@Override public ConnectorPageSource createPageSource( Session session, Split split, List<ColumnHandle> columns) { assertInstanceOf(split.getConnectorSplit(), FunctionAssertions.TestSplit.class); FunctionAssertions.TestSplit testSplit = (FunctionAssertions.TestSplit) split.getConnectorSplit(); if (testSplit.isRecordSet()) { RecordSet records = InMemoryRecordSet.builder( ImmutableList.<Type>of( BIGINT, VARCHAR, DOUBLE, BOOLEAN, BIGINT, VARCHAR, VARCHAR, TIMESTAMP_WITH_TIME_ZONE)) .addRow( 1234L, "hello", 12.34, true, new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(), "%el%", null, packDateTimeWithZone( new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z"))) .build(); return new RecordPageSource(records); } else { return new FixedPageSource(ImmutableList.of(SOURCE_PAGE)); } }
public final class FunctionAssertions { private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-%s")); private static final SqlParser SQL_PARSER = new SqlParser(); private static final Page SOURCE_PAGE = new Page( createLongsBlock(1234L), createStringsBlock("hello"), createDoublesBlock(12.34), createBooleansBlock(true), createLongsBlock(new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis()), createStringsBlock("%el%"), createStringsBlock((String) null), createTimestampsWithTimezoneBlock( packDateTimeWithZone( new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z")))); private static final Page ZERO_CHANNEL_PAGE = new Page(1); private static final Map<Integer, Type> INPUT_TYPES = ImmutableMap.<Integer, Type>builder() .put(0, BIGINT) .put(1, VARCHAR) .put(2, DOUBLE) .put(3, BOOLEAN) .put(4, BIGINT) .put(5, VARCHAR) .put(6, VARCHAR) .put(7, TIMESTAMP_WITH_TIME_ZONE) .build(); private static final Map<Symbol, Integer> INPUT_MAPPING = ImmutableMap.<Symbol, Integer>builder() .put(new Symbol("bound_long"), 0) .put(new Symbol("bound_string"), 1) .put(new Symbol("bound_double"), 2) .put(new Symbol("bound_boolean"), 3) .put(new Symbol("bound_timestamp"), 4) .put(new Symbol("bound_pattern"), 5) .put(new Symbol("bound_null_string"), 6) .put(new Symbol("bound_timestamp_with_timezone"), 7) .build(); private static final Map<Symbol, Type> SYMBOL_TYPES = ImmutableMap.<Symbol, Type>builder() .put(new Symbol("bound_long"), BIGINT) .put(new Symbol("bound_string"), VARCHAR) .put(new Symbol("bound_double"), DOUBLE) .put(new Symbol("bound_boolean"), BOOLEAN) .put(new Symbol("bound_timestamp"), BIGINT) .put(new Symbol("bound_pattern"), VARCHAR) .put(new Symbol("bound_null_string"), VARCHAR) .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) .build(); private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider(); private static final PlanNodeId SOURCE_ID = new PlanNodeId("scan"); private final Session session; private final LocalQueryRunner runner; private final Metadata metadata; private final ExpressionCompiler compiler; public FunctionAssertions() { this(TEST_SESSION); } public FunctionAssertions(Session session) { this.session = requireNonNull(session, "session is null"); runner = new LocalQueryRunner(session); metadata = runner.getMetadata(); compiler = new ExpressionCompiler(metadata); } public Metadata getMetadata() { return metadata; } public FunctionAssertions addFunctions(List<SqlFunction> functionInfos) { metadata.addFunctions(functionInfos); return this; } public FunctionAssertions addScalarFunctions(Class<?> clazz) { metadata.addFunctions( new FunctionListBuilder(metadata.getTypeManager()).scalar(clazz).getFunctions()); return this; } public void assertFunction(String projection, Type expectedType, Object expected) { if (expected instanceof Integer) { expected = ((Integer) expected).longValue(); } else if (expected instanceof Slice) { expected = ((Slice) expected).toString(UTF_8); } Object actual = selectSingleValue(projection, expectedType, compiler); try { assertEquals(actual, expected); } catch (Throwable e) { throw e; } } public void tryEvaluate(String expression, Type expectedType) { tryEvaluate(expression, expectedType, session); } public void tryEvaluate(String expression, Type expectedType, Session session) { selectUniqueValue(expression, expectedType, session, compiler); } public void tryEvaluateWithAll(String expression, Type expectedType, Session session) { executeProjectionWithAll(expression, expectedType, session, compiler); } private Object selectSingleValue( String projection, Type expectedType, ExpressionCompiler compiler) { return selectUniqueValue(projection, expectedType, session, compiler); } private Object selectUniqueValue( String projection, Type expectedType, Session session, ExpressionCompiler compiler) { List<Object> results = executeProjectionWithAll(projection, expectedType, session, compiler); HashSet<Object> resultSet = new HashSet<>(results); // we should only have a single result assertTrue( resultSet.size() == 1, "Expected only one result unique result, but got " + resultSet); return Iterables.getOnlyElement(resultSet); } private List<Object> executeProjectionWithAll( String projection, Type expectedType, Session session, ExpressionCompiler compiler) { requireNonNull(projection, "projection is null"); Expression projectionExpression = createExpression(projection, metadata, SYMBOL_TYPES); List<Object> results = new ArrayList<>(); // // If the projection does not need bound values, execute query using full engine if (!needsBoundValue(projectionExpression)) { MaterializedResult result = runner.execute("SELECT " + projection); assertType(result.getTypes(), expectedType); assertEquals(result.getTypes().size(), 1); assertEquals(result.getMaterializedRows().size(), 1); Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0); results.add(queryResult); } // execute as standalone operator OperatorFactory operatorFactory = compileFilterProject(TRUE_LITERAL, projectionExpression, compiler); assertType(operatorFactory.getTypes(), expectedType); Object directOperatorValue = selectSingleValue(operatorFactory, session); results.add(directOperatorValue); // interpret Operator interpretedFilterProject = interpretedFilterProject(TRUE_LITERAL, projectionExpression, session); assertType(interpretedFilterProject.getTypes(), expectedType); Object interpretedValue = selectSingleValue(interpretedFilterProject); results.add(interpretedValue); // execute over normal operator SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(TRUE_LITERAL, projectionExpression, compiler); assertType(scanProjectOperatorFactory.getTypes(), expectedType); Object scanOperatorValue = selectSingleValue(scanProjectOperatorFactory, createNormalSplit(), session); results.add(scanOperatorValue); // execute over record set Object recordValue = selectSingleValue(scanProjectOperatorFactory, createRecordSetSplit(), session); results.add(recordValue); // // If the projection does not need bound values, execute query using full engine if (!needsBoundValue(projectionExpression)) { MaterializedResult result = runner.execute("SELECT " + projection); assertType(result.getTypes(), expectedType); assertEquals(result.getTypes().size(), 1); assertEquals(result.getMaterializedRows().size(), 1); Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0); results.add(queryResult); } return results; } private Object selectSingleValue(OperatorFactory operatorFactory, Session session) { Operator operator = operatorFactory.createOperator(createDriverContext(session)); return selectSingleValue(operator); } private Object selectSingleValue( SourceOperatorFactory operatorFactory, Split split, Session session) { SourceOperator operator = operatorFactory.createOperator(createDriverContext(session)); operator.addSplit(split); operator.noMoreSplits(); return selectSingleValue(operator); } private Object selectSingleValue(Operator operator) { Page output = getAtMostOnePage(operator, SOURCE_PAGE); assertNotNull(output); assertEquals(output.getPositionCount(), 1); assertEquals(output.getChannelCount(), 1); Type type = operator.getTypes().get(0); Block block = output.getBlock(0); assertEquals(block.getPositionCount(), 1); return type.getObjectValue(session.toConnectorSession(), block, 0); } public void assertFilter(String filter, boolean expected, boolean withNoInputColumns) { assertFilter(filter, expected, withNoInputColumns, compiler); } private void assertFilter( String filter, boolean expected, boolean withNoInputColumns, ExpressionCompiler compiler) { List<Boolean> results = executeFilterWithAll(filter, TEST_SESSION, withNoInputColumns, compiler); HashSet<Boolean> resultSet = new HashSet<>(results); // we should only have a single result assertTrue( resultSet.size() == 1, "Expected only [" + expected + "] result unique result, but got " + resultSet); assertEquals((boolean) Iterables.getOnlyElement(resultSet), expected); } private List<Boolean> executeFilterWithAll( String filter, Session session, boolean executeWithNoInputColumns, ExpressionCompiler compiler) { requireNonNull(filter, "filter is null"); Expression filterExpression = createExpression(filter, metadata, SYMBOL_TYPES); List<Boolean> results = new ArrayList<>(); // execute as standalone operator OperatorFactory operatorFactory = compileFilterProject(filterExpression, TRUE_LITERAL, compiler); results.add(executeFilter(operatorFactory, session)); if (executeWithNoInputColumns) { // execute as standalone operator operatorFactory = compileFilterWithNoInputColumns(filterExpression, compiler); results.add(executeFilterWithNoInputColumns(operatorFactory, session)); } // interpret boolean interpretedValue = executeFilter(interpretedFilterProject(filterExpression, TRUE_LITERAL, session)); results.add(interpretedValue); // execute over normal operator SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(filterExpression, TRUE_LITERAL, compiler); boolean scanOperatorValue = executeFilter(scanProjectOperatorFactory, createNormalSplit(), session); results.add(scanOperatorValue); // execute over record set boolean recordValue = executeFilter(scanProjectOperatorFactory, createRecordSetSplit(), session); results.add(recordValue); // // If the filter does not need bound values, execute query using full engine if (!needsBoundValue(filterExpression)) { MaterializedResult result = runner.execute("SELECT TRUE WHERE " + filter); assertEquals(result.getTypes().size(), 1); Boolean queryResult; if (result.getMaterializedRows().isEmpty()) { queryResult = false; } else { assertEquals(result.getMaterializedRows().size(), 1); queryResult = (Boolean) Iterables.getOnlyElement(result.getMaterializedRows()).getField(0); } results.add(queryResult); } return results; } public static Expression createExpression( String expression, Metadata metadata, Map<Symbol, Type> symbolTypes) { Expression parsedExpression = SQL_PARSER.createExpression(expression); final ExpressionAnalysis analysis = analyzeExpressionsWithSymbols( TEST_SESSION, metadata, SQL_PARSER, symbolTypes, ImmutableList.of(parsedExpression)); Expression rewrittenExpression = ExpressionTreeRewriter.rewriteWith( new ExpressionRewriter<Void>() { @Override public Expression rewriteExpression( Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context); // cast expression if coercion is registered Type coercion = analysis.getCoercion(node); if (coercion != null) { rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString()); } return rewrittenExpression; } @Override public Expression rewriteDereferenceExpression( DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { if (analysis.getColumnReferences().contains(node)) { return rewriteExpression(node, context, treeRewriter); } // Rewrite all row field reference to function call. QualifiedName mangledName = QualifiedName.of(mangleFieldReference(node.getFieldName())); FunctionCall functionCall = new FunctionCall(mangledName, ImmutableList.of(node.getBase())); Expression rewrittenExpression = rewriteFunctionCall(functionCall, context, treeRewriter); // cast expression if coercion is registered Type coercion = analysis.getCoercion(node); if (coercion != null) { rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString()); } return rewrittenExpression; } }, parsedExpression); return canonicalizeExpression(rewrittenExpression); } private static boolean executeFilterWithNoInputColumns( OperatorFactory operatorFactory, Session session) { return executeFilterWithNoInputColumns( operatorFactory.createOperator(createDriverContext(session))); } private static boolean executeFilter(OperatorFactory operatorFactory, Session session) { return executeFilter(operatorFactory.createOperator(createDriverContext(session))); } private static boolean executeFilter( SourceOperatorFactory operatorFactory, Split split, Session session) { SourceOperator operator = operatorFactory.createOperator(createDriverContext(session)); operator.addSplit(split); operator.noMoreSplits(); return executeFilter(operator); } private static boolean executeFilter(Operator operator) { Page page = getAtMostOnePage(operator, SOURCE_PAGE); boolean value; if (page != null) { assertEquals(page.getPositionCount(), 1); assertEquals(page.getChannelCount(), 1); assertTrue(operator.getTypes().get(0).getBoolean(page.getBlock(0), 0)); value = true; } else { value = false; } return value; } private static boolean executeFilterWithNoInputColumns(Operator operator) { Page page = getAtMostOnePage(operator, ZERO_CHANNEL_PAGE); boolean value; if (page != null) { assertEquals(page.getPositionCount(), 1); assertEquals(page.getChannelCount(), 0); value = true; } else { value = false; } return value; } private static boolean needsBoundValue(Expression projectionExpression) { final AtomicBoolean hasQualifiedNameReference = new AtomicBoolean(); projectionExpression.accept( new DefaultTraversalVisitor<Void, Void>() { @Override protected Void visitQualifiedNameReference(QualifiedNameReference node, Void context) { hasQualifiedNameReference.set(true); return null; } }, null); return hasQualifiedNameReference.get(); } private Operator interpretedFilterProject( Expression filter, Expression projection, Session session) { FilterFunction filterFunction = new InterpretedFilterFunction( filter, SYMBOL_TYPES, INPUT_MAPPING, metadata, SQL_PARSER, session); ProjectionFunction projectionFunction = new InterpretedProjectionFunction( projection, SYMBOL_TYPES, INPUT_MAPPING, metadata, SQL_PARSER, session); OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( 0, new GenericPageProcessor(filterFunction, ImmutableList.of(projectionFunction)), toTypes(ImmutableList.of(projectionFunction))); return operatorFactory.createOperator(createDriverContext(session)); } private OperatorFactory compileFilterWithNoInputColumns( Expression filter, ExpressionCompiler compiler) { filter = ExpressionTreeRewriter.rewriteWith( new SymbolToInputRewriter(ImmutableMap.<Symbol, Integer>of()), filter); IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput( TEST_SESSION, metadata, SQL_PARSER, INPUT_TYPES, ImmutableList.of(filter)); try { PageProcessor processor = compiler.compilePageProcessor( toRowExpression(filter, expressionTypes), ImmutableList.of()); return new FilterAndProjectOperator.FilterAndProjectOperatorFactory( 0, processor, ImmutableList.<Type>of()); } catch (Throwable e) { if (e instanceof UncheckedExecutionException) { e = e.getCause(); } throw new RuntimeException("Error compiling " + filter + ": " + e.getMessage(), e); } } private OperatorFactory compileFilterProject( Expression filter, Expression projection, ExpressionCompiler compiler) { filter = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(INPUT_MAPPING), filter); projection = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(INPUT_MAPPING), projection); IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput( TEST_SESSION, metadata, SQL_PARSER, INPUT_TYPES, ImmutableList.of(filter, projection)); try { List<RowExpression> projections = ImmutableList.of(toRowExpression(projection, expressionTypes)); PageProcessor processor = compiler.compilePageProcessor(toRowExpression(filter, expressionTypes), projections); return new FilterAndProjectOperator.FilterAndProjectOperatorFactory( 0, processor, ImmutableList.of(expressionTypes.get(projection))); } catch (Throwable e) { if (e instanceof UncheckedExecutionException) { e = e.getCause(); } throw new RuntimeException("Error compiling " + projection + ": " + e.getMessage(), e); } } private SourceOperatorFactory compileScanFilterProject( Expression filter, Expression projection, ExpressionCompiler compiler) { filter = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(INPUT_MAPPING), filter); projection = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(INPUT_MAPPING), projection); IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput( TEST_SESSION, metadata, SQL_PARSER, INPUT_TYPES, ImmutableList.of(filter, projection)); try { CursorProcessor cursorProcessor = compiler.compileCursorProcessor( toRowExpression(filter, expressionTypes), ImmutableList.of(toRowExpression(projection, expressionTypes)), SOURCE_ID); PageProcessor pageProcessor = compiler.compilePageProcessor( toRowExpression(filter, expressionTypes), ImmutableList.of(toRowExpression(projection, expressionTypes))); return new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( 0, SOURCE_ID, PAGE_SOURCE_PROVIDER, cursorProcessor, pageProcessor, ImmutableList.<ColumnHandle>of(), ImmutableList.of(expressionTypes.get(projection))); } catch (Throwable e) { if (e instanceof UncheckedExecutionException) { e = e.getCause(); } throw new RuntimeException("Error compiling " + projection + ": " + e.getMessage(), e); } } private RowExpression toRowExpression( Expression projection, IdentityHashMap<Expression, Type> expressionTypes) { return SqlToRowExpressionTranslator.translate( projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) { // add our input page if needed if (operator.needsInput()) { operator.addInput(sourcePage); } // try to get the output page Page result = operator.getOutput(); // tell operator to finish operator.finish(); // try to get output until the operator is finished while (!operator.isFinished()) { // operator should never block assertTrue(operator.isBlocked().isDone()); Page output = operator.getOutput(); if (output != null) { assertNull(result); result = output; } } return result; } private static DriverContext createDriverContext(Session session) { return createTaskContext(EXECUTOR, session).addPipelineContext(true, true).addDriverContext(); } private static void assertType(List<Type> types, Type expectedType) { assertTrue(types.size() == 1, "Expected one type, but got " + types); Type actualType = types.get(0); assertEquals(actualType, expectedType); } private static class TestPageSourceProvider implements PageSourceProvider { @Override public ConnectorPageSource createPageSource( Session session, Split split, List<ColumnHandle> columns) { assertInstanceOf(split.getConnectorSplit(), FunctionAssertions.TestSplit.class); FunctionAssertions.TestSplit testSplit = (FunctionAssertions.TestSplit) split.getConnectorSplit(); if (testSplit.isRecordSet()) { RecordSet records = InMemoryRecordSet.builder( ImmutableList.<Type>of( BIGINT, VARCHAR, DOUBLE, BOOLEAN, BIGINT, VARCHAR, VARCHAR, TIMESTAMP_WITH_TIME_ZONE)) .addRow( 1234L, "hello", 12.34, true, new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(), "%el%", null, packDateTimeWithZone( new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z"))) .build(); return new RecordPageSource(records); } else { return new FixedPageSource(ImmutableList.of(SOURCE_PAGE)); } } } static class TestSplit implements ConnectorSplit { static Split createRecordSetSplit() { return new Split("test", new TestSplit(true)); } static Split createNormalSplit() { return new Split("test", new TestSplit(false)); } private final boolean recordSet; private TestSplit(boolean recordSet) { this.recordSet = recordSet; } private boolean isRecordSet() { return recordSet; } @Override public boolean isRemotelyAccessible() { return false; } @Override public List<HostAddress> getAddresses() { return ImmutableList.of(); } @Override public Object getInfo() { return this; } } }