Ejemplo n.º 1
0
  // 构造一个tri-trainer分类器。
  public Tritrainer(
      String classifier, String trainingIns_File, String testIns_File, double precentage) {
    try {
      this.classifier1 = (Classifier) Class.forName(classifier).newInstance();
      this.classifier2 = (Classifier) Class.forName(classifier).newInstance();
      this.classifier3 = (Classifier) Class.forName(classifier).newInstance();

      Instances trainingInstances = Util.getInstances(trainingIns_File);

      // 将trainIns_File按照precentage和(1-precentage)的比例切割成labeledIns和unlabeledIns;
      int length = trainingInstances.numInstances();
      int i = new Double(length * precentage).intValue();
      labeledIns = new Instances(trainingInstances, 0);
      for (int j = 0; j < i; j++) {
        labeledIns.add(trainingInstances.firstInstance());
        trainingInstances.delete(0);
      }
      unlabeledIns = trainingInstances;
      testIns = Util.getInstances(testIns_File);

      Init();
    } catch (Exception e) {

    }
  }
Ejemplo n.º 2
0
 // 使用bootstrap方法从固定小样本集中获得三个样本集,并用这三个样本集训练三个不同的分类器;
 public void bootstrap() {
   Instances x = new Instances(labeledIns, 0);
   System.out.println("before training:");
   try {
     for (int i = 0; i < 3; i++) {
       x = this.labeledIns.resample(new Random());
       this.instan_Array[i] = x;
       class_Array[i].buildClassifier(x);
       err_Classifier[i] = Util.errorRate(this.class_Array[i], this.testIns);
       System.out.println(
           "classifier[" + i + "]:=" + Util.errorRate(this.class_Array[i], this.testIns));
     }
   } catch (Exception e) {
     System.out.println(e);
   }
 }
Ejemplo n.º 3
0
  // 打印出最终三个分类器的分类错误率;
  public void argMax() {

    System.out.println("after training:");
    for (int i = 0; i < 3; i++) {
      err_Classifier[i] = Util.errorRate(this.class_Array[i], this.testIns);
      System.out.println("classifier[" + i + "]:=" + err_Classifier[i]);
    }
  }
Ejemplo n.º 4
0
 // this method use MajorityVoting to decide the tritrainer's errorRate;
 public double errorRateByMajorityVoting() {
   return Util.errorRate(this, this.testIns);
 }
Ejemplo n.º 5
0
  // tri-training学习过程;
  public void training() {
    int length_L = 0;
    int up_int = 0;
    double temp = 0.0;

    // 直到没有分类器发生更新时,跳出循环;
    while (update1 || update2 || update3) {
      update1 = false;
      update2 = false;
      update3 = false;
      for (int i = 0; i < 3; i++) {

        ins_Array[i] = new Instances(testIns, 0);

        ins_Array[i].setClassIndex(testIns.numAttributes() - 1);

        switch (i) {
          case 0:
            j = 1;
            k = 2;
            break;
          case 1:
            j = 0;
            k = 2;
            break;
          case 2:
            j = 0;
            k = 1;
            break;
        }

        // 获得用于加强第i个分类器的其它两个分类器j,k的分类错误率;
        err_Array[i] = measureBothError(class_Array[j], class_Array[k], this.unlabeledIns);

        // 如果这个两个分类器j,k的分类错误率小于前一次的时候,运用这两个分类器为第i个分类器标记样本;
        if (err_Array[i] < error[i]) {

          // 获得两个分类器j,k分类做出相同决策得到的样本集合ins_Array[i]
          this.updateL(class_Array[j], class_Array[k], ins_Array[i], this.unlabeledIns);

          length_L = ins_Array[i].numInstances();

          if (length[i] == 0) {

            //	System.out.println("err_array[i] =" + err_Array[i] + " err=" + error );
            length[i] = this.getDownInt(err_Array[i], error[i]);
            //	System.out.println("length[i] =" + length[i]);
          }

          if (length[i] < length_L) {
            if (err_Array[i] * length_L < error[i] * length[i]) {
              this.update[i] = true;
            } else if (length[i] > (err_Array[i] / (error[i] - err_Array[i]))) {
              up_int = this.getUpInt(err_Array[i], error[i], length[i]);
              //	System.out.println("err_array[i] =" + err_Array[i] + " err=" + error + "length:=" +
              // length[i]);
              //	System.out.println("up_int=" + up_int );
              this.SubSample(this.ins_Array[i], up_int);
              this.update[i] = true;
            }
          }
        }
      }

      // 更新分类器
      for (int i = 0; i < 3; i++) {
        // 如果第i个分类器的update为true,更新该分类器;
        if (this.update[i]) {
          try {
            this.class_Array[i].buildClassifier(Util.add(this.instan_Array[i], this.ins_Array[i]));
            temp = Util.errorRate(this.class_Array[i], this.testIns);

            // 如果分类器更新以后的分类错误率比以前高,则恢复分类器到未更新时的状态。 这一点与论文中的算法有一点点不同。
            // 论文中没有这一步的判断。
            if (temp > err_Classifier[i]) {
              this.update[i] = false;
              this.class_Array[i].buildClassifier(this.instan_Array[i]);
            } else {
              // 如果分类器的分类错误率下降了,则更新length[i]以及error[i];
              length[i] = this.ins_Array[i].numInstances();
              error[i] = err_Array[i];
            }
          } catch (Exception e) {
            System.out.println(e);
          }
        }
      }
    }
  }