public static void trainModel(String filteredDataPath, String modelpath) throws IOException { String line = ""; String combline = ""; // read and process raw data BufferedReader br = new BufferedReader(new FileReader(filteredDataPath)); while ((line = br.readLine()) != null) combline = combline + " " + line; List<String> words = Lists.newArrayList(combline.split(" ")); List<List<String>> localDoc = Lists.newArrayList(words, words); // build a context object JavaSparkContext sc = new JavaSparkContext("local", "Word2VecSuite"); JavaRDD<List<String>> doc = sc.parallelize(localDoc); // training settings Word2Vec word2vec = new Word2Vec().setVectorSize(100).setMinCount(50).setSeed(42L); // train Word2VecModel model = word2vec.fit(doc); // save model SparkContext sc1 = sc.toSparkContext(sc); model.save(sc1, modelpath); System.out.println("Model has been saved in folder: " + modelpath); }
public static double cos50(String word, Tuple2<String, Object>[] syms, Word2VecModel model) { double[] v1 = model.transform(word).toArray(); // Tuple2<String, Object>[] syms2 = model.findSynonyms(word, 20); int count = 0; double cosSum = 0; double cosAvg = 0; for (Tuple2<String, Object> t : syms) { double[] v2 = model.transform(t._1).toArray(); cosSum = cosSum + cosineSimilarity(v1, v2); // cosSum=cosSum+distance(v1, v2); } cosAvg = cosSum / syms.length; return cosAvg; }
public static List<List<String>> getSynonyms(List<String> words, String modelpath) throws IOException { /** * //read term file to get the list of phrases List<String> terms = new ArrayList<>(); * logger.log(Level.INFO, "*********Reading term file...."); String combline=""; String line * =""; try { BufferedReader br = new BufferedReader(new FileReader(termFile)); int i=0; while * ((line=br.readLine()) != null) { String phr = line.replaceAll("[^A-Za-z]"," ").trim(); * terms.add(phr); } } catch (FileNotFoundException e) { // TODO Auto-generated catch block * e.printStackTrace(); } */ SparkContext sc = new SparkContext("local", "appName"); // load a trained model Word2VecModel model = Word2VecModel.load(sc, modelpath); // find similar words List<List<String>> synsSet = new ArrayList<List<String>>(); for (String word : words) { word = word.replaceAll("\\s", "-"); Tuple2<String, Object>[] syms1 = model.findSynonyms(word, 50); // cos50 value of the target word // double cos50Value = cos50(word,syms, model); // System.out.println("CosAvg-50 value of the target word " + word +":"+cos50Value); HashMap<String, Double> closeList = new HashMap<String, Double>(); System.out.println("The nearests of the word " + word + ":"); int count = 0; double cos50Value1 = 0; double cos50Value2 = 0; List<String> syns = new ArrayList<>(); syns.add(word); // add itself for (Tuple2<String, Object> t : syms1) { syns.add(t._1); /** * String term=""; int termIndex=0; if (t._1.contains("_")){ termIndex = Integer.parseInt( * t._1.substring(t._1.indexOf("_")+1)); if (termIndex != 0) term = terms.get(termIndex);} */ // cos50Value1 = cos50(word,syms1, model); cos50Value1 = cosineSimilarity(model.transform(word).toArray(), model.transform(t._1).toArray()); // cos50Value1=distance(model.transform(word).toArray(), model.transform(t._1).toArray()); cos50Value2 = 0; // Tuple2<String, Object>[] syms2 = model.findSynonyms(t._1,10); // cos50Value1 = cos50(word,syms2, model); // cos50Value2 = cos50(t._1,syms2, model); if (t._1.contains("/NN") || t._1.contains("/NNS")) { // if (!word.contains(t._1) && !t._1.contains(word)) { // if (!term.contains(word)) { count = count + 1; System.out.println("No " + count + ": " + t._1); } double dif = Math.abs(cos50Value1 - cos50Value2); closeList.put(t._1, dif); } synsSet.add(syns); /** * //sort the list by difference System.out.println("--------------------------------------"); * System.out.println("The nearests of the word " + word +":"); Set<Entry<String, Double>> set * = closeList.entrySet(); List<Entry<String, Double>> sortedList = new * ArrayList<Entry<String, Double>>( set); Collections.sort(sortedList, new * Comparator<Map.Entry<String, Double>>() { public int compare(Map.Entry<String, Double> o1, * Map.Entry<String, Double> o2) { return o2.getValue().compareTo(o1.getValue()); } }); * count=0; for (Entry<String, Double> entry : sortedList) { count=count+1; * * <p>System.out.println("No " + count+ ": " +entry.getKey() + " " + entry.getValue()); } * * <p>//retrieve vector space scala.collection.immutable.Map<String,float[]> m = * model.getVectors(); * * <p>double[] v1 = model.transform(word).toArray(); * * <p>System.out.println(Arrays.toString(v1)); */ } return synsSet; }