/** * Tests the binary SpMV multiplication of the cartesian-product computed in {@link * #testUnfilteredCartesianProductVectorSimpleGrammar2()} with simple grammar 2. * * @throws Exception if something bad happens */ @Test public void testBinarySpMVMultiplySimpleGrammar2() throws Exception { // Create the parser final SparseMatrixGrammar g = (SparseMatrixGrammar) simpleGrammar2; final SparseMatrixVectorParser<?, ?> p = createParser(g, parserOptions(), configProperties()); final ParseTask parseTask = new ParseTask( "The fish market stands last", Parser.InputFormat.Text, g, DecodeMethod.ViterbiMax); p.initSentence(parseTask); final Chart chart = p.chart; final float[] probabilities = new float[g.packingFunction().packedArraySize()]; Arrays.fill(probabilities, Float.NEGATIVE_INFINITY); final short[] midpoints = new short[g.packingFunction().packedArraySize()]; populateSimpleGrammar2Rows1_3(chart, g); // // Test SpMV for cell 0,4 // // Midpoint 1 probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = -2.890f; midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("S"))] = 1; probabilities[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = -2.890f; midpoints[pack(g, g.mapNonterminal("DT"), g.mapNonterminal("NP"))] = 1; // Midpoint 2 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = -2.485f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP|NN"))] = 2; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 2; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 2; // Midpoint 3 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.663f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NN"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -4.277f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 3; CartesianProductVector crossProductVector = new CartesianProductVector(g, probabilities, midpoints, 8); // Check the SpMV multiplication final ChartCell cell_0_4 = p.chart.getCell(0, 4); p.binarySpmv(crossProductVector, cell_0_4); assertEquals(2, cell_0_4.getNumNTs()); final ChartEdge np = cell_0_4.getBestEdge(g.mapNonterminal("NP")); assertEquals(-4.27667, np.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 1), np.leftCell); assertEquals("Wrong right child cell", chart.getCell(1, 4), np.rightCell); ChartEdge s = cell_0_4.getBestEdge(g.mapNonterminal("S")); assertEquals(-5.66296f, s.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell); assertEquals("Wrong right child cell", chart.getCell(3, 4), s.rightCell); // // Test SpMV for cell 0,5 // populateSimpleGrammar2Rows1_3(chart, g); populateSimpleGrammar2Row4(chart, g); Arrays.fill(probabilities, Float.NEGATIVE_INFINITY); // Midpoint 3 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = -5.37528f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = -6.474f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VP|VB"))] = 3; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = -6.474f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("NP"))] = 3; // Midpoint 4 probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = -4.682f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("RB"))] = 4; probabilities[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = -5.375f; midpoints[pack(g, g.mapNonterminal("NP"), g.mapNonterminal("VB"))] = 4; crossProductVector = new CartesianProductVector(g, probabilities, midpoints, 8); // Check the SpMV multiplication final ChartCell cell_0_5 = p.chart.getCell(0, 5); p.binarySpmv(crossProductVector, cell_0_5); assertEquals(1, cell_0_5.getNumNTs()); s = cell_0_5.getBestEdge(g.mapNonterminal("S")); assertEquals(-5.37528f, s.inside(), .001f); assertEquals("Wrong left child cell", chart.getCell(0, 3), s.leftCell); assertEquals("Wrong right child cell", chart.getCell(3, 5), s.rightCell); }