@Override
  public Object compute(FloatMatrix params, int flag) {

    x = params.getRange(0, rows * features);
    FloatMatrix theta = params.getRange(rows * features, params.length);

    x = x.reshape(rows, features);
    theta = theta.reshape(columns, features);

    if (flag == 1 || flag == 3) {
      FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2);
      this.cost = M.mul(r).columnSums().rowSums().get(0) / 2;

      if (lambda != 0) {
        float cost1 =
            (lambda / 2)
                * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0)
                    + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0));
        this.cost += cost1;
      }
    }

    if (flag == 2 || flag == 3) {

      FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns);
      FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns);

      int[] indices;
      FloatMatrix thetaTemp;
      FloatMatrix xTemp;
      FloatMatrix yTemp;
      for (int i = 0; i < rows; i++) {
        indices = r.getRow(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        thetaTemp = theta.getRows(indices);
        yTemp = y.getRow(i).get(indices);
        xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp));
      }
      xGrad = xGrad.add(x.mmul(lambda));

      for (int i = 0; i < columns; i++) {
        indices = r.getColumn(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        xTemp = x.getRows(indices);
        yTemp = y.getColumn(i).get(indices);
        thetaGrad.putRow(
            i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp));
      }
      thetaGrad = thetaGrad.add(theta.mmul(lambda));

      this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data);
    }

    return flag == 1 ? cost : gradient;
  }
  /**
   * Strips the dataset down to the specified labels and remaps them
   *
   * @param labels the labels to strip down to
   */
  public void filterAndStrip(int[] labels) {
    FloatDataSet filtered = filterBy(labels);
    List<Integer> newLabels = new ArrayList<>();

    // map new labels to index according to passed in labels
    Map<Integer, Integer> labelMap = new HashMap<>();

    for (int i = 0; i < labels.length; i++) labelMap.put(labels[i], i);

    // map examples
    for (int i = 0; i < filtered.numExamples(); i++) {
      int o2 = filtered.get(i).outcome();
      int outcome = labelMap.get(o2);
      newLabels.add(outcome);
    }

    FloatMatrix newLabelMatrix = new FloatMatrix(filtered.numExamples(), labels.length);

    if (newLabelMatrix.rows != newLabels.size())
      throw new IllegalStateException("Inconsistent label sizes");

    for (int i = 0; i < newLabelMatrix.rows; i++) {
      Integer i2 = newLabels.get(i);
      if (i2 == null) throw new IllegalStateException("Label not found on row " + i);
      FloatMatrix newRow = MatrixUtil.toOutcomeVectorFloat(i2, labels.length);
      newLabelMatrix.putRow(i, newRow);
    }

    setFirst(filtered.getFirst());
    setSecond(newLabelMatrix);
  }
 /**
  * Sample a dataset
  *
  * @param numSamples the number of samples to getFromOrigin
  * @param rng the rng to use
  * @param withReplacement whether to allow duplicates (only tracked by example row number)
  * @return the sample dataset
  */
 public FloatDataSet sample(int numSamples, RandomGenerator rng, boolean withReplacement) {
   if (numSamples >= numExamples()) return this;
   else {
     FloatMatrix examples = new FloatMatrix(numSamples, getFirst().columns);
     FloatMatrix outcomes = new FloatMatrix(numSamples, numOutcomes());
     Set<Integer> added = new HashSet<Integer>();
     for (int i = 0; i < numSamples; i++) {
       int picked = rng.nextInt(numExamples());
       if (!withReplacement)
         while (added.contains(picked)) {
           picked = rng.nextInt(numExamples());
         }
       examples.putRow(i, get(picked).getFirst());
       outcomes.putRow(i, get(picked).getSecond());
     }
     return new FloatDataSet(examples, outcomes);
   }
 }
  public static FloatDataSet merge(List<FloatDataSet> data) {
    if (data.isEmpty()) throw new IllegalArgumentException("Unable to merge empty dataset");
    FloatDataSet first = data.get(0);
    int numExamples = totalExamples(data);
    FloatMatrix in = new FloatMatrix(numExamples, first.getFirst().columns);
    FloatMatrix out = new FloatMatrix(numExamples, first.getSecond().columns);
    int count = 0;

    for (int i = 0; i < data.size(); i++) {
      FloatDataSet d1 = data.get(i);
      for (int j = 0; j < d1.numExamples(); j++) {
        FloatDataSet example = d1.get(j);
        in.putRow(count, example.getFirst());
        out.putRow(count, example.getSecond());
        count++;
      }
    }
    return new FloatDataSet(in, out);
  }
  /**
   * Loads the google binary model Credit to:
   * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
   *
   * @param path path to model
   * @throws IOException
   */
  public static Word2Vec loadGoogleModel(String path) throws IOException {
    DataInputStream dis = null;
    BufferedInputStream bis = null;
    double len = 0;
    float vector = 0;
    Word2Vec ret = new Word2Vec();
    Index wordIndex = new Index();
    FloatMatrix wordVectors = null;
    try {
      bis =
          new BufferedInputStream(
              path.endsWith(".gz")
                  ? new GZIPInputStream(new FileInputStream(path))
                  : new FileInputStream(path));
      dis = new DataInputStream(bis);
      Map<String, FloatMatrix> wordMap = new HashMap<>();
      // number of words
      int words = Integer.parseInt(readString(dis));
      // word vector size
      int size = Integer.parseInt(readString(dis));
      wordVectors = new FloatMatrix(words, size);
      String word;
      float[] vectors = null;
      for (int i = 0; i < words; i++) {
        word = readString(dis);
        log.info("Loaded " + word);
        vectors = new float[size];
        len = 0;
        for (int j = 0; j < size; j++) {
          vector = readFloat(dis);
          len += vector * vector;
          vectors[j] = vector;
        }
        len = Math.sqrt(len);

        for (int j = 0; j < size; j++) {
          vectors[j] /= len;
        }
        wordIndex.add(word);
        wordVectors.putRow(i, new FloatMatrix(vectors));
      }
    } finally {
      bis.close();
      dis.close();
    }

    ret.setWordIndex(wordIndex);
    ret.setSyn0(wordVectors);

    return ret;
  }