public static Evaluator createEvaluator(InputStream is) throws SAXException, JAXBException { Source source = ImportFilter.apply(new InputSource(is)); PMML pmml = JAXBUtil.unmarshalPMML(source); // If the SAX Locator information is available, then transform it to java.io.Serializable // representation LocatorTransformer locatorTransformer = new LocatorTransformer(); locatorTransformer.applyTo(pmml); ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory.newModelManager(pmml); modelEvaluator.verify(); return modelEvaluator; }
@Test public void resolve() throws Exception { PMML pmml; try (InputStream is = PMMLUtil.getResourceAsStream(FieldResolverTest.class)) { pmml = JAXBUtil.unmarshalPMML(new StreamSource(is)); } final Set<FieldName> dataFieldNames = FieldNameUtil.create("y", "x1", "x2", "x3"); final Set<FieldName> pmmlNames = FieldNameUtil.create(dataFieldNames, "x1_squared", "x1_cubed"); FieldResolver applyResolver = new FieldResolver() { @Override public VisitorAction visit(Apply apply) { Set<Field> fields = getFields(); String function = apply.getFunction(); if ("*".equals(function)) { DerivedField derivedField = (DerivedField) VisitorUtil.getParent(this); FieldName name = derivedField.getName(); if ("x1_squared".equals(name.getValue())) { checkFields(dataFieldNames, fields); } else if ("x1_cubed".equals(name.getValue())) { checkFields(FieldNameUtil.create(dataFieldNames, "x1_squared"), fields); } else { throw new AssertionError(); } } else if ("pow".equals(function)) { checkFields(FieldNameUtil.create("x"), fields); } else if ("square".equals(function)) { checkFields(FieldNameUtil.create(pmmlNames, "first_output"), fields); } else if ("cube".equals(function)) { checkFields(FieldNameUtil.create(pmmlNames, "first_output", "x2_squared"), fields); } else { throw new AssertionError(); } return super.visit(apply); } }; applyResolver.applyTo(pmml); assertEquals(Collections.emptySet(), applyResolver.getFields()); FieldResolver regressionTableResolver = new FieldResolver() { @Override public VisitorAction visit(RegressionTable regressionTable) { Set<Field> fields = getFields(); Segment segment = (Segment) VisitorUtil.getParent(this, 1); String id = segment.getId(); if ("first".equals(id)) { checkFields(pmmlNames, fields); } else if ("second".equals(id)) { checkFields( FieldNameUtil.create(pmmlNames, "first_output", "x2_squared", "x2_cubed"), fields); } else if ("third".equals(id)) { checkFields(FieldNameUtil.create(pmmlNames, "first_output", "second_output"), fields); } else if ("sum".equals(id)) { checkFields( FieldNameUtil.create(pmmlNames, "first_output", "second_output", "third_output"), fields); } else { throw new AssertionError(); } return super.visit(regressionTable); } }; regressionTableResolver.applyTo(pmml); assertEquals(Collections.emptySet(), regressionTableResolver.getFields()); FieldResolver predicateResolver = new FieldResolver() { @Override public VisitorAction visit(SimplePredicate simplePredicate) { Set<Field> fields = getFields(); Segment segment = (Segment) VisitorUtil.getParent(this); String id = segment.getId(); if ("first".equals(id)) { checkFields(pmmlNames, fields); } else if ("second".equals(id)) { checkFields(FieldNameUtil.create(pmmlNames, "first_output"), fields); } else if ("third".equals(id)) { checkFields(FieldNameUtil.create(pmmlNames, "first_output", "second_output"), fields); } else { throw new AssertionError(); } return super.visit(simplePredicate); } }; predicateResolver.applyTo(pmml); assertEquals(Collections.emptySet(), predicateResolver.getFields()); }