예제 #1
0
  @Override
  public String classify(User user, Sample sample) {

    Instances trainingSet =
        new TrainingSetBuilder()
            .setAttributes(user.getBssids())
            .setClassAttribute(
                "Location",
                user.getLocations().stream().map(Location::getName).collect(Collectors.toList()))
            .build("TrainingSet", 1);

    // Create instance
    Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample);

    Instance instance = new Instance(trainingSet.numAttributes());

    for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) {
      Attribute attribute = (Attribute) e.nextElement();
      String bssid = attribute.name();
      int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0;
      instance.setValue(attribute, level);
    }

    if (sample.getLocation() != null)
      instance.setValue(trainingSet.classAttribute(), sample.getLocation());

    instance.setDataset(trainingSet);
    trainingSet.add(instance);

    int predictedClass = classify(fromBase64(user.getClassifiers()), instance);

    return trainingSet.classAttribute().value(predictedClass);
  }
예제 #2
0
  /**
   * Returns a string representation of the classifier.
   *
   * @return a string representation of the classifier
   */
  public String toString() {
    StringBuffer result =
        new StringBuffer(
            "The independent probability of a class\n--------------------------------------\n");

    for (int c = 0; c < m_numClasses; c++)
      result
          .append(m_headerInfo.classAttribute().value(c))
          .append("\t")
          .append(Double.toString(m_probOfClass[c]))
          .append("\n");

    result.append(
        "\nThe probability of a word given the class\n-----------------------------------------\n\t");

    for (int c = 0; c < m_numClasses; c++)
      result.append(m_headerInfo.classAttribute().value(c)).append("\t");

    result.append("\n");

    for (int w = 0; w < m_numAttributes; w++) {
      result.append(m_headerInfo.attribute(w).name()).append("\t");
      for (int c = 0; c < m_numClasses; c++)
        result.append(Double.toString(Math.exp(m_probOfWordGivenClass[c][w]))).append("\t");
      result.append("\n");
    }

    return result.toString();
  }
예제 #3
0
  @Override
  public void buildClassifier(Instances data) throws Exception {
    trainingData = data;
    Attribute classAttribute = data.classAttribute();
    prototypes = new ArrayList<>();

    classedData = new HashMap<String, ArrayList<Sequence>>();
    indexClassedDataInFullData = new HashMap<String, ArrayList<Integer>>();
    for (int c = 0; c < data.numClasses(); c++) {
      classedData.put(data.classAttribute().value(c), new ArrayList<Sequence>());
      indexClassedDataInFullData.put(data.classAttribute().value(c), new ArrayList<Integer>());
    }

    sequences = new Sequence[data.numInstances()];
    classMap = new String[sequences.length];
    for (int i = 0; i < sequences.length; i++) {
      Instance sample = data.instance(i);
      MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1];
      int shift = (sample.classIndex() == 0) ? 1 : 0;
      for (int t = 0; t < sequence.length; t++) {
        sequence[t] = new MonoDoubleItemSet(sample.value(t + shift));
      }
      sequences[i] = new Sequence(sequence);
      String clas = sample.stringValue(classAttribute);
      classMap[i] = clas;
      classedData.get(clas).add(sequences[i]);
      indexClassedDataInFullData.get(clas).add(i);
      //			System.out.println("Element "+i+" of train is classed "+clas+" and went to element
      // "+(indexClassedDataInFullData.get(clas).size()-1));
    }
    buildSpecificClassifier(data);
  }
  /**
   * Returns a textual description of this classifier.
   *
   * @return a textual description of this classifier.
   */
  @Override
  public String toString() {

    if (m_probOfClass == null) {
      return "NaiveBayesMultinomialText: No model built yet.\n";
    }

    StringBuffer result = new StringBuffer();

    // build a master dictionary over all classes
    HashSet<String> master = new HashSet<String>();
    for (int i = 0; i < m_data.numClasses(); i++) {
      LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);

      for (String key : classDict.keySet()) {
        master.add(key);
      }
    }

    result.append("Dictionary size: " + master.size()).append("\n\n");

    result.append("The independent frequency of a class\n");
    result.append("--------------------------------------\n");

    for (int i = 0; i < m_data.numClasses(); i++) {
      result
          .append(m_data.classAttribute().value(i))
          .append("\t")
          .append(Double.toString(m_probOfClass[i]))
          .append("\n");
    }

    result.append("\nThe frequency of a word given the class\n");
    result.append("-----------------------------------------\n");

    for (int i = 0; i < m_data.numClasses(); i++) {
      result.append(Utils.padLeft(m_data.classAttribute().value(i), 11)).append("\t");
    }

    result.append("\n");

    Iterator<String> masterIter = master.iterator();
    while (masterIter.hasNext()) {
      String word = masterIter.next();

      for (int i = 0; i < m_data.numClasses(); i++) {
        LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);
        Count c = classDict.get(word);
        if (c == null) {
          result.append("<laplace=1>\t");
        } else {
          result.append(Utils.padLeft(Double.toString(c.m_count), 11)).append("\t");
        }
      }
      result.append(word);
      result.append("\n");
    }

    return result.toString();
  }
예제 #5
0
  /**
   * Inserts an instance into the hash table
   *
   * @param inst instance to be inserted
   * @param instA to create the hash key from
   * @throws Exception if the instance can't be inserted
   */
  private void insertIntoTable(Instance inst, double[] instA) throws Exception {

    double[] tempClassDist2;
    double[] newDist;
    DecisionTableHashKey thekey;

    if (instA != null) {
      thekey = new DecisionTableHashKey(instA);
    } else {
      thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);
    }

    // see if this one is already in the table
    tempClassDist2 = (double[]) m_entries.get(thekey);
    if (tempClassDist2 == null) {
      if (m_classIsNominal) {
        newDist = new double[m_theInstances.classAttribute().numValues()];

        // Leplace estimation
        for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          newDist[i] = 1.0;
        }

        newDist[(int) inst.classValue()] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      } else {
        newDist = new double[2];
        newDist[0] = inst.classValue() * inst.weight();
        newDist[1] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      }
    } else {

      // update the distribution for this instance
      if (m_classIsNominal) {
        tempClassDist2[(int) inst.classValue()] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      } else {
        tempClassDist2[0] += (inst.classValue() * inst.weight());
        tempClassDist2[1] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      }
    }
  }
예제 #6
0
  /**
   * Stratify the given data into the given number of bags based on the class values. It differs
   * from the <code>Instances.stratify(int fold)</code> that before stratification it sorts the
   * instances according to the class order in the header file. It assumes no missing values in the
   * class.
   *
   * @param data the given data
   * @param folds the given number of folds
   * @param rand the random object used to randomize the instances
   * @return the stratified instances
   */
  public static final Instances stratify(Instances data, int folds, Random rand) {
    if (!data.classAttribute().isNominal()) return data;

    Instances result = new Instances(data, 0);
    Instances[] bagsByClasses = new Instances[data.numClasses()];

    for (int i = 0; i < bagsByClasses.length; i++) bagsByClasses[i] = new Instances(data, 0);

    // Sort by class
    for (int j = 0; j < data.numInstances(); j++) {
      Instance datum = data.instance(j);
      bagsByClasses[(int) datum.classValue()].add(datum);
    }

    // Randomize each class
    for (int j = 0; j < bagsByClasses.length; j++) bagsByClasses[j].randomize(rand);

    for (int k = 0; k < folds; k++) {
      int offset = k, bag = 0;
      oneFold:
      while (true) {
        while (offset >= bagsByClasses[bag].numInstances()) {
          offset -= bagsByClasses[bag].numInstances();
          if (++bag >= bagsByClasses.length) // Next bag
          break oneFold;
        }

        result.add(bagsByClasses[bag].instance(offset));
        offset += folds;
      }
    }

    return result;
  }
예제 #7
0
  private static void writePredictedDistributions(
      Classifier c, Instances data, int idIndex, Writer out) throws Exception {
    // header
    out.write("id");
    for (int i = 0; i < data.numClasses(); i++) {
      out.write(",\"");
      out.write(data.classAttribute().value(i).replaceAll("[\"\\\\]", "_"));
      out.write("\"");
    }
    out.write("\n");

    // data
    for (int i = 0; i < data.numInstances(); i++) {
      final String id = data.instance(i).stringValue(idIndex);
      double[] distribution = c.distributionForInstance(data.instance(i));

      // final String label = data.attribute(classIndex).value();
      out.write(id);
      for (double probability : distribution) {
        out.write(",");
        out.write(String.valueOf(probability > 1e-5 ? (float) probability : 0f));
      }
      out.write("\n");
    }
  }
예제 #8
0
  /**
   * Outputs the linear regression model as a string.
   *
   * @return the model as string
   */
  public String toString() {

    if (m_TransformedData == null) {
      return "Linear Regression: No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer();
      int column = 0;
      boolean first = true;

      text.append("\nLinear Regression Model\n\n");

      text.append(m_TransformedData.classAttribute().name() + " =\n\n");
      for (int i = 0; i < m_TransformedData.numAttributes(); i++) {
        if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) {
          if (!first) text.append(" +\n");
          else first = false;
          text.append(Utils.doubleToString(m_Coefficients[column], 12, 4) + " * ");
          text.append(m_TransformedData.attribute(i).name());
          column++;
        }
      }
      text.append(" +\n" + Utils.doubleToString(m_Coefficients[column], 12, 4));
      return text.toString();
    } catch (Exception e) {
      return "Can't print Linear Regression!";
    }
  }
예제 #9
0
  protected void buildSpecificClassifier(Instances data) {
    if (distancesPerClass == null) {
      initDistances();
    }

    ArrayList<String> classes = new ArrayList<String>(classedData.keySet());

    for (String clas : classes) {
      // if the class is empty, continue
      if (classedData.get(clas).isEmpty()) continue;
      KMeansCachedSymbolicSequence kmeans =
          new KMeansCachedSymbolicSequence(
              nbPrototypesPerClass[trainingData.classAttribute().indexOfValue(clas)],
              classedData.get(clas),
              distancesPerClass.get(clas));
      kmeans.cluster();

      for (int i = 0; i < kmeans.centers.length; i++) {
        if (kmeans.centers[i] != null) { // ~ if empty cluster
          ClassedSequence s = new ClassedSequence(kmeans.centers[i], clas);
          prototypes.add(s);
        }
      }
    }
  }
예제 #10
0
  @Override
  public List<Classifier> buildClassifiers(User user, List<Sample> validSamples) {

    Instances trainingSet =
        new TrainingSetBuilder()
            .setAttributes(user.getBssids())
            .setClassAttribute(
                "Location",
                user.getLocations().stream().map(Location::getName).collect(Collectors.toList()))
            .build("TrainingSet", validSamples.size());

    // Create instances
    validSamples.forEach(
        sample -> {
          Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample);

          Instance instance = new Instance(trainingSet.numAttributes());

          for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) {
            Attribute attribute = (Attribute) e.nextElement();
            String bssid = attribute.name();
            int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0;
            instance.setValue(attribute, level);
          }

          instance.setValue(trainingSet.classAttribute(), sample.getLocation());

          instance.setDataset(trainingSet);
          trainingSet.add(instance);
        });

    // Build classifiers
    List<Classifier> classifiers = buildClassifiers(trainingSet);
    return classifiers;
  }
예제 #11
0
  private static void analyse(Instances train, Instances datapredict) {
    String mpOptions = "-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a";

    try {

      train.setClassIndex(train.numAttributes() - 1);
      train.deleteAttributeAt(0);
      int numClasses = train.numClasses();

      for (int i = 0; i < numClasses; i++) {
        System.out.println("class value [" + i + "]=" + train.classAttribute().value(i) + "");
      }

      // Instance of NN
      MultilayerPerceptron mlp = new MultilayerPerceptron();
      mlp.setOptions(weka.core.Utils.splitOptions(mpOptions));
      mlp.buildClassifier(train);

      datapredict.setClassIndex(datapredict.numAttributes() - 1);
      datapredict.deleteAttributeAt(0);

      // Instances predicteddata = new Instances(datapredict);
      for (int i = 0; i < datapredict.numInstances(); i++) {

        Instance newInst = datapredict.instance(i);
        double pred = mlp.classifyInstance(newInst);
        int predInt = (int) pred; // Math.round(pred);
        String predString = train.classAttribute().value(predInt);
        System.out.println(
            "cliente["
                + i
                + "] pred["
                + pred
                + "] predInt["
                + predInt
                + "] desertor["
                + predString
                + "]");
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
예제 #12
0
 @Override
 public void buildClassifier(Instances instances) throws Exception {
   List<List<Object>> data = new LinkedList<List<Object>>();
   int classAttribute = instances.classAttribute().index();
   determineNumericAttributes(instances);
   for (Instance inst : instances) {
     data.add(convert(inst));
   }
   classifier.learnModel(data, classAttribute);
 }
예제 #13
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
예제 #14
0
 /** Initializes the m_Attributes of the class. */
 private void init_m_Attributes() {
   try {
     m_NumInstances = m_Train.numInstances();
     m_NumClasses = m_Train.numClasses();
     m_NumAttributes = m_Train.numAttributes();
     m_ClassType = m_Train.classAttribute().type();
     m_InitFlag = ON;
   } catch (Exception e) {
     e.printStackTrace();
   }
 }
예제 #15
0
파일: Wavelet.java 프로젝트: dachylong/weka
  /**
   * processes the instances using the HAAR algorithm
   *
   * @param instances the data to process
   * @return the modified data
   * @throws Exception in case the processing goes wrong
   */
  protected Instances processHAAR(Instances instances) throws Exception {
    Instances result;
    int i;
    int n;
    int j;
    int clsIdx;
    double[] oldVal;
    double[] newVal;
    int level;
    int length;
    double[] clsVal;
    Attribute clsAtt;

    clsIdx = instances.classIndex();
    clsVal = null;
    clsAtt = null;
    if (clsIdx > -1) {
      clsVal = instances.attributeToDoubleArray(clsIdx);
      clsAtt = (Attribute) instances.classAttribute().copy();
      instances.setClassIndex(-1);
      instances.deleteAttributeAt(clsIdx);
    }
    result = new Instances(instances, 0);
    level = (int) StrictMath.ceil(StrictMath.log(instances.numAttributes()) / StrictMath.log(2.0));

    for (i = 0; i < instances.numInstances(); i++) {
      oldVal = instances.instance(i).toDoubleArray();
      newVal = new double[oldVal.length];

      for (n = level; n > 0; n--) {
        length = (int) StrictMath.pow(2, n - 1);

        for (j = 0; j < length; j++) {
          newVal[j] = (oldVal[j * 2] + oldVal[j * 2 + 1]) / StrictMath.sqrt(2);
          newVal[j + length] = (oldVal[j * 2] - oldVal[j * 2 + 1]) / StrictMath.sqrt(2);
        }

        System.arraycopy(newVal, 0, oldVal, 0, newVal.length);
      }

      // add new transformed instance
      result.add(new DenseInstance(1, newVal));
    }

    // add class again
    if (clsIdx > -1) {
      result.insertAttributeAt(clsAtt, clsIdx);
      result.setClassIndex(clsIdx);
      for (i = 0; i < clsVal.length; i++) result.instance(i).setClassValue(clsVal[i]);
    }

    return result;
  }
예제 #16
0
  public String classifyInstance(String newInst) {

    File f = null;
    String type = null;
    try {
      f = new File("/data/data/com.example.gpstracker/tmp.arff");
      f.createNewFile();

      FileWriter fw = new FileWriter(f);
      BufferedWriter bw = new BufferedWriter(fw);
      bw.write("@relation gps_tracking");
      bw.newLine();
      bw.newLine();
      bw.write("@attribute Longtitude numeric");
      bw.newLine();
      bw.write("@attribute Latitude numeric");
      bw.newLine();
      bw.write("@attribute CurrentSpeed numeric");
      bw.newLine();
      bw.write("@attribute Timestamp date \"yyyy-MM-dd HH:mm:ss\"");
      bw.newLine();
      bw.write("@attribute MoveType {Walking,Running,Biking,Driving,Metro,Bus,Motionless}");
      bw.newLine();
      bw.write("@attribute IsGpsFixed {yes,no}");
      bw.newLine();
      bw.newLine();
      bw.write("@data");
      bw.newLine();
      bw.write(newInst);
      bw.close();

      // load unlabeled data
      Instances unlabeled =
          new Instances(
              new BufferedReader(new FileReader("/data/data/com.example.gpstracker/tmp.arff")));
      // set class attribute
      unlabeled.setClassIndex(unlabeled.numAttributes() - 2);

      // label instances
      double clsLabel = classifier.classifyInstance(unlabeled.instance(0));
      type = unlabeled.classAttribute().value((int) clsLabel);
      boolean deleted = f.delete();

    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } catch (IOException e) {
      e.printStackTrace();

    } catch (Exception e) {
      e.printStackTrace();
    }
    return type;
  }
 @Override
 public Instances labelData(String data) throws Exception {
   Instances unlabeled = new Instances(new BufferedReader(new FileReader(data)));
   // set class attribute
   unlabeled.setClassIndex(unlabeled.numAttributes() - 1);
   // create copy
   Instances labeled = new Instances(unlabeled);
   for (int i = 0; i < unlabeled.numInstances(); i++) {
     Instance ui = unlabeled.instance(i);
     double clsLabel = this.classifier.classifyInstance(ui);
     labeled.instance(i).setClassValue(clsLabel);
     System.out.println(ui.toString() + " -> " + unlabeled.classAttribute().value((int) clsLabel));
   }
   return labeled;
 }
예제 #18
0
  /**
   * Sets the format of the input instances.
   *
   * @param instanceInfo an Instances object containing the input instance structure (any instances
   *     contained in the object are ignored - only the structure is required).
   * @return true if the outputFormat may be collected immediately
   * @throws Exception if the input format can't be set successfully
   */
  @Override
  public boolean setInputFormat(Instances instanceInfo) throws Exception {

    super.setInputFormat(instanceInfo);
    if (instanceInfo.classIndex() < 0) {
      throw new UnassignedClassException("No class has been assigned to the instances");
    }
    setOutputFormat();
    m_Indices = null;
    if (instanceInfo.classAttribute().isNominal()) {
      return true;
    } else {
      return false;
    }
  }
예제 #19
0
  /**
   * Add a rule to the ruleset and update the stats
   *
   * @param lastRule the rule to be added
   */
  public void addAndUpdate(Rule lastRule) {
    if (m_Ruleset == null) m_Ruleset = new FastVector();
    m_Ruleset.addElement(lastRule);

    Instances data = (m_Filtered == null) ? m_Data : ((Instances[]) m_Filtered.lastElement())[1];
    double[] stats = new double[6];
    double[] classCounts = new double[m_Data.classAttribute().numValues()];
    Instances[] filtered = computeSimpleStats(m_Ruleset.size() - 1, data, stats, classCounts);

    if (m_Filtered == null) m_Filtered = new FastVector();
    m_Filtered.addElement(filtered);

    if (m_SimpleStats == null) m_SimpleStats = new FastVector();
    m_SimpleStats.addElement(stats);

    if (m_Distributions == null) m_Distributions = new FastVector();
    m_Distributions.addElement(classCounts);
  }
예제 #20
0
  /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  @Override
  public String toString() {

    if (m_Instances == null) {
      return "Naive Bayes (simple): No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer("Naive Bayes (simple)");
      int attIndex;

      for (int i = 0; i < m_Instances.numClasses(); i++) {
        text.append(
            "\n\nClass "
                + m_Instances.classAttribute().value(i)
                + ": P(C) = "
                + Utils.doubleToString(m_Priors[i], 10, 8)
                + "\n\n");
        Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          text.append("Attribute " + attribute.name() + "\n");
          if (attribute.isNominal()) {
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(attribute.value(j) + "\t");
            }
            text.append("\n");
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t");
            }
          } else {
            text.append("Mean: " + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
            text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
          }
          text.append("\n\n");
          attIndex++;
        }
      }

      return text.toString();
    } catch (Exception e) {
      return "Can't print Naive Bayes classifier!";
    }
  }
예제 #21
0
  /**
   * Filter the data according to the ruleset and compute the basic stats: coverage/uncoverage,
   * true/false positive/negatives of each rule
   */
  public void countData() {
    if ((m_Filtered != null) || (m_Ruleset == null) || (m_Data == null)) return;

    int size = m_Ruleset.size();
    m_Filtered = new FastVector(size);
    m_SimpleStats = new FastVector(size);
    m_Distributions = new FastVector(size);
    Instances data = new Instances(m_Data);

    for (int i = 0; i < size; i++) {
      double[] stats = new double[6]; // 6 statistics parameters
      double[] classCounts = new double[m_Data.classAttribute().numValues()];
      Instances[] filtered = computeSimpleStats(i, data, stats, classCounts);
      m_Filtered.addElement(filtered);
      m_SimpleStats.addElement(stats);
      m_Distributions.addElement(classCounts);
      data = filtered[1]; // Data not covered
    }
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();

    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
      m_Classifier.buildClassifier(trainData);
      m_BestClassifierOptions = m_InitOptions;
      return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;

    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String[] options = (String[]) m_BestClassifierOptions.clone();
    ((OptionHandler) m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }
예제 #23
0
파일: Id3.java 프로젝트: alishakiba/jDenetX
  /**
   * Method for building an Id3 tree.
   *
   * @param data the training data
   * @exception Exception if decision tree can't be built successfully
   */
  private void makeTree(Instances data) throws Exception {

    // Check if no instances have reached this node.
    if (data.numInstances() == 0) {
      m_Attribute = null;
      m_ClassValue = Utils.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    // Compute attribute with maximum information gain.
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    while (attEnum.hasMoreElements()) {
      Attribute att = (Attribute) attEnum.nextElement();
      infoGains[att.index()] = computeInfoGain(data, att);
    }
    m_Attribute = data.attribute(Utils.maxIndex(infoGains));

    // Make leaf if information gain is zero.
    // Otherwise create successors.
    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
      m_Attribute = null;
      m_Distribution = new double[data.numClasses()];
      Enumeration instEnum = data.enumerateInstances();
      while (instEnum.hasMoreElements()) {
        Instance inst = (Instance) instEnum.nextElement();
        m_Distribution[(int) inst.classValue()]++;
      }
      Utils.normalize(m_Distribution);
      m_ClassValue = Utils.maxIndex(m_Distribution);
      m_ClassAttribute = data.classAttribute();
    } else {
      Instances[] splitData = splitData(data, m_Attribute);
      m_Successors = new Id3[m_Attribute.numValues()];
      for (int j = 0; j < m_Attribute.numValues(); j++) {
        m_Successors[j] = new Id3();
        m_Successors[j].makeTree(splitData[j]);
      }
    }
  }
예제 #24
0
  private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData)
      throws Exception {
    System.err.println(
        "INFO: Starting split validation to predict '"
            + trainData.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' (#train="
            + trainData.numInstances()
            + ",#test="
            + testData.numInstances()
            + ") ...");

    if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set");

    c.buildClassifier(trainData);
    Evaluation eval = new Evaluation(testData);
    eval.useNoPriors();
    double[] predictions = eval.evaluateModel(c, testData);

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    // write predictions to file
    {
      System.err.println("INFO: Writing predictions to file ...");
      Writer out = new FileWriter("prediction.trec");
      writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out);
      out.close();
    }

    // write predicted distributions to CSV
    {
      System.err.println("INFO: Writing predicted distributions to CSV ...");
      Writer out = new FileWriter("predicted_distribution.csv");
      writePredictedDistributions(c, testData, 0, out);
      out.close();
    }
  }
  /**
   * test on one sample
   *
   * @param sample
   * @return p(y|sample) forall y
   * @throws Exception
   */
  public double classifyInstance(Instance sample) throws Exception {
    // transform instance to sequence
    MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1];
    int shift = (sample.classIndex() == 0) ? 1 : 0;
    for (int t = 0; t < sequence.length; t++) {
      sequence[t] = new MonoDoubleItemSet(sample.value(t + shift));
    }
    Sequence seq = new Sequence(sequence);

    // for each class
    String classValue = null;
    double maxProb = 0.0;
    double[] pr = new double[classedData.keySet().size()];
    for (String clas : classedData.keySet()) {
      int c = trainingData.classAttribute().indexOfValue(clas);
      double prob = 0.0;
      for (int k = 0; k < centroidsPerClass[c].length; k++) {
        // compute P(Q|k_c)
        if (sigmasPerClass[c][k] == Double.NaN || sigmasPerClass[c][k] == 0) {
          System.err.println("sigma=NAN||sigma=0");
          continue;
        }
        double dist = seq.distanceEuc(centroidsPerClass[c][k]);
        double p = computeProbaForQueryAndCluster(sigmasPerClass[c][k], dist);
        prob += p / centroidsPerClass[c].length;
        //				prob += p*prior[c][k];
        if (p > maxProb) {
          maxProb = p;
          classValue = clas;
        }
      }
      //			if (prob > maxProb) {
      //				maxProb = prob;
      //				classValue = clas;
      //			}
    }
    //		System.out.println(Arrays.toString(pr));
    //		System.out.println(classValue);
    return sample.classAttribute().indexOfValue(classValue);
  }
예제 #26
0
  /**
   * use for training data
   *
   * @param instancesWithMeta
   * @param labelInstances
   * @return
   * @throws Exception
   */
  public static Instances addNominalLabelsForClassificationToTrainingData(
      Instances instances, AttributeFilterMeta instancesWithMeta, Instances labelInstances)
      throws Exception {

    Instances finalCleaned = Instances.mergeInstances(instances, labelInstances);
    finalCleaned.setClassIndex(finalCleaned.numAttributes() - 1);

    Attribute classAt = finalCleaned.classAttribute();
    int numOfAttValues = classAt.numValues();
    String attValues = "";
    for (int nai = 0; nai < numOfAttValues; nai++) {
      if (nai != 0) {
        attValues += ",";
      }
      attValues += classAt.value(nai);
    }
    instancesWithMeta.setClassAtrributeValues(attValues);

    instancesWithMeta.setInstances(finalCleaned);

    return finalCleaned;
  }
  public void chooseClassifier() {
    int classIndex = 0; // number of attributes must be greater than 1
    /**
     * We can use either a supervised or an un-supervised algorithm if a class attribute already
     * exists in the dataset (meaning some labeled instances exists), depending on the size of the
     * training set, the decision is taken.
     */
    classIndex = traindata.numAttributes() - 1;
    traindata.setClassIndex(classIndex);
    if (classIndex == traindata.numAttributes() - 1
        || traindata.attribute("class") != null
        || traindata.attribute("Class") != null && traindata.size() >= testdata.size()) {
      System.out.println("class attribute found....");
      System.out.println("Initial training set is larger than the test set...." + traindata.size());

      // Go ahead to generate folds, then call classifier
      try {
        ce.generateFolds(traindata);
      } catch (Exception ex) {
        Logger.getLogger(FileTypeEnablerAndProcessor.class.getName()).log(Level.SEVERE, null, ex);
      }
    }
    /**
     * When there is no class attribute to show labeled instances exists then use an un-supervised
     * algorithm straight; no need for the cross-validation folds.
     */
    else {
      try {
        System.out.println("class attribute not found");
        classIndex = traindata.numAttributes() - 1;
        traindata.setClassIndex(classIndex);
        System.out.println("Class to predict is = " + traindata.classAttribute() + "\n");
        uc.autoProbClass(traindata);
      } catch (Exception ex) {
        Logger.getLogger(FileTypeEnablerAndProcessor.class.getName()).log(Level.SEVERE, null, ex);
      }
    }
  }
예제 #28
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_NumClasses = instances.numClasses();
    m_ClassType = instances.classAttribute().type();
    m_Train = new Instances(instances, 0, instances.numInstances());

    // Throw away initial instances until within the specified window size
    if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
      m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize);
    }

    m_NumAttributesUsed = 0.0;
    for (int i = 0; i < m_Train.numAttributes(); i++) {
      if ((i != m_Train.classIndex())
          && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) {
        m_NumAttributesUsed += 1.0;
      }
    }

    m_NNSearch.setInstances(m_Train);

    // Invalidate any currently cross-validation selected k
    m_kNNValid = false;

    m_defaultModel = new ZeroR();
    m_defaultModel.buildClassifier(instances);
    m_defaultModel.setOptions(getOptions());
    // System.out.println("hello world");

  }
예제 #29
0
  static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception {
    System.err.println(
        "INFO: Starting crossvalidation to predict '"
            + data.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' ...");

    StringBuffer sb = new StringBuffer();
    Evaluation eval = new Evaluation(data);
    eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE);

    // write predictions to file
    {
      Writer out = new FileWriter("cv.log");
      out.write(sb.toString());
      out.close();
    }

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
  }
  @Override
  public NaiveBayesMultinomialText aggregate(NaiveBayesMultinomialText toAggregate)
      throws Exception {
    if (m_numModels == Integer.MIN_VALUE) {
      throw new Exception(
          "Can't aggregate further - model has already been " + "aggregated and finalized");
    }

    if (m_probOfClass == null) {
      throw new Exception("No model built yet, can't aggregate");
    }

    // just check the class attribute for compatibility as we will be
    // merging dictionaries
    if (!m_data.classAttribute().equals(toAggregate.m_data.classAttribute())) {
      throw new Exception(
          "Can't aggregate - class attribute in data headers "
              + "does not match: "
              + m_data.classAttribute().equalsMsg(toAggregate.m_data.classAttribute()));
    }

    for (int i = 0; i < m_probOfClass.length; i++) {
      m_probOfClass[i] += toAggregate.m_probOfClass[i];
    }

    Map<Integer, LinkedHashMap<String, Count>> dicts = toAggregate.m_probOfWordGivenClass;
    Iterator<Map.Entry<Integer, LinkedHashMap<String, Count>>> perClass =
        dicts.entrySet().iterator();
    while (perClass.hasNext()) {
      Map.Entry<Integer, LinkedHashMap<String, Count>> currentClassDict = perClass.next();

      LinkedHashMap<String, Count> masterDict =
          m_probOfWordGivenClass.get(currentClassDict.getKey());

      if (masterDict == null) {
        // we haven't seen this class during our training
        masterDict = new LinkedHashMap<String, Count>();
        m_probOfWordGivenClass.put(currentClassDict.getKey(), masterDict);
      }

      // now process words seen for this class
      Iterator<Map.Entry<String, Count>> perClassEntries =
          currentClassDict.getValue().entrySet().iterator();
      while (perClassEntries.hasNext()) {
        Map.Entry<String, Count> entry = perClassEntries.next();

        Count masterCount = masterDict.get(entry.getKey());

        if (masterCount == null) {
          // we haven't seen this entry (or its been pruned)
          masterCount = new Count(entry.getValue().m_count);
          masterDict.put(entry.getKey(), masterCount);
        } else {
          // add up
          masterCount.m_count += entry.getValue().m_count;
        }
      }
    }

    m_numModels++;

    return this;
  }