public void testSOMPersist() throws Exception { Matrix matrix = new Matrix(TestPersist.trainedData); double pattern1[] = {-0.5, -0.5, -0.5, -0.5}; double pattern2[] = {0.5, 0.5, 0.5, 0.5}; double pattern3[] = {-0.5, -0.5, -0.5, 0.5}; double pattern4[] = {0.5, 0.5, 0.5, -0.5}; NeuralData data1 = new BasicNeuralData(pattern1); NeuralData data2 = new BasicNeuralData(pattern2); NeuralData data3 = new BasicNeuralData(pattern3); NeuralData data4 = new BasicNeuralData(pattern4); SOMLayer layer; BasicNetwork network = new BasicNetwork(); network.addLayer(layer = new SOMLayer(4, NormalizationType.MULTIPLICATIVE)); network.addLayer(new BasicLayer(2)); layer.setMatrix(matrix); EncogPersistedCollection encog = new EncogPersistedCollection(); encog.add(network); encog.save("encogtest.xml"); EncogPersistedCollection encog2 = new EncogPersistedCollection(); encog2.load("encogtest.xml"); new File("encogtest.xml").delete(); BasicNetwork network2 = (BasicNetwork) encog2.getList().get(0); int data1Neuron = network2.winner(data1); int data2Neuron = network2.winner(data2); TestCase.assertTrue(data1Neuron != data2Neuron); int data3Neuron = network2.winner(data3); int data4Neuron = network2.winner(data4); TestCase.assertTrue(data3Neuron == data1Neuron); TestCase.assertTrue(data4Neuron == data2Neuron); }
public static void main(String args[]) { int inputNeurons = CHAR_WIDTH * CHAR_HEIGHT; int outputNeurons = DIGITS.length; BasicNetwork network = new BasicNetwork(); Layer inputLayer = new BasicLayer(new ActivationLinear(), false, inputNeurons); Layer outputLayer = new BasicLayer(new ActivationLinear(), true, outputNeurons); network.addLayer(inputLayer); network.addLayer(outputLayer); network.getStructure().finalizeStructure(); (new RangeRandomizer(-0.5, 0.5)).randomize(network); // train it NeuralDataSet training = generateTraining(); Train train = new TrainAdaline(network, training, 0.01); int epoch = 1; do { train.iteration(); System.out.println("Epoch #" + epoch + " Error:" + train.getError()); epoch++; } while (train.getError() > 0.01); // System.out.println("Error:" + network.calculateError(training)); // test it for (int i = 0; i < DIGITS.length; i++) { int output = network.winner(image2data(DIGITS[i])); for (int j = 0; j < CHAR_HEIGHT; j++) { if (j == CHAR_HEIGHT - 1) System.out.println(DIGITS[i][j] + " -> " + output); else System.out.println(DIGITS[i][j]); } System.out.println(); } }