예제 #1
0
  private static void analyse(Instances train, Instances datapredict) {
    String mpOptions = "-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a";

    try {

      train.setClassIndex(train.numAttributes() - 1);
      train.deleteAttributeAt(0);
      int numClasses = train.numClasses();

      for (int i = 0; i < numClasses; i++) {
        System.out.println("class value [" + i + "]=" + train.classAttribute().value(i) + "");
      }

      // Instance of NN
      MultilayerPerceptron mlp = new MultilayerPerceptron();
      mlp.setOptions(weka.core.Utils.splitOptions(mpOptions));
      mlp.buildClassifier(train);

      datapredict.setClassIndex(datapredict.numAttributes() - 1);
      datapredict.deleteAttributeAt(0);

      // Instances predicteddata = new Instances(datapredict);
      for (int i = 0; i < datapredict.numInstances(); i++) {

        Instance newInst = datapredict.instance(i);
        double pred = mlp.classifyInstance(newInst);
        int predInt = (int) pred; // Math.round(pred);
        String predString = train.classAttribute().value(predInt);
        System.out.println(
            "cliente["
                + i
                + "] pred["
                + pred
                + "] predInt["
                + predInt
                + "] desertor["
                + predString
                + "]");
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
  // Create 70% training data set
  public void generateTrainingDataSet() {

    trainingDataSet = new Instances(instances);
    int size = trainingDataSet.numInstances();

    // Remove closing prize "close" attribute
    trainingDataSet.deleteAttributeAt(0);

    // Randomize data set
    trainingDataSet.randomize(trainingDataSet.getRandomNumberGenerator(1));
  }
예제 #3
0
파일: Wavelet.java 프로젝트: dachylong/weka
  /**
   * processes the instances using the HAAR algorithm
   *
   * @param instances the data to process
   * @return the modified data
   * @throws Exception in case the processing goes wrong
   */
  protected Instances processHAAR(Instances instances) throws Exception {
    Instances result;
    int i;
    int n;
    int j;
    int clsIdx;
    double[] oldVal;
    double[] newVal;
    int level;
    int length;
    double[] clsVal;
    Attribute clsAtt;

    clsIdx = instances.classIndex();
    clsVal = null;
    clsAtt = null;
    if (clsIdx > -1) {
      clsVal = instances.attributeToDoubleArray(clsIdx);
      clsAtt = (Attribute) instances.classAttribute().copy();
      instances.setClassIndex(-1);
      instances.deleteAttributeAt(clsIdx);
    }
    result = new Instances(instances, 0);
    level = (int) StrictMath.ceil(StrictMath.log(instances.numAttributes()) / StrictMath.log(2.0));

    for (i = 0; i < instances.numInstances(); i++) {
      oldVal = instances.instance(i).toDoubleArray();
      newVal = new double[oldVal.length];

      for (n = level; n > 0; n--) {
        length = (int) StrictMath.pow(2, n - 1);

        for (j = 0; j < length; j++) {
          newVal[j] = (oldVal[j * 2] + oldVal[j * 2 + 1]) / StrictMath.sqrt(2);
          newVal[j + length] = (oldVal[j * 2] - oldVal[j * 2 + 1]) / StrictMath.sqrt(2);
        }

        System.arraycopy(newVal, 0, oldVal, 0, newVal.length);
      }

      // add new transformed instance
      result.add(new DenseInstance(1, newVal));
    }

    // add class again
    if (clsIdx > -1) {
      result.insertAttributeAt(clsAtt, clsIdx);
      result.setClassIndex(clsIdx);
      for (i = 0; i < clsVal.length; i++) result.instance(i).setClassValue(clsVal[i]);
    }

    return result;
  }
예제 #4
0
  public static Double runClassify(String trainFile, String testFile) {
    double predictOrder = 0.0;
    double trueOrder = 0.0;
    try {
      String trainWekaFileName = trainFile;
      String testWekaFileName = testFile;

      Instances train = DataSource.read(trainWekaFileName);
      Instances test = DataSource.read(testWekaFileName);

      train.setClassIndex(0);
      test.setClassIndex(0);

      train.deleteAttributeAt(8);
      test.deleteAttributeAt(8);
      train.deleteAttributeAt(6);
      test.deleteAttributeAt(6);
      train.deleteAttributeAt(5);
      test.deleteAttributeAt(5);
      train.deleteAttributeAt(4);
      test.deleteAttributeAt(4);

      // AdditiveRegression classifier = new AdditiveRegression();

      // NaiveBayes classifier = new NaiveBayes();

      RandomForest classifier = new RandomForest();
      // LibSVM classifier = new LibSVM();

      classifier.buildClassifier(train);
      Evaluation eval = new Evaluation(train);
      eval.evaluateModel(classifier, test);

      System.out.println(eval.toSummaryString("\nResults\n\n", true));
      // System.out.println(eval.toClassDetailsString());
      // System.out.println(eval.toMatrixString());
      int k = 892;
      for (int i = 0; i < test.numInstances(); i++) {
        predictOrder = classifier.classifyInstance(test.instance(i));
        trueOrder = test.instance(i).classValue();
        System.out.println((k++) + "," + (int) predictOrder);
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
    return predictOrder;
  }
예제 #5
0
  public static void test_NHBS_old() throws Exception {
    // load the data
    CSVLoader loader = new CSVLoader();
    // these must come before the getDataSet()
    // loader.setEnclosureCharacters(",\'\"S");
    // loader.setNominalAttributes("16,71"); //zip code, drug name
    // loader.setStringAttributes("");
    // loader.setDateAttributes("0,1");
    // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv"));
    loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv"));
    Instances nhbs_data = loader.getDataSet();
    loader.setMissingValue("NOVALUE");
    // loader.setMissingValue("");

    nhbs_data.deleteAttributeAt(12); // zip code
    nhbs_data.deleteAttributeAt(1); // date - redundant with age
    nhbs_data.deleteAttributeAt(0); // date
    System.out.println("classifying attribute:");
    nhbs_data.setClassIndex(1); // new index  3->2->1
    nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1

    // wishlist: perhaps it would be smarter to throw out unclassified instance?  they interfere
    // with the scoring
    nhbs_data.deleteWithMissingClass();
    // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column.  all are
    // mostly negative
    // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column

    // #14, i.e. rds_fem, should be made numeric
    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    Classifier classifier = new RandomForest();
    String[] options = {
      "-I", "100", "-K", "4"
    }; // -I trees, -K features per tree.  generally, might want to optimize (or not
       // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    classifier.setOptions(options);
    // Classifier classifier = new Logistic();

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent

    // does poorly
    // Classifier classifier = new AdaBoostM1();
    // Classifier classifier = new MultiBoostAB();
    // Classifier classifier = new Stacking();

    // building a C45 tree classifier
    // J48 classifier = new J48(); // new instance of tree
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options
    // classifier.buildClassifier(nhbs_data); // build classifier

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1));
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
  public static void main(String[] args) {

    if (args.length < 1) {
      System.out.println("usage: C4_5TweetTopicCategorization <root_path>");
      System.exit(-1);
    }

    String rootPath = args[0];
    File dataFolder = new File(rootPath + "/data");
    String resultFolderPath = rootPath + "/results/C4_5/";

    CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer();
    Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class);
    PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH);

    File resultFolder = new File(resultFolderPath);
    if (!resultFolder.exists()) resultFolder.mkdir();

    CSVLoader csvLoader = new CSVLoader();

    try {
      for (File dataSetName : dataFolder.listFiles()) {

        Instances data = null;
        try {
          csvLoader.setSource(dataSetName);
          csvLoader.setStringAttributes("2");
          data = csvLoader.getDataSet();
        } catch (IOException ioe) {
          logger.error(ioe);
          crisisMailer.sendEmailAlert(ioe);
          System.exit(-1);
        }

        data.setClassIndex(data.numAttributes() - 1);
        data.deleteWithMissingClass();

        Instances vectorizedData = null;
        StringToWordVector stringToWordVectorFilter = new StringToWordVector();
        try {
          stringToWordVectorFilter.setInputFormat(data);
          stringToWordVectorFilter.setAttributeIndices("2");
          stringToWordVectorFilter.setIDFTransform(true);
          stringToWordVectorFilter.setLowerCaseTokens(true);
          stringToWordVectorFilter.setOutputWordCounts(false);
          stringToWordVectorFilter.setUseStoplist(true);

          vectorizedData = Filter.useFilter(data, stringToWordVectorFilter);
          vectorizedData.deleteAttributeAt(0);
          // System.out.println(vectorizedData);
        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }

        J48 j48Classifier = new J48();

        /*
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(stringToWordVectorFilter);
        filteredClassifier.setClassifier(j48Classifier);
        */

        try {
          Evaluation eval = new Evaluation(vectorizedData);
          eval.crossValidateModel(
              j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis()));

          FileOutputStream resultOutputStream =
              new FileOutputStream(new File(resultFolderPath + dataSetName.getName()));

          resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes());
          resultOutputStream.write(eval.toMatrixString().getBytes());
          resultOutputStream.write(eval.toClassDetailsString().getBytes());
          resultOutputStream.close();

        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }
      }
    } catch (Exception exception) {
      logger.error(exception);
      crisisMailer.sendEmailAlert(exception);
      System.out.println(-1);
    }
  }
  public void generateDataSet() {

    // Read all the instances in the file (ARFF, CSV, XRFF, ...)
    try {
      source = new DataSource("data\\bne.csv");
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }

    // Create data set
    try {
      instances = source.getDataSet();
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }

    // Reverse the order of instances in the data set to place them in
    // chronological order
    for (int i = 0; i < (instances.numInstances() / 2); i++) {
      instances.swap(i, instances.numInstances() - 1 - i);
    }

    // Remove "volume", "low price", "high price", "opening price" and
    // "data" from data set
    instances.deleteAttributeAt(instances.numAttributes() - 1);
    instances.deleteAttributeAt(instances.numAttributes() - 2);
    instances.deleteAttributeAt(instances.numAttributes() - 2);
    instances.deleteAttributeAt(instances.numAttributes() - 2);
    instances.deleteAttributeAt(instances.numAttributes() - 2);

    // Create list to hold nominal values "purchase", "sale", "retain"
    List my_nominal_values = new ArrayList(3);
    my_nominal_values.add("purchase");
    my_nominal_values.add("sale");
    my_nominal_values.add("retain");

    // Create nominal attribute "classIndex"
    Attribute classIndex = new Attribute("classIndex", my_nominal_values);

    // Add "classIndex" as an attribute to each instance
    instances.insertAttributeAt(classIndex, instances.numAttributes());

    // Set the value of "classIndex" for each instance
    for (int i = 0; i < instances.numInstances() - 1; i++) {
      if (instances.get(i + 1).value(instances.numAttributes() - 2)
          > instances.get(i).value(instances.numAttributes() - 2)) {
        instances.get(i).setValue(instances.numAttributes() - 1, "purchase");
      } else if (instances.get(i + 1).value(instances.numAttributes() - 2)
          < instances.get(i).value(instances.numAttributes() - 2)) {
        instances.get(i).setValue(instances.numAttributes() - 1, "sale");
      } else if (instances.get(i + 1).value(instances.numAttributes() - 2)
          == instances.get(i).value(instances.numAttributes() - 2)) {
        instances.get(i).setValue(instances.numAttributes() - 1, "retain");
      }
    }

    // Make the last attribute be the class
    instances.setClassIndex(instances.numAttributes() - 1);

    // Calculate and insert technical analysis attributes into data set
    Strategies strategies = new Strategies();
    strategies.applyStrategies();

    // Print header and instances
    System.out.println("\nDataset:\n");
    System.out.println(instances);
    System.out.println(instances.numInstances());
  }