public weka.classifiers.Classifier getClassifier() throws Exception {

    StringToWordVector stwv = new StringToWordVector();
    stwv.setTFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF));
    stwv.setIDFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF));
    stwv.setLowerCaseTokens(hasParam(Constant.RUNTIME_PARAMS.CONV_LOWERCASE));
    stwv.setUseStoplist(hasParam(Constant.RUNTIME_PARAMS.REM_STOP_WORDS));
    stwv.setOutputWordCounts(hasParam(Constant.RUNTIME_PARAMS.USE_WORD_FREQ));
    if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) stwv.setInputFormat(getTrainData());
    if (hasParam(Constant.RUNTIME_PARAMS.USE_BIGRAM)) {
      NGramTokenizer tokenizer = new NGramTokenizer();
      tokenizer.setNGramMinSize(2);
      stwv.setTokenizer(tokenizer);
    } else if (hasParam(Constant.RUNTIME_PARAMS.USE_TRIGRAM)) {
      NGramTokenizer tokenizer = new NGramTokenizer();
      tokenizer.setNGramMinSize(3);
      stwv.setTokenizer(tokenizer);
    }
    if (hasParam(Constant.RUNTIME_PARAMS.USE_STEMMER)) {
      SnowballStemmer stemmer = new SnowballStemmer("porter");
      stwv.setStemmer(stemmer);
    }

    Logistic l = new Logistic();

    FilteredClassifier cls = new FilteredClassifier();
    cls.setClassifier(l);
    cls.setFilter(stwv);
    if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) cls.buildClassifier(getTrainData());

    return cls;
  }
  /** 分类过程 */
  public double classifyMessage(String message) throws Exception {

    filter.input(makeInstance(message, instances.stringFreeStructure()));
    Instance filteredInstance = filter.output(); // 必须使用原来的filter

    double predicted = classifier.classifyInstance(filteredInstance); // (int)predicted是类标索引
    //        System.out.println("Message classified as : "
    //                + instances.classAttribute().value((int) predicted));
    return predicted;
  }
  public void filterData() throws Exception {
    Instances data = source.getDataSet();
    StringToWordVector stv = new StringToWordVector();
    stv.setOptions(
        weka.core.Utils.splitOptions(
            "-R first-last -W 1000 "
                + "-prune-rate -1.0 -N 0 "
                + "-stemmer weka.core.stemmers.NullStemmer -M 1 "
                + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters  \\\" \\r\\n\\t.,;:\\\'\\\"()?!\""));

    stv.setInputFormat(data);
    Instances newdata = Filter.useFilter(data, stv);
    this.inst = newdata;
    this.inst.setClassIndex(0);
  }
  @Override
  public void crossValidation(String traindata) throws Exception {
    DataSource ds = new DataSource(traindata);
    Instances instances = ds.getDataSet();
    StringToWordVector stv = new StringToWordVector();
    stv.setOptions(
        weka.core.Utils.splitOptions(
            "-R first-last -W 1000 "
                + "-prune-rate -1.0 -N 0 "
                + "-stemmer weka.core.stemmers.NullStemmer -M 1 "
                + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters  \\\" \\r\\n\\t.,;:\\\'\\\"()?!\""));

    stv.setInputFormat(instances);
    instances = Filter.useFilter(instances, stv);
    instances.setClassIndex(0);
    Evaluation eval = new Evaluation(instances);
    eval.crossValidateModel(this.classifier, instances, 10, new Random(1));
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());
  }
  /**
   * Make data sets and train and test model
   *
   * @param filePathTrain
   * @param filePathTest
   * @param gram
   */
  public static void makeDataSet(String filePathTrain, String filePathTest, int gram) {

    TextDirectoryLoader loader = new TextDirectoryLoader();
    try {

      loader.setDirectory(new File(filePathTrain));
      Instances dataRawTrain = loader.getDataSet();

      loader.setDirectory(new File(filePathTest));
      Instances dataRawTest = loader.getDataSet();

      StringToWordVector filter = new StringToWordVector();
      NGramTokenizer tokeniser = new NGramTokenizer();

      tokeniser.setNGramMinSize(gram);
      tokeniser.setNGramMaxSize(gram);

      filter.setTokenizer(tokeniser);

      filter.setInputFormat(dataRawTrain);

      Instances train = Filter.useFilter(dataRawTrain, filter);

      // filter.setInputFormat(dataRawTest);

      Instances test = Filter.useFilter(dataRawTest, filter);

      /**
       * *
       *
       * <p>Replace this function each time to change models
       */
      trainModelNaiveBayes(train, test);
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
  public static void main(String[] args) {

    if (args.length < 1) {
      System.out.println("usage: C4_5TweetTopicCategorization <root_path>");
      System.exit(-1);
    }

    String rootPath = args[0];
    File dataFolder = new File(rootPath + "/data");
    String resultFolderPath = rootPath + "/results/C4_5/";

    CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer();
    Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class);
    PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH);

    File resultFolder = new File(resultFolderPath);
    if (!resultFolder.exists()) resultFolder.mkdir();

    CSVLoader csvLoader = new CSVLoader();

    try {
      for (File dataSetName : dataFolder.listFiles()) {

        Instances data = null;
        try {
          csvLoader.setSource(dataSetName);
          csvLoader.setStringAttributes("2");
          data = csvLoader.getDataSet();
        } catch (IOException ioe) {
          logger.error(ioe);
          crisisMailer.sendEmailAlert(ioe);
          System.exit(-1);
        }

        data.setClassIndex(data.numAttributes() - 1);
        data.deleteWithMissingClass();

        Instances vectorizedData = null;
        StringToWordVector stringToWordVectorFilter = new StringToWordVector();
        try {
          stringToWordVectorFilter.setInputFormat(data);
          stringToWordVectorFilter.setAttributeIndices("2");
          stringToWordVectorFilter.setIDFTransform(true);
          stringToWordVectorFilter.setLowerCaseTokens(true);
          stringToWordVectorFilter.setOutputWordCounts(false);
          stringToWordVectorFilter.setUseStoplist(true);

          vectorizedData = Filter.useFilter(data, stringToWordVectorFilter);
          vectorizedData.deleteAttributeAt(0);
          // System.out.println(vectorizedData);
        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }

        J48 j48Classifier = new J48();

        /*
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(stringToWordVectorFilter);
        filteredClassifier.setClassifier(j48Classifier);
        */

        try {
          Evaluation eval = new Evaluation(vectorizedData);
          eval.crossValidateModel(
              j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis()));

          FileOutputStream resultOutputStream =
              new FileOutputStream(new File(resultFolderPath + dataSetName.getName()));

          resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes());
          resultOutputStream.write(eval.toMatrixString().getBytes());
          resultOutputStream.write(eval.toClassDetailsString().getBytes());
          resultOutputStream.close();

        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }
      }
    } catch (Exception exception) {
      logger.error(exception);
      crisisMailer.sendEmailAlert(exception);
      System.out.println(-1);
    }
  }
Пример #7
0
  private double[] classify(String test) {

    String[] lab = {
      "I.2", "I.3", "I.5", "I.6", "I.2.1", "I.2.6", "I.2.8", "I.3.5", "I.3.6", "I.3.7", "I.5.1",
      "I.5.2", "I.5.4", "I.6.3", "I.6.5", "I.6.8",
    };

    int NSel = 1000; //       Number of selection
    Filter[] filters = new Filter[2];
    double[] x = new double[16];
    double[] prd = new double[16];
    double clsLabel;
    Ranker rank = new Ranker();
    Evaluation eval = null;

    StringToWordVector stwv = new StringToWordVector();
    weka.filters.supervised.attribute.AttributeSelection featSel =
        new weka.filters.supervised.attribute.AttributeSelection();

    WordTokenizer wtok = new WordTokenizer();
    String delim = " \r\n\t.,;:'\"()?!$*-&[]+/|\\";

    InfoGainAttributeEval ig = new InfoGainAttributeEval();

    String[] stwvOpts;
    wtok.setDelimiters(delim);

    Instances[] dataRaw = new Instances[10000];

    DataSource[] source = new DataSource[16];

    String str;

    Instances testset = null;
    DataSource testsrc = null;
    try {
      testsrc = new DataSource(test);
      testset = testsrc.getDataSet();
    } catch (Exception e1) {
      // TODO Auto-generated catch block
      e1.printStackTrace();
    }

    for (int j = 0; j < 16; j++) // 16 element 0-15
    {
      try {
        str = lab[j];
        source[j] =
            new DataSource(
                "D:/Users/nma1g11/workspace2/WebScraperFlatNew/dataPernode/new/" + str + ".arff");
        dataRaw[j] = source[j].getDataSet();
      } catch (Exception e) {
        e.printStackTrace();
      }

      System.out.println(lab[j]);
      if (dataRaw[j].classIndex() == -1) dataRaw[j].setClassIndex(dataRaw[j].numAttributes() - 1);
    }
    if (testset.classIndex() == -1) testset.setClassIndex(testset.numAttributes() - 1);

    try {
      stwvOpts =
          weka.core.Utils.splitOptions(
              "-R first-last -W 1000000 -prune-rate -1.0 -C -T -I -N 1 -L -S -stemmer weka.core.stemmers.LovinsStemmer -M 2 ");
      stwv.setOptions(stwvOpts);
      stwv.setTokenizer(wtok);

      rank.setOptions(weka.core.Utils.splitOptions("-T -1.7976931348623157E308 -N 100"));
      rank.setNumToSelect(NSel);
      featSel.setEvaluator(ig);
      featSel.setSearch(rank);
    } catch (Exception e) {
      e.printStackTrace();
    }

    filters[0] = stwv;
    filters[1] = featSel;

    System.out.println("Loading is Done!");

    MultiFilter mfilter = new MultiFilter();

    mfilter.setFilters(filters);

    FilteredClassifier classify = new FilteredClassifier();
    classify.setClassifier(
        new NaiveBayesMultinomial()); ///////// Algorithm of The Classification  /////////
    classify.setFilter(mfilter);

    String ss2 = "";

    try {
      Classifier[] clsArr = new Classifier[16];
      clsArr = Classifier.makeCopies(classify, 16);
      String strcls = "";

      List<String> clsList = new ArrayList<String>();
      String s = null;
      String newcls = null;
      String lb = "";
      String prev = "";
      boolean flag = false;
      String Ocls = null;
      int q = 0;

      for (int i = 0; i < 16; i++) {

        for (int k = 0; k < testset.numInstances(); k++) {
          flag = false;

          s = testset.instance(k).stringValue(1);
          clsList.add(s);
          if (lab[i].equals(s)) {
            flag = true;
            newcls = s;
          }
        }

        clsArr[i].buildClassifier(dataRaw[i]);
        eval = new Evaluation(dataRaw[i]);
        for (int j = 0; j < testset.numInstances(); j++) {
          Ocls = testset.instance(j).stringValue(1);

          if (flag && !s.equals(null)) testset.instance(j).setClassValue(lab[i]);

          // -----------------------------------------
          strcls = testset.instance(j).stringValue(1);
          if (i < 4) {
            if (strcls.substring(0, 3).equals(lab[i])) testset.instance(j).setClassValue(lab[i]);
          } else if (lab[i].substring(0, 3).equals(strcls))
            testset.instance(j).setClassValue(lab[i]);
          // ------------------------------------------------
          System.out.println(
              dataRaw[i].classAttribute().value(i)
                  + " --- > Correct%:"
                  + eval.pctCorrect()
                  + "  F-measure:"
                  + eval.fMeasure(i));
          if (!prev.equals(testset.instance(j).stringValue(0)) || !lab[i].equals(lb)) {

            clsLabel = clsArr[i].classifyInstance(testset.instance(j));
            x = clsArr[i].distributionForInstance(testset.instance(j));

            prd[i] = x[i];
            System.out.println(" --- > prob: " + clsLabel);
            System.out.println(" --- > x :" + x[i]);
            System.out.println(clsLabel + " --> " + testset.classAttribute().value((int) clsLabel));
          }
          testset.instance(j).setClassValue(Ocls);

          prev = testset.instance(j).stringValue(0);
          lb = lab[i];
        }

        System.out.println("Done with " + lab[i].replace("99", "") + " !!!!!!!!!!!");
      }
      System.out.println(eval.correct());

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }

    return prd;
  }
 /**
  * 文本分类要特别一点,因为在使用StringToWordVector对象计算文本中词项(attribute)权重的时候需要用到全局变量,比如DF,所以这里需要批量处理
  * 在weka中要注意有些机器学习算法是批处理有些不是
  */
 public void finishBatch() throws Exception {
   filter.setIDFTransform(true);
   filter.setInputFormat(instances);
   Instances filteredData = Filter.useFilter(instances, filter); // 这才真正产生符合weka算法输入格式的数据集
   classifier.buildClassifier(filteredData); // 真正的训练分类器
 }