コード例 #1
0
  @Test
  public void testSettingZIndicators() throws IOException {
    String whichModel = "uncollapsed";
    Integer numTopics = 20;
    Double alphaSum = 1.0;
    Double beta = 0.01;
    Integer numIter = 1000;
    Integer numBatches = 6;
    Integer rareWordThreshold = 0;
    Integer showTopicsInterval = 50;
    Integer startDiagnosticOutput = 500;

    SimpleLDAConfiguration config =
        new SimpleLDAConfiguration(
            new LoggingUtils(),
            whichModel,
            numTopics,
            alphaSum,
            beta,
            numIter,
            numBatches,
            rareWordThreshold,
            showTopicsInterval,
            startDiagnosticOutput,
            4711,
            "src/main/resources/datasets/nips.txt");

    LoggingUtils lu = new LoggingUtils();
    lu.checkAndCreateCurrentLogDir("Runs");
    config.setLoggingUtil(lu);
    config.activateSubconfig("demo-nips");

    System.out.println("Using Config: " + config.whereAmI());

    String dataset_fn = config.getDatasetFilename();
    System.out.println("Using dataset: " + dataset_fn);
    System.out.println("Scheme: " + whichModel);

    InstanceList instances =
        LDAUtils.loadInstances(
            dataset_fn,
            "stoplist.txt",
            config.getRareThreshold(LDAConfiguration.RARE_WORD_THRESHOLD));

    SerialCollapsedLDA collapsed =
        new SerialCollapsedLDA(
            numTopics,
            config.getAlpha(LDAConfiguration.ALPHA_DEFAULT),
            config.getBeta(LDAConfiguration.BETA_DEFAULT));
    collapsed.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT));
    collapsed.setConfiguration(config);
    collapsed.addInstances(instances);

    UncollapsedParallelLDA uncollapsed = new UncollapsedParallelLDA(config);
    uncollapsed.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT));
    uncollapsed.addInstances(instances);

    TestUtils.assertEqualArrays(collapsed.getTypeTopicCounts(), uncollapsed.getTypeTopicCounts());

    System.out.println("TTCounts 1 ok!");

    double collapsedModelLogLikelihood = collapsed.modelLogLikelihood();
    double uncollapsedModelLogLikelihood = uncollapsed.modelLogLikelihood();

    TestUtils.assertEqualArrays(collapsed.getTypeTopicCounts(), uncollapsed.getTypeTopicCounts());

    System.out.println("TTCounts 2 ok!");

    assertEquals(
        "Collapsed and UnCollapsed LogLikelihoods are not the same: "
            + collapsedModelLogLikelihood
            + " != "
            + uncollapsedModelLogLikelihood
            + " Diff: "
            + (collapsedModelLogLikelihood - uncollapsedModelLogLikelihood),
        collapsedModelLogLikelihood,
        uncollapsedModelLogLikelihood,
        epsilon);

    System.out.println("Precheck ok!");

    // sample 100 iterations just for the sake of doing getting to a
    // something other than the start state
    collapsed.sample(50);

    System.out.println("Finished sampling!");

    int[][] collapsedZas = collapsed.getZIndicators();
    uncollapsed.setZIndicators(collapsedZas);

    int[][] setZas = uncollapsed.getZIndicators();

    TestUtils.assertEqualArrays(collapsedZas, setZas);

    // Sample 5 iterations to change z
    uncollapsed.sample(5);

    int[][] uncollapsedZas = uncollapsed.getZIndicators();
    collapsed.setZIndicators(uncollapsedZas);
    int[][] setZascollapsed = collapsed.getZIndicators();

    TestUtils.assertEqualArrays(uncollapsedZas, setZascollapsed);

    System.out.println("Z indicators are equal...");

    // printTTCounts(collapsed.getTypeTopicCounts(),uncollapsed.getTypeTopicCounts());

    TestUtils.assertEqualArrays(collapsed.getTypeTopicCounts(), uncollapsed.getTypeTopicCounts());

    collapsedModelLogLikelihood = collapsed.modelLogLikelihood();
    uncollapsedModelLogLikelihood = uncollapsed.modelLogLikelihood();

    assertEquals(
        "Collapsed and UnCollapsed LogLikelihoods are not the same: "
            + collapsedModelLogLikelihood
            + " != "
            + uncollapsedModelLogLikelihood
            + " Diff: "
            + (collapsedModelLogLikelihood - uncollapsedModelLogLikelihood),
        collapsedModelLogLikelihood,
        uncollapsedModelLogLikelihood,
        epsilon);
  }
コード例 #2
0
  @Test
  public void testEqualInitialization() throws ParseException, ConfigurationException, IOException {
    // String [] args = {"--run_cfg=src/main/resources/configuration/TestConfig.cfg"};
    // LDACommandLineParser cp = new LDACommandLineParser(args);
    // LDAConfiguration config = (LDAConfiguration) ConfigFactory.getMainConfiguration(cp);

    int seed = 20150326;
    Integer numTopics = 20;
    Double alpha = 0.1;
    Double beta = 0.01;
    Integer numIter = 1000;
    Integer numBatches = 4;
    Integer rareWordThreshold = 10;
    Integer showTopicsInterval = 50;
    Integer startDiagnosticOutput = 500;

    SimpleLDAConfiguration config =
        new SimpleLDAConfiguration(
            new LoggingUtils(),
            "ALL",
            numTopics,
            alpha,
            beta,
            numIter,
            numBatches,
            rareWordThreshold,
            showTopicsInterval,
            startDiagnosticOutput,
            seed,
            "src/main/resources/datasets/nips.txt");

    LoggingUtils lu = new LoggingUtils();
    lu.checkAndCreateCurrentLogDir("Runs");
    config.setLoggingUtil(lu);
    config.activateSubconfig("demo-nips");

    System.out.println("Using Config: " + config.whereAmI());

    String dataset_fn = config.getDatasetFilename();
    System.out.println("Using dataset: " + dataset_fn);

    InstanceList instances =
        LDAUtils.loadInstances(
            dataset_fn,
            "stoplist.txt",
            config.getRareThreshold(LDAConfiguration.RARE_WORD_THRESHOLD));

    SerialCollapsedLDA collapsed =
        new SerialCollapsedLDA(
            numTopics,
            config.getAlpha(LDAConfiguration.ALPHA_DEFAULT),
            config.getBeta(LDAConfiguration.BETA_DEFAULT));
    collapsed.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT));
    collapsed.setConfiguration(config);
    collapsed.addInstances(instances);

    UncollapsedParallelLDA uncollapsed = new UncollapsedParallelLDA(config);
    uncollapsed.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT));
    uncollapsed.addInstances(instances);

    ADLDA adlda = new ADLDA(config);
    adlda.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT));
    adlda.addInstances(instances);

    int[][] collapsedTopicIndicators = collapsed.getTypeTopicCounts();
    int[][] uncollapsedTopicIndicators = uncollapsed.getTypeTopicCounts();
    int[][] adldaTopicIndicators = adlda.getTypeTopicCounts();

    for (int i = 0; i < uncollapsedTopicIndicators.length; i++) {
      for (int j = 0; j < uncollapsedTopicIndicators[0].length; j++) {
        assertEquals(
            "Collapsed and UnCollapsed are not the same: "
                + collapsedTopicIndicators[i][j]
                + "!="
                + uncollapsedTopicIndicators[i][j],
            collapsedTopicIndicators[i][j],
            uncollapsedTopicIndicators[i][j]);
        assertEquals(
            "Collapsed and ADLDA are not the same: "
                + collapsedTopicIndicators[i][j]
                + "!="
                + adldaTopicIndicators[i][j],
            collapsedTopicIndicators[i][j],
            adldaTopicIndicators[i][j]);
      }
    }

    int[] collapsedTokensPerTopic = collapsed.getTopicTotals();
    int[] uncollapsedTokensPerTopic = uncollapsed.getTopicTotals();
    int[] adldaTokensPerTopic = adlda.getTopicTotals();

    for (int i = 0; i < collapsedTokensPerTopic.length; i++) {
      assertEquals(
          "Collapsed and ADLA token counts are not the same: "
              + collapsedTokensPerTopic[i]
              + "!="
              + adldaTokensPerTopic[i],
          collapsedTokensPerTopic[i],
          adldaTokensPerTopic[i]);
      assertEquals(
          "Collapsed and UnCollapsed token counts are not the same: "
              + collapsedTokensPerTopic[i]
              + "!="
              + uncollapsedTokensPerTopic[i],
          collapsedTokensPerTopic[i],
          uncollapsedTokensPerTopic[i]);
    }

    double collapsedModelLogLikelihood = collapsed.modelLogLikelihood();
    double uncollapsedModelLogLikelihood = uncollapsed.modelLogLikelihood();
    double adldaModelLogLikelihood = adlda.modelLogLikelihood();

    assertEquals(
        "ADLDA and Collapsed LogLikelihoods are not the same: "
            + adldaModelLogLikelihood
            + " != "
            + collapsedModelLogLikelihood
            + " Diff: "
            + (adldaModelLogLikelihood - collapsedModelLogLikelihood),
        adldaModelLogLikelihood,
        collapsedModelLogLikelihood,
        epsilon);

    assertEquals(
        "Collapsed and UnCollapsed LogLikelihoods are not the same: "
            + collapsedModelLogLikelihood
            + " != "
            + uncollapsedModelLogLikelihood
            + " Diff: "
            + (collapsedModelLogLikelihood - uncollapsedModelLogLikelihood),
        collapsedModelLogLikelihood,
        uncollapsedModelLogLikelihood,
        epsilon);

    TestUtils.assertEqualArrays(collapsed.getTypeTopicCounts(), uncollapsed.getTypeTopicCounts());

    int[][] collapsedZIndicators = collapsed.getZIndicators();
    int[][] uncollapsedZIndicators = uncollapsed.getZIndicators();
    int[][] adldaZIndicators = adlda.getZIndicators();

    for (int i = 0; i < collapsedZIndicators.length; i++) {
      for (int j = 0; j < collapsedZIndicators[i].length; j++) {
        assertEquals(
            "Collapsed and UnCollapsed are not the same: "
                + collapsedZIndicators[i][j]
                + "!="
                + uncollapsedZIndicators[i][j],
            collapsedZIndicators[i][j],
            uncollapsedZIndicators[i][j]);
        assertEquals(
            "Collapsed and ADLDA are not the same: "
                + collapsedZIndicators[i][j]
                + "!="
                + adldaZIndicators[i][j],
            collapsedZIndicators[i][j],
            adldaZIndicators[i][j]);
      }
    }
  }