예제 #1
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
예제 #2
0
  private Instances buildInstance(String data) {
    Instances instance;
    double[] values;
    // 3. fill with data
    // data = "15073302,9,0,1,0,0,5,f";
    String[] s_parts = data.split(",");
    // first instance
    // 2. create Instances object
    String id = s_parts[0];
    instance = new Instances(id, attributes, 0);
    values = new double[instance.numAttributes()];

    for (int j = 1; j < s_parts.length; j++) {
      String text = s_parts[j];
      if (j == s_parts.length - 1) {
        values[j - 1] = attributeReturn.indexOf(text);
      } else {
        values[j - 1] = Double.valueOf(text);
      }
    }
    // System.out.println(values[6]);
    // add
    instance.add(new DenseInstance(1.0, values));
    instance.setClass(instance.attribute(s_parts.length - 2));
    return instance;
  }
예제 #3
0
  /**
   * Trains the classifier on the array of Signal objects. Implementations of this method should
   * also produce an ordered list of the class names which can be returned with the <code>
   * getClassNames</code> method.
   *
   * @param inputData the Signal array that the model should be trained on.
   * @throws noMetadataException Thrown if there is no class metadata to train the Gaussian model
   *     with
   */
  public void train(Signal[] inputData) {

    List classNamesList = new ArrayList();
    for (int i = 0; i < inputData.length; i++) {
      try {
        String className = inputData[i].getStringMetadata(Signal.PROP_CLASS);
        if ((className != null) && (!classNamesList.contains(className))) {
          classNamesList.add(className);
        }
      } catch (noMetadataException ex) {
        throw new RuntimeException("No class metadata found to train model on!", ex);
      }
    }
    Collections.sort(classNamesList);
    classnames = (String[]) classNamesList.toArray(new String[classNamesList.size()]);

    FastVector classValues = new FastVector(classnames.length);
    for (int i = 0; i < classnames.length; i++) {
      classValues.addElement(classnames[i]);
    }
    classAttribute = new Attribute(Signal.PROP_CLASS, classValues);
    Instances trainingDataSet =
        new Instances(Signal2Instance.convert(inputData[0], classAttribute));

    if (inputData.length > 1) {
      for (int i = 1; i < inputData.length; i++) {
        Instances aSignalInstance = Signal2Instance.convert(inputData[i], classAttribute);
        for (int j = 0; j < aSignalInstance.numInstances(); j++)
          trainingDataSet.add(aSignalInstance.instance(j));
      }
    }

    trainingDataSet.setClass(classAttribute);

    inputData = null;
    theRule = new MISMO();

    // parse options
    StringTokenizer stOption = new StringTokenizer(this.MISMOOptions, " ");
    String[] options = new String[stOption.countTokens()];
    for (int i = 0; i < options.length; i++) {
      options[i] = stOption.nextToken();
    }

    try {
      theRule.setOptions(options);
    } catch (Exception ex) {
      throw new RuntimeException("Failed to set MISMO classifier options!", ex);
    }
    try {
      theRule.buildClassifier(trainingDataSet);
      System.out.println("WEKA: outputting MISMO classifier; " + theRule.globalInfo());
    } catch (Exception ex) {
      throw new RuntimeException("Failed to train classifier!", ex);
    }
  }
  public Instances initializeInstances() {

    FastVector wekaAttributes = buildCosineAttributes();
    Attribute label = (Attribute) wekaAttributes.lastElement();

    Instances data = new Instances("semantic-space", wekaAttributes, 1000);
    data.setClass(label);

    return data;
  }
예제 #5
0
  public Instances getInstances(List<ImageData> data) {

    CSVLoader loader = new CSVLoader();
    Instances instances;

    try {
      // Create a temp csv file
      tempFile = new File("tmp");

      PrintWriter pw = null;

      try {
        pw = new PrintWriter(tempFile);
      } catch (FileNotFoundException e) {
        throw new Error(e);
      }

      // Load the data into the csv file
      for (int i = 0; i < Reader.featureSize; i++) {
        pw.print(i + ",");
      }

      pw.println("class");

      for (int i = 0; i < data.size(); i++) {

        List<Double> features = data.get(i).getFeatures();
        for (int j = 0; j < features.size(); j++) {
          pw.print(features.get(j) + ",");
        }

        pw.println(data.get(i).getClassType());
        pw.flush();
      }

      // Load the instances from the temp csv file
      loader.setSource(new File("tmp"));
      instances = loader.getDataSet();
      instances.setClass(instances.attribute("class"));

      return instances;

    } catch (IOException e) {
      throw new Error(e);
    } finally {
      if (tempFile != null) {
        tempFile.delete();
        tempFile = null;
      }
    }
  }
  private static Instances initializeAttributes() {

    String nameOfDataset = "Badges";

    Instances instances;

    FastVector attributes = new FastVector(9);
    for (String featureName : features) {
      attributes.addElement(new Attribute(featureName, zeroOne));
    }
    Attribute classLabel = new Attribute("Class", labels);
    // labels is a FastVector of '+' and '-'
    attributes.addElement(classLabel);

    instances = new Instances(nameOfDataset, attributes, 0);

    instances.setClass(classLabel);

    return instances;
  }
예제 #7
0
  public static void main(String args[]) throws Exception {
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("src/train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();
    weka.core.Attribute trainAttribute = trainDataSet.attribute("class");

    trainDataSet.setClass(trainAttribute);
    // trainDataSet.deleteStringAttributes();

    NaiveBayes classifier = new NaiveBayes();

    final double startTime = System.currentTimeMillis();
    classifier.buildClassifier(trainDataSet);
    final double endTime = System.currentTimeMillis();
    double executionTime = (endTime - startTime) / (1000.0);
    System.out.println("Total execution time: " + executionTime);

    SerializationHelper.write("NaiveBayes.model", classifier);
    System.out.println("Saved trained model to classifier.model");
  }
  public Instance buildWekaInstance(QAPair pair) {

    double[] query = projector.transform(pair.getQueryList());
    double[] answer = projector.transform(pair.getAnswerList());
    double[] cosine = {projector.computeCosignSimilarity(query, answer), 0.0};

    FastVector attributes = buildCosineAttributes();
    Attribute label = (Attribute) attributes.lastElement();

    Instances testInstances = new Instances("test", attributes, 1);
    testInstances.setClass(label);

    Instance example = new Instance(1, cosine);
    testInstances.add(example);
    example.setDataset(testInstances);

    if (!pair.getLabel().equals("-1")) {
      example.setClassValue(pair.getLabel());
    } else {
      example.setClassMissing();
    }
    return example;
  }
예제 #9
0
  public static void analyze_accuracy_NHBS(int rng_seed) throws Exception {
    HashMap<String, Object> population_params = load_defaults(null);
    RawLoader rl = new RawLoader(population_params, true, false, rng_seed);
    List<DrugUser> learningData = rl.getLearningData();

    Instances nhbs_data =
        new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size());
    for (DrugUser du : learningData) {
      nhbs_data.add(du.getInstance());
    }
    System.out.println(nhbs_data.toSummaryString());
    nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state"));

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    ArrayList<String> options = new ArrayList<String>();
    options.add("-Q");
    options.add("" + rng_seed);
    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // ROC=0.60
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    Classifier classifier = new Logistic();
    // ROC=0.686
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    // Classifier classifier = new RandomForest();
    // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree.  generally,
    // might want to optimize (or not
    // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    // options.add("-I"); options.add("100"); options.add("-K"); options.add("4");
    // ROC=0.673

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent
    // ROC=0.633

    // Classifier classifier = new AdaBoostM1();
    // ROC=0.66
    // Classifier classifier = new MultiBoostAB();
    // ROC=0.67
    // Classifier classifier = new Stacking();
    // ROC=0.495

    // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier
    // ROC=0.585
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options

    classifier.setOptions((String[]) options.toArray(new String[0]));

    // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code
    // classifier.buildClassifier(nhbs_data); // build classifier

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
예제 #10
0
  public static void main(String[] args) throws Exception {
    // NaiveBayesSimple nb = new NaiveBayesSimple();

    //		BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt"));
    //		String s = null;
    //		long st_time = System.currentTimeMillis();
    //		Instances inst_train = new Instances(br_train);
    //		System.out.println(inst_train.numAttributes());
    //		inst_train.setClassIndex(inst_train.numAttributes()-1);
    //		System.out.println("train time"+(System.currentTimeMillis()-st_time));
    // NaiveBayes nb1 = new NaiveBayes();
    // nb1.buildClassifier(inst_train);
    // br_train.close();
    long st_time = System.currentTimeMillis();
    st_time = System.currentTimeMillis();

    Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model");

    //		BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt"));
    //		Instances inst_test = new Instances(br_test);
    //		inst_test.setClassIndex(inst_test.numAttributes()-1);
    //		System.out.println("test time"+(System.currentTimeMillis()-st_time));
    //

    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("src/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    Attribute testAttribute = testDataSet.attribute("class");
    testDataSet.setClass(testAttribute);

    int correct = 0;
    int incorrect = 0;
    FastVector attInfo = new FastVector();
    attInfo.addElement(new Attribute("Id"));
    attInfo.addElement(new Attribute("Category"));

    Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances());

    Enumeration testInstances = testDataSet.enumerateInstances();
    int index = 1;
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      double classification = classifier.classifyInstance(instance);
      Instance predictInstance = new Instance(outputInstances.numAttributes());
      predictInstance.setValue(0, index++);
      predictInstance.setValue(1, (int) classification + 1);
      outputInstances.add(predictInstance);
    }

    System.out.println("Correct Instance: " + correct);
    System.out.println("IncCorrect Instance: " + incorrect);
    double accuracy = (double) (correct) / (double) (correct + incorrect);
    System.out.println("Accuracy: " + accuracy);
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("predict.csv"));
    predictedCsvSaver.setInstances(outputInstances);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
예제 #11
0
 public SensitivityAnalysis(Instances d) {
   data = d;
   data.setClass(data.attribute(data.numAttributes() - 1));
   ;
   System.out.println(data.classIndex());
 }
예제 #12
0
  public MappingInfo(Instances dataSet, MiningSchema miningSchema, Logger log) throws Exception {
    m_log = log;
    // miningSchema.convertStringAttsToNominal();
    Instances fieldsI = miningSchema.getMiningSchemaAsInstances();

    m_fieldsMap = new int[fieldsI.numAttributes()];
    m_nominalValueMaps = new int[fieldsI.numAttributes()][];

    for (int i = 0; i < fieldsI.numAttributes(); i++) {
      String schemaAttName = fieldsI.attribute(i).name();
      boolean found = false;
      for (int j = 0; j < dataSet.numAttributes(); j++) {
        if (dataSet.attribute(j).name().equals(schemaAttName)) {
          Attribute miningSchemaAtt = fieldsI.attribute(i);
          Attribute incomingAtt = dataSet.attribute(j);
          // check type match
          if (miningSchemaAtt.type() != incomingAtt.type()) {
            throw new Exception(
                "[MappingInfo] type mismatch for field "
                    + schemaAttName
                    + ". Mining schema type "
                    + miningSchemaAtt.toString()
                    + ". Incoming type "
                    + incomingAtt.toString()
                    + ".");
          }

          // check nominal values (number, names...)
          if (miningSchemaAtt.numValues() != incomingAtt.numValues()) {
            String warningString =
                "[MappingInfo] WARNING: incoming nominal attribute "
                    + incomingAtt.name()
                    + " does not have the same "
                    + "number of values as the corresponding mining "
                    + "schema attribute.";
            if (m_log != null) {
              m_log.logMessage(warningString);
            } else {
              System.err.println(warningString);
            }
          }
          if (miningSchemaAtt.isNominal() || miningSchemaAtt.isString()) {
            int[] valuesMap = new int[incomingAtt.numValues()];
            for (int k = 0; k < incomingAtt.numValues(); k++) {
              String incomingNomVal = incomingAtt.value(k);
              int indexInSchema = miningSchemaAtt.indexOfValue(incomingNomVal);
              if (indexInSchema < 0) {
                String warningString =
                    "[MappingInfo] WARNING: incoming nominal attribute "
                        + incomingAtt.name()
                        + " has value "
                        + incomingNomVal
                        + " that doesn't occur in the mining schema.";
                if (m_log != null) {
                  m_log.logMessage(warningString);
                } else {
                  System.err.println(warningString);
                }
                valuesMap[k] = UNKNOWN_NOMINAL_VALUE;
              } else {
                valuesMap[k] = indexInSchema;
              }
            }
            m_nominalValueMaps[i] = valuesMap;
          }

          /*if (miningSchemaAtt.isNominal()) {
            for (int k = 0; k < miningSchemaAtt.numValues(); k++) {
              if (!miningSchemaAtt.value(k).equals(incomingAtt.value(k))) {
                throw new Exception("[PMMLUtils] value " + k + " (" +
                                    miningSchemaAtt.value(k) + ") does not match " +
                                    "incoming value (" + incomingAtt.value(k) +
                                    ") for attribute " + miningSchemaAtt.name() +
                                    ".");

              }
            }
          }*/
          found = true;
          m_fieldsMap[i] = j;
        }
      }
      if (!found) {
        throw new Exception(
            "[MappingInfo] Unable to find a match for mining schema "
                + "attribute "
                + schemaAttName
                + " in the "
                + "incoming instances!");
      }
    }

    // check class attribute (if set)
    if (fieldsI.classIndex() >= 0) {
      if (dataSet.classIndex() < 0) {
        // first see if we can find a matching class
        String className = fieldsI.classAttribute().name();
        Attribute classMatch = dataSet.attribute(className);
        if (classMatch == null) {
          throw new Exception(
              "[MappingInfo] Can't find match for target field "
                  + className
                  + "in incoming instances!");
        }
        dataSet.setClass(classMatch);
      } else if (!fieldsI.classAttribute().name().equals(dataSet.classAttribute().name())) {
        throw new Exception(
            "[MappingInfo] class attribute in mining schema does not match "
                + "class attribute in incoming instances!");
      }
    }

    // Set up the textual description of the mapping
    fieldsMappingString(fieldsI, dataSet);
  }
예제 #13
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }