/**
   * Returns a textual description of this classifier.
   *
   * @return a textual description of this classifier.
   */
  @Override
  public String toString() {

    if (m_probOfClass == null) {
      return "NaiveBayesMultinomialText: No model built yet.\n";
    }

    StringBuffer result = new StringBuffer();

    // build a master dictionary over all classes
    HashSet<String> master = new HashSet<String>();
    for (int i = 0; i < m_data.numClasses(); i++) {
      LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);

      for (String key : classDict.keySet()) {
        master.add(key);
      }
    }

    result.append("Dictionary size: " + master.size()).append("\n\n");

    result.append("The independent frequency of a class\n");
    result.append("--------------------------------------\n");

    for (int i = 0; i < m_data.numClasses(); i++) {
      result
          .append(m_data.classAttribute().value(i))
          .append("\t")
          .append(Double.toString(m_probOfClass[i]))
          .append("\n");
    }

    result.append("\nThe frequency of a word given the class\n");
    result.append("-----------------------------------------\n");

    for (int i = 0; i < m_data.numClasses(); i++) {
      result.append(Utils.padLeft(m_data.classAttribute().value(i), 11)).append("\t");
    }

    result.append("\n");

    Iterator<String> masterIter = master.iterator();
    while (masterIter.hasNext()) {
      String word = masterIter.next();

      for (int i = 0; i < m_data.numClasses(); i++) {
        LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);
        Count c = classDict.get(word);
        if (c == null) {
          result.append("<laplace=1>\t");
        } else {
          result.append(Utils.padLeft(Double.toString(c.m_count), 11)).append("\t");
        }
      }
      result.append(word);
      result.append("\n");
    }

    return result.toString();
  }
  /**
   * Generates the classifier.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  @Override
  public void buildClassifier(Instances data) throws Exception {
    reset();

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    m_data = new Instances(data, 0);
    data = new Instances(data);

    m_wordsPerClass = new double[data.numClasses()];
    m_probOfClass = new double[data.numClasses()];
    m_probOfWordGivenClass = new HashMap<Integer, LinkedHashMap<String, Count>>();

    double laplace = 1.0;
    for (int i = 0; i < data.numClasses(); i++) {
      LinkedHashMap<String, Count> dict =
          new LinkedHashMap<String, Count>(10000 / data.numClasses());
      m_probOfWordGivenClass.put(i, dict);
      m_probOfClass[i] = laplace;

      // this needs to be updated for laplace correction every time we see a new
      // word (attribute)
      m_wordsPerClass[i] = 0;
    }

    for (int i = 0; i < data.numInstances(); i++) {
      updateClassifier(data.instance(i));
    }
  }
  public Vector<int[]> getMultiKmodesResults(
      Instances data, Instances dataforcluster, int ensemblesize) throws Exception {
    Vector<int[]> cls = new Vector<int[]>();
    int k = data.numClasses();

    for (int i = 0; i < ensemblesize; ++i) {

      int tmpk = Rnd.nextInt(ensemblesize * k);
      tmpk = tmpk <= k ? (tmpk + k) : tmpk;
      SimpleKMeans km = new SimpleKMeans();
      km.setMaxIterations(100);
      km.setNumClusters(tmpk);
      km.setDontReplaceMissingValues(true);
      km.setSeed(Rnd.nextInt());
      km.setPreserveInstancesOrder(true);
      km.buildClusterer(dataforcluster);
      SquaredError[i] = km.getSquaredError();

      /*EM em = new EM();
      em.setMaxIterations(100);
      em.setNumClusters(k);
      //em.setSeed(Rnd.nextInt());
      em.buildClusterer(dataforcluster);
      int[] res2 = new int[dataforcluster.numInstances()];
      for(int r=0;r<dataforcluster.numInstances();++r){
      	res2[r]=em.clusterInstance(dataforcluster.instance(r));
      }*/

      // System.out.println(Arrays.toString(km.getAssignments()));
      cls.add(km.getAssignments());
      // cls.add(res2);

    }
    return cls;
  }
  /**
   * Builds the model of the base learner.
   *
   * @param data the training data
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (m_Classifier == null) {
      throw new Exception("No base classifier has been set!");
    }
    if (m_MatrixSource == MATRIX_ON_DEMAND) {
      String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
      File costFile = new File(getOnDemandDirectory(), costName);
      if (!costFile.exists()) {
        throw new Exception("On-demand cost file doesn't exist: " + costFile);
      }
      setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile))));
    } else if (m_CostMatrix == null) {
      // try loading an old format cost file
      m_CostMatrix = new CostMatrix(data.numClasses());
      m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile)));
    }

    if (!m_MinimizeExpectedCost) {
      Random random = null;
      if (!(m_Classifier instanceof WeightedInstancesHandler)) {
        random = new Random(m_Seed);
      }
      data = m_CostMatrix.applyCostMatrix(data, random);
    }
    m_Classifier.buildClassifier(data);
  }
예제 #5
0
  @Override
  public void buildClassifier(Instances data) throws Exception {
    trainingData = data;
    Attribute classAttribute = data.classAttribute();
    prototypes = new ArrayList<>();

    classedData = new HashMap<String, ArrayList<Sequence>>();
    indexClassedDataInFullData = new HashMap<String, ArrayList<Integer>>();
    for (int c = 0; c < data.numClasses(); c++) {
      classedData.put(data.classAttribute().value(c), new ArrayList<Sequence>());
      indexClassedDataInFullData.put(data.classAttribute().value(c), new ArrayList<Integer>());
    }

    sequences = new Sequence[data.numInstances()];
    classMap = new String[sequences.length];
    for (int i = 0; i < sequences.length; i++) {
      Instance sample = data.instance(i);
      MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1];
      int shift = (sample.classIndex() == 0) ? 1 : 0;
      for (int t = 0; t < sequence.length; t++) {
        sequence[t] = new MonoDoubleItemSet(sample.value(t + shift));
      }
      sequences[i] = new Sequence(sequence);
      String clas = sample.stringValue(classAttribute);
      classMap[i] = clas;
      classedData.get(clas).add(sequences[i]);
      indexClassedDataInFullData.get(clas).add(i);
      //			System.out.println("Element "+i+" of train is classed "+clas+" and went to element
      // "+(indexClassedDataInFullData.get(clas).size()-1));
    }
    buildSpecificClassifier(data);
  }
예제 #6
0
  private static void writePredictedDistributions(
      Classifier c, Instances data, int idIndex, Writer out) throws Exception {
    // header
    out.write("id");
    for (int i = 0; i < data.numClasses(); i++) {
      out.write(",\"");
      out.write(data.classAttribute().value(i).replaceAll("[\"\\\\]", "_"));
      out.write("\"");
    }
    out.write("\n");

    // data
    for (int i = 0; i < data.numInstances(); i++) {
      final String id = data.instance(i).stringValue(idIndex);
      double[] distribution = c.distributionForInstance(data.instance(i));

      // final String label = data.attribute(classIndex).value();
      out.write(id);
      for (double probability : distribution) {
        out.write(",");
        out.write(String.valueOf(probability > 1e-5 ? (float) probability : 0f));
      }
      out.write("\n");
    }
  }
예제 #7
0
  /**
   * Stratify the given data into the given number of bags based on the class values. It differs
   * from the <code>Instances.stratify(int fold)</code> that before stratification it sorts the
   * instances according to the class order in the header file. It assumes no missing values in the
   * class.
   *
   * @param data the given data
   * @param folds the given number of folds
   * @param rand the random object used to randomize the instances
   * @return the stratified instances
   */
  public static final Instances stratify(Instances data, int folds, Random rand) {
    if (!data.classAttribute().isNominal()) return data;

    Instances result = new Instances(data, 0);
    Instances[] bagsByClasses = new Instances[data.numClasses()];

    for (int i = 0; i < bagsByClasses.length; i++) bagsByClasses[i] = new Instances(data, 0);

    // Sort by class
    for (int j = 0; j < data.numInstances(); j++) {
      Instance datum = data.instance(j);
      bagsByClasses[(int) datum.classValue()].add(datum);
    }

    // Randomize each class
    for (int j = 0; j < bagsByClasses.length; j++) bagsByClasses[j].randomize(rand);

    for (int k = 0; k < folds; k++) {
      int offset = k, bag = 0;
      oneFold:
      while (true) {
        while (offset >= bagsByClasses[bag].numInstances()) {
          offset -= bagsByClasses[bag].numInstances();
          if (++bag >= bagsByClasses.length) // Next bag
          break oneFold;
        }

        result.add(bagsByClasses[bag].instance(offset));
        offset += folds;
      }
    }

    return result;
  }
예제 #8
0
  /**
   * @param args
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    // TODO Auto-generated method stub

    oneAlgorithm oneAlg = new oneAlgorithm();
    oneAlg.category = xCategory.RSandFCBFalg;
    oneAlg.style = xStyle.fuzzySU;
    oneAlg.flag = false;
    oneAlg.alpha = 2.0;
    // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/wine.arff";
    // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/wdbc.arff";
    String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/glass.arff";
    // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/shen/wine-shen.arff";
    // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/fuzzy/fuzzy-ex.arff";
    // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/derm.arff";
    oneFile onef = new oneFile(new File(fn));
    Instances dataset = new Instances(new FileReader(fn));
    dataset.setClassIndex(dataset.numAttributes() - 1);
    onef.ins = dataset.numInstances();
    onef.att = dataset.numAttributes();
    onef.cla = dataset.numClasses();

    RSandFCBFReduceMethod rs = new RSandFCBFReduceMethod(onef, oneAlg);

    boolean[] B = new boolean[rs.NumAttr];
    boolean[] rq = rs.getOneReduction(B);
    System.out.println(Arrays.toString(Utils.boolean2select(rq)));
  }
  public double ExpectedClassificationError(Instances pool, int attr_i) {

    // initialize alpha's to one
    int alpha[][][];
    int NumberOfFeatures = pool.numAttributes() - 1;
    int NumberOfLabels = pool.numClasses();

    alpha = new int[NumberOfFeatures][NumberOfLabels][];
    for (int i = 0; i < NumberOfFeatures; i++)
      for (int j = 0; j < NumberOfLabels; j++) alpha[i][j] = new int[pool.attribute(i).numValues()];

    for (int i = 0; i < NumberOfFeatures; i++)
      for (int j = 0; j < NumberOfLabels; j++)
        for (int k = 0; k < alpha[i][j].length; k++) alpha[i][j][k] = 1;

    // construct alpha's
    for (int i = 0; i < NumberOfFeatures; i++) // for each attribute
    {
      if (i == pool.classIndex()) // skip the class attribute
      i++;
      for (Enumeration<Instance> e = pool.enumerateInstances();
          e.hasMoreElements(); ) // for each instance
      {
        Instance inst = e.nextElement();
        if (!inst.isMissing(i)) // if attribute i is not missing (i.e. its been bought)
        {
          int j = (int) inst.classValue();
          int k = (int) inst.value(i);
          alpha[i][j][k]++;
        }
      }
    }
    return ExpectedClassificationError(alpha, attr_i);
  }
예제 #10
0
파일: Id3.java 프로젝트: alishakiba/jDenetX
  /**
   * Computes the entropy of a dataset.
   *
   * @param data the data for which entropy is to be computed
   * @return the entropy of the data's class distribution
   * @throws Exception if computation fails
   */
  private double computeEntropy(Instances data) throws Exception {

    double[] classCounts = new double[data.numClasses()];
    Enumeration instEnum = data.enumerateInstances();
    while (instEnum.hasMoreElements()) {
      Instance inst = (Instance) instEnum.nextElement();
      classCounts[(int) inst.classValue()]++;
    }
    double entropy = 0;
    for (int j = 0; j < data.numClasses(); j++) {
      if (classCounts[j] > 0) {
        entropy -= classCounts[j] * Utils.log2(classCounts[j]);
      }
    }
    entropy /= (double) data.numInstances();
    return entropy + Utils.log2(data.numInstances());
  }
예제 #11
0
 /** Initializes the m_Attributes of the class. */
 private void init_m_Attributes() {
   try {
     m_NumInstances = m_Train.numInstances();
     m_NumClasses = m_Train.numClasses();
     m_NumAttributes = m_Train.numAttributes();
     m_ClassType = m_Train.classAttribute().type();
     m_InitFlag = ON;
   } catch (Exception e) {
     e.printStackTrace();
   }
 }
예제 #12
0
파일: Id3.java 프로젝트: alishakiba/jDenetX
  /**
   * Method for building an Id3 tree.
   *
   * @param data the training data
   * @exception Exception if decision tree can't be built successfully
   */
  private void makeTree(Instances data) throws Exception {

    // Check if no instances have reached this node.
    if (data.numInstances() == 0) {
      m_Attribute = null;
      m_ClassValue = Utils.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    // Compute attribute with maximum information gain.
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    while (attEnum.hasMoreElements()) {
      Attribute att = (Attribute) attEnum.nextElement();
      infoGains[att.index()] = computeInfoGain(data, att);
    }
    m_Attribute = data.attribute(Utils.maxIndex(infoGains));

    // Make leaf if information gain is zero.
    // Otherwise create successors.
    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
      m_Attribute = null;
      m_Distribution = new double[data.numClasses()];
      Enumeration instEnum = data.enumerateInstances();
      while (instEnum.hasMoreElements()) {
        Instance inst = (Instance) instEnum.nextElement();
        m_Distribution[(int) inst.classValue()]++;
      }
      Utils.normalize(m_Distribution);
      m_ClassValue = Utils.maxIndex(m_Distribution);
      m_ClassAttribute = data.classAttribute();
    } else {
      Instances[] splitData = splitData(data, m_Attribute);
      m_Successors = new Id3[m_Attribute.numValues()];
      for (int j = 0; j < m_Attribute.numValues(); j++) {
        m_Successors[j] = new Id3();
        m_Successors[j].makeTree(splitData[j]);
      }
    }
  }
예제 #13
0
  /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  @Override
  public String toString() {

    if (m_Instances == null) {
      return "Naive Bayes (simple): No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer("Naive Bayes (simple)");
      int attIndex;

      for (int i = 0; i < m_Instances.numClasses(); i++) {
        text.append(
            "\n\nClass "
                + m_Instances.classAttribute().value(i)
                + ": P(C) = "
                + Utils.doubleToString(m_Priors[i], 10, 8)
                + "\n\n");
        Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          text.append("Attribute " + attribute.name() + "\n");
          if (attribute.isNominal()) {
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(attribute.value(j) + "\t");
            }
            text.append("\n");
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t");
            }
          } else {
            text.append("Mean: " + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
            text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
          }
          text.append("\n\n");
          attIndex++;
        }
      }

      return text.toString();
    } catch (Exception e) {
      return "Can't print Naive Bayes classifier!";
    }
  }
예제 #14
0
  private static void analyse(Instances train, Instances datapredict) {
    String mpOptions = "-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a";

    try {

      train.setClassIndex(train.numAttributes() - 1);
      train.deleteAttributeAt(0);
      int numClasses = train.numClasses();

      for (int i = 0; i < numClasses; i++) {
        System.out.println("class value [" + i + "]=" + train.classAttribute().value(i) + "");
      }

      // Instance of NN
      MultilayerPerceptron mlp = new MultilayerPerceptron();
      mlp.setOptions(weka.core.Utils.splitOptions(mpOptions));
      mlp.buildClassifier(train);

      datapredict.setClassIndex(datapredict.numAttributes() - 1);
      datapredict.deleteAttributeAt(0);

      // Instances predicteddata = new Instances(datapredict);
      for (int i = 0; i < datapredict.numInstances(); i++) {

        Instance newInst = datapredict.instance(i);
        double pred = mlp.classifyInstance(newInst);
        int predInt = (int) pred; // Math.round(pred);
        String predString = train.classAttribute().value(predInt);
        System.out.println(
            "cliente["
                + i
                + "] pred["
                + pred
                + "] predInt["
                + predInt
                + "] desertor["
                + predString
                + "]");
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
  private Vector<int[]> getMultiKmodesResultswithRandomSelectFeature(
      Instances data, Instances dataforcluster, int ensemblesize) throws Exception {
    // TODO Auto-generated method stub
    Vector<int[]> cls = new Vector<int[]>();

    for (int i = 0; i < ensemblesize; ++i) {
      SimpleKMeans km = new SimpleKMeans();
      km.setMaxIterations(100);
      km.setNumClusters(data.numClasses());
      km.setDontReplaceMissingValues(true);
      km.setSeed(Rnd.nextInt());
      Instances newData = getNewRandomData(dataforcluster, Rnd.nextInt());
      km.setPreserveInstancesOrder(true);
      km.buildClusterer(newData);
      SquaredError[i] = km.getSquaredError();
      cls.add(km.getAssignments());
    }
    return cls;
  }
  public int SelectRow_L2Norm(
      Instances pool, Classifier myEstimator, int desiredAttr, int desiredLabel) {

    // for each instance with unbought desiredAttr and label = desiredLabel
    // measure distance from uniform
    // choose (row) that is closest to uniform as your instance to buy from

    double leastDistance = Double.MAX_VALUE;
    int leastIndex = -1;
    Instance inst;
    int n = pool.numClasses();
    double[] uniform;
    double[] probs;
    uniform = new double[n];
    for (int i = 0; i < n; i++) uniform[i] = 1.0 / (double) n;

    for (int i = 0; i < pool.numInstances(); i++) {
      inst = pool.instance(i);
      // System.out.println("currentlabel="+(int)inst.classValue()+"
      // isMissing="+inst.isMissing(desiredAttr));
      if ((int) inst.classValue() == desiredLabel && inst.isMissing(desiredAttr)) {
        // valid instance
        // measure the distance from uniform:
        // sqrt{ sum_i (a_i - b_i)^2 }
        probs = new double[n];
        try {
          probs = myEstimator.distributionForInstance(inst);
        } catch (Exception e) {
          // TODO Auto-generated catch block
          e.printStackTrace();
        }
        double distance = 0.0;
        for (int j = 0; j < n; j++) distance += (probs[j] - uniform[j]) * (probs[j] - uniform[j]);
        distance = Math.sqrt(distance);
        // System.out.println("current distance="+distance);
        if (distance < leastDistance) {
          leastDistance = distance;
          leastIndex = i;
        }
      }
    }
    return leastIndex;
  }
예제 #17
0
  /**
   * Creates split on enumerated attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleEnumeratedAttribute(Instances trainInstances) throws Exception {

    Distribution newDistribution, secondDistribution;
    int numAttValues;
    double currIG, currGR;
    Instance instance;
    int i;

    numAttValues = trainInstances.attribute(m_attIndex).numValues();
    newDistribution = new Distribution(numAttValues, trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex))
        newDistribution.add((int) instance.value(m_attIndex), instance);
    }
    m_distribution = newDistribution;

    // For all values
    for (i = 0; i < numAttValues; i++) {

      if (Utils.grOrEq(newDistribution.perBag(i), m_minNoObj)) {
        secondDistribution = new Distribution(newDistribution, i);

        // Check if minimum number of Instances in the two
        // subsets.
        if (secondDistribution.check(m_minNoObj)) {
          m_numSubsets = 2;
          currIG = m_infoGainCrit.splitCritValue(secondDistribution, m_sumOfWeights);
          currGR = m_gainRatioCrit.splitCritValue(secondDistribution, m_sumOfWeights, currIG);
          if ((i == 0) || Utils.gr(currGR, m_gainRatio)) {
            m_gainRatio = currGR;
            m_infoGain = currIG;
            m_splitPoint = (double) i;
            m_distribution = secondDistribution;
          }
        }
      }
    }
  }
예제 #18
0
  /**
   * Creates split on enumerated attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleEnumeratedAttribute(Instances trainInstances) throws Exception {

    Instance instance;

    m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration<Instance> enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = enu.nextElement();
      if (!instance.isMissing(m_attIndex)) {
        m_distribution.add((int) instance.value(m_attIndex), instance);
      }
    }

    // Check if minimum number of Instances in at least two
    // subsets.
    if (m_distribution.check(m_minNoObj)) {
      m_numSubsets = m_complexityIndex;
      m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights);
      m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain);
    }
  }
예제 #19
0
 public Map<FV, Collection<FV>> extractValuesFromData(Instances inst) {
   Multimap<FV, FV> fv_list = ArrayListMultimap.create();
   // Instances outFormat = getOutputFormat();
   for (int i = 0; i < inst.numInstances(); i++) {
     Instance ins = inst.instance(i);
     // Skip the class label
     for (int x = 0; x < ins.numAttributes() - 1; x++) {
       Object value = null;
       try {
         value = ins.stringValue(x);
       } catch (Exception e) {
         value = ins.value(x);
       }
       FV fv = new FV(x, value, ins.classValue());
       fv.setNumLabels(inst.numClasses());
       if (!fv_list.put(fv, fv)) {
         System.err.println("Couldn't put duplicates: " + fv);
       }
     }
   }
   Map<FV, Collection<FV>> original_map = fv_list.asMap();
   return original_map;
 }
예제 #20
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_NumClasses = instances.numClasses();
    m_ClassType = instances.classAttribute().type();
    m_Train = new Instances(instances, 0, instances.numInstances());

    // Throw away initial instances until within the specified window size
    if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
      m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize);
    }

    m_NumAttributesUsed = 0.0;
    for (int i = 0; i < m_Train.numAttributes(); i++) {
      if ((i != m_Train.classIndex())
          && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) {
        m_NumAttributesUsed += 1.0;
      }
    }

    m_NNSearch.setInstances(m_Train);

    // Invalidate any currently cross-validation selected k
    m_kNNValid = false;

    m_defaultModel = new ZeroR();
    m_defaultModel.buildClassifier(instances);
    m_defaultModel.setOptions(getOptions());
    // System.out.println("hello world");

  }
예제 #21
0
  /**
   * Selects cutpoints for sorted subset.
   *
   * @param instances
   * @param attIndex
   * @param first
   * @param lastPlusOne
   * @return
   */
  private double[] cutPointsForSubset(
      Instances instances, int attIndex, int first, int lastPlusOne) {

    double[][] counts, bestCounts;
    double[] priorCounts, left, right, cutPoints;
    double currentCutPoint = -Double.MAX_VALUE,
        bestCutPoint = -1,
        currentEntropy,
        bestEntropy,
        priorEntropy,
        gain;
    int bestIndex = -1, numCutPoints = 0;
    double numInstances = 0;

    // Compute number of instances in set
    if ((lastPlusOne - first) < 2) {
      return null;
    }

    // Compute class counts.
    counts = new double[2][instances.numClasses()];
    for (int i = first; i < lastPlusOne; i++) {
      numInstances += instances.instance(i).weight();
      counts[1][(int) instances.instance(i).classValue()] += instances.instance(i).weight();
    }

    // Save prior counts
    priorCounts = new double[instances.numClasses()];
    System.arraycopy(counts[1], 0, priorCounts, 0, instances.numClasses());

    // Entropy of the full set
    priorEntropy = ContingencyTables.entropy(priorCounts);
    bestEntropy = priorEntropy;

    // Find best entropy.
    bestCounts = new double[2][instances.numClasses()];
    for (int i = first; i < (lastPlusOne - 1); i++) {
      counts[0][(int) instances.instance(i).classValue()] += instances.instance(i).weight();
      counts[1][(int) instances.instance(i).classValue()] -= instances.instance(i).weight();
      if (instances.instance(i).value(attIndex) < instances.instance(i + 1).value(attIndex)) {
        currentCutPoint =
            (instances.instance(i).value(attIndex) + instances.instance(i + 1).value(attIndex))
                / 2.0;
        currentEntropy = ContingencyTables.entropyConditionedOnRows(counts);
        if (currentEntropy < bestEntropy) {
          bestCutPoint = currentCutPoint;
          bestEntropy = currentEntropy;
          bestIndex = i;
          System.arraycopy(counts[0], 0, bestCounts[0], 0, instances.numClasses());
          System.arraycopy(counts[1], 0, bestCounts[1], 0, instances.numClasses());
        }
        numCutPoints++;
      }
    }

    // Use worse encoding?
    if (!m_UseBetterEncoding) {
      numCutPoints = (lastPlusOne - first) - 1;
    }

    // Checks if gain is zero
    gain = priorEntropy - bestEntropy;
    if (gain <= 0) {
      return null;
    }

    // Check if split is to be accepted
    if ((m_UseKononenko && KononenkosMDL(priorCounts, bestCounts, numInstances, numCutPoints))
        || (!m_UseKononenko
            && FayyadAndIranisMDL(priorCounts, bestCounts, numInstances, numCutPoints))) {

      // Select split points for the left and right subsets
      left = cutPointsForSubset(instances, attIndex, first, bestIndex + 1);
      right = cutPointsForSubset(instances, attIndex, bestIndex + 1, lastPlusOne);

      // Merge cutpoints and return them
      if ((left == null) && (right) == null) {
        cutPoints = new double[1];
        cutPoints[0] = bestCutPoint;
      } else if (right == null) {
        cutPoints = new double[left.length + 1];
        System.arraycopy(left, 0, cutPoints, 0, left.length);
        cutPoints[left.length] = bestCutPoint;
      } else if (left == null) {
        cutPoints = new double[1 + right.length];
        cutPoints[0] = bestCutPoint;
        System.arraycopy(right, 0, cutPoints, 1, right.length);
      } else {
        cutPoints = new double[left.length + right.length + 1];
        System.arraycopy(left, 0, cutPoints, 0, left.length);
        cutPoints[left.length] = bestCutPoint;
        System.arraycopy(right, 0, cutPoints, left.length + 1, right.length);
      }

      return cutPoints;
    } else {
      return null;
    }
  }
예제 #22
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_headerInfo = new Instances(instances, 0);
    m_numClasses = instances.numClasses();
    m_numAttributes = instances.numAttributes();
    m_probOfWordGivenClass = new double[m_numClasses][];

    /*
      initialising the matrix of word counts
      NOTE: Laplace estimator introduced in case a word that does not appear for a class in the
      training set does so for the test set
    */
    for (int c = 0; c < m_numClasses; c++) {
      m_probOfWordGivenClass[c] = new double[m_numAttributes];
      for (int att = 0; att < m_numAttributes; att++) {
        m_probOfWordGivenClass[c][att] = 1;
      }
    }

    // enumerate through the instances
    Instance instance;
    int classIndex;
    double numOccurences;
    double[] docsPerClass = new double[m_numClasses];
    double[] wordsPerClass = new double[m_numClasses];

    java.util.Enumeration enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      instance = (Instance) enumInsts.nextElement();
      classIndex = (int) instance.value(instance.classIndex());
      docsPerClass[classIndex] += instance.weight();

      for (int a = 0; a < instance.numValues(); a++)
        if (instance.index(a) != instance.classIndex()) {
          if (!instance.isMissing(a)) {
            numOccurences = instance.valueSparse(a) * instance.weight();
            if (numOccurences < 0)
              throw new Exception("Numeric attribute values must all be greater or equal to zero.");
            wordsPerClass[classIndex] += numOccurences;
            m_probOfWordGivenClass[classIndex][instance.index(a)] += numOccurences;
          }
        }
    }

    /*
      normalising probOfWordGivenClass values
      and saving each value as the log of each value
    */
    for (int c = 0; c < m_numClasses; c++)
      for (int v = 0; v < m_numAttributes; v++)
        m_probOfWordGivenClass[c][v] =
            Math.log(m_probOfWordGivenClass[c][v] / (wordsPerClass[c] + m_numAttributes - 1));

    /*
      calculating Pr(H)
      NOTE: Laplace estimator introduced in case a class does not get mentioned in the set of
      training instances
    */
    final double numDocs = instances.sumOfWeights() + m_numClasses;
    m_probOfClass = new double[m_numClasses];
    for (int h = 0; h < m_numClasses; h++)
      m_probOfClass[h] = (double) (docsPerClass[h] + 1) / numDocs;
  }
예제 #23
0
파일: VFI.java 프로젝트: SuperWan/weka
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (!m_weightByConfidence) {
      TINY = 0.0;
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_ClassIndex = instances.classIndex();
    m_NumClasses = instances.numClasses();
    m_globalCounts = new double[m_NumClasses];
    m_maxEntrop = Math.log(m_NumClasses) / Math.log(2);

    m_Instances = new Instances(instances, 0); // Copy the structure for ref

    m_intervalBounds = new double[instances.numAttributes()][2 + (2 * m_NumClasses)];

    for (int j = 0; j < instances.numAttributes(); j++) {
      boolean alt = false;
      for (int i = 0; i < m_NumClasses * 2 + 2; i++) {
        if (i == 0) {
          m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
        } else if (i == m_NumClasses * 2 + 1) {
          m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
        } else {
          if (alt) {
            m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
            alt = false;
          } else {
            m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
            alt = true;
          }
        }
      }
    }

    // find upper and lower bounds for numeric attributes
    for (int j = 0; j < instances.numAttributes(); j++) {
      if (j != m_ClassIndex && instances.attribute(j).isNumeric()) {
        for (int i = 0; i < instances.numInstances(); i++) {
          Instance inst = instances.instance(i);
          if (!inst.isMissing(j)) {
            if (inst.value(j) < m_intervalBounds[j][((int) inst.classValue() * 2 + 1)]) {
              m_intervalBounds[j][((int) inst.classValue() * 2 + 1)] = inst.value(j);
            }
            if (inst.value(j) > m_intervalBounds[j][((int) inst.classValue() * 2 + 2)]) {
              m_intervalBounds[j][((int) inst.classValue() * 2 + 2)] = inst.value(j);
            }
          }
        }
      }
    }

    m_counts = new double[instances.numAttributes()][][];

    // sort intervals
    for (int i = 0; i < instances.numAttributes(); i++) {
      if (instances.attribute(i).isNumeric()) {
        int[] sortedIntervals = Utils.sort(m_intervalBounds[i]);
        // remove any duplicate bounds
        int count = 1;
        for (int j = 1; j < sortedIntervals.length; j++) {
          if (m_intervalBounds[i][sortedIntervals[j]]
              != m_intervalBounds[i][sortedIntervals[j - 1]]) {
            count++;
          }
        }
        double[] reordered = new double[count];
        count = 1;
        reordered[0] = m_intervalBounds[i][sortedIntervals[0]];
        for (int j = 1; j < sortedIntervals.length; j++) {
          if (m_intervalBounds[i][sortedIntervals[j]]
              != m_intervalBounds[i][sortedIntervals[j - 1]]) {
            reordered[count] = m_intervalBounds[i][sortedIntervals[j]];
            count++;
          }
        }
        m_intervalBounds[i] = reordered;
        m_counts[i] = new double[count][m_NumClasses];
      } else if (i != m_ClassIndex) { // nominal attribute
        m_counts[i] = new double[instances.attribute(i).numValues()][m_NumClasses];
      }
    }

    // collect class counts
    for (int i = 0; i < instances.numInstances(); i++) {
      Instance inst = instances.instance(i);
      m_globalCounts[(int) instances.instance(i).classValue()] += inst.weight();
      for (int j = 0; j < instances.numAttributes(); j++) {
        if (!inst.isMissing(j) && j != m_ClassIndex) {
          if (instances.attribute(j).isNumeric()) {
            double val = inst.value(j);

            int k;
            for (k = m_intervalBounds[j].length - 1; k >= 0; k--) {
              if (val > m_intervalBounds[j][k]) {
                m_counts[j][k][(int) inst.classValue()] += inst.weight();
                break;
              } else if (val == m_intervalBounds[j][k]) {
                m_counts[j][k][(int) inst.classValue()] += (inst.weight() / 2.0);
                m_counts[j][k - 1][(int) inst.classValue()] += (inst.weight() / 2.0);
                ;
                break;
              }
            }

          } else {
            // nominal attribute
            m_counts[j][(int) inst.value(j)][(int) inst.classValue()] += inst.weight();
            ;
          }
        }
      }
    }
  }
예제 #24
0
  public static void analyze_accuracy_NHBS(int rng_seed) throws Exception {
    HashMap<String, Object> population_params = load_defaults(null);
    RawLoader rl = new RawLoader(population_params, true, false, rng_seed);
    List<DrugUser> learningData = rl.getLearningData();

    Instances nhbs_data =
        new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size());
    for (DrugUser du : learningData) {
      nhbs_data.add(du.getInstance());
    }
    System.out.println(nhbs_data.toSummaryString());
    nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state"));

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    ArrayList<String> options = new ArrayList<String>();
    options.add("-Q");
    options.add("" + rng_seed);
    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // ROC=0.60
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    Classifier classifier = new Logistic();
    // ROC=0.686
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    // Classifier classifier = new RandomForest();
    // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree.  generally,
    // might want to optimize (or not
    // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    // options.add("-I"); options.add("100"); options.add("-K"); options.add("4");
    // ROC=0.673

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent
    // ROC=0.633

    // Classifier classifier = new AdaBoostM1();
    // ROC=0.66
    // Classifier classifier = new MultiBoostAB();
    // ROC=0.67
    // Classifier classifier = new Stacking();
    // ROC=0.495

    // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier
    // ROC=0.585
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options

    classifier.setOptions((String[]) options.toArray(new String[0]));

    // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code
    // classifier.buildClassifier(nhbs_data); // build classifier

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
예제 #25
0
  public static void test_NHBS_old() throws Exception {
    // load the data
    CSVLoader loader = new CSVLoader();
    // these must come before the getDataSet()
    // loader.setEnclosureCharacters(",\'\"S");
    // loader.setNominalAttributes("16,71"); //zip code, drug name
    // loader.setStringAttributes("");
    // loader.setDateAttributes("0,1");
    // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv"));
    loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv"));
    Instances nhbs_data = loader.getDataSet();
    loader.setMissingValue("NOVALUE");
    // loader.setMissingValue("");

    nhbs_data.deleteAttributeAt(12); // zip code
    nhbs_data.deleteAttributeAt(1); // date - redundant with age
    nhbs_data.deleteAttributeAt(0); // date
    System.out.println("classifying attribute:");
    nhbs_data.setClassIndex(1); // new index  3->2->1
    nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1

    // wishlist: perhaps it would be smarter to throw out unclassified instance?  they interfere
    // with the scoring
    nhbs_data.deleteWithMissingClass();
    // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column.  all are
    // mostly negative
    // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column

    // #14, i.e. rds_fem, should be made numeric
    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    Classifier classifier = new RandomForest();
    String[] options = {
      "-I", "100", "-K", "4"
    }; // -I trees, -K features per tree.  generally, might want to optimize (or not
       // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    classifier.setOptions(options);
    // Classifier classifier = new Logistic();

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent

    // does poorly
    // Classifier classifier = new AdaBoostM1();
    // Classifier classifier = new MultiBoostAB();
    // Classifier classifier = new Stacking();

    // building a C45 tree classifier
    // J48 classifier = new J48(); // new instance of tree
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options
    // classifier.buildClassifier(nhbs_data); // build classifier

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1));
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if there is a problem generating the prediction
   */
  @Override
  public double[] distributionForInstance(Instance instance) throws Exception {

    tokenizeInstance(instance, false);

    double[] probOfClassGivenDoc = new double[m_data.numClasses()];

    double[] logDocGivenClass = new double[m_data.numClasses()];
    for (int i = 0; i < m_data.numClasses(); i++) {
      logDocGivenClass[i] += Math.log(m_probOfClass[i]);

      LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass.get(i);

      int allWords = 0;
      // for document normalization (if in use)
      double iNorm = 0;
      double fv = 0;

      if (m_normalize) {
        for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
          String word = feature.getKey();
          Count c = feature.getValue();

          // check the word against all the dictionaries (all classes)
          boolean ok = false;
          for (int clss = 0; clss < m_data.numClasses(); clss++) {
            if (m_probOfWordGivenClass.get(clss).get(word) != null) {
              ok = true;
              break;
            }
          }

          // only normalize with respect to those words that we've seen during
          // training
          // (i.e. dictionary over all classes)
          if (ok) {
            // word counts or bag-of-words?
            fv = (m_wordFrequencies) ? c.m_count : 1.0;
            iNorm += Math.pow(Math.abs(fv), m_lnorm);
          }
        }
        iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
      }

      // System.out.println("---- " + m_inputVector.size());
      for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
        String word = feature.getKey();
        Count dictCount = dictForClass.get(word);
        // System.out.print(word + " ");
        /*
         * if (dictCount != null) { System.out.println(dictCount.m_count); }
         * else { System.out.println("*1"); }
         */
        // check the word against all the dictionaries (all classes)
        boolean ok = false;
        for (int clss = 0; clss < m_data.numClasses(); clss++) {
          if (m_probOfWordGivenClass.get(clss).get(word) != null) {
            ok = true;
            break;
          }
        }

        // ignore words we haven't seen in the training data
        if (ok) {
          double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
          // double freq = (feature.getValue().m_count / iNorm * m_norm);
          if (m_normalize) {
            freq /= iNorm * m_norm;
          }
          allWords += freq;

          if (dictCount != null) {
            logDocGivenClass[i] += freq * Math.log(dictCount.m_count);
          } else {
            // leplace for zero frequency
            logDocGivenClass[i] += freq * Math.log(m_leplace);
          }
        }
      }

      if (m_wordsPerClass[i] > 0) {
        logDocGivenClass[i] -= allWords * Math.log(m_wordsPerClass[i]);
      }
    }

    double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)];

    for (int i = 0; i < m_data.numClasses(); i++) {
      probOfClassGivenDoc[i] = Math.exp(logDocGivenClass[i] - max);
    }

    Utils.normalize(probOfClassGivenDoc);

    return probOfClassGivenDoc;
  }
예제 #27
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    // reset variable
    m_NumClasses = instances.numClasses();
    m_ClassIndex = instances.classIndex();
    m_NumAttributes = instances.numAttributes();
    m_NumInstances = instances.numInstances();
    m_TotalAttValues = 0;

    // allocate space for attribute reference arrays
    m_StartAttIndex = new int[m_NumAttributes];
    m_NumAttValues = new int[m_NumAttributes];

    // set the starting index of each attribute and the number of values for
    // each attribute and the total number of values for all attributes (not including class).
    for (int i = 0; i < m_NumAttributes; i++) {
      if (i != m_ClassIndex) {
        m_StartAttIndex[i] = m_TotalAttValues;
        m_NumAttValues[i] = instances.attribute(i).numValues();
        m_TotalAttValues += m_NumAttValues[i];
      } else {
        m_StartAttIndex[i] = -1;
        m_NumAttValues[i] = m_NumClasses;
      }
    }

    // allocate space for counts and frequencies
    m_ClassCounts = new double[m_NumClasses];
    m_AttCounts = new double[m_TotalAttValues];
    m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues];
    m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];
    m_Header = new Instances(instances, 0);

    // Calculate the counts
    for (int k = 0; k < m_NumInstances; k++) {
      int classVal = (int) instances.instance(k).classValue();
      m_ClassCounts[classVal]++;
      int[] attIndex = new int[m_NumAttributes];
      for (int i = 0; i < m_NumAttributes; i++) {
        if (i == m_ClassIndex) {
          attIndex[i] = -1;
        } else {
          attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);
          m_AttCounts[attIndex[i]]++;
        }
      }
      for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
        if (attIndex[Att1] == -1) continue;
        for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
          if ((attIndex[Att2] != -1)) {
            m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++;
            m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
          }
        }
      }
    }

    // compute mutual information between each attribute and class
    m_mutualInformation = new double[m_NumAttributes];
    for (int att = 0; att < m_NumAttributes; att++) {
      if (att == m_ClassIndex) continue;
      m_mutualInformation[att] = mutualInfo(att);
    }
  }
예제 #28
0
  /**
   * Creates split on numeric attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleNumericAttribute(Instances trainInstances) throws Exception {

    int firstMiss;
    int next = 1;
    int last = 0;
    int index = 0;
    int splitIndex = -1;
    double currentInfoGain;
    double defaultEnt;
    double minSplit;
    Instance instance;
    int i;

    // Current attribute is a numeric attribute.
    m_distribution = new Distribution(2, trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    i = 0;
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (instance.isMissing(m_attIndex)) break;
      m_distribution.add(1, instance);
      i++;
    }
    firstMiss = i;

    // Compute minimum number of Instances required in each
    // subset.
    minSplit = 0.1 * (m_distribution.total()) / ((double) trainInstances.numClasses());
    if (Utils.smOrEq(minSplit, m_minNoObj)) minSplit = m_minNoObj;
    else if (Utils.gr(minSplit, 25)) minSplit = 25;

    // Enough Instances with known values?
    if (Utils.sm((double) firstMiss, 2 * minSplit)) return;

    // Compute values of criteria for all possible split
    // indices.
    defaultEnt = m_infoGainCrit.oldEnt(m_distribution);
    while (next < firstMiss) {

      if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5
          < trainInstances.instance(next).value(m_attIndex)) {

        // Move class values for all Instances up to next
        // possible split point.
        m_distribution.shiftRange(1, 0, trainInstances, last, next);

        // Check if enough Instances in each subset and compute
        // values for criteria.
        if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
            && Utils.grOrEq(m_distribution.perBag(1), minSplit)) {
          currentInfoGain =
              m_infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights, defaultEnt);
          if (Utils.gr(currentInfoGain, m_infoGain)) {
            m_infoGain = currentInfoGain;
            splitIndex = next - 1;
          }
          index++;
        }
        last = next;
      }
      next++;
    }

    // Was there any useful split?
    if (index == 0) return;

    // Compute modified information gain for best split.
    if (m_useMDLcorrection) {
      m_infoGain = m_infoGain - (Utils.log2(index) / m_sumOfWeights);
    }
    if (Utils.smOrEq(m_infoGain, 0)) return;

    // Set instance variables' values to values for
    // best split.
    m_numSubsets = 2;
    m_splitPoint =
        (trainInstances.instance(splitIndex + 1).value(m_attIndex)
                + trainInstances.instance(splitIndex).value(m_attIndex))
            / 2;

    // In case we have a numerical precision problem we need to choose the
    // smaller value
    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
    }

    // Restore distributioN for best split.
    m_distribution = new Distribution(2, trainInstances.numClasses());
    m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
    m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);

    // Compute modified gain ratio for best split.
    m_gainRatio = m_gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain);
  }
  protected void tokenizeInstance(Instance instance, boolean updateDictionary) {
    if (m_inputVector == null) {
      m_inputVector = new LinkedHashMap<String, Count>();
    } else {
      m_inputVector.clear();
    }

    if (m_useStopList && m_stopwords == null) {
      m_stopwords = new Stopwords();
      try {
        if (getStopwords().exists() && !getStopwords().isDirectory()) {
          m_stopwords.read(getStopwords());
        }
      } catch (Exception ex) {
        ex.printStackTrace();
      }
    }

    for (int i = 0; i < instance.numAttributes(); i++) {
      if (instance.attribute(i).isString() && !instance.isMissing(i)) {
        m_tokenizer.tokenize(instance.stringValue(i));

        while (m_tokenizer.hasMoreElements()) {
          String word = m_tokenizer.nextElement();
          if (m_lowercaseTokens) {
            word = word.toLowerCase();
          }

          word = m_stemmer.stem(word);

          if (m_useStopList) {
            if (m_stopwords.is(word)) {
              continue;
            }
          }

          Count docCount = m_inputVector.get(word);
          if (docCount == null) {
            m_inputVector.put(word, new Count(instance.weight()));
          } else {
            docCount.m_count += instance.weight();
          }
        }
      }
    }

    if (updateDictionary) {
      int classValue = (int) instance.classValue();
      LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass.get(classValue);

      // document normalization
      double iNorm = 0;
      double fv = 0;

      if (m_normalize) {
        for (Count c : m_inputVector.values()) {
          // word counts or bag-of-words?
          fv = (m_wordFrequencies) ? c.m_count : 1.0;
          iNorm += Math.pow(Math.abs(fv), m_lnorm);
        }
        iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
      }

      for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
        String word = feature.getKey();
        double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
        // double freq = (feature.getValue().m_count / iNorm * m_norm);

        if (m_normalize) {
          freq /= (iNorm * m_norm);
        }

        // check all classes
        for (int i = 0; i < m_data.numClasses(); i++) {
          LinkedHashMap<String, Count> dict = m_probOfWordGivenClass.get(i);
          if (dict.get(word) == null) {
            dict.put(word, new Count(m_leplace));
            m_wordsPerClass[i] += m_leplace;
          }
        }

        Count dictCount = dictForClass.get(word);
        /*
         * if (dictCount == null) { dictForClass.put(word, new Count(m_leplace +
         * freq)); m_wordsPerClass[classValue] += (m_leplace + freq); } else {
         */
        dictCount.m_count += freq;
        m_wordsPerClass[classValue] += freq;
        // }
      }

      pruneDictionary();
    }
  }
예제 #30
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @exception Exception if the classifier has not been generated successfully
   */
  @Override
  public void buildClassifier(Instances instances) throws Exception {

    int attIndex = 0;
    double sum;

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_Instances = new Instances(instances, 0);

    // Reserve space
    m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0];
    m_Means = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Priors = new double[instances.numClasses()];
    Enumeration<Attribute> enu = instances.enumerateAttributes();
    while (enu.hasMoreElements()) {
      Attribute attribute = enu.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[attribute.numValues()];
        }
      } else {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[1];
        }
      }
      attIndex++;
    }

    // Compute counts and sums
    Enumeration<Instance> enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNominal()) {
              m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++;
            } else {
              m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute);
              m_Counts[(int) instance.classValue()][attIndex][0]++;
            }
          }
          attIndex++;
        }
        m_Priors[(int) instance.classValue()]++;
      }
    }

    // Compute means
    Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Counts[j][attIndex][0] < 2) {
            throw new Exception(
                "attribute "
                    + attribute.name()
                    + ": less than two values for class "
                    + instances.classAttribute().value(j));
          }
          m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
        }
      }
      attIndex++;
    }

    // Compute standard deviations
    enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNumeric()) {
              m_Devs[(int) instance.classValue()][attIndex] +=
                  (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute))
                      * (m_Means[(int) instance.classValue()][attIndex]
                          - instance.value(attribute));
            }
          }
          attIndex++;
        }
      }
    }
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Devs[j][attIndex] <= 0) {
            throw new Exception(
                "attribute "
                    + attribute.name()
                    + ": standard deviation is 0 for class "
                    + instances.classAttribute().value(j));
          } else {
            m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
            m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
          }
        }
      }
      attIndex++;
    }

    // Normalize counts
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          sum = Utils.sum(m_Counts[j][attIndex]);
          for (int i = 0; i < attribute.numValues(); i++) {
            m_Counts[j][attIndex][i] =
                (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues());
          }
        }
      }
      attIndex++;
    }

    // Normalize priors
    sum = Utils.sum(m_Priors);
    for (int j = 0; j < instances.numClasses(); j++) {
      m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses());
    }
  }