Exemplo n.º 1
0
  public static double get1FitnessFromTestSet(
      double[] inputs, int classification, StrokeNetwork network, int featureToClassify) {
    double output = FullyConnectedNNBuilder.getOutput(inputs, network);

    if (classification == featureToClassify) {
      /*
      }
      if (classification == 0 || classification == 1 || classification == 3 || classification == 7 || classification == 8 || classification == 9
      		|| classification == 10 || classification == 11 || classification == 19 || classification == 20 || classification == 21
      		|| classification == 22 || classification == 23) {
      	*/
      /*
      if (output > 0) {
      	return 1.0;
      } else {
      	return -1.0;
      }*/
      return output;
    } else {
      /*
      if (output < 0) {
      	return 1.0;
      } else {
      	return -1.0;
      }
      */
      return -1 * output;
    }
  }
Exemplo n.º 2
0
  // return difference between output of desired classification, and output of highest
  // classification (not desired).
  public static double get1Fitness(
      double[] inputs,
      int classification,
      StrokeNetwork network,
      int featureToClassify,
      int num_possible) {
    double output = FullyConnectedNNBuilder.getOutput(inputs, network);

    if (classification == featureToClassify) {
      /*if (classification == 0 || classification == 1 || classification == 3 || classification == 7 || classification == 8 || classification == 9
      || classification == 10 || classification == 11 || classification == 19 || classification == 20 || classification == 21
      || classification == 22 || classification == 23) {
      */
      // System.out.println(output);
      return output;

    } else {

      // return -0.0416666667*output;
      // return -1.08333333333*output;
      // System.out.println(-1.0*(1.0/(num_possible-1.0))*output);

      return -1.0 * (1.0 / (num_possible - 1.0)) * output;
    }
  }
Exemplo n.º 3
0
  public static void evolve_nn_2(
      TrainingSet trainer,
      TrainingSet testSet,
      int numInputs,
      int populationSize,
      int numGenerations,
      int numToSelect,
      int largeMutationsPerSelected,
      double largeMutationFactor,
      int smallMutationsPerSelected,
      double smallMutationFactor,
      int numChildrenToBreed,
      int numToChoose,
      String fileOutput,
      int featureToClassify,
      int[] hiddenLayerSizes,
      boolean[] toggleFeatures,
      double maxThreshold,
      int maxOverfitGenerations,
      int num_possible)
      throws NumberFormatException, IOException {

    Vector<StrokeNetwork> netpopulation = new Vector<StrokeNetwork>();

    System.out.println("training set size: " + trainer.inputs.length);

    StrokeNetwork thisnet;

    for (int i = 0; i < populationSize; i++) {
      thisnet = FullyConnectedNNBuilder.networkBuilder(true, numInputs, 0.5, 1, hiddenLayerSizes);
      netpopulation.add(thisnet);
    }
    double[] fitnesses = getFitnesses(trainer, netpopulation, featureToClassify, num_possible);
    long time;
    long time2;

    System.out.println("testing set size: " + testSet.inputs.length);

    StrokeNetwork bestNetModel = (StrokeNetwork) netpopulation.get(0);
    int bestNet = 0;
    double bestNetFitness = 0;

    double bestNetTestedFitness = Integer.MIN_VALUE;
    StrokeNetwork bestNetTestedModel = null;

    int numGenerationsSinceImproved = 0;

    for (int i = 0; i < numGenerations; i++) {
      System.out.println("generation " + i);
      time = System.currentTimeMillis();
      netpopulation =
          evolvePopulation(
              netpopulation,
              fitnesses,
              numToSelect,
              largeMutationsPerSelected,
              largeMutationFactor,
              smallMutationsPerSelected,
              smallMutationFactor,
              numChildrenToBreed,
              numGenerations,
              numGenerationsSinceImproved);
      time2 = System.currentTimeMillis();
      System.out.println("mutation took: " + (time2 - time));
      fitnesses = getFitnesses(trainer, netpopulation, featureToClassify, num_possible);
      System.out.println("getting fitness took: " + (System.currentTimeMillis() - time2));
      System.out.println(
          "took: " + (System.currentTimeMillis() - time) + ", fitness: " + fitnesses[0]);

      bestNet = findMax(fitnesses);
      bestNetModel = (StrokeNetwork) netpopulation.get(bestNet);

      for (int j = 0; j < testSet.inputs.length; j++) {
        bestNetFitness +=
            get1FitnessFromTestSet(
                testSet.inputs[j], testSet.classifications[j], bestNetModel, featureToClassify);
        // System.out.println(bestNetFitness);
      }
      System.out.println("test fitness: " + bestNetFitness);
      System.out.println(bestNetTestedFitness);
      System.out.println(numGenerationsSinceImproved);
      if (bestNetFitness > (bestNetTestedFitness + 0.01)) {
        numGenerationsSinceImproved = 0;
        bestNetTestedModel = bestNetModel;
        bestNetTestedFitness = bestNetFitness;

      } else {
        numGenerationsSinceImproved++;
      }

      if (numGenerationsSinceImproved > maxOverfitGenerations) {
        break;
      }
      bestNetFitness = 0;
    }

    double trainFit = 0;
    for (int j = 0; j < trainer.inputs.length; j++) {
      trainFit +=
          get1Fitness(
              trainer.inputs[j],
              trainer.classifications[j],
              bestNetTestedModel,
              featureToClassify,
              num_possible);
    }

    double testFit = 0;
    for (int j = 0; j < testSet.inputs.length; j++) {
      testFit +=
          get1FitnessFromTestSet(
              testSet.inputs[j], testSet.classifications[j], bestNetTestedModel, featureToClassify);
    }
    System.out.println("most fit (training): " + trainFit);
    System.out.println("most fit (testing): " + testFit);

    FullyConnectedNNBuilder.networkToFile(fileOutput, bestNetTestedModel);
  }