public static void main(String... args) throws Exception { int numFeatures = 40; int iterations = 5; int seed = 123; int listenerFreq = iterations / 5; Nd4j.getRandom().setSeed(seed); log.info("Load dat...."); INDArray input = Nd4j.create( 2, numFeatures); // have to be at least two or else output layer gradient is a scalar and // cause exception INDArray labels = Nd4j.create(2, 2); INDArray row0 = Nd4j.create(1, numFeatures); row0.assign(0.1); input.putRow(0, row0); labels.put(0, 1, 1); // set the 4th column INDArray row1 = Nd4j.create(1, numFeatures); row1.assign(0.2); input.putRow(1, row1); labels.put(1, 0, 1); // set the 2nd column DataSet trainingSet = new DataSet(input, labels); log.info("Build model...."); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new RBM()) .nIn(trainingSet.numInputs()) .nOut(trainingSet.numOutcomes()) .seed(seed) .weightInit(WeightInit.SIZE) .constrainGradientToUnitNorm(true) .iterations(iterations) .activationFunction("tanh") .visibleUnit(RBM.VisibleUnit.GAUSSIAN) .hiddenUnit(RBM.HiddenUnit.RECTIFIED) .lossFunction(LossFunctions.LossFunction.RMSE_XENT) .learningRate(1e-1f) .optimizationAlgo(OptimizationAlgorithm.ITERATION_GRADIENT_DESCENT) .build(); Layer model = LayerFactories.getFactory(conf).create(conf); model.setIterationListeners( Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq))); log.info("Evaluate weights...."); INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY); log.info("Weights: " + w); log.info("Train model...."); model.fit(trainingSet.getFeatureMatrix()); log.info("Visualize training results...."); // Work in progress to get NeuralNetPlotter functioning NeuralNetPlotter plotter = new NeuralNetPlotter(); plotter.plotNetworkGradient(model, model.gradient(), 10); }