/**
   * Initializes training. Runs through all data points in the training set and updates the weight
   * vector whenever a classification error occurs.
   *
   * <p>Can be called multiple times.
   *
   * @param dataset the dataset to train on. Each column is treated as point.
   * @param labelset the set of labels, one for each data point. If the cardinalities of data- and
   *     labelset do not match, a CardinalityException is thrown
   */
  public void train(Vector labelset, Matrix dataset) throws TrainingException {
    if (labelset.size() != dataset.columnSize()) {
      throw new CardinalityException(labelset.size(), dataset.columnSize());
    }

    boolean converged = false;
    int iteration = 0;
    while (!converged) {
      if (iteration > 1000) {
        throw new TrainingException("Too many iterations needed to find hyperplane.");
      }

      converged = true;
      int columnCount = dataset.columnSize();
      for (int i = 0; i < columnCount; i++) {
        Vector dataPoint = dataset.viewColumn(i);
        log.debug("Training point: {}", dataPoint);

        synchronized (this.model) {
          boolean prediction = model.classify(dataPoint);
          double label = labelset.get(i);
          if (label <= 0 && prediction || label > 0 && !prediction) {
            log.debug("updating");
            converged = false;
            update(label, dataPoint, this.model);
          }
        }
      }
      iteration++;
    }
  }
  public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
    FileSystem fs = output.getFileSystem(conf);

    Vector weightsPerLabel = null;
    Vector perLabelThetaNormalizer = null;
    Vector weightsPerFeature = null;
    Matrix weightsPerLabelAndFeature;
    float alphaI;

    FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
    try {
      alphaI = in.readFloat();
      weightsPerFeature = VectorWritable.readVector(in);
      weightsPerLabel = VectorWritable.readVector(in);
      perLabelThetaNormalizer = VectorWritable.readVector(in);

      weightsPerLabelAndFeature =
          new SparseMatrix(weightsPerLabel.size(), weightsPerFeature.size());
      for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
        weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
      }
    } finally {
      Closeables.closeQuietly(in);
    }
    NaiveBayesModel model =
        new NaiveBayesModel(
            weightsPerLabelAndFeature,
            weightsPerFeature,
            weightsPerLabel,
            perLabelThetaNormalizer,
            alphaI);
    model.validate();
    return model;
  }
    @Override
    protected void map(IntWritable row, VectorWritable similaritiesWritable, Context ctx)
        throws IOException, InterruptedException {
      Vector similarities = similaritiesWritable.get();
      // For performance, the creation of transposedPartial is moved out of the while loop and it is
      // reused inside
      Vector transposedPartial = new RandomAccessSparseVector(similarities.size(), 1);
      TopElementsQueue topKQueue = new TopElementsQueue(maxSimilaritiesPerRow);
      Iterator<Vector.Element> nonZeroElements = similarities.iterateNonZero();
      while (nonZeroElements.hasNext()) {
        Vector.Element nonZeroElement = nonZeroElements.next();

        MutableElement top = topKQueue.top();
        double candidateValue = nonZeroElement.get();
        if (candidateValue > top.get()) {
          top.setIndex(nonZeroElement.index());
          top.set(candidateValue);
          topKQueue.updateTop();
        }

        transposedPartial.setQuick(row.get(), candidateValue);
        ctx.write(new IntWritable(nonZeroElement.index()), new VectorWritable(transposedPartial));
        transposedPartial.setQuick(row.get(), 0.0);
      }
      Vector topKSimilarities =
          new RandomAccessSparseVector(similarities.size(), maxSimilaritiesPerRow);
      for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
        topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
      }
      ctx.write(row, new VectorWritable(topKSimilarities));
    }
Beispiel #4
0
  static NaiveBayesModel readModelFromTempDir(Path base, Configuration conf) {

    float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);

    // read feature sums and label sums
    Vector scoresPerLabel = null;
    Vector scoresPerFeature = null;
    for (Pair<Text, VectorWritable> record :
        new SequenceFileDirIterable<Text, VectorWritable>(
            new Path(base, TrainNaiveBayesJob.WEIGHTS),
            PathType.LIST,
            PathFilters.partFilter(),
            conf)) {
      String key = record.getFirst().toString();
      VectorWritable value = record.getSecond();
      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
        scoresPerFeature = value.get();
      } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
        scoresPerLabel = value.get();
      }
    }

    Preconditions.checkNotNull(scoresPerFeature);
    Preconditions.checkNotNull(scoresPerLabel);

    Matrix scoresPerLabelAndFeature =
        new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
    for (Pair<IntWritable, VectorWritable> entry :
        new SequenceFileDirIterable<IntWritable, VectorWritable>(
            new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS),
            PathType.LIST,
            PathFilters.partFilter(),
            conf)) {
      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
    }

    Vector perlabelThetaNormalizer = null;
    for (Pair<Text, VectorWritable> entry :
        new SequenceFileDirIterable<Text, VectorWritable>(
            new Path(base, TrainNaiveBayesJob.THETAS),
            PathType.LIST,
            PathFilters.partFilter(),
            conf)) {
      if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
        perlabelThetaNormalizer = entry.getSecond().get();
      }
    }

    Preconditions.checkNotNull(perlabelThetaNormalizer);

    return new NaiveBayesModel(
        scoresPerLabelAndFeature,
        scoresPerFeature,
        scoresPerLabel,
        perlabelThetaNormalizer,
        alphaI);
  }
 @Override
 public Vector select(Vector probabilities) {
   int maxValueIndex = probabilities.maxValueIndex();
   Vector weights = new SequentialAccessSparseVector(probabilities.size());
   weights.set(maxValueIndex, 1.0);
   return weights;
 }
 @Override
 public Matrix getDiagonalMatrix() {
   if (diagonalMatrix == null) {
     diagonalMatrix = new DenseMatrix(desiredRank, desiredRank);
   }
   if (diagonalMatrix.get(0, 1) <= 0) {
     try {
       Vector norms = fetchVector(new Path(baseDir, "norms"), 0);
       Vector projections = fetchVector(new Path(baseDir, "projections"), 0);
       if (norms != null && projections != null) {
         int i = 0;
         while (i < projections.size() - 1) {
           diagonalMatrix.set(i, i, projections.get(i));
           diagonalMatrix.set(i, i + 1, norms.get(i));
           diagonalMatrix.set(i + 1, i, norms.get(i));
           i++;
         }
         diagonalMatrix.set(i, i, projections.get(i));
       }
     } catch (IOException e) {
       log.error("Could not load diagonal matrix of norms and projections: ", e);
     }
   }
   return diagonalMatrix;
 }
  static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (!parseArgs(args)) {
      return;
    }
    AdaptiveLogisticModelParameters lmp =
        AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    csv.setIdName(idColumn);

    AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();

    State<Wrapper, CrossFoldLearner> best = lr.getBest();
    if (best == null) {
      output.println("AdaptiveLogisticRegression has not be trained probably.");
      return;
    }
    CrossFoldLearner learner = best.getPayload().getLearner();

    BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
    BufferedWriter out =
        new BufferedWriter(
            new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8));

    out.write(idColumn + ",target,score");
    out.newLine();

    String line = in.readLine();
    csv.firstLine(line);
    line = in.readLine();
    Map<String, Double> results = new HashMap<String, Double>();
    int k = 0;
    while (line != null) {
      Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
      csv.processLine(line, v, false);
      Vector scores = learner.classifyFull(v);
      results.clear();
      if (maxScoreOnly) {
        results.put(csv.getTargetLabel(scores.maxValueIndex()), scores.maxValue());
      } else {
        for (int i = 0; i < scores.size(); i++) {
          results.put(csv.getTargetLabel(i), scores.get(i));
        }
      }

      for (Map.Entry<String, Double> entry : results.entrySet()) {
        out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
        out.newLine();
      }
      k++;
      if (k % 100 == 0) {
        output.println(k + " records processed");
      }
      line = in.readLine();
    }
    out.flush();
    out.close();
    output.println(k + " records processed totally.");
  }
 private static String getFormatedOutput(VectorWritable vw) {
   String formatedString = "";
   int formatWidth = 8;
   Vector vector = vw.get();
   for (int i = 0; i < vector.size(); ++i) {
     formatedString += String.format("%" + Integer.toString(formatWidth) + ".4f", vector.get(i));
   }
   return formatedString;
 }
 @Override
 public double pdf(VectorWritable v) {
   Vector x = v.get();
   // return the product of the component pdfs
   // TODO: is this reasonable? correct?
   double pdf = pdf(x, stdDev.get(0));
   for (int i = 1; i < x.size(); i++) {
     pdf *= pdf(x, stdDev.get(i));
   }
   return pdf;
 }
  /**
   * targets from until cluster zoom
   *
   * @see HttpServlet#doGet(HttpServletRequest request, HttpServletResponse response)
   */
  protected void doGet(HttpServletRequest request, HttpServletResponse response)
      throws ServletException, IOException {

    // Make Query From Request
    String url = makeQuery(request);
    if (url == null) return;

    System.out.println(" [Location Clusterer] : Get Data From Graphite.............");
    // Get Data From Graphite
    String locationData = getDataFromGraphite(url);

    System.out.println(" [Location Clusterer] : Parse and Fuse obtained data.............");
    // Parse and Fuse obtained data
    JSONArray jArr = new JSONArray(locationData);
    String manipulatedString = fusionGraphiteData(jArr);

    System.out.println(" [Location Clusterer] : Manipulate Data.............");
    // Manipulate Data ( Separate / Filter )
    List<Period> indoorPeriods = getIndoorData(manipulatedString);
    manipulatedString = filterIndoorData(manipulatedString);

    System.out.println(" [Location Clusterer] : Run K-means Clustering Algorithm.............");
    int numOfClusters = 4;
    if (request.getParameter("cluster") != null)
      numOfClusters = Integer.parseInt(request.getParameter("cluster"));
    List<Cluster> clusters = runKMeansClustering(manipulatedString, numOfClusters);
    if (clusters.size() == 1) {
      Vector radi = clusters.get(0).getRadius();
      if (radi.size() == 0) {
        System.out.println(" [Location Clusterer] : Null Cluster Exit");
        return;
      }
    }
    System.out.println(" [Location Clusterer] : Filtering Result.............");
    String[] centers = getCenters(clusters);
    List<String> clusterMembers = getClusterMembers(manipulatedString, centers);
    clusterMembers = filterNullClusters(clusterMembers);

    System.out.println(" [Location Clusterer] : Make Regions for Clusters.............");
    List<ClusterRectangle> clusterRectangles = new ArrayList<ClusterRectangle>();
    for (int i = 0; i < clusterMembers.size(); i++) {
      ClusterRectangle clusterRectangle = getClusterRectangle(clusterMembers.get(i));
      clusterRectangles.add(clusterRectangle);
    }

    // Do Visualization
    doVisualization(clusterMembers, manipulatedString, indoorPeriods, request);

    // Update Graph
    System.out.println(" [Location Clusterer] : Update Graph.............");
    String target = (String) request.getParameter("target");
    updateGraph(clusterRectangles, indoorPeriods, target);
  }
Beispiel #11
0
  public static Job createTimesSquaredJob(
      Configuration initialConf,
      Vector v,
      Path matrixInputPath,
      Path outputVectorPathBase,
      Class<? extends TimesSquaredMapper> mapClass,
      Class<? extends VectorSummingReducer> redClass)
      throws IOException {

    return createTimesSquaredJob(
        initialConf, v, v.size(), matrixInputPath, outputVectorPathBase, mapClass, redClass);
  }
Beispiel #12
0
 /**
  * Return a human-readable formatted string representation of the vector, not intended to be
  * complete nor usable as an input/output representation
  */
 public static String formatVector(Vector v, String[] bindings) {
   StringBuilder buf = new StringBuilder();
   if (v instanceof NamedVector) {
     buf.append(((NamedVector) v).getName()).append(" = ");
   }
   int nzero = 0;
   Iterator<Vector.Element> iterateNonZero = v.iterateNonZero();
   while (iterateNonZero.hasNext()) {
     iterateNonZero.next();
     nzero++;
   }
   // if vector is sparse or if we have bindings, use sparse notation
   if (nzero < v.size() || bindings != null) {
     buf.append('[');
     for (int i = 0; i < v.size(); i++) {
       double elem = v.get(i);
       if (elem == 0.0) {
         continue;
       }
       String label;
       if (bindings != null && (label = bindings[i]) != null) {
         buf.append(label).append(':');
       } else {
         buf.append(i).append(':');
       }
       buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
     }
   } else {
     buf.append('[');
     for (int i = 0; i < v.size(); i++) {
       double elem = v.get(i);
       buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
     }
   }
   if (buf.length() > 1) {
     buf.setLength(buf.length() - 2);
   }
   buf.append(']');
   return buf.toString();
 }
 public double[][] toDoubleArray() {
   final double[][] matrix = new double[this.numRows][this.numCols];
   Iterator<MatrixSlice> iterator = this.iterateAll();
   int i = 0;
   while (iterator.hasNext()) {
     Vector rowVector = iterator.next().vector();
     for (int j = 0; j < rowVector.size(); j++) {
       matrix[i][j] = rowVector.getElement(j).get();
     }
     i++;
   }
   return matrix;
 }
 /**
  * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics()
  * @param numSamples the number of times to sample (with replacement) from the model
  * @return array of length numSamples, with each entry being a sample from the model. There may
  *     be repeats
  */
 public int[] sample(Vector topicDistribution, int numSamples) {
   Preconditions.checkNotNull(topicDistribution);
   Preconditions.checkArgument(numSamples > 0, "numSamples must be positive");
   Preconditions.checkArgument(
       topicDistribution.size() == samplers.length,
       "topicDistribution must have same cardinality as the sampling model");
   int[] samples = new int[numSamples];
   Sampler topicSampler = new Sampler(random, topicDistribution);
   for (int i = 0; i < numSamples; i++) {
     samples[i] = samplers[topicSampler.sample()].sample();
   }
   return samples;
 }
 @Override
 public double getScaleFactor() {
   if (scaleFactor <= 0) {
     try {
       Vector v = fetchVector(new Path(baseDir, "scaleFactor"), 0);
       if (v != null && v.size() > 0) {
         scaleFactor = v.get(0);
       }
     } catch (IOException e) {
       log.error("could not load scaleFactor:", e);
     }
   }
   return scaleFactor;
 }
Beispiel #16
0
 /**
  * A version to compute yRow as a sparse vector in case of extremely sparse matrices
  *
  * @param aRow
  * @param yRowOut
  */
 public void computeYRow(Vector aRow, Vector yRowOut) {
   yRowOut.assign(0.0);
   if (aRow.isDense()) {
     int n = aRow.size();
     for (int j = 0; j < n; j++) {
       accumDots(j, aRow.getQuick(j), yRowOut);
     }
   } else {
     for (Iterator<Element> iter = aRow.iterateNonZero(); iter.hasNext(); ) {
       Element el = iter.next();
       accumDots(el.index(), el.get(), yRowOut);
     }
   }
 }
  public boolean verify(DistributedRowMatrix other) {

    Iterator<MatrixSlice> iteratorThis = this.iterateAll();
    Iterator<MatrixSlice> iteratorOther = other.iterateAll();

    while (iteratorThis.hasNext()) {
      Vector thisVector = iteratorThis.next().vector();
      Vector otherVector = iteratorOther.next().vector();

      if (thisVector.size() != otherVector.size()) {
        return false;
      }

      for (int j = 0; j < thisVector.size(); j++) {
        if (thisVector.getElement(j).get() != otherVector.getElement(j).get()) {
          // System.out.println("Verify failed!");
          // System.out.println("  Vector1: " + thisVector.toString());
          // System.out.println("  Vector2: " + otherVector.toString());
          return false;
        }
      }
    }
    return true;
  }
Beispiel #18
0
 /**
  * compute YRow=ARow*Omega.
  *
  * @param aRow row of matrix A (size n)
  * @param yRow row of matrix Y (result) must be pre-allocated to size of (k+p)
  */
 @Deprecated
 public void computeYRow(Vector aRow, double[] yRow) {
   // assert yRow.length == kp;
   Arrays.fill(yRow, 0.0);
   if (aRow.isDense()) {
     int n = aRow.size();
     for (int j = 0; j < n; j++) {
       accumDots(j, aRow.getQuick(j), yRow);
     }
   } else {
     for (Iterator<Element> iter = aRow.iterateNonZero(); iter.hasNext(); ) {
       Element el = iter.next();
       accumDots(el.index(), el.get(), yRow);
     }
   }
 }
 public static Matrix sampledCorpus(
     Matrix matrix, Random random, int numDocs, int numSamples, int numTopicsPerDoc) {
   Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
   LDASampler modelSampler = new LDASampler(matrix, random);
   Vector topicVector = new DenseVector(matrix.numRows());
   for (int i = 0; i < numTopicsPerDoc; i++) {
     int topic = random.nextInt(topicVector.size());
     topicVector.set(topic, topicVector.get(topic) + 1);
   }
   for (int docId = 0; docId < numDocs; docId++) {
     for (int sample : modelSampler.sample(topicVector, numSamples)) {
       corpus.set(docId, sample, corpus.get(docId, sample) + 1);
     }
   }
   return corpus;
 }
 public int printDistributedRowMatrix() {
   System.out.println("RowPath: " + this.rowPath);
   Iterator<MatrixSlice> iterator = this.iterateAll();
   int count = 0;
   while (iterator.hasNext()) {
     MatrixSlice slice = iterator.next();
     Vector v = slice.vector();
     int size = v.size();
     for (int i = 0; i < size; i++) {
       Element e = v.getElement(i);
       count++;
       System.out.print(e.get() + " ");
     }
     System.out.println();
   }
   return count;
 }
Beispiel #21
0
 @Override
 public void reduce(IntWritable id, Iterable<VectorWritable> sums, Context context)
     throws IOException, InterruptedException {
   Iterator<VectorWritable> it = sums.iterator();
   if (!it.hasNext()) {
     return;
   }
   DenseVector sumVector = null;
   while (it.hasNext()) {
     Vector vec = it.next().get();
     if (sumVector == null) {
       sumVector = new DenseVector(vec.size());
     }
     sumVector.assign(vec, Functions.PLUS);
   }
   double max = sumVector.maxValue();
   context.write(id, new DoubleWritable(max));
 }
 private static Matrix loadVectors(String vectorPathString, Configuration conf)
     throws IOException {
   Path vectorPath = new Path(vectorPathString);
   FileSystem fs = vectorPath.getFileSystem(conf);
   List<Path> subPaths = Lists.newArrayList();
   if (fs.isFile(vectorPath)) {
     subPaths.add(vectorPath);
   } else {
     for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) {
       subPaths.add(fileStatus.getPath());
     }
   }
   List<Pair<Integer, Vector>> rowList = Lists.newArrayList();
   int numRows = Integer.MIN_VALUE;
   int numCols = -1;
   boolean sequentialAccess = false;
   for (Path subPath : subPaths) {
     for (Pair<IntWritable, VectorWritable> record :
         new SequenceFileIterable<IntWritable, VectorWritable>(subPath, true, conf)) {
       int id = record.getFirst().get();
       Vector vector = record.getSecond().get();
       if (vector instanceof NamedVector) {
         vector = ((NamedVector) vector).getDelegate();
       }
       if (numCols < 0) {
         numCols = vector.size();
         sequentialAccess = vector.isSequentialAccess();
       }
       rowList.add(Pair.of(id, vector));
       numRows = Math.max(numRows, id);
     }
   }
   numRows++;
   Vector[] rowVectors = new Vector[numRows];
   for (Pair<Integer, Vector> pair : rowList) {
     rowVectors[pair.getFirst()] = pair.getSecond();
   }
   return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess);
 }
Beispiel #23
0
  @Test
  public void testInitialization() {
    // start with super clusterable data
    List<? extends WeightedVector> data = cubishTestData(0.01);

    // just do initialization of ball k-means.  This should drop a point into each of the clusters
    BallKMeans r = new BallKMeans(new BruteSearch(new EuclideanDistanceMeasure()), 6, 20);
    r.cluster(data);

    // put the centroids into a matrix
    Matrix x = new DenseMatrix(6, 5);
    int row = 0;
    for (Centroid c : r) {
      x.viewRow(row).assign(c.viewPart(0, 5));
      row++;
    }

    // verify that each column looks right.  Should contain zeros except for a single 6.
    final Vector columnNorms =
        x.aggregateColumns(
            new VectorFunction() {
              @Override
              public double apply(Vector f) {
                // return the sum of three discrepancy measures
                return Math.abs(f.minValue())
                    + Math.abs(f.maxValue() - 6)
                    + Math.abs(f.norm(1) - 6);
              }
            });
    // verify all errors are nearly zero
    assertEquals(0, columnNorms.norm(1) / columnNorms.size(), 0.1);

    // verify that the centroids are a permutation of the original ones
    SingularValueDecomposition svd = new SingularValueDecomposition(x);
    Vector s = svd.getS().viewDiagonal().assign(Functions.div(6));
    assertEquals(5, s.getLengthSquared(), 0.05);
    assertEquals(5, s.norm(1), 0.05);
  }
  @Test
  public void testMatrixDiagonalizeReducer() throws Exception {
    MatrixDiagonalizeMapper mapper = new MatrixDiagonalizeMapper();
    Configuration conf = getConfiguration();
    conf.setInt(Keys.AFFINITY_DIMENSIONS, RAW_DIMENSIONS);

    // set up the dummy writers
    DummyRecordWriter<NullWritable, IntDoublePairWritable> mapWriter = new DummyRecordWriter<>();
    Mapper<IntWritable, VectorWritable, NullWritable, IntDoublePairWritable>.Context mapContext =
        DummyRecordWriter.build(mapper, conf, mapWriter);

    // perform the mapping
    for (int i = 0; i < RAW_DIMENSIONS; i++) {
      RandomAccessSparseVector toAdd = new RandomAccessSparseVector(RAW_DIMENSIONS);
      toAdd.assign(RAW[i]);
      mapper.map(new IntWritable(i), new VectorWritable(toAdd), mapContext);
    }

    // now perform the reduction
    MatrixDiagonalizeReducer reducer = new MatrixDiagonalizeReducer();
    DummyRecordWriter<NullWritable, VectorWritable> redWriter = new DummyRecordWriter<>();
    Reducer<NullWritable, IntDoublePairWritable, NullWritable, VectorWritable>.Context redContext =
        DummyRecordWriter.build(
            reducer, conf, redWriter, NullWritable.class, IntDoublePairWritable.class);

    // only need one reduction
    reducer.reduce(NullWritable.get(), mapWriter.getValue(NullWritable.get()), redContext);

    // first, make sure there's only one result
    List<VectorWritable> list = redWriter.getValue(NullWritable.get());
    assertEquals("Only a single resulting vector", 1, list.size());
    Vector v = list.get(0).get();
    for (int i = 0; i < v.size(); i++) {
      assertEquals("Element sum is correct", rowSum(RAW[i]), v.get(i), 0.01);
    }
  }
 public int numLabels() {
   return weightsPerLabel.size();
 }
  /**
   * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified
   * preconditioner to improve numeric stability and possibly speed convergence. This version of
   * solve() allows control over the termination and iteration parameters.
   *
   * @param a The matrix A.
   * @param b The vector b.
   * @param preconditioner The preconditioner to apply.
   * @param maxIterations The maximum number of iterations to run.
   * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until
   *     the residual falls below this value or until maxIterations are completed.
   * @return The result x of solving the system.
   * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to
   *     the number of columns of A, if maxError is less than zero, or if maxIterations is not
   *     positive.
   */
  public Vector solve(
      VectorIterable a,
      Vector b,
      Preconditioner preconditioner,
      int maxIterations,
      double maxError) {

    if (a.numRows() != a.numCols()) {
      throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
    }

    if (a.numCols() != b.size()) {
      throw new CardinalityException(a.numCols(), b.size());
    }

    if (maxIterations <= 0) {
      throw new IllegalArgumentException("Max iterations must be positive.");
    }

    if (maxError < 0.0) {
      throw new IllegalArgumentException("Max error must be non-negative.");
    }

    Vector x = new DenseVector(b.size());

    iterations = 0;
    Vector residual = b.minus(a.times(x));
    residualNormSquared = residual.dot(residual);

    log.info("Conjugate gradient initial residual norm = {}", Math.sqrt(residualNormSquared));
    double previousConditionedNormSqr = 0.0;
    Vector updateDirection = null;
    while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) {
      Vector conditionedResidual;
      double conditionedNormSqr;
      if (preconditioner == null) {
        conditionedResidual = residual;
        conditionedNormSqr = residualNormSquared;
      } else {
        conditionedResidual = preconditioner.precondition(residual);
        conditionedNormSqr = residual.dot(conditionedResidual);
      }

      ++iterations;

      if (iterations == 1) {
        updateDirection = new DenseVector(conditionedResidual);
      } else {
        double beta = conditionedNormSqr / previousConditionedNormSqr;

        // updateDirection = residual + beta * updateDirection
        updateDirection.assign(Functions.MULT, beta);
        updateDirection.assign(conditionedResidual, Functions.PLUS);
      }

      Vector aTimesUpdate = a.times(updateDirection);

      double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);

      // x = x + alpha * updateDirection
      PLUS_MULT.setMultiplicator(alpha);
      x.assign(updateDirection, PLUS_MULT);

      // residual = residual - alpha * A * updateDirection
      PLUS_MULT.setMultiplicator(-alpha);
      residual.assign(aTimesUpdate, PLUS_MULT);

      previousConditionedNormSqr = conditionedNormSqr;
      residualNormSquared = residual.dot(residual);

      log.info(
          "Conjugate gradient iteration {} residual norm = {}",
          iterations,
          Math.sqrt(residualNormSquared));
    }
    return x;
  }
 /**
  * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and
  * positive definite. Only the squareness of a is checked, since testing for symmetry and positive
  * definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not
  * yield a valid result.
  *
  * @param a The linear operator A.
  * @param b The vector b.
  * @return The result x of solving the system.
  * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the
  *     number of columns of a.
  */
 public Vector solve(VectorIterable a, Vector b) {
   return solve(a, b, null, b.size() + 2, DEFAULT_MAX_ERROR);
 }
 /**
  * Solves the system Ax = b with default termination criteria using the specified preconditioner.
  * A must be symmetric, square, and positive definite. Only the squareness of a is checked, since
  * testing for symmetry and positive definiteness are too expensive. If an invalid matrix is
  * specified, then the algorithm may not yield a valid result.
  *
  * @param a The linear operator A.
  * @param b The vector b.
  * @param precond A preconditioner to use on A during the solution process.
  * @return The result x of solving the system.
  * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the
  *     number of columns of a.
  */
 public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
   return solve(a, b, precond, b.size() + 2, DEFAULT_MAX_ERROR);
 }