@Override public AnalyticOutPut doAnalysis(AnalyticSource source) throws AnalysisException { try { DataSet dataSet = getDataSet((DataBaseAnalyticSource) source, source.getAnalyticConfig()); // dataSet.recalculateAllcolumnStatistics(); DatabaseConnection databaseConnection = ((DBTable) dataSet.getDBTable()).getDatabaseConnection(); PivotTableConfig config = (PivotTableConfig) source.getAnalyticConfig(); setInputSchema(((DataBaseAnalyticSource) source).getTableInfo().getSchema()); setInputTable(((DataBaseAnalyticSource) source).getTableInfo().getTableName()); setOutputType(config.getOutputType()); setOutputSchema(config.getOutputSchema()); setOutputTable(config.getOutputTable()); setDropIfExist(config.getDropIfExist()); columnNames = config.getPivotColumn(); groupColumn = config.getGroupByColumn(); aggColumn = config.getAggregateColumn(); aggrType = config.getAggregateType(); String dbType = databaseConnection.getProperties().getName(); if (config.getUseArray() != null && config.getUseArray().equalsIgnoreCase("true")) { if (dbType.equals(DataSourceInfoDB2.dBType) || dbType.equals(DataSourceInfoNZ.dBType)) { useArray = false; } else { useArray = true; } } else { useArray = false; } generateStoragePrameterString((DataBaseAnalyticSource) source); performOperation(databaseConnection, dataSet, config.getLocale()); DataBaseInfo dbInfo = ((DataBaseAnalyticSource) source).getDataBaseInfo(); AnalyzerOutPutTableObject outPut = getResultTableSampleRow(databaseConnection, dbInfo); outPut.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale())); outPut.setDbInfo(dbInfo); outPut.setSchemaName(getOutputSchema()); outPut.setTableName(getOutputTable()); return outPut; } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } }
public List<Stop> getStop(DataSet dataSet, boolean loadData) throws OperatorException { List<Stop> result = new LinkedList<Stop>(); if (loadData) { result.add(new ClassPureStop()); } else { result.add(new DBPureStop()); } result.add(new NoColumnStop()); result.add(new NoDataStop()); long maxDepth = para.getMaxDepth(); if (maxDepth <= 0) { maxDepth = dataSet.size(); } result.add(new DepthStop(maxDepth)); return result; }
protected StringBuffer getNominalUpdate( DataSet dataSet, Column predictedLabel, String newTableName, int numberOfClasses, StringBuffer[] classProbabilitiesSql, StringBuffer caseSql, StringBuffer[] biggerSql) { StringBuffer sql = new StringBuffer(); caseSql.append(" (case "); for (int c = 0; c < numberOfClasses - 1; c++) { caseSql .append(" when ") .append(biggerSql[c]) .append(" then '") .append(StringHandler.escQ(getLabel().getMapping().mapIndex(c))) .append("'"); } caseSql .append(" else '") .append(StringHandler.escQ(getLabel().getMapping().mapIndex(numberOfClasses - 1))) .append("' end)"); sql.append( " update " + newTableName + " set " + StringHandler.doubleQ(predictedLabel.getName())) .append("=") .append(caseSql); for (int c = 0; c < numberOfClasses; c++) { sql.append( ", " + StringHandler.doubleQ( dataSet .getColumns() .getSpecial( Column.CONFIDENCE_NAME + "_" + getLabel().getMapping().mapIndex(c)) .getName())) .append(" = ") .append(classProbabilitiesSql[c]); } sql.append(getWherePredict()); return sql; }
protected ConstructTree getTB(DataSet dataSet) throws OperatorException { if (dataSet.getColumns().getLabel().isNominal() || (para.isForWoe() == true)) { if (para.isUseChiSquare() == true || para.isForWoe() == true) { return new ConstructTree( createStandard(false), createStandard(true), getStop(dataSet, false), getStop(dataSet, true), new DBBuildLeaf(), new BuildLeaf(), getPrune(false), para.isNoPrePruning(), para.getPrepruningAlternativesNumber(), para.getSplitMinSize(), para.getMinLeafSize(), para.getThresholdLoadData(), para.isUseChiSquare()); } else { return new ConstructTree( createStandard(false), createStandard(true), getStop(dataSet, false), getStop(dataSet, true), new DBBuildLeaf(), new BuildLeaf(), getPrune(false), para.isNoPrePruning(), para.getPrepruningAlternativesNumber(), para.getSplitMinSize(), para.getMinLeafSize(), para.getThresholdLoadData()); } } return null; }
@Override protected Model train(AnalyticSource analyticSource) throws AnalysisException { ResultSet rs = null; Statement st = null; try { IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(analyticSource.getDataSourceType()); dbtype = dataSourceInfo.getDBType(); RandomForestModel lastResult = null; RandomForestIMP randomForestImpl = null; // if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { // randomForestTrainer = new AdaboostOracle(); // // } else if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { randomForestImpl = new RandomForestGreenplum(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { randomForestImpl = new RandomForestOracle(); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { randomForestImpl = new RandomForestDB2(); ((RandomForestDB2) randomForestImpl) .setConnection(((DataBaseAnalyticSource) analyticSource).getConnection()); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { randomForestImpl = new RandomForestNZ(); } else { throw new AnalysisException("Databse type is not supported for Random Forest:" + dbtype); // return null; } try { dataSet = getDataSet((DataBaseAnalyticSource) analyticSource, analyticSource.getAnalyticConfig()); } catch (OperatorException e1) { logger.error(e1); throw new OperatorException(e1.getLocalizedMessage()); } setSpecifyColumn(dataSet, analyticSource.getAnalyticConfig()); dataSet.computeAllColumnStatistics(); RandomForestConfig rfConfig = (RandomForestConfig) analyticSource.getAnalyticConfig(); String dbSystem = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getSystem(); String url = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUrl(); String userName = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUserName(); String password = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getPassword(); String inputSchema = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getSchema(); String tableName = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getTableName(); String useSSL = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUseSSL(); String sampleWithReplacement = rfConfig.getSampleWithReplacement(); long timeStamp = System.currentTimeMillis(); pnewTable = "pnew" + timeStamp; sampleTable = "s" + timeStamp; String dependentColumn = rfConfig.getDependentColumn(); String columnNames = rfConfig.getColumnNames(); String[] totalColumns = columnNames.split(","); int subSize = Integer.parseInt(rfConfig.getNodeColumnNumber()); int forestSize = Integer.parseInt(rfConfig.getForestSize()); Connection conncetion = null; if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType) || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) { lastResult = new RandomForestModelGreenplum(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) { lastResult = new RandomForestModelOracle(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) { lastResult = new RandomForestModelDB2(dataSet); } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) { lastResult = new RandomForestModelNZ(dataSet); } lastResult.setColumnNames(columnNames); lastResult.setDependColumn(dependentColumn); lastResult.setTableName(tableName); conncetion = ((DataBaseAnalyticSource) analyticSource).getConnection(); Model result = null; try { st = conncetion.createStatement(); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } // Iterator<String> dependvalueIterator = dataSet.getColumns() // .getLabel().getMapping().getValues().iterator(); if (dataSet.getColumns().getLabel() instanceof NominalColumn) { if (dataSet.getColumns().getLabel().getMapping().getValues().size() <= 1) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_SAMPLE_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } if (dataSet.getColumns().getLabel().getMapping().getValues().size() > AlpineMinerConfig.ADABOOST_MAX_DEPENDENT_COUNT) { String e = SDKLanguagePack.getMessage( SDKLanguagePack.ADABOOST_MAX_DEPENDENT_COUNT_ERRINFO, rfConfig.getLocale()); logger.error(e); throw new AnalysisException(e); } } try { randomForestImpl.randomForestTrainInit( inputSchema, tableName, timeStamp, dependentColumn, st, dataSet); } catch (SQLException e) { logger.error(e); throw new AnalysisException(e); } CartConfig config = new CartConfig(); config.setDependentColumn(dependentColumn); config.setConfidence(rfConfig.getConfidence()); config.setMaximal_depth(rfConfig.getMaximal_depth()); config.setMinimal_leaf_size(rfConfig.getMinimal_leaf_size()); config.setMinimal_size_for_split(rfConfig.getMinimal_size_for_split()); config.setNo_pre_pruning("true"); config.setNo_pruning("true"); for (int i = 0; i < forestSize; i++) { CartTrainer analyzer = new CartTrainer(); if (sampleWithReplacement == Resources.TrueOpt) { randomForestImpl.randomForestSample( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale()); } else { randomForestImpl.randomForestSampleNoReplace( inputSchema, timeStamp + "" + i, dependentColumn, st, rs, pnewTable, sampleTable + i, rfConfig.getLocale(), dataSet.size()); } String subColumns = getSubColumns(totalColumns, subSize); config.setColumnNames(subColumns); DataBaseAnalyticSource tempsource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, sampleTable + i, useSSL); tempsource.setAnalyticConfiguration(config); tempsource.setConenction(conncetion); result = ((AnalyzerOutPutTrainModel) analyzer.doAnalysis(tempsource)) .getEngineModel() .getModel(); String OOBTable = "OOB" + sampleTable + i; randomForestImpl.generateOOBTable( inputSchema, OOBTable, pnewTable, sampleTable + i, st, rs); DataBaseAnalyticSource tempPredictSource = new DataBaseAnalyticSource( dbSystem, url, userName, password, inputSchema, OOBTable, useSSL); String predictOutTable = "OOBPredict" + sampleTable; EngineModel em = new EngineModel(); em.setModel(result); PredictorConfig tempconfig = new PredictorConfig(em); tempconfig.setDropIfExist(dropIfExists); tempconfig.setOutputSchema(inputSchema); tempconfig.setOutputTable(predictOutTable); tempPredictSource.setAnalyticConfiguration(tempconfig); tempPredictSource.setConenction(conncetion); AbstractDBModelPredictor predictor = new CartPredictor(); predictor.doAnalysis(tempPredictSource); // use the weak alg , do double OOBError = 0.0; if (result instanceof DecisionTreeModel) { OOBError = randomForestImpl.getOOBError( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobEstimateError().add(OOBError); } else if (result instanceof RegressionTreeModel) { OOBError = randomForestImpl.getMSE(tempPredictSource, "P(" + dependentColumn + ")"); lastResult.getOobLoss().add(OOBError); double OOBMape = randomForestImpl.getMAPE( tempPredictSource, dependentColumn, "P(" + dependentColumn + ")"); lastResult.getOobMape().add(OOBMape); } else { OOBError = Double.NaN; lastResult.getOobLoss().add(OOBError); } lastResult.addModel((SingleModel) result); randomForestImpl.clearTrainResult(inputSchema, sampleTable + i); randomForestImpl.clearTrainResult(inputSchema, predictOutTable); randomForestImpl.clearTrainResult(inputSchema, OOBTable); } return lastResult; } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } finally { try { if (st != null) { st.close(); } if (rs != null) { rs.close(); } } catch (SQLException e) { logger.error(e); throw new AnalysisException(e.getLocalizedMessage()); } } }
public Model train(DataSet dataSet, SVMParameter parameter) throws OperatorException { para = parameter; setDataSourceInfo( DataSourceInfoFactory.createConnectionInfo( ((DBTable) dataSet.getDBTable()).getDatabaseConnection().getProperties().getName())); DatabaseConnection databaseConnection = ((DBTable) dataSet.getDBTable()).getDatabaseConnection(); Column label = dataSet.getColumns().getLabel(); String labelString = StringHandler.doubleQ(label.getName()); DataSet newDataSet = getTransformer().TransformCategoryToNumeric_new(dataSet); String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName(); Statement st = null; ResultSet rs = null; try { st = databaseConnection.createStatement(false); } catch (SQLException e) { e.printStackTrace(); throw new OperatorException(e.getLocalizedMessage()); } StringBuffer ind = getColumnArray(newDataSet); StringBuffer where = getColumnWhere(newDataSet); where.append(" and ").append(labelString).append(" is not null "); SVMRegressionModel model = new SVMRegressionModel(dataSet, newDataSet); if (!newDataSet.equals(dataSet)) { model.setAllTransformMap_valueKey(getTransformer().getAllTransformMap_valueKey()); } model.setKernelType(para.getKernelType()); model.setDegree(para.getDegree()); model.setGamma(para.getGamma()); String sql = "select (model).inds, (model).cum_err, (model).epsilon, (model).rho, (model).b, (model).nsvs, (model).ind_dim, (model).weights, (model).individuals from (select alpine_miner_online_sv_reg('" + newTableName + "','" + ind + "','" + labelString + "','" + where + "'," + para.getKernelType() + "," + para.getDegree() + "," + para.getGamma() + "," + para.getEta() + "," + ((SVMRegressionParameter) para).getSlambda() + "," + para.getNu() + ") as model "; if (getDataSourceInfo().getDBType().equals(DataSourceInfoOracle.dBType)) { sql += " from dual "; } sql += ") a"; try { itsLogger.debug("SVMRegression.train():sql=" + sql); rs = st.executeQuery(sql.toString()); setModel(rs, model); if (getTransformer().isTransform()) { dropTable(st, newTableName); } rs.close(); st.close(); } catch (SQLException e) { e.printStackTrace(); throw new OperatorException(e.getLocalizedMessage()); } return model; }
protected void adjustPerWholeData( DataSet dataSet, List<String[]> hiddenLayers, int maxCycles, double maxError, double learningRate, double momentum, boolean decay, boolean normalize, AlpineRandom randomGenerator, int fetchSize, boolean adjustPerRow, Column label, int numberOfClasses, DatabaseConnection databaseConnection, String tableName) throws OperatorException { double totalSize = dataSet.size(); Statement st = null; ResultSet rs = null; try { st = databaseConnection.createStatement(true); } catch (SQLException e) { throw new OperatorException(e.getLocalizedMessage()); } double error = 0; // optimization loop for (int cycle = 0; cycle < maxCycles; cycle++) { double tempRate = learningRate; if (decay) { tempRate /= (cycle + 1); } StringBuffer sql = new StringBuffer(); sql.append("select ").append(getArraySum()).append("("); // sql.append("select ").append("("); sql.append(getAllWeightChange(getNumberOfClasses(label))); sql.append(") from ").append(tableName); try { if (useFloatArraySumCursor()) { StringBuffer varcharArray = CommonUtility.splitOracleSqlToVarcharArray(sql); sql = new StringBuffer(); sql.append("select floatarraysum_cursor(").append(varcharArray).append(") from dual"); } itsLogger.debug("NNModel.adjustPerWholeData():sql=" + sql); rs = st.executeQuery(sql.toString()); if (rs.next()) { Double[] currentChanges = getResult(rs); // String currentChangesString= new String(); for (int i = 0; i < currentChanges.length; i++) { if (Double.isNaN(currentChanges[i]) || Double.isInfinite(currentChanges[i])) { try { if (rs != null) { rs.close(); } if (st != null) { st.close(); } databaseConnection.getConnection().setAutoCommit(true); } catch (SQLException e) { e.printStackTrace(); // throw new OperatorException(e.getLocalizedMessage ()); copyBestErrorWeightsToWeights(); return; } copyBestErrorWeightsToWeights(); return; } } error = currentChanges[currentChanges.length - 1] / numberOfClasses / totalSize; if (error < bestError) { bestError = error; copyWeightsToBestErrorWeights(); } updateWeight(currentChanges, tempRate, momentum); } } catch (SQLException e) { e.printStackTrace(); copyBestErrorWeightsToWeights(); // throw new OperatorException(e.getLocalizedMessage()); return; } // error /= totalSize; itsLogger.debug("cycle" + cycle + ";error:" + error); if (error < maxError) { itsLogger.debug("loop break : " + cycle + " error: " + error); break; } if (Double.isInfinite(error) || Double.isNaN(error)) { if (Tools.isLessEqual(learningRate, 0.0d)) // should hardly happen throw new RuntimeException("Cannot reset network to a smaller learning rate."); learningRate /= 2; train( dataSet, hiddenLayers, maxCycles, maxError, learningRate, momentum, decay, normalize, randomGenerator, fetchSize, adjustPerRow); } } copyBestErrorWeightsToWeights(); try { if (rs != null) { rs.close(); } if (st != null) { st.close(); } } catch (SQLException e) { e.printStackTrace(); // throw new OperatorException(e.getLocalizedMessage ()); // copyBestErrorWeightsToWeights(); return; } }
/* (non-Javadoc) * @see com.alpine.datamining.api.impl.db.AbstractDBModelTrainer#train(com.alpine.datamining.api.AnalyticSource) */ @Override protected Model train(AnalyticSource source) throws AnalysisException { ResultSet rs = null; Statement st = null; EMModel trainModel = null; try { IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(source.getDataSourceType()); dbtype = dataSourceInfo.getDBType(); EMConfig config = (EMConfig) source.getAnalyticConfig(); String anaColumns = config.getColumnNames(); String[] columnsArray = anaColumns.split(","); List<String> transformColumns = new ArrayList<String>(); for (int i = 0; i < columnsArray.length; i++) { transformColumns.add(columnsArray[i]); } DataSet dataSet = getDataSet((DataBaseAnalyticSource) source, config); filerColumens(dataSet, transformColumns); dataSet.computeAllColumnStatistics(); ColumnTypeTransformer transformer = new ColumnTypeTransformer(); DataSet newDataSet = transformer.TransformCategoryToNumeric_new(dataSet); String tableName = ((DBTable) newDataSet.getDBTable()).getTableName(); Columns columns = newDataSet.getColumns(); List<String> newTransformColumns = new ArrayList<String>(); HashMap<String, String> transformMap = new HashMap<String, String>(); for (String key : transformer.getAllTransformMap_valueKey().keySet()) { HashMap<String, String> values = (transformer.getAllTransformMap_valueKey()).get(key); for (String lowKey : values.keySet()) { transformMap.put(values.get(lowKey), lowKey); } } Iterator<Column> attributeIter = columns.iterator(); while (attributeIter.hasNext()) { Column column = attributeIter.next(); newTransformColumns.add(column.getName()); } int maxIterationNumber = Integer.parseInt(config.getMaxIterationNumber()); int clusterNumber = Integer.parseInt(config.getClusterNumber()); double epsilon = Double.parseDouble(config.getEpsilon()); int initClusterSize = 10; if (config.getInitClusterSize() != null) { initClusterSize = Integer.parseInt(config.getInitClusterSize()); } if (newDataSet.size() < initClusterSize * clusterNumber) { initClusterSize = (int) (newDataSet.size() / clusterNumber + 1); } // TODO get it from config and make sure it will not be too large EMClusterImpl emImpl = EMClusterFactory.createEMAnalyzer(dbtype); trainModel = EMClusterFactory.createEMModel(dbtype, newDataSet); Connection connection = null; connection = ((DataBaseAnalyticSource) source).getConnection(); st = connection.createStatement(); ArrayList<Double> tempResult = emImpl.emTrain( connection, st, tableName, maxIterationNumber, epsilon, clusterNumber, newTransformColumns, initClusterSize, trainModel); trainModel = generateEMModel(trainModel, newTransformColumns, clusterNumber, tempResult); if (!newDataSet.equals(this.dataSet)) { trainModel.setAllTransformMap_valueKey(transformMap); } } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } finally { try { if (st != null) { st.close(); } if (rs != null) { rs.close(); } } catch (SQLException e) { logger.debug(e.toString()); throw new AnalysisException(e.getLocalizedMessage()); } } return trainModel; }
@Override protected AnalyzerOutPutObject doValidate(DataBaseAnalyticSource source, EvaluatorConfig config) throws AnalysisException { DoubleListData result = null; List<DoubleListData> resultList = new ArrayList<DoubleListData>(); try { DataSet dataSet = getDataSet(source, source.getAnalyticConfig()); // dataSet.recalculateAllcolumnStatistics(); Operator operator = OperatorUtil.createOperator(LiftDataGeneratorGeneral.class); EvaluatorParameter parameter = new EvaluatorParameter(); if (config.getUseModel() != null) { // operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_USE_MODEL, // config.getUseModel()); parameter.setUseModel(Boolean.parseBoolean(config.getUseModel())); } if (config.getColumnValue() != null) { // operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_TARGET_CLASS, // config.getColumnValue()); parameter.setColumnValue(config.getColumnValue()); } operator.setParameter(parameter); if (config.getUseModel().equals("true")) { List<EngineModel> models = config.getTrainedModel(); Model model = null; for (int i = 0; i < models.size(); i++) { Container container = new Container(); model = models.get(i).getModel(); container = container.append(dataSet); container = container.append(model); Container resultContainer = operator.apply(container); result = resultContainer.get(DoubleListData.class); result.setSourceName(models.get(i).getName()); resultList.add(result); } } else { Container container = new Container(); String targetClass = config.getColumnValue(); Column confidence = dataSet.getColumns().get(Column.CONFIDENCE_NAME + "(" + targetClass + ")"); dataSet .getColumns() .setSpecialColumn(confidence, Column.CONFIDENCE_NAME + "_" + targetClass); container = container.append(dataSet); Container resultContainer = operator.apply(container); result = resultContainer.get(DoubleListData.class); result.setSourceName(getName()); resultList.add(result); } } catch (Exception e) { logger.error(e); if (e instanceof WrongUsedException) { throw new AnalysisError(this, (WrongUsedException) e); } else if (e instanceof AnalysisError) { throw (AnalysisError) e; } else { throw new AnalysisException(e); } } AnalyzerOutPutObject out = new AnalyzerOutPutObject(resultList); out.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale())); return out; }
public Model learn(DataSet dataSet, LinearRegressionParameter para, String columnNames) throws OperatorException { this.dataSet = dataSet; ArrayList<String> columnNamesList = new ArrayList<String>(); if (columnNames != null && !StringUtil.isEmpty(columnNames.trim())) { String[] columnNamesArray = columnNames.split(","); for (String s : columnNamesArray) { columnNamesList.add(s); } } transformer.setColumnNames(columnNamesList); transformer.setAnalysisInterActionModel(para.getAnalysisInterActionModel()); newDataSet = transformer.TransformCategoryToNumeric_new(dataSet, groupbyColumn); DatabaseConnection databaseConnection = ((DBTable) newDataSet.getDBTable()).getDatabaseConnection(); Column label = newDataSet.getColumns().getLabel(); String labelName = StringHandler.doubleQ(label.getName()); String tableName = ((DBTable) dataSet.getDBTable()).getTableName(); String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName(); try { st = databaseConnection.createStatement(false); } catch (SQLException e) { e.printStackTrace(); throw new OperatorException(e.getLocalizedMessage()); } try { newDataSet.computeAllColumnStatistics(); Columns atts = newDataSet.getColumns(); Iterator<Column> atts_i = atts.iterator(); int count = 0; String[] columnNamesArray = new String[atts.size()]; while (atts_i.hasNext()) { Column att = atts_i.next(); columnNamesArray[count] = att.getName(); count++; } null_list = calculateNull(dataSet); null_list_group = calculateNullGroup(newDataSet, atts); StringBuilder sb_notNull = getWhere(atts); getCoefficientAndR2Group( columnNames, dataSet, labelName, tableName, newTableName, atts, columnNamesArray, sb_notNull); HashMap<String, Long> degreeOfFreedom = new HashMap<String, Long>(); for (String groupValue : groupCount.keySet()) { long tempDof = groupCount.get(groupValue) - columnNamesArray.length - 1; if (tempDof <= 0) { model.getOneModel(groupValue).setS(Double.NaN); } degreeOfFreedom.put(groupValue, tempDof); } // if (dof <= 0) // { // model.setS(Double.NaN); // return model; // } StringBuffer sSQL = createSSQLLGroup( newDataSet, newTableName, label, columnNamesArray, coefficients, sb_notNull); HashMap<String, Double> sValueMap = new HashMap<String, Double>(); try { itsLogger.debug(classLogInfo + ".learn():sql=" + sSQL); rs = st.executeQuery(sSQL.toString()); while (rs.next()) { String groupValue = rs.getString(2); if (groupValue == null) { continue; } if (dataErrorList.contains(groupValue)) { sValueMap.put(groupValue, Double.NaN); } else { sValueMap.put(groupValue, rs.getDouble(1)); } } rs.close(); } catch (SQLException e) { e.printStackTrace(); itsLogger.error(e.getMessage(), e); throw new OperatorException(e.getLocalizedMessage()); } HashMap<String, Matrix> varianceCovarianceMatrix = getVarianceCovarianceMatrixGroup(newTableName, columnNamesArray, st); for (String groupValue : varianceCovarianceMatrix.keySet()) { if (varianceCovarianceMatrix == null) { model .getOneModel(groupValue) .setErrorString( AlpineDataAnalysisLanguagePack.getMessage( AlpineDataAnalysisLanguagePack.MATRIX_IS_SIGULAR, AlpineThreadLocal.getLocale()) + Tools.getLineSeparator()); } } caculateStatistics( columnNamesArray, coefficients, model, sValueMap, varianceCovarianceMatrix, degreeOfFreedom); for (String groupValue : null_list_group.keySet()) { if (null_list_group.get(groupValue).size() != 0) { StringBuilder sb_null = new StringBuilder(); for (int i = 0; i < null_list_group.get(groupValue).size(); i++) { sb_null .append(StringHandler.doubleQ(null_list_group.get(groupValue).get(i))) .append(","); } sb_null = sb_null.deleteCharAt(sb_null.length() - 1); String table_exist_null = AlpineDataAnalysisLanguagePack.getMessage( AlpineDataAnalysisLanguagePack.TABLE_EXIST_NULL, AlpineThreadLocal.getLocale()); String[] temp = table_exist_null.split(";"); model .getOneModel(groupValue) .setErrorString(temp[0] + sb_null.toString() + temp[1] + Tools.getLineSeparator()); } } if (transformer.isTransform()) { dropTable(newTableName); } st.close(); itsLogger.debug(LogUtils.exit(classLogInfo, "learn", model.toString())); model.setGroupByColumn(groupbyColumn); return model; } catch (Exception e) { itsLogger.error(e.getMessage(), e); throw new OperatorException(e.getLocalizedMessage()); } }
protected HashMap<String, ArrayList<String>> calculateNullGroup(DataSet dataSet, Columns atts) throws OperatorException { DatabaseConnection databaseConnection = ((DBTable) dataSet.getDBTable()).getDatabaseConnection(); String tableName = ((DBTable) dataSet.getDBTable()).getTableName(); HashMap<String, ArrayList<String>> resultList = new HashMap<String, ArrayList<String>>(); Iterator<Column> i = dataSet.getColumns().iterator(); StringBuilder sb_count = new StringBuilder("select "); sb_count.append(StringHandler.doubleQ(groupbyColumn)).append(", count(*) "); sb_count .append(" from ") .append(tableName) .append(this.getWhere(atts)) .append(" group by ") .append(StringHandler.doubleQ(groupbyColumn)); try { Statement st = databaseConnection.createStatement(false); itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString()); ResultSet rs = st.executeQuery(sb_count.toString()); while (rs.next()) { String groupValue = ""; groupValue = rs.getString(1); this.groupCount.put(groupValue, rs.getInt(2)); } } catch (SQLException e) { e.printStackTrace(); itsLogger.error(e.getMessage(), e); throw new OperatorException(e.getLocalizedMessage()); } sb_count = new StringBuilder("select ").append(StringHandler.doubleQ(groupbyColumn)).append(", "); while (i.hasNext()) { Column att = i.next(); sb_count.append("count(").append(StringHandler.doubleQ(att.getName())).append(")"); sb_count.append(StringHandler.doubleQ(att.getName())).append(","); } sb_count = sb_count.deleteCharAt(sb_count.length() - 1); sb_count .append(" from ") .append(tableName) .append(this.getWhere(atts)) .append(" group by ") .append(StringHandler.doubleQ(groupbyColumn)); try { Statement st = databaseConnection.createStatement(false); itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString()); ResultSet rs = st.executeQuery(sb_count.toString()); while (rs.next()) { ArrayList<String> null_list = new ArrayList<String>(); String groupValue = ""; groupValue = rs.getString(1); for (int j = 1; j < rs.getMetaData().getColumnCount(); j++) { if (rs.getFloat(j + 1) != groupCount.get(groupValue)) { null_list.add( dataSet.getColumns().get(rs.getMetaData().getColumnName(j + 1)).getName()); } } resultList.put(groupValue, null_list); } } catch (SQLException e) { e.printStackTrace(); itsLogger.error(e.getMessage(), e); throw new OperatorException(e.getLocalizedMessage()); } return resultList; }
private void performOperation( DatabaseConnection databaseConnection, DataSet dataSet, Locale locale) throws AnalysisError, OperatorException { String outputTableName = getQuotaedTableName(getOutputSchema(), getOutputTable()); String inputTableName = getQuotaedTableName(getInputSchema(), getInputTable()); Columns atts = dataSet.getColumns(); String dbType = databaseConnection.getProperties().getName(); IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(dbType); IMultiDBUtility multiDBUtility = MultiDBUtilityFactory.createConnectionInfo(dbType); ISqlGeneratorMultiDB sqlGenerator = SqlGeneratorMultiDBFactory.createConnectionInfo(dbType); dropIfExist(dataSet); DatabaseUtil.alterParallel(databaseConnection, getOutputType()); // for oracle StringBuilder sb_create = new StringBuilder("create "); StringBuilder insertTable = new StringBuilder(); if (getOutputType().equalsIgnoreCase("table")) { sb_create.append(" table "); } else { sb_create.append(" view "); } sb_create.append(outputTableName); sb_create.append( getOutputType().equalsIgnoreCase(Resources.TableType) ? getAppendOnlyString() : ""); sb_create.append(DatabaseUtil.addParallel(databaseConnection, getOutputType())).append(" as ("); StringBuilder selectSql = new StringBuilder(" select "); selectSql.append(StringHandler.doubleQ(groupColumn)).append(","); Column att = atts.get(columnNames); dataSet.computeColumnStatistics(att); if (att.isNumerical()) { logger.error("PivotTableAnalyzer cannot accept numeric type column"); throw new AnalysisError( this, AnalysisErrorName.Not_numeric, locale, SDKLanguagePack.getMessage(SDKLanguagePack.PIVOT_NAME, locale)); } String attName = StringHandler.doubleQ(att.getName()); List<String> valueList = att.getMapping().getValues(); if (!useArray && valueList.size() > Integer.parseInt(AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD)) { logger.error("Too many distinct value for column " + StringHandler.doubleQ(columnNames)); throw new AnalysisError( this, AnalysisErrorName.Too_Many_Distinct_value, locale, StringHandler.doubleQ(columnNames), AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD); } if (valueList.size() <= 0) { logger.error("Empty table"); throw new AnalysisError(this, AnalysisErrorName.Empty_table, locale); } String aggColumnName; if (!StringUtil.isEmpty(aggColumn)) { aggColumnName = StringHandler.doubleQ(aggColumn); } else { aggColumnName = "1"; } Iterator<String> valueList_i = valueList.iterator(); if (useArray) { if (dataSourceInfo.getDBType().equals(DataSourceInfoOracle.dBType)) { ArrayList<String> array = new ArrayList<String>(); while (valueList_i.hasNext()) { String value = StringHandler.escQ(valueList_i.next()); String newValue = "alpine_miner_null_to_0(" + aggrType + " (case when " + attName + "=" + CommonUtility.quoteValue(dbType, att, value) + " then " + aggColumnName + " end )) "; array.add(newValue); } selectSql.append( CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float)); } else { selectSql.append(multiDBUtility.floatArrayHead()); while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); selectSql.append(multiDBUtility.floatArrayTail()); } selectSql.append(" " + StringHandler.doubleQ(att.getName())); } else { if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoNZ.dBType)) { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } else if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoDB2.dBType)) { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (double(case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end ))) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } else { while (valueList_i.hasNext()) { String value = valueList_i.next(); selectSql.append("alpine_miner_null_to_0(").append(aggrType); selectSql.append(" (case when ").append(attName).append("="); value = StringHandler.escQ(value); selectSql .append(CommonUtility.quoteValue(dbType, att, value)) .append(" then ") .append(aggColumnName) .append(" end )) "); // else 0 String colName = StringHandler.doubleQ(att.getName() + "_" + value); selectSql.append(colName); selectSql.append(","); } selectSql = selectSql.deleteCharAt(selectSql.length() - 1); } } selectSql.append(" from ").append(inputTableName).append(" foo group by "); selectSql.append(StringHandler.doubleQ(groupColumn)); if (((DBTable) dataSet.getDBTable()) .getDatabaseConnection() .getProperties() .getName() .equals(DataSourceInfoNZ.dBType)) { StringBuilder sb = new StringBuilder(); sb.append("select ").append(StringHandler.doubleQ(groupColumn)).append(","); Iterator<String> valueList_new = valueList.iterator(); while (valueList_new.hasNext()) { String value = valueList_new.next(); String colName = StringHandler.doubleQ(att.getName() + "_" + value); sb.append("case when ").append(colName).append(" is null then 0 else "); sb.append(colName).append(" end ").append(colName).append(","); } sb = sb.deleteCharAt(sb.length() - 1); sb.append(" from (").append(selectSql).append(") foo "); selectSql = sb; } sb_create.append(selectSql).append(" )"); if (getOutputType().equalsIgnoreCase("table")) { sb_create.append(getEndingString()); insertTable.append(sqlGenerator.insertTable(selectSql.toString(), outputTableName)); } try { Statement st = databaseConnection.createStatement(false); logger.debug("PivotTableAnalyzer.performOperation():sql=" + sb_create); st.execute(sb_create.toString()); if (insertTable.length() > 0) { st.execute(insertTable.toString()); logger.debug("PivotTableAnalyzer.performOperation():insertTableSql=" + insertTable); } } catch (SQLException e) { logger.error(e); if (e.getMessage().startsWith("ORA-03001") || e.getMessage().startsWith("ERROR: invalid identifier")) { throw new AnalysisError(this, AnalysisErrorName.Invalid_Identifier, locale); } else { throw new OperatorException(e.getLocalizedMessage()); } } }