Пример #1
0
  public static void trainModel(String filteredDataPath, String modelpath) throws IOException {

    String line = "";
    String combline = "";
    // read and process raw data
    BufferedReader br = new BufferedReader(new FileReader(filteredDataPath));

    while ((line = br.readLine()) != null) combline = combline + " " + line;

    List<String> words = Lists.newArrayList(combline.split(" "));
    List<List<String>> localDoc = Lists.newArrayList(words, words);

    // build a context object
    JavaSparkContext sc = new JavaSparkContext("local", "Word2VecSuite");
    JavaRDD<List<String>> doc = sc.parallelize(localDoc);

    // training settings
    Word2Vec word2vec = new Word2Vec().setVectorSize(100).setMinCount(50).setSeed(42L);

    // train
    Word2VecModel model = word2vec.fit(doc);

    // save model
    SparkContext sc1 = sc.toSparkContext(sc);
    model.save(sc1, modelpath);
    System.out.println("Model has been saved in folder: " + modelpath);
  }
Пример #2
0
  public static double cos50(String word, Tuple2<String, Object>[] syms, Word2VecModel model) {

    double[] v1 = model.transform(word).toArray();

    // Tuple2<String, Object>[] syms2 = model.findSynonyms(word, 20);

    int count = 0;
    double cosSum = 0;
    double cosAvg = 0;
    for (Tuple2<String, Object> t : syms) {
      double[] v2 = model.transform(t._1).toArray();
      cosSum = cosSum + cosineSimilarity(v1, v2);
      // cosSum=cosSum+distance(v1, v2);
    }

    cosAvg = cosSum / syms.length;

    return cosAvg;
  }
Пример #3
0
  public static List<List<String>> getSynonyms(List<String> words, String modelpath)
      throws IOException {
    /**
     * //read term file to get the list of phrases List<String> terms = new ArrayList<>();
     * logger.log(Level.INFO, "*********Reading term file...."); String combline=""; String line
     * =""; try { BufferedReader br = new BufferedReader(new FileReader(termFile)); int i=0; while
     * ((line=br.readLine()) != null) { String phr = line.replaceAll("[^A-Za-z]"," ").trim();
     * terms.add(phr); } } catch (FileNotFoundException e) { // TODO Auto-generated catch block
     * e.printStackTrace(); }
     */
    SparkContext sc = new SparkContext("local", "appName");

    // load a trained model
    Word2VecModel model = Word2VecModel.load(sc, modelpath);

    // find similar words
    List<List<String>> synsSet = new ArrayList<List<String>>();

    for (String word : words) {
      word = word.replaceAll("\\s", "-");
      Tuple2<String, Object>[] syms1 = model.findSynonyms(word, 50);

      // cos50 value of the target word
      // double cos50Value = cos50(word,syms, model);
      // System.out.println("CosAvg-50 value of the target word " + word +":"+cos50Value);

      HashMap<String, Double> closeList = new HashMap<String, Double>();

      System.out.println("The nearests of the word " + word + ":");
      int count = 0;
      double cos50Value1 = 0;
      double cos50Value2 = 0;
      List<String> syns = new ArrayList<>();
      syns.add(word); // add itself
      for (Tuple2<String, Object> t : syms1) {
        syns.add(t._1);

        /**
         * String term=""; int termIndex=0; if (t._1.contains("_")){ termIndex = Integer.parseInt(
         * t._1.substring(t._1.indexOf("_")+1)); if (termIndex != 0) term = terms.get(termIndex);}
         */

        // cos50Value1 = cos50(word,syms1, model);
        cos50Value1 =
            cosineSimilarity(model.transform(word).toArray(), model.transform(t._1).toArray());

        // cos50Value1=distance(model.transform(word).toArray(), model.transform(t._1).toArray());
        cos50Value2 = 0;

        // Tuple2<String, Object>[] syms2 = model.findSynonyms(t._1,10);

        // cos50Value1 = cos50(word,syms2, model);
        // cos50Value2 = cos50(t._1,syms2, model);
        if (t._1.contains("/NN") || t._1.contains("/NNS")) {
          // if (!word.contains(t._1) && !t._1.contains(word)) {
          // if (!term.contains(word)) {
          count = count + 1;
          System.out.println("No " + count + ": " + t._1);
        }
        double dif = Math.abs(cos50Value1 - cos50Value2);
        closeList.put(t._1, dif);
      }
      synsSet.add(syns);

      /**
       * //sort the list by difference System.out.println("--------------------------------------");
       * System.out.println("The nearests of the word " + word +":"); Set<Entry<String, Double>> set
       * = closeList.entrySet(); List<Entry<String, Double>> sortedList = new
       * ArrayList<Entry<String, Double>>( set); Collections.sort(sortedList, new
       * Comparator<Map.Entry<String, Double>>() { public int compare(Map.Entry<String, Double> o1,
       * Map.Entry<String, Double> o2) { return o2.getValue().compareTo(o1.getValue()); } });
       * count=0; for (Entry<String, Double> entry : sortedList) { count=count+1;
       *
       * <p>System.out.println("No " + count+ ": " +entry.getKey() + " " + entry.getValue()); }
       *
       * <p>//retrieve vector space scala.collection.immutable.Map<String,float[]> m =
       * model.getVectors();
       *
       * <p>double[] v1 = model.transform(word).toArray();
       *
       * <p>System.out.println(Arrays.toString(v1));
       */
    }
    return synsSet;
  }