/** * Tests the binary SpMV multiplication of the cartesian-product computed in {@link * #testUnfilteredCartesianProductVectorSimpleGrammar2()} with simple grammar 2. * * @throws Exception if something bad happens */ @Test public void testBinarySpMVMultiplySimpleGrammar2() throws Exception { // Create the parser final SparseMatrixGrammar g = (SparseMatrixGrammar) simpleGrammar2; final SparseMatrixVectorParser<?, ?> p = createParser(g, parserOptions(), configProperties()); final ParseTask parseTask = new ParseTask( "The fish market stands last", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax); p.initSentence(parseTask); final Chart chart = p.chart; final float[] probabilities = new float[g.packingFunction().packedArraySize()]; Arrays.fill(probabilities, Float.NEGATIVE_INFINITY); final short[] midpoints = new short[g.packingFunction().packedArraySize()]; populateSimpleGrammar2Rows1_3(chart, g); // // Test SpMV for cell 0,4 // // Midpoint 1 probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = -2.890f; midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = 1; probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = -2.890f; midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = 1; // Midpoint 2 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = -2.485f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = 2; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 2; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 2; // Midpoint 3 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.663f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 3; CartesianProductVector crossProductVector = new CartesianProductVector(g, probabilities, midpoints, 8); // Check the SpMV multiplication final ChartCell cell_0_4 = p.chart.getCell(0, 4); p.binarySpmv(crossProductVector, cell_0_4); assertEquals(2, cell_0_4.getNumNTs()); final ChartEdge np = cell_0_4.getBestEdge(g.mapNonterminal("NP")); assertEquals(-4.27667, np.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 1), np.leftCell); assertEquals("Wrong right child cell", chart.getCell(1, 4), np.rightCell); ChartEdge s = cell_0_4.getBestEdge(g.mapNonterminal("S")); assertEquals(-5.66296f, s.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell); assertEquals("Wrong right child cell", chart.getCell(3, 4), s.rightCell); // // Test SpMV for cell 0,5 // populateSimpleGrammar2Rows1_3(chart, g); populateSimpleGrammar2Row4(chart, g); Arrays.fill(probabilities, Float.NEGATIVE_INFINITY); // Midpoint 3 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.37528f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -6.474f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -6.474f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 3; // Midpoint 4 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = -4.682f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = 4; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -5.375f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 4; crossProductVector = new CartesianProductVector(g, probabilities, midpoints, 8); // Check the SpMV multiplication final ChartCell cell_0_5 = p.chart.getCell(0, 5); p.binarySpmv(crossProductVector, cell_0_5); assertEquals(1, cell_0_5.getNumNTs()); s = cell_0_5.getBestEdge(g.mapNonterminal("S")); assertEquals(-5.37528f, s.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell); assertEquals("Wrong right child cell", chart.getCell(3, 5), s.rightCell); }
/** * Tests an imagined example cartesian-product vector (based very loosely on the computation of * the top cell in the 'systems analyst arbitration chef' example) * * @throws Exception if something bad happens */ @Test public void testCartesianProductVectorExample() throws Exception { // Create the parser final SparseMatrixGrammar g = (SparseMatrixGrammar) simpleGrammar1; final P p = createParser(g, parserOptions(), configProperties()); final ParseTask parseTask = new ParseTask( "systems analyst arbitration chef", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax); p.initSentence(parseTask); final Chart chart = p.chart; final int nn = g.mapNonterminal("NN"); final int np = g.mapNonterminal("NP"); // Cell 0,1 contains NN (-2) // Cell 1,4 contains NN (-3), NP (-4) // So: 0,1 X 1,4 cross-product = NN/NN (-5,1), NN/NP (-6,1) final ChartCell cell_0_1 = chart.getCell(0, 1); cell_0_1.updateInside(new Production("NN", "NN", -2, false, g), cell_0_1, null, -2f); cell_0_1.finalizeCell(); final ChartCell cell_1_3 = chart.getCell(1, 3); final ChartCell cell_1_4 = chart.getCell(1, 4); cell_1_4.updateInside(new Production("NN", "NN", -3f, false, g), cell_1_3, null, -3f); cell_1_4.updateInside(new Production("NP", "NP", -4f, false, g), cell_1_3, null, -4f); cell_1_4.finalizeCell(); // Cell 0,2 contains NN (-2), NP (-3) // Cell 2,4 contains NN (-4), NP (-4) // So: 0,2 X 2,4 cross-product = NN/NN (-6,2), NN/NP (-6,2), NP/NN (-7,2), NP/NP (-7,2) final ChartCell cell_0_2 = chart.getCell(0, 2); cell_0_2.updateInside( new Production("NN", "NN", -2f, false, g), chart.getCell(0, 1), null, -2f); cell_0_2.updateInside( new Production("NP", "NP", -3f, false, g), chart.getCell(0, 1), null, -3f); cell_0_2.finalizeCell(); final ChartCell cell_2_4 = chart.getCell(2, 4); cell_2_4.updateInside( new Production("NN", "NN", -4f, false, g), chart.getCell(2, 3), null, -4f); cell_2_4.updateInside( new Production("NP", "NP", -4f, false, g), chart.getCell(2, 3), null, -4f); cell_2_4.finalizeCell(); // Cell 0,3 contains NP (-2) // Cell 3,4 contains NP (-2) // So: 0,3 X 3,4 cross-product = NP/NP (-4,3) final ChartCell cell_0_3 = chart.getCell(0, 3); cell_0_3.updateInside(new Production("NP", "NP", -2, false, g), chart.getCell(0, 2), null, -2f); cell_0_3.finalizeCell(); final ChartCell cell_3_4 = chart.getCell(3, 4); cell_3_4.updateInside( new Production("NP", "NP", -2f, false, g), chart.getCell(3, 4), null, -2f); cell_3_4.finalizeCell(); // So: 0,1 X 1,4 cross-product = NN/NN (-5,1), NN/NP (-6,1) // So: 0,2 X 2,4 cross-product = NN/NN (-6,2), NN/NP (-6,2), NP/NN (-7,2), NP/NP (-7,2) // So: 0,3 X 3,4 cross-product = NP/NP (-4,3) // Cross-product union should be NN/NN (-5,1), NN/NP (-6,1), NP/NN (-7,2), NP/NP (-4,3) final SparseMatrixVectorParser.CartesianProductVector crossProductVector = p.cartesianProductUnion(0, 4); final int[] expectedChildren = new int[] {pack(g, nn, nn), pack(g, nn, np), pack(g, np, nn), pack(g, np, np)}; final float[] expectedProbabilities = new float[] {-5f, -6f, -7f, -4f}; final int[] expectedMidpoints = new int[] {1, 1, 2, 3}; for (int i = 0; i < expectedChildren.length; i++) { assertEquals( "Wrong probability #" + i, expectedProbabilities[i], crossProductVector.probability(expectedChildren[i]), .01f); assertEquals( "Wrong midpoint #" + i, expectedMidpoints[i], crossProductVector.midpoint(expectedChildren[i])); } }
/** * Tests the cartesian-product vector computed in the top cells of the 'The fish market stands * last' example. * * @throws Exception if something bad happens */ @Test public void testUnfilteredCartesianProductVectorSimpleGrammar2() throws Exception { final SparseMatrixGrammar g = (SparseMatrixGrammar) createGrammar(simpleGrammar2(), LeftShiftFunction.class); // Create the parser final P p = createParser(g, parserOptions(), configProperties()); final ParseTask parseTask = new ParseTask( "The fish market stands last", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax); p.initSentence(parseTask); final Chart chart = p.chart; populateSimpleGrammar2Rows1_3(chart, g); populateSimpleGrammar2Row4(chart, g); // Row of span 5 final ChartCell cell_0_5 = chart.getCell(0, 5); cell_0_5.updateInside( new Production("S", "NP", "VP", -5.37528f, simpleGrammar2), chart.getCell(0, 3), chart.getCell(3, 5), -5.37528f); cell_0_5.updateInside( new Production("TOP", "S", -5.37528f, false, simpleGrammar2), cell_0_5, null, -5.37528f); // Finalize all chart cells for (int i = 0; i < chart.size(); i++) { for (int j = i + 1; j <= chart.size(); j++) { chart.getCell(i, j).finalizeCell(); } } // Cross-product union for cell 0,4 SparseMatrixVectorParser.CartesianProductVector cartesianProductVector = p.cartesianProductUnion(0, 4); assertEquals(14, cartesianProductVector.size()); // Midpoint 1 assertEquals( -2.890f, cartesianProductVector.probability( pack(g, g.mapNonterminal("DT"), g.mapNonterminal("VP|VB"))), .001f); assertEquals( 1, cartesianProductVector.midpoint( pack(g, g.mapNonterminal("DT"), g.mapNonterminal("VP|VB")))); assertEquals( -2.890f, cartesianProductVector.probability(pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))), .001f); assertEquals( 1, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP")))); // Midpoint 2 assertEquals( -2.485f, cartesianProductVector.probability( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))), .001f); assertEquals( 2, cartesianProductVector.midpoint( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN")))); assertEquals( -4.277f, cartesianProductVector.probability( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))), .001f); assertEquals( 2, cartesianProductVector.midpoint( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB")))); assertEquals( -4.277f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))), .001f); assertEquals( 2, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP")))); // Midpoint 3 assertEquals( -5.663f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))), .001f); assertEquals( 3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP")))); assertEquals( -4.277f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))), .001f); assertEquals( 3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN")))); // assertEquals(-4.277f, // cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))), // .001f); // assertEquals(3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), // g.mapNonterminal("VB")))); // Cross-product union for cell 0,5 cartesianProductVector = p.cartesianProductUnion(0, 5); assertEquals(12, cartesianProductVector.size()); // Midpoint 3 assertEquals( -5.37528f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))), .001f); assertEquals( 3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP")))); assertEquals( -6.474f, cartesianProductVector.probability( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))), .001f); assertEquals( 3, cartesianProductVector.midpoint( pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB")))); assertEquals( -6.474f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))), .001f); assertEquals( 3, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP")))); // Midpoint 4 assertEquals( -4.682f, cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))), .001f); assertEquals( 4, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB")))); // assertEquals(-5.375f, // cartesianProductVector.probability(pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))), // .001f); // assertEquals(4, cartesianProductVector.midpoint(pack(g, g.mapNonterminal("NP"), // g.mapNonterminal("VB")))); }
@Override public final void binarySpmv( final CartesianProductVector cartesianProductVector, final ChartCell chartCell) { final Future<?>[] futures = new Future[grammarThreads]; final PackedArrayChart.TemporaryChartCell[] temporaryCells = threadLocalTemporaryCellArrays.get(); // Iterate over binary grammar segments for (int i = 0; i < grammarThreads; i++) { final int segmentStart = binaryRowSegments[i]; final int segmentEnd = binaryRowSegments[i + 1]; final PackedArrayChart.TemporaryChartCell tmpCell = temporaryCells[i]; if (cellSelector.hasCellConstraints() && cellSelector.isCellOnlyFactored(chartCell.start(), chartCell.end())) { futures[i] = threadPool.submit( new Runnable() { @Override public void run() { tmpCell.clear(); // Multiply by the factored grammar rule matrix binarySpmvMultiply( cartesianProductVector, grammar.factoredCscBinaryPopulatedColumns, grammar.factoredCscBinaryPopulatedColumnOffsets, grammar.factoredCscBinaryRowIndices, grammar.factoredCscBinaryProbabilities, tmpCell.packedChildren, tmpCell.insideProbabilities, tmpCell.midpoints, segmentStart, segmentEnd); } }); } else { futures[i] = threadPool.submit( new Runnable() { @Override public void run() { tmpCell.clear(); // Multiply by the full grammar rule matrix binarySpmvMultiply( cartesianProductVector, grammar.cscBinaryPopulatedColumns, grammar.cscBinaryPopulatedColumnOffsets, grammar.cscBinaryRowIndices, grammar.cscBinaryProbabilities, tmpCell.packedChildren, tmpCell.insideProbabilities, tmpCell.midpoints, segmentStart, segmentEnd); } }); } } final PackedArrayChartCell packedArrayCell = (PackedArrayChartCell) chartCell; packedArrayCell.allocateTemporaryStorage(); try { // Wait for the first task to finish and use its arrays as the temporary cell storage futures[0].get(); final TemporaryChartCell tmpCell = packedArrayCell.tmpCell; // TODO Eliminate this extra arraycopy final int arrayLength = temporaryCells[0].insideProbabilities.length; System.arraycopy( temporaryCells[0].insideProbabilities, 0, tmpCell.insideProbabilities, 0, arrayLength); System.arraycopy(temporaryCells[0].packedChildren, 0, tmpCell.packedChildren, 0, arrayLength); System.arraycopy(temporaryCells[0].midpoints, 0, tmpCell.midpoints, 0, arrayLength); // Wait for each other task to complete and merge results into the main temporary storage for (int i = 1; i < grammarThreads; i++) { futures[i].get(); final PackedArrayChart.TemporaryChartCell threadTmpCell = temporaryCells[i]; for (int j = 0; j < arrayLength; j++) { if (threadTmpCell.insideProbabilities[j] > tmpCell.insideProbabilities[j]) { tmpCell.insideProbabilities[j] = threadTmpCell.insideProbabilities[j]; tmpCell.packedChildren[j] = threadTmpCell.packedChildren[j]; tmpCell.midpoints[j] = threadTmpCell.midpoints[j]; } } } } catch (final InterruptedException e) { e.printStackTrace(); } catch (final ExecutionException e) { e.printStackTrace(); } }