Пример #1
0
  /**
   * Calculates the AUC measurment.
   *
   * @param algorithm The optimisation algorithm with a NNTrainingProblem.
   * @return A Vector with the AUC for each NN output.
   */
  @Override
  public Vector getValue(Algorithm algorithm) {
    Vector solution = (Vector) algorithm.getBestSolution().getPosition();
    NNTrainingProblem problem = (NNTrainingProblem) algorithm.getOptimisationProblem();
    StandardPatternDataTable generalisationSet = problem.getGeneralisationSet();
    NeuralNetwork neuralNetwork = problem.getNeuralNetwork();
    neuralNetwork.setWeights(solution);

    // Arrange outputs and target values into ArrayLists.
    ArrayList<ArrayList<Real>> targets = new ArrayList<ArrayList<Real>>();
    ArrayList<ArrayList<Real>> outputs = new ArrayList<ArrayList<Real>>();
    // case of multiple outputs
    if (generalisationSet.getRow(0).getTarget() instanceof Vector) {
      int size = ((Vector) generalisationSet.getRow(0).getTarget()).size();
      for (int i = 0; i < size; ++i) {
        targets.add(new ArrayList<Real>());
        outputs.add(new ArrayList<Real>());
      }

      for (StandardPattern pattern : generalisationSet) {
        Vector target = (Vector) pattern.getTarget();
        Vector output = neuralNetwork.evaluatePattern(pattern);

        for (int curOutput = 0; curOutput < target.size(); ++curOutput) {
          targets.get(curOutput).add((Real) target.get(curOutput));
          outputs.get(curOutput).add((Real) output.get(curOutput));
        }
      }
    }
    // case of single output
    else {
      targets.add(new ArrayList<Real>());
      outputs.add(new ArrayList<Real>());

      for (StandardPattern pattern : generalisationSet) {
        Real target = (Real) pattern.getTarget();
        Vector output = neuralNetwork.evaluatePattern(pattern);

        targets.get(0).add(target);
        outputs.get(0).add((Real) output.get(0));
      }
    }

    // Calculate the Vector of AUC values
    Vector results = Vector.of();
    for (int curOutput = 0; curOutput < outputs.size(); ++curOutput) {
      results.add(Real.valueOf(areaUnderCurve(targets.get(curOutput), outputs.get(curOutput))));
    }

    return results;
  }
  /** Test the get row method. */
  @Test
  public void testGetRow() {
    StandardPattern pattern = stringTargetPatterns.getRow(0);
    List<Type> row = typedData.getRow(0);
    for (int j = 0; j < row.size() - 1; j++) {
      Assert.assertEquals(row.get(j), pattern.getVector().get(j));
    }
    Assert.assertEquals(row.get(row.size() - 1), pattern.getTarget());

    pattern = vectorTargetPatterns.getRow(0);
    row = typedData.getRow(2);
    for (int j = 0; j < row.size() - 3; j++) {
      Assert.assertEquals(row.get(j), pattern.getVector().get(j));
    }
    for (int j = row.size() - 3; j < row.size(); j++) {
      Assert.assertEquals(row.get(j), ((Vector) pattern.getTarget()).get(j - (row.size() - 3)));
    }
  }
  @Test
  public void testRemoveRow() {
    StandardPattern removed = stringTargetPatterns.removeRow(0);
    List<Type> row = typedData.getRow(0);
    for (int j = 0; j < row.size() - 1; j++) {
      Assert.assertEquals(row.get(j), removed.getVector().get(j));
    }
    Assert.assertEquals(row.get(row.size() - 1), removed.getTarget());

    removed = vectorTargetPatterns.removeRow(1);
    row = typedData.getRow(3);
    for (int j = 0; j < row.size() - 3; j++) {
      Assert.assertEquals(row.get(j), removed.getVector().get(j));
    }
    for (int j = row.size() - 3; j < row.size(); j++) {
      Assert.assertEquals(row.get(j), ((Vector) removed.getTarget()).get(j - (row.size() - 3)));
    }

    Assert.assertEquals(1, stringTargetPatterns.size());
    Assert.assertEquals(1, vectorTargetPatterns.size());
  }
Пример #4
0
  /**
   * Performs a gradient decent backpropagation given the previous {@link StandardPattern} as input
   * as well as the weight updates after the previous execution of a backpropagation. If the
   * previous weight updates do not exist, the visitor will create them and initialize them to zero.
   *
   * @param architecture
   */
  @Override
  public void visit(Architecture architecture) {
    List<Layer> layers = architecture.getLayers();
    int numLayers = layers.size();
    int currentLayerIdx = numLayers - 1; // skip input layer
    Layer currentLayer = layers.get(currentLayerIdx);
    int layerSize = currentLayer.size();
    Layer nextLayer = null;
    int nextLayerSize = -1;
    Neuron currentNeuron;

    // setup delta storage
    layerWeightsDelta = new double[numLayers - 1][]; // not necessary for input layer

    // calculate output layer deltas
    layerWeightsDelta[currentLayerIdx - 1] = new double[layerSize];
    for (int k = 0; k < layerSize; k++) {
      currentNeuron = currentLayer.get(k);
      double t_k =
          layerSize > 1
              ? ((Vector) previousPattern.getTarget()).doubleValueOf(k)
              : ((Real) previousPattern.getTarget()).doubleValue();
      double o_k = currentNeuron.getActivation();
      layerWeightsDelta[currentLayerIdx - 1][k] =
          -1.0 * (t_k - o_k) * currentNeuron.getActivationFunction().getGradient(o_k);
    }

    // calculate deltas for all hidden layers
    for (currentLayerIdx = numLayers - 2; currentLayerIdx > 0; currentLayerIdx--) {
      currentLayer = layers.get(currentLayerIdx);
      layerSize = currentLayer.size();
      layerSize = currentLayer.isBias() ? layerSize - 1 : layerSize;
      layerWeightsDelta[currentLayerIdx - 1] = new double[layerSize];
      for (int j = 0; j < layerSize; j++) {
        layerWeightsDelta[currentLayerIdx - 1][j] = 0.0;
        nextLayer = layers.get(currentLayerIdx + 1);
        nextLayerSize = nextLayer.size();
        nextLayerSize = nextLayer.isBias() ? nextLayerSize - 1 : nextLayerSize;
        for (int k = 0; k < nextLayerSize; k++) {
          double w_kj = nextLayer.get(k).getWeights().doubleValueOf(j);
          layerWeightsDelta[currentLayerIdx - 1][j] += w_kj * layerWeightsDelta[currentLayerIdx][k];
        }
        currentNeuron = currentLayer.get(j);
        layerWeightsDelta[currentLayerIdx - 1][j] *=
            currentNeuron.getActivationFunction().getGradient(currentNeuron.getActivation());
      }
    }

    // storage for the weight updates
    if (previousWeightUpdates == null) {
      previousWeightUpdates = new double[numLayers - 1][];
      for (currentLayerIdx = numLayers - 1; currentLayerIdx > 0; currentLayerIdx--) {
        for (int k = 0; k < layerSize; k++) {
          currentLayer = layers.get(currentLayerIdx);
          layerSize = currentLayer.isBias() ? currentLayer.size() - 1 : currentLayer.size();
          int previousLayerSize = layers.get(currentLayerIdx - 1).size();
          previousWeightUpdates[currentLayerIdx - 1] =
              new double[layerSize * previousLayerSize + previousLayerSize + 1];
        }
      }
    }

    ((ForwardingLayer) layers.get(0)).setSource(new PatternInputSource(previousPattern));
    // updates output and all hidden layer weights
    for (currentLayerIdx = numLayers - 1;
        currentLayerIdx > 0;
        currentLayerIdx--) { // loop excludes input layer
      currentLayer = layers.get(currentLayerIdx);
      layerSize = currentLayer.isBias() ? currentLayer.size() - 1 : currentLayer.size();
      int previousLayerSize = -1;
      Layer previousLayer = null;

      for (int k = 0; k < layerSize; k++) {
        currentNeuron = currentLayer.get(k);
        previousLayer = layers.get(currentLayerIdx - 1);
        previousLayerSize = previousLayer.size();

        double tmp = (-1.0 * learningRate) * layerWeightsDelta[currentLayerIdx - 1][k];
        for (int j = 0; j < previousLayerSize; j++) {
          double weight = currentNeuron.getWeights().doubleValueOf(j);
          double newWeightUpdate = tmp * previousLayer.getNeuralInput(j);
          double update =
              newWeightUpdate
                  + momentum
                      * previousWeightUpdates[currentLayerIdx - 1][k * previousLayerSize + j];
          currentNeuron.getWeights().setReal(j, weight + update);
          previousWeightUpdates[currentLayerIdx - 1][k * previousLayerSize + j] = newWeightUpdate;
        }
      }
    }
  }