@Override
  public AnalyticOutPut doAnalysis(AnalyticSource source) throws AnalysisException {
    try {
      DataSet dataSet = getDataSet((DataBaseAnalyticSource) source, source.getAnalyticConfig());
      //			dataSet.recalculateAllcolumnStatistics();
      DatabaseConnection databaseConnection =
          ((DBTable) dataSet.getDBTable()).getDatabaseConnection();
      PivotTableConfig config = (PivotTableConfig) source.getAnalyticConfig();

      setInputSchema(((DataBaseAnalyticSource) source).getTableInfo().getSchema());
      setInputTable(((DataBaseAnalyticSource) source).getTableInfo().getTableName());
      setOutputType(config.getOutputType());
      setOutputSchema(config.getOutputSchema());
      setOutputTable(config.getOutputTable());
      setDropIfExist(config.getDropIfExist());
      columnNames = config.getPivotColumn();
      groupColumn = config.getGroupByColumn();
      aggColumn = config.getAggregateColumn();
      aggrType = config.getAggregateType();
      String dbType = databaseConnection.getProperties().getName();
      if (config.getUseArray() != null && config.getUseArray().equalsIgnoreCase("true")) {
        if (dbType.equals(DataSourceInfoDB2.dBType) || dbType.equals(DataSourceInfoNZ.dBType)) {
          useArray = false;
        } else {
          useArray = true;
        }
      } else {
        useArray = false;
      }
      generateStoragePrameterString((DataBaseAnalyticSource) source);
      performOperation(databaseConnection, dataSet, config.getLocale());

      DataBaseInfo dbInfo = ((DataBaseAnalyticSource) source).getDataBaseInfo();
      AnalyzerOutPutTableObject outPut = getResultTableSampleRow(databaseConnection, dbInfo);
      outPut.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale()));
      outPut.setDbInfo(dbInfo);
      outPut.setSchemaName(getOutputSchema());
      outPut.setTableName(getOutputTable());
      return outPut;

    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    }
  }
 public List<Stop> getStop(DataSet dataSet, boolean loadData) throws OperatorException {
   List<Stop> result = new LinkedList<Stop>();
   if (loadData) {
     result.add(new ClassPureStop());
   } else {
     result.add(new DBPureStop());
   }
   result.add(new NoColumnStop());
   result.add(new NoDataStop());
   long maxDepth = para.getMaxDepth();
   if (maxDepth <= 0) {
     maxDepth = dataSet.size();
   }
   result.add(new DepthStop(maxDepth));
   return result;
 }
示例#3
0
  protected StringBuffer getNominalUpdate(
      DataSet dataSet,
      Column predictedLabel,
      String newTableName,
      int numberOfClasses,
      StringBuffer[] classProbabilitiesSql,
      StringBuffer caseSql,
      StringBuffer[] biggerSql) {
    StringBuffer sql = new StringBuffer();
    caseSql.append(" (case ");
    for (int c = 0; c < numberOfClasses - 1; c++) {
      caseSql
          .append(" when ")
          .append(biggerSql[c])
          .append(" then '")
          .append(StringHandler.escQ(getLabel().getMapping().mapIndex(c)))
          .append("'");
    }
    caseSql
        .append(" else '")
        .append(StringHandler.escQ(getLabel().getMapping().mapIndex(numberOfClasses - 1)))
        .append("' end)");

    sql.append(
            " update " + newTableName + " set " + StringHandler.doubleQ(predictedLabel.getName()))
        .append("=")
        .append(caseSql);
    for (int c = 0; c < numberOfClasses; c++) {
      sql.append(
              ", "
                  + StringHandler.doubleQ(
                      dataSet
                          .getColumns()
                          .getSpecial(
                              Column.CONFIDENCE_NAME + "_" + getLabel().getMapping().mapIndex(c))
                          .getName()))
          .append(" = ")
          .append(classProbabilitiesSql[c]);
    }
    sql.append(getWherePredict());
    return sql;
  }
  protected ConstructTree getTB(DataSet dataSet) throws OperatorException {
    if (dataSet.getColumns().getLabel().isNominal() || (para.isForWoe() == true)) {

      if (para.isUseChiSquare() == true || para.isForWoe() == true) {
        return new ConstructTree(
            createStandard(false),
            createStandard(true),
            getStop(dataSet, false),
            getStop(dataSet, true),
            new DBBuildLeaf(),
            new BuildLeaf(),
            getPrune(false),
            para.isNoPrePruning(),
            para.getPrepruningAlternativesNumber(),
            para.getSplitMinSize(),
            para.getMinLeafSize(),
            para.getThresholdLoadData(),
            para.isUseChiSquare());
      } else {
        return new ConstructTree(
            createStandard(false),
            createStandard(true),
            getStop(dataSet, false),
            getStop(dataSet, true),
            new DBBuildLeaf(),
            new BuildLeaf(),
            getPrune(false),
            para.isNoPrePruning(),
            para.getPrepruningAlternativesNumber(),
            para.getSplitMinSize(),
            para.getMinLeafSize(),
            para.getThresholdLoadData());
      }
    }
    return null;
  }
  @Override
  protected Model train(AnalyticSource analyticSource) throws AnalysisException {
    ResultSet rs = null;
    Statement st = null;
    try {
      IDataSourceInfo dataSourceInfo =
          DataSourceInfoFactory.createConnectionInfo(analyticSource.getDataSourceType());
      dbtype = dataSourceInfo.getDBType();
      RandomForestModel lastResult = null;
      RandomForestIMP randomForestImpl = null;
      // if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
      // randomForestTrainer = new AdaboostOracle();
      //
      // } else
      if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType)
          || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) {
        randomForestImpl = new RandomForestGreenplum();

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
        randomForestImpl = new RandomForestOracle();

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) {
        randomForestImpl = new RandomForestDB2();
        ((RandomForestDB2) randomForestImpl)
            .setConnection(((DataBaseAnalyticSource) analyticSource).getConnection());
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) {
        randomForestImpl = new RandomForestNZ();
      } else {
        throw new AnalysisException("Databse type is not supported for Random Forest:" + dbtype);
        //				return null;
      }

      try {
        dataSet =
            getDataSet((DataBaseAnalyticSource) analyticSource, analyticSource.getAnalyticConfig());
      } catch (OperatorException e1) {
        logger.error(e1);
        throw new OperatorException(e1.getLocalizedMessage());
      }
      setSpecifyColumn(dataSet, analyticSource.getAnalyticConfig());
      dataSet.computeAllColumnStatistics();

      RandomForestConfig rfConfig = (RandomForestConfig) analyticSource.getAnalyticConfig();

      String dbSystem = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getSystem();

      String url = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUrl();
      String userName = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUserName();
      String password = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getPassword();
      String inputSchema = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getSchema();
      String tableName = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getTableName();
      String useSSL = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUseSSL();
      String sampleWithReplacement = rfConfig.getSampleWithReplacement();
      long timeStamp = System.currentTimeMillis();
      pnewTable = "pnew" + timeStamp;
      sampleTable = "s" + timeStamp;
      String dependentColumn = rfConfig.getDependentColumn();
      String columnNames = rfConfig.getColumnNames();
      String[] totalColumns = columnNames.split(",");
      int subSize = Integer.parseInt(rfConfig.getNodeColumnNumber());
      int forestSize = Integer.parseInt(rfConfig.getForestSize());

      Connection conncetion = null;

      if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType)
          || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) {

        lastResult = new RandomForestModelGreenplum(dataSet);
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
        lastResult = new RandomForestModelOracle(dataSet);

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) {
        lastResult = new RandomForestModelDB2(dataSet);
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) {
        lastResult = new RandomForestModelNZ(dataSet);
      }
      lastResult.setColumnNames(columnNames);
      lastResult.setDependColumn(dependentColumn);
      lastResult.setTableName(tableName);

      conncetion = ((DataBaseAnalyticSource) analyticSource).getConnection();

      Model result = null;

      try {
        st = conncetion.createStatement();
      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e);
      }

      //			Iterator<String> dependvalueIterator = dataSet.getColumns()
      //					.getLabel().getMapping().getValues().iterator();
      if (dataSet.getColumns().getLabel() instanceof NominalColumn) {
        if (dataSet.getColumns().getLabel().getMapping().getValues().size() <= 1) {
          String e =
              SDKLanguagePack.getMessage(
                  SDKLanguagePack.ADABOOST_SAMPLE_ERRINFO, rfConfig.getLocale());
          logger.error(e);
          throw new AnalysisException(e);
        }
        if (dataSet.getColumns().getLabel().getMapping().getValues().size()
            > AlpineMinerConfig.ADABOOST_MAX_DEPENDENT_COUNT) {
          String e =
              SDKLanguagePack.getMessage(
                  SDKLanguagePack.ADABOOST_MAX_DEPENDENT_COUNT_ERRINFO, rfConfig.getLocale());
          logger.error(e);
          throw new AnalysisException(e);
        }
      }

      try {

        randomForestImpl.randomForestTrainInit(
            inputSchema, tableName, timeStamp, dependentColumn, st, dataSet);

      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e);
      }

      CartConfig config = new CartConfig();
      config.setDependentColumn(dependentColumn);
      config.setConfidence(rfConfig.getConfidence());
      config.setMaximal_depth(rfConfig.getMaximal_depth());
      config.setMinimal_leaf_size(rfConfig.getMinimal_leaf_size());
      config.setMinimal_size_for_split(rfConfig.getMinimal_size_for_split());
      config.setNo_pre_pruning("true");
      config.setNo_pruning("true");

      for (int i = 0; i < forestSize; i++) {
        CartTrainer analyzer = new CartTrainer();

        if (sampleWithReplacement == Resources.TrueOpt) {
          randomForestImpl.randomForestSample(
              inputSchema,
              timeStamp + "" + i,
              dependentColumn,
              st,
              rs,
              pnewTable,
              sampleTable + i,
              rfConfig.getLocale());
        } else {
          randomForestImpl.randomForestSampleNoReplace(
              inputSchema,
              timeStamp + "" + i,
              dependentColumn,
              st,
              rs,
              pnewTable,
              sampleTable + i,
              rfConfig.getLocale(),
              dataSet.size());
        }
        String subColumns = getSubColumns(totalColumns, subSize);

        config.setColumnNames(subColumns);
        DataBaseAnalyticSource tempsource =
            new DataBaseAnalyticSource(
                dbSystem, url, userName, password, inputSchema, sampleTable + i, useSSL);
        tempsource.setAnalyticConfiguration(config);
        tempsource.setConenction(conncetion);
        result =
            ((AnalyzerOutPutTrainModel) analyzer.doAnalysis(tempsource))
                .getEngineModel()
                .getModel();
        String OOBTable = "OOB" + sampleTable + i;

        randomForestImpl.generateOOBTable(
            inputSchema, OOBTable, pnewTable, sampleTable + i, st, rs);

        DataBaseAnalyticSource tempPredictSource =
            new DataBaseAnalyticSource(
                dbSystem, url, userName, password, inputSchema, OOBTable, useSSL);

        String predictOutTable = "OOBPredict" + sampleTable;
        EngineModel em = new EngineModel();
        em.setModel(result);
        PredictorConfig tempconfig = new PredictorConfig(em);
        tempconfig.setDropIfExist(dropIfExists);
        tempconfig.setOutputSchema(inputSchema);
        tempconfig.setOutputTable(predictOutTable);
        tempPredictSource.setAnalyticConfiguration(tempconfig);
        tempPredictSource.setConenction(conncetion);
        AbstractDBModelPredictor predictor = new CartPredictor();
        predictor.doAnalysis(tempPredictSource); // use the weak alg , do

        double OOBError = 0.0;
        if (result instanceof DecisionTreeModel) {
          OOBError =
              randomForestImpl.getOOBError(
                  tempPredictSource, dependentColumn, "P(" + dependentColumn + ")");
          lastResult.getOobEstimateError().add(OOBError);
        } else if (result instanceof RegressionTreeModel) {
          OOBError = randomForestImpl.getMSE(tempPredictSource, "P(" + dependentColumn + ")");
          lastResult.getOobLoss().add(OOBError);
          double OOBMape =
              randomForestImpl.getMAPE(
                  tempPredictSource, dependentColumn, "P(" + dependentColumn + ")");
          lastResult.getOobMape().add(OOBMape);
        } else {
          OOBError = Double.NaN;
          lastResult.getOobLoss().add(OOBError);
        }

        lastResult.addModel((SingleModel) result);
        randomForestImpl.clearTrainResult(inputSchema, sampleTable + i);
        randomForestImpl.clearTrainResult(inputSchema, predictOutTable);
        randomForestImpl.clearTrainResult(inputSchema, OOBTable);
      }
      return lastResult;

    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    } finally {
      try {
        if (st != null) {
          st.close();
        }
        if (rs != null) {
          rs.close();
        }
      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e.getLocalizedMessage());
      }
    }
  }
  public Model train(DataSet dataSet, SVMParameter parameter) throws OperatorException {
    para = parameter;
    setDataSourceInfo(
        DataSourceInfoFactory.createConnectionInfo(
            ((DBTable) dataSet.getDBTable()).getDatabaseConnection().getProperties().getName()));
    DatabaseConnection databaseConnection =
        ((DBTable) dataSet.getDBTable()).getDatabaseConnection();
    Column label = dataSet.getColumns().getLabel();
    String labelString = StringHandler.doubleQ(label.getName());

    DataSet newDataSet = getTransformer().TransformCategoryToNumeric_new(dataSet);
    String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName();

    Statement st = null;
    ResultSet rs = null;
    try {
      st = databaseConnection.createStatement(false);
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    StringBuffer ind = getColumnArray(newDataSet);
    StringBuffer where = getColumnWhere(newDataSet);
    where.append(" and ").append(labelString).append(" is not null ");
    SVMRegressionModel model = new SVMRegressionModel(dataSet, newDataSet);
    if (!newDataSet.equals(dataSet)) {
      model.setAllTransformMap_valueKey(getTransformer().getAllTransformMap_valueKey());
    }
    model.setKernelType(para.getKernelType());
    model.setDegree(para.getDegree());
    model.setGamma(para.getGamma());

    String sql =
        "select (model).inds, (model).cum_err, (model).epsilon, (model).rho, (model).b, (model).nsvs, (model).ind_dim, (model).weights, (model).individuals from (select alpine_miner_online_sv_reg('"
            + newTableName
            + "','"
            + ind
            + "','"
            + labelString
            + "','"
            + where
            + "',"
            + para.getKernelType()
            + ","
            + para.getDegree()
            + ","
            + para.getGamma()
            + ","
            + para.getEta()
            + ","
            + ((SVMRegressionParameter) para).getSlambda()
            + ","
            + para.getNu()
            + ") as model ";
    if (getDataSourceInfo().getDBType().equals(DataSourceInfoOracle.dBType)) {
      sql += " from dual ";
    }
    sql += ") a";
    try {
      itsLogger.debug("SVMRegression.train():sql=" + sql);
      rs = st.executeQuery(sql.toString());
      setModel(rs, model);
      if (getTransformer().isTransform()) {
        dropTable(st, newTableName);
      }
      rs.close();
      st.close();
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    return model;
  }
示例#7
0
  protected void adjustPerWholeData(
      DataSet dataSet,
      List<String[]> hiddenLayers,
      int maxCycles,
      double maxError,
      double learningRate,
      double momentum,
      boolean decay,
      boolean normalize,
      AlpineRandom randomGenerator,
      int fetchSize,
      boolean adjustPerRow,
      Column label,
      int numberOfClasses,
      DatabaseConnection databaseConnection,
      String tableName)
      throws OperatorException {
    double totalSize = dataSet.size();

    Statement st = null;
    ResultSet rs = null;

    try {
      st = databaseConnection.createStatement(true);
    } catch (SQLException e) {
      throw new OperatorException(e.getLocalizedMessage());
    }
    double error = 0;
    // optimization loop
    for (int cycle = 0; cycle < maxCycles; cycle++) {
      double tempRate = learningRate;
      if (decay) {
        tempRate /= (cycle + 1);
      }

      StringBuffer sql = new StringBuffer();
      sql.append("select ").append(getArraySum()).append("(");
      //			sql.append("select ").append("(");
      sql.append(getAllWeightChange(getNumberOfClasses(label)));
      sql.append(") from ").append(tableName);

      try {
        if (useFloatArraySumCursor()) {
          StringBuffer varcharArray = CommonUtility.splitOracleSqlToVarcharArray(sql);
          sql = new StringBuffer();
          sql.append("select floatarraysum_cursor(").append(varcharArray).append(") from dual");
        }
        itsLogger.debug("NNModel.adjustPerWholeData():sql=" + sql);

        rs = st.executeQuery(sql.toString());
        if (rs.next()) {
          Double[] currentChanges = getResult(rs);
          //				String currentChangesString= new String();
          for (int i = 0; i < currentChanges.length; i++) {
            if (Double.isNaN(currentChanges[i]) || Double.isInfinite(currentChanges[i])) {
              try {
                if (rs != null) {
                  rs.close();
                }
                if (st != null) {
                  st.close();
                }
                databaseConnection.getConnection().setAutoCommit(true);
              } catch (SQLException e) {
                e.printStackTrace();
                //							throw new OperatorException(e.getLocalizedMessage ());
                copyBestErrorWeightsToWeights();
                return;
              }
              copyBestErrorWeightsToWeights();
              return;
            }
          }
          error = currentChanges[currentChanges.length - 1] / numberOfClasses / totalSize;
          if (error < bestError) {
            bestError = error;
            copyWeightsToBestErrorWeights();
          }
          updateWeight(currentChanges, tempRate, momentum);
        }
      } catch (SQLException e) {
        e.printStackTrace();
        copyBestErrorWeightsToWeights();
        //				throw new OperatorException(e.getLocalizedMessage());
        return;
      }
      //			error /= totalSize;
      itsLogger.debug("cycle" + cycle + ";error:" + error);
      if (error < maxError) {
        itsLogger.debug("loop break : " + cycle + " error: " + error);
        break;
      }

      if (Double.isInfinite(error) || Double.isNaN(error)) {
        if (Tools.isLessEqual(learningRate, 0.0d)) // should hardly happen
        throw new RuntimeException("Cannot reset network to a smaller learning rate.");
        learningRate /= 2;
        train(
            dataSet,
            hiddenLayers,
            maxCycles,
            maxError,
            learningRate,
            momentum,
            decay,
            normalize,
            randomGenerator,
            fetchSize,
            adjustPerRow);
      }
    }
    copyBestErrorWeightsToWeights();

    try {
      if (rs != null) {
        rs.close();
      }
      if (st != null) {
        st.close();
      }
    } catch (SQLException e) {
      e.printStackTrace();
      //			throw new OperatorException(e.getLocalizedMessage ());
      //			copyBestErrorWeightsToWeights();
      return;
    }
  }
示例#8
0
  /* (non-Javadoc)
   * @see com.alpine.datamining.api.impl.db.AbstractDBModelTrainer#train(com.alpine.datamining.api.AnalyticSource)
   */
  @Override
  protected Model train(AnalyticSource source) throws AnalysisException {
    ResultSet rs = null;
    Statement st = null;
    EMModel trainModel = null;
    try {

      IDataSourceInfo dataSourceInfo =
          DataSourceInfoFactory.createConnectionInfo(source.getDataSourceType());
      dbtype = dataSourceInfo.getDBType();

      EMConfig config = (EMConfig) source.getAnalyticConfig();

      String anaColumns = config.getColumnNames();
      String[] columnsArray = anaColumns.split(",");
      List<String> transformColumns = new ArrayList<String>();
      for (int i = 0; i < columnsArray.length; i++) {

        transformColumns.add(columnsArray[i]);
      }
      DataSet dataSet = getDataSet((DataBaseAnalyticSource) source, config);
      filerColumens(dataSet, transformColumns);
      dataSet.computeAllColumnStatistics();
      ColumnTypeTransformer transformer = new ColumnTypeTransformer();
      DataSet newDataSet = transformer.TransformCategoryToNumeric_new(dataSet);
      String tableName = ((DBTable) newDataSet.getDBTable()).getTableName();
      Columns columns = newDataSet.getColumns();
      List<String> newTransformColumns = new ArrayList<String>();
      HashMap<String, String> transformMap = new HashMap<String, String>();
      for (String key : transformer.getAllTransformMap_valueKey().keySet()) {
        HashMap<String, String> values = (transformer.getAllTransformMap_valueKey()).get(key);
        for (String lowKey : values.keySet()) {
          transformMap.put(values.get(lowKey), lowKey);
        }
      }

      Iterator<Column> attributeIter = columns.iterator();
      while (attributeIter.hasNext()) {
        Column column = attributeIter.next();
        newTransformColumns.add(column.getName());
      }

      int maxIterationNumber = Integer.parseInt(config.getMaxIterationNumber());
      int clusterNumber = Integer.parseInt(config.getClusterNumber());
      double epsilon = Double.parseDouble(config.getEpsilon());
      int initClusterSize = 10;
      if (config.getInitClusterSize() != null) {
        initClusterSize = Integer.parseInt(config.getInitClusterSize());
      }
      if (newDataSet.size() < initClusterSize * clusterNumber) {
        initClusterSize = (int) (newDataSet.size() / clusterNumber + 1);
      } // TODO  get it from config and make sure it will not be too large
      EMClusterImpl emImpl = EMClusterFactory.createEMAnalyzer(dbtype);
      trainModel = EMClusterFactory.createEMModel(dbtype, newDataSet);
      Connection connection = null;
      connection = ((DataBaseAnalyticSource) source).getConnection();

      st = connection.createStatement();

      ArrayList<Double> tempResult =
          emImpl.emTrain(
              connection,
              st,
              tableName,
              maxIterationNumber,
              epsilon,
              clusterNumber,
              newTransformColumns,
              initClusterSize,
              trainModel);
      trainModel = generateEMModel(trainModel, newTransformColumns, clusterNumber, tempResult);
      if (!newDataSet.equals(this.dataSet)) {
        trainModel.setAllTransformMap_valueKey(transformMap);
      }
    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    } finally {
      try {
        if (st != null) {
          st.close();
        }
        if (rs != null) {
          rs.close();
        }
      } catch (SQLException e) {
        logger.debug(e.toString());
        throw new AnalysisException(e.getLocalizedMessage());
      }
    }
    return trainModel;
  }
  @Override
  protected AnalyzerOutPutObject doValidate(DataBaseAnalyticSource source, EvaluatorConfig config)
      throws AnalysisException {

    DoubleListData result = null;
    List<DoubleListData> resultList = new ArrayList<DoubleListData>();

    try {
      DataSet dataSet = getDataSet(source, source.getAnalyticConfig());
      //			dataSet.recalculateAllcolumnStatistics();
      Operator operator = OperatorUtil.createOperator(LiftDataGeneratorGeneral.class);
      EvaluatorParameter parameter = new EvaluatorParameter();
      if (config.getUseModel() != null) {
        //				operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_USE_MODEL,
        // config.getUseModel());
        parameter.setUseModel(Boolean.parseBoolean(config.getUseModel()));
      }
      if (config.getColumnValue() != null) {
        //				operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_TARGET_CLASS,
        // config.getColumnValue());
        parameter.setColumnValue(config.getColumnValue());
      }
      operator.setParameter(parameter);
      if (config.getUseModel().equals("true")) {
        List<EngineModel> models = config.getTrainedModel();
        Model model = null;
        for (int i = 0; i < models.size(); i++) {
          Container container = new Container();
          model = models.get(i).getModel();
          container = container.append(dataSet);
          container = container.append(model);
          Container resultContainer = operator.apply(container);
          result = resultContainer.get(DoubleListData.class);
          result.setSourceName(models.get(i).getName());
          resultList.add(result);
        }
      } else {
        Container container = new Container();
        String targetClass = config.getColumnValue();
        Column confidence =
            dataSet.getColumns().get(Column.CONFIDENCE_NAME + "(" + targetClass + ")");
        dataSet
            .getColumns()
            .setSpecialColumn(confidence, Column.CONFIDENCE_NAME + "_" + targetClass);
        container = container.append(dataSet);
        Container resultContainer = operator.apply(container);
        result = resultContainer.get(DoubleListData.class);
        result.setSourceName(getName());
        resultList.add(result);
      }

    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    }
    AnalyzerOutPutObject out = new AnalyzerOutPutObject(resultList);
    out.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale()));
    return out;
  }
  public Model learn(DataSet dataSet, LinearRegressionParameter para, String columnNames)
      throws OperatorException {
    this.dataSet = dataSet;
    ArrayList<String> columnNamesList = new ArrayList<String>();
    if (columnNames != null && !StringUtil.isEmpty(columnNames.trim())) {
      String[] columnNamesArray = columnNames.split(",");
      for (String s : columnNamesArray) {
        columnNamesList.add(s);
      }
    }
    transformer.setColumnNames(columnNamesList);
    transformer.setAnalysisInterActionModel(para.getAnalysisInterActionModel());
    newDataSet = transformer.TransformCategoryToNumeric_new(dataSet, groupbyColumn);
    DatabaseConnection databaseConnection =
        ((DBTable) newDataSet.getDBTable()).getDatabaseConnection();

    Column label = newDataSet.getColumns().getLabel();
    String labelName = StringHandler.doubleQ(label.getName());
    String tableName = ((DBTable) dataSet.getDBTable()).getTableName();

    String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName();

    try {
      st = databaseConnection.createStatement(false);
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    try {
      newDataSet.computeAllColumnStatistics();
      Columns atts = newDataSet.getColumns();

      Iterator<Column> atts_i = atts.iterator();

      int count = 0;
      String[] columnNamesArray = new String[atts.size()];
      while (atts_i.hasNext()) {
        Column att = atts_i.next();
        columnNamesArray[count] = att.getName();
        count++;
      }
      null_list = calculateNull(dataSet);
      null_list_group = calculateNullGroup(newDataSet, atts);
      StringBuilder sb_notNull = getWhere(atts);

      getCoefficientAndR2Group(
          columnNames,
          dataSet,
          labelName,
          tableName,
          newTableName,
          atts,
          columnNamesArray,
          sb_notNull);

      HashMap<String, Long> degreeOfFreedom = new HashMap<String, Long>();
      for (String groupValue : groupCount.keySet()) {
        long tempDof = groupCount.get(groupValue) - columnNamesArray.length - 1;
        if (tempDof <= 0) {
          model.getOneModel(groupValue).setS(Double.NaN);
        }
        degreeOfFreedom.put(groupValue, tempDof);
      }

      //			if (dof <= 0)
      //			{
      //				model.setS(Double.NaN);
      //				return model;
      //			}
      StringBuffer sSQL =
          createSSQLLGroup(
              newDataSet, newTableName, label, columnNamesArray, coefficients, sb_notNull);
      HashMap<String, Double> sValueMap = new HashMap<String, Double>();

      try {
        itsLogger.debug(classLogInfo + ".learn():sql=" + sSQL);
        rs = st.executeQuery(sSQL.toString());
        while (rs.next()) {
          String groupValue = rs.getString(2);
          if (groupValue == null) {
            continue;
          }
          if (dataErrorList.contains(groupValue)) {
            sValueMap.put(groupValue, Double.NaN);
          } else {
            sValueMap.put(groupValue, rs.getDouble(1));
          }
        }
        rs.close();
      } catch (SQLException e) {
        e.printStackTrace();
        itsLogger.error(e.getMessage(), e);
        throw new OperatorException(e.getLocalizedMessage());
      }

      HashMap<String, Matrix> varianceCovarianceMatrix =
          getVarianceCovarianceMatrixGroup(newTableName, columnNamesArray, st);
      for (String groupValue : varianceCovarianceMatrix.keySet()) {
        if (varianceCovarianceMatrix == null) {
          model
              .getOneModel(groupValue)
              .setErrorString(
                  AlpineDataAnalysisLanguagePack.getMessage(
                          AlpineDataAnalysisLanguagePack.MATRIX_IS_SIGULAR,
                          AlpineThreadLocal.getLocale())
                      + Tools.getLineSeparator());
        }
      }
      caculateStatistics(
          columnNamesArray,
          coefficients,
          model,
          sValueMap,
          varianceCovarianceMatrix,
          degreeOfFreedom);
      for (String groupValue : null_list_group.keySet()) {
        if (null_list_group.get(groupValue).size() != 0) {
          StringBuilder sb_null = new StringBuilder();
          for (int i = 0; i < null_list_group.get(groupValue).size(); i++) {
            sb_null
                .append(StringHandler.doubleQ(null_list_group.get(groupValue).get(i)))
                .append(",");
          }
          sb_null = sb_null.deleteCharAt(sb_null.length() - 1);
          String table_exist_null =
              AlpineDataAnalysisLanguagePack.getMessage(
                  AlpineDataAnalysisLanguagePack.TABLE_EXIST_NULL, AlpineThreadLocal.getLocale());
          String[] temp = table_exist_null.split(";");
          model
              .getOneModel(groupValue)
              .setErrorString(temp[0] + sb_null.toString() + temp[1] + Tools.getLineSeparator());
        }
      }
      if (transformer.isTransform()) {
        dropTable(newTableName);
      }
      st.close();
      itsLogger.debug(LogUtils.exit(classLogInfo, "learn", model.toString()));
      model.setGroupByColumn(groupbyColumn);
      return model;

    } catch (Exception e) {
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
  }
  protected HashMap<String, ArrayList<String>> calculateNullGroup(DataSet dataSet, Columns atts)
      throws OperatorException {

    DatabaseConnection databaseConnection =
        ((DBTable) dataSet.getDBTable()).getDatabaseConnection();
    String tableName = ((DBTable) dataSet.getDBTable()).getTableName();
    HashMap<String, ArrayList<String>> resultList = new HashMap<String, ArrayList<String>>();
    Iterator<Column> i = dataSet.getColumns().iterator();
    StringBuilder sb_count = new StringBuilder("select ");
    sb_count.append(StringHandler.doubleQ(groupbyColumn)).append(", count(*) ");

    sb_count
        .append(" from ")
        .append(tableName)
        .append(this.getWhere(atts))
        .append(" group by ")
        .append(StringHandler.doubleQ(groupbyColumn));
    try {
      Statement st = databaseConnection.createStatement(false);
      itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString());

      ResultSet rs = st.executeQuery(sb_count.toString());
      while (rs.next()) {
        String groupValue = "";
        groupValue = rs.getString(1);
        this.groupCount.put(groupValue, rs.getInt(2));
      }
    } catch (SQLException e) {
      e.printStackTrace();
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
    sb_count =
        new StringBuilder("select ").append(StringHandler.doubleQ(groupbyColumn)).append(", ");
    while (i.hasNext()) {
      Column att = i.next();
      sb_count.append("count(").append(StringHandler.doubleQ(att.getName())).append(")");
      sb_count.append(StringHandler.doubleQ(att.getName())).append(",");
    }
    sb_count = sb_count.deleteCharAt(sb_count.length() - 1);
    sb_count
        .append(" from ")
        .append(tableName)
        .append(this.getWhere(atts))
        .append(" group by ")
        .append(StringHandler.doubleQ(groupbyColumn));
    try {
      Statement st = databaseConnection.createStatement(false);
      itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString());

      ResultSet rs = st.executeQuery(sb_count.toString());
      while (rs.next()) {
        ArrayList<String> null_list = new ArrayList<String>();
        String groupValue = "";
        groupValue = rs.getString(1);
        for (int j = 1; j < rs.getMetaData().getColumnCount(); j++) {
          if (rs.getFloat(j + 1) != groupCount.get(groupValue)) {
            null_list.add(
                dataSet.getColumns().get(rs.getMetaData().getColumnName(j + 1)).getName());
          }
        }
        resultList.put(groupValue, null_list);
      }
    } catch (SQLException e) {
      e.printStackTrace();
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
    return resultList;
  }
示例#12
0
  private void performOperation(
      DatabaseConnection databaseConnection, DataSet dataSet, Locale locale)
      throws AnalysisError, OperatorException {
    String outputTableName = getQuotaedTableName(getOutputSchema(), getOutputTable());

    String inputTableName = getQuotaedTableName(getInputSchema(), getInputTable());

    Columns atts = dataSet.getColumns();
    String dbType = databaseConnection.getProperties().getName();
    IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(dbType);

    IMultiDBUtility multiDBUtility = MultiDBUtilityFactory.createConnectionInfo(dbType);

    ISqlGeneratorMultiDB sqlGenerator = SqlGeneratorMultiDBFactory.createConnectionInfo(dbType);

    dropIfExist(dataSet);

    DatabaseUtil.alterParallel(databaseConnection, getOutputType()); // for oracle

    StringBuilder sb_create = new StringBuilder("create ");
    StringBuilder insertTable = new StringBuilder();

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(" table ");
    } else {
      sb_create.append(" view ");
    }
    sb_create.append(outputTableName);
    sb_create.append(
        getOutputType().equalsIgnoreCase(Resources.TableType) ? getAppendOnlyString() : "");
    sb_create.append(DatabaseUtil.addParallel(databaseConnection, getOutputType())).append(" as (");
    StringBuilder selectSql = new StringBuilder(" select ");

    selectSql.append(StringHandler.doubleQ(groupColumn)).append(",");

    Column att = atts.get(columnNames);
    dataSet.computeColumnStatistics(att);
    if (att.isNumerical()) {
      logger.error("PivotTableAnalyzer cannot accept numeric type column");
      throw new AnalysisError(
          this,
          AnalysisErrorName.Not_numeric,
          locale,
          SDKLanguagePack.getMessage(SDKLanguagePack.PIVOT_NAME, locale));
    }
    String attName = StringHandler.doubleQ(att.getName());
    List<String> valueList = att.getMapping().getValues();
    if (!useArray
        && valueList.size() > Integer.parseInt(AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD)) {
      logger.error("Too many distinct value for column " + StringHandler.doubleQ(columnNames));
      throw new AnalysisError(
          this,
          AnalysisErrorName.Too_Many_Distinct_value,
          locale,
          StringHandler.doubleQ(columnNames),
          AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD);
    }

    if (valueList.size() <= 0) {
      logger.error("Empty table");
      throw new AnalysisError(this, AnalysisErrorName.Empty_table, locale);
    }

    String aggColumnName;
    if (!StringUtil.isEmpty(aggColumn)) {
      aggColumnName = StringHandler.doubleQ(aggColumn);
    } else {
      aggColumnName = "1";
    }

    Iterator<String> valueList_i = valueList.iterator();

    if (useArray) {
      if (dataSourceInfo.getDBType().equals(DataSourceInfoOracle.dBType)) {
        ArrayList<String> array = new ArrayList<String>();
        while (valueList_i.hasNext()) {
          String value = StringHandler.escQ(valueList_i.next());
          String newValue =
              "alpine_miner_null_to_0("
                  + aggrType
                  + " (case when "
                  + attName
                  + "="
                  + CommonUtility.quoteValue(dbType, att, value)
                  + " then "
                  + aggColumnName
                  + " end )) ";
          array.add(newValue);
        }
        selectSql.append(
            CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float));
      } else {
        selectSql.append(multiDBUtility.floatArrayHead());
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
        selectSql.append(multiDBUtility.floatArrayTail());
      }
      selectSql.append(" " + StringHandler.doubleQ(att.getName()));
    } else {
      if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoNZ.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoDB2.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (double(case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end ))) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      }
    }
    selectSql.append(" from ").append(inputTableName).append(" foo group by ");
    selectSql.append(StringHandler.doubleQ(groupColumn));

    if (((DBTable) dataSet.getDBTable())
        .getDatabaseConnection()
        .getProperties()
        .getName()
        .equals(DataSourceInfoNZ.dBType)) {
      StringBuilder sb = new StringBuilder();
      sb.append("select ").append(StringHandler.doubleQ(groupColumn)).append(",");
      Iterator<String> valueList_new = valueList.iterator();
      while (valueList_new.hasNext()) {
        String value = valueList_new.next();
        String colName = StringHandler.doubleQ(att.getName() + "_" + value);
        sb.append("case when ").append(colName).append(" is null then 0 else ");
        sb.append(colName).append(" end ").append(colName).append(",");
      }
      sb = sb.deleteCharAt(sb.length() - 1);
      sb.append(" from (").append(selectSql).append(") foo ");
      selectSql = sb;
    }
    sb_create.append(selectSql).append(" )");

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(getEndingString());
      insertTable.append(sqlGenerator.insertTable(selectSql.toString(), outputTableName));
    }
    try {
      Statement st = databaseConnection.createStatement(false);
      logger.debug("PivotTableAnalyzer.performOperation():sql=" + sb_create);
      st.execute(sb_create.toString());

      if (insertTable.length() > 0) {
        st.execute(insertTable.toString());
        logger.debug("PivotTableAnalyzer.performOperation():insertTableSql=" + insertTable);
      }
    } catch (SQLException e) {
      logger.error(e);
      if (e.getMessage().startsWith("ORA-03001")
          || e.getMessage().startsWith("ERROR:  invalid identifier")) {
        throw new AnalysisError(this, AnalysisErrorName.Invalid_Identifier, locale);
      } else {
        throw new OperatorException(e.getLocalizedMessage());
      }
    }
  }