/**
   * Parses the development set with the specified grammar and FOM. Returns the accuracy (F1) of the
   * resulting parses. Uses the specified <code>beamWidth</code> for all cells spanning >= 2 words.
   * For lexical cells, uses <code>beamWidth</code> x 3 (since the FOM does not prioritize entries
   * in lexical cells), and allocates <code>beamWidth</code> entries in lexical cells for unary
   * productions.
   *
   * @param sparseMatrixGrammar
   * @param posFom
   * @param cycle
   * @return Accuracy (F1) and speed (w/s)
   */
  protected float[] parseDevSet(
      final LeftCscSparseMatrixGrammar sparseMatrixGrammar,
      final BoundaryPosModel posFom,
      final int cycle) {

    // Initialize the parser
    final ParserDriver opts = new ParserDriver();
    opts.researchParserType = ResearchParserType.CartesianProductHashMl;
    opts.cellSelectorModel = ccModel;
    opts.fomModel = posFom;

    // Set beam-width configuration properties
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_MAX_BEAM_WIDTH, Integer.toString(beamWidth));
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_LEXICAL_ROW_BEAM_WIDTH, Integer.toString(beamWidth * 3));
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_LEXICAL_ROW_UNARIES, Integer.toString(beamWidth));

    // Parse the dev-set
    final CartesianProductHashSpmlParser parser =
        new CartesianProductHashSpmlParser(opts, sparseMatrixGrammar);
    final long t0 = System.currentTimeMillis();
    int words = 0;
    final BracketEvaluator evaluator = new BracketEvaluator();
    for (final String inputTree : developmentSet) {
      parser.parseSentence(inputTree).evaluate(evaluator);
      final NaryTree<String> naryTree = NaryTree.read(inputTree, String.class);
      words += naryTree.leaves();
    }
    final long t1 = System.currentTimeMillis();

    return new float[] {(float) evaluator.accumulatedResult().f1(), words * 1000f / (t1 - t0)};
  }
  public GrammarParallelCscSpmvParser(
      final ParserDriver opts, final LeftCscSparseMatrixGrammar grammar) {
    super(opts, grammar);

    final ConfigProperties props = GlobalConfigProperties.singleton();
    // Split the binary grammar rules into segments of roughly equal size
    final int requestedThreads = props.getIntProperty(ParserDriver.OPT_GRAMMAR_THREAD_COUNT);
    final int[] segments = new int[requestedThreads + 1];
    final int segmentSize = grammar.cscBinaryRowIndices.length / requestedThreads + 1;
    segments[0] = 0;
    int i = 1;
    // Examine each populated column
    for (int j = 1; j < grammar.cscBinaryPopulatedColumns.length - 1; j++) {
      if (grammar.cscBinaryPopulatedColumnOffsets[j]
              - grammar.cscBinaryPopulatedColumnOffsets[segments[i - 1]]
          >= segmentSize) {
        segments[i++] = j;
      }
    }
    segments[i] = grammar.cscBinaryPopulatedColumnOffsets.length - 1;

    this.grammarThreads = i;
    this.cpvSegments = grammarThreads * 2;
    GlobalConfigProperties.singleton()
        .setProperty(
            ParserDriver.RUNTIME_CONFIGURED_THREAD_COUNT,
            Integer.toString(
                props.getIntProperty(ParserDriver.OPT_CELL_THREAD_COUNT, 1) * grammarThreads));

    this.binaryRowSegments = new int[i + 1];
    System.arraycopy(segments, 0, binaryRowSegments, 0, binaryRowSegments.length);

    if (BaseLogger.singleton().isLoggable(Level.FINE)) {
      final StringBuilder sb = new StringBuilder();
      for (int j = 1; j < binaryRowSegments.length; j++) {
        sb.append(
            (grammar.cscBinaryPopulatedColumnOffsets[binaryRowSegments[j]]
                    - grammar.cscBinaryPopulatedColumnOffsets[binaryRowSegments[j - 1]])
                + " ");
      }
      BaseLogger.singleton().fine("INFO: CSC Binary Grammar segments of length: " + sb.toString());
    }

    // Temporary cell storage for each grammar-level thread
    this.threadLocalTemporaryCellArrays =
        new ThreadLocal<PackedArrayChart.TemporaryChartCell[]>() {

          @Override
          protected PackedArrayChart.TemporaryChartCell[] initialValue() {
            final PackedArrayChart.TemporaryChartCell[] tcs =
                new PackedArrayChart.TemporaryChartCell[grammarThreads];
            for (int j = 0; j < grammarThreads; j++) {
              tcs[j] = new PackedArrayChart.TemporaryChartCell(grammar, false);
            }
            return tcs;
          }
        };
  }
  /**
   * Evaluates speed and accuracy with the baseline (fully-split) grammar
   *
   * @param grammar
   * @param lexicon
   * @param cycle
   */
  @Override
  public void initMergeCycle(final Grammar grammar, final Lexicon lexicon, final int cycle) {

    this.splitGrammar = grammar;
    this.splitLexicon = lexicon;
    final String[] split =
        GlobalConfigProperties.singleton()
            .getProperty(PROPERTY_BEAM_WIDTHS, DEFAULT_BEAM_WIDTHS)
            .split(",");
    this.beamWidth = Integer.parseInt(split[cycle - 1]);

    // Convert the grammar to BUBS sparse-matrix format and train a Boundary POS FOM
    BaseLogger.singleton()
        .info("Constrained parsing the training-set and training a prioritization model");
    final LeftCscSparseMatrixGrammar sparseMatrixGrammar =
        convertGrammarToSparseMatrix(splitGrammar, splitLexicon);
    final BoundaryPosModel posFom = trainPosFom(sparseMatrixGrammar);

    // Record accuracy with the full split grammar (so we can later compare with MergeCandidates)
    BaseLogger.singleton().info("Parsing the dev-set with the fully-split grammar");
    final float[] parseResult = parseDevSet(sparseMatrixGrammar, posFom, beamWidth);
    splitF1 = parseResult[0];
    splitSpeed = parseResult[1];
    BaseLogger.singleton()
        .info(String.format("F1 = %.3f  Speed = %.3f w/s", splitF1 * 100, splitSpeed));
  }
/**
 * Performs pruned inference on a development set. The 'normal' implementation ( {@link
 * DiscriminativeMergeObjectiveFunction}) evaluates each merge candidate separately; {@link
 * SamplingMergeObjective} shares much of the infrastructure, but evaluate a set of merge candidates
 * in combination.
 *
 * @author Aaron Dunlop
 */
public abstract class InferenceInformedMergeObjectiveFunction extends MergeObjectiveFunction {

  protected float splitF1;
  protected float splitSpeed;

  protected Grammar splitGrammar;
  protected Lexicon splitLexicon;
  private float minRuleProbability;
  private List<String> trainingCorpus;
  private List<String> developmentSet;
  private CompleteClosureModel ccModel;
  protected int beamWidth;

  private static final String PROPERTY_BEAM_WIDTHS = "beamWidths";
  private static final String DEFAULT_BEAM_WIDTHS = "15,15,20,20,30,30";

  private static final String PROPERTY_PARSE_FRACTION = "discParseFraction";
  private static final float DEFAULT_PARSE_FRACTION = .5f;

  /**
   * Fraction of merge candidates to parse - if x < 1, we estimate likelihood loss (as in the
   * Likelihood {@link MergeRanking}), retain the top (1 - x) / 2 and discard the bottom (1 - x) /
   * 2, and perform inference to evaluate the candidates in between
   */
  protected static final float PARSE_FRACTION =
      GlobalConfigProperties.singleton()
          .getFloatProperty(PROPERTY_PARSE_FRACTION, DEFAULT_PARSE_FRACTION);

  @SuppressWarnings("hiding")
  public void init(
      final CompleteClosureModel ccModel,
      final List<String> trainingCorpus,
      final List<String> developmentSet,
      final float minRuleProbability) {

    // Store CC model, training corpus, and dev-set (they'll be consistent throughout all training
    // cycles)
    this.ccModel = ccModel;
    this.minRuleProbability = minRuleProbability;
    this.ccModel = ccModel;
    this.trainingCorpus = trainingCorpus;

    // Un-binarize the dev-set
    this.developmentSet = new ArrayList<String>();
    for (final String binarizedTree : developmentSet) {
      this.developmentSet.add(
          BinaryTree.read(binarizedTree, String.class)
              .unfactor(GrammarFormatType.Berkeley)
              .toString());
    }
  }

  /**
   * Evaluates speed and accuracy with the baseline (fully-split) grammar
   *
   * @param grammar
   * @param lexicon
   * @param cycle
   */
  @Override
  public void initMergeCycle(final Grammar grammar, final Lexicon lexicon, final int cycle) {

    this.splitGrammar = grammar;
    this.splitLexicon = lexicon;
    final String[] split =
        GlobalConfigProperties.singleton()
            .getProperty(PROPERTY_BEAM_WIDTHS, DEFAULT_BEAM_WIDTHS)
            .split(",");
    this.beamWidth = Integer.parseInt(split[cycle - 1]);

    // Convert the grammar to BUBS sparse-matrix format and train a Boundary POS FOM
    BaseLogger.singleton()
        .info("Constrained parsing the training-set and training a prioritization model");
    final LeftCscSparseMatrixGrammar sparseMatrixGrammar =
        convertGrammarToSparseMatrix(splitGrammar, splitLexicon);
    final BoundaryPosModel posFom = trainPosFom(sparseMatrixGrammar);

    // Record accuracy with the full split grammar (so we can later compare with MergeCandidates)
    BaseLogger.singleton().info("Parsing the dev-set with the fully-split grammar");
    final float[] parseResult = parseDevSet(sparseMatrixGrammar, posFom, beamWidth);
    splitF1 = parseResult[0];
    splitSpeed = parseResult[1];
    BaseLogger.singleton()
        .info(String.format("F1 = %.3f  Speed = %.3f w/s", splitF1 * 100, splitSpeed));
  }

  /**
   * Parses the development set with the specified grammar and FOM. Returns the accuracy (F1) of the
   * resulting parses. Uses the specified <code>beamWidth</code> for all cells spanning >= 2 words.
   * For lexical cells, uses <code>beamWidth</code> x 3 (since the FOM does not prioritize entries
   * in lexical cells), and allocates <code>beamWidth</code> entries in lexical cells for unary
   * productions.
   *
   * @param sparseMatrixGrammar
   * @param posFom
   * @param cycle
   * @return Accuracy (F1) and speed (w/s)
   */
  protected float[] parseDevSet(
      final LeftCscSparseMatrixGrammar sparseMatrixGrammar,
      final BoundaryPosModel posFom,
      final int cycle) {

    // Initialize the parser
    final ParserDriver opts = new ParserDriver();
    opts.researchParserType = ResearchParserType.CartesianProductHashMl;
    opts.cellSelectorModel = ccModel;
    opts.fomModel = posFom;

    // Set beam-width configuration properties
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_MAX_BEAM_WIDTH, Integer.toString(beamWidth));
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_LEXICAL_ROW_BEAM_WIDTH, Integer.toString(beamWidth * 3));
    GlobalConfigProperties.singleton()
        .setProperty(Parser.PROPERTY_LEXICAL_ROW_UNARIES, Integer.toString(beamWidth));

    // Parse the dev-set
    final CartesianProductHashSpmlParser parser =
        new CartesianProductHashSpmlParser(opts, sparseMatrixGrammar);
    final long t0 = System.currentTimeMillis();
    int words = 0;
    final BracketEvaluator evaluator = new BracketEvaluator();
    for (final String inputTree : developmentSet) {
      parser.parseSentence(inputTree).evaluate(evaluator);
      final NaryTree<String> naryTree = NaryTree.read(inputTree, String.class);
      words += naryTree.leaves();
    }
    final long t1 = System.currentTimeMillis();

    return new float[] {(float) evaluator.accumulatedResult().f1(), words * 1000f / (t1 - t0)};
  }

  /**
   * Converts a {@link Grammar} and {@link Lexicon} to BUBS sparse-matrix format, specifically
   * {@link LeftCscSparseMatrixGrammar}. Prunes rules below the minimum rule probability threshold
   * specified when the {@link DiscriminativeMergeObjectiveFunction} was initialized with {@link
   * #init(CompleteClosureModel, List, List, float)}.
   *
   * @param grammar
   * @param lexicon
   * @return {@link LeftCscSparseMatrixGrammar}
   */
  protected LeftCscSparseMatrixGrammar convertGrammarToSparseMatrix(
      final Grammar grammar, final Lexicon lexicon) {
    try {
      final Writer w = new StringWriter(150 * 1024 * 1024);
      // Note We could use a PipedOutputStream / PipedInputStream combination (with 2 threads) to
      // write and read
      // at the same time, and avoid using enough memory to serialize the entire grammar. But memory
      // isn't a huge
      // constraint during training, and the threading would add complexity, so we'll skip that for
      // now.
      w.write(grammar.toString(lexicon.totalRules(minRuleProbability), minRuleProbability, 0, 0));
      w.write("===== LEXICON =====\n");
      w.write(lexicon.toString(minRuleProbability));
      return new LeftCscSparseMatrixGrammar(
          new StringReader(w.toString()), new DecisionTreeTokenClassifier());

    } catch (final IOException e) {
      // StringWriter and StringReader should never IOException
      throw new AssertionError(e);
    }
  }

  /**
   * Trains a boundary POS prioritization model (AKA a figure-of-merit, or FOM). Parses the training
   * corpus, constrained by the gold trees, and learns prioritization probabilities from the
   * resulting parses.
   *
   * @param sparseMatrixGrammar
   * @return a boundary POS figure of merit model.
   */
  protected BoundaryPosModel trainPosFom(final LeftCscSparseMatrixGrammar sparseMatrixGrammar) {

    try {
      // Constrained parse the training corpus
      final ParserDriver opts = new ParserDriver();
      opts.cellSelectorModel = ConstrainedCellSelector.MODEL;
      opts.researchParserType = ResearchParserType.ConstrainedCartesianProductHashMl;
      final ConstrainedCphSpmlParser constrainedParser =
          new ConstrainedCphSpmlParser(opts, sparseMatrixGrammar);

      final StringWriter binaryConstrainedParses = new StringWriter(30 * 1024 * 1024);
      for (final String inputTree : trainingCorpus) {
        final ParseTask parseTask = constrainedParser.parseSentence(inputTree);
        binaryConstrainedParses.write(parseTask.binaryParse.toString());
        binaryConstrainedParses.write('\n');
      }

      final StringWriter serializedFomModel = new StringWriter(30 * 1024 * 1024);
      BoundaryPosModel.train(
          sparseMatrixGrammar,
          new BufferedReader(new StringReader(binaryConstrainedParses.toString())),
          new BufferedWriter(serializedFomModel),
          .5f,
          false,
          2);

      final BufferedReader fomModelReader =
          new BufferedReader(new StringReader(serializedFomModel.toString()));
      return new BoundaryPosModel(FOMType.BoundaryPOS, sparseMatrixGrammar, fomModelReader);

    } catch (final IOException e) {
      // StringWriter and StringReader should never IOException
      throw new AssertionError(e);
    }
  }
}