/**
   * Tests the binary SpMV multiplication of the cartesian-product computed in {@link
   * #testUnfilteredCartesianProductVectorSimpleGrammar2()} with simple grammar 2.
   *
   * @throws Exception if something bad happens
   */
  @Test
  public void testBinarySpMVMultiplySimpleGrammar2() throws Exception {

    // Create the parser
    final SparseMatrixGrammar g = (SparseMatrixGrammar) simpleGrammar2;
    final SparseMatrixVectorParser<?, ?> p = createParser(g, parserOptions(), configProperties());
    final ParseTask parseTask =
        new ParseTask(
            "The fish market stands last", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax);
    p.initSentence(parseTask);
    final Chart chart = p.chart;

    final float[] probabilities = new float[g.packingFunction().packedArraySize()];
    Arrays.fill(probabilities, Float.NEGATIVE_INFINITY);
    final short[] midpoints = new short[g.packingFunction().packedArraySize()];

    populateSimpleGrammar2Rows1_3(chart, g);

    //
    // Test SpMV for cell 0,4
    //

    // Midpoint 1
    probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = -2.890f;
    midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = 1;

    probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = -2.890f;
    midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = 1;

    // Midpoint 2
    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = -2.485f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = 2;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -4.277f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 2;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -4.277f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 2;

    // Midpoint 3
    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.663f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = -4.277f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = 3;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -4.277f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 3;

    CartesianProductVector crossProductVector =
        new CartesianProductVector(g, probabilities, midpoints, 8);

    // Check the SpMV multiplication
    final ChartCell cell_0_4 = p.chart.getCell(0, 4);
    p.binarySpmv(crossProductVector, cell_0_4);

    assertEquals(2, cell_0_4.getNumNTs());

    final ChartEdge np = cell_0_4.getBestEdge(g.mapNonterminal("NP"));
    assertEquals(-4.27667, np.inside(), .001f);
    assertEquals("Wrong left child cell", chart.getCell(0, 1), np.leftCell);
    assertEquals("Wrong right child cell", chart.getCell(1, 4), np.rightCell);

    ChartEdge s = cell_0_4.getBestEdge(g.mapNonterminal("S"));
    assertEquals(-5.66296f, s.inside(), .001f);
    assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell);
    assertEquals("Wrong right child cell", chart.getCell(3, 4), s.rightCell);

    //
    // Test SpMV for cell 0,5
    //
    populateSimpleGrammar2Rows1_3(chart, g);
    populateSimpleGrammar2Row4(chart, g);

    Arrays.fill(probabilities, Float.NEGATIVE_INFINITY);

    // Midpoint 3
    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.37528f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -6.474f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 3;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -6.474f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 3;

    // Midpoint 4
    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = -4.682f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = 4;

    probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -5.375f;
    midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 4;

    crossProductVector = new CartesianProductVector(g, probabilities, midpoints, 8);

    // Check the SpMV multiplication
    final ChartCell cell_0_5 = p.chart.getCell(0, 5);
    p.binarySpmv(crossProductVector, cell_0_5);

    assertEquals(1, cell_0_5.getNumNTs());

    s = cell_0_5.getBestEdge(g.mapNonterminal("S"));
    assertEquals(-5.37528f, s.inside(), .001f);
    assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell);
    assertEquals("Wrong right child cell", chart.getCell(3, 5), s.rightCell);
  }
  /**
   * Tests an imagined example cartesian-product vector (based very loosely on the computation of
   * the top cell in the 'systems analyst arbitration chef' example)
   *
   * @throws Exception if something bad happens
   */
  @Test
  public void testCartesianProductVectorExample() throws Exception {

    // Create the parser
    final SparseMatrixGrammar g = (SparseMatrixGrammar) simpleGrammar1;
    final P p = createParser(g, parserOptions(), configProperties());
    final ParseTask parseTask =
        new ParseTask(
            "systems analyst arbitration chef",
            Parser.InputFormat.Text,
            g,
            DecodeMethod.ViterbiMax);
    p.initSentence(parseTask);
    final Chart chart = p.chart;

    final int nn = g.mapNonterminal("NN");
    final int np = g.mapNonterminal("NP");
    // Cell 0,1 contains NN (-2)
    // Cell 1,4 contains NN (-3), NP (-4)
    // So: 0,1 X 1,4 cross-product = NN/NN (-5,1), NN/NP (-6,1)
    final ChartCell cell_0_1 = chart.getCell(0, 1);
    cell_0_1.updateInside(new Production("NN", "NN", -2, false, g), cell_0_1, null, -2f);
    cell_0_1.finalizeCell();

    final ChartCell cell_1_3 = chart.getCell(1, 3);
    final ChartCell cell_1_4 = chart.getCell(1, 4);
    cell_1_4.updateInside(new Production("NN", "NN", -3f, false, g), cell_1_3, null, -3f);
    cell_1_4.updateInside(new Production("NP", "NP", -4f, false, g), cell_1_3, null, -4f);
    cell_1_4.finalizeCell();

    // Cell 0,2 contains NN (-2), NP (-3)
    // Cell 2,4 contains NN (-4), NP (-4)
    // So: 0,2 X 2,4 cross-product = NN/NN (-6,2), NN/NP (-6,2), NP/NN (-7,2), NP/NP (-7,2)
    final ChartCell cell_0_2 = chart.getCell(0, 2);
    cell_0_2.updateInside(
        new Production("NN", "NN", -2f, false, g), chart.getCell(0, 1), null, -2f);
    cell_0_2.updateInside(
        new Production("NP", "NP", -3f, false, g), chart.getCell(0, 1), null, -3f);
    cell_0_2.finalizeCell();

    final ChartCell cell_2_4 = chart.getCell(2, 4);
    cell_2_4.updateInside(
        new Production("NN", "NN", -4f, false, g), chart.getCell(2, 3), null, -4f);
    cell_2_4.updateInside(
        new Production("NP", "NP", -4f, false, g), chart.getCell(2, 3), null, -4f);
    cell_2_4.finalizeCell();

    // Cell 0,3 contains NP (-2)
    // Cell 3,4 contains NP (-2)
    // So: 0,3 X 3,4 cross-product = NP/NP (-4,3)
    final ChartCell cell_0_3 = chart.getCell(0, 3);
    cell_0_3.updateInside(new Production("NP", "NP", -2, false, g), chart.getCell(0, 2), null, -2f);
    cell_0_3.finalizeCell();

    final ChartCell cell_3_4 = chart.getCell(3, 4);
    cell_3_4.updateInside(
        new Production("NP", "NP", -2f, false, g), chart.getCell(3, 4), null, -2f);
    cell_3_4.finalizeCell();

    // So: 0,1 X 1,4 cross-product = NN/NN (-5,1), NN/NP (-6,1)
    // So: 0,2 X 2,4 cross-product = NN/NN (-6,2), NN/NP (-6,2), NP/NN (-7,2), NP/NP (-7,2)
    // So: 0,3 X 3,4 cross-product = NP/NP (-4,3)

    // Cross-product union should be NN/NN (-5,1), NN/NP (-6,1), NP/NN (-7,2), NP/NP (-4,3)
    final SparseMatrixVectorParser.CartesianProductVector crossProductVector =
        p.cartesianProductUnion(0, 4);
    final int[] expectedChildren =
        new int[] {pack(g, nn, nn), pack(g, nn, np), pack(g, np, nn), pack(g, np, np)};
    final float[] expectedProbabilities = new float[] {-5f, -6f, -7f, -4f};
    final int[] expectedMidpoints = new int[] {1, 1, 2, 3};

    for (int i = 0; i < expectedChildren.length; i++) {
      assertEquals(
          "Wrong probability #" + i,
          expectedProbabilities[i],
          crossProductVector.probability(expectedChildren[i]),
          .01f);
      assertEquals(
          "Wrong midpoint #" + i,
          expectedMidpoints[i],
          crossProductVector.midpoint(expectedChildren[i]));
    }
  }
  /**
   * Tests the cartesian-product vector computed in the top cells of the 'The fish market stands
   * last' example.
   *
   * @throws Exception if something bad happens
   */
  @Test
  public void testUnfilteredCartesianProductVectorSimpleGrammar2() throws Exception {

    final SparseMatrixGrammar g =
        (SparseMatrixGrammar) createGrammar(simpleGrammar2(), LeftShiftFunction.class);

    // Create the parser
    final P p = createParser(g, parserOptions(), configProperties());
    final ParseTask parseTask =
        new ParseTask(
            "The fish market stands last", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax);
    p.initSentence(parseTask);
    final Chart chart = p.chart;

    populateSimpleGrammar2Rows1_3(chart, g);
    populateSimpleGrammar2Row4(chart, g);

    // Row of span 5
    final ChartCell cell_0_5 = chart.getCell(0, 5);
    cell_0_5.updateInside(
        new Production("S", "NP", "VP", -5.37528f, simpleGrammar2),
        chart.getCell(0, 3),
        chart.getCell(3, 5),
        -5.37528f);
    cell_0_5.updateInside(
        new Production("TOP", "S", -5.37528f, false, simpleGrammar2), cell_0_5, null, -5.37528f);

    // Finalize all chart cells
    for (int i = 0; i < chart.size(); i++) {
      for (int j = i + 1; j <= chart.size(); j++) {
        chart.getCell(i, j).finalizeCell();
      }
    }

    // Cross-product union for cell 0,4
    SparseMatrixVectorParser.CartesianProductVector cartesianProductVector =
        p.cartesianProductUnion(0, 4);
    assertEquals(14, cartesianProductVector.size());

    // Midpoint 1
    assertEquals(
        -2.890f,
        cartesianProductVector.probability(
            pack(g, g.mapNonterminal("DT"), g.mapNonterminal("VP|VB"))),
        .001f);
    assertEquals(
        1,
        cartesianProductVector.midpoint(
            pack(g, g.mapNonterminal("DT"), g.mapNonterminal("VP|VB"))));

    assertEquals(
        -2.890f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))),
        .001f);
    assertEquals(
        1,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))));

    // Midpoint 2
    assertEquals(
        -2.485f,
        cartesianProductVector.probability(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))),
        .001f);
    assertEquals(
        2,
        cartesianProductVector.midpoint(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))));

    assertEquals(
        -4.277f,
        cartesianProductVector.probability(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))),
        .001f);
    assertEquals(
        2,
        cartesianProductVector.midpoint(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))));

    assertEquals(
        -4.277f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))),
        .001f);
    assertEquals(
        2,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))));

    // Midpoint 3
    assertEquals(
        -5.663f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))),
        .001f);
    assertEquals(
        3,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))));

    assertEquals(
        -4.277f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))),
        .001f);
    assertEquals(
        3,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))));

    // assertEquals(-4.277f,
    // cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))),
    // .001f);
    // assertEquals(3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"),
    // g.mapNonterminal("VB"))));

    // Cross-product union for cell 0,5
    cartesianProductVector = p.cartesianProductUnion(0, 5);
    assertEquals(12, cartesianProductVector.size());

    // Midpoint 3
    assertEquals(
        -5.37528f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))),
        .001f);
    assertEquals(
        3,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))));

    assertEquals(
        -6.474f,
        cartesianProductVector.probability(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))),
        .001f);
    assertEquals(
        3,
        cartesianProductVector.midpoint(
            pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))));

    assertEquals(
        -6.474f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))),
        .001f);
    assertEquals(
        3,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))));

    // Midpoint 4
    assertEquals(
        -4.682f,
        cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))),
        .001f);
    assertEquals(
        4,
        cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))));

    // assertEquals(-5.375f,
    // cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))),
    // .001f);
    // assertEquals(4, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"),
    // g.mapNonterminal("VB"))));
  }
  @Override
  public final void binarySpmv(
      final CartesianProductVector cartesianProductVector, final ChartCell chartCell) {

    final Future<?>[] futures = new Future[grammarThreads];
    final PackedArrayChart.TemporaryChartCell[] temporaryCells =
        threadLocalTemporaryCellArrays.get();

    // Iterate over binary grammar segments
    for (int i = 0; i < grammarThreads; i++) {
      final int segmentStart = binaryRowSegments[i];
      final int segmentEnd = binaryRowSegments[i + 1];
      final PackedArrayChart.TemporaryChartCell tmpCell = temporaryCells[i];

      if (cellSelector.hasCellConstraints()
          && cellSelector.isCellOnlyFactored(chartCell.start(), chartCell.end())) {
        futures[i] =
            threadPool.submit(
                new Runnable() {

                  @Override
                  public void run() {
                    tmpCell.clear();
                    // Multiply by the factored grammar rule matrix
                    binarySpmvMultiply(
                        cartesianProductVector,
                        grammar.factoredCscBinaryPopulatedColumns,
                        grammar.factoredCscBinaryPopulatedColumnOffsets,
                        grammar.factoredCscBinaryRowIndices,
                        grammar.factoredCscBinaryProbabilities,
                        tmpCell.packedChildren,
                        tmpCell.insideProbabilities,
                        tmpCell.midpoints,
                        segmentStart,
                        segmentEnd);
                  }
                });
      } else {
        futures[i] =
            threadPool.submit(
                new Runnable() {

                  @Override
                  public void run() {
                    tmpCell.clear();
                    // Multiply by the full grammar rule matrix
                    binarySpmvMultiply(
                        cartesianProductVector,
                        grammar.cscBinaryPopulatedColumns,
                        grammar.cscBinaryPopulatedColumnOffsets,
                        grammar.cscBinaryRowIndices,
                        grammar.cscBinaryProbabilities,
                        tmpCell.packedChildren,
                        tmpCell.insideProbabilities,
                        tmpCell.midpoints,
                        segmentStart,
                        segmentEnd);
                  }
                });
      }
    }

    final PackedArrayChartCell packedArrayCell = (PackedArrayChartCell) chartCell;
    packedArrayCell.allocateTemporaryStorage();
    try {
      // Wait for the first task to finish and use its arrays as the temporary cell storage
      futures[0].get();

      final TemporaryChartCell tmpCell = packedArrayCell.tmpCell;
      // TODO Eliminate this extra arraycopy
      final int arrayLength = temporaryCells[0].insideProbabilities.length;
      System.arraycopy(
          temporaryCells[0].insideProbabilities, 0, tmpCell.insideProbabilities, 0, arrayLength);
      System.arraycopy(temporaryCells[0].packedChildren, 0, tmpCell.packedChildren, 0, arrayLength);
      System.arraycopy(temporaryCells[0].midpoints, 0, tmpCell.midpoints, 0, arrayLength);

      // Wait for each other task to complete and merge results into the main temporary storage
      for (int i = 1; i < grammarThreads; i++) {
        futures[i].get();
        final PackedArrayChart.TemporaryChartCell threadTmpCell = temporaryCells[i];
        for (int j = 0; j < arrayLength; j++) {
          if (threadTmpCell.insideProbabilities[j] > tmpCell.insideProbabilities[j]) {
            tmpCell.insideProbabilities[j] = threadTmpCell.insideProbabilities[j];
            tmpCell.packedChildren[j] = threadTmpCell.packedChildren[j];
            tmpCell.midpoints[j] = threadTmpCell.midpoints[j];
          }
        }
      }
    } catch (final InterruptedException e) {
      e.printStackTrace();
    } catch (final ExecutionException e) {
      e.printStackTrace();
    }
  }