public static void main(String... args) throws Exception {
    int numFeatures = 40;
    int iterations = 5;
    int seed = 123;
    int listenerFreq = iterations / 5;
    Nd4j.getRandom().setSeed(seed);

    log.info("Load dat....");
    INDArray input =
        Nd4j.create(
            2,
            numFeatures); // have to be at least two or else output layer gradient is a scalar and
    // cause exception
    INDArray labels = Nd4j.create(2, 2);

    INDArray row0 = Nd4j.create(1, numFeatures);
    row0.assign(0.1);
    input.putRow(0, row0);
    labels.put(0, 1, 1); // set the 4th column

    INDArray row1 = Nd4j.create(1, numFeatures);
    row1.assign(0.2);

    input.putRow(1, row1);
    labels.put(1, 0, 1); // set the 2nd column

    DataSet trainingSet = new DataSet(input, labels);

    log.info("Build model....");
    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .layer(new RBM())
            .nIn(trainingSet.numInputs())
            .nOut(trainingSet.numOutcomes())
            .seed(seed)
            .weightInit(WeightInit.SIZE)
            .constrainGradientToUnitNorm(true)
            .iterations(iterations)
            .activationFunction("tanh")
            .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
            .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
            .lossFunction(LossFunctions.LossFunction.RMSE_XENT)
            .learningRate(1e-1f)
            .optimizationAlgo(OptimizationAlgorithm.ITERATION_GRADIENT_DESCENT)
            .build();
    Layer model = LayerFactories.getFactory(conf).create(conf);
    model.setIterationListeners(
        Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Evaluate weights....");
    INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
    log.info("Weights: " + w);

    log.info("Train model....");
    model.fit(trainingSet.getFeatureMatrix());

    log.info("Visualize training results....");
    // Work in progress to get NeuralNetPlotter functioning
    NeuralNetPlotter plotter = new NeuralNetPlotter();
    plotter.plotNetworkGradient(model, model.gradient(), 10);
  }