Exemplo n.º 1
0
  /**
   * 显示决策树
   *
   * @param node 待显示的节点
   * @param blankNum 行空格符,用于显示树型结构
   */
  private void showDecisionTree(AttrNode node, int blankNum) {
    System.out.println();
    for (int i = 0; i < blankNum; i++) {
      System.out.print("\t");
    }
    System.out.print("--");
    // 显示分类的属性值
    if (node.getParentAttrValue() != null && node.getParentAttrValue().length() > 0) {
      System.out.print(node.getParentAttrValue());
    } else {
      System.out.print("--");
    }
    System.out.print("--");

    if (node.getChildDataIndex() != null && node.getChildDataIndex().size() > 0) {
      String i = node.getChildDataIndex().get(0);
      System.out.print("类别:" + data[Integer.parseInt(i)][attrNames.length - 1]);
      System.out.print("[");
      for (String index : node.getChildDataIndex()) {
        System.out.print(index + ", ");
      }
      System.out.print("]");
    } else {
      // 递归显示子节点
      System.out.print("【" + node.getAttrName() + "】");
      for (AttrNode childNode : node.getChildAttrNode()) {
        showDecisionTree(childNode, 2 * blankNum);
      }
    }
  }
Exemplo n.º 2
0
  /** 利用源数据构造决策树 */
  private void buildDecisionTree(
      AttrNode node,
      String parentAttrValue,
      String[][] remainData,
      ArrayList<String> remainAttr,
      boolean isID3) {
    node.setParentAttrValue(parentAttrValue);

    String attrName = "";
    double gainValue = 0;
    double tempValue = 0;

    // 如果只有1个属性则直接返回
    if (remainAttr.size() == 1) {
      System.out.println("attr null");
      return;
    }

    // 选择剩余属性中信息增益最大的作为下一个分类的属性
    for (int i = 0; i < remainAttr.size(); i++) {
      // 判断是否用ID3算法还是C4.5算法
      if (isID3) {
        // ID3算法采用的是按照信息增益的值来比
        tempValue = computeGain(remainData, remainAttr.get(i));
      } else {
        // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
        tempValue = computeGainRatio(remainData, remainAttr.get(i));
      }

      if (tempValue > gainValue) {
        gainValue = tempValue;
        attrName = remainAttr.get(i);
      }
    }

    node.setAttrName(attrName);
    ArrayList<String> valueTypes = attrValue.get(attrName);
    remainAttr.remove(attrName);

    AttrNode[] childNode = new AttrNode[valueTypes.size()];
    String[][] rData;
    for (int i = 0; i < valueTypes.size(); i++) {
      // 移除非此值类型的数据
      rData = removeData(remainData, attrName, valueTypes.get(i));

      childNode[i] = new AttrNode();
      boolean sameClass = true;
      ArrayList<String> indexArray = Lists.newArrayList();
      for (int k = 1; k < rData.length; k++) {
        indexArray.add(rData[k][0]);
        // 判断是否为同一类的
        if (!rData[k][attrNames.length - 1].equals(rData[1][attrNames.length - 1])) {
          // 只要有1个不相等,就不是同类型的
          sameClass = false;
          break;
        }
      }

      if (!sameClass) {
        // 创建新的对象属性,对象的同个引用会出错
        ArrayList<String> rAttr = Lists.newArrayList();
        for (String str : remainAttr) {
          rAttr.add(str);
        }

        buildDecisionTree(childNode[i], valueTypes.get(i), rData, rAttr, isID3);
      } else {
        // 如果是同种类型,则直接为数据节点
        childNode[i].setParentAttrValue(valueTypes.get(i));
        childNode[i].setChildDataIndex(indexArray);
      }
    }
    node.setChildAttrNode(childNode);
  }