Beispiel #1
0
 public SparkRuntime(
     SparkPipeline pipeline,
     JavaSparkContext sparkContext,
     Configuration conf,
     Map<PCollectionImpl<?>, Set<Target>> outputTargets,
     Map<PCollectionImpl<?>, MaterializableIterable> toMaterialize,
     Map<PCollection<?>, StorageLevel> toCache,
     Map<PipelineCallable<?>, Set<Target>> allPipelineCallables) {
   this.pipeline = pipeline;
   this.sparkContext = sparkContext;
   this.conf = conf;
   this.counters =
       sparkContext.accumulator(
           Maps.<String, Map<String, Long>>newHashMap(), new CounterAccumulatorParam());
   this.ctxt =
       new SparkRuntimeContext(
           sparkContext.appName(),
           counters,
           sparkContext.broadcast(WritableUtils.toByteArray(conf)));
   this.outputTargets = Maps.newTreeMap(DEPTH_COMPARATOR);
   this.outputTargets.putAll(outputTargets);
   this.toMaterialize = toMaterialize;
   this.toCache = toCache;
   this.allPipelineCallables = allPipelineCallables;
   this.activePipelineCallables = allPipelineCallables.keySet();
   this.status.set(Status.READY);
   this.monitorThread =
       new Thread(
           new Runnable() {
             @Override
             public void run() {
               monitorLoop();
             }
           });
 }
  public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(
      final JavaRDD<GATKRead> reads, final JavaRDD<GATKVariant> variants) {
    final JavaSparkContext ctx = new JavaSparkContext(reads.context());
    final IntervalsSkipList<GATKVariant> variantSkipList =
        new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast =
        ctx.broadcast(variantSkipList);

    return reads.mapToPair(
        r -> {
          final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
          if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
            return new Tuple2<>(r, intervalsSkipList.getOverlapping(new SimpleInterval(r)));
          } else {
            // Sometimes we have reads that do not form valid intervals (reads that do not consume
            // any ref bases, eg CIGAR 61S90I
            // In those cases, we'll just say that nothing overlaps the read
            return new Tuple2<>(r, Collections.emptyList());
          }
        });
  }
  public void buildVocabCache() {

    // Tokenize
    JavaRDD<List<String>> tokenizedRDD = tokenize();

    // Update accumulator values and map to an RDD of sentence counts
    sentenceWordsCountRDD = updateAndReturnAccumulatorVal(tokenizedRDD).cache();

    // Get value from accumulator
    Counter<String> wordFreqCounter = wordFreqAcc.value();

    // Filter out low count words and add to vocab cache object and feed into LookupCache
    filterMinWordAddVocab(wordFreqCounter);

    // huffman tree should be built BEFORE vocab broadcast
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);

    // At this point the vocab cache is built. Broadcast vocab cache
    vocabCacheBroadcast = sc.broadcast(vocabCache);
  }
  public static void main(String[] args) throws IOException {
    Parameters param = new Parameters();
    long initTime = System.currentTimeMillis();

    SparkConf conf = new SparkConf().setAppName("StarJoin");
    JavaSparkContext sc = new JavaSparkContext(conf);

    if (param.useKryo) {
      conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
      conf.set("spark.kryo.registrator", MyBloomFilter.BloomFilterRegistrator.class.getName());
      conf.set("spark.kryoserializer.buffer.mb", param.buffer);
    }

    MyBloomFilter.BloomFilter<String> BFS =
        new MyBloomFilter.BloomFilter(1.0, param.bitsS, param.hashes);
    MyBloomFilter.BloomFilter<String> BFD =
        new MyBloomFilter.BloomFilter(1.0, param.bitsD, param.hashes);
    MyBloomFilter.BloomFilter<String> BFC =
        new MyBloomFilter.BloomFilter(1.0, param.bitsC, param.hashes);

    JavaPairRDD<String, String> supps =
        sc.textFile(param.suppPath)
            .map(
                new Function<String, String[]>() {
                  public String[] call(String line) {
                    return line.split("\\|");
                  }
                })
            .filter(
                new Function<String[], Boolean>() {
                  public Boolean call(String[] s) {
                    return s[3].equals("UNITED KI1") | s[3].equals("UNITED KI5");
                  }
                })
            .mapToPair(
                new PairFunction<String[], String, String>() {
                  public Tuple2<String, String> call(String[] s) {
                    return new Tuple2<String, String>(s[0], s[3]);
                  }
                });

    List<Tuple2<String, String>> s = supps.collect();
    for (int i = 0; i < s.size(); i++) {
      BFS.add(s.get(i)._1);
    }

    final Broadcast<MyBloomFilter.BloomFilter<String>> varS = sc.broadcast(BFS);

    JavaPairRDD<String, String> custs =
        sc.textFile(param.custPath)
            .map(
                new Function<String, String[]>() {
                  public String[] call(String line) {
                    return line.split("\\|");
                  }
                })
            .filter(
                new Function<String[], Boolean>() {
                  public Boolean call(String[] s) {
                    return s[3].equals("UNITED KI1") | s[3].equals("UNITED KI5");
                  }
                })
            .mapToPair(
                new PairFunction<String[], String, String>() {
                  public Tuple2<String, String> call(String[] s) {
                    return new Tuple2<String, String>(s[0], s[3]);
                  }
                });

    List<Tuple2<String, String>> c = custs.collect();
    for (int i = 0; i < c.size(); i++) {
      BFC.add(c.get(i)._1);
    }

    final Broadcast<MyBloomFilter.BloomFilter<String>> varC = sc.broadcast(BFC);

    JavaPairRDD<String, String> dates =
        sc.textFile(param.datePath)
            .map(
                new Function<String, String[]>() {
                  public String[] call(String line) {
                    return line.split("\\|");
                  }
                })
            .filter(
                new Function<String[], Boolean>() {
                  public Boolean call(String[] s) {
                    return s[6].equals("Dec1997");
                  }
                })
            .mapToPair(
                new PairFunction<String[], String, String>() {
                  public Tuple2<String, String> call(String[] s) {
                    return new Tuple2<String, String>(s[0], s[4]);
                  }
                });

    List<Tuple2<String, String>> d = dates.collect();
    for (int i = 0; i < d.size(); i++) {
      BFD.add(d.get(i)._1);
    }

    final Broadcast<MyBloomFilter.BloomFilter<String>> varD = sc.broadcast(BFD);

    JavaPairRDD<String, String[]> lines =
        sc.textFile(param.linePath)
            .map(
                new Function<String, String[]>() {
                  public String[] call(String line) {
                    return line.split("\\|");
                  }
                })
            .filter(
                new Function<String[], Boolean>() {
                  public Boolean call(String[] s) {
                    return varC.value().contains(s[2].getBytes())
                        & varS.value().contains(s[4].getBytes())
                        & varD.value().contains(s[5].getBytes());
                  }
                })
            .mapToPair(
                new PairFunction<String[], String, String[]>() {
                  public Tuple2<String, String[]> call(String[] s) {
                    String[] v = {s[2], s[5], s[12]};
                    return new Tuple2<String, String[]>(s[4], v);
                  }
                });

    JavaPairRDD<String, String[]> result =
        lines
            .join(supps)
            .mapToPair(
                new PairFunction<Tuple2<String, Tuple2<String[], String>>, String, String[]>() {
                  public Tuple2<String, String[]> call(Tuple2<String, Tuple2<String[], String>> s) {
                    String[] v = {s._2._1[1], s._2._1[2], s._2._2};
                    return new Tuple2<String, String[]>(s._2._1[0], v);
                  }
                });

    result =
        result
            .join(custs)
            .mapToPair(
                new PairFunction<Tuple2<String, Tuple2<String[], String>>, String, String[]>() {
                  public Tuple2<String, String[]> call(Tuple2<String, Tuple2<String[], String>> s) {
                    String[] v = {s._2._1[1], s._2._1[2], s._2._2};
                    return new Tuple2<String, String[]>(s._2._1[0], v);
                  }
                });

    JavaPairRDD<String, Long> final_result =
        result
            .join(dates)
            .mapToPair(
                new PairFunction<Tuple2<String, Tuple2<String[], String>>, String, Long>() {
                  public Tuple2<String, Long> call(Tuple2<String, Tuple2<String[], String>> s) {
                    return new Tuple2<String, Long>(
                        s._2._1[2] + "," + s._2._1[1] + "," + s._2._2, Long.parseLong(s._2._1[0]));
                  }
                })
            .reduceByKey(
                new Function2<Long, Long, Long>() {
                  public Long call(Long i1, Long i2) {
                    return i1 + i2;
                  }
                });

    JavaPairRDD<String, String> sub_result =
        final_result.mapToPair(
            new PairFunction<Tuple2<String, Long>, String, String>() {
              public Tuple2<String, String> call(Tuple2<String, Long> line) {
                return new Tuple2(line._1 + "," + line._2.toString(), null);
              }
            });

    final_result =
        sub_result
            .sortByKey(new Q3Comparator())
            .mapToPair(
                new PairFunction<Tuple2<String, String>, String, Long>() {
                  public Tuple2<String, Long> call(Tuple2<String, String> line) {
                    String[] s = line._1.split(",");
                    return new Tuple2<String, Long>(
                        s[0] + "," + s[1] + "," + s[2], Long.parseLong(s[3]));
                  }
                });

    Configuration HDFSconf = new Configuration();
    FileSystem fs = FileSystem.get(HDFSconf);
    fs.delete(new Path(param.output), true);

    final_result.saveAsTextFile(param.output);

    long finalTime = System.currentTimeMillis();
    System.out.print("Tempo total(ms): ");
    System.out.println(finalTime - initTime);

    sc.close();
  }
  /**
   * Train on the corpus
   *
   * @param rdd the rdd to train
   * @return the vocab and weights
   */
  public Pair<VocabCache, GloveWeightLookupTable> train(JavaRDD<String> rdd) {
    TextPipeline pipeline = new TextPipeline(rdd);
    final Pair<VocabCache, Long> vocabAndNumWords = pipeline.process();
    SparkConf conf = rdd.context().getConf();
    JavaSparkContext sc = new JavaSparkContext(rdd.context());
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());

    final GloveWeightLookupTable gloveWeightLookupTable =
        new GloveWeightLookupTable.Builder()
            .cache(vocabAndNumWords.getFirst())
            .lr(conf.getDouble(GlovePerformer.ALPHA, 0.025))
            .maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100))
            .vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300))
            .xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75))
            .build();
    gloveWeightLookupTable.resetWeights();

    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient =
        Nd4j.zeros(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient =
        Nd4j.create(gloveWeightLookupTable.getSyn0().shape());

    log.info(
        "Created lookup table of size "
            + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts =
        rdd.map(new TokenizerFunction(tokenizerFactoryClazz))
            .map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize))
            .fold(new CounterMap<String, String>(), new CoOccurrenceCounts());

    List<Triple<String, String, Double>> counts = new ArrayList<>();
    Iterator<Pair<String, String>> pairIter = coOccurrenceCounts.getPairIterator();
    while (pairIter.hasNext()) {
      Pair<String, String> pair = pairIter.next();
      counts.add(
          new Triple<>(
              pair.getFirst(),
              pair.getSecond(),
              coOccurrenceCounts.getCount(pair.getFirst(), pair.getSecond())));
    }

    log.info("Calculated co occurrences");

    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs =
        parallel.mapToPair(
            new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
              @Override
              public Tuple2<String, Tuple2<String, Double>> call(
                  Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
                return new Tuple2<>(
                    stringStringDoubleTriple.getFirst(),
                    new Tuple2<>(
                        stringStringDoubleTriple.getFirst(), stringStringDoubleTriple.getThird()));
              }
            });

    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab =
        pairs.mapToPair(
            new PairFunction<
                Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
              @Override
              public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(
                  Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
                return new Tuple2<>(
                    vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1()),
                    new Tuple2<>(
                        vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1()),
                        stringTuple2Tuple2._2()._2()));
              }
            });

    for (int i = 0; i < iterations; i++) {

      JavaRDD<GloveChange> change =
          pairsVocab.map(
              new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
                @Override
                public GloveChange call(
                    Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2)
                    throws Exception {
                  VocabWord w1 = vocabWordTuple2Tuple2._1();
                  VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                  INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                  INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                  INDArray bias = gloveWeightLookupTable.getBias();
                  double score = vocabWordTuple2Tuple2._2()._2();
                  double xMax = gloveWeightLookupTable.getxMax();
                  double maxCount = gloveWeightLookupTable.getMaxCount();
                  // w1 * w2 + bias
                  double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                  prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());

                  double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);

                  double fDiff =
                      score > xMax ? prediction : weight * (prediction - Math.log(score));
                  if (Double.isNaN(fDiff)) fDiff = Nd4j.EPS_THRESHOLD;
                  // amount of change
                  double gradient = fDiff;
                  // update(w1,w1Vector,w2Vector,gradient);
                  // update(w2,w2Vector,w1Vector,gradient);

                  Pair<INDArray, Double> w1Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w1,
                          w1Vector,
                          w2Vector,
                          gradient);
                  Pair<INDArray, Double> w2Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w2,
                          w2Vector,
                          w1Vector,
                          gradient);
                  return new GloveChange(
                      w1,
                      w2,
                      w1Update.getFirst(),
                      w2Update.getFirst(),
                      w1Update.getSecond(),
                      w2Update.getSecond(),
                      fDiff);
                }
              });

      JavaRDD<Double> error =
          change.map(
              new Function<GloveChange, Double>() {
                @Override
                public Double call(GloveChange gloveChange) throws Exception {
                  gloveChange.apply(gloveWeightLookupTable);
                  return gloveChange.getError();
                }
              });

      final Accumulator<Double> d = sc.accumulator(0.0);
      error.foreach(
          new VoidFunction<Double>() {
            @Override
            public void call(Double aDouble) throws Exception {
              d.$plus$eq(aDouble);
            }
          });

      log.info("Error at iteration " + i + " was " + d.value());
    }

    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
  }
Beispiel #6
0
 private void monitorLoop() {
   status.set(Status.RUNNING);
   long start = System.currentTimeMillis();
   Map<PCollectionImpl<?>, Set<Target>> targetDeps = Maps.newTreeMap(DEPTH_COMPARATOR);
   Set<Target> unfinished = Sets.newHashSet();
   for (PCollectionImpl<?> pcollect : outputTargets.keySet()) {
     targetDeps.put(pcollect, pcollect.getTargetDependencies());
     unfinished.addAll(outputTargets.get(pcollect));
   }
   runCallables(unfinished);
   while (!targetDeps.isEmpty() && doneSignal.getCount() > 0) {
     Set<Target> allTargets = Sets.newHashSet();
     for (PCollectionImpl<?> pcollect : targetDeps.keySet()) {
       allTargets.addAll(outputTargets.get(pcollect));
     }
     Map<PCollectionImpl<?>, JavaRDDLike<?, ?>> pcolToRdd = Maps.newTreeMap(DEPTH_COMPARATOR);
     for (PCollectionImpl<?> pcollect : targetDeps.keySet()) {
       if (Sets.intersection(allTargets, targetDeps.get(pcollect)).isEmpty()) {
         JavaRDDLike<?, ?> rdd = ((SparkCollection) pcollect).getJavaRDDLike(this);
         pcolToRdd.put(pcollect, rdd);
       }
     }
     distributeFiles();
     for (Map.Entry<PCollectionImpl<?>, JavaRDDLike<?, ?>> e : pcolToRdd.entrySet()) {
       JavaRDDLike<?, ?> rdd = e.getValue();
       PType<?> ptype = e.getKey().getPType();
       Set<Target> targets = outputTargets.get(e.getKey());
       if (targets.size() > 1) {
         rdd.rdd().cache();
       }
       for (Target t : targets) {
         Configuration conf = new Configuration(getConfiguration());
         getRuntimeContext().setConf(sparkContext.broadcast(WritableUtils.toByteArray(conf)));
         if (t instanceof MapReduceTarget) { // TODO: check this earlier
           Converter c = t.getConverter(ptype);
           IdentityFn ident = IdentityFn.getInstance();
           JavaPairRDD<?, ?> outRDD;
           if (rdd instanceof JavaRDD) {
             outRDD =
                 ((JavaRDD) rdd)
                     .map(
                         new MapFunction(
                             c.applyPTypeTransforms() ? ptype.getOutputMapFn() : ident, ctxt))
                     .mapToPair(new OutputConverterFunction(c));
           } else {
             outRDD =
                 ((JavaPairRDD) rdd)
                     .map(
                         new PairMapFunction(
                             c.applyPTypeTransforms() ? ptype.getOutputMapFn() : ident, ctxt))
                     .mapToPair(new OutputConverterFunction(c));
           }
           try {
             Job job = new Job(conf);
             if (t instanceof PathTarget) {
               PathTarget pt = (PathTarget) t;
               pt.configureForMapReduce(job, ptype, pt.getPath(), "out0");
               CrunchOutputs.OutputConfig outConfig =
                   CrunchOutputs.getNamedOutputs(job.getConfiguration()).get("out0");
               job.setOutputFormatClass(outConfig.bundle.getFormatClass());
               job.setOutputKeyClass(outConfig.keyClass);
               job.setOutputValueClass(outConfig.valueClass);
               outConfig.bundle.configure(job.getConfiguration());
               Path tmpPath = pipeline.createTempPath();
               outRDD.saveAsNewAPIHadoopFile(
                   tmpPath.toString(),
                   c.getKeyClass(),
                   c.getValueClass(),
                   job.getOutputFormatClass(),
                   job.getConfiguration());
               pt.handleOutputs(job.getConfiguration(), tmpPath, -1);
             } else { // if (t instanceof MapReduceTarget) {
               MapReduceTarget mrt = (MapReduceTarget) t;
               mrt.configureForMapReduce(job, ptype, new Path("/tmp"), "out0");
               CrunchOutputs.OutputConfig outConfig =
                   CrunchOutputs.getNamedOutputs(job.getConfiguration()).get("out0");
               job.setOutputFormatClass(outConfig.bundle.getFormatClass());
               job.setOutputKeyClass(outConfig.keyClass);
               job.setOutputValueClass(outConfig.valueClass);
               outRDD.saveAsHadoopDataset(new JobConf(job.getConfiguration()));
             }
           } catch (Exception et) {
             LOG.error("Spark Exception", et);
             status.set(Status.FAILED);
             set(PipelineResult.EMPTY);
             doneSignal.countDown();
           }
         }
       }
       unfinished.removeAll(targets);
     }
     if (status.get() == Status.RUNNING) {
       for (PCollectionImpl<?> output : pcolToRdd.keySet()) {
         if (toMaterialize.containsKey(output)) {
           MaterializableIterable mi = toMaterialize.get(output);
           if (mi.isSourceTarget()) {
             output.materializeAt((SourceTarget) mi.getSource());
           }
         }
         targetDeps.remove(output);
       }
     }
     runCallables(unfinished);
   }
   if (status.get() != Status.FAILED || status.get() != Status.KILLED) {
     status.set(Status.SUCCEEDED);
     set(
         new PipelineResult(
             ImmutableList.of(
                 new PipelineResult.StageResult(
                     "Spark", getCounters(), start, System.currentTimeMillis())),
             Status.SUCCEEDED));
   } else {
     set(PipelineResult.EMPTY);
   }
   doneSignal.countDown();
 }
Beispiel #7
0
  // Training word2vec based on corpus
  public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");

    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());

    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();

    // Variables to fill in in train
    // final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    // final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;

    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline =
        new TextPipeline(corpusRDD.repartition(numPartitions), broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();

    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount() / numPartitions);

    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    // sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();

    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();

    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();

    /////////////////////////////////////
    log.info("Training word2vec sentences ...");

    word2vecVarMap.put("vecNum", vocabCache.numWords());

    // Map<Tuple2<Integer,Integer>, INDArray> s0 = new HashMap();
    Map<Pair<Integer, Integer>, INDArray> s0 = new HashMap();
    for (int k = 0; k < K; k++) {
      for (int i = 0; i < vocabCache.numWords(); i++) {
        s0.put(new Pair(i, k), getRandomSyn0Vec(vectorLength));
      }
    }
    for (int i = vocabCache.numWords(); i < vocabCache.numWords() * 2 - 1; i++) {
      s0.put(new Pair(i, 0), Nd4j.zeros(1, vectorLength));
    }

    for (int i = 0; i < iterations; i++) {
      System.out.println("iteration: " + i);

      word2vecVarMap.put("alpha", alpha - (alpha - minAlpha) / iterations * i);
      word2vecVarMap.put("minAlpha", alpha - (alpha - minAlpha) / iterations * (i + 1));

      FlatMapFunction firstIterationFunction =
          new FirstIterationFunction(word2vecVarMap, expTable, sc.broadcast(s0));

      class MapPairFunction
          implements PairFunction<Map.Entry<Integer, INDArray>, Integer, INDArray> {
        public Tuple2<Integer, INDArray> call(Map.Entry<Integer, INDArray> pair) {
          return new Tuple2(pair.getKey(), pair.getValue());
        }
      }

      class Sum implements Function2<INDArray, INDArray, INDArray> {
        public INDArray call(INDArray a, INDArray b) {
          return a.add(b);
        }
      }

      // @SuppressWarnings("unchecked")
      JavaPairRDD<Pair<Integer, Integer>, INDArray> indexSyn0UpdateEntryRDD =
          vocabWordListRDD
              .mapPartitions(firstIterationFunction)
              .mapToPair(new MapPairFunction())
              .cache();
      Map<Pair<Integer, Integer>, Object> count = indexSyn0UpdateEntryRDD.countByKey();
      indexSyn0UpdateEntryRDD = indexSyn0UpdateEntryRDD.reduceByKey(new Sum());

      // Get all the syn0 updates into a list in driver
      List<Tuple2<Pair<Integer, Integer>, INDArray>> syn0UpdateEntries =
          indexSyn0UpdateEntryRDD.collect();

      // Updating syn0
      s0 = new HashMap();
      for (Tuple2<Pair<Integer, Integer>, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        int cc = Integer.parseInt(count.get(syn0UpdateEntry._1).toString());
        // int cc = 1;
        if (cc > 0) {
          INDArray tmp = Nd4j.zeros(1, vectorLength).addi(syn0UpdateEntry._2).divi(cc);
          s0.put(syn0UpdateEntry._1, tmp);
        }
      }
    }

    syn0 = Nd4j.zeros(vocabCache.numWords() * K, vectorLength);
    for (Map.Entry<Pair<Integer, Integer>, INDArray> ss : s0.entrySet()) {
      if (ss.getKey().getFirst() < vocabCache.numWords()) {
        syn0.getRow(ss.getKey().getSecond() * vocabCache.numWords() + ss.getKey().getFirst())
            .addi(ss.getValue());
      }
    }

    vocab = vocabCache;
    syn0.diviRowVector(syn0.norm2(1));

    BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
    for (int i = 0; i < syn0.rows(); i++) {
      String word = vocab.wordAtIndex(i % vocab.numWords());
      if (word == null) {
        continue;
      }
      word = word + "(" + i / vocab.numWords() + ")";
      StringBuilder sb = new StringBuilder();
      sb.append(word.replaceAll(" ", "_"));
      sb.append(" ");
      INDArray wordVector = syn0.getRow(i);
      for (int j = 0; j < wordVector.length(); j++) {
        sb.append(wordVector.getDouble(j));
        if (j < wordVector.length() - 1) {
          sb.append(" ");
        }
      }
      sb.append("\n");
      write.write(sb.toString());
    }
    write.flush();
    write.close();
  }
  public static void main(String[] args) {

    SparkConf conf = new SparkConf().setAppName("SpatialJoinQuery Application");

    JavaSparkContext sc = new JavaSparkContext(conf);
    // Read the input csv file holding set of polygons in a rdd of string
    // objects
    JavaRDD<String> firstInputPoints =
        sc.textFile("hdfs://192.168.139.149:54310/harsh/spatialJoinFirstInput.csv");

    // Map the above rdd of strings to a rdd of rectangles for the first
    // input

    // Repeat the above process but now for initializing the rdd for query
    // window
    JavaRDD<String> secondInputPoints =
        sc.textFile("hdfs://192.168.139.149:54310/harsh/spatialJoinSecondInput.csv");

    JavaRDD<Tuple2<Integer, ArrayList<Integer>>> joinQueryRDD =
        new JavaRDD<Tuple2<Integer, ArrayList<Integer>>>(null, null);

    if (args[0].equalsIgnoreCase("rectangle")) {

      System.out.println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>");
      System.out.println("inside>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>");

      final JavaRDD<Rectangle> firstInputRDD = firstInputPoints.map(mapInputStringToRectRDD());
      System.out.println(firstInputRDD.collect());

      // Map the query window to RDD object
      final JavaRDD<Rectangle> secondInputRDD = secondInputPoints.map(mapInputStringToRectRDD());

      // broadcast the second set of rectangles to each of the workers
      final Broadcast<List<Rectangle>> firstInput = sc.broadcast(firstInputRDD.collect());
      // map the id of first input to the multiple id’s of the second
      // input if
      // they contain the
      // first rectangle.
      joinQueryRDD =
          secondInputRDD.map(
              new Function<Rectangle, Tuple2<Integer, ArrayList<Integer>>>() {
                public Tuple2<Integer, ArrayList<Integer>> call(Rectangle rectangle)
                    throws Exception {
                  // Get the list of rectangles from the second RDD input.
                  List<Rectangle> firstInputCollection = firstInput.value();
                  ArrayList<Integer> secondInputIds = new ArrayList<Integer>();
                  // Iterate the second input and check for the second set
                  // of rectangle id’s
                  // that hold the rectangle from first set obtained from
                  // the mapped RDD
                  for (Rectangle firstRects : firstInputCollection) {
                    if (rectangle.isRectangleinsideQueryWindow(firstRects)) {
                      secondInputIds.add(firstRects.getRectangleId());
                    }
                  }
                  // Create a new tuple of the mapped values and return
                  // back the mapped
                  // transformation.
                  Tuple2<Integer, ArrayList<Integer>> resultList =
                      new Tuple2<Integer, ArrayList<Integer>>(
                          rectangle.getRectangleId(), secondInputIds);
                  return resultList;
                }
              });

    } else if (args[0].equalsIgnoreCase("point")) {

      final JavaRDD<Point> firstInputRDD =
          firstInputPoints.map(SpatialRangeQuery.mapInputStringToPointRDD());

      // broadcast the second set of rectangles to each of the workers
      final Broadcast<List<Point>> firstInput = sc.broadcast(firstInputRDD.collect());

      // Map the query window to RDD object
      final JavaRDD<Rectangle> secondInputRDD = secondInputPoints.map(mapInputStringToRectRDD());

      joinQueryRDD =
          secondInputRDD.map(
              new Function<Rectangle, Tuple2<Integer, ArrayList<Integer>>>() {
                public Tuple2<Integer, ArrayList<Integer>> call(Rectangle rectangle)
                    throws Exception {
                  // Get the list of rectangles from the second RDD input.
                  List<Point> firstInputCollection = firstInput.getValue();
                  ArrayList<Integer> secondInputIds = new ArrayList<Integer>();
                  // Iterate the second input and check for the second set
                  // of rectangle id’s
                  // that hold the rectangle from first set obtained from
                  // the mapped RDD
                  for (Point point : firstInputCollection) {
                    if (point.isPointinsideQueryWindow(rectangle)) {
                      secondInputIds.add(point.getPointID());
                    }
                  }
                  // Create a new tuple of the mapped values and return
                  // back the mapped
                  // transformation.
                  Tuple2<Integer, ArrayList<Integer>> resultList =
                      new Tuple2<Integer, ArrayList<Integer>>(
                          rectangle.getRectangleId(), secondInputIds);
                  return resultList;
                }
              });
    }

    JavaRDD<String> result =
        joinQueryRDD.map(
            new Function<Tuple2<Integer, ArrayList<Integer>>, String>() {
              public String call(Tuple2<Integer, ArrayList<Integer>> inputPoint) {

                Integer containingRect = inputPoint._1();
                ArrayList<Integer> containedRects = inputPoint._2();

                StringBuffer intermediateBuffer = new StringBuffer();

                intermediateBuffer.append(containingRect);

                for (Integer rects : containedRects) {
                  intermediateBuffer.append(", " + rects);
                }

                return intermediateBuffer.toString();
              }
            });

    result.coalesce(1).saveAsTextFile("hdfs://192.168.139.149:54310/harsh/jQueryResult.csv");

    sc.close();
  }
 private void setup() {
   // Set up accumulators and broadcast stopwords
   this.sc = new JavaSparkContext(corpusRDD.context());
   this.wordFreqAcc = sc.accumulator(new Counter<String>(), new WordFreqAccumulator());
   this.stopWordBroadCast = sc.broadcast(stopWords);
 }