コード例 #1
0
ファイル: ThreeLayer.java プロジェクト: pbloem/Lilian
  /**
   * Modifies this map through a single backpropagation iteration using the given error values on
   * the output nodes.
   *
   * @param error
   */
  public void train(List<Double> error, double learningRate) {
    RealVector eOut = new ArrayRealVector(error.size());
    for (int i : series(error.size())) eOut.setEntry(i, error.get(i));

    // * gHidden: delta for the non-bias nodes of the hidden layer
    gHidden.setSubVector(0, stateHidden.getSubVector(0, n)); // optimize

    for (int i : Series.series(gHidden.getDimension()))
      gHidden.setEntry(i, activation.derivative(gHidden.getEntry(i)));

    eHiddenL = weights1.transpose().operate(eOut);
    eHidden.setSubVector(0, eHiddenL.getSubVector(0, h));
    for (int i : series(h)) eHidden.setEntry(i, eHidden.getEntry(i) * gHidden.getEntry(i));

    weights1Delta = MatrixTools.outer(eOut, stateHidden);
    weights1Delta = weights1Delta.scalarMultiply(-1.0 * learningRate); // optimize

    weights0Delta = MatrixTools.outer(eHidden, stateIn);
    weights0Delta = weights0Delta.scalarMultiply(-1.0 * learningRate);

    weights0 = weights0.add(weights0Delta);
    weights1 = weights1.add(weights1Delta);
  }
コード例 #2
0
  /**
   * Solve an estimation problem using a least squares criterion.
   *
   * <p>This method set the unbound parameters of the given problem starting from their current
   * values through several iterations. At each step, the unbound parameters are changed in order to
   * minimize a weighted least square criterion based on the measurements of the problem.
   *
   * <p>The iterations are stopped either when the criterion goes below a physical threshold under
   * which improvement are considered useless or when the algorithm is unable to improve it (even if
   * it is still high). The first condition that is met stops the iterations. If the convergence it
   * not reached before the maximum number of iterations, an {@link EstimationException} is thrown.
   *
   * @param problem estimation problem to solve
   * @exception EstimationException if the problem cannot be solved
   * @see EstimationProblem
   */
  @Override
  public void estimate(EstimationProblem problem) throws EstimationException {

    initializeEstimate(problem);

    // work matrices
    double[] grad = new double[parameters.length];
    ArrayRealVector bDecrement = new ArrayRealVector(parameters.length);
    double[] bDecrementData = bDecrement.getDataRef();
    RealMatrix wGradGradT = MatrixUtils.createRealMatrix(parameters.length, parameters.length);

    // iterate until convergence is reached
    double previous = Double.POSITIVE_INFINITY;
    do {

      // build the linear problem
      incrementJacobianEvaluationsCounter();
      RealVector b = new ArrayRealVector(parameters.length);
      RealMatrix a = MatrixUtils.createRealMatrix(parameters.length, parameters.length);
      for (int i = 0; i < measurements.length; ++i) {
        if (!measurements[i].isIgnored()) {

          double weight = measurements[i].getWeight();
          double residual = measurements[i].getResidual();

          // compute the normal equation
          for (int j = 0; j < parameters.length; ++j) {
            grad[j] = measurements[i].getPartial(parameters[j]);
            bDecrementData[j] = weight * residual * grad[j];
          }

          // build the contribution matrix for measurement i
          for (int k = 0; k < parameters.length; ++k) {
            double gk = grad[k];
            for (int l = 0; l < parameters.length; ++l) {
              wGradGradT.setEntry(k, l, weight * gk * grad[l]);
            }
          }

          // update the matrices
          a = a.add(wGradGradT);
          b = b.add(bDecrement);
        }
      }

      try {

        // solve the linearized least squares problem
        RealVector dX = new LUDecompositionImpl(a).getSolver().solve(b);

        // update the estimated parameters
        for (int i = 0; i < parameters.length; ++i) {
          parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i));
        }

      } catch (InvalidMatrixException e) {
        throw new EstimationException("unable to solve: singular problem");
      }

      previous = cost;
      updateResidualsAndCost();

    } while ((getCostEvaluations() < 2)
        || (Math.abs(previous - cost) > (cost * steadyStateThreshold)
            && (Math.abs(cost) > convergence)));
  }