@SuppressWarnings("serial") public static void main(String[] args) throws Exception { if (!parseParameters(args)) { return; } ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Edge<Long, NullValue>> edges = getEdgesDataSet(env); Graph<Long, Long, NullValue> graph = Graph.fromDataSet( edges, new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { return value; } }, env); DataSet<Vertex<Long, Long>> verticesWithMinIds = graph.run(new GSAConnectedComponents<Long, Long, NullValue>(maxIterations)); // emit result if (fileOutput) { verticesWithMinIds.writeAsCsv(outputPath, "\n", ","); // since file sinks are lazy, we trigger the execution explicitly env.execute("Connected Components Example"); } else { verticesWithMinIds.print(); } }
/** {@inheritDoc} */ @Override protected LogicalGraph<VD, ED, GD> executeInternal( LogicalGraph<VD, ED, GD> firstGraph, LogicalGraph<VD, ED, GD> secondGraph) { final Long newGraphID = FlinkConstants.EXCLUDE_GRAPH_ID; Graph<Long, VD, ED> graph1 = firstGraph.getGellyGraph(); Graph<Long, VD, ED> graph2 = secondGraph.getGellyGraph(); // union vertex sets, group by vertex id, filter vertices where the group // contains exactly one vertex which belongs to the graph, the operator is // called on DataSet<Vertex<Long, VD>> newVertexSet = graph1 .getVertices() .union(graph2.getVertices()) .groupBy(new KeySelectors.VertexKeySelector<VD>()) .reduceGroup(new VertexGroupReducer<VD>(1L, firstGraph.getId(), secondGraph.getId())) .map(new VertexToGraphUpdater<VD>(newGraphID)); JoinFunction<Edge<Long, ED>, Vertex<Long, VD>, Edge<Long, ED>> joinFunc = new JoinFunction<Edge<Long, ED>, Vertex<Long, VD>, Edge<Long, ED>>() { @Override public Edge<Long, ED> join(Edge<Long, ED> leftTuple, Vertex<Long, VD> rightTuple) throws Exception { return leftTuple; } }; // In exclude(), we are only interested in edges that connect vertices // that are in the exclusion of the vertex sets. Thus, we join the edges // from the left graph with the new vertex set using source and target ids. DataSet<Edge<Long, ED>> newEdgeSet = graph1 .getEdges() .join(newVertexSet) .where(new KeySelectors.EdgeSourceVertexKeySelector<ED>()) .equalTo(new KeySelectors.VertexKeySelector<VD>()) .with(joinFunc) .join(newVertexSet) .where(new KeySelectors.EdgeTargetVertexKeySelector<ED>()) .equalTo(new KeySelectors.VertexKeySelector<VD>()) .with(joinFunc) .map(new EdgeToGraphUpdater<ED>(newGraphID)); return LogicalGraph.fromGraph( Graph.fromDataSet(newVertexSet, newEdgeSet, graph1.getContext()), firstGraph.getGraphDataFactory().createGraphData(newGraphID), firstGraph.getVertexDataFactory(), firstGraph.getEdgeDataFactory(), firstGraph.getGraphDataFactory()); }
@Test public void testTranslation() { try { final String ITERATION_NAME = "Test Name"; final String AGGREGATOR_NAME = "AggregatorName"; final String BC_SET_GATHER_NAME = "gather messages"; final String BC_SET_SUM_NAME = "sum updates"; final String BC_SET_APLLY_NAME = "apply updates"; final int NUM_ITERATIONS = 13; final int ITERATION_parallelism = 77; ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> bcGather = env.fromElements(1L); DataSet<Long> bcSum = env.fromElements(1L); DataSet<Long> bcApply = env.fromElements(1L); DataSet<Vertex<Long, Long>> result; // ------------ construct the test program ------------------ { DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple3<Long, Long, NullValue>(1L, 2L, NullValue.getInstance())) .map(new Tuple3ToEdgeMap<Long, NullValue>()); Graph<Long, Long, NullValue> graph = Graph.fromDataSet(edges, new InitVertices(), env); GSAConfiguration parameters = new GSAConfiguration(); parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator()); parameters.setName(ITERATION_NAME); parameters.setParallelism(ITERATION_parallelism); parameters.addBroadcastSetForGatherFunction(BC_SET_GATHER_NAME, bcGather); parameters.addBroadcastSetForSumFunction(BC_SET_SUM_NAME, bcSum); parameters.addBroadcastSetForApplyFunction(BC_SET_APLLY_NAME, bcApply); result = graph .runGatherSumApplyIteration( new GatherNeighborIds(), new SelectMinId(), new UpdateComponentId(), NUM_ITERATIONS, parameters) .getVertices(); result.output(new DiscardingOutputFormat<Vertex<Long, Long>>()); } // ------------- validate the java program ---------------- assertTrue(result instanceof DeltaIterationResultSet); DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>) result; DeltaIteration<?, ?> iteration = (DeltaIteration<?, ?>) resultSet.getIterationHead(); // check the basic iteration properties assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations()); assertArrayEquals(new int[] {0}, resultSet.getKeyPositions()); assertEquals(ITERATION_parallelism, iteration.getParallelism()); assertEquals(ITERATION_NAME, iteration.getName()); assertEquals( AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName()); // validate that the semantic properties are set as they should TwoInputUdfOperator<?, ?, ?, ?> solutionSetJoin = (TwoInputUdfOperator<?, ?, ?, ?>) resultSet.getNextWorkset(); assertTrue( solutionSetJoin.getSemanticProperties().getForwardingTargetFields(0, 0).contains(0)); assertTrue( solutionSetJoin.getSemanticProperties().getForwardingTargetFields(1, 0).contains(0)); SingleInputUdfOperator<?, ?, ?> sumReduce = (SingleInputUdfOperator<?, ?, ?>) solutionSetJoin.getInput1(); SingleInputUdfOperator<?, ?, ?> gatherMap = (SingleInputUdfOperator<?, ?, ?>) sumReduce.getInput(); // validate that the broadcast sets are forwarded assertEquals(bcGather, gatherMap.getBroadcastSets().get(BC_SET_GATHER_NAME)); assertEquals(bcSum, sumReduce.getBroadcastSets().get(BC_SET_SUM_NAME)); assertEquals(bcApply, solutionSetJoin.getBroadcastSets().get(BC_SET_APLLY_NAME)); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail(e.getMessage()); } }