public void saveModel() { try { File tmpFile = File.createTempFile("model", null); ModelSerializer.writeModel(net.getNN(), tmpFile, true); } catch (IOException ioe) { } }
public void clearAll() { clearData(); try { net.clear(); } catch (NullPointerException npe) { Logger.getLogger(MLSegmentationService.class.getName()).log(Level.SEVERE, null, npe); } }
public INDArray classify() { DataSetIterator iter = new ProfileIterator(testData, imgDatasets); INDArray predict = net.output(iter); // int size0 = predict.size(0); // int size1 = predict.size(1); // int size2 = predict.size(2); // INDArray element = predict.getRow(0); // System.out.println("predict DONE"); return predict; }
public INDArray dummyClassification() { DataSetIterator iter = new ProfileIterator(testData, imgDatasets); INDArray predict = net.output(iter); INDArray dummy = Nd4j.create(new int[] {predict.size(0), predict.size(1), predict.size(2)}, 'f'); for (int i = 0; i < predict.size(0); i++) { for (int j = 0; j < predict.size(2); j++) { if (j < 10 || j > 14) dummy.putScalar(new int[] {i, 0, j}, 0.0); else dummy.putScalar(new int[] {i, 0, j}, 1.0); } } return dummy; }
public void train() { DataSetIterator iter = new ProfileIterator(trainingData, confirmationSet, imgDatasets, true); int iEpoch = 0; int nEpochs = 300; while (iEpoch < nEpochs) { System.out.printf("EPOCH %d\n", iEpoch); Evaluation eval = new Evaluation(); while (iter.hasNext()) { DataSet ds = iter.next(); net.train(ds); INDArray predict2 = net.output(ds.getFeatureMatrix()); INDArray labels2 = ds.getLabels(); // eval.evalTimeSeries(labels2, predict2); } iter.reset(); // System.out.println(eval.stats()); iEpoch++; } System.out.println("Fitting : DONE"); }