コード例 #1
0
  @Override
  protected AnalyzerOutPutDataBaseUpdate doPredict(
      DataBaseAnalyticSource source, PredictorConfig config) throws AnalysisException {
    DataSet dataSet = null;
    try {
      dataSet = getDataSet(source, source.getAnalyticConfig());
      config.getTrainedModel().getModel().apply(dataSet);
    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    }
    AnalyzerOutPutDataBaseUpdate result = new AnalyzerOutPutDataBaseUpdate();

    result.setDataset(dataSet);

    //		 set url user pwd ,schema, table
    fillDBInfo(result, (DataBaseAnalyticSource) source);

    Model model = config.getTrainedModel().getModel();
    result.setUpdatedColumns(((NBModel) model).getUpdateColumns());
    // good, bad ,

    result.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale()));
    return result;
  }
コード例 #2
0
  @Override
  protected Model train(AnalyticSource analyticSource) throws AnalysisException {
    ResultSet rs = null;
    Statement st = null;
    try {
      IDataSourceInfo dataSourceInfo =
          DataSourceInfoFactory.createConnectionInfo(analyticSource.getDataSourceType());
      dbtype = dataSourceInfo.getDBType();
      RandomForestModel lastResult = null;
      RandomForestIMP randomForestImpl = null;
      // if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
      // randomForestTrainer = new AdaboostOracle();
      //
      // } else
      if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType)
          || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) {
        randomForestImpl = new RandomForestGreenplum();

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
        randomForestImpl = new RandomForestOracle();

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) {
        randomForestImpl = new RandomForestDB2();
        ((RandomForestDB2) randomForestImpl)
            .setConnection(((DataBaseAnalyticSource) analyticSource).getConnection());
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) {
        randomForestImpl = new RandomForestNZ();
      } else {
        throw new AnalysisException("Databse type is not supported for Random Forest:" + dbtype);
        //				return null;
      }

      try {
        dataSet =
            getDataSet((DataBaseAnalyticSource) analyticSource, analyticSource.getAnalyticConfig());
      } catch (OperatorException e1) {
        logger.error(e1);
        throw new OperatorException(e1.getLocalizedMessage());
      }
      setSpecifyColumn(dataSet, analyticSource.getAnalyticConfig());
      dataSet.computeAllColumnStatistics();

      RandomForestConfig rfConfig = (RandomForestConfig) analyticSource.getAnalyticConfig();

      String dbSystem = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getSystem();

      String url = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUrl();
      String userName = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUserName();
      String password = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getPassword();
      String inputSchema = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getSchema();
      String tableName = ((DataBaseAnalyticSource) analyticSource).getTableInfo().getTableName();
      String useSSL = ((DataBaseAnalyticSource) analyticSource).getDataBaseInfo().getUseSSL();
      String sampleWithReplacement = rfConfig.getSampleWithReplacement();
      long timeStamp = System.currentTimeMillis();
      pnewTable = "pnew" + timeStamp;
      sampleTable = "s" + timeStamp;
      String dependentColumn = rfConfig.getDependentColumn();
      String columnNames = rfConfig.getColumnNames();
      String[] totalColumns = columnNames.split(",");
      int subSize = Integer.parseInt(rfConfig.getNodeColumnNumber());
      int forestSize = Integer.parseInt(rfConfig.getForestSize());

      Connection conncetion = null;

      if (dbtype.equalsIgnoreCase(DataSourceInfoGreenplum.dBType)
          || dbtype.equalsIgnoreCase(DataSourceInfoPostgres.dBType)) {

        lastResult = new RandomForestModelGreenplum(dataSet);
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoOracle.dBType)) {
        lastResult = new RandomForestModelOracle(dataSet);

      } else if (dbtype.equalsIgnoreCase(DataSourceInfoDB2.dBType)) {
        lastResult = new RandomForestModelDB2(dataSet);
      } else if (dbtype.equalsIgnoreCase(DataSourceInfoNZ.dBType)) {
        lastResult = new RandomForestModelNZ(dataSet);
      }
      lastResult.setColumnNames(columnNames);
      lastResult.setDependColumn(dependentColumn);
      lastResult.setTableName(tableName);

      conncetion = ((DataBaseAnalyticSource) analyticSource).getConnection();

      Model result = null;

      try {
        st = conncetion.createStatement();
      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e);
      }

      //			Iterator<String> dependvalueIterator = dataSet.getColumns()
      //					.getLabel().getMapping().getValues().iterator();
      if (dataSet.getColumns().getLabel() instanceof NominalColumn) {
        if (dataSet.getColumns().getLabel().getMapping().getValues().size() <= 1) {
          String e =
              SDKLanguagePack.getMessage(
                  SDKLanguagePack.ADABOOST_SAMPLE_ERRINFO, rfConfig.getLocale());
          logger.error(e);
          throw new AnalysisException(e);
        }
        if (dataSet.getColumns().getLabel().getMapping().getValues().size()
            > AlpineMinerConfig.ADABOOST_MAX_DEPENDENT_COUNT) {
          String e =
              SDKLanguagePack.getMessage(
                  SDKLanguagePack.ADABOOST_MAX_DEPENDENT_COUNT_ERRINFO, rfConfig.getLocale());
          logger.error(e);
          throw new AnalysisException(e);
        }
      }

      try {

        randomForestImpl.randomForestTrainInit(
            inputSchema, tableName, timeStamp, dependentColumn, st, dataSet);

      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e);
      }

      CartConfig config = new CartConfig();
      config.setDependentColumn(dependentColumn);
      config.setConfidence(rfConfig.getConfidence());
      config.setMaximal_depth(rfConfig.getMaximal_depth());
      config.setMinimal_leaf_size(rfConfig.getMinimal_leaf_size());
      config.setMinimal_size_for_split(rfConfig.getMinimal_size_for_split());
      config.setNo_pre_pruning("true");
      config.setNo_pruning("true");

      for (int i = 0; i < forestSize; i++) {
        CartTrainer analyzer = new CartTrainer();

        if (sampleWithReplacement == Resources.TrueOpt) {
          randomForestImpl.randomForestSample(
              inputSchema,
              timeStamp + "" + i,
              dependentColumn,
              st,
              rs,
              pnewTable,
              sampleTable + i,
              rfConfig.getLocale());
        } else {
          randomForestImpl.randomForestSampleNoReplace(
              inputSchema,
              timeStamp + "" + i,
              dependentColumn,
              st,
              rs,
              pnewTable,
              sampleTable + i,
              rfConfig.getLocale(),
              dataSet.size());
        }
        String subColumns = getSubColumns(totalColumns, subSize);

        config.setColumnNames(subColumns);
        DataBaseAnalyticSource tempsource =
            new DataBaseAnalyticSource(
                dbSystem, url, userName, password, inputSchema, sampleTable + i, useSSL);
        tempsource.setAnalyticConfiguration(config);
        tempsource.setConenction(conncetion);
        result =
            ((AnalyzerOutPutTrainModel) analyzer.doAnalysis(tempsource))
                .getEngineModel()
                .getModel();
        String OOBTable = "OOB" + sampleTable + i;

        randomForestImpl.generateOOBTable(
            inputSchema, OOBTable, pnewTable, sampleTable + i, st, rs);

        DataBaseAnalyticSource tempPredictSource =
            new DataBaseAnalyticSource(
                dbSystem, url, userName, password, inputSchema, OOBTable, useSSL);

        String predictOutTable = "OOBPredict" + sampleTable;
        EngineModel em = new EngineModel();
        em.setModel(result);
        PredictorConfig tempconfig = new PredictorConfig(em);
        tempconfig.setDropIfExist(dropIfExists);
        tempconfig.setOutputSchema(inputSchema);
        tempconfig.setOutputTable(predictOutTable);
        tempPredictSource.setAnalyticConfiguration(tempconfig);
        tempPredictSource.setConenction(conncetion);
        AbstractDBModelPredictor predictor = new CartPredictor();
        predictor.doAnalysis(tempPredictSource); // use the weak alg , do

        double OOBError = 0.0;
        if (result instanceof DecisionTreeModel) {
          OOBError =
              randomForestImpl.getOOBError(
                  tempPredictSource, dependentColumn, "P(" + dependentColumn + ")");
          lastResult.getOobEstimateError().add(OOBError);
        } else if (result instanceof RegressionTreeModel) {
          OOBError = randomForestImpl.getMSE(tempPredictSource, "P(" + dependentColumn + ")");
          lastResult.getOobLoss().add(OOBError);
          double OOBMape =
              randomForestImpl.getMAPE(
                  tempPredictSource, dependentColumn, "P(" + dependentColumn + ")");
          lastResult.getOobMape().add(OOBMape);
        } else {
          OOBError = Double.NaN;
          lastResult.getOobLoss().add(OOBError);
        }

        lastResult.addModel((SingleModel) result);
        randomForestImpl.clearTrainResult(inputSchema, sampleTable + i);
        randomForestImpl.clearTrainResult(inputSchema, predictOutTable);
        randomForestImpl.clearTrainResult(inputSchema, OOBTable);
      }
      return lastResult;

    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    } finally {
      try {
        if (st != null) {
          st.close();
        }
        if (rs != null) {
          rs.close();
        }
      } catch (SQLException e) {
        logger.error(e);
        throw new AnalysisException(e.getLocalizedMessage());
      }
    }
  }
コード例 #3
0
  @Override
  protected AnalyzerOutPutObject doValidate(DataBaseAnalyticSource source, EvaluatorConfig config)
      throws AnalysisException {

    DoubleListData result = null;
    List<DoubleListData> resultList = new ArrayList<DoubleListData>();

    try {
      DataSet dataSet = getDataSet(source, source.getAnalyticConfig());
      //			dataSet.recalculateAllcolumnStatistics();
      Operator operator = OperatorUtil.createOperator(LiftDataGeneratorGeneral.class);
      EvaluatorParameter parameter = new EvaluatorParameter();
      if (config.getUseModel() != null) {
        //				operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_USE_MODEL,
        // config.getUseModel());
        parameter.setUseModel(Boolean.parseBoolean(config.getUseModel()));
      }
      if (config.getColumnValue() != null) {
        //				operator.setParameter(LiftDataGeneratorGeneral.PARAMETER_TARGET_CLASS,
        // config.getColumnValue());
        parameter.setColumnValue(config.getColumnValue());
      }
      operator.setParameter(parameter);
      if (config.getUseModel().equals("true")) {
        List<EngineModel> models = config.getTrainedModel();
        Model model = null;
        for (int i = 0; i < models.size(); i++) {
          Container container = new Container();
          model = models.get(i).getModel();
          container = container.append(dataSet);
          container = container.append(model);
          Container resultContainer = operator.apply(container);
          result = resultContainer.get(DoubleListData.class);
          result.setSourceName(models.get(i).getName());
          resultList.add(result);
        }
      } else {
        Container container = new Container();
        String targetClass = config.getColumnValue();
        Column confidence =
            dataSet.getColumns().get(Column.CONFIDENCE_NAME + "(" + targetClass + ")");
        dataSet
            .getColumns()
            .setSpecialColumn(confidence, Column.CONFIDENCE_NAME + "_" + targetClass);
        container = container.append(dataSet);
        Container resultContainer = operator.apply(container);
        result = resultContainer.get(DoubleListData.class);
        result.setSourceName(getName());
        resultList.add(result);
      }

    } catch (Exception e) {
      logger.error(e);
      if (e instanceof WrongUsedException) {
        throw new AnalysisError(this, (WrongUsedException) e);
      } else if (e instanceof AnalysisError) {
        throw (AnalysisError) e;
      } else {
        throw new AnalysisException(e);
      }
    }
    AnalyzerOutPutObject out = new AnalyzerOutPutObject(resultList);
    out.setAnalyticNodeMetaInfo(createNodeMetaInfo(config.getLocale()));
    return out;
  }