@Override
  public Object compute(FloatMatrix params, int flag) {

    x = params.getRange(0, rows * features);
    FloatMatrix theta = params.getRange(rows * features, params.length);

    x = x.reshape(rows, features);
    theta = theta.reshape(columns, features);

    if (flag == 1 || flag == 3) {
      FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2);
      this.cost = M.mul(r).columnSums().rowSums().get(0) / 2;

      if (lambda != 0) {
        float cost1 =
            (lambda / 2)
                * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0)
                    + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0));
        this.cost += cost1;
      }
    }

    if (flag == 2 || flag == 3) {

      FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns);
      FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns);

      int[] indices;
      FloatMatrix thetaTemp;
      FloatMatrix xTemp;
      FloatMatrix yTemp;
      for (int i = 0; i < rows; i++) {
        indices = r.getRow(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        thetaTemp = theta.getRows(indices);
        yTemp = y.getRow(i).get(indices);
        xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp));
      }
      xGrad = xGrad.add(x.mmul(lambda));

      for (int i = 0; i < columns; i++) {
        indices = r.getColumn(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        xTemp = x.getRows(indices);
        yTemp = y.getColumn(i).get(indices);
        thetaGrad.putRow(
            i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp));
      }
      thetaGrad = thetaGrad.add(theta.mmul(lambda));

      this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data);
    }

    return flag == 1 ? cost : gradient;
  }
 @Override
 public double mse() {
   DoubleMatrix reconstructed = reconstruct(input);
   DoubleMatrix diff = reconstructed.sub(input);
   double sum = 0.5 * MatrixFunctions.pow(diff, 2).columnSums().sum() / input.rows;
   return sum;
 }
  @Override
  public double squaredLoss() {
    DoubleMatrix squaredDiff = pow(reconstruct(input).sub(input), 2);
    double loss = squaredDiff.columnSums().sum() / input.rows;
    if (this.useRegularization) {
      loss += 0.5 * l2 * MatrixFunctions.pow(W, 2).sum();
    }

    return loss;
  }
  /**
   * Negative log likelihood of the current input given the corruption level
   *
   * @return the negative log likelihood of the auto encoder given the corruption level
   */
  public double negativeLoglikelihood(DoubleMatrix input) {
    DoubleMatrix z = this.reconstruct(input);
    if (this.useRegularization) {
      double reg = (2 / l2) * MatrixFunctions.pow(this.W, 2).sum();

      return -input.mul(log(z)).add(oneMinus(input).mul(log(oneMinus(z)))).columnSums().mean()
          + reg;
    }

    return -input.mul(log(z)).add(oneMinus(input).mul(log(oneMinus(z)))).columnSums().mean();
  }
  /**
   * Negative log likelihood of the current input given the corruption level
   *
   * @return the negative log likelihood of the auto encoder given the corruption level
   */
  @Override
  public double negativeLogLikelihood() {
    DoubleMatrix z = this.reconstruct(input);
    if (this.useRegularization) {
      double reg = (2 / l2) * MatrixFunctions.pow(this.W, 2).sum();

      double ret =
          -input.mul(log(z)).add(oneMinus(input).mul(log(oneMinus(z)))).columnSums().mean() + reg;
      if (this.normalizeByInputRows) ret /= input.rows;
      return ret;
    }

    double likelihood =
        -input.mul(log(z)).add(oneMinus(input).mul(log(oneMinus(z)))).columnSums().mean();

    if (this.normalizeByInputRows) likelihood /= input.rows;

    return likelihood;
  }
Beispiel #6
0
  private void backpropDerivativesAndError(
      Tree tree,
      MultiDimensionalMap<String, String, FloatMatrix> binaryTD,
      MultiDimensionalMap<String, String, FloatMatrix> binaryCD,
      MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD,
      Map<String, FloatMatrix> unaryCD,
      Map<String, FloatMatrix> wordVectorD,
      FloatMatrix deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    FloatMatrix currentVector = tree.vector();
    String category = tree.label();
    category = basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    FloatMatrix goldLabel = new FloatMatrix(numOuts, 1);
    int goldClass = tree.goldLabel();
    if (goldClass >= 0) {
      goldLabel.put(goldClass, 1.0f);
    }

    Float nodeWeight = classWeights.get(goldClass);
    if (nodeWeight == null) nodeWeight = 1.0f;
    FloatMatrix predictions = tree.prediction();

    // If this is an unlabeled class, set deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    FloatMatrix deltaClass =
        goldClass >= 0
            ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel))
            : new FloatMatrix(predictions.rows, predictions.columns);
    FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose());

    float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum());
    error = error * nodeWeight;
    tree.setError(error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).add(localCD));

      String word = tree.children().get(0).label();
      word = getVocabWord(word);

      FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector);
      FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass);
      deltaFromClass =
          deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);
      wordVectorD.put(word, wordVectorD.get(word).add(deltaFull));

    } else {
      // Otherwise, this must be a binary node
      String leftCategory = basicCategory(tree.children().get(0).label());
      String rightCategory = basicCategory(tree.children().get(1).label());
      if (combineClassification) {
        unaryCD.put("", unaryCD.get("").add(localCD));
      } else {
        binaryCD.put(
            leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD));
      }

      FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector);
      FloatMatrix deltaFromClass =
          getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);

      FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1));
      deltaFromClass = mult.muli(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      // deltaFull 50 x 1, childrenVector: 50 x 2
      FloatMatrix add = binaryTD.get(leftCategory, rightCategory);

      FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, add.add(W_df));

      FloatMatrix deltaDown;
      if (useFloatTensors) {
        FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector);
        binaryFloatTensorTD.put(
            leftCategory,
            rightCategory,
            binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df));
        deltaDown =
            computeFloatTensorDeltaDown(
                deltaFull,
                leftVector,
                rightVector,
                getBinaryTransform(leftCategory, rightCategory),
                getBinaryFloatTensor(leftCategory, rightCategory));
      } else {
        deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
      }

      FloatMatrix leftDerivative = activationFunction.apply(leftVector);
      FloatMatrix rightDerivative = activationFunction.apply(rightVector);
      FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1));
      FloatMatrix rightDeltaDown =
          deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1));
      backpropDerivativesAndError(
          tree.children().get(0),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          leftDerivative.mul(leftDeltaDown));
      backpropDerivativesAndError(
          tree.children().get(1),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          rightDerivative.mul(rightDeltaDown));
    }
  }
Beispiel #7
0
  public static void main(String[] args) {
    Scanner s = new Scanner(System.in);
    int n = s.nextInt();
    int nn = n;
    DoubleMatrix d = DoubleMatrix.zeros(2 * n, 2 * n);
    DoubleMatrix d2 = DoubleMatrix.zeros(n, n);

    for (int i = 0; i < n; i++) {
      for (int j = 0; j < n; j++) {
        d.put(i, j, s.nextDouble());
        d2.put(i, j, d.get(i, j));
      }
    }
    // d2 = new DoubleMatrix(d.data);
    System.out.println(d);
    System.out.println(d);

    List<Integer> L = new ArrayList<Integer>(2 * n);
    List<Double> R = new ArrayList<Double>(2 * n);
    for (int i = 0; i < 2 * n; i++) {
      L.add(0);
      R.add(0.0);
    }
    // inicjalizacja lisci - jako etykiety kolejne liczby od 0
    for (int i = 0; i < n; i++) {
      L.set(i, i);
    }

    // V - drzewo addytywne, które tworzymy
    ArrayList[] V = new ArrayList[2 * n];
    for (int i = 0; i < V.length; i++) {
      V[i] = new ArrayList<Integer>();
    }

    double suma, rmin, rr;
    int i, j, vertNum = n;

    while (n > 3) {
      // wyznaczanie r dla każdego liścia
      for (int a = 0; a < n; a++) {
        suma = 0;
        for (int b = 0; b < n; b++) {
          suma = suma + d.get(L.get(a), L.get(b));
        }
        suma = suma / (n - 2);
        R.set(a, suma);
      }
      // wyznaczania sąsiadów na podstawie r
      i = 0;
      j = 1;
      rmin = d.get(L.get(0), L.get(1)) - (R.get(0) + R.get(1));
      for (int a = 0; a < n - 1; a++) {
        for (int b = a + 1; b < n; b++) {
          rr = d.get(L.get(a), L.get(b)) - (R.get(a) + R.get(b));
          if (rr < rmin) {
            rmin = rr;
            i = a;
            j = b;
          }
        }
      }

      // usuniecie ze zbioru lisci i,j oraz dodanie k
      L.set(n, vertNum);
      vertNum++;
      i = L.remove(i);
      j = L.remove(j - 1);
      n = n - 1;

      // uaktualnienie d dla każdego pozostałego liścia
      for (int l = 0; l < n - 1; l++) {
        double value = (d.get(i, L.get(l)) + d.get(j, L.get(l)) - d.get(i, j)) / 2;
        d.put(L.get(n - 1), L.get(l), value);
        d.put(L.get(l), L.get(n - 1), value);
      }

      // dodanie odpowiednich krawędzi do tworzonego drzewa
      V[i].add((vertNum - 1));
      V[j].add((vertNum - 1));
      V[vertNum - 1].add(i);
      V[vertNum - 1].add(j);

      // wyznaczenie odległości między nowym wierzchołkiem oraz i,j
      double value = (d.get(i, j) + d.get(i, L.get(0)) - d.get(j, L.get(0))) / 2;
      d.put(i, vertNum - 1, value);
      d.put(vertNum - 1, i, value);
      d.put(j, vertNum - 1, d.get(i, j) - d.get(i, vertNum - 1));
      d.put(vertNum - 1, j, d.get(i, j) - d.get(i, vertNum - 1));
    }

    // 3 elementowe drzewo
    double value;
    value = (d.get(L.get(0), L.get(1)) + d.get(L.get(0), L.get(2)) - d.get(L.get(1), L.get(2))) / 2;
    d.put(L.get(0), vertNum, value);
    d.put(vertNum, L.get(0), value);

    value = (d.get(L.get(0), L.get(1)) + d.get(L.get(1), L.get(2)) - d.get(L.get(0), L.get(2))) / 2;
    d.put(L.get(1), vertNum, value);
    d.put(vertNum, L.get(1), value);

    value = (d.get(L.get(0), L.get(2)) + d.get(L.get(1), L.get(2)) - d.get(L.get(0), L.get(1))) / 2;
    d.put(L.get(2), vertNum, value);
    d.put(vertNum, L.get(2), value);

    V[vertNum].add(L.get(0));
    V[vertNum].add(L.get(1));
    V[vertNum].add(L.get(2));
    V[L.get(0)].add(vertNum);
    V[L.get(1)].add(vertNum);
    V[L.get(2)].add(vertNum);

    // wypisanie wyników
    System.out.println(d);

    // DoubleMatrix w2 = DoubleMatrix.zeros(2*n, 2*n);
    ArrayList w = new ArrayList<Integer>();

    for (int a = 0; a <= vertNum; a++) {
      System.out.print(a);
      System.out.print(" : ");
      for (int b = 0; b < V[a].size(); b++) {
        System.out.print(V[a].get(b));
        System.out.print(" ");

        // w2.put(a,b,Integer.parseInt(V[a].get(b).toString()));
        w.add(V[a].get(b));
      }
      System.out.println("");
    }

    DoubleMatrix A = DoubleMatrix.zeros((nn * (nn - 1)) / 2, vertNum);
    DoubleMatrix g = DoubleMatrix.zeros((nn * (nn - 1)) / 2, 1);
    double blad = nk(A, d2, g, V, vertNum); // wrzucam to do siebie - mkd

    System.out.println(A);

    DoubleMatrix p = (new DoubleMatrix(A.rows, A.columns, A.data)).transpose();
    System.out.println(p.rows + " " + p.columns);
    DoubleMatrix p2 =
        (new DoubleMatrix(p.rows, p.columns, p.data))
            .mmul((new DoubleMatrix(A.rows, A.columns, A.data)));
    System.out.println("p2: " + p2);
    DoubleMatrix p3 = MatrixFunctions.pow(p2, -1);
    System.out.println("p3: " + p3);
    DoubleMatrix p4 = p3.mmul(p);
    DoubleMatrix b = p4.mmul(g);

    // DoubleMatrix b = MatrixFunctions.pow(A.transpose().mmul(A), -1).mmul(A.transpose()).mmul(g);
    System.out.println(g);
    System.out.println(b);
    System.out.println("Kwadrat bledu wynosi " + blad);
  }
 @Override
 public double l2RegularizedCoefficient() {
   return (MatrixFunctions.pow(getW(), 2).sum() / 2.0) * l2 + 1e-6;
 }