コード例 #1
0
 @Test
 public void testVectorAssembler() {
   StructType schema =
       createStructType(
           new StructField[] {
             createStructField("id", IntegerType, false),
             createStructField("x", DoubleType, false),
             createStructField("y", new VectorUDT(), false),
             createStructField("name", StringType, false),
             createStructField("z", new VectorUDT(), false),
             createStructField("n", LongType, false)
           });
   Row row =
       RowFactory.create(
           0,
           0.0,
           Vectors.dense(1.0, 2.0),
           "a",
           Vectors.sparse(2, new int[] {1}, new double[] {3.0}),
           10L);
   Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
   VectorAssembler assembler =
       new VectorAssembler()
           .setInputCols(new String[] {"x", "y", "z", "n"})
           .setOutputCol("features");
   Dataset<Row> output = assembler.transform(dataset);
   Assert.assertEquals(
       Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
       output.select("features").first().<Vector>getAs(0));
 }
コード例 #2
0
  private LogisticRegressionModel instantiateSparkModel() {
    Configuration conf = new Configuration();
    conf.set("fs.defaultFS", topologyConfig.getProperty("hdfs.url"));

    double[] sparkModelInfo = null;

    try {
      sparkModelInfo =
          getSparkModelInfoFromHDFS(
              new Path(topologyConfig.getProperty("hdfs.url") + "/tmp/sparkML_weights"), conf);
    } catch (Exception e) {
      LOG.error("Couldn't instantiate Spark model in prediction bolt: " + e.getMessage());
      e.printStackTrace();

      throw new RuntimeException(e);
    }

    // all numbers besides the last value are the weights
    double[] weights = Arrays.copyOfRange(sparkModelInfo, 0, sparkModelInfo.length - 1);

    // the last number in the array is the intercept
    double intercept = sparkModelInfo[sparkModelInfo.length - 1];

    org.apache.spark.mllib.linalg.Vector weightsV = (Vectors.dense(weights));
    return new LogisticRegressionModel(weightsV, intercept);
  }
コード例 #3
0
ファイル: KMeansMP.java プロジェクト: pashadow/cloudapp-mp5
 public Vector call(String line) {
   String[] tok = SPACE.split(line);
   double[] point = new double[tok.length - 1];
   for (int i = 1; i < tok.length; ++i) {
     point[i - 1] = Double.parseDouble(tok[i]);
   }
   return Vectors.dense(point);
 }
コード例 #4
0
 public LabeledPoint call(String line) throws Exception {
   String[] tok = SPACE.split(line);
   double label = Double.parseDouble(tok[tok.length - 1]);
   double[] point = new double[tok.length - 1];
   for (int i = 0; i < tok.length - 1; ++i) {
     point[i] = Double.parseDouble(tok[i]);
   }
   return new LabeledPoint(label, Vectors.dense(point));
 }
コード例 #5
0
 @Test
 public void vectorIndexerAPI() {
   // The tests are to check Java compatibility.
   List<FeatureData> points =
       Arrays.asList(
           new FeatureData(Vectors.dense(0.0, -2.0)),
           new FeatureData(Vectors.dense(1.0, 3.0)),
           new FeatureData(Vectors.dense(1.0, 4.0)));
   SQLContext sqlContext = new SQLContext(sc);
   DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
   VectorIndexer indexer =
       new VectorIndexer().setInputCol("features").setOutputCol("indexed").setMaxCategories(2);
   VectorIndexerModel model = indexer.fit(data);
   Assert.assertEquals(model.numFeatures(), 2);
   Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps();
   Assert.assertEquals(categoryMaps.size(), 1);
   DataFrame indexedData = model.transform(data);
 }
コード例 #6
0
  public void execute(Tuple input) {

    LOG.info("Entered prediction bolt execute...");
    String eventType = input.getStringByField("eventType");

    double prediction;

    if (eventType.equals("Normal")) {
      double[] predictionParams = enrichEvent(input);
      prediction = model.predict(Vectors.dense(predictionParams));

      LOG.info("Prediction is: " + prediction);

      String driverName = input.getStringByField("driverName");
      String routeName = input.getStringByField("routeName");
      int truckId = input.getIntegerByField("truckId");
      Timestamp eventTime = (Timestamp) input.getValueByField("eventTime");
      double longitude = input.getDoubleByField("longitude");
      double latitude = input.getDoubleByField("latitude");
      double driverId = input.getIntegerByField("driverId");
      SimpleDateFormat sdf = new SimpleDateFormat();

      collector.emit(
          input,
          new Values(
              prediction == 0.0 ? "normal" : "violation",
              driverName,
              routeName,
              driverId,
              truckId,
              sdf.format(new Date(eventTime.getTime())),
              longitude,
              latitude,
              predictionParams[0] == 1 ? "Y" : "N", // driver certification status
              predictionParams[1] == 1 ? "miles" : "hourly", // driver wage plan
              predictionParams[2] * 100, // hours feature was scaled down by 100
              predictionParams[3] * 1000, // miles feature was scaled down by 1000
              predictionParams[4] == 1 ? "Y" : "N", // foggy weather
              predictionParams[5] == 1 ? "Y" : "N", // rainy weather
              predictionParams[6] == 1 ? "Y" : "N" // windy weather
              ));

      if (prediction == 1.0) {

        try {
          writePredictionToHDFS(input, predictionParams, prediction);
        } catch (Exception e) {
          e.printStackTrace();
          throw new RuntimeException("Couldn't write prediction to hdfs" + e);
        }
      }
    }

    // acknowledge even if there is an error
    collector.ack(input);
  }
コード例 #7
0
 // Function for converting a csv line to a LabelPoint
 public LabeledPoint call(String line) {
   Logger logger = Logger.getLogger(this.getClass());
   logger.debug(line);
   // System.out.println(line);
   String[] parts = COMMA.split(line);
   double y = Double.parseDouble(parts[0]);
   double[] x = new double[parts.length - 1];
   for (int i = 1; i < parts.length; ++i) {
     x[i - 1] = Double.parseDouble(parts[i]);
   }
   return new LabeledPoint(y, Vectors.dense(x));
 }
コード例 #8
0
  public static void main(String[] args) throws IOException {

    SparkConf config = new SparkConf().setAppName("003-distributed-matrices").setMaster("local[*]");

    try (JavaSparkContext sc = new JavaSparkContext(config)) {

      /* Create a RowMatrix */
      List<Vector> vectors = new ArrayList<>(10);
      for (int i = 0; i < 10; i++) {
        vectors.add(Vectors.dense(getVectorElements()));
      }

      JavaRDD<Vector> rowsRDD = sc.parallelize(vectors, 4);

      RowMatrix rowMatrix = new RowMatrix(rowsRDD.rdd());
      System.out.println(rowMatrix.toString());

      /* Create an IndexedRowMatrix */
      JavaRDD<IndexedRow> indexedRows =
          sc.parallelize(
              Arrays.asList(new IndexedRow(0, vectors.get(0)), new IndexedRow(1, vectors.get(1))));
      IndexedRowMatrix indexedRowMatrix = new IndexedRowMatrix(indexedRows.rdd());
      System.out.println(indexedRowMatrix);

      /* convert */
      JavaRDD<IndexedRow> indexedRowsFromRowMatrix =
          rowMatrix
              .rows()
              .toJavaRDD()
              .zipWithIndex()
              .map((Tuple2<Vector, Long> t) -> new IndexedRow(t._2(), t._1()));
      IndexedRowMatrix indexedRowMatrixFromRowMatrix =
          new IndexedRowMatrix(indexedRowsFromRowMatrix.rdd());
      System.out.println(indexedRowMatrixFromRowMatrix);

      /* Create a CoordinateMatrix
       *     M = [ 5 0 1
       *           0 3 4 ]
       */
      JavaRDD<MatrixEntry> matrixEntries =
          sc.parallelize(
              Arrays.asList(
                  new MatrixEntry(0, 0, 5.),
                  new MatrixEntry(1, 1, 3.),
                  new MatrixEntry(2, 0, 1.),
                  new MatrixEntry(2, 1, 4.)));
      CoordinateMatrix coordMatrix = new CoordinateMatrix(matrixEntries.rdd());
      System.out.println(coordMatrix);
      printSeparator();
    }
  }
コード例 #9
0
ファイル: Test.java プロジェクト: Stronhold/NewsClasifier
  public static void main(String[] args) {

    // Path de resultados
    String pathResults = "results";

    String pathToCategories = "values.txt";
    String pathToWords = "words.txt";
    File file = new File(pathToWords);

    HashMap<Double, String> categoriesDict = new HashMap<>();
    HashMap<String, String> resultado = new HashMap<>();

    FileInputStream fis = null;
    try {
      fis = new FileInputStream(pathToCategories);
      // Construct BufferedReader from InputStreamReader
      BufferedReader br = new BufferedReader(new InputStreamReader(fis));

      String line = null;
      while ((line = br.readLine()) != null) {
        String[] words = line.split(" ");
        categoriesDict.put(Double.valueOf(words[0]), words[1]);
      }
      br.close();
    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } catch (IOException e) {
      e.printStackTrace();
    }

    // Path donde estaran las categorias
    String pathCategories = "src/main/resources/categoriestest/";

    // Configuracion basica de la aplicacion
    SparkConf sparkConf = new SparkConf().setAppName("NaiveBayesTest").setMaster("local[*]");

    // Creacion del contexto
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);

    NaiveBayesModel model = NaiveBayesModel.load(jsc.sc(), pathResults);

    HashMap<String, String> dictionary = loadDictionary();

    JavaRDD<String> fileWords = null;

    if (file.exists()) {
      JavaRDD<String> input = jsc.textFile(pathToWords);
      fileWords =
          input.flatMap(
              new FlatMapFunction<String, String>() {
                @Override
                public Iterable<String> call(String s) throws Exception {
                  return Arrays.asList(s.split(" "));
                }
              });
    } else {
      System.out.println("Error, there is no words");
      System.exit(-1);
    }
    ArrayList<String> aFileWords = (ArrayList<String>) fileWords.collect();

    // Cogemos el fichero en el que se encuentran las categorias
    File dir = new File(pathCategories);
    for (File f : dir.listFiles()) {
      JavaRDD<String> input = jsc.textFile(f.getPath());
      JavaRDD<String> words =
          input.flatMap(
              new FlatMapFunction<String, String>() {
                @Override
                public Iterable<String> call(String s) throws Exception {
                  return Arrays.asList(s.split(" "));
                }
              });
      JavaPairRDD<String, Double> wordCount = Reducer.parseWords(words, dictionary);
      List<Tuple2<String, Double>> total = wordCount.collect();
      List<Tuple2<String, Double>> elementsRemoved = new ArrayList<>();
      for (Tuple2<String, Double> t : total) {
        if (!t._1.equals("")) {
          elementsRemoved.add(new Tuple2<>(t._1, t._2 / wordCount.count()));
        }
      }
      ArrayList<Tuple2<String, Double>> freqFinal = new ArrayList<>();
      for (String s : aFileWords) {
        boolean found = false;
        for (Tuple2<String, Double> t : elementsRemoved) {
          if (t._1.equals(s)) {
            found = true;
            freqFinal.add(t);
            break;
          }
        }
        if (!found) {
          freqFinal.add(new Tuple2<String, Double>(s, 0.0));
        }
      }
      double[] v = new double[freqFinal.size()];
      for (int i = 0; i < freqFinal.size(); i++) {
        Tuple2<String, Double> t = freqFinal.get(i);
        v[i] = t._2;
      }
      org.apache.spark.mllib.linalg.Vector vector = Vectors.dense(v);
      /**/
      double d = model.predict(vector);
      System.out.println(categoriesDict.get(d));
      resultado.put(f.getName(), categoriesDict.get(d));
    }
    jsc.stop();
    try {
      Thread.sleep(2000);
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
    for (String key : resultado.keySet()) {
      System.out.println(key + " - " + resultado.get(key));
    }
  }