public static void trainLibSVMs() throws Exception {
    System.out.println("train LibSVMs.");

    Indexer<String> featureIndexer = IOUtils.readIndexer(WORD_INDEXER_FILE);
    Indexer<String> labelIndexer = new Indexer<String>();
    labelIndexer.add("Non-Relevant");
    labelIndexer.add("Relevant");

    List<SparseVector> trainData = SparseVector.readList(TRAIN_DATA_FILE);
    List<SparseVector> testData = SparseVector.readList(TEST_DATA_FILE);

    Collections.shuffle(trainData);
    Collections.shuffle(testData);

    LibSvmTrainer trainer = new LibSvmTrainer();
    LibSvmWrapper wrapper = trainer.train(labelIndexer, featureIndexer, trainData);
    wrapper.evalute(testData);
  }
  public static void generateData2() throws Exception {
    System.out.println("generate data.");

    List<BaseQuery> bqs = QueryReader.readTrecCdsQueries(MIRPath.TREC_CDS_QUERY_2014_FILE);

    Map<Integer, BaseQuery> queryMap = new HashMap<Integer, BaseQuery>();

    for (BaseQuery bq : bqs) {
      int qid = Integer.parseInt(bq.getId());
      queryMap.put(qid, bq);
    }

    Indexer<String> wordIndexer = new Indexer<String>();
    Indexer<String> typeIndexer = new Indexer<String>();

    TextFileWriter writer = new TextFileWriter(DATA_FILE);

    List<SparseVector> data = new ArrayList<SparseVector>();

    TextFileReader reader = new TextFileReader(MIRPath.TREC_CDS_QUERY_DOC_FILE);
    while (reader.hasNext()) {
      List<String> lines = reader.getNextLines();
      List<SparseVector> svs = new ArrayList<SparseVector>();
      int qid = -1;

      for (int i = 0; i < lines.size(); i++) {
        String line = lines.get(i);
        String[] parts = line.split("\t");

        double relevance = -1;

        if (i == 0) {
          qid = Integer.parseInt(parts[1]);
        } else {
          relevance = Double.parseDouble(parts[1]);
        }

        StrCounter c = new StrCounter();
        String[] toks = parts[2].split(" ");
        for (int j = 0; j < toks.length; j++) {
          String[] two = StrUtils.split2Two(":", toks[j]);
          c.incrementCount(two[0], Double.parseDouble(two[1]));
        }

        SparseVector sv = VectorUtils.toSparseVector(c, wordIndexer, true);

        if (i > 0) {
          sv.setLabel((int) relevance);
        }

        svs.add(sv);
      }

      TrecCdsQuery tcq = (TrecCdsQuery) queryMap.get(qid);
      String type = tcq.getType();
      int typeId = typeIndexer.getIndex(type);

      // SparseVector q = svs.get(0);

      for (int i = 1; i < svs.size(); i++) {
        SparseVector d = svs.get(i);
        double relevance = d.label();

        if (relevance > 0) {
          d.setLabel(typeId);
        } else {
          d.setLabel(3);
        }

        // SparseVector qd = VectorMath.add(q, d);
        // qd.setLabel(d.label());

        data.add(d);
      }
    }
    reader.close();
    writer.close();

    SparseVector.write(DATA_FILE, data);
    IOUtils.write(WORD_INDEXER_FILE, wordIndexer);
  }
  public static void trainLibLinear() throws Exception {
    System.out.println("train SVMs.");

    Indexer<String> featureIndexer = IOUtils.readIndexer(WORD_INDEXER_FILE);

    List<SparseVector> trainData = SparseVector.readList(TRAIN_DATA_FILE);
    List<SparseVector> testData = SparseVector.readList(TEST_DATA_FILE);

    Collections.shuffle(trainData);
    Collections.shuffle(testData);

    // List[] lists = new List[] { trainData, testData };
    //
    // for (int i = 0; i < lists.length; i++) {
    // List<SparseVector> list = lists[i];
    // for (int j = 0; j < list.size(); j++) {
    // SparseVector sv = list.get(j);
    // if (sv.label() > 0) {
    // sv.setLabel(1);
    // }
    // }
    // }

    Problem prob = new Problem();
    prob.l = trainData.size();
    prob.n = featureIndexer.size() + 1;
    prob.y = new double[prob.l];
    prob.x = new Feature[prob.l][];
    prob.bias = -1;

    if (prob.bias >= 0) {
      prob.n++;
    }

    for (int i = 0; i < trainData.size(); i++) {
      SparseVector x = trainData.get(i);

      Feature[] input = new Feature[prob.bias > 0 ? x.size() + 1 : x.size()];

      for (int j = 0; j < x.size(); j++) {
        int index = x.indexAtLoc(j) + 1;
        double value = x.valueAtLoc(j);

        assert index >= 0;

        input[j] = new FeatureNode(index + 1, value);
      }

      if (prob.bias >= 0) {
        input[input.length - 1] = new FeatureNode(prob.n, prob.bias);
      }

      prob.x[i] = input;
      prob.y[i] = x.label();
    }

    Model model = Linear.train(prob, getSVMParamter());

    CounterMap<Integer, Integer> cm = new CounterMap<Integer, Integer>();

    for (int i = 0; i < testData.size(); i++) {
      SparseVector sv = testData.get(i);
      Feature[] input = new Feature[sv.size()];
      for (int j = 0; j < sv.size(); j++) {
        int index = sv.indexAtLoc(j) + 1;
        double value = sv.valueAtLoc(j);
        input[j] = new FeatureNode(index + 1, value);
      }

      double[] dec_values = new double[model.getNrClass()];
      Linear.predictValues(model, input, dec_values);
      int max_id = ArrayMath.argmax(dec_values);
      int pred = model.getLabels()[max_id];
      int answer = sv.label();

      cm.incrementCount(answer, pred, 1);
    }

    System.out.println(cm);

    model.save(new File(MODEL_FILE));
  }
Example #4
0
  public void collect2() throws Exception {
    String[] queryFileNames = MIRPath.QueryFileNames;

    String[] indexDirNames = MIRPath.IndexDirNames;

    String[] relDataFileNames = MIRPath.RelevanceFileNames;

    String[] docMapFileNames = MIRPath.DocIdMapFileNames;

    String[] queryDocFileNames = MIRPath.QueryDocFileNames;

    IndexSearcher[] indexSearchers = SearcherUtils.getIndexSearchers(indexDirNames);

    Analyzer analyzer = MedicalEnglishAnalyzer.getAnalyzer();

    for (int i = 0; i < queryFileNames.length; i++) {
      List<BaseQuery> bqs = new ArrayList<BaseQuery>();
      CounterMap<String, String> queryRels = new CounterMap<String, String>();

      File queryFile = new File(queryFileNames[i]);
      File relvFile = new File(relDataFileNames[i]);

      if (i == 0) {
        bqs = QueryReader.readTrecCdsQueries(queryFileNames[i]);
        queryRels = RelevanceReader.readTrecCdsRelevances(relDataFileNames[i]);
      } else if (i == 1) {
        bqs = QueryReader.readClefEHealthQueries(queryFileNames[i]);
        queryRels = RelevanceReader.readClefEHealthRelevances(relDataFileNames[i]);
      } else if (i == 2) {
        bqs = QueryReader.readOhsumedQueries(queryFileNames[i]);
        queryRels = RelevanceReader.readOhsumedRelevances(relDataFileNames[i]);
      } else if (i == 3) {
        bqs = QueryReader.readTrecGenomicsQueries(queryFileNames[i]);
        queryRels = RelevanceReader.readTrecGenomicsRelevances(relDataFileNames[i]);
      }

      List<Counter<String>> qs = new ArrayList<Counter<String>>();

      for (int j = 0; j < bqs.size(); j++) {
        BaseQuery bq = bqs.get(j);
        qs.add(AnalyzerUtils.getWordCounts(bq.getSearchText(), analyzer));
      }

      BidMap<String, String> docIdMap = DocumentIdMapper.readDocumentIdMap(docMapFileNames[i]);

      // queryRelevances = RelevanceReader.filter(queryRelevances, docIdMap);

      // baseQueries = QueryReader.filter(baseQueries, queryRelevances);

      List<SparseVector> docRelData =
          DocumentIdMapper.mapDocIdsToIndexIds(bqs, queryRels, docIdMap);

      IndexReader indexReader = indexSearchers[i].getIndexReader();

      if (bqs.size() != docRelData.size()) {
        throw new Exception();
      }

      TextFileWriter writer = new TextFileWriter(queryDocFileNames[i]);

      for (int j = 0; j < bqs.size(); j++) {
        BaseQuery bq = bqs.get(j);
        SparseVector docRels = docRelData.get(j);

        Indexer<String> wordIndexer = new Indexer<String>();

        Counter<String> qwcs = qs.get(j);

        SparseVector q = VectorUtils.toSparseVector(qs.get(j), wordIndexer, true);

        {
          SparseVector docFreqs =
              VectorUtils.toSparseVector(
                  WordCountBox.getDocFreqs(indexReader, IndexFieldName.CONTENT, qs.get(j).keySet()),
                  wordIndexer,
                  true);
          computeTFIDFs(q, docFreqs, indexReader.maxDoc());
        }

        WordCountBox wcb = WordCountBox.getWordCountBox(indexReader, docRels, wordIndexer);
        SparseMatrix sm = wcb.getDocWordCounts();
        SparseVector docFreqs = wcb.getCollDocFreqs();

        for (int k = 0; k < sm.rowSize(); k++) {
          int docId = sm.indexAtRowLoc(k);
          SparseVector sv = sm.vectorAtRowLoc(k);
          computeTFIDFs(sv, docFreqs, wcb.getNumDocsInCollection());
        }

        writer.write(String.format("#Query\t%d\t%s\n", j + 1, bq.toString()));
        writer.write(
            String.format("#Query Words\t%s\n", toString(VectorUtils.toCounter(q, wordIndexer))));

        docRels.sortByValue();

        for (int k = 0; k < docRels.size(); k++) {
          int docId = docRels.indexAtLoc(k);
          double rel = docRels.valueAtLoc(k);

          if (rel == 0) {
            continue;
          }

          Document doc = indexReader.document(docId);

          List<Integer> ws = wcb.getDocWords().get(docId);

          StringBuffer sb = new StringBuffer();

          for (int l = 0; l < ws.size(); l++) {
            int w = ws.get(l);
            boolean found = false;
            if (q.location(w) > -1) {
              found = true;
            }

            if (found) {
              sb.append(
                  String.format(
                      "%d\t%s\t%s\n", l + 1, wordIndexer.getObject(w), found ? 1 + "" : ""));
            }
          }

          String content = doc.get(IndexFieldName.CONTENT);

          SparseVector sv = sm.rowAlways(docId);

          if (sv.size() == 0) {
            continue;
          }

          writer.write(String.format("DOC-ID\t%d\nRelevance\t%d\n", docId, (int) rel));
          writer.write(String.format("Loc\tWord\tMark\n%s\n", sb.toString()));
        }
        writer.write("\n");
      }
    }
  }