public static void trainLibSVMs() throws Exception { System.out.println("train LibSVMs."); Indexer<String> featureIndexer = IOUtils.readIndexer(WORD_INDEXER_FILE); Indexer<String> labelIndexer = new Indexer<String>(); labelIndexer.add("Non-Relevant"); labelIndexer.add("Relevant"); List<SparseVector> trainData = SparseVector.readList(TRAIN_DATA_FILE); List<SparseVector> testData = SparseVector.readList(TEST_DATA_FILE); Collections.shuffle(trainData); Collections.shuffle(testData); LibSvmTrainer trainer = new LibSvmTrainer(); LibSvmWrapper wrapper = trainer.train(labelIndexer, featureIndexer, trainData); wrapper.evalute(testData); }
public static void generateData2() throws Exception { System.out.println("generate data."); List<BaseQuery> bqs = QueryReader.readTrecCdsQueries(MIRPath.TREC_CDS_QUERY_2014_FILE); Map<Integer, BaseQuery> queryMap = new HashMap<Integer, BaseQuery>(); for (BaseQuery bq : bqs) { int qid = Integer.parseInt(bq.getId()); queryMap.put(qid, bq); } Indexer<String> wordIndexer = new Indexer<String>(); Indexer<String> typeIndexer = new Indexer<String>(); TextFileWriter writer = new TextFileWriter(DATA_FILE); List<SparseVector> data = new ArrayList<SparseVector>(); TextFileReader reader = new TextFileReader(MIRPath.TREC_CDS_QUERY_DOC_FILE); while (reader.hasNext()) { List<String> lines = reader.getNextLines(); List<SparseVector> svs = new ArrayList<SparseVector>(); int qid = -1; for (int i = 0; i < lines.size(); i++) { String line = lines.get(i); String[] parts = line.split("\t"); double relevance = -1; if (i == 0) { qid = Integer.parseInt(parts[1]); } else { relevance = Double.parseDouble(parts[1]); } StrCounter c = new StrCounter(); String[] toks = parts[2].split(" "); for (int j = 0; j < toks.length; j++) { String[] two = StrUtils.split2Two(":", toks[j]); c.incrementCount(two[0], Double.parseDouble(two[1])); } SparseVector sv = VectorUtils.toSparseVector(c, wordIndexer, true); if (i > 0) { sv.setLabel((int) relevance); } svs.add(sv); } TrecCdsQuery tcq = (TrecCdsQuery) queryMap.get(qid); String type = tcq.getType(); int typeId = typeIndexer.getIndex(type); // SparseVector q = svs.get(0); for (int i = 1; i < svs.size(); i++) { SparseVector d = svs.get(i); double relevance = d.label(); if (relevance > 0) { d.setLabel(typeId); } else { d.setLabel(3); } // SparseVector qd = VectorMath.add(q, d); // qd.setLabel(d.label()); data.add(d); } } reader.close(); writer.close(); SparseVector.write(DATA_FILE, data); IOUtils.write(WORD_INDEXER_FILE, wordIndexer); }
public static void trainLibLinear() throws Exception { System.out.println("train SVMs."); Indexer<String> featureIndexer = IOUtils.readIndexer(WORD_INDEXER_FILE); List<SparseVector> trainData = SparseVector.readList(TRAIN_DATA_FILE); List<SparseVector> testData = SparseVector.readList(TEST_DATA_FILE); Collections.shuffle(trainData); Collections.shuffle(testData); // List[] lists = new List[] { trainData, testData }; // // for (int i = 0; i < lists.length; i++) { // List<SparseVector> list = lists[i]; // for (int j = 0; j < list.size(); j++) { // SparseVector sv = list.get(j); // if (sv.label() > 0) { // sv.setLabel(1); // } // } // } Problem prob = new Problem(); prob.l = trainData.size(); prob.n = featureIndexer.size() + 1; prob.y = new double[prob.l]; prob.x = new Feature[prob.l][]; prob.bias = -1; if (prob.bias >= 0) { prob.n++; } for (int i = 0; i < trainData.size(); i++) { SparseVector x = trainData.get(i); Feature[] input = new Feature[prob.bias > 0 ? x.size() + 1 : x.size()]; for (int j = 0; j < x.size(); j++) { int index = x.indexAtLoc(j) + 1; double value = x.valueAtLoc(j); assert index >= 0; input[j] = new FeatureNode(index + 1, value); } if (prob.bias >= 0) { input[input.length - 1] = new FeatureNode(prob.n, prob.bias); } prob.x[i] = input; prob.y[i] = x.label(); } Model model = Linear.train(prob, getSVMParamter()); CounterMap<Integer, Integer> cm = new CounterMap<Integer, Integer>(); for (int i = 0; i < testData.size(); i++) { SparseVector sv = testData.get(i); Feature[] input = new Feature[sv.size()]; for (int j = 0; j < sv.size(); j++) { int index = sv.indexAtLoc(j) + 1; double value = sv.valueAtLoc(j); input[j] = new FeatureNode(index + 1, value); } double[] dec_values = new double[model.getNrClass()]; Linear.predictValues(model, input, dec_values); int max_id = ArrayMath.argmax(dec_values); int pred = model.getLabels()[max_id]; int answer = sv.label(); cm.incrementCount(answer, pred, 1); } System.out.println(cm); model.save(new File(MODEL_FILE)); }
public void collect2() throws Exception { String[] queryFileNames = MIRPath.QueryFileNames; String[] indexDirNames = MIRPath.IndexDirNames; String[] relDataFileNames = MIRPath.RelevanceFileNames; String[] docMapFileNames = MIRPath.DocIdMapFileNames; String[] queryDocFileNames = MIRPath.QueryDocFileNames; IndexSearcher[] indexSearchers = SearcherUtils.getIndexSearchers(indexDirNames); Analyzer analyzer = MedicalEnglishAnalyzer.getAnalyzer(); for (int i = 0; i < queryFileNames.length; i++) { List<BaseQuery> bqs = new ArrayList<BaseQuery>(); CounterMap<String, String> queryRels = new CounterMap<String, String>(); File queryFile = new File(queryFileNames[i]); File relvFile = new File(relDataFileNames[i]); if (i == 0) { bqs = QueryReader.readTrecCdsQueries(queryFileNames[i]); queryRels = RelevanceReader.readTrecCdsRelevances(relDataFileNames[i]); } else if (i == 1) { bqs = QueryReader.readClefEHealthQueries(queryFileNames[i]); queryRels = RelevanceReader.readClefEHealthRelevances(relDataFileNames[i]); } else if (i == 2) { bqs = QueryReader.readOhsumedQueries(queryFileNames[i]); queryRels = RelevanceReader.readOhsumedRelevances(relDataFileNames[i]); } else if (i == 3) { bqs = QueryReader.readTrecGenomicsQueries(queryFileNames[i]); queryRels = RelevanceReader.readTrecGenomicsRelevances(relDataFileNames[i]); } List<Counter<String>> qs = new ArrayList<Counter<String>>(); for (int j = 0; j < bqs.size(); j++) { BaseQuery bq = bqs.get(j); qs.add(AnalyzerUtils.getWordCounts(bq.getSearchText(), analyzer)); } BidMap<String, String> docIdMap = DocumentIdMapper.readDocumentIdMap(docMapFileNames[i]); // queryRelevances = RelevanceReader.filter(queryRelevances, docIdMap); // baseQueries = QueryReader.filter(baseQueries, queryRelevances); List<SparseVector> docRelData = DocumentIdMapper.mapDocIdsToIndexIds(bqs, queryRels, docIdMap); IndexReader indexReader = indexSearchers[i].getIndexReader(); if (bqs.size() != docRelData.size()) { throw new Exception(); } TextFileWriter writer = new TextFileWriter(queryDocFileNames[i]); for (int j = 0; j < bqs.size(); j++) { BaseQuery bq = bqs.get(j); SparseVector docRels = docRelData.get(j); Indexer<String> wordIndexer = new Indexer<String>(); Counter<String> qwcs = qs.get(j); SparseVector q = VectorUtils.toSparseVector(qs.get(j), wordIndexer, true); { SparseVector docFreqs = VectorUtils.toSparseVector( WordCountBox.getDocFreqs(indexReader, IndexFieldName.CONTENT, qs.get(j).keySet()), wordIndexer, true); computeTFIDFs(q, docFreqs, indexReader.maxDoc()); } WordCountBox wcb = WordCountBox.getWordCountBox(indexReader, docRels, wordIndexer); SparseMatrix sm = wcb.getDocWordCounts(); SparseVector docFreqs = wcb.getCollDocFreqs(); for (int k = 0; k < sm.rowSize(); k++) { int docId = sm.indexAtRowLoc(k); SparseVector sv = sm.vectorAtRowLoc(k); computeTFIDFs(sv, docFreqs, wcb.getNumDocsInCollection()); } writer.write(String.format("#Query\t%d\t%s\n", j + 1, bq.toString())); writer.write( String.format("#Query Words\t%s\n", toString(VectorUtils.toCounter(q, wordIndexer)))); docRels.sortByValue(); for (int k = 0; k < docRels.size(); k++) { int docId = docRels.indexAtLoc(k); double rel = docRels.valueAtLoc(k); if (rel == 0) { continue; } Document doc = indexReader.document(docId); List<Integer> ws = wcb.getDocWords().get(docId); StringBuffer sb = new StringBuffer(); for (int l = 0; l < ws.size(); l++) { int w = ws.get(l); boolean found = false; if (q.location(w) > -1) { found = true; } if (found) { sb.append( String.format( "%d\t%s\t%s\n", l + 1, wordIndexer.getObject(w), found ? 1 + "" : "")); } } String content = doc.get(IndexFieldName.CONTENT); SparseVector sv = sm.rowAlways(docId); if (sv.size() == 0) { continue; } writer.write(String.format("DOC-ID\t%d\nRelevance\t%d\n", docId, (int) rel)); writer.write(String.format("Loc\tWord\tMark\n%s\n", sb.toString())); } writer.write("\n"); } } }