public double[] getAvg(int output, Array1D state, KMetaDependencies meta) { double[] avg = new double[meta.origin().inputs().length]; double total = state.get(getCounter(output, meta)); if (total != 0) { for (int i = 0; i < meta.origin().inputs().length; i++) { avg[i] = state.get(getIndex(i, output, SUM, meta)) / total; } } return avg; }
public double[] getVariance(int output, Array1D state, double[] avg, KMetaDependencies meta) { double[] variances = new double[meta.origin().inputs().length]; double total = state.get(getCounter(output, meta)); if (total != 0) { for (int i = 0; i < meta.origin().inputs().length; i++) { variances[i] = state.get(getIndex(i, output, SUMSQUARE, meta)) / total - avg[i] * avg[i]; // x count/ (count-1) } } return variances; }
@Override public void train( double[][] trainingSet, double[][] expectedResultSet, KObject origin, KInternalDataManager manager) { int maxOutput = ((MetaEnum) origin.metaClass().outputs()[0].type()).literals().length; KMemoryChunk ks = manager.preciseChunk( origin.universe(), origin.now(), origin.uuid(), origin.metaClass(), ((AbstractKObject) origin).previousResolved()); int dependenciesIndex = origin.metaClass().dependencies().index(); // Create initial chunk if empty int size = (maxOutput + 1) * (origin.metaClass().inputs().length * NUMOFFIELDS + 1); if (ks.getDoubleArraySize(dependenciesIndex, origin.metaClass()) == 0) { ks.extendDoubleArray(origin.metaClass().dependencies().index(), size, origin.metaClass()); for (int i = 0; i < size; i++) { ks.setDoubleArrayElem(dependenciesIndex, i, 0, origin.metaClass()); } } Array1D state = new Array1D(size, 0, origin.metaClass().dependencies().index(), ks, origin.metaClass()); // update the state for (int i = 0; i < trainingSet.length; i++) { int output = (int) expectedResultSet[i][0]; for (int j = 0; j < origin.metaClass().inputs().length; j++) { // If this is the first datapoint if (state.get(getCounter(output, origin.metaClass().dependencies())) == 0) { state.set(getIndex(j, output, MIN, origin.metaClass().dependencies()), trainingSet[i][j]); state.set(getIndex(j, output, MAX, origin.metaClass().dependencies()), trainingSet[i][j]); state.set(getIndex(j, output, SUM, origin.metaClass().dependencies()), trainingSet[i][j]); state.set( getIndex(j, output, SUMSQUARE, origin.metaClass().dependencies()), trainingSet[i][j] * trainingSet[i][j]); } else { if (trainingSet[i][j] < state.get(getIndex(j, output, MIN, origin.metaClass().dependencies()))) { state.set( getIndex(j, output, MIN, origin.metaClass().dependencies()), trainingSet[i][j]); } if (trainingSet[i][j] > state.get(getIndex(j, output, MAX, origin.metaClass().dependencies()))) { state.set( getIndex(j, output, MAX, origin.metaClass().dependencies()), trainingSet[i][j]); } state.add(getIndex(j, output, SUM, origin.metaClass().dependencies()), trainingSet[i][j]); state.add( getIndex(j, output, SUMSQUARE, origin.metaClass().dependencies()), trainingSet[i][j] * trainingSet[i][j]); } // update global stat if (state.get(getCounter(maxOutput, origin.metaClass().dependencies())) == 0) { state.set( getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()), trainingSet[i][j]); state.set( getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()), trainingSet[i][j]); state.set( getIndex(j, maxOutput, SUM, origin.metaClass().dependencies()), trainingSet[i][j]); state.set( getIndex(j, maxOutput, SUMSQUARE, origin.metaClass().dependencies()), trainingSet[i][j] * trainingSet[i][j]); } else { if (trainingSet[i][j] < state.get(getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()))) { state.set( getIndex(j, maxOutput, MIN, origin.metaClass().dependencies()), trainingSet[i][j]); } if (trainingSet[i][j] > state.get(getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()))) { state.set( getIndex(j, maxOutput, MAX, origin.metaClass().dependencies()), trainingSet[i][j]); } state.add( getIndex(j, maxOutput, SUM, origin.metaClass().dependencies()), trainingSet[i][j]); state.add( getIndex(j, maxOutput, SUMSQUARE, origin.metaClass().dependencies()), trainingSet[i][j] * trainingSet[i][j]); } } // Update Global counters state.add(getCounter(output, origin.metaClass().dependencies()), 1); state.add(getCounter(maxOutput, origin.metaClass().dependencies()), 1); } }