private SvmClassificationMethodType getClassificationMethod() { SupportVectorMachineModel supportVectorMachineModel = getModel(); // Older versions of several popular PMML producer software are known to omit the // classificationMethod attribute. // The method SupportVectorMachineModel#getSvmRepresentation() replaces a missing value with the // default value "OneAgainstAll", which may lead to incorrect behaviour. // The workaround is to bypass this method using Java Reflection API, and infer the correct // classification method type based on evidence. Field field = ReflectionUtil.getField(supportVectorMachineModel, "classificationMethod"); SvmClassificationMethodType svmClassificationMethod = ReflectionUtil.getFieldValue(field, supportVectorMachineModel); if (svmClassificationMethod != null) { return svmClassificationMethod; } List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory(); if (alternateBinaryTargetCategory != null) { if (supportVectorMachines.size() == 1) { SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0); String targetCategory = supportVectorMachine.getTargetCategory(); if (targetCategory != null) { return SvmClassificationMethodType.ONE_AGAINST_ONE; } throw new InvalidFeatureException(supportVectorMachine); } throw new InvalidFeatureException(supportVectorMachineModel); } for (SupportVectorMachine supportVectorMachine : supportVectorMachines) { String targetCategory = supportVectorMachine.getTargetCategory(); String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory(); if (targetCategory != null) { if (alternateTargetCategory != null) { return SvmClassificationMethodType.ONE_AGAINST_ONE; } return SvmClassificationMethodType.ONE_AGAINST_ALL; } throw new InvalidFeatureException(supportVectorMachine); } throw new InvalidFeatureException(supportVectorMachineModel); }
private Double evaluateSupportVectorMachine( SupportVectorMachine supportVectorMachine, double[] input) { SupportVectorMachineModel supportVectorMachineModel = getModel(); double result = 0d; Kernel kernel = supportVectorMachineModel.getKernel(); Coefficients coefficients = supportVectorMachine.getCoefficients(); Iterator<Coefficient> coefficientIterator = coefficients.iterator(); SupportVectors supportVectors = supportVectorMachine.getSupportVectors(); Iterator<SupportVector> supportVectorIterator = supportVectors.iterator(); Map<String, double[]> vectorMap = getVectorMap(); while (coefficientIterator.hasNext() && supportVectorIterator.hasNext()) { Coefficient coefficient = coefficientIterator.next(); SupportVector supportVector = supportVectorIterator.next(); double[] vector = vectorMap.get(supportVector.getVectorId()); if (vector == null) { throw new InvalidFeatureException(supportVector); } Double value = KernelUtil.evaluate(kernel, input, vector); result += (coefficient.getValue() * value); } if (coefficientIterator.hasNext() || supportVectorIterator.hasNext()) { throw new InvalidFeatureException(supportVectorMachine); } result += coefficients.getAbsoluteValue(); return result; }
private Map<FieldName, ? extends Classification> evaluateClassification( ModelEvaluationContext context) { SupportVectorMachineModel supportVectorMachineModel = getModel(); List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); if (supportVectorMachines.size() < 1) { throw new InvalidFeatureException(supportVectorMachineModel); } String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory(); Classification result; SvmClassificationMethodType svmClassificationMethod = getClassificationMethod(); switch (svmClassificationMethod) { case ONE_AGAINST_ALL: result = new Classification(Classification.Type.DISTANCE); break; case ONE_AGAINST_ONE: result = new VoteDistribution(); break; default: throw new UnsupportedFeatureException(supportVectorMachineModel, svmClassificationMethod); } double[] input = createInput(context); for (SupportVectorMachine supportVectorMachine : supportVectorMachines) { String targetCategory = supportVectorMachine.getTargetCategory(); String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory(); Double value = evaluateSupportVectorMachine(supportVectorMachine, input); switch (svmClassificationMethod) { case ONE_AGAINST_ALL: { if (targetCategory == null || alternateTargetCategory != null) { throw new InvalidFeatureException(supportVectorMachine); } result.put(targetCategory, value); } break; case ONE_AGAINST_ONE: if (alternateBinaryTargetCategory != null) { if (targetCategory == null || alternateTargetCategory != null) { throw new InvalidFeatureException(supportVectorMachine); } String label; long roundedValue = Math.round(value); // "A rounded value of 1 corresponds to the targetCategory attribute of the // SupportVectorMachine element" if (roundedValue == 1) { label = targetCategory; } else // "A rounded value of 0 corresponds to the alternateBinaryTargetCategory attribute of // the SupportVectorMachineModel element" if (roundedValue == 0) { label = alternateBinaryTargetCategory; } else // "The numeric prediction must be between 0 and 1" { throw new EvaluationException("Invalid numeric prediction " + value); } Double vote = result.get(label); if (vote == null) { vote = 0d; } result.put(label, (vote + 1d)); } else { if (targetCategory == null || alternateTargetCategory == null) { throw new InvalidFeatureException(supportVectorMachine); } Double threshold = supportVectorMachine.getThreshold(); if (threshold == null) { threshold = supportVectorMachineModel.getThreshold(); } String label; // "If the numeric prediction is smaller than the threshold, it corresponds to the // targetCategory attribute" if ((value).compareTo(threshold) < 0) { label = targetCategory; } else { label = alternateTargetCategory; } Double vote = result.get(label); if (vote == null) { vote = 0d; } result.put(label, (vote + 1d)); } break; default: break; } } return TargetUtil.evaluateClassification(result, context); }