@Override public void addThruPipe(Instance inst) { Object data = inst.getData(); List<String> tokens = Collections.emptyList(); if (data instanceof String) { tokens = Arrays.asList(((String) data).split("\\s+")); } else if (data instanceof List) { tokens = (List<String>) data; } ArrayList<String> list = new ArrayList<String>(); StringBuffer buf = new StringBuffer(); for (int j = 0; j < gramSizes.length; j++) { int len = gramSizes[j]; if (len <= 0 || len > tokens.size()) continue; for (int i = 0; i < tokens.size() - len + 1; i++) { buf.delete(0, buf.length()); int k = 0; for (; k < len - 1; ++k) { buf.append(tokens.get(i + k)); buf.append(' '); } buf.append(tokens.get(i + k)); list.add(buf.toString().intern()); } } inst.setData(list); }
/** * 构造并初始化网格 * * @param carrier 样本实例 * @return 推理网格 */ protected Node[][] initialLattice(Instance carrier) { int[][] data = (int[][]) carrier.getData(); int length = carrier.length(); Node[][] lattice = new Node[length][]; for (int l = 0; l < length; l++) { lattice[l] = new Node[ysize]; for (int c = 0; c < ysize; c++) { lattice[l][c] = new Node(ysize); for (int i = 0; i < orders.length; i++) { if (data[l][i] == -1 || data[l][i] >= weights.length) // TODO: xpqiu 2013.2.1 continue; if (orders[i] == 0) { lattice[l][c].score += weights[data[l][i] + c]; } else if (orders[i] == 1) { int offset = c; for (int p = 0; p < ysize; p++) { // weights对应trans(c,p)的按行展开 lattice[l][c].trans[p] += weights[data[l][i] + offset]; offset += ysize; } } } } } return lattice; }
/** * 用当前模型在测试集上进行测试 输出正确率 * * @param testSet */ public void test(InstanceSet testSet) { double err = 0; double errorAll = 0; int total = 0; for (int i = 0; i < testSet.size(); i++) { Instance inst = testSet.getInstance(i); total += ((int[]) inst.getTarget()).length; Results pred = (Results) msolver.getBest(inst, 1); double l = loss.calc(pred.getPredAt(0), inst.getTarget()); if (l > 0) { // 预测错误 errorAll += 1.0; err += l; } } if (!simpleOutput) { System.out.print("Test:\t"); System.out.print(total - err); System.out.print('/'); System.out.print(total); System.out.print("\tTag acc:"); } else { System.out.print('\t'); } System.out.print(1 - err / total); if (!simpleOutput) { System.out.print("\tSentence acc:"); System.out.println(1 - errorAll / testSet.size()); } System.out.println(); }
@Override public void addThruPipe(Instance inst) throws Exception { String str = (String) inst.getSource(); BinarySparseVector sv = (BinarySparseVector) inst.getData(); List<RETemplate> templates = new ArrayList<RETemplate>(); for (int i = 0; i < group.size(); i++) { RETemplate qt = group.get(i); float w = qt.matches(str); if (w > 0) { // System.out.println(qt.comment); int id = features.lookupIndex("template:" + qt.comment); sv.put(id); } } }
public Results getBest(Instance inst, int n) { Integer target = null; if (isUseTarget) target = (Integer) inst.getTarget(); SparseVector fv = featureGen.getVector(inst); // 每个类对应的内积 double[] sw = new double[alphabet.size()]; Callable<Double>[] c = new Multiplesolve[numClass]; Future<Double>[] f = new Future[numClass]; for (int i = 0; i < numClass; i++) { c[i] = new Multiplesolve(fv, i); f[i] = pool.submit(c[i]); } // 执行任务并获取Future对象 for (int i = 0; i < numClass; i++) { try { sw[i] = (Double) f[i].get(); } catch (Exception e) { e.printStackTrace(); return null; } } Results res = new Results(n); if (target != null) { res.buildOracle(); } Iterator<Integer> it = leafs.iterator(); while (it.hasNext()) { double score = 0.0; Integer i = it.next(); if (tree != null) { // 计算含层次信息的内积 ArrayList<Integer> anc = tree.getPath(i); for (int j = 0; j < anc.size(); j++) { score += sw[anc.get(j)]; } } else { score = sw[i]; } // 给定目标范围是,只计算目标范围的值 if (target != null && target.equals(i)) { res.addOracle(score, i); } else { res.addPred(score, i); } } return res; }
/** @return 预测序列和对照序列之间不同的Clique数量 */ @Override protected int diff(Instance inst, float[] weights, Object targets, Object predicts) { data = (int[][]) inst.getData(); if (targets == null) golds = (int[]) inst.getTarget(); else golds = (int[]) targets; preds = (int[]) predicts; int diff = 0; if (golds[0] != preds[0]) { diff++; diffClique(weights, 0); } for (int p = 1; p < data.length; p++) { if (golds[p - 1] != preds[p - 1] || golds[p] != preds[p]) { diff++; diffClique(weights, p); } } return diff; }
private void dothis() throws Exception { list = new LinkedList<Instance>(); Entity ss = null; Entity s2 = null; EntityGroup eg = null; FeatureGeter fp = null; Instance in = null; Iterator<Entity> it = null; List<String> newdata = null; while (ll.size() > 0) { ss = (Entity) ll.poll(); it = ll.iterator(); while (it.hasNext()) { s2 = (Entity) it.next(); eg = new EntityGroup(ss, s2); fp = new FeatureGeter(eg); String[] tokens = this.intArrayToString(fp.getFeatrue()).split("\\t+|\\s+"); newdata = Arrays.asList(tokens); in = new Instance(newdata, null); in.setSource(eg); list.add(in); } } }
public void addThruPipe(Instance instance) throws Exception { String[][] data = (String[][]) instance.getData(); int length = data[0].length; int[][] dicData = new int[length][labels.size()]; int indexLen = dict.getIndexLen(); for (int i = 0; i < length; i++) { if (i + indexLen <= length) { WordInfo s = getNextN(data[0], i, indexLen); int[] index = dict.getIndex(s.word); if (index != null) { for (int k = 0; k < index.length; k++) { int n = index[k]; if (n == indexLen) { // 下面那个check函数的特殊情况,只为了加速 label(i, s.len, dicData); if (!mutiple) { i = i + s.len; break; } } int len = check(i, n, length, data[0], dicData); if (len > 0 && !mutiple) { i = i + len; break; } } } } } for (int i = 0; i < length; i++) if (hasWay(dicData[i])) for (int j = 0; j < dicData[i].length; j++) dicData[i][j]++; instance.setDicData(dicData); }
/** 训练 */ public Linear train(InstanceSet trainingList, InstanceSet testList) { int numSamples = trainingList.size(); count = 0; for (int ii = 0; ii < trainingList.size(); ii++) { Instance inst = trainingList.getInstance(ii); count += ((int[]) inst.getTarget()).length; } System.out.println("Chars Number: " + count); double oldErrorRate = Double.MAX_VALUE; // 开始循环 long beginTime, endTime; long beginTimeIter, endTimeIter; beginTime = System.currentTimeMillis(); double pE = 0; int iter = 0; int frac = numSamples / 10; while (iter++ < maxIter) { if (!simpleOutput) { System.out.print("iter:"); System.out.print(iter + "\t"); } double err = 0; double errorAll = 0; beginTimeIter = System.currentTimeMillis(); int progress = frac; for (int ii = 0; ii < numSamples; ii++) { Instance inst = trainingList.getInstance(ii); Results pred = (Results) msolver.getBest(inst, 1); double l = loss.calc(pred.getPredAt(0), inst.getTarget()); if (l > 0) { // 预测错误,更新权重 errorAll += 1.0; err += l; update.update(inst, weights, pred.getPredAt(0), c); } else { if (pred.size() > 1) update.update(inst, weights, pred.getPredAt(1), c); } if (!simpleOutput && ii % progress == 0) { // 显示进度 System.out.print('.'); progress += frac; } } double errRate = err / count; endTimeIter = System.currentTimeMillis(); if (!simpleOutput) { System.out.println("\ttime:" + (endTimeIter - beginTimeIter) / 1000.0 + "s"); System.out.print("Train:"); System.out.print("\tTag acc:"); } System.out.print(1 - errRate); if (!simpleOutput) { System.out.print("\tSentence acc:"); System.out.print(1 - errorAll / numSamples); System.out.println(); } if (testList != null) { test(testList); } if (Math.abs(errRate - oldErrorRate) < eps) { System.out.println("Convergence!"); break; } oldErrorRate = errRate; if (interim) { Linear p = new Linear(msolver, trainingList.getAlphabetFactory()); try { p.saveTo("tmp.model"); } catch (IOException e) { System.err.println("write model error!"); } } if (isOptimized) { // 模型优化,去掉不显著的特征 int[] idx = MyArrays.getTop(weights.clone(), threshold, false); System.out.print("Opt: weight numbers: " + MyArrays.countNoneZero(weights)); MyArrays.set(weights, idx, 0.0); System.out.println(" -> " + MyArrays.countNoneZero(weights)); } } endTime = System.currentTimeMillis(); System.out.println("done!"); System.out.println("time escape:" + (endTime - beginTime) / 1000.0 + "s"); Linear p = new Linear(msolver, trainingList.getAlphabetFactory()); return p; }