@Override protected AnalyzerOutPutDataBaseUpdate doPredict( DataBaseAnalyticSource source, PredictorConfig config) throws AnalysisException { DataSet dataSet = null; try { dataSet = getDataSet(source, source.getAnalyticConfig()); config.getTrainedModel().getModel().apply(dataSet); } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } AnalyzerOutPutDataBaseUpdate result = new AnalyzerOutPutDataBaseUpdate(); result.setDataset(dataSet); // set url user pwd ,schema, table fillDBInfo(result, (DataBaseAnalyticSource) source); Model model = config.getTrainedModel().getModel(); result.setUpdatedColumns(((NBModel) model).getUpdateColumns()); // good, bad , result.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale())); return result; }
@Override protected Model train(AnalyticSource analyticSource) throws AnalysisException { ResultSet rs = null; Statement st = null; try { IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(analyticSource.getDataSourceType()); dbtype = dataSourceInfo.getDBType(); RandomForestModel lastResult = null; RandomForestIMP randomForestImpl = null; // if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { // randomForestTrainer = new AdaboostOracle(); // // } else if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { randomForestImpl = new RandomForestGreenplum(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { randomForestImpl = new RandomForestOracle(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { randomForestImpl = new RandomForestDB2(); ((RandomForestDB2) randomForestImpl) .setConnection(((DataBaseAnalyticSource) analyticSource).getConnection()); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { randomForestImpl = new RandomForestNZ(); } else { throw new AnalysisException("Databse type is not supported for Random Forest:" + dbtype); // return null; } try { dataSet = getDataSet((DataBaseAnalyticSource) analyticSource, analyticSource.getAnalyticConfig()); } catch (OperatorException e1) { logger.error(e1); throw new OperatorException(e1.getLocalizedMessage()); } setSpecifyColumn(dataSet, analyticSource.getAnalyticConfig()); dataSet.computeAllColumnStatistics(); RandomForestConfig rfConfig = (RandomForestConfig) analyticSource.getAnalyticConfig(); String dbSystem = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getSystem(); String url = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUrl(); String userName = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUserName(); String password = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getPassword(); String inputSchema = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getSchema(); String tableName = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getTableName(); String useSSL = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUseSSL(); String sampleWithReplacement = rfConfig.getSampleWithReplacement(); long timeStamp = System.currentTimeMillis(); pnewTable = "pnew" + timeStamp; sampleTable = "s" + timeStamp; String dependentColumn = rfConfig.getDependentColumn(); String columnNames = rfConfig.getColumnNames(); String[] totalColumns = columnNames.split(","); int subSize = Integer.parseInt(rfConfig.getNodeColumnNumber()); int forestSize = Integer.parseInt(rfConfig.getForestSize()); Connection conncetion = null; if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { lastResult = new RandomForestModelGreenplum(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { lastResult = new RandomForestModelOracle(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { lastResult = new RandomForestModelDB2(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { lastResult = new RandomForestModelNZ(dataSet); } lastResult.setColumnNames(columnNames); lastResult.setDependColumn(dependentColumn); lastResult.setTableName(tableName); conncetion = ((DataBaseAnalyticSource) analyticSource).getConnection(); Model result = null; try { st = conncetion.createStatement(); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } // Iterator<String> dependvalueIterator = dataSet.getColumns() // .getLabel().getMapping().getValues().iterator(); if (dataSet.getColumns().getLabel() instanceof NominalColumn) { if (dataSet.getColumns().getLabel().getMapping().getValues().size() <= 1) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_SAMPLE_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } if (dataSet.getColumns().getLabel().getMapping().getValues().size() > AlpineMinerConfig.ADABOOST_MAX_DEPENDENT_COUNT) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_MAX_DEPENDENT_COUNT_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } } try { randomForestImpl.randomForestTrainInit( inputSchema, tableName, timeStamp, dependentColumn, st, dataSet); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } CartConfig config = new CartConfig(); config.setDependentColumn(dependentColumn); config.setConfidence(rfConfig.getConfidence()); config.setMaximal_depth(rfConfig.getMaximal_depth()); config.setMinimal_leaf_size(rfConfig.getMinimal_leaf_size()); config.setMinimal_size_for_split(rfConfig.getMinimal_size_for_split()); config.setNo_pre_pruning("true"); config.setNo_pruning("true"); for (int i = 0; i < forestSize; i++) { CartTrainer analyzer = new CartTrainer(); if (sampleWithReplacement == Resources.TrueOpt) { randomForestImpl.randomForestSample( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale()); } else { randomForestImpl.randomForestSampleNoReplace( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale(), dataSet.size()); } String subColumns = getSubColumns(totalColumns, subSize); config.setColumnNames(subColumns); DataBaseAnalyticSource tempsource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, sampleTable + i, useSSL); tempsource.setAnalyticConfiguration(config); tempsource.setConenction(conncetion); result = ((AnalyzerOutPutTrainModel) analyzer.doAnalysis(tempsource)) .getEngineModel() .getModel(); String OOBTable = "OOB" + sampleTable + i; randomForestImpl.generateOOBTable( inputSchema, OOBTable, pnewTable, sampleTable + i, st, rs); DataBaseAnalyticSource tempPredictSource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, OOBTable, useSSL); String predictOutTable = "OOBPredict" + sampleTable; EngineModel em = new EngineModel(); em.setModel(result); PredictorConfig tempconfig = new PredictorConfig(em); tempconfig.setDropIfExist(dropIfExists); tempconfig.setOutputSchema(inputSchema); tempconfig.setOutputTable(predictOutTable); tempPredictSource.setAnalyticConfiguration(tempconfig); tempPredictSource.setConenction(conncetion); AbstractDBModelPredictor predictor = new CartPredictor(); predictor.doAnalysis(tempPredictSource); // use the weak alg , do double OOBError = 0.0; if (result instanceof DecisionTreeModel) { OOBError = randomForestImpl.getOOBError( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobEstimateError().add(OOBError); } else if (result instanceof RegressionTreeModel) { OOBError = randomForestImpl.getMSE(tempPredictSource, "P(" + dependentColumn + ")"); lastResult.getOobLoss().add(OOBError); double OOBMape = randomForestImpl.getMAPE( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobMape().add(OOBMape); } else { OOBError = Double.NaN; lastResult.getOobLoss().add(OOBError); } lastResult.addModel((SingleModel) result); randomForestImpl.clearTrainResult(inputSchema, sampleTable + i); randomForestImpl.clearTrainResult(inputSchema, predictOutTable); randomForestImpl.clearTrainResult(inputSchema, OOBTable); } return lastResult; } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } finally { try { if (st != null) { st.close(); } if (rs != null) { rs.close(); } } catch (SQLException e) { logger.error(e); throw new AnalysisException(e.getLocalizedMessage()); } } }