/**
   * This method trains a stacked autoencoder
   *
   * @param trainData Training dataset as a JavaRDD
   * @param batchSize Size of a training mini-batch
   * @param layerSizes Number of neurons for each layer
   * @param epochs Number of epochs to train
   * @param responseColumn Name of the response column
   * @param modelName Name of the model
   * @return DeepLearningModel
   */
  public DeepLearningModel train(
      JavaRDD<LabeledPoint> trainData,
      int batchSize,
      int[] layerSizes,
      String activationType,
      int epochs,
      String responseColumn,
      String modelName,
      MLModel mlModel,
      long modelID) {
    // build stacked autoencoder by training the model with training data

    double trainingFraction = 1;
    try {
      Scope.enter();
      if (trainData != null) {

        int numberOfFeatures = mlModel.getFeatures().size();
        List<Feature> features = mlModel.getFeatures();
        String[] names = new String[numberOfFeatures + 1];
        for (int i = 0; i < numberOfFeatures; i++) {
          names[i] = features.get(i).getName();
        }
        names[numberOfFeatures] = mlModel.getResponseVariable();

        Frame frame = DeeplearningModelUtils.javaRDDToFrame(names, trainData);

        // H2O uses default C<x> for column header
        // String classifColName = "C" + frame.numCols();
        String classifColName = mlModel.getResponseVariable();

        // Convert response to categorical (digits 1 to <num of columns>)
        int ci = frame.find(classifColName);
        Scope.track(frame.replace(ci, frame.vecs()[ci].toEnum())._key);

        // Splitting train file to train, validation and test
        // Using FrameSplitter (instead of SuffleSplitFrame) gives a weird exception
        // barrier onExCompletion for hex.deeplearning.DeepLearning$DeepLearningDriver@78ec854
        double[] ratios = new double[] {trainingFraction, 1 - trainingFraction};
        @SuppressWarnings("unchecked")
        Frame[] splits =
            ShuffleSplitFrame.shuffleSplitFrame(
                frame, generateNumKeys(frame._key, ratios.length), ratios, 123456789);

        Frame trainFrame = splits[0];
        Frame vframe = splits[1];

        if (log.isDebugEnabled()) {
          log.debug("Creating Deeplearning parameters");
        }

        DeepLearningParameters deeplearningParameters = new DeepLearningParameters();

        // convert model name
        String dlModelName = modelName.replace('.', '_').replace('-', '_');

        // populate model parameters
        deeplearningParameters._model_id = Key.make(dlModelName + "_dl");
        deeplearningParameters._train = trainFrame._key;
        deeplearningParameters._valid = vframe._key;
        deeplearningParameters._response_column = classifColName; // last column is the response
        // This is causin all the predictions to be 0.0
        // p._autoencoder = true;
        deeplearningParameters._activation = getActivationType(activationType);
        deeplearningParameters._hidden = layerSizes;
        deeplearningParameters._train_samples_per_iteration = batchSize;
        deeplearningParameters._input_dropout_ratio = 0.2;
        deeplearningParameters._l1 = 1e-5;
        deeplearningParameters._max_w2 = 10;
        deeplearningParameters._epochs = epochs;

        // speed up training
        deeplearningParameters._adaptive_rate =
            true; // disable adaptive per-weight learning rate -> default
        // settings for learning rate and momentum are probably
        // not ideal (slow convergence)
        deeplearningParameters._replicate_training_data =
            true; // avoid extra communication cost upfront, got
        // enough data on each node for load balancing
        deeplearningParameters._overwrite_with_best_model =
            true; // no need to keep the best model around
        deeplearningParameters._diagnostics =
            false; // no need to compute statistics during training
        deeplearningParameters._classification_stop = -1;
        deeplearningParameters._score_interval =
            60; // score and print progress report (only) every 20 seconds
        deeplearningParameters._score_training_samples =
            batchSize / 10; // only score on a small sample of the
        // training set -> don't want to spend
        // too much time scoring (note: there
        // will be at least 1 row per chunk)

        DKV.put(trainFrame);
        DKV.put(vframe);

        deeplearning = new DeepLearning(deeplearningParameters);

        if (log.isDebugEnabled()) {
          log.debug("Start training deeplearning model ....");
        }

        try {
          dlModel = deeplearning.trainModel().get();
          if (log.isDebugEnabled()) {
            log.debug("Successfully finished Training deeplearning model.");
          }

        } catch (RuntimeException ex) {
          log.error("Error in training Stacked Autoencoder classifier model", ex);
        }
      } else {
        log.error("Train file not found!");
      }
    } catch (RuntimeException ex) {
      log.error("Failed to train the deeplearning model [id] " + modelID + ". " + ex.getMessage());
    } finally {
      Scope.exit();
    }

    return dlModel;
  }
  @Test
  @Ignore
  public void run() {
    Scope.enter();
    try {
      File file = find_test_file("bigdata/laptop/mnist/train.csv.gz");
      File valid = find_test_file("bigdata/laptop/mnist/test.csv.gz");
      if (file != null) {
        NFSFileVec trainfv = NFSFileVec.make(file);
        Frame frame = ParseDataset.parse(Key.make(), trainfv._key);
        NFSFileVec validfv = NFSFileVec.make(valid);
        Frame vframe = ParseDataset.parse(Key.make(), validfv._key);
        DeepLearningParameters p = new DeepLearningParameters();

        // populate model parameters
        p._model_id = Key.make("dl_mnist_model");
        p._train = frame._key;
        //        p._valid = vframe._key;
        p._response_column = "C785"; // last column is the response
        p._activation = DeepLearningParameters.Activation.RectifierWithDropout;
        //        p._activation = DeepLearningParameters.Activation.MaxoutWithDropout;
        p._hidden = new int[] {800, 800};
        p._input_dropout_ratio = 0.2;
        p._mini_batch_size = 1;
        p._train_samples_per_iteration = 50000;
        p._score_duty_cycle = 0;
        //        p._shuffle_training_data = true;
        //        p._l1= 1e-5;
        //        p._max_w2= 10;
        p._epochs = 10 * 5. / 6;

        // Convert response 'C785' to categorical (digits 1 to 10)
        int ci = frame.find("C785");
        Scope.track(frame.replace(ci, frame.vecs()[ci].toEnum())._key);
        Scope.track(vframe.replace(ci, vframe.vecs()[ci].toEnum())._key);
        DKV.put(frame);
        DKV.put(vframe);

        // speed up training
        p._adaptive_rate =
            true; // disable adaptive per-weight learning rate -> default settings for learning rate
                  // and momentum are probably not ideal (slow convergence)
        p._replicate_training_data =
            true; // avoid extra communication cost upfront, got enough data on each node for load
                  // balancing
        p._overwrite_with_best_model = true; // no need to keep the best model around
        p._classification_stop = -1;
        p._score_interval = 60; // score and print progress report (only) every 20 seconds
        p._score_training_samples =
            10000; // only score on a small sample of the training set -> don't want to spend too
                   // much time scoring (note: there will be at least 1 row per chunk)

        DeepLearning dl = new DeepLearning(p);
        DeepLearningModel model = null;
        try {
          model = dl.trainModel().get();
        } catch (Throwable t) {
          t.printStackTrace();
          throw new RuntimeException(t);
        } finally {
          dl.remove();
          if (model != null) {
            model.delete();
          }
        }
      } else {
        Log.info("Please run ./gradlew syncBigDataLaptop in the top-level directory of h2o-3.");
      }
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      Scope.exit();
    }
  }