Exemplo n.º 1
0
 public double[] getAvg(int output, Array1D state, KMetaDependencies meta) {
   double[] avg = new double[meta.origin().inputs().length];
   double total = state.get(getCounter(output, meta));
   if (total != 0) {
     for (int i = 0; i < meta.origin().inputs().length; i++) {
       avg[i] = state.get(getIndex(i, output, SUM, meta)) / total;
     }
   }
   return avg;
 }
Exemplo n.º 2
0
 public double[] getVariance(int output, Array1D state, double[] avg, KMetaDependencies meta) {
   double[] variances = new double[meta.origin().inputs().length];
   double total = state.get(getCounter(output, meta));
   if (total != 0) {
     for (int i = 0; i < meta.origin().inputs().length; i++) {
       variances[i] =
           state.get(getIndex(i, output, SUMSQUARE, meta)) / total
               - avg[i] * avg[i]; // x count/ (count-1)
     }
   }
   return variances;
 }
Exemplo n.º 3
0
  @Override
  public void train(
      double[][] trainingSet,
      double[][] expectedResultSet,
      KObject origin,
      KInternalDataManager manager) {
    int maxOutput = ((MetaEnum) origin.metaClass().outputs()[0].type()).literals().length;
    KMemoryChunk ks =
        manager.preciseChunk(
            origin.universe(),
            origin.now(),
            origin.uuid(),
            origin.metaClass(),
            ((AbstractKObject) origin).previousResolved());
    int dependenciesIndex = origin.metaClass().dependencies().index();
    // Create initial chunk if empty
    int size = (maxOutput + 1) * (origin.metaClass().inputs().length * NUMOFFIELDS + 1);
    if (ks.getDoubleArraySize(dependenciesIndex, origin.metaClass()) == 0) {
      ks.extendDoubleArray(origin.metaClass().dependencies().index(), size, origin.metaClass());
      for (int i = 0; i < size; i++) {
        ks.setDoubleArrayElem(dependenciesIndex, i, 0, origin.metaClass());
      }
    }

    Array1D state =
        new Array1D(size, 0, origin.metaClass().dependencies().index(), ks, origin.metaClass());

    // update the state
    for (int i = 0; i < trainingSet.length; i++) {
      int output = (int) expectedResultSet[i][0];
      for (int j = 0; j < origin.metaClass().inputs().length; j++) {
        // If this is the first datapoint
        if (state.get(getCounter(output, origin.metaClass().dependencies())) == 0) {
          state.set(getIndex(j, output, MIN, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(getIndex(j, output, MAX, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(getIndex(j, output, SUM, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(
              getIndex(j, output, SUMSQUARE, origin.metaClass().dependencies()),
              trainingSet[i][j] * trainingSet[i][j]);

        } else {
          if (trainingSet[i][j]
              < state.get(getIndex(j, output, MIN, origin.metaClass().dependencies()))) {
            state.set(
                getIndex(j, output, MIN, origin.metaClass().dependencies()), trainingSet[i][j]);
          }
          if (trainingSet[i][j]
              > state.get(getIndex(j, output, MAX, origin.metaClass().dependencies()))) {
            state.set(
                getIndex(j, output, MAX, origin.metaClass().dependencies()), trainingSet[i][j]);
          }
          state.add(getIndex(j, output, SUM, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.add(
              getIndex(j, output, SUMSQUARE, origin.metaClass().dependencies()),
              trainingSet[i][j] * trainingSet[i][j]);
        }

        // update global stat
        if (state.get(getCounter(maxOutput, origin.metaClass().dependencies())) == 0) {
          state.set(
              getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(
              getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(
              getIndex(j, maxOutput, SUM, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.set(
              getIndex(j, maxOutput, SUMSQUARE, origin.metaClass().dependencies()),
              trainingSet[i][j] * trainingSet[i][j]);
        } else {
          if (trainingSet[i][j]
              < state.get(getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()))) {
            state.set(
                getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()), trainingSet[i][j]);
          }
          if (trainingSet[i][j]
              > state.get(getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()))) {
            state.set(
                getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()), trainingSet[i][j]);
          }
          state.add(
              getIndex(j, maxOutput, SUM, origin.metaClass().dependencies()), trainingSet[i][j]);
          state.add(
              getIndex(j, maxOutput, SUMSQUARE, origin.metaClass().dependencies()),
              trainingSet[i][j] * trainingSet[i][j]);
        }
      }

      // Update Global counters
      state.add(getCounter(output, origin.metaClass().dependencies()), 1);
      state.add(getCounter(maxOutput, origin.metaClass().dependencies()), 1);
    }
  }