コード例 #1
0
ファイル: NGram.java プロジェクト: patterncat/sina-services
 @Override
 public void addThruPipe(Instance inst) {
   Object data = inst.getData();
   List<String> tokens = Collections.emptyList();
   if (data instanceof String) {
     tokens = Arrays.asList(((String) data).split("\\s+"));
   } else if (data instanceof List) {
     tokens = (List<String>) data;
   }
   ArrayList<String> list = new ArrayList<String>();
   StringBuffer buf = new StringBuffer();
   for (int j = 0; j < gramSizes.length; j++) {
     int len = gramSizes[j];
     if (len <= 0 || len > tokens.size()) continue;
     for (int i = 0; i < tokens.size() - len + 1; i++) {
       buf.delete(0, buf.length());
       int k = 0;
       for (; k < len - 1; ++k) {
         buf.append(tokens.get(i + k));
         buf.append(' ');
       }
       buf.append(tokens.get(i + k));
       list.add(buf.toString().intern());
     }
   }
   inst.setData(list);
 }
コード例 #2
0
  /**
   * 构造并初始化网格
   *
   * @param carrier 样本实例
   * @return 推理网格
   */
  protected Node[][] initialLattice(Instance carrier) {
    int[][] data = (int[][]) carrier.getData();

    int length = carrier.length();

    Node[][] lattice = new Node[length][];
    for (int l = 0; l < length; l++) {
      lattice[l] = new Node[ysize];
      for (int c = 0; c < ysize; c++) {
        lattice[l][c] = new Node(ysize);
        for (int i = 0; i < orders.length; i++) {
          if (data[l][i] == -1 || data[l][i] >= weights.length) // TODO: xpqiu 2013.2.1
          continue;
          if (orders[i] == 0) {
            lattice[l][c].score += weights[data[l][i] + c];
          } else if (orders[i] == 1) {
            int offset = c;
            for (int p = 0; p < ysize; p++) {
              // weights对应trans(c,p)的按行展开
              lattice[l][c].trans[p] += weights[data[l][i] + offset];
              offset += ysize;
            }
          }
        }
      }
    }

    return lattice;
  }
コード例 #3
0
ファイル: PATrainer.java プロジェクト: Ericva/java
  /**
   * 用当前模型在测试集上进行测试 输出正确率
   *
   * @param testSet
   */
  public void test(InstanceSet testSet) {

    double err = 0;
    double errorAll = 0;
    int total = 0;
    for (int i = 0; i < testSet.size(); i++) {
      Instance inst = testSet.getInstance(i);
      total += ((int[]) inst.getTarget()).length;
      Results pred = (Results) msolver.getBest(inst, 1);
      double l = loss.calc(pred.getPredAt(0), inst.getTarget());
      if (l > 0) { // 预测错误
        errorAll += 1.0;
        err += l;
      }
    }
    if (!simpleOutput) {
      System.out.print("Test:\t");
      System.out.print(total - err);
      System.out.print('/');
      System.out.print(total);
      System.out.print("\tTag acc:");
    } else {
      System.out.print('\t');
    }
    System.out.print(1 - err / total);
    if (!simpleOutput) {
      System.out.print("\tSentence acc:");
      System.out.println(1 - errorAll / testSet.size());
    }
    System.out.println();
  }
コード例 #4
0
 @Override
 public void addThruPipe(Instance inst) throws Exception {
   String str = (String) inst.getSource();
   BinarySparseVector sv = (BinarySparseVector) inst.getData();
   List<RETemplate> templates = new ArrayList<RETemplate>();
   for (int i = 0; i < group.size(); i++) {
     RETemplate qt = group.get(i);
     float w = qt.matches(str);
     if (w > 0) {
       //				System.out.println(qt.comment);
       int id = features.lookupIndex("template:" + qt.comment);
       sv.put(id);
     }
   }
 }
コード例 #5
0
ファイル: MultiLinearMax.java プロジェクト: Ericva/java
  public Results getBest(Instance inst, int n) {
    Integer target = null;
    if (isUseTarget) target = (Integer) inst.getTarget();

    SparseVector fv = featureGen.getVector(inst);

    // 每个类对应的内积
    double[] sw = new double[alphabet.size()];
    Callable<Double>[] c = new Multiplesolve[numClass];
    Future<Double>[] f = new Future[numClass];

    for (int i = 0; i < numClass; i++) {
      c[i] = new Multiplesolve(fv, i);
      f[i] = pool.submit(c[i]);
    }

    // 执行任务并获取Future对象
    for (int i = 0; i < numClass; i++) {
      try {
        sw[i] = (Double) f[i].get();
      } catch (Exception e) {
        e.printStackTrace();
        return null;
      }
    }

    Results res = new Results(n);
    if (target != null) {
      res.buildOracle();
    }

    Iterator<Integer> it = leafs.iterator();

    while (it.hasNext()) {
      double score = 0.0;
      Integer i = it.next();

      if (tree != null) { // 计算含层次信息的内积
        ArrayList<Integer> anc = tree.getPath(i);
        for (int j = 0; j < anc.size(); j++) {
          score += sw[anc.get(j)];
        }
      } else {
        score = sw[i];
      }

      // 给定目标范围是,只计算目标范围的值
      if (target != null && target.equals(i)) {
        res.addOracle(score, i);
      } else {
        res.addPred(score, i);
      }
    }
    return res;
  }
コード例 #6
0
  /** @return 预测序列和对照序列之间不同的Clique数量 */
  @Override
  protected int diff(Instance inst, float[] weights, Object targets, Object predicts) {

    data = (int[][]) inst.getData();

    if (targets == null) golds = (int[]) inst.getTarget();
    else golds = (int[]) targets;
    preds = (int[]) predicts;

    int diff = 0;

    if (golds[0] != preds[0]) {
      diff++;
      diffClique(weights, 0);
    }
    for (int p = 1; p < data.length; p++) {
      if (golds[p - 1] != preds[p - 1] || golds[p] != preds[p]) {
        diff++;
        diffClique(weights, p);
      }
    }

    return diff;
  }
コード例 #7
0
ファイル: AR_Reader.java プロジェクト: jilen/snlp
 private void dothis() throws Exception {
   list = new LinkedList<Instance>();
   Entity ss = null;
   Entity s2 = null;
   EntityGroup eg = null;
   FeatureGeter fp = null;
   Instance in = null;
   Iterator<Entity> it = null;
   List<String> newdata = null;
   while (ll.size() > 0) {
     ss = (Entity) ll.poll();
     it = ll.iterator();
     while (it.hasNext()) {
       s2 = (Entity) it.next();
       eg = new EntityGroup(ss, s2);
       fp = new FeatureGeter(eg);
       String[] tokens = this.intArrayToString(fp.getFeatrue()).split("\\t+|\\s+");
       newdata = Arrays.asList(tokens);
       in = new Instance(newdata, null);
       in.setSource(eg);
       list.add(in);
     }
   }
 }
コード例 #8
0
ファイル: DictLabel.java プロジェクト: HarveyTvT/SocialPlus
  public void addThruPipe(Instance instance) throws Exception {
    String[][] data = (String[][]) instance.getData();

    int length = data[0].length;
    int[][] dicData = new int[length][labels.size()];

    int indexLen = dict.getIndexLen();
    for (int i = 0; i < length; i++) {
      if (i + indexLen <= length) {
        WordInfo s = getNextN(data[0], i, indexLen);
        int[] index = dict.getIndex(s.word);
        if (index != null) {
          for (int k = 0; k < index.length; k++) {
            int n = index[k];
            if (n == indexLen) { // 下面那个check函数的特殊情况,只为了加速
              label(i, s.len, dicData);
              if (!mutiple) {
                i = i + s.len;
                break;
              }
            }
            int len = check(i, n, length, data[0], dicData);
            if (len > 0 && !mutiple) {
              i = i + len;
              break;
            }
          }
        }
      }
    }

    for (int i = 0; i < length; i++)
      if (hasWay(dicData[i])) for (int j = 0; j < dicData[i].length; j++) dicData[i][j]++;

    instance.setDicData(dicData);
  }
コード例 #9
0
ファイル: PATrainer.java プロジェクト: Ericva/java
  /** 训练 */
  public Linear train(InstanceSet trainingList, InstanceSet testList) {
    int numSamples = trainingList.size();
    count = 0;
    for (int ii = 0; ii < trainingList.size(); ii++) {
      Instance inst = trainingList.getInstance(ii);
      count += ((int[]) inst.getTarget()).length;
    }

    System.out.println("Chars Number: " + count);

    double oldErrorRate = Double.MAX_VALUE;

    // 开始循环
    long beginTime, endTime;
    long beginTimeIter, endTimeIter;
    beginTime = System.currentTimeMillis();
    double pE = 0;
    int iter = 0;
    int frac = numSamples / 10;
    while (iter++ < maxIter) {
      if (!simpleOutput) {
        System.out.print("iter:");
        System.out.print(iter + "\t");
      }
      double err = 0;
      double errorAll = 0;
      beginTimeIter = System.currentTimeMillis();
      int progress = frac;
      for (int ii = 0; ii < numSamples; ii++) {
        Instance inst = trainingList.getInstance(ii);
        Results pred = (Results) msolver.getBest(inst, 1);
        double l = loss.calc(pred.getPredAt(0), inst.getTarget());
        if (l > 0) { // 预测错误,更新权重
          errorAll += 1.0;
          err += l;
          update.update(inst, weights, pred.getPredAt(0), c);
        } else {
          if (pred.size() > 1) update.update(inst, weights, pred.getPredAt(1), c);
        }
        if (!simpleOutput && ii % progress == 0) { // 显示进度
          System.out.print('.');
          progress += frac;
        }
      }
      double errRate = err / count;

      endTimeIter = System.currentTimeMillis();
      if (!simpleOutput) {
        System.out.println("\ttime:" + (endTimeIter - beginTimeIter) / 1000.0 + "s");
        System.out.print("Train:");
        System.out.print("\tTag acc:");
      }
      System.out.print(1 - errRate);
      if (!simpleOutput) {
        System.out.print("\tSentence acc:");
        System.out.print(1 - errorAll / numSamples);
        System.out.println();
      }
      if (testList != null) {
        test(testList);
      }
      if (Math.abs(errRate - oldErrorRate) < eps) {
        System.out.println("Convergence!");
        break;
      }
      oldErrorRate = errRate;
      if (interim) {
        Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
        try {
          p.saveTo("tmp.model");
        } catch (IOException e) {
          System.err.println("write model error!");
        }
      }
      if (isOptimized) { // 模型优化,去掉不显著的特征
        int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
        System.out.print("Opt: weight numbers: " + MyArrays.countNoneZero(weights));
        MyArrays.set(weights, idx, 0.0);
        System.out.println(" -> " + MyArrays.countNoneZero(weights));
      }
    }
    endTime = System.currentTimeMillis();
    System.out.println("done!");
    System.out.println("time escape:" + (endTime - beginTime) / 1000.0 + "s");
    Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
    return p;
  }