@Override protected Model train(AnalyticSource analyticSource) throws AnalysisException { ResultSet rs = null; Statement st = null; try { IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(analyticSource.getDataSourceType()); dbtype = dataSourceInfo.getDBType(); RandomForestModel lastResult = null; RandomForestIMP randomForestImpl = null; // if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { // randomForestTrainer = new AdaboostOracle(); // // } else if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { randomForestImpl = new RandomForestGreenplum(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { randomForestImpl = new RandomForestOracle(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { randomForestImpl = new RandomForestDB2(); ((RandomForestDB2) randomForestImpl) .setConnection(((DataBaseAnalyticSource) analyticSource).getConnection()); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { randomForestImpl = new RandomForestNZ(); } else { throw new AnalysisException("Databse type is not supported for Random Forest:" + dbtype); // return null; } try { dataSet = getDataSet((DataBaseAnalyticSource) analyticSource, analyticSource.getAnalyticConfig()); } catch (OperatorException e1) { logger.error(e1); throw new OperatorException(e1.getLocalizedMessage()); } setSpecifyColumn(dataSet, analyticSource.getAnalyticConfig()); dataSet.computeAllColumnStatistics(); RandomForestConfig rfConfig = (RandomForestConfig) analyticSource.getAnalyticConfig(); String dbSystem = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getSystem(); String url = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUrl(); String userName = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUserName(); String password = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getPassword(); String inputSchema = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getSchema(); String tableName = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getTableName(); String useSSL = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUseSSL(); String sampleWithReplacement = rfConfig.getSampleWithReplacement(); long timeStamp = System.currentTimeMillis(); pnewTable = "pnew" + timeStamp; sampleTable = "s" + timeStamp; String dependentColumn = rfConfig.getDependentColumn(); String columnNames = rfConfig.getColumnNames(); String[] totalColumns = columnNames.split(","); int subSize = Integer.parseInt(rfConfig.getNodeColumnNumber()); int forestSize = Integer.parseInt(rfConfig.getForestSize()); Connection conncetion = null; if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { lastResult = new RandomForestModelGreenplum(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { lastResult = new RandomForestModelOracle(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { lastResult = new RandomForestModelDB2(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { lastResult = new RandomForestModelNZ(dataSet); } lastResult.setColumnNames(columnNames); lastResult.setDependColumn(dependentColumn); lastResult.setTableName(tableName); conncetion = ((DataBaseAnalyticSource) analyticSource).getConnection(); Model result = null; try { st = conncetion.createStatement(); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } // Iterator<String> dependvalueIterator = dataSet.getColumns() // .getLabel().getMapping().getValues().iterator(); if (dataSet.getColumns().getLabel() instanceof NominalColumn) { if (dataSet.getColumns().getLabel().getMapping().getValues().size() <= 1) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_SAMPLE_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } if (dataSet.getColumns().getLabel().getMapping().getValues().size() > AlpineMinerConfig.ADABOOST_MAX_DEPENDENT_COUNT) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_MAX_DEPENDENT_COUNT_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } } try { randomForestImpl.randomForestTrainInit( inputSchema, tableName, timeStamp, dependentColumn, st, dataSet); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } CartConfig config = new CartConfig(); config.setDependentColumn(dependentColumn); config.setConfidence(rfConfig.getConfidence()); config.setMaximal_depth(rfConfig.getMaximal_depth()); config.setMinimal_leaf_size(rfConfig.getMinimal_leaf_size()); config.setMinimal_size_for_split(rfConfig.getMinimal_size_for_split()); config.setNo_pre_pruning("true"); config.setNo_pruning("true"); for (int i = 0; i < forestSize; i++) { CartTrainer analyzer = new CartTrainer(); if (sampleWithReplacement == Resources.TrueOpt) { randomForestImpl.randomForestSample( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale()); } else { randomForestImpl.randomForestSampleNoReplace( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale(), dataSet.size()); } String subColumns = getSubColumns(totalColumns, subSize); config.setColumnNames(subColumns); DataBaseAnalyticSource tempsource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, sampleTable + i, useSSL); tempsource.setAnalyticConfiguration(config); tempsource.setConenction(conncetion); result = ((AnalyzerOutPutTrainModel) analyzer.doAnalysis(tempsource)) .getEngineModel() .getModel(); String OOBTable = "OOB" + sampleTable + i; randomForestImpl.generateOOBTable( inputSchema, OOBTable, pnewTable, sampleTable + i, st, rs); DataBaseAnalyticSource tempPredictSource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, OOBTable, useSSL); String predictOutTable = "OOBPredict" + sampleTable; EngineModel em = new EngineModel(); em.setModel(result); PredictorConfig tempconfig = new PredictorConfig(em); tempconfig.setDropIfExist(dropIfExists); tempconfig.setOutputSchema(inputSchema); tempconfig.setOutputTable(predictOutTable); tempPredictSource.setAnalyticConfiguration(tempconfig); tempPredictSource.setConenction(conncetion); AbstractDBModelPredictor predictor = new CartPredictor(); predictor.doAnalysis(tempPredictSource); // use the weak alg , do double OOBError = 0.0; if (result instanceof DecisionTreeModel) { OOBError = randomForestImpl.getOOBError( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobEstimateError().add(OOBError); } else if (result instanceof RegressionTreeModel) { OOBError = randomForestImpl.getMSE(tempPredictSource, "P(" + dependentColumn + ")"); lastResult.getOobLoss().add(OOBError); double OOBMape = randomForestImpl.getMAPE( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobMape().add(OOBMape); } else { OOBError = Double.NaN; lastResult.getOobLoss().add(OOBError); } lastResult.addModel((SingleModel) result); randomForestImpl.clearTrainResult(inputSchema, sampleTable + i); randomForestImpl.clearTrainResult(inputSchema, predictOutTable); randomForestImpl.clearTrainResult(inputSchema, OOBTable); } return lastResult; } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } finally { try { if (st != null) { st.close(); } if (rs != null) { rs.close(); } } catch (SQLException e) { logger.error(e); throw new AnalysisException(e.getLocalizedMessage()); } } }
/* (non-Javadoc) * @see com.alpine.datamining.api.impl.db.AbstractDBModelTrainer#train(com.alpine.datamining.api.AnalyticSource) */ @Override protected Model train(AnalyticSource source) throws AnalysisException { ResultSet rs = null; Statement st = null; EMModel trainModel = null; try { IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(source.getDataSourceType()); dbtype = dataSourceInfo.getDBType(); EMConfig config = (EMConfig) source.getAnalyticConfig(); String anaColumns = config.getColumnNames(); String[] columnsArray = anaColumns.split(","); List<String> transformColumns = new ArrayList<String>(); for (int i = 0; i < columnsArray.length; i++) { transformColumns.add(columnsArray[i]); } DataSet dataSet = getDataSet((DataBaseAnalyticSource) source, config); filerColumens(dataSet, transformColumns); dataSet.computeAllColumnStatistics(); ColumnTypeTransformer transformer = new ColumnTypeTransformer(); DataSet newDataSet = transformer.TransformCategoryToNumeric_new(dataSet); String tableName = ((DBTable) newDataSet.getDBTable()).getTableName(); Columns columns = newDataSet.getColumns(); List<String> newTransformColumns = new ArrayList<String>(); HashMap<String, String> transformMap = new HashMap<String, String>(); for (String key : transformer.getAllTransformMap_valueKey().keySet()) { HashMap<String, String> values = (transformer.getAllTransformMap_valueKey()).get(key); for (String lowKey : values.keySet()) { transformMap.put(values.get(lowKey), lowKey); } } Iterator<Column> attributeIter = columns.iterator(); while (attributeIter.hasNext()) { Column column = attributeIter.next(); newTransformColumns.add(column.getName()); } int maxIterationNumber = Integer.parseInt(config.getMaxIterationNumber()); int clusterNumber = Integer.parseInt(config.getClusterNumber()); double epsilon = Double.parseDouble(config.getEpsilon()); int initClusterSize = 10; if (config.getInitClusterSize() != null) { initClusterSize = Integer.parseInt(config.getInitClusterSize()); } if (newDataSet.size() < initClusterSize * clusterNumber) { initClusterSize = (int) (newDataSet.size() / clusterNumber + 1); } // TODO get it from config and make sure it will not be too large EMClusterImpl emImpl = EMClusterFactory.createEMAnalyzer(dbtype); trainModel = EMClusterFactory.createEMModel(dbtype, newDataSet); Connection connection = null; connection = ((DataBaseAnalyticSource) source).getConnection(); st = connection.createStatement(); ArrayList<Double> tempResult = emImpl.emTrain( connection, st, tableName, maxIterationNumber, epsilon, clusterNumber, newTransformColumns, initClusterSize, trainModel); trainModel = generateEMModel(trainModel, newTransformColumns, clusterNumber, tempResult); if (!newDataSet.equals(this.dataSet)) { trainModel.setAllTransformMap_valueKey(transformMap); } } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } finally { try { if (st != null) { st.close(); } if (rs != null) { rs.close(); } } catch (SQLException e) { logger.debug(e.toString()); throw new AnalysisException(e.getLocalizedMessage()); } } return trainModel; }
public Model train(DataSet dataSet, SVMParameter parameter) throws OperatorException { para = parameter; setDataSourceInfo( DataSourceInfoFactory.createConnectionInfo( ((DBTable) dataSet.getDBTable()).getDatabaseConnection().getProperties().getName())); DatabaseConnection databaseConnection = ((DBTable) dataSet.getDBTable()).getDatabaseConnection(); Column label = dataSet.getColumns().getLabel(); String labelString = StringHandler.doubleQ(label.getName()); DataSet newDataSet = getTransformer().TransformCategoryToNumeric_new(dataSet); String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName(); Statement st = null; ResultSet rs = null; try { st = databaseConnection.createStatement(false); } catch (SQLException e) { e.printStackTrace(); throw new OperatorException(e.getLocalizedMessage()); } StringBuffer ind = getColumnArray(newDataSet); StringBuffer where = getColumnWhere(newDataSet); where.append(" and ").append(labelString).append(" is not null "); SVMRegressionModel model = new SVMRegressionModel(dataSet, newDataSet); if (!newDataSet.equals(dataSet)) { model.setAllTransformMap_valueKey(getTransformer().getAllTransformMap_valueKey()); } model.setKernelType(para.getKernelType()); model.setDegree(para.getDegree()); model.setGamma(para.getGamma()); String sql = "select (model).inds, (model).cum_err, (model).epsilon, (model).rho, (model).b, (model).nsvs, (model).ind_dim, (model).weights, (model).individuals from (select alpine_miner_online_sv_reg('" + newTableName + "','" + ind + "','" + labelString + "','" + where + "'," + para.getKernelType() + "," + para.getDegree() + "," + para.getGamma() + "," + para.getEta() + "," + ((SVMRegressionParameter) para).getSlambda() + "," + para.getNu() + ") as model "; if (getDataSourceInfo().getDBType().equals(DataSourceInfoOracle.dBType)) { sql += " from dual "; } sql += ") a"; try { itsLogger.debug("SVMRegression.train():sql=" + sql); rs = st.executeQuery(sql.toString()); setModel(rs, model); if (getTransformer().isTransform()) { dropTable(st, newTableName); } rs.close(); st.close(); } catch (SQLException e) { e.printStackTrace(); throw new OperatorException(e.getLocalizedMessage()); } return model; }
private void performOperation( DatabaseConnection databaseConnection, DataSet dataSet, Locale locale) throws AnalysisError, OperatorException { String outputTableName = getQuotaedTableName(getOutputSchema(), getOutputTable()); String inputTableName = getQuotaedTableName(getInputSchema(), getInputTable()); Columns atts = dataSet.getColumns(); String dbType = databaseConnection.getProperties().getName(); IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(dbType); IMultiDBUtility multiDBUtility = MultiDBUtilityFactory.createConnectionInfo(dbType); ISqlGeneratorMultiDB sqlGenerator = SqlGeneratorMultiDBFactory.createConnectionInfo(dbType); dropIfExist(dataSet); DatabaseUtil.alterParallel(databaseConnection, getOutputType()); // for oracle StringBuilder sb_create = new StringBuilder("create "); StringBuilder insertTable = new StringBuilder(); if (getOutputType().equalsIgnoreCase("table")) { sb_create.append(" table "); } else { sb_create.append(" view "); } sb_create.append(outputTableName); sb_create.append( getOutputType().equalsIgnoreCase(Resources.TableType) ? getAppendOnlyString() : ""); sb_create.append(DatabaseUtil.addParallel(databaseConnection, getOutputType())).append(" as ("); StringBuilder selectSql = new StringBuilder(" select "); selectSql.append(StringHandler.doubleQ(groupColumn)).append(","); Column att = atts.get(columnNames); dataSet.computeColumnStatistics(att); if (att.isNumerical()) { logger.error("PivotTableAnalyzer cannot accept numeric type column"); throw new AnalysisError( this, AnalysisErrorName.Not_numeric, locale, SDKLanguagePack.getMessage(SDKLanguagePack.PIVOT_NAME, locale)); } String attName = StringHandler.doubleQ(att.getName()); List<String> valueList = att.getMapping().getValues(); if (!useArray && valueList.size() > Integer.parseInt(AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD)) { logger.error("Too many distinct value for column " + StringHandler.doubleQ(columnNames)); throw new AnalysisError( this, AnalysisErrorName.Too_Many_Distinct_value, locale, StringHandler.doubleQ(columnNames), AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD); } if (valueList.size() <= 0) { logger.error("Empty table"); throw new AnalysisError(this, AnalysisErrorName.Empty_table, locale); } String aggColumnName; if (!StringUtil.isEmpty(aggColumn)) { aggColumnName = StringHandler.doubleQ(aggColumn); } else { aggColumnName = "1"; } Iterator<String> valueList_i = valueList.iterator(); if (useArray) { if (dataSourceInfo.getDBType().equals(DataSourceInfoOracle.dBType)) { ArrayList<String> array = new ArrayList<String>(); while (valueList_i.hasNext()) { String value = StringHandler.escQ(valueList_i.next()); String newValue = "alpine_miner_null_to_0(" + aggrType + " (case when " + attName + "=" + CommonUtility.quoteValue(dbType, att, value) + " then " + aggColumnName + " end )) "; array.add(newValue); } selectSql.append( CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float)); } else { selectSql.append(multiDBUtility.floatArrayHead()); while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); selectSql.append(multiDBUtility.floatArrayTail()); } selectSql.append(" " + StringHandler.doubleQ(att.getName())); } else { if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoNZ.dBType)) { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } else if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoDB2.dBType)) { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (double(case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end ))) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } else { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } } selectSql.append(" from ").append(inputTableName).append(" foo group by "); selectSql.append(StringHandler.doubleQ(groupColumn)); if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoNZ.dBType)) { StringBuilder sb = new StringBuilder(); sb.append("select ").append(StringHandler.doubleQ(groupColumn)).append(","); Iterator<String> valueList_new = valueList.iterator(); while (valueList_new.hasNext()) { String value = valueList_new.next(); String colName = StringHandler.doubleQ(att.getName() + "_" + value); sb.append("case when ").append(colName).append(" is null then 0 else "); sb.append(colName).append(" end ").append(colName).append(","); } sb = sb.deleteCharAt(sb.length() - 1); sb.append(" from (").append(selectSql).append(") foo "); selectSql = sb; } sb_create.append(selectSql).append(" )"); if (getOutputType().equalsIgnoreCase("table")) { sb_create.append(getEndingString()); insertTable.append(sqlGenerator.insertTable(selectSql.toString(), outputTableName)); } try { Statement st = databaseConnection.createStatement(false); logger.debug("PivotTableAnalyzer.performOperation():sql=" + sb_create); st.execute(sb_create.toString()); if (insertTable.length() > 0) { st.execute(insertTable.toString()); logger.debug("PivotTableAnalyzer.performOperation():insertTableSql=" + insertTable); } } catch (SQLException e) { logger.error(e); if (e.getMessage().startsWith("ORA-03001") || e.getMessage().startsWith("ERROR: invalid identifier")) { throw new AnalysisError(this, AnalysisErrorName.Invalid_Identifier, locale); } else { throw new OperatorException(e.getLocalizedMessage()); } } }