Пример #1
0
  public static void main(String[] args) {
    if (!parseArgs(args)) throw new IllegalArgumentException("Parse arguments failed.");

    SparkConf conf = new SparkConf().setAppName("Logistic Regression with SGD");
    SparkContext sc = new SparkContext(conf);

    JavaRDD<String> data = sc.textFile(inputFile, 1).toJavaRDD();
    JavaRDD<LabeledPoint> training =
        data.map(
                new Function<String, LabeledPoint>() {
                  public LabeledPoint call(String line) {
                    String[] splits = line.split(",");
                    double[] features = new double[3];
                    try {
                      features[0] = Double.valueOf(splits[1]);
                      features[1] = Double.valueOf(splits[2]);
                      features[2] = Double.valueOf(splits[3]);
                      return new LabeledPoint(Double.valueOf(splits[3]), Vectors.dense(features));
                    } catch (NumberFormatException e) {
                      return null; // Nothing to do..
                    }
                  }
                })
            .filter(
                new Function<LabeledPoint, Boolean>() {
                  public Boolean call(LabeledPoint p) {
                    return p != null;
                  }
                })
            .cache();

    LogisticRegressionModel model = lrs.run(training.rdd());
    model.save(sc, outputFile);
    sc.stop();
  }
Пример #2
0
  public void execute(Tuple input) {

    LOG.info("Entered prediction bolt execute...");
    String eventType = input.getStringByField("eventType");

    double prediction;

    if (eventType.equals("Normal")) {
      double[] predictionParams = enrichEvent(input);
      prediction = model.predict(Vectors.dense(predictionParams));

      LOG.info("Prediction is: " + prediction);

      String driverName = input.getStringByField("driverName");
      String routeName = input.getStringByField("routeName");
      int truckId = input.getIntegerByField("truckId");
      Timestamp eventTime = (Timestamp) input.getValueByField("eventTime");
      double longitude = input.getDoubleByField("longitude");
      double latitude = input.getDoubleByField("latitude");
      double driverId = input.getIntegerByField("driverId");
      SimpleDateFormat sdf = new SimpleDateFormat();

      collector.emit(
          input,
          new Values(
              prediction == 0.0 ? "normal" : "violation",
              driverName,
              routeName,
              driverId,
              truckId,
              sdf.format(new Date(eventTime.getTime())),
              longitude,
              latitude,
              predictionParams[0] == 1 ? "Y" : "N", // driver certification status
              predictionParams[1] == 1 ? "miles" : "hourly", // driver wage plan
              predictionParams[2] * 100, // hours feature was scaled down by 100
              predictionParams[3] * 1000, // miles feature was scaled down by 1000
              predictionParams[4] == 1 ? "Y" : "N", // foggy weather
              predictionParams[5] == 1 ? "Y" : "N", // rainy weather
              predictionParams[6] == 1 ? "Y" : "N" // windy weather
              ));

      if (prediction == 1.0) {

        try {
          writePredictionToHDFS(input, predictionParams, prediction);
        } catch (Exception e) {
          e.printStackTrace();
          throw new RuntimeException("Couldn't write prediction to hdfs" + e);
        }
      }
    }

    // acknowledge even if there is an error
    collector.ack(input);
  }
  /**
   * This method builds a logistic regression model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @param isSGD Whether the algorithm is Logistic regression with SGD
   * @throws MLModelBuilderException
   */
  private ModelSummary buildLogisticRegressionModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      boolean isSGD)
      throws MLModelBuilderException {
    try {
      LogisticRegression logisticRegression = new LogisticRegression();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      LogisticRegressionModel logisticRegressionModel;
      String algorithmName;

      int noOfClasses = getNoOfClasses(mlModel);

      if (isSGD) {
        algorithmName = SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION.toString();

        if (noOfClasses > 2) {
          throw new MLModelBuilderException(
              "A binary classification algorithm cannot have more than "
                  + "two distinct values in response variable.");
        }

        logisticRegressionModel =
            logisticRegression.trainWithSGD(
                trainingData,
                Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
                Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
                hyperParameters.get(MLConstants.REGULARIZATION_TYPE),
                Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
                Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));
      } else {
        algorithmName = SUPERVISED_ALGORITHM.LOGISTIC_REGRESSION_LBFGS.toString();
        logisticRegressionModel =
            logisticRegression.trainWithLBFGS(
                trainingData, hyperParameters.get(MLConstants.REGULARIZATION_TYPE), noOfClasses);
      }

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = logisticRegressionModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }

      // getting scores and labels without clearing threshold to get confusion matrix
      JavaRDD<Tuple2<Object, Object>> scoresAndLabelsThresholded =
          logisticRegression.test(logisticRegressionModel, testingData);
      MulticlassMetrics multiclassMetrics =
          new MulticlassMetrics(JavaRDD.toRDD(scoresAndLabelsThresholded));
      MulticlassConfusionMatrix multiclassConfusionMatrix =
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel);

      // clearing the threshold value to get a probability as the output of the prediction
      logisticRegressionModel.clearThreshold();
      JavaRDD<Tuple2<Object, Object>> scoresAndLabels =
          logisticRegression.test(logisticRegressionModel, testingData);
      ProbabilisticClassificationModelSummary probabilisticClassificationModelSummary =
          SparkModelUtils.generateProbabilisticClassificationModelSummary(
              sparkContext, testingData, scoresAndLabels);
      mlModel.setModel(new MLClassificationModel(logisticRegressionModel));

      // remove from cache
      testingData.unpersist();

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, logisticRegressionModel.weights().toArray());
      probabilisticClassificationModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      probabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
      probabilisticClassificationModelSummary.setAlgorithm(algorithmName);

      probabilisticClassificationModelSummary.setMulticlassConfusionMatrix(
          multiclassConfusionMatrix);
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      probabilisticClassificationModelSummary.setModelAccuracy(modelAccuracy);
      probabilisticClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return probabilisticClassificationModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building logistic regression model: " + e.getMessage(), e);
    }
  }