public MultiLinearMax(Generator featureGen, LabelAlphabet alphabet, Tree tree, int n) { this.featureGen = featureGen; this.alphabet = alphabet; numThread = n; this.tree = tree; pool = Executors.newFixedThreadPool(numThread); numClass = alphabet.size(); if (tree == null) { leafs = alphabet.toSet(); } else leafs = tree.getLeafs(); }
public Results getBest(Instance inst, int n) { Integer target = null; if (isUseTarget) target = (Integer) inst.getTarget(); SparseVector fv = featureGen.getVector(inst); // 每个类对应的内积 double[] sw = new double[alphabet.size()]; Callable<Double>[] c = new Multiplesolve[numClass]; Future<Double>[] f = new Future[numClass]; for (int i = 0; i < numClass; i++) { c[i] = new Multiplesolve(fv, i); f[i] = pool.submit(c[i]); } // 执行任务并获取Future对象 for (int i = 0; i < numClass; i++) { try { sw[i] = (Double) f[i].get(); } catch (Exception e) { e.printStackTrace(); return null; } } Results res = new Results(n); if (target != null) { res.buildOracle(); } Iterator<Integer> it = leafs.iterator(); while (it.hasNext()) { double score = 0.0; Integer i = it.next(); if (tree != null) { // 计算含层次信息的内积 ArrayList<Integer> anc = tree.getPath(i); for (int j = 0; j < anc.size(); j++) { score += sw[anc.get(j)]; } } else { score = sw[i]; } // 给定目标范围是,只计算目标范围的值 if (target != null && target.equals(i)) { res.addOracle(score, i); } else { res.addPred(score, i); } } return res; }