/**
   * Create a LMA trainer.
   *
   * @param method The method to use.
   * @param training The training data to use.
   * @param argsStr The arguments to use.
   * @return The newly created trainer.
   */
  public MLTrain create(final MLMethod method, final MLDataSet training, final String argsStr) {

    if (!(method instanceof SOM)) {
      throw new EncogError(
          "Neighborhood training cannot be used on a method of type: "
              + method.getClass().getName());
    }

    final Map<String, String> args = ArchitectureParse.parseParams(argsStr);
    final ParamsHolder holder = new ParamsHolder(args);

    final double learningRate = holder.getDouble(MLTrainFactory.PROPERTY_LEARNING_RATE, false, 0.7);
    final String neighborhoodStr =
        holder.getString(MLTrainFactory.PROPERTY_NEIGHBORHOOD, false, "rbf");
    final String rbfTypeStr = holder.getString(MLTrainFactory.PROPERTY_RBF_TYPE, false, "gaussian");

    RBFEnum t;

    if (rbfTypeStr.equalsIgnoreCase("Gaussian")) {
      t = RBFEnum.Gaussian;
    } else if (rbfTypeStr.equalsIgnoreCase("Multiquadric")) {
      t = RBFEnum.Multiquadric;
    } else if (rbfTypeStr.equalsIgnoreCase("InverseMultiquadric")) {
      t = RBFEnum.InverseMultiquadric;
    } else if (rbfTypeStr.equalsIgnoreCase("MexicanHat")) {
      t = RBFEnum.MexicanHat;
    } else {
      t = RBFEnum.Gaussian;
    }

    NeighborhoodFunction nf = null;

    if (neighborhoodStr.equalsIgnoreCase("bubble")) {
      nf = new NeighborhoodBubble(1);
    } else if (neighborhoodStr.equalsIgnoreCase("rbf")) {
      final String str = holder.getString(MLTrainFactory.PROPERTY_DIMENSIONS, true, null);
      final int[] size = NumberList.fromListInt(CSVFormat.EG_FORMAT, str);
      nf = new NeighborhoodRBF(size, t);
    } else if (neighborhoodStr.equalsIgnoreCase("rbf1d")) {
      nf = new NeighborhoodRBF1D(t);
    }
    if (neighborhoodStr.equalsIgnoreCase("single")) {
      nf = new NeighborhoodSingle();
    }

    final BasicTrainSOM result = new BasicTrainSOM((SOM) method, learningRate, training, nf);

    if (args.containsKey(MLTrainFactory.PROPERTY_ITERATIONS)) {
      final int plannedIterations = holder.getInt(MLTrainFactory.PROPERTY_ITERATIONS, false, 1000);
      final double startRate =
          holder.getDouble(MLTrainFactory.PROPERTY_START_LEARNING_RATE, false, 0.05);
      final double endRate =
          holder.getDouble(MLTrainFactory.PROPERTY_END_LEARNING_RATE, false, 0.05);
      final double startRadius = holder.getDouble(MLTrainFactory.PROPERTY_START_RADIUS, false, 10);
      final double endRadius = holder.getDouble(MLTrainFactory.PROPERTY_END_RADIUS, false, 1);
      result.setAutoDecay(plannedIterations, startRate, endRate, startRadius, endRadius);
    }

    return result;
  }
예제 #2
0
  /**
   * Create a NEAT population.
   *
   * @param architecture The architecture string to use.
   * @param input The input count.
   * @param output The output count.
   * @return The population.
   */
  public MLMethod create(final String architecture, final int input, final int output) {

    if (input <= 0) {
      throw new EncogError("Must have at least one input for NEAT.");
    }

    if (output <= 0) {
      throw new EncogError("Must have at least one output for NEAT.");
    }

    final Map<String, String> args = ArchitectureParse.parseParams(architecture);
    final ParamsHolder holder = new ParamsHolder(args);

    final int populationSize = holder.getInt(MLMethodFactory.PROPERTY_POPULATION_SIZE, false, 1000);

    final int cycles =
        holder.getInt(MLMethodFactory.PROPERTY_CYCLES, false, NEATPopulation.DEFAULT_CYCLES);

    ActivationFunction af =
        this.factory.create(
            holder.getString(MLMethodFactory.PROPERTY_AF, false, MLActivationFactory.AF_SSIGMOID));

    NEATPopulation pop = new NEATPopulation(input, output, populationSize);
    pop.reset();
    pop.setActivationCycles(cycles);
    pop.setNEATActivationFunction(af);

    return pop;
  }