예제 #1
0
  /**
   * Builds a regression model for the given data.
   *
   * @param data the training data to be used for generating the linear regression function
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (!m_checksTurnedOff) {
      // can classifier handle the data?
      getCapabilities().testWithFail(data);

      // remove instances with missing class
      data = new Instances(data);
      data.deleteWithMissingClass();
    }

    // Preprocess instances
    if (!m_checksTurnedOff) {
      m_TransformFilter = new NominalToBinary();
      m_TransformFilter.setInputFormat(data);
      data = Filter.useFilter(data, m_TransformFilter);
      m_MissingFilter = new ReplaceMissingValues();
      m_MissingFilter.setInputFormat(data);
      data = Filter.useFilter(data, m_MissingFilter);
      data.deleteWithMissingClass();
    } else {
      m_TransformFilter = null;
      m_MissingFilter = null;
    }

    m_ClassIndex = data.classIndex();
    m_TransformedData = data;

    // Turn all attributes on for a start
    m_SelectedAttributes = new boolean[data.numAttributes()];
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i != m_ClassIndex) {
        m_SelectedAttributes[i] = true;
      }
    }
    m_Coefficients = null;

    // Compute means and standard deviations
    m_Means = new double[data.numAttributes()];
    m_StdDevs = new double[data.numAttributes()];
    for (int j = 0; j < data.numAttributes(); j++) {
      if (j != data.classIndex()) {
        m_Means[j] = data.meanOrMode(j);
        m_StdDevs[j] = Math.sqrt(data.variance(j));
        if (m_StdDevs[j] == 0) {
          m_SelectedAttributes[j] = false;
        }
      }
    }

    m_ClassStdDev = Math.sqrt(data.variance(m_TransformedData.classIndex()));
    m_ClassMean = data.meanOrMode(m_TransformedData.classIndex());

    // Perform the regression
    findBestModel();

    // Save memory
    m_TransformedData = new Instances(data, 0);
  }
예제 #2
0
파일: PART.java 프로젝트: CSLeicester/weka
  /**
   * Generates the classifier.
   *
   * @param instances the data to train with
   * @throws Exception if classifier can't be built successfully
   */
  @Override
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    ModelSelection modSelection;

    if (m_binarySplits) {
      modSelection =
          new BinC45ModelSelection(
              m_minNumObj, instances, m_useMDLcorrection, m_doNotMakeSplitPointActualValue);
    } else {
      modSelection =
          new C45ModelSelection(
              m_minNumObj, instances, m_useMDLcorrection, m_doNotMakeSplitPointActualValue);
    }
    if (m_unpruned) {
      m_root = new MakeDecList(modSelection, m_minNumObj);
    } else if (m_reducedErrorPruning) {
      m_root = new MakeDecList(modSelection, m_numFolds, m_minNumObj, m_Seed);
    } else {
      m_root = new MakeDecList(modSelection, m_CF, m_minNumObj);
    }
    m_root.buildClassifier(instances);
    if (m_binarySplits) {
      ((BinC45ModelSelection) modSelection).cleanup();
    } else {
      ((C45ModelSelection) modSelection).cleanup();
    }
  }
  /**
   * Builds the model of the base learner.
   *
   * @param data the training data
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (m_Classifier == null) {
      throw new Exception("No base classifier has been set!");
    }
    if (m_MatrixSource == MATRIX_ON_DEMAND) {
      String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
      File costFile = new File(getOnDemandDirectory(), costName);
      if (!costFile.exists()) {
        throw new Exception("On-demand cost file doesn't exist: " + costFile);
      }
      setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile))));
    } else if (m_CostMatrix == null) {
      // try loading an old format cost file
      m_CostMatrix = new CostMatrix(data.numClasses());
      m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile)));
    }

    if (!m_MinimizeExpectedCost) {
      Random random = null;
      if (!(m_Classifier instanceof WeightedInstancesHandler)) {
        random = new Random(m_Seed);
      }
      data = m_CostMatrix.applyCostMatrix(data, random);
    }
    m_Classifier.buildClassifier(data);
  }
예제 #4
0
  /**
   * Builds the classifier.
   *
   * @param data the data to train with
   * @throws Exception if classifier can't be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    Instances filteredData = new Instances(data);
    filteredData.deleteWithMissingClass();

    // replace missing values
    m_replaceMissing = new ReplaceMissingValues();
    m_replaceMissing.setInputFormat(filteredData);
    filteredData = Filter.useFilter(filteredData, m_replaceMissing);

    // possibly convert nominal attributes globally
    if (m_convertNominal) {
      m_nominalToBinary = new NominalToBinary();
      m_nominalToBinary.setInputFormat(filteredData);
      filteredData = Filter.useFilter(filteredData, m_nominalToBinary);
    }

    int minNumInstances = 2;

    // create a FT  tree root
    if (m_modelType == 0)
      m_tree =
          new FTNode(
              m_errorOnProbabilities,
              m_numBoostingIterations,
              m_minNumInstances,
              m_weightTrimBeta,
              m_useAIC);

    // create a FTLeaves  tree root
    if (m_modelType == 1) {
      m_tree =
          new FTLeavesNode(
              m_errorOnProbabilities,
              m_numBoostingIterations,
              m_minNumInstances,
              m_weightTrimBeta,
              m_useAIC);
    }
    // create a FTInner  tree root
    if (m_modelType == 2)
      m_tree =
          new FTInnerNode(
              m_errorOnProbabilities,
              m_numBoostingIterations,
              m_minNumInstances,
              m_weightTrimBeta,
              m_useAIC);

    // build tree
    m_tree.buildClassifier(filteredData);
    // prune tree
    m_tree.prune();
    m_tree.assignIDs(0);
    m_tree.cleanup();
  }
예제 #5
0
파일: LWL.java 프로젝트: alishakiba/jDenetX
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (!(m_Classifier instanceof WeightedInstancesHandler)) {
      throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!");
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    m_Train = new Instances(instances, 0, instances.numInstances());

    m_NNSearch.setInstances(m_Train);
  }
예제 #6
0
  /**
   * Build the associator on the filtered data.
   *
   * @param data the training data
   * @throws Exception if the Associator could not be built successfully
   */
  public void buildAssociations(Instances data) throws Exception {
    if (m_Associator == null) throw new Exception("No base associator has been set!");

    // create copy and set class-index
    data = new Instances(data);
    if (getClassIndex() == 0) {
      data.setClassIndex(data.numAttributes() - 1);
    } else {
      data.setClassIndex(getClassIndex() - 1);
    }

    if (getClassIndex() != -1) {
      // remove instances with missing class
      data.deleteWithMissingClass();
    }

    m_Filter.setInputFormat(data); // filter capabilities are checked here
    data = Filter.useFilter(data, m_Filter);

    // can associator handle the data?
    getAssociator().getCapabilities().testWithFail(data);

    m_FilteredInstances = data.stringFreeStructure();
    m_Associator.buildAssociations(data);
  }
 public void buildClassifier(Instances data) throws Exception {
   // remove instances with missing class
   data = new Instances(data);
   data.deleteWithMissingClass();
   buildTree(data);
   collapse();
   prune();
   cleanup(new Instances(data, 0));
 }
  /**
   * the standard collective classifier accepts only nominal, binary classes otherwise an exception
   * is thrown
   *
   * @throws Exception if the data doesn't have a nominal, binary class
   */
  protected void checkRestrictions() throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(m_Trainset);

    // remove instances with missing class
    m_Trainset = new Instances(m_Trainset);
    m_Trainset.deleteWithMissingClass();
    if (m_Testset != null) m_Testset = new Instances(m_Testset);
  }
예제 #9
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
예제 #10
0
파일: Id3.java 프로젝트: alishakiba/jDenetX
  /**
   * Builds Id3 decision tree classifier.
   *
   * @param data the training data
   * @exception Exception if classifier can't be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    makeTree(data);
  }
예제 #11
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    String debug = "(KStar.buildClassifier) ";

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_Train = new Instances(instances, 0, instances.numInstances());

    // initializes class attributes ** java-speaking! :-) **
    init_m_Attributes();
  }
예제 #12
0
  /**
   * Boosting method.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    super.buildClassifier(data);

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    m_SumOfWeights = data.sumOfWeights();

    if ((!m_UseResampling) && (m_Classifier instanceof WeightedInstancesHandler)) {
      buildClassifierWithWeights(data);
    } else {
      buildClassifierUsingResampling(data);
    }
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();

    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
      m_Classifier.buildClassifier(trainData);
      m_BestClassifierOptions = m_InitOptions;
      return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;

    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String[] options = (String[]) m_BestClassifierOptions.clone();
    ((OptionHandler) m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }
예제 #14
0
  /**
   * Generates an attribute evaluator. Has to initialise all fields of the evaluator that are not
   * being set via options.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the evaluator has not been generated successfully
   */
  public void buildEvaluator(Instances data) throws Exception {

    // can evaluator handle data?
    getCapabilities().testWithFail(data);

    m_trainInstances = new Instances(data);
    m_trainInstances.deleteWithMissingClass();

    m_numAttribs = m_trainInstances.numAttributes();
    m_numInstances = m_trainInstances.numInstances();

    // if the data has no decision feature, m_classIndex is negative
    m_classIndex = m_trainInstances.classIndex();

    // supervised
    if (m_classIndex >= 0) {
      m_isNumeric = m_trainInstances.attribute(m_classIndex).isNumeric();

      if (m_isNumeric) {
        m_DecisionSimilarity = m_Similarity;
      } else m_DecisionSimilarity = m_SimilarityEq;
    }

    m_Similarity.setInstances(m_trainInstances);
    m_DecisionSimilarity.setInstances(m_trainInstances);
    m_SimilarityEq.setInstances(m_trainInstances);
    m_composition = m_Similarity.getTNorm();

    m_FuzzyMeasure.set(
        m_Similarity,
        m_DecisionSimilarity,
        m_TNorm,
        m_composition,
        m_Implicator,
        m_SNorm,
        m_numInstances,
        m_numAttribs,
        m_classIndex,
        m_trainInstances);
  }
예제 #15
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_NumClasses = instances.numClasses();
    m_ClassType = instances.classAttribute().type();
    m_Train = new Instances(instances, 0, instances.numInstances());

    // Throw away initial instances until within the specified window size
    if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
      m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize);
    }

    m_NumAttributesUsed = 0.0;
    for (int i = 0; i < m_Train.numAttributes(); i++) {
      if ((i != m_Train.classIndex())
          && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) {
        m_NumAttributesUsed += 1.0;
      }
    }

    m_NNSearch.setInstances(m_Train);

    // Invalidate any currently cross-validation selected k
    m_kNNValid = false;

    m_defaultModel = new ZeroR();
    m_defaultModel.buildClassifier(instances);
    m_defaultModel.setOptions(getOptions());
    // System.out.println("hello world");

  }
예제 #16
0
  public static void test_NHBS_old() throws Exception {
    // load the data
    CSVLoader loader = new CSVLoader();
    // these must come before the getDataSet()
    // loader.setEnclosureCharacters(",\'\"S");
    // loader.setNominalAttributes("16,71"); //zip code, drug name
    // loader.setStringAttributes("");
    // loader.setDateAttributes("0,1");
    // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv"));
    loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv"));
    Instances nhbs_data = loader.getDataSet();
    loader.setMissingValue("NOVALUE");
    // loader.setMissingValue("");

    nhbs_data.deleteAttributeAt(12); // zip code
    nhbs_data.deleteAttributeAt(1); // date - redundant with age
    nhbs_data.deleteAttributeAt(0); // date
    System.out.println("classifying attribute:");
    nhbs_data.setClassIndex(1); // new index  3->2->1
    nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1

    // wishlist: perhaps it would be smarter to throw out unclassified instance?  they interfere
    // with the scoring
    nhbs_data.deleteWithMissingClass();
    // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column.  all are
    // mostly negative
    // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column

    // #14, i.e. rds_fem, should be made numeric
    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    Classifier classifier = new RandomForest();
    String[] options = {
      "-I", "100", "-K", "4"
    }; // -I trees, -K features per tree.  generally, might want to optimize (or not
       // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    classifier.setOptions(options);
    // Classifier classifier = new Logistic();

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent

    // does poorly
    // Classifier classifier = new AdaBoostM1();
    // Classifier classifier = new MultiBoostAB();
    // Classifier classifier = new Stacking();

    // building a C45 tree classifier
    // J48 classifier = new J48(); // new instance of tree
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options
    // classifier.buildClassifier(nhbs_data); // build classifier

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1));
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
예제 #17
0
  /**
   * Build Decorate classifier
   *
   * @param data the training data to be used for generating the classifier
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {
    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }
    if (data.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }
    if (data.classAttribute().isNumeric()) {
      throw new UnsupportedClassTypeException("Decorate can't handle a numeric class!");
    }
    if (m_NumIterations < m_DesiredSize)
      throw new Exception("Max number of iterations must be >= desired ensemble size!");

    // initialize random number generator
    if (m_Seed == -1) m_Random = new Random();
    else m_Random = new Random(m_Seed);

    int i = 1; // current committee size
    int numTrials = 1; // number of Decorate iterations
    Instances divData = new Instances(data); // local copy of data - diversity data
    divData.deleteWithMissingClass();
    Instances artData = null; // artificial data

    // compute number of artficial instances to add at each iteration
    int artSize = (int) (Math.abs(m_ArtSize) * divData.numInstances());
    if (artSize == 0) artSize = 1; // atleast add one random example
    computeStats(data); // Compute training data stats for creating artificial examples

    // initialize new committee
    m_Committee = new Vector();
    Classifier newClassifier = m_Classifier;
    newClassifier.buildClassifier(divData);
    m_Committee.add(newClassifier);
    double eComm = computeError(divData); // compute ensemble error
    if (m_Debug)
      System.out.println(
          "Initialize:\tClassifier " + i + " added to ensemble. Ensemble error = " + eComm);

    // repeat till desired committee size is reached OR the max number of iterations is exceeded
    while (i < m_DesiredSize && numTrials < m_NumIterations) {
      // Generate artificial training examples
      artData = generateArtificialData(artSize, data);

      // Label artificial examples
      labelData(artData);
      addInstances(divData, artData); // Add new artificial data

      // Build new classifier
      Classifier tmp[] = Classifier.makeCopies(m_Classifier, 1);
      newClassifier = tmp[0];
      newClassifier.buildClassifier(divData);
      // Remove all the artificial data
      removeInstances(divData, artSize);

      // Test if the new classifier should be added to the ensemble
      m_Committee.add(newClassifier); // add new classifier to current committee
      double currError = computeError(divData);
      if (currError <= eComm) { // adding the new member did not increase the error
        i++;
        eComm = currError;
        if (m_Debug)
          System.out.println(
              "Iteration: "
                  + (1 + numTrials)
                  + "\tClassifier "
                  + i
                  + " added to ensemble. Ensemble error = "
                  + eComm);
      } else { // reject the current classifier because it increased the ensemble error
        m_Committee.removeElementAt(m_Committee.size() - 1); // pop the last member
      }
      numTrials++;
    }
  }
  public static void main(String[] args) {

    if (args.length < 1) {
      System.out.println("usage: C4_5TweetTopicCategorization <root_path>");
      System.exit(-1);
    }

    String rootPath = args[0];
    File dataFolder = new File(rootPath + "/data");
    String resultFolderPath = rootPath + "/results/C4_5/";

    CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer();
    Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class);
    PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH);

    File resultFolder = new File(resultFolderPath);
    if (!resultFolder.exists()) resultFolder.mkdir();

    CSVLoader csvLoader = new CSVLoader();

    try {
      for (File dataSetName : dataFolder.listFiles()) {

        Instances data = null;
        try {
          csvLoader.setSource(dataSetName);
          csvLoader.setStringAttributes("2");
          data = csvLoader.getDataSet();
        } catch (IOException ioe) {
          logger.error(ioe);
          crisisMailer.sendEmailAlert(ioe);
          System.exit(-1);
        }

        data.setClassIndex(data.numAttributes() - 1);
        data.deleteWithMissingClass();

        Instances vectorizedData = null;
        StringToWordVector stringToWordVectorFilter = new StringToWordVector();
        try {
          stringToWordVectorFilter.setInputFormat(data);
          stringToWordVectorFilter.setAttributeIndices("2");
          stringToWordVectorFilter.setIDFTransform(true);
          stringToWordVectorFilter.setLowerCaseTokens(true);
          stringToWordVectorFilter.setOutputWordCounts(false);
          stringToWordVectorFilter.setUseStoplist(true);

          vectorizedData = Filter.useFilter(data, stringToWordVectorFilter);
          vectorizedData.deleteAttributeAt(0);
          // System.out.println(vectorizedData);
        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }

        J48 j48Classifier = new J48();

        /*
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(stringToWordVectorFilter);
        filteredClassifier.setClassifier(j48Classifier);
        */

        try {
          Evaluation eval = new Evaluation(vectorizedData);
          eval.crossValidateModel(
              j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis()));

          FileOutputStream resultOutputStream =
              new FileOutputStream(new File(resultFolderPath + dataSetName.getName()));

          resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes());
          resultOutputStream.write(eval.toMatrixString().getBytes());
          resultOutputStream.write(eval.toClassDetailsString().getBytes());
          resultOutputStream.close();

        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }
      }
    } catch (Exception exception) {
      logger.error(exception);
      crisisMailer.sendEmailAlert(exception);
      System.out.println(-1);
    }
  }
  /**
   * Build the classifier on the supplied data
   *
   * @param data the training data
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    super.buildClassifier(data);

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();

    double sum = 0;
    double temp_sum = 0;
    // Add the model for the mean first
    m_zeroR = new ZeroR();
    m_zeroR.buildClassifier(newData);

    // only class? -> use only ZeroR model
    if (newData.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_SuitableData = false;
      return;
    } else {
      m_SuitableData = true;
    }

    newData = residualReplace(newData, m_zeroR, false);
    for (int i = 0; i < newData.numInstances(); i++) {
      sum +=
          newData.instance(i).weight()
              * newData.instance(i).classValue()
              * newData.instance(i).classValue();
    }
    if (m_Debug) {
      System.err.println("Sum of squared residuals " + "(predicting the mean) : " + sum);
    }

    m_NumIterationsPerformed = 0;
    do {
      temp_sum = sum;

      // +++++ CHANGES FROM LEFMAN START ++++++++

      Resample resample = new Resample();
      resample.setRandomSeed(m_NumIterationsPerformed);
      resample.setNoReplacement(true);
      resample.setSampleSizePercent(getPercentage());
      resample.setInputFormat(newData);
      Instances sampledData = Filter.useFilter(newData, resample);

      // Build the classifier
      // m_Classifiers[m_NumIterationsPerformed].buildClassifier(newData);

      m_Classifiers[m_NumIterationsPerformed].buildClassifier(sampledData);
      // output the number of nodes in the tree!
      double numNodes =
          ((REPTree) m_Classifiers[m_NumIterationsPerformed]).getMeasure("measureTreeSize");
      if (m_Debug) {
        System.err.println("It#: " + m_NumIterationsPerformed + " #nodes: " + numNodes);
      }

      // +++++ CHANGES FROM LEFMAN END ++++++++

      newData = residualReplace(newData, m_Classifiers[m_NumIterationsPerformed], true);
      sum = 0;
      for (int i = 0; i < newData.numInstances(); i++) {
        sum +=
            newData.instance(i).weight()
                * newData.instance(i).classValue()
                * newData.instance(i).classValue();
      }
      if (m_Debug) {
        System.err.println("Sum of squared residuals : " + sum);
      }
      m_NumIterationsPerformed++;
    } while (((temp_sum - sum) > Utils.SMALL) && (m_NumIterationsPerformed < m_Classifiers.length));
  }
예제 #20
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @exception Exception if the classifier has not been generated successfully
   */
  @Override
  public void buildClassifier(Instances instances) throws Exception {

    int attIndex = 0;
    double sum;

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_Instances = new Instances(instances, 0);

    // Reserve space
    m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0];
    m_Means = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Priors = new double[instances.numClasses()];
    Enumeration<Attribute> enu = instances.enumerateAttributes();
    while (enu.hasMoreElements()) {
      Attribute attribute = enu.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[attribute.numValues()];
        }
      } else {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[1];
        }
      }
      attIndex++;
    }

    // Compute counts and sums
    Enumeration<Instance> enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNominal()) {
              m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++;
            } else {
              m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute);
              m_Counts[(int) instance.classValue()][attIndex][0]++;
            }
          }
          attIndex++;
        }
        m_Priors[(int) instance.classValue()]++;
      }
    }

    // Compute means
    Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Counts[j][attIndex][0] < 2) {
            throw new Exception(
                "attribute "
                    + attribute.name()
                    + ": less than two values for class "
                    + instances.classAttribute().value(j));
          }
          m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
        }
      }
      attIndex++;
    }

    // Compute standard deviations
    enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNumeric()) {
              m_Devs[(int) instance.classValue()][attIndex] +=
                  (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute))
                      * (m_Means[(int) instance.classValue()][attIndex]
                          - instance.value(attribute));
            }
          }
          attIndex++;
        }
      }
    }
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Devs[j][attIndex] <= 0) {
            throw new Exception(
                "attribute "
                    + attribute.name()
                    + ": standard deviation is 0 for class "
                    + instances.classAttribute().value(j));
          } else {
            m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
            m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
          }
        }
      }
      attIndex++;
    }

    // Normalize counts
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          sum = Utils.sum(m_Counts[j][attIndex]);
          for (int i = 0; i < attribute.numValues(); i++) {
            m_Counts[j][attIndex][i] =
                (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues());
          }
        }
      }
      attIndex++;
    }

    // Normalize priors
    sum = Utils.sum(m_Priors);
    for (int j = 0; j < instances.numClasses(); j++) {
      m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses());
    }
  }
예제 #21
0
  /**
   * Method for building the classifier.
   *
   * @param instances the set of training instances
   * @throws Exception if the classifier can't be built successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // Removes all the instances with weight equal to 0.
    // MUST be done since condition (8) of Keerthi's paper
    // is made with the assertion Ci > 0 (See equation (3a).
    Instances data = new Instances(instances, 0);
    for (int i = 0; i < instances.numInstances(); i++) {
      if (instances.instance(i).weight() > 0) {
        data.add(instances.instance(i));
      }
    }

    if (data.numInstances() == 0) {
      throw new Exception(
          "No training instances left after removing "
              + "instance with either a weight null or a missing class!");
    }
    instances = data;

    m_onlyNumeric = true;
    for (int i = 0; i < instances.numAttributes(); i++) {
      if (i != instances.classIndex()) {
        if (!instances.attribute(i).isNumeric()) {
          m_onlyNumeric = false;
          break;
        }
      }
    }
    m_Missing = new ReplaceMissingValues();
    m_Missing.setInputFormat(instances);
    instances = Filter.useFilter(instances, m_Missing);

    if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) {
      if (!m_onlyNumeric) {
        m_NominalToBinary = new NominalToBinary();
        m_NominalToBinary.setInputFormat(instances);
        instances = Filter.useFilter(instances, m_NominalToBinary);
      } else {
        m_NominalToBinary = null;
      }
    } else {
      m_NominalToBinary = null;
    }

    // retrieve two different class values used to determine filter transformation
    double y0 = instances.instance(0).classValue();
    int index = 1;
    while (index < instances.numInstances() && instances.instance(index).classValue() == y0) {
      index++;
    }
    if (index == instances.numInstances()) {
      // degenerate case, all class values are equal
      // we don't want to deal with this, too much hassle
      throw new Exception(
          "All class values are the same. At least two class values should be different");
    }
    double y1 = instances.instance(index).classValue();

    // apply filters
    if (m_filterType == FILTER_STANDARDIZE) {
      m_Filter = new Standardize();
      ((Standardize) m_Filter).setIgnoreClass(true);
      m_Filter.setInputFormat(instances);
      instances = Filter.useFilter(instances, m_Filter);
    } else if (m_filterType == FILTER_NORMALIZE) {
      m_Filter = new Normalize();
      ((Normalize) m_Filter).setIgnoreClass(true);
      m_Filter.setInputFormat(instances);
      instances = Filter.useFilter(instances, m_Filter);
    } else {
      m_Filter = null;
    }
    if (m_Filter != null) {
      double z0 = instances.instance(0).classValue();
      double z1 = instances.instance(index).classValue();
      m_x1 =
          (y0 - y1) / (z0 - z1); // no division by zero, since y0 != y1 guaranteed => z0 != z1 ???
      m_x0 = (y0 - m_x1 * z0); // = y1 - m_x1 * z1
    } else {
      m_x1 = 1.0;
      m_x0 = 0.0;
    }

    m_optimizer.setSMOReg(this);
    m_optimizer.buildClassifier(instances);
  }
예제 #22
0
  /**
   * Generates the classifier.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    m_theInstances = new Instances(data);
    m_theInstances.deleteWithMissingClass();

    m_rr = new Random(1);

    if (m_theInstances.classAttribute().isNominal()) { // 	 Set up class priors
      m_classPriorCounts = new double[data.classAttribute().numValues()];
      Arrays.fill(m_classPriorCounts, 1.0);
      for (int i = 0; i < data.numInstances(); i++) {
        Instance curr = data.instance(i);
        m_classPriorCounts[(int) curr.classValue()] += curr.weight();
      }
      m_classPriors = m_classPriorCounts.clone();
      Utils.normalize(m_classPriors);
    }

    setUpEvaluator();

    if (m_theInstances.classAttribute().isNumeric()) {
      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
      m_classIsNominal = false;

      // use binned discretisation if the class is numeric
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10);
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true);

      // Discretize all attributes EXCEPT the class
      String rangeList = "";
      rangeList += (m_theInstances.classIndex() + 1);
      // System.out.println("The class col: "+m_theInstances.classIndex());

      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform)
          .setAttributeIndices(rangeList);
    } else {
      m_disTransform = new weka.filters.supervised.attribute.Discretize();
      ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true);
      m_classIsNominal = true;
    }

    m_disTransform.setInputFormat(m_theInstances);
    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);

    m_numAttributes = m_theInstances.numAttributes();
    m_numInstances = m_theInstances.numInstances();
    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());

    // Perform the search
    int[] selected = m_search.search(m_evaluator, m_theInstances);

    m_decisionFeatures = new int[selected.length + 1];
    System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length);
    m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex();

    // reduce instances to selected features
    m_delTransform = new Remove();
    m_delTransform.setInvertSelection(true);

    // set features to keep
    m_delTransform.setAttributeIndicesArray(m_decisionFeatures);
    m_delTransform.setInputFormat(m_theInstances);
    m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform);

    // reset the number of attributes
    m_numAttributes = m_dtInstances.numAttributes();

    // create hash table
    m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (int i = 0; i < m_numInstances; i++) {
      Instance inst = m_dtInstances.instance(i);
      insertIntoTable(inst, null);
    }

    // Replace the global table majority with nearest neighbour?
    if (m_useIBk) {
      m_ibk = new IBk();
      m_ibk.buildClassifier(m_theInstances);
    }

    // Save memory
    if (m_saveMemory) {
      m_theInstances = new Instances(m_theInstances, 0);
      m_dtInstances = new Instances(m_dtInstances, 0);
    }
    m_evaluation = null;
  }
예제 #23
0
파일: VFI.java 프로젝트: SuperWan/weka
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (!m_weightByConfidence) {
      TINY = 0.0;
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_ClassIndex = instances.classIndex();
    m_NumClasses = instances.numClasses();
    m_globalCounts = new double[m_NumClasses];
    m_maxEntrop = Math.log(m_NumClasses) / Math.log(2);

    m_Instances = new Instances(instances, 0); // Copy the structure for ref

    m_intervalBounds = new double[instances.numAttributes()][2 + (2 * m_NumClasses)];

    for (int j = 0; j < instances.numAttributes(); j++) {
      boolean alt = false;
      for (int i = 0; i < m_NumClasses * 2 + 2; i++) {
        if (i == 0) {
          m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
        } else if (i == m_NumClasses * 2 + 1) {
          m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
        } else {
          if (alt) {
            m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
            alt = false;
          } else {
            m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
            alt = true;
          }
        }
      }
    }

    // find upper and lower bounds for numeric attributes
    for (int j = 0; j < instances.numAttributes(); j++) {
      if (j != m_ClassIndex && instances.attribute(j).isNumeric()) {
        for (int i = 0; i < instances.numInstances(); i++) {
          Instance inst = instances.instance(i);
          if (!inst.isMissing(j)) {
            if (inst.value(j) < m_intervalBounds[j][((int) inst.classValue() * 2 + 1)]) {
              m_intervalBounds[j][((int) inst.classValue() * 2 + 1)] = inst.value(j);
            }
            if (inst.value(j) > m_intervalBounds[j][((int) inst.classValue() * 2 + 2)]) {
              m_intervalBounds[j][((int) inst.classValue() * 2 + 2)] = inst.value(j);
            }
          }
        }
      }
    }

    m_counts = new double[instances.numAttributes()][][];

    // sort intervals
    for (int i = 0; i < instances.numAttributes(); i++) {
      if (instances.attribute(i).isNumeric()) {
        int[] sortedIntervals = Utils.sort(m_intervalBounds[i]);
        // remove any duplicate bounds
        int count = 1;
        for (int j = 1; j < sortedIntervals.length; j++) {
          if (m_intervalBounds[i][sortedIntervals[j]]
              != m_intervalBounds[i][sortedIntervals[j - 1]]) {
            count++;
          }
        }
        double[] reordered = new double[count];
        count = 1;
        reordered[0] = m_intervalBounds[i][sortedIntervals[0]];
        for (int j = 1; j < sortedIntervals.length; j++) {
          if (m_intervalBounds[i][sortedIntervals[j]]
              != m_intervalBounds[i][sortedIntervals[j - 1]]) {
            reordered[count] = m_intervalBounds[i][sortedIntervals[j]];
            count++;
          }
        }
        m_intervalBounds[i] = reordered;
        m_counts[i] = new double[count][m_NumClasses];
      } else if (i != m_ClassIndex) { // nominal attribute
        m_counts[i] = new double[instances.attribute(i).numValues()][m_NumClasses];
      }
    }

    // collect class counts
    for (int i = 0; i < instances.numInstances(); i++) {
      Instance inst = instances.instance(i);
      m_globalCounts[(int) instances.instance(i).classValue()] += inst.weight();
      for (int j = 0; j < instances.numAttributes(); j++) {
        if (!inst.isMissing(j) && j != m_ClassIndex) {
          if (instances.attribute(j).isNumeric()) {
            double val = inst.value(j);

            int k;
            for (k = m_intervalBounds[j].length - 1; k >= 0; k--) {
              if (val > m_intervalBounds[j][k]) {
                m_counts[j][k][(int) inst.classValue()] += inst.weight();
                break;
              } else if (val == m_intervalBounds[j][k]) {
                m_counts[j][k][(int) inst.classValue()] += (inst.weight() / 2.0);
                m_counts[j][k - 1][(int) inst.classValue()] += (inst.weight() / 2.0);
                ;
                break;
              }
            }

          } else {
            // nominal attribute
            m_counts[j][(int) inst.value(j)][(int) inst.classValue()] += inst.weight();
            ;
          }
        }
      }
    }
  }
예제 #24
0
  /**
   * Carry out the bias-variance decomposition
   *
   * @throws Exception if the decomposition couldn't be carried out
   */
  public void decompose() throws Exception {

    Reader dataReader = new BufferedReader(new FileReader(m_DataFileName));
    Instances data = new Instances(dataReader);

    if (m_ClassIndex < 0) {
      data.setClassIndex(data.numAttributes() - 1);
    } else {
      data.setClassIndex(m_ClassIndex);
    }
    if (data.classAttribute().type() != Attribute.NOMINAL) {
      throw new Exception("Class attribute must be nominal");
    }
    int numClasses = data.numClasses();

    data.deleteWithMissingClass();
    if (data.checkForStringAttributes()) {
      throw new Exception("Can't handle string attributes!");
    }

    if (data.numInstances() < 2 * m_TrainPoolSize) {
      throw new Exception(
          "The dataset must contain at least " + (2 * m_TrainPoolSize) + " instances");
    }
    Random random = new Random(m_Seed);
    data.randomize(random);
    Instances trainPool = new Instances(data, 0, m_TrainPoolSize);
    Instances test = new Instances(data, m_TrainPoolSize, data.numInstances() - m_TrainPoolSize);
    int numTest = test.numInstances();
    double[][] instanceProbs = new double[numTest][numClasses];

    m_Error = 0;
    for (int i = 0; i < m_TrainIterations; i++) {
      if (m_Debug) {
        System.err.println("Iteration " + (i + 1));
      }
      trainPool.randomize(random);
      Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2);

      Classifier current = AbstractClassifier.makeCopy(m_Classifier);
      current.buildClassifier(train);

      //// Evaluate the classifier on test, updating BVD stats
      for (int j = 0; j < numTest; j++) {
        int pred = (int) current.classifyInstance(test.instance(j));
        if (pred != test.instance(j).classValue()) {
          m_Error++;
        }
        instanceProbs[j][pred]++;
      }
    }
    m_Error /= (m_TrainIterations * numTest);

    // Average the BV over each instance in test.
    m_Bias = 0;
    m_Variance = 0;
    m_Sigma = 0;
    for (int i = 0; i < numTest; i++) {
      Instance current = test.instance(i);
      double[] predProbs = instanceProbs[i];
      double pActual, pPred;
      double bsum = 0, vsum = 0, ssum = 0;
      for (int j = 0; j < numClasses; j++) {
        pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data?
        pPred = predProbs[j] / m_TrainIterations;
        bsum +=
            (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_TrainIterations - 1);
        vsum += pPred * pPred;
        ssum += pActual * pActual;
      }
      m_Bias += bsum;
      m_Variance += (1 - vsum);
      m_Sigma += (1 - ssum);
    }
    m_Bias /= (2 * numTest);
    m_Variance /= (2 * numTest);
    m_Sigma /= (2 * numTest);

    if (m_Debug) {
      System.err.println("Decomposition finished");
    }
  }
예제 #25
0
  public static void main(String[] args) {

    final boolean precomputed = false;

    Instances trainData = null, testData = null;
    final boolean crossValidation = args.length == 1;

    try {
      String dataFile = args[0];
      System.err.println("INFO: Loading dataset from '" + dataFile + "' ...");

      CSVLoader csvLoader = new CSVLoader();
      // csvLoader.setStringAttributes("first"); // id
      // csvLoader.setNominalAttributes("last"); // label
      csvLoader.setNominalAttributes("first,last"); // id, label
      csvLoader.setSource(new FileInputStream(dataFile));

      Instances data = csvLoader.getDataSet();

      if (!crossValidation) {
        System.err.println(
            "INFO: Loading splits from '" + args[1] + "' (train), '" + args[2] + "' (test) ...");
        Set<String> trainLabels =
            Collections.unmodifiableSet(new HashSet<String>(readAsList(args[1])));
        Set<String> testLabels =
            Collections.unmodifiableSet(new HashSet<String>(readAsList(args[2])));
        {
          Set<String> intersection = new HashSet<String>(trainLabels);
          intersection.retainAll(testLabels);
          if (!intersection.isEmpty())
            throw new IllegalStateException(
                "Train and test sets intersect: " + Arrays.toString(intersection.toArray()));
        }
        RemoveWithStringValues rwsv = new RemoveWithStringValues();
        rwsv.setAttributeIndex("first");
        trainData = rwsv.setValues(trainLabels).split(data, true, false);
        testData = rwsv.setValues(testLabels).split(data, true, false);

        if (trainData.classIndex() == -1) trainData.setClassIndex(trainData.numAttributes() - 1);
        if (testData.classIndex() == -1) testData.setClassIndex(testData.numAttributes() - 1);

        data = null;
      } else if (!precomputed) {
        if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1);
        data.deleteWithMissingClass();
      }

      // feature selection
      // if (crossValidation)
      // {
      // InfoGainFeatureSelection igfs = new InfoGainFeatureSelection(10,
      // true);
      // igfs.build(data);
      //
      // System.err.println("INFO: Selected " +
      // igfs.selectedAttributes().size() +" of " +
      // (data.numAttributes()-2) + " attributes.");
      //
      // Remove r = new Remove();
      // r.setAttributeIndices("first," + igfs.selectedAttributeRange()
      // +",last");
      // r.setInvertSelection(true);
      //
      // r.setInputFormat(data);
      // data = Filter.useFilter(data, r);
      // if (data.classIndex() == -1)
      // data.setClassIndex(data.numAttributes() - 1);
      // }

      // standardize
      // {
      // Standardize st = new Standardize();
      //
      // st.setInputFormat(data);
      // data = Filter.useFilter(data, st);
      // if (data.classIndex() == -1)
      // data.setClassIndex(data.numAttributes() - 1);
      // }

      // weighting to balance classes
      // {
      // System.err.println("INFO: Weighting instances...");
      // // collect statistics
      // int[] classSupport = new int[data.numClasses()];
      // int hasClass = 0;
      // for (int i = 0; i < data.numInstances(); i++) {
      // if (data.instance(i).classIsMissing())
      // continue;
      // classSupport[(int) data.instance(i).classValue()]++;
      // hasClass++;
      // }
      //
      // // calculate weights
      // double[] classWeight = new double[data.numClasses()];
      // final int smoothFactor = 2;
      // double expectedFrequency = 1.0d / data.numClasses();
      // for (int i = 0; i < data.numClasses(); i++) {
      // final double frequency = classSupport[i]
      // / (double) hasClass;
      // final double ratio = expectedFrequency / frequency;
      // classWeight[i] = (smoothFactor + ratio)/(smoothFactor + 1);
      // System.err.println("INFO: Class '" +
      // data.classAttribute().value(i) + "' instance weight set to " +
      // (float)classWeight[i] );
      // }
      //
      // for (int i = 0; i < data.numInstances(); i++) {
      // if (data.instance(i).classIsMissing())
      // continue;
      // Instance inst = data.instance(i);
      // inst.setWeight(classWeight[(int) inst.classValue()]);
      // }
      //
      // }

      if (crossValidation) {
        for (int i = 0; i < data.numAttributes(); i++) {
          if (i == 0 || i == data.classIndex()) continue;
          Attribute a = data.attribute(i);
          if (!a.isNumeric()) throw new IllegalStateException("attribute is not numeric: " + a);
        }
      }

      final Classifier c;
      {
        if (!precomputed) {
          // classifier
          SMO smo = new SMO();
          // c = smo;

          FilteredClassifier smoRegex = prefixFilteredClassifier(new SMO(), "regex_");

          HackedClassifier smoRegexHacked = new HackedClassifier();
          smoRegexHacked.setSureClasses(Collections.singleton(Label.DM_Dermatology.code));
          smoRegexHacked.setClassifier(smoRegex);

          PrecomputedClassifier pc =
              new PrecomputedClassifier(
                  Maps.readStringStringMapFromFile(
                      new File("resources/classificationResults_sampleRun.train+test.csv"),
                      Charset.forName("US-ASCII"),
                      ",",
                      new HashMap<String, String>(),
                      true,
                      1));
          pc.setSmoothFactor(0.7);
          weka.classifiers.meta.Vote v = new Vote();
          v.setOptions(new String[] {"-R", "AVG"});
          v.setClassifiers(
              new Classifier[] {wrapRemoveFirst(smo), wrapRemoveFirst(smoRegexHacked), pc});

          //					c = v;
          c = wrapRemoveFirst(smo);

          // GaussianProcesses gp = new GaussianProcesses ();
          // c = gp;

          // LADTree t = new LADTree();
          // c = t;

          // LibSVM svm = new LibSVM();
          // c = svm;

          // JRip jrip = new JRip();
          // c = jrip;

          // LogitBoost lb = new LogitBoost();
          // lb.setClassifier(new SMOreg());
          // lb.setDebug(true);
          // c = lb;
          // BayesNet bn = new BayesNet();
          // // c = bn;
          //
          // J48 j48 = new J48();
          // // j48.setMinNumObj(10);
          // // c = j48;
          //
          // ClassificationViaRegression cvr = new
          // ClassificationViaRegression();
          // cvr.setClassifier(new SMOreg());
          // c = cvr;

          // Vote v = new Vote();
          // v.setClassifiers(new Classifier[] {smo, bn, j48});
          // c = v;

          // // HIERARCHY -- groups
          // HierachyNode root = new HierachyNode("ROOT");
          // Map<Group, HierachyNode> groupNodes = new HashMap<Group,
          // HierachyNode>();
          // for (Group g : Group.values()) {
          // HierachyNode groupNode = new HierachyNode(g.name());
          // root.addChild(groupNode);
          // groupNodes.put(g, groupNode);
          // }
          // for (Label l : Label.values()) {
          // if (l.code != null) {
          // HierachyNode labelNode = new HierachyNode(l.code);
          // groupNodes.get(l.group).addChild(labelNode);
          // } else
          // System.err
          // .println("WARNING: Skipping label without code  "
          // + l);
          // }
          // System.err.println(root.toString());
          //
          // HierarchicalClassifier hc = new
          // HierarchicalClassifier(smo,
          // root);
          // c = hc;

          // // HIERARCHY -- custom
          // HierachyNode root = new HierachyNode("ROOT");
          // HierachyNode radiology = new HierachyNode("RADIOLOGY");
          // HierachyNode graphic = new HierachyNode("GRAPHIC");
          // HierachyNode photo = new HierachyNode("PHOTO");
          // root.addChild(radiology);
          // root.addChild(graphic);
          // root.addChild(photo);
          // for (Label l : Label.values()) {
          // if (l.code != null) {
          // HierachyNode labelNode = new HierachyNode(l.code);
          //
          // switch (l) {
          // case _3D_ThreeDee:
          // graphic.addChild(labelNode);
          // break;
          // case AN_Angiography:
          // radiology.addChild(labelNode);
          // break;
          // case CM_CompoundFigure:
          // graphic.addChild(labelNode);
          // break;
          // case CT_ComputedTomography:
          // radiology.addChild(labelNode);
          // break;
          // case DM_Dermatology:
          // graphic.addChild(labelNode);
          // break;
          // case DR_Drawing:
          // graphic.addChild(labelNode);
          // break;
          // case EM_ElectronMicroscopy:
          // photo.addChild(labelNode);
          // break;
          // case EN_Endoscope:
          // photo.addChild(labelNode);
          // break;
          // case FL_Fluorescense:
          // radiology.addChild(labelNode);
          // break;
          // case GL_Gel:
          // graphic.addChild(labelNode);
          // break;
          // case GR_GrossPathology:
          // photo.addChild(labelNode);
          // break;
          // case GX_Graphs:
          // graphic.addChild(labelNode);
          // break;
          // case HX_Histopathology:
          // radiology.addChild(labelNode);
          // break;
          // case MR_MagneticResonance:
          // radiology.addChild(labelNode);
          // break;
          // case PX_Photo:
          // photo.addChild(labelNode);
          // break;
          // case RN_Retinograph:
          // radiology.addChild(labelNode);
          // break;
          // case US_Ultrasound:
          // radiology.addChild(labelNode);
          // break;
          // case XR_XRay:
          // radiology.addChild(labelNode);
          // break;
          // default:
          // throw new IllegalStateException(l.toString());
          // }
          //
          // } else
          // System.err
          // .println("WARNING: Skipping label without code  "+ l);
          // }
          // System.err.println(root.toString());
          //
          // HierarchicalClassifier hc = new
          // HierarchicalClassifier(smo,
          // root);
          // c = hc;

          // Cluster membership

          // EM em = new EM();
          // em.setNumClusters(Label.values().length); // TODO
          // heuristics
          //
          // ClusterMembership cm = new AddClusterMembership();
          // cm.setDensityBasedClusterer(em);
          // cm.setIgnoredAttributeIndices(String.valueOf(data.classIndex()+1));
          // // ignore class label
          //
          // FilteredClassifier fc = new FilteredClassifier();
          // fc.setClassifier(hc);
          // fc.setFilter(cm);
          //
          // c = fc;

          // ASEvaluation ae = new InfoGainAttributeEval();
          //
          // Ranker ranker = new Ranker();
          // ranker.setNumToSelect(data.numAttributes()/2);
          // AttributeSelectedClassifier asc = new
          // AttributeSelectedClassifier();
          // asc.setClassifier(smo);
          // asc.setEvaluator(ae);
          // asc.setSearch(ranker);
          //
          // c = asc;
        } else {
          c =
              new PrecomputedClassifier(
                  Maps.readStringStringMapFromFile(
                      new File("resources/classificationResults_sampleRun.txt"),
                      Charset.forName("US-ASCII"),
                      " ",
                      new HashMap<String, String>(),
                      true));
        }
      }

      c.setDebug(true);

      if (precomputed) evaluateClassifier(c, data, 10);
      else if (crossValidation) evaluateClassifier(c, data, 10);
      else
        // split
        evaluateClassifier(c, trainData, testData);

    } catch (Exception e) {
      System.err.println("ERROR: " + e.getMessage());
      e.printStackTrace();
      System.exit(1);
    }
  }
예제 #26
0
파일: TLD.java 프로젝트: 0x0539/weka
  /**
   * @param exs the training exemplars
   * @throws Exception if the model cannot be built properly
   */
  public void buildClassifier(Instances exs) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(exs);

    // remove instances with missing class
    exs = new Instances(exs);
    exs.deleteWithMissingClass();

    int numegs = exs.numInstances();
    m_Dimension = exs.attribute(1).relation().numAttributes();
    Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0);

    for (int u = 0; u < numegs; u++) {
      Instance example = exs.instance(u);
      if (example.classValue() == 1) pos.add(example);
      else neg.add(example);
    }

    int pnum = pos.numInstances(), nnum = neg.numInstances();

    m_MeanP = new double[pnum][m_Dimension];
    m_VarianceP = new double[pnum][m_Dimension];
    m_SumP = new double[pnum][m_Dimension];
    m_MeanN = new double[nnum][m_Dimension];
    m_VarianceN = new double[nnum][m_Dimension];
    m_SumN = new double[nnum][m_Dimension];
    m_ParamsP = new double[4 * m_Dimension];
    m_ParamsN = new double[4 * m_Dimension];

    // Estimation of the parameters: as the start value for search
    double[] pSumVal = new double[m_Dimension], // for m
        nSumVal = new double[m_Dimension];
    double[] maxVarsP = new double[m_Dimension], // for a
        maxVarsN = new double[m_Dimension];
    // Mean of sample variances: for b, b=a/E(\sigma^2)+2
    double[] varMeanP = new double[m_Dimension], varMeanN = new double[m_Dimension];
    // Variances of sample means: for w, w=E[var(\mu)]/E[\sigma^2]
    double[] meanVarP = new double[m_Dimension], meanVarN = new double[m_Dimension];
    // number of exemplars without all values missing
    double[] numExsP = new double[m_Dimension], numExsN = new double[m_Dimension];

    // Extract metadata fro both positive and negative bags
    for (int v = 0; v < pnum; v++) {
      /*Exemplar px = pos.exemplar(v);
      m_MeanP[v] = px.meanOrMode();
      m_VarianceP[v] = px.variance();
      Instances pxi =  px.getInstances();
      */

      Instances pxi = pos.instance(v).relationalValue(1);
      for (int k = 0; k < pxi.numAttributes(); k++) {
        m_MeanP[v][k] = pxi.meanOrMode(k);
        m_VarianceP[v][k] = pxi.variance(k);
      }

      for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
        // if((t==m_ClassIndex) || (t==m_IdIndex))
        //  t++;

        if (!Double.isNaN(m_MeanP[v][w])) {
          for (int u = 0; u < pxi.numInstances(); u++) {
            Instance ins = pxi.instance(u);
            if (!ins.isMissing(t)) m_SumP[v][w] += ins.weight();
          }
          numExsP[w]++;
          pSumVal[w] += m_MeanP[v][w];
          meanVarP[w] += m_MeanP[v][w] * m_MeanP[v][w];
          if (maxVarsP[w] < m_VarianceP[v][w]) maxVarsP[w] = m_VarianceP[v][w];
          varMeanP[w] += m_VarianceP[v][w];
          m_VarianceP[v][w] *= (m_SumP[v][w] - 1.0);
          if (m_VarianceP[v][w] < 0.0) m_VarianceP[v][w] = 0.0;
        }
      }
    }

    for (int v = 0; v < nnum; v++) {
      /*Exemplar nx = neg.exemplar(v);
      m_MeanN[v] = nx.meanOrMode();
      m_VarianceN[v] = nx.variance();
      Instances nxi =  nx.getInstances();
      */
      Instances nxi = neg.instance(v).relationalValue(1);
      for (int k = 0; k < nxi.numAttributes(); k++) {
        m_MeanN[v][k] = nxi.meanOrMode(k);
        m_VarianceN[v][k] = nxi.variance(k);
      }

      for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
        // if((t==m_ClassIndex) || (t==m_IdIndex))
        //  t++;

        if (!Double.isNaN(m_MeanN[v][w])) {
          for (int u = 0; u < nxi.numInstances(); u++)
            if (!nxi.instance(u).isMissing(t)) m_SumN[v][w] += nxi.instance(u).weight();
          numExsN[w]++;
          nSumVal[w] += m_MeanN[v][w];
          meanVarN[w] += m_MeanN[v][w] * m_MeanN[v][w];
          if (maxVarsN[w] < m_VarianceN[v][w]) maxVarsN[w] = m_VarianceN[v][w];
          varMeanN[w] += m_VarianceN[v][w];
          m_VarianceN[v][w] *= (m_SumN[v][w] - 1.0);
          if (m_VarianceN[v][w] < 0.0) m_VarianceN[v][w] = 0.0;
        }
      }
    }

    for (int w = 0; w < m_Dimension; w++) {
      pSumVal[w] /= numExsP[w];
      nSumVal[w] /= numExsN[w];
      if (numExsP[w] > 1)
        meanVarP[w] =
            meanVarP[w] / (numExsP[w] - 1.0) - pSumVal[w] * numExsP[w] / (numExsP[w] - 1.0);
      if (numExsN[w] > 1)
        meanVarN[w] =
            meanVarN[w] / (numExsN[w] - 1.0) - nSumVal[w] * numExsN[w] / (numExsN[w] - 1.0);
      varMeanP[w] /= numExsP[w];
      varMeanN[w] /= numExsN[w];
    }

    // Bounds and parameter values for each run
    double[][] bounds = new double[2][4];
    double[] pThisParam = new double[4], nThisParam = new double[4];

    // Initial values for parameters
    double a, b, w, m;

    // Optimize for one dimension
    for (int x = 0; x < m_Dimension; x++) {
      if (getDebug()) System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #" + x);

      // Positive examplars: first run
      a = (maxVarsP[x] > ZERO) ? maxVarsP[x] : 1.0;
      if (varMeanP[x] <= ZERO) varMeanP[x] = ZERO; // modified by LinDong (09/2005)
      b = a / varMeanP[x] + 2.0; // a/(b-2) = E(\sigma^2)
      w = meanVarP[x] / varMeanP[x]; // E[var(\mu)] = w*E[\sigma^2]	
      if (w <= ZERO) w = 1.0;

      m = pSumVal[x];
      pThisParam[0] = a; // a
      pThisParam[1] = b; // b
      pThisParam[2] = w; // w
      pThisParam[3] = m; // m

      // Negative examplars: first run
      a = (maxVarsN[x] > ZERO) ? maxVarsN[x] : 1.0;
      if (varMeanN[x] <= ZERO) varMeanN[x] = ZERO; // modified by LinDong (09/2005)
      b = a / varMeanN[x] + 2.0; // a/(b-2) = E(\sigma^2)
      w = meanVarN[x] / varMeanN[x]; // E[var(\mu)] = w*E[\sigma^2]	
      if (w <= ZERO) w = 1.0;

      m = nSumVal[x];
      nThisParam[0] = a; // a
      nThisParam[1] = b; // b
      nThisParam[2] = w; // w
      nThisParam[3] = m; // m

      // Bound constraints
      bounds[0][0] = ZERO; // a > 0
      bounds[0][1] = 2.0 + ZERO; // b > 2
      bounds[0][2] = ZERO; // w > 0
      bounds[0][3] = Double.NaN;

      for (int t = 0; t < 4; t++) {
        bounds[1][t] = Double.NaN;
        m_ParamsP[4 * x + t] = pThisParam[t];
        m_ParamsN[4 * x + t] = nThisParam[t];
      }
      double pminVal = Double.MAX_VALUE, nminVal = Double.MAX_VALUE;
      Random whichEx = new Random(m_Seed);
      TLD_Optm pOp = null, nOp = null;
      boolean isRunValid = true;
      double[] sumP = new double[pnum], meanP = new double[pnum], varP = new double[pnum];
      double[] sumN = new double[nnum], meanN = new double[nnum], varN = new double[nnum];

      // One dimension
      for (int p = 0; p < pnum; p++) {
        sumP[p] = m_SumP[p][x];
        meanP[p] = m_MeanP[p][x];
        varP[p] = m_VarianceP[p][x];
      }
      for (int q = 0; q < nnum; q++) {
        sumN[q] = m_SumN[q][x];
        meanN[q] = m_MeanN[q][x];
        varN[q] = m_VarianceN[q][x];
      }

      for (int y = 0; y < m_Run; ) {
        if (getDebug()) System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Run #" + y);
        double thisMin;

        if (getDebug()) System.err.println("\nPositive exemplars");
        pOp = new TLD_Optm();
        pOp.setNum(sumP);
        pOp.setSSquare(varP);
        pOp.setXBar(meanP);

        pThisParam = pOp.findArgmin(pThisParam, bounds);
        while (pThisParam == null) {
          pThisParam = pOp.getVarbValues();
          if (getDebug()) System.err.println("!!! 200 iterations finished, not enough!");
          pThisParam = pOp.findArgmin(pThisParam, bounds);
        }

        thisMin = pOp.getMinFunction();
        if (!Double.isNaN(thisMin) && (thisMin < pminVal)) {
          pminVal = thisMin;
          for (int z = 0; z < 4; z++) m_ParamsP[4 * x + z] = pThisParam[z];
        }

        if (Double.isNaN(thisMin)) {
          pThisParam = new double[4];
          isRunValid = false;
        }

        if (getDebug()) System.err.println("\nNegative exemplars");
        nOp = new TLD_Optm();
        nOp.setNum(sumN);
        nOp.setSSquare(varN);
        nOp.setXBar(meanN);

        nThisParam = nOp.findArgmin(nThisParam, bounds);
        while (nThisParam == null) {
          nThisParam = nOp.getVarbValues();
          if (getDebug()) System.err.println("!!! 200 iterations finished, not enough!");
          nThisParam = nOp.findArgmin(nThisParam, bounds);
        }
        thisMin = nOp.getMinFunction();
        if (!Double.isNaN(thisMin) && (thisMin < nminVal)) {
          nminVal = thisMin;
          for (int z = 0; z < 4; z++) m_ParamsN[4 * x + z] = nThisParam[z];
        }

        if (Double.isNaN(thisMin)) {
          nThisParam = new double[4];
          isRunValid = false;
        }

        if (!isRunValid) {
          y--;
          isRunValid = true;
        }

        if (++y < m_Run) {
          // Change the initial parameters and restart
          int pone = whichEx.nextInt(pnum), // Randomly pick one pos. exmpl.
              none = whichEx.nextInt(nnum);

          // Positive exemplars: next run
          while ((m_SumP[pone][x] <= 1.0) || Double.isNaN(m_MeanP[pone][x]))
            pone = whichEx.nextInt(pnum);

          a = m_VarianceP[pone][x] / (m_SumP[pone][x] - 1.0);
          if (a <= ZERO) a = m_ParamsN[4 * x]; // Change to negative params
          m = m_MeanP[pone][x];
          double sq = (m - m_ParamsP[4 * x + 3]) * (m - m_ParamsP[4 * x + 3]);

          b = a * m_ParamsP[4 * x + 2] / sq + 2.0; // b=a/Var+2, assuming Var=Sq/w'
          if ((b <= ZERO) || Double.isNaN(b) || Double.isInfinite(b)) b = m_ParamsN[4 * x + 1];

          w =
              sq
                  * (m_ParamsP[4 * x + 1] - 2.0)
                  / m_ParamsP[4 * x]; // w=Sq/Var, assuming Var=a'/(b'-2)
          if ((w <= ZERO) || Double.isNaN(w) || Double.isInfinite(w)) w = m_ParamsN[4 * x + 2];

          pThisParam[0] = a; // a
          pThisParam[1] = b; // b
          pThisParam[2] = w; // w
          pThisParam[3] = m; // m	

          // Negative exemplars: next run
          while ((m_SumN[none][x] <= 1.0) || Double.isNaN(m_MeanN[none][x]))
            none = whichEx.nextInt(nnum);

          a = m_VarianceN[none][x] / (m_SumN[none][x] - 1.0);
          if (a <= ZERO) a = m_ParamsP[4 * x];
          m = m_MeanN[none][x];
          sq = (m - m_ParamsN[4 * x + 3]) * (m - m_ParamsN[4 * x + 3]);

          b = a * m_ParamsN[4 * x + 2] / sq + 2.0; // b=a/Var+2, assuming Var=Sq/w'
          if ((b <= ZERO) || Double.isNaN(b) || Double.isInfinite(b)) b = m_ParamsP[4 * x + 1];

          w =
              sq
                  * (m_ParamsN[4 * x + 1] - 2.0)
                  / m_ParamsN[4 * x]; // w=Sq/Var, assuming Var=a'/(b'-2)
          if ((w <= ZERO) || Double.isNaN(w) || Double.isInfinite(w)) w = m_ParamsP[4 * x + 2];

          nThisParam[0] = a; // a
          nThisParam[1] = b; // b
          nThisParam[2] = w; // w
          nThisParam[3] = m; // m	    	
        }
      }
    }

    for (int x = 0, y = 0; x < m_Dimension; x++, y++) {
      // if((x==exs.classIndex()) || (x==exs.idIndex()))
      // y++;
      a = m_ParamsP[4 * x];
      b = m_ParamsP[4 * x + 1];
      w = m_ParamsP[4 * x + 2];
      m = m_ParamsP[4 * x + 3];
      if (getDebug())
        System.err.println(
            "\n\n???Positive: ( "
                + exs.attribute(1).relation().attribute(y)
                + "): a="
                + a
                + ", b="
                + b
                + ", w="
                + w
                + ", m="
                + m);

      a = m_ParamsN[4 * x];
      b = m_ParamsN[4 * x + 1];
      w = m_ParamsN[4 * x + 2];
      m = m_ParamsN[4 * x + 3];
      if (getDebug())
        System.err.println(
            "???Negative: ("
                + exs.attribute(1).relation().attribute(y)
                + "): a="
                + a
                + ", b="
                + b
                + ", w="
                + w
                + ", m="
                + m);
    }

    if (m_UseEmpiricalCutOff) {
      // Find the empirical cut-off
      double[] pLogOdds = new double[pnum], nLogOdds = new double[nnum];
      for (int p = 0; p < pnum; p++)
        pLogOdds[p] = likelihoodRatio(m_SumP[p], m_MeanP[p], m_VarianceP[p]);

      for (int q = 0; q < nnum; q++)
        nLogOdds[q] = likelihoodRatio(m_SumN[q], m_MeanN[q], m_VarianceN[q]);

      // Update m_Cutoff
      findCutOff(pLogOdds, nLogOdds);
    } else m_Cutoff = -Math.log((double) pnum / (double) nnum);

    if (getDebug()) System.err.println("???Cut-off=" + m_Cutoff);
  }
예제 #27
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if an error occurred during the prediction
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    NaiveBayes nb = new NaiveBayes();

    // System.out.println("number of instances		"+m_Train.numInstances());

    if (m_Train.numInstances() == 0) {
      // throw new Exception("No training instances!");
      return m_defaultModel.distributionForInstance(instance);
    }
    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
      m_kNNValid = false;
      boolean deletedInstance = false;
      while (m_Train.numInstances() > m_WindowSize) {
        m_Train.delete(0);
      }
      // rebuild datastructure KDTree currently can't delete
      if (deletedInstance == true) m_NNSearch.setInstances(m_Train);
    }

    // Select k by cross validation
    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
      crossValidate();
    }

    m_NNSearch.addInstanceInfo(instance);
    m_kNN = 1000;
    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
    double[] distances = m_NNSearch.getDistances();

    // System.out.println("--------------classify instance--------- ");
    // System.out.println("neighbours.numInstances"+neighbours.numInstances());
    // System.out.println("distances.length"+distances.length);
    // System.out.println("--------------classify instance--------- ");

    /*	for (int k = 0; k < distances.length; k++) {
    	//System.out.println("-------");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    	//System.out.println("instance of "+k+"	"+neighbours.instance(k));
    	//distances[k] = distances[k]+0.1;
    	//System.out.println("------- after add 0.1");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    }
    */
    Instances instances = new Instances(m_Train);
    // int attrnum = instances.numAttributes();
    instances.deleteWithMissingClass();

    Instances newm_Train = new Instances(instances, 0, instances.numInstances());

    for (int k = 0; k < neighbours.numInstances(); k++) {
      // System.out.println("-------");
      // Instance in = new Instance();
      Instance insk = neighbours.instance(k);
      // System.out.println("instance "+k+"	"+neighbours.instance(k));
      // System.out.println("-------");
      double dis = distances[k] + 0.1;
      // System.out.println("dis		"+dis);
      dis = (1 / dis) * 10;
      // System.out.println("1/dis		"+dis);
      int weightnum = (int) dis;
      // System.out.println("weightnum		"+weightnum);

      for (int s = 0; s < weightnum; s++) {

        newm_Train.add(insk);
      }
    }

    // System.out.println("number of instances		"+newm_Train.numInstances());

    /*  for (int k = 0; k < newm_Train.numInstances(); k++) {
    		System.out.println("-------");
    		System.out.println("instance "+k+"	"+newm_Train.instance(k));
    		System.out.println("-------");
    	}

    /*
    	for (int k = 0; k < distances.length; k++) {
    		System.out.println("-------");
    		System.out.println("distance of "+k+"	"+distances[k]);
    		System.out.println("-------");
    	}*/

    nb.buildClassifier(newm_Train);
    double[] dis = nb.distributionForInstance(instance);
    // double[] distribution = makeDistribution(neighbours, distances);
    return dis;
    // return distribution;
  }
예제 #28
0
  /**
   * @param exs the training exemplars
   * @throws Exception if the model cannot be built properly
   */
  public void buildClassifier(Instances exs) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(exs);

    // remove instances with missing class
    exs = new Instances(exs);
    exs.deleteWithMissingClass();

    int numegs = exs.numInstances();
    m_Dimension = exs.attribute(1).relation().numAttributes();
    m_Attribute = exs.attribute(1).relation().stringFreeStructure();
    Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0);

    // Divide into two groups
    for (int u = 0; u < numegs; u++) {
      Instance example = exs.instance(u);
      if (example.classValue() == 1) pos.add(example);
      else neg.add(example);
    }
    int pnum = pos.numInstances(), nnum = neg.numInstances();

    // xBar, n
    m_MeanP = new double[pnum][m_Dimension];
    m_SumP = new double[pnum][m_Dimension];
    m_MeanN = new double[nnum][m_Dimension];
    m_SumN = new double[nnum][m_Dimension];
    // w, m
    m_ParamsP = new double[2 * m_Dimension];
    m_ParamsN = new double[2 * m_Dimension];
    // \sigma^2
    m_SgmSqP = new double[m_Dimension];
    m_SgmSqN = new double[m_Dimension];
    // S^2
    double[][] varP = new double[pnum][m_Dimension], varN = new double[nnum][m_Dimension];
    // numOfEx 'e' without all missing
    double[] effNumExP = new double[m_Dimension], effNumExN = new double[m_Dimension];
    // For the starting values
    double[] pMM = new double[m_Dimension],
        nMM = new double[m_Dimension],
        pVM = new double[m_Dimension],
        nVM = new double[m_Dimension];
    // # of exemplars with only one instance
    double[] numOneInsExsP = new double[m_Dimension], numOneInsExsN = new double[m_Dimension];
    // sum_i(1/n_i)
    double[] pInvN = new double[m_Dimension], nInvN = new double[m_Dimension];

    // Extract metadata from both positive and negative bags
    for (int v = 0; v < pnum; v++) {
      // Instance px = pos.instance(v);
      Instances pxi = pos.instance(v).relationalValue(1);
      for (int k = 0; k < pxi.numAttributes(); k++) {
        m_MeanP[v][k] = pxi.meanOrMode(k);
        varP[v][k] = pxi.variance(k);
      }

      for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
        // if((t==m_ClassIndex) || (t==m_IdIndex))
        //  t++;
        if (varP[v][w] <= 0.0) varP[v][w] = 0.0;
        if (!Double.isNaN(m_MeanP[v][w])) {

          for (int u = 0; u < pxi.numInstances(); u++)
            if (!pxi.instance(u).isMissing(t)) m_SumP[v][w] += pxi.instance(u).weight();

          pMM[w] += m_MeanP[v][w];
          pVM[w] += m_MeanP[v][w] * m_MeanP[v][w];
          if ((m_SumP[v][w] > 1) && (varP[v][w] > ZERO)) {

            m_SgmSqP[w] += varP[v][w] * (m_SumP[v][w] - 1.0) / m_SumP[v][w];

            // m_SgmSqP[w] += varP[v][w]*(m_SumP[v][w]-1.0);
            effNumExP[w]++; // Not count exemplars with 1 instance
            pInvN[w] += 1.0 / m_SumP[v][w];
            // pInvN[w] += m_SumP[v][w];
          } else numOneInsExsP[w]++;
        }
      }
    }

    for (int v = 0; v < nnum; v++) {
      // Instance nx = neg.instance(v);
      Instances nxi = neg.instance(v).relationalValue(1);
      for (int k = 0; k < nxi.numAttributes(); k++) {
        m_MeanN[v][k] = nxi.meanOrMode(k);
        varN[v][k] = nxi.variance(k);
      }
      // Instances nxi =  nx.getInstances();

      for (int w = 0, t = 0; w < m_Dimension; w++, t++) {

        // if((t==m_ClassIndex) || (t==m_IdIndex))
        //  t++;
        if (varN[v][w] <= 0.0) varN[v][w] = 0.0;
        if (!Double.isNaN(m_MeanN[v][w])) {
          for (int u = 0; u < nxi.numInstances(); u++)
            if (!nxi.instance(u).isMissing(t)) m_SumN[v][w] += nxi.instance(u).weight();

          nMM[w] += m_MeanN[v][w];
          nVM[w] += m_MeanN[v][w] * m_MeanN[v][w];
          if ((m_SumN[v][w] > 1) && (varN[v][w] > ZERO)) {
            m_SgmSqN[w] += varN[v][w] * (m_SumN[v][w] - 1.0) / m_SumN[v][w];
            // m_SgmSqN[w] += varN[v][w]*(m_SumN[v][w]-1.0);
            effNumExN[w]++; // Not count exemplars with 1 instance
            nInvN[w] += 1.0 / m_SumN[v][w];
            // nInvN[w] += m_SumN[v][w];
          } else numOneInsExsN[w]++;
        }
      }
    }

    // Expected \sigma^2
    /* if m_SgmSqP[u] or m_SgmSqN[u] is 0, assign 0 to sigma^2.
     * Otherwise, may cause k m_SgmSqP / m_SgmSqN to be NaN.
     * Modified by Lin Dong (Sep. 2005)
     */
    for (int u = 0; u < m_Dimension; u++) {
      // For exemplars with only one instance, use avg(\sigma^2) of other exemplars
      if (m_SgmSqP[u] != 0) m_SgmSqP[u] /= (effNumExP[u] - pInvN[u]);
      else m_SgmSqP[u] = 0;
      if (m_SgmSqN[u] != 0) m_SgmSqN[u] /= (effNumExN[u] - nInvN[u]);
      else m_SgmSqN[u] = 0;

      // m_SgmSqP[u] /= (pInvN[u]-effNumExP[u]);
      // m_SgmSqN[u] /= (nInvN[u]-effNumExN[u]);
      effNumExP[u] += numOneInsExsP[u];
      effNumExN[u] += numOneInsExsN[u];
      pMM[u] /= effNumExP[u];
      nMM[u] /= effNumExN[u];
      pVM[u] =
          pVM[u] / (effNumExP[u] - 1.0) - pMM[u] * pMM[u] * effNumExP[u] / (effNumExP[u] - 1.0);
      nVM[u] =
          nVM[u] / (effNumExN[u] - 1.0) - nMM[u] * nMM[u] * effNumExN[u] / (effNumExN[u] - 1.0);
    }

    // Bounds and parameter values for each run
    double[][] bounds = new double[2][2];
    double[] pThisParam = new double[2], nThisParam = new double[2];

    // Initial values for parameters
    double w, m;
    Random whichEx = new Random(m_Seed);

    // Optimize for one dimension
    for (int x = 0; x < m_Dimension; x++) {
      // System.out.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x);

      // Positive examplars: first run
      pThisParam[0] = pVM[x]; // w
      if (pThisParam[0] <= ZERO) pThisParam[0] = 1.0;
      pThisParam[1] = pMM[x]; // m

      // Negative examplars: first run
      nThisParam[0] = nVM[x]; // w
      if (nThisParam[0] <= ZERO) nThisParam[0] = 1.0;
      nThisParam[1] = nMM[x]; // m

      // Bound constraints
      bounds[0][0] = ZERO; // w > 0
      bounds[0][1] = Double.NaN;
      bounds[1][0] = Double.NaN;
      bounds[1][1] = Double.NaN;

      double pminVal = Double.MAX_VALUE, nminVal = Double.MAX_VALUE;
      TLDSimple_Optm pOp = null, nOp = null;
      boolean isRunValid = true;
      double[] sumP = new double[pnum], meanP = new double[pnum];
      double[] sumN = new double[nnum], meanN = new double[nnum];

      // One dimension
      for (int p = 0; p < pnum; p++) {
        sumP[p] = m_SumP[p][x];
        meanP[p] = m_MeanP[p][x];
      }
      for (int q = 0; q < nnum; q++) {
        sumN[q] = m_SumN[q][x];
        meanN[q] = m_MeanN[q][x];
      }

      for (int y = 0; y < m_Run; y++) {
        // System.out.println("\n\n!!!!!!!!!Positive exemplars: Run #"+y);
        double thisMin;
        pOp = new TLDSimple_Optm();
        pOp.setNum(sumP);
        pOp.setSgmSq(m_SgmSqP[x]);
        if (getDebug()) System.out.println("m_SgmSqP[" + x + "]= " + m_SgmSqP[x]);
        pOp.setXBar(meanP);
        // pOp.setDebug(true);
        pThisParam = pOp.findArgmin(pThisParam, bounds);
        while (pThisParam == null) {
          pThisParam = pOp.getVarbValues();
          if (getDebug()) System.out.println("!!! 200 iterations finished, not enough!");
          pThisParam = pOp.findArgmin(pThisParam, bounds);
        }

        thisMin = pOp.getMinFunction();
        if (!Double.isNaN(thisMin) && (thisMin < pminVal)) {
          pminVal = thisMin;
          for (int z = 0; z < 2; z++) m_ParamsP[2 * x + z] = pThisParam[z];
        }

        if (Double.isNaN(thisMin)) {
          pThisParam = new double[2];
          isRunValid = false;
        }
        if (!isRunValid) {
          y--;
          isRunValid = true;
        }

        // Change the initial parameters and restart
        int pone = whichEx.nextInt(pnum);

        // Positive exemplars: next run
        while (Double.isNaN(m_MeanP[pone][x])) pone = whichEx.nextInt(pnum);

        m = m_MeanP[pone][x];
        w = (m - pThisParam[1]) * (m - pThisParam[1]);
        pThisParam[0] = w; // w
        pThisParam[1] = m; // m	
      }

      for (int y = 0; y < m_Run; y++) {
        // System.out.println("\n\n!!!!!!!!!Negative exemplars: Run #"+y);
        double thisMin;
        nOp = new TLDSimple_Optm();
        nOp.setNum(sumN);
        nOp.setSgmSq(m_SgmSqN[x]);
        if (getDebug()) System.out.println(m_SgmSqN[x]);
        nOp.setXBar(meanN);
        // nOp.setDebug(true);
        nThisParam = nOp.findArgmin(nThisParam, bounds);

        while (nThisParam == null) {
          nThisParam = nOp.getVarbValues();
          if (getDebug()) System.out.println("!!! 200 iterations finished, not enough!");
          nThisParam = nOp.findArgmin(nThisParam, bounds);
        }

        thisMin = nOp.getMinFunction();
        if (!Double.isNaN(thisMin) && (thisMin < nminVal)) {
          nminVal = thisMin;
          for (int z = 0; z < 2; z++) m_ParamsN[2 * x + z] = nThisParam[z];
        }

        if (Double.isNaN(thisMin)) {
          nThisParam = new double[2];
          isRunValid = false;
        }

        if (!isRunValid) {
          y--;
          isRunValid = true;
        }

        // Change the initial parameters and restart
        int none = whichEx.nextInt(nnum); // Randomly pick one pos. exmpl.

        // Negative exemplars: next run
        while (Double.isNaN(m_MeanN[none][x])) none = whichEx.nextInt(nnum);

        m = m_MeanN[none][x];
        w = (m - nThisParam[1]) * (m - nThisParam[1]);
        nThisParam[0] = w; // w
        nThisParam[1] = m; // m	 	
      }
    }

    m_LkRatio = new double[m_Dimension];

    if (m_UseEmpiricalCutOff) {
      // Find the empirical cut-off
      double[] pLogOdds = new double[pnum], nLogOdds = new double[nnum];
      for (int p = 0; p < pnum; p++) pLogOdds[p] = likelihoodRatio(m_SumP[p], m_MeanP[p]);

      for (int q = 0; q < nnum; q++) nLogOdds[q] = likelihoodRatio(m_SumN[q], m_MeanN[q]);

      // Update m_Cutoff
      findCutOff(pLogOdds, nLogOdds);
    } else m_Cutoff = -Math.log((double) pnum / (double) nnum);

    /*
    for(int x=0, y=0; x<m_Dimension; x++, y++){
    if((x==exs.classIndex()) || (x==exs.idIndex()))
    y++;

    w=m_ParamsP[2*x]; m=m_ParamsP[2*x+1];
    System.err.println("\n\n???Positive: ( "+exs.attribute(y)+
    "):  w="+w+", m="+m+", sgmSq="+m_SgmSqP[x]);

    w=m_ParamsN[2*x]; m=m_ParamsN[2*x+1];
    System.err.println("???Negative: ("+exs.attribute(y)+
    "):  w="+w+", m="+m+", sgmSq="+m_SgmSqN[x]+
    "\nAvg. log-likelihood ratio in training data="
    +(m_LkRatio[x]/(pnum+nnum)));
    }
    */
    if (getDebug()) System.err.println("\n\n???Cut-off=" + m_Cutoff);
  }
예제 #29
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_headerInfo = new Instances(instances, 0);
    m_numClasses = instances.numClasses();
    m_numAttributes = instances.numAttributes();
    m_probOfWordGivenClass = new double[m_numClasses][];

    /*
      initialising the matrix of word counts
      NOTE: Laplace estimator introduced in case a word that does not appear for a class in the
      training set does so for the test set
    */
    for (int c = 0; c < m_numClasses; c++) {
      m_probOfWordGivenClass[c] = new double[m_numAttributes];
      for (int att = 0; att < m_numAttributes; att++) {
        m_probOfWordGivenClass[c][att] = 1;
      }
    }

    // enumerate through the instances
    Instance instance;
    int classIndex;
    double numOccurences;
    double[] docsPerClass = new double[m_numClasses];
    double[] wordsPerClass = new double[m_numClasses];

    java.util.Enumeration enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      instance = (Instance) enumInsts.nextElement();
      classIndex = (int) instance.value(instance.classIndex());
      docsPerClass[classIndex] += instance.weight();

      for (int a = 0; a < instance.numValues(); a++)
        if (instance.index(a) != instance.classIndex()) {
          if (!instance.isMissing(a)) {
            numOccurences = instance.valueSparse(a) * instance.weight();
            if (numOccurences < 0)
              throw new Exception("Numeric attribute values must all be greater or equal to zero.");
            wordsPerClass[classIndex] += numOccurences;
            m_probOfWordGivenClass[classIndex][instance.index(a)] += numOccurences;
          }
        }
    }

    /*
      normalising probOfWordGivenClass values
      and saving each value as the log of each value
    */
    for (int c = 0; c < m_numClasses; c++)
      for (int v = 0; v < m_numAttributes; v++)
        m_probOfWordGivenClass[c][v] =
            Math.log(m_probOfWordGivenClass[c][v] / (wordsPerClass[c] + m_numAttributes - 1));

    /*
      calculating Pr(H)
      NOTE: Laplace estimator introduced in case a class does not get mentioned in the set of
      training instances
    */
    final double numDocs = instances.sumOfWeights() + m_numClasses;
    m_probOfClass = new double[m_numClasses];
    for (int h = 0; h < m_numClasses; h++)
      m_probOfClass[h] = (double) (docsPerClass[h] + 1) / numDocs;
  }
예제 #30
0
파일: Logistic.java 프로젝트: Faelg5/weka
  /**
   * Builds the classifier
   *
   * @param train the training data to be used for generating the boosted classifier.
   * @throws Exception if the classifier could not be built successfully
   */
  @Override
  public void buildClassifier(Instances train) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(train);

    // remove instances with missing class
    train = new Instances(train);
    train.deleteWithMissingClass();

    // Replace missing values
    m_ReplaceMissingValues = new ReplaceMissingValues();
    m_ReplaceMissingValues.setInputFormat(train);
    train = Filter.useFilter(train, m_ReplaceMissingValues);

    // Remove useless attributes
    m_AttFilter = new RemoveUseless();
    m_AttFilter.setInputFormat(train);
    train = Filter.useFilter(train, m_AttFilter);

    // Transform attributes
    m_NominalToBinary = new NominalToBinary();
    m_NominalToBinary.setInputFormat(train);
    train = Filter.useFilter(train, m_NominalToBinary);

    // Save the structure for printing the model
    m_structure = new Instances(train, 0);

    // Extract data
    m_ClassIndex = train.classIndex();
    m_NumClasses = train.numClasses();

    int nK = m_NumClasses - 1; // Only K-1 class labels needed
    int nR = m_NumPredictors = train.numAttributes() - 1;
    int nC = train.numInstances();

    m_Data = new double[nC][nR + 1]; // Data values
    int[] Y = new int[nC]; // Class labels
    double[] xMean = new double[nR + 1]; // Attribute means
    double[] xSD = new double[nR + 1]; // Attribute stddev's
    double[] sY = new double[nK + 1]; // Number of classes
    double[] weights = new double[nC]; // Weights of instances
    double totWeights = 0; // Total weights of the instances
    m_Par = new double[nR + 1][nK]; // Optimized parameter values

    if (m_Debug) {
      System.out.println("Extracting data...");
    }

    for (int i = 0; i < nC; i++) {
      // initialize X[][]
      Instance current = train.instance(i);
      Y[i] = (int) current.classValue(); // Class value starts from 0
      weights[i] = current.weight(); // Dealing with weights
      totWeights += weights[i];

      m_Data[i][0] = 1;
      int j = 1;
      for (int k = 0; k <= nR; k++) {
        if (k != m_ClassIndex) {
          double x = current.value(k);
          m_Data[i][j] = x;
          xMean[j] += weights[i] * x;
          xSD[j] += weights[i] * x * x;
          j++;
        }
      }

      // Class count
      sY[Y[i]]++;
    }

    if ((totWeights <= 1) && (nC > 1)) {
      throw new Exception("Sum of weights of instances less than 1, please reweight!");
    }

    xMean[0] = 0;
    xSD[0] = 1;
    for (int j = 1; j <= nR; j++) {
      xMean[j] = xMean[j] / totWeights;
      if (totWeights > 1) {
        xSD[j] = Math.sqrt(Math.abs(xSD[j] - totWeights * xMean[j] * xMean[j]) / (totWeights - 1));
      } else {
        xSD[j] = 0;
      }
    }

    if (m_Debug) {
      // Output stats about input data
      System.out.println("Descriptives...");
      for (int m = 0; m <= nK; m++) {
        System.out.println(sY[m] + " cases have class " + m);
      }
      System.out.println("\n Variable     Avg       SD    ");
      for (int j = 1; j <= nR; j++) {
        System.out.println(
            Utils.doubleToString(j, 8, 4)
                + Utils.doubleToString(xMean[j], 10, 4)
                + Utils.doubleToString(xSD[j], 10, 4));
      }
    }

    // Normalise input data
    for (int i = 0; i < nC; i++) {
      for (int j = 0; j <= nR; j++) {
        if (xSD[j] != 0) {
          m_Data[i][j] = (m_Data[i][j] - xMean[j]) / xSD[j];
        }
      }
    }

    if (m_Debug) {
      System.out.println("\nIteration History...");
    }

    double x[] = new double[(nR + 1) * nK];
    double[][] b = new double[2][x.length]; // Boundary constraints, N/A here

    // Initialize
    for (int p = 0; p < nK; p++) {
      int offset = p * (nR + 1);
      x[offset] = Math.log(sY[p] + 1.0) - Math.log(sY[nK] + 1.0); // Null model
      b[0][offset] = Double.NaN;
      b[1][offset] = Double.NaN;
      for (int q = 1; q <= nR; q++) {
        x[offset + q] = 0.0;
        b[0][offset + q] = Double.NaN;
        b[1][offset + q] = Double.NaN;
      }
    }

    OptObject oO = new OptObject();
    oO.setWeights(weights);
    oO.setClassLabels(Y);

    Optimization opt = null;
    if (m_useConjugateGradientDescent) {
      opt = new OptEngCG(oO);
    } else {
      opt = new OptEng(oO);
    }
    opt.setDebug(m_Debug);

    if (m_MaxIts == -1) { // Search until convergence
      x = opt.findArgmin(x, b);
      while (x == null) {
        x = opt.getVarbValues();
        if (m_Debug) {
          System.out.println("First set of iterations finished, not enough!");
        }
        x = opt.findArgmin(x, b);
      }
      if (m_Debug) {
        System.out.println(" -------------<Converged>--------------");
      }
    } else {
      opt.setMaxIteration(m_MaxIts);
      x = opt.findArgmin(x, b);
      if (x == null) {
        x = opt.getVarbValues();
      }
    }

    m_LL = -opt.getMinFunction(); // Log-likelihood

    // Don't need data matrix anymore
    m_Data = null;

    // Convert coefficients back to non-normalized attribute units
    for (int i = 0; i < nK; i++) {
      m_Par[0][i] = x[i * (nR + 1)];
      for (int j = 1; j <= nR; j++) {
        m_Par[j][i] = x[i * (nR + 1) + j];
        if (xSD[j] != 0) {
          m_Par[j][i] /= xSD[j];
          m_Par[0][i] -= m_Par[j][i] * xMean[j];
        }
      }
    }
  }