public weka.classifiers.Classifier getClassifier() throws Exception {

    StringToWordVector stwv = new StringToWordVector();
    stwv.setTFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF));
    stwv.setIDFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF));
    stwv.setLowerCaseTokens(hasParam(Constant.RUNTIME_PARAMS.CONV_LOWERCASE));
    stwv.setUseStoplist(hasParam(Constant.RUNTIME_PARAMS.REM_STOP_WORDS));
    stwv.setOutputWordCounts(hasParam(Constant.RUNTIME_PARAMS.USE_WORD_FREQ));
    if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) stwv.setInputFormat(getTrainData());
    if (hasParam(Constant.RUNTIME_PARAMS.USE_BIGRAM)) {
      NGramTokenizer tokenizer = new NGramTokenizer();
      tokenizer.setNGramMinSize(2);
      stwv.setTokenizer(tokenizer);
    } else if (hasParam(Constant.RUNTIME_PARAMS.USE_TRIGRAM)) {
      NGramTokenizer tokenizer = new NGramTokenizer();
      tokenizer.setNGramMinSize(3);
      stwv.setTokenizer(tokenizer);
    }
    if (hasParam(Constant.RUNTIME_PARAMS.USE_STEMMER)) {
      SnowballStemmer stemmer = new SnowballStemmer("porter");
      stwv.setStemmer(stemmer);
    }

    Logistic l = new Logistic();

    FilteredClassifier cls = new FilteredClassifier();
    cls.setClassifier(l);
    cls.setFilter(stwv);
    if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) cls.buildClassifier(getTrainData());

    return cls;
  }
예제 #2
0
 static FilteredClassifier wrapRemoveFirst(Classifier c) {
   Remove remove = new Remove(); // new instance of filter
   remove.setAttributeIndices("first");
   FilteredClassifier fc = new FilteredClassifier();
   fc.setFilter(remove);
   fc.setClassifier(c);
   fc.setDebug(true);
   return fc;
 }
예제 #3
0
  private static FilteredClassifier prefixFilteredClassifier(Classifier c, String prefix) {
    RemoveWithPrefix rwp = new RemoveWithPrefix();
    rwp.setMatchClass(true);
    rwp.setPrefix(prefix);
    rwp.setInvertSelection(true);

    FilteredClassifier fc = new FilteredClassifier();
    fc.setClassifier(c);
    fc.setFilter(rwp);

    return fc;
  }
  public static void wekaAlgorithms(Instances data) throws Exception {
    classifier = new FilteredClassifier(); // new instance of tree
    classifier.setClassifier(new NaiveBayes());
    //  classifier.setClassifier(new J48());
    // classifier.setClassifier(new RandomForest());

    //	classifier.setClassifier(new ZeroR());
    //  classifier.setClassifier(new NaiveBayes());
    //     classifier.setClassifier(new IBk());

    data.setClassIndex(data.numAttributes() - 1);
    Evaluation eval = new Evaluation(data);

    int folds = 10;
    eval.crossValidateModel(classifier, data, folds, new Random(1));

    System.out.println("===== Evaluating on filtered (training) dataset =====");
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toClassDetailsString());
    double[][] mat = eval.confusionMatrix();
    System.out.println("========= Confusion Matrix =========");
    for (int i = 0; i < mat.length; i++) {
      for (int j = 0; j < mat.length; j++) {

        System.out.print(mat[i][j] + "  ");
      }
      System.out.println(" ");
    }
  }
  public static void classify(Instances train, File file) throws Exception {
    FastVector atts = new FastVector();
    String[] classes = {"classical", "hiphop", "pop", "rock"};
    double[] val;
    FastVector attValsRel = new FastVector();

    // Setting attributes for the test data
    Attribute attributeZero = new Attribute("Zero_Crossings");
    atts.addElement(attributeZero);
    Attribute attributeLPC = new Attribute("LPC");
    atts.addElement(attributeLPC);
    Attribute attributeCentroid = new Attribute("Spectral_Centroid");
    atts.addElement(attributeCentroid);
    Attribute attributeRollOff = new Attribute("Spectral_Rolloff_Point");
    atts.addElement(attributeRollOff);
    Attribute attributePeakDetection = new Attribute("Peak_Detection");
    atts.addElement(attributePeakDetection);
    Attribute attributeStrongestBeat = new Attribute("Strongest_Beat");
    atts.addElement(attributeStrongestBeat);
    Attribute attributeBeatSum = new Attribute("Beat_Sum");
    atts.addElement(attributeBeatSum);
    Attribute attributeRMS = new Attribute("RMS");
    atts.addElement(attributeRMS);
    Attribute attributeConstantQ = new Attribute("ConstantQ");
    atts.addElement(attributeConstantQ);
    Attribute attributeMFT = new Attribute("MagnitudeFFT");
    atts.addElement(attributeMFT);
    Attribute attributeMFCC = new Attribute("MFCC");
    atts.addElement(attributeMFCC);

    for (int i = 0; i < classes.length; i++) attValsRel.addElement(classes[i]);
    atts.addElement(new Attribute("class", attValsRel));
    // Adding instances to the test dataset

    Instances test = new Instances("AudioSamples", atts, 0);
    val = makeData(file, null, attValsRel, test.numAttributes());
    Instance instance = new Instance(test.numAttributes());
    instance.setValue(attributeZero, val[0]);
    instance.setValue(attributeLPC, val[1]);
    instance.setValue(attributeCentroid, val[2]);
    instance.setValue(attributeRollOff, val[3]);
    instance.setValue(attributePeakDetection, val[4]);
    instance.setValue(attributeStrongestBeat, val[5]);
    instance.setValue(attributeBeatSum, val[6]);
    instance.setValue(attributeRMS, val[7]);
    instance.setValue(attributeConstantQ, val[8]);
    instance.setValue(attributeMFT, val[9]);
    instance.setValue(attributeMFCC, val[10]);
    test.add(instance);
    // Setting the class attribute
    test.setClassIndex(test.numAttributes() - 1);
    System.out.println(test);
    // Trainging the classifier with train dataset
    classifier = new FilteredClassifier();
    classifier.buildClassifier(train);

    // Classifying the test data with the train data
    for (int i = 0; i < test.numInstances(); i++) {
      double pred = classifier.classifyInstance(test.instance(i));
      System.out.println("===== Classified instance =====");
      System.out.println("Class predicted: " + test.classAttribute().value((int) pred));
    }
  }
예제 #6
0
  private double[] classify(String test) {

    String[] lab = {
      "I.2", "I.3", "I.5", "I.6", "I.2.1", "I.2.6", "I.2.8", "I.3.5", "I.3.6", "I.3.7", "I.5.1",
      "I.5.2", "I.5.4", "I.6.3", "I.6.5", "I.6.8",
    };

    int NSel = 1000; //       Number of selection
    Filter[] filters = new Filter[2];
    double[] x = new double[16];
    double[] prd = new double[16];
    double clsLabel;
    Ranker rank = new Ranker();
    Evaluation eval = null;

    StringToWordVector stwv = new StringToWordVector();
    weka.filters.supervised.attribute.AttributeSelection featSel =
        new weka.filters.supervised.attribute.AttributeSelection();

    WordTokenizer wtok = new WordTokenizer();
    String delim = " \r\n\t.,;:'\"()?!$*-&[]+/|\\";

    InfoGainAttributeEval ig = new InfoGainAttributeEval();

    String[] stwvOpts;
    wtok.setDelimiters(delim);

    Instances[] dataRaw = new Instances[10000];

    DataSource[] source = new DataSource[16];

    String str;

    Instances testset = null;
    DataSource testsrc = null;
    try {
      testsrc = new DataSource(test);
      testset = testsrc.getDataSet();
    } catch (Exception e1) {
      // TODO Auto-generated catch block
      e1.printStackTrace();
    }

    for (int j = 0; j < 16; j++) // 16 element 0-15
    {
      try {
        str = lab[j];
        source[j] =
            new DataSource(
                "D:/Users/nma1g11/workspace2/WebScraperFlatNew/dataPernode/new/" + str + ".arff");
        dataRaw[j] = source[j].getDataSet();
      } catch (Exception e) {
        e.printStackTrace();
      }

      System.out.println(lab[j]);
      if (dataRaw[j].classIndex() == -1) dataRaw[j].setClassIndex(dataRaw[j].numAttributes() - 1);
    }
    if (testset.classIndex() == -1) testset.setClassIndex(testset.numAttributes() - 1);

    try {
      stwvOpts =
          weka.core.Utils.splitOptions(
              "-R first-last -W 1000000 -prune-rate -1.0 -C -T -I -N 1 -L -S -stemmer weka.core.stemmers.LovinsStemmer -M 2 ");
      stwv.setOptions(stwvOpts);
      stwv.setTokenizer(wtok);

      rank.setOptions(weka.core.Utils.splitOptions("-T -1.7976931348623157E308 -N 100"));
      rank.setNumToSelect(NSel);
      featSel.setEvaluator(ig);
      featSel.setSearch(rank);
    } catch (Exception e) {
      e.printStackTrace();
    }

    filters[0] = stwv;
    filters[1] = featSel;

    System.out.println("Loading is Done!");

    MultiFilter mfilter = new MultiFilter();

    mfilter.setFilters(filters);

    FilteredClassifier classify = new FilteredClassifier();
    classify.setClassifier(
        new NaiveBayesMultinomial()); ///////// Algorithm of The Classification  /////////
    classify.setFilter(mfilter);

    String ss2 = "";

    try {
      Classifier[] clsArr = new Classifier[16];
      clsArr = Classifier.makeCopies(classify, 16);
      String strcls = "";

      List<String> clsList = new ArrayList<String>();
      String s = null;
      String newcls = null;
      String lb = "";
      String prev = "";
      boolean flag = false;
      String Ocls = null;
      int q = 0;

      for (int i = 0; i < 16; i++) {

        for (int k = 0; k < testset.numInstances(); k++) {
          flag = false;

          s = testset.instance(k).stringValue(1);
          clsList.add(s);
          if (lab[i].equals(s)) {
            flag = true;
            newcls = s;
          }
        }

        clsArr[i].buildClassifier(dataRaw[i]);
        eval = new Evaluation(dataRaw[i]);
        for (int j = 0; j < testset.numInstances(); j++) {
          Ocls = testset.instance(j).stringValue(1);

          if (flag && !s.equals(null)) testset.instance(j).setClassValue(lab[i]);

          // -----------------------------------------
          strcls = testset.instance(j).stringValue(1);
          if (i < 4) {
            if (strcls.substring(0, 3).equals(lab[i])) testset.instance(j).setClassValue(lab[i]);
          } else if (lab[i].substring(0, 3).equals(strcls))
            testset.instance(j).setClassValue(lab[i]);
          // ------------------------------------------------
          System.out.println(
              dataRaw[i].classAttribute().value(i)
                  + " --- > Correct%:"
                  + eval.pctCorrect()
                  + "  F-measure:"
                  + eval.fMeasure(i));
          if (!prev.equals(testset.instance(j).stringValue(0)) || !lab[i].equals(lb)) {

            clsLabel = clsArr[i].classifyInstance(testset.instance(j));
            x = clsArr[i].distributionForInstance(testset.instance(j));

            prd[i] = x[i];
            System.out.println(" --- > prob: " + clsLabel);
            System.out.println(" --- > x :" + x[i]);
            System.out.println(clsLabel + " --> " + testset.classAttribute().value((int) clsLabel));
          }
          testset.instance(j).setClassValue(Ocls);

          prev = testset.instance(j).stringValue(0);
          lb = lab[i];
        }

        System.out.println("Done with " + lab[i].replace("99", "") + " !!!!!!!!!!!");
      }
      System.out.println(eval.correct());

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }

    return prd;
  }