Пример #1
0
 /**
  * copy property's values from SVD operator to SVD Calculator operator this function should be
  * called when update SVD operator's properties and connect SVD operator to SVD calculator.
  *
  * @param svdOperator
  * @param calculatorOperator
  */
 public static void syncSVDParams2SVDCalculator(
     Operator svdOperator, Operator calculatorOperator) {
   final int svdOperatorIdx = 1, svdCalculatorOperatorIdx = 0;
   final String[][] syncParamsNameMapping = {
     {OperatorParameter.NAME_UmatrixFullTable, OperatorParameter.NAME_UmatrixTable},
     {OperatorParameter.NAME_RowNameF, OperatorParameter.NAME_RowName},
     {OperatorParameter.NAME_UdependentColumn, OperatorParameter.NAME_dependentColumn},
     {OperatorParameter.NAME_VmatrixFullTable, OperatorParameter.NAME_VmatrixTable},
     {OperatorParameter.NAME_ColNameF, OperatorParameter.NAME_ColName},
     {OperatorParameter.NAME_VdependentColumn, OperatorParameter.NAME_dependentColumn},
     {OperatorParameter.NAME_SmatrixFullTable, OperatorParameter.NAME_singularValueTable},
     {OperatorParameter.NAME_SdependentColumn, OperatorParameter.NAME_dependentColumn}
   };
   for (String[] syncParamNameMapping : syncParamsNameMapping) {
     Object value =
         svdOperator.getOperatorParameter(syncParamNameMapping[svdOperatorIdx]).getValue();
     boolean isTable = false;
     String schemaName = null;
     if (syncParamNameMapping[svdOperatorIdx] == OperatorParameter.NAME_UmatrixTable) {
       schemaName =
           StringHandler.doubleQ(
               (String)
                   svdOperator
                       .getOperatorParameter(OperatorParameter.NAME_UmatrixSchema)
                       .getValue());
       isTable |= true;
     } else if (syncParamNameMapping[svdOperatorIdx] == OperatorParameter.NAME_VmatrixTable) {
       schemaName =
           StringHandler.doubleQ(
               (String)
                   svdOperator
                       .getOperatorParameter(OperatorParameter.NAME_VmatrixSchema)
                       .getValue());
       isTable |= true;
     } else if (syncParamNameMapping[svdOperatorIdx]
         == OperatorParameter.NAME_singularValueTable) {
       schemaName =
           StringHandler.doubleQ(
               (String)
                   svdOperator
                       .getOperatorParameter(OperatorParameter.NAME_singularValueSchema)
                       .getValue());
       isTable |= true;
     }
     if (isTable) {
       value = schemaName + "." + StringHandler.doubleQ((String) value);
     }
     calculatorOperator
         .getOperatorParameter(syncParamNameMapping[svdCalculatorOperatorIdx])
         .setValue(value);
   }
 }
 protected StringBuilder getWhere(Columns atts) {
   //		if (null_list_group.size() == 0){
   //			return new StringBuilder("");
   //		}
   Iterator<Column> atts_i;
   StringBuilder sb_notNull = new StringBuilder(" where ");
   atts_i = atts.iterator();
   while (atts_i.hasNext()) {
     Column att = atts_i.next();
     //			if(null_list.get(att).contains(att.getName()))
     //			{
     sb_notNull.append(StringHandler.doubleQ(att.getName())).append(" is not null and ");
     //			}
   }
   sb_notNull.append(StringHandler.doubleQ(this.groupbyColumn)).append(" is not null and ");
   sb_notNull.delete(sb_notNull.length() - 4, sb_notNull.length());
   return sb_notNull;
 }
Пример #3
0
  protected StringBuffer getNominalUpdate(
      DataSet dataSet,
      Column predictedLabel,
      String newTableName,
      int numberOfClasses,
      StringBuffer[] classProbabilitiesSql,
      StringBuffer caseSql,
      StringBuffer[] biggerSql) {
    StringBuffer sql = new StringBuffer();
    caseSql.append(" (case ");
    for (int c = 0; c < numberOfClasses - 1; c++) {
      caseSql
          .append(" when ")
          .append(biggerSql[c])
          .append(" then '")
          .append(StringHandler.escQ(getLabel().getMapping().mapIndex(c)))
          .append("'");
    }
    caseSql
        .append(" else '")
        .append(StringHandler.escQ(getLabel().getMapping().mapIndex(numberOfClasses - 1)))
        .append("' end)");

    sql.append(
            " update " + newTableName + " set " + StringHandler.doubleQ(predictedLabel.getName()))
        .append("=")
        .append(caseSql);
    for (int c = 0; c < numberOfClasses; c++) {
      sql.append(
              ", "
                  + StringHandler.doubleQ(
                      dataSet
                          .getColumns()
                          .getSpecial(
                              Column.CONFIDENCE_NAME + "_" + getLabel().getMapping().mapIndex(c))
                          .getName()))
          .append(" = ")
          .append(classProbabilitiesSql[c]);
    }
    sql.append(getWherePredict());
    return sql;
  }
Пример #4
0
 public String getColumnNames(boolean prediction) {
   ArrayList<String> array = new ArrayList<String>();
   for (int i = 0; i < nodeInputs.length; i++) {
     if (prediction) {
       array.add(nodeInputs[i].getTransformValue().toString());
     } else {
       array.add(StringHandler.doubleQ(nodeInputs[i].getColumn().getName()));
     }
   }
   return CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float).toString();
 }
Пример #5
0
 @Override
 public String insertTempTable(
     String tempTableName, String columnNames, String inputTableName, String sampleColumn) {
   String sqlCreateSample0 =
       "INSERT INTO "
           + tempTableName
           + " SELECT *,row_number() over(partition by "
           + StringHandler.doubleQ(sampleColumn)
           + " order by random()) AS alpine_sample_id, random() as rand_order from "
           + inputTableName
           + sqlGenerator.setCreateTableEndingSql(null);
   return sqlCreateSample0;
 }
Пример #6
0
  public void adaboostTrainSample(
      String inputSchema,
      long timeStamp,
      String dependentColumn,
      Statement st,
      ResultSet rs,
      String pnewTable,
      Locale locale)
      throws SQLException, AnalysisException {
    try {
      int breakLoop = 0;
      int maxLoop = AlpineMinerConfig.ADABOOST_SAMPLE;
      while (breakLoop != 1 && maxLoop != 0) {
        adaboostTrainSampleOnce(inputSchema, timeStamp, st, pnewTable);
        StringBuffer sql = new StringBuffer();
        sql.append("select count(distinct ");
        sql.append(StringHandler.doubleQ(dependentColumn));
        sql.append(") from ");
        sql.append(StringHandler.doubleQ(inputSchema)).append(".");
        sql.append(StringHandler.doubleQ(sampleTable));
        sql.append(" ");
        logger.debug(sql.toString());
        rs = st.executeQuery(sql.toString());
        while (rs.next()) if (rs.getInt(1) > 1) breakLoop = 1;
        maxLoop--;
      }
      if (breakLoop != 1) {
        String e = SDKLanguagePack.getMessage(SDKLanguagePack.ADABOOST_SAMPLE_FAIL, locale);

        throw new AnalysisException(e);
      }
    } catch (SQLException e) {
      logger.error(e.getMessage(), e);
      throw new AnalysisException(e);
    }
  }
Пример #7
0
  public Model train(DataSet dataSet, SVMParameter parameter) throws OperatorException {
    para = parameter;
    setDataSourceInfo(
        DataSourceInfoFactory.createConnectionInfo(
            ((DBTable) dataSet.getDBTable()).getDatabaseConnection().getProperties().getName()));
    DatabaseConnection databaseConnection =
        ((DBTable) dataSet.getDBTable()).getDatabaseConnection();
    Column label = dataSet.getColumns().getLabel();
    String labelString = StringHandler.doubleQ(label.getName());

    DataSet newDataSet = getTransformer().TransformCategoryToNumeric_new(dataSet);
    String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName();

    Statement st = null;
    ResultSet rs = null;
    try {
      st = databaseConnection.createStatement(false);
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    StringBuffer ind = getColumnArray(newDataSet);
    StringBuffer where = getColumnWhere(newDataSet);
    where.append(" and ").append(labelString).append(" is not null ");
    SVMRegressionModel model = new SVMRegressionModel(dataSet, newDataSet);
    if (!newDataSet.equals(dataSet)) {
      model.setAllTransformMap_valueKey(getTransformer().getAllTransformMap_valueKey());
    }
    model.setKernelType(para.getKernelType());
    model.setDegree(para.getDegree());
    model.setGamma(para.getGamma());

    String sql =
        "select (model).inds, (model).cum_err, (model).epsilon, (model).rho, (model).b, (model).nsvs, (model).ind_dim, (model).weights, (model).individuals from (select alpine_miner_online_sv_reg('"
            + newTableName
            + "','"
            + ind
            + "','"
            + labelString
            + "','"
            + where
            + "',"
            + para.getKernelType()
            + ","
            + para.getDegree()
            + ","
            + para.getGamma()
            + ","
            + para.getEta()
            + ","
            + ((SVMRegressionParameter) para).getSlambda()
            + ","
            + para.getNu()
            + ") as model ";
    if (getDataSourceInfo().getDBType().equals(DataSourceInfoOracle.dBType)) {
      sql += " from dual ";
    }
    sql += ") a";
    try {
      itsLogger.debug("SVMRegression.train():sql=" + sql);
      rs = st.executeQuery(sql.toString());
      setModel(rs, model);
      if (getTransformer().isTransform()) {
        dropTable(st, newTableName);
      }
      rs.close();
      st.close();
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    return model;
  }
  protected double cacluateGroupRSquare(
      DataSet dataSet,
      String tableName,
      String labelName,
      LinearRegressionModelDB model,
      String groupValue,
      StringBuilder sb_notNull)
      throws OperatorException {
    //		String notnull=sb_notNull.toString().substring(beginIndex, endIndex);
    //		notnull.replace("where", " and ");
    double RSquare = 0.0;
    StringBuffer RSquareSQL = new StringBuffer();
    StringBuffer avgSQL = new StringBuffer();
    avgSQL
        .append(" select avg(")
        .append(labelName)
        .append(")  from ")
        .append(tableName)
        .append(sb_notNull)
        .append(" and ")
        .append(StringHandler.doubleQ(groupbyColumn))
        .append("=")
        .append(StringHandler.singleQ(groupValue));
    double avg = 0.0;
    try {
      itsLogger.debug(classLogInfo + ".cacluateRSquare():sql=" + avgSQL);
      ResultSet rs = st.executeQuery(avgSQL.toString());
      if (rs.next()) {
        avg = rs.getDouble(1);
      }
    } catch (SQLException e) {
      itsLogger.error(e.getMessage(), e);
      return Double.NaN;
    }

    StringBuffer predictedValueSQL = new StringBuffer();
    predictedValueSQL.append(model.generatePredictedString(dataSet));
    RSquareSQL.append("select 1 - sum((")
        .append(predictedValueSQL)
        .append("-")
        .append(labelName)
        .append(")*(")
        .append(predictedValueSQL)
        .append("-")
        .append(labelName)
        .append("))*1.0/sum((")
        .append(labelName)
        .append("-(")
        .append(avg)
        .append("))*(")
        .append(labelName)
        .append("-(")
        .append(avg)
        .append("))) from ")
        .append(tableName)
        .append(sb_notNull)
        .append(" and ")
        .append(StringHandler.doubleQ(groupbyColumn))
        .append("=")
        .append(StringHandler.singleQ(groupValue));
    try {
      itsLogger.debug("LinearRegressionImpPGGP.cacluateRSquare():sql=" + RSquareSQL);
      ResultSet rs = st.executeQuery(RSquareSQL.toString());
      if (rs.next()) {
        RSquare = rs.getDouble(1);
      }
    } catch (SQLException e) {
      itsLogger.error(e.getMessage(), e);
      return Double.NaN;
    }
    return RSquare;
  }
  public Model learn(DataSet dataSet, LinearRegressionParameter para, String columnNames)
      throws OperatorException {
    this.dataSet = dataSet;
    ArrayList<String> columnNamesList = new ArrayList<String>();
    if (columnNames != null && !StringUtil.isEmpty(columnNames.trim())) {
      String[] columnNamesArray = columnNames.split(",");
      for (String s : columnNamesArray) {
        columnNamesList.add(s);
      }
    }
    transformer.setColumnNames(columnNamesList);
    transformer.setAnalysisInterActionModel(para.getAnalysisInterActionModel());
    newDataSet = transformer.TransformCategoryToNumeric_new(dataSet, groupbyColumn);
    DatabaseConnection databaseConnection =
        ((DBTable) newDataSet.getDBTable()).getDatabaseConnection();

    Column label = newDataSet.getColumns().getLabel();
    String labelName = StringHandler.doubleQ(label.getName());
    String tableName = ((DBTable) dataSet.getDBTable()).getTableName();

    String newTableName = ((DBTable) newDataSet.getDBTable()).getTableName();

    try {
      st = databaseConnection.createStatement(false);
    } catch (SQLException e) {
      e.printStackTrace();
      throw new OperatorException(e.getLocalizedMessage());
    }
    try {
      newDataSet.computeAllColumnStatistics();
      Columns atts = newDataSet.getColumns();

      Iterator<Column> atts_i = atts.iterator();

      int count = 0;
      String[] columnNamesArray = new String[atts.size()];
      while (atts_i.hasNext()) {
        Column att = atts_i.next();
        columnNamesArray[count] = att.getName();
        count++;
      }
      null_list = calculateNull(dataSet);
      null_list_group = calculateNullGroup(newDataSet, atts);
      StringBuilder sb_notNull = getWhere(atts);

      getCoefficientAndR2Group(
          columnNames,
          dataSet,
          labelName,
          tableName,
          newTableName,
          atts,
          columnNamesArray,
          sb_notNull);

      HashMap<String, Long> degreeOfFreedom = new HashMap<String, Long>();
      for (String groupValue : groupCount.keySet()) {
        long tempDof = groupCount.get(groupValue) - columnNamesArray.length - 1;
        if (tempDof <= 0) {
          model.getOneModel(groupValue).setS(Double.NaN);
        }
        degreeOfFreedom.put(groupValue, tempDof);
      }

      //			if (dof <= 0)
      //			{
      //				model.setS(Double.NaN);
      //				return model;
      //			}
      StringBuffer sSQL =
          createSSQLLGroup(
              newDataSet, newTableName, label, columnNamesArray, coefficients, sb_notNull);
      HashMap<String, Double> sValueMap = new HashMap<String, Double>();

      try {
        itsLogger.debug(classLogInfo + ".learn():sql=" + sSQL);
        rs = st.executeQuery(sSQL.toString());
        while (rs.next()) {
          String groupValue = rs.getString(2);
          if (groupValue == null) {
            continue;
          }
          if (dataErrorList.contains(groupValue)) {
            sValueMap.put(groupValue, Double.NaN);
          } else {
            sValueMap.put(groupValue, rs.getDouble(1));
          }
        }
        rs.close();
      } catch (SQLException e) {
        e.printStackTrace();
        itsLogger.error(e.getMessage(), e);
        throw new OperatorException(e.getLocalizedMessage());
      }

      HashMap<String, Matrix> varianceCovarianceMatrix =
          getVarianceCovarianceMatrixGroup(newTableName, columnNamesArray, st);
      for (String groupValue : varianceCovarianceMatrix.keySet()) {
        if (varianceCovarianceMatrix == null) {
          model
              .getOneModel(groupValue)
              .setErrorString(
                  AlpineDataAnalysisLanguagePack.getMessage(
                          AlpineDataAnalysisLanguagePack.MATRIX_IS_SIGULAR,
                          AlpineThreadLocal.getLocale())
                      + Tools.getLineSeparator());
        }
      }
      caculateStatistics(
          columnNamesArray,
          coefficients,
          model,
          sValueMap,
          varianceCovarianceMatrix,
          degreeOfFreedom);
      for (String groupValue : null_list_group.keySet()) {
        if (null_list_group.get(groupValue).size() != 0) {
          StringBuilder sb_null = new StringBuilder();
          for (int i = 0; i < null_list_group.get(groupValue).size(); i++) {
            sb_null
                .append(StringHandler.doubleQ(null_list_group.get(groupValue).get(i)))
                .append(",");
          }
          sb_null = sb_null.deleteCharAt(sb_null.length() - 1);
          String table_exist_null =
              AlpineDataAnalysisLanguagePack.getMessage(
                  AlpineDataAnalysisLanguagePack.TABLE_EXIST_NULL, AlpineThreadLocal.getLocale());
          String[] temp = table_exist_null.split(";");
          model
              .getOneModel(groupValue)
              .setErrorString(temp[0] + sb_null.toString() + temp[1] + Tools.getLineSeparator());
        }
      }
      if (transformer.isTransform()) {
        dropTable(newTableName);
      }
      st.close();
      itsLogger.debug(LogUtils.exit(classLogInfo, "learn", model.toString()));
      model.setGroupByColumn(groupbyColumn);
      return model;

    } catch (Exception e) {
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
  }
  protected void getCoefficientAndR2Group(
      String columNames,
      DataSet dataSet,
      String labelName,
      String tableName,
      String newTableName,
      Columns atts,
      String[] columnNames,
      StringBuilder sb_notNull)
      throws OperatorException {
    Iterator<Column> atts_i;
    StringBuffer columnNamesArray = new StringBuffer();
    columnNamesArray.append("array[1.0,");

    atts_i = atts.iterator();
    int i = 0;
    while (atts_i.hasNext()) {
      Column att = atts_i.next();
      if (i != 0) {
        columnNamesArray.append(",");
      }
      columnNamesArray.append(StringHandler.doubleQ(att.getName())).append("::float");
      i++;
    }
    columnNamesArray.append("]");
    String sql = null;
    sql =
        "select alpine_miner_mregr_coef("
            + labelName
            + "::float,"
            + columnNamesArray
            + ") , "
            + StringHandler.doubleQ(groupbyColumn)
            + " from "
            + newTableName
            + sb_notNull
            + "     group by "
            + StringHandler.doubleQ(groupbyColumn);
    itsLogger.debug(classLogInfo + ".getCoefficientAndR2():sql=" + sql);
    HashMap<String, Matrix> XY = new HashMap<String, Matrix>();
    try {
      Object[] object = null;
      ResultSet rs = st.executeQuery(sql.toString());
      while (rs.next()) {
        Matrix tempXY = new Matrix(columnNames.length + 1, 1);

        Matrix tempHessian = new Matrix(columnNames.length + 1, columnNames.length + 1);
        String groupValue = rs.getString(2);
        object = (Object[]) rs.getArray(1).getArray();
        for (int x = 0; x < columnNames.length + 1; x++) {
          int y = x + 1;
          double doubleValue = 0.0;
          if (object[y] != null) {
            if (object[y] instanceof BigDecimal) {
              doubleValue = ((BigDecimal) object[y]).doubleValue();
            } else if (object[y] instanceof Double) {
              doubleValue = ((Double) object[y]).doubleValue();
            } else if (object[y] instanceof Integer) {
              doubleValue = ((Integer) object[y]).doubleValue();
            } else {
              doubleValue = ((Number) object[y]).doubleValue();
            }
          }
          tempXY.set(x, 0, doubleValue);
        }
        XY.put(groupValue, tempXY);
        double[] arrayarrayResult =
            getHessian(
                object,
                columnNames.length + 2,
                (columnNames.length + 1)
                    * (columnNames.length + 2)
                    / 2); // new double[sbAllArray.size()];
        i = 0;

        for (int x = 0; x < columnNames.length + 1; x++) {
          for (int y = x; y < columnNames.length + 1; y++) {
            {
              double h = 0.0;
              if (!Double.isNaN(arrayarrayResult[i])) {
                h = arrayarrayResult[i];
              }
              tempHessian.set(x, y, h);
              if (x != y) {
                tempHessian.set(y, x, h);
              }
              i++;
            }
          }
        }
        hessian.put(groupValue, tempHessian);
      }
    } catch (SQLException e) {
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
    boolean first = true;
    for (String tempString : hessian.keySet()) {
      Matrix beta = null;
      Matrix varianceCovarianceMatrix = null;
      Double[] tempCoefficients = new Double[columnNames.length + 1];
      for (i = 0; i < tempCoefficients.length; i++) {
        tempCoefficients[i] = 0.0;
      }

      try {
        varianceCovarianceMatrix = hessian.get(tempString).SVDInverse();
        beta = varianceCovarianceMatrix.times(XY.get(tempString));

        for (i = 0; i < beta.getRowDimension(); i++) {
          if (i == 0) {
            tempCoefficients[beta.getRowDimension() - 1] = beta.get(i, 0);
          } else {
            tempCoefficients[i - 1] = beta.get(i, 0);
          }
        }
        coefficients.put(tempString, tempCoefficients);

        double r2 = 0;
        getGroupCoefficientMap(columnNames, tempString);
        if (first == true) {
          model =
              new LinearRegressionGroupGPModel(
                  dataSet,
                  columnNames,
                  columNames,
                  tempCoefficients,
                  coefficientmap.get(tempString));
          first = false;
        }
        LinearRegressionModelDB tempModel =
            new LinearRegressionModelDB(
                dataSet, columnNames, columNames, tempCoefficients, coefficientmap.get(tempString));
        if (!this.newDataSet.equals(this.dataSet)) {
          tempModel.setAllTransformMap_valueKey(transformer.getAllTransformMap_valueKey());
        }
        tempModel.setInteractionColumnExpMap(transformer.getInteractionColumnExpMap());
        tempModel.setInteractionColumnColumnMap(transformer.getInteractionColumnColumnMap());
        r2 = cacluateGroupRSquare(dataSet, tableName, labelName, tempModel, tempString, sb_notNull);
        tempModel.setR2(r2);
        model.addOneModel(tempModel, tempString);
      } catch (Exception e) {
        itsLogger.error(e.getMessage(), e);
      }
    }
  }
  protected StringBuffer createSSQLLGroup(
      DataSet dataSet,
      String tableName,
      Column label,
      String[] columnNames,
      HashMap<String, Double[]> coefficients2,
      StringBuilder sb_notNull) {
    StringBuffer predictedY = new StringBuffer("( case ");
    StringBuffer countString = new StringBuffer(" (case ");
    for (String groupValue : coefficients2.keySet()) {
      predictedY
          .append(" when ")
          .append(StringHandler.doubleQ(groupbyColumn))
          .append("=")
          .append(StringHandler.singleQ(groupValue))
          .append(" then ")
          .append(coefficients.get(groupValue)[coefficients.get(groupValue).length - 1]);
      for (int i = 0; i < columnNames.length; i++) {
        predictedY
            .append("+")
            .append(coefficients2.get(groupValue)[i])
            .append("*\"")
            .append(columnNames[i])
            .append("\"");
      }
      if (groupCount.get(groupValue) > (columnNames.length + 1)) {
        countString
            .append(" when ")
            .append(StringHandler.doubleQ(groupbyColumn))
            .append("=")
            .append(StringHandler.singleQ(groupValue))
            .append(" then ")
            .append(groupCount.get(groupValue))
            .append(" - ")
            .append(columnNames.length + 1);
      } else {
        countString
            .append(" when ")
            .append(StringHandler.doubleQ(groupbyColumn))
            .append("=")
            .append(StringHandler.singleQ(groupValue))
            .append(" then ")
            .append(" null ");
        dataErrorList.add(groupValue);
      }
    }
    countString.append(" end )::double precision");
    predictedY.append(" end ");
    String labelName = StringHandler.doubleQ(label.getName());
    predictedY.append(")");
    StringBuffer sSQL = new StringBuffer("select sqrt(");

    sSQL.append("sum((")
        .append(labelName)
        .append(" - ")
        .append(predictedY)
        .append(")*1.0*(")
        .append(labelName)
        .append(" - ")
        .append(predictedY)
        .append("))/")
        .append("( ")
        .append(countString)
        .append(")");
    sSQL.append("),")
        .append(StringHandler.doubleQ(groupbyColumn))
        .append(" from ")
        .append(tableName)
        .append(" ")
        .append(sb_notNull)
        .append(" group by ")
        .append(StringHandler.doubleQ(groupbyColumn));
    return sSQL;
  }
  protected HashMap<String, ArrayList<String>> calculateNullGroup(DataSet dataSet, Columns atts)
      throws OperatorException {

    DatabaseConnection databaseConnection =
        ((DBTable) dataSet.getDBTable()).getDatabaseConnection();
    String tableName = ((DBTable) dataSet.getDBTable()).getTableName();
    HashMap<String, ArrayList<String>> resultList = new HashMap<String, ArrayList<String>>();
    Iterator<Column> i = dataSet.getColumns().iterator();
    StringBuilder sb_count = new StringBuilder("select ");
    sb_count.append(StringHandler.doubleQ(groupbyColumn)).append(", count(*) ");

    sb_count
        .append(" from ")
        .append(tableName)
        .append(this.getWhere(atts))
        .append(" group by ")
        .append(StringHandler.doubleQ(groupbyColumn));
    try {
      Statement st = databaseConnection.createStatement(false);
      itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString());

      ResultSet rs = st.executeQuery(sb_count.toString());
      while (rs.next()) {
        String groupValue = "";
        groupValue = rs.getString(1);
        this.groupCount.put(groupValue, rs.getInt(2));
      }
    } catch (SQLException e) {
      e.printStackTrace();
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
    sb_count =
        new StringBuilder("select ").append(StringHandler.doubleQ(groupbyColumn)).append(", ");
    while (i.hasNext()) {
      Column att = i.next();
      sb_count.append("count(").append(StringHandler.doubleQ(att.getName())).append(")");
      sb_count.append(StringHandler.doubleQ(att.getName())).append(",");
    }
    sb_count = sb_count.deleteCharAt(sb_count.length() - 1);
    sb_count
        .append(" from ")
        .append(tableName)
        .append(this.getWhere(atts))
        .append(" group by ")
        .append(StringHandler.doubleQ(groupbyColumn));
    try {
      Statement st = databaseConnection.createStatement(false);
      itsLogger.debug("LinearRegressionImp.calculateNull():sql=" + sb_count.toString());

      ResultSet rs = st.executeQuery(sb_count.toString());
      while (rs.next()) {
        ArrayList<String> null_list = new ArrayList<String>();
        String groupValue = "";
        groupValue = rs.getString(1);
        for (int j = 1; j < rs.getMetaData().getColumnCount(); j++) {
          if (rs.getFloat(j + 1) != groupCount.get(groupValue)) {
            null_list.add(
                dataSet.getColumns().get(rs.getMetaData().getColumnName(j + 1)).getName());
          }
        }
        resultList.put(groupValue, null_list);
      }
    } catch (SQLException e) {
      e.printStackTrace();
      itsLogger.error(e.getMessage(), e);
      throw new OperatorException(e.getLocalizedMessage());
    }
    return resultList;
  }
Пример #13
0
  private void performOperation(
      DatabaseConnection databaseConnection, DataSet dataSet, Locale locale)
      throws AnalysisError, OperatorException {
    String outputTableName = getQuotaedTableName(getOutputSchema(), getOutputTable());

    String inputTableName = getQuotaedTableName(getInputSchema(), getInputTable());

    Columns atts = dataSet.getColumns();
    String dbType = databaseConnection.getProperties().getName();
    IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(dbType);

    IMultiDBUtility multiDBUtility = MultiDBUtilityFactory.createConnectionInfo(dbType);

    ISqlGeneratorMultiDB sqlGenerator = SqlGeneratorMultiDBFactory.createConnectionInfo(dbType);

    dropIfExist(dataSet);

    DatabaseUtil.alterParallel(databaseConnection, getOutputType()); // for oracle

    StringBuilder sb_create = new StringBuilder("create ");
    StringBuilder insertTable = new StringBuilder();

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(" table ");
    } else {
      sb_create.append(" view ");
    }
    sb_create.append(outputTableName);
    sb_create.append(
        getOutputType().equalsIgnoreCase(Resources.TableType) ? getAppendOnlyString() : "");
    sb_create.append(DatabaseUtil.addParallel(databaseConnection, getOutputType())).append(" as (");
    StringBuilder selectSql = new StringBuilder(" select ");

    selectSql.append(StringHandler.doubleQ(groupColumn)).append(",");

    Column att = atts.get(columnNames);
    dataSet.computeColumnStatistics(att);
    if (att.isNumerical()) {
      logger.error("PivotTableAnalyzer cannot accept numeric type column");
      throw new AnalysisError(
          this,
          AnalysisErrorName.Not_numeric,
          locale,
          SDKLanguagePack.getMessage(SDKLanguagePack.PIVOT_NAME, locale));
    }
    String attName = StringHandler.doubleQ(att.getName());
    List<String> valueList = att.getMapping().getValues();
    if (!useArray
        && valueList.size() > Integer.parseInt(AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD)) {
      logger.error("Too many distinct value for column " + StringHandler.doubleQ(columnNames));
      throw new AnalysisError(
          this,
          AnalysisErrorName.Too_Many_Distinct_value,
          locale,
          StringHandler.doubleQ(columnNames),
          AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD);
    }

    if (valueList.size() <= 0) {
      logger.error("Empty table");
      throw new AnalysisError(this, AnalysisErrorName.Empty_table, locale);
    }

    String aggColumnName;
    if (!StringUtil.isEmpty(aggColumn)) {
      aggColumnName = StringHandler.doubleQ(aggColumn);
    } else {
      aggColumnName = "1";
    }

    Iterator<String> valueList_i = valueList.iterator();

    if (useArray) {
      if (dataSourceInfo.getDBType().equals(DataSourceInfoOracle.dBType)) {
        ArrayList<String> array = new ArrayList<String>();
        while (valueList_i.hasNext()) {
          String value = StringHandler.escQ(valueList_i.next());
          String newValue =
              "alpine_miner_null_to_0("
                  + aggrType
                  + " (case when "
                  + attName
                  + "="
                  + CommonUtility.quoteValue(dbType, att, value)
                  + " then "
                  + aggColumnName
                  + " end )) ";
          array.add(newValue);
        }
        selectSql.append(
            CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float));
      } else {
        selectSql.append(multiDBUtility.floatArrayHead());
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
        selectSql.append(multiDBUtility.floatArrayTail());
      }
      selectSql.append(" " + StringHandler.doubleQ(att.getName()));
    } else {
      if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoNZ.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoDB2.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (double(case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end ))) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      }
    }
    selectSql.append(" from ").append(inputTableName).append(" foo group by ");
    selectSql.append(StringHandler.doubleQ(groupColumn));

    if (((DBTable) dataSet.getDBTable())
        .getDatabaseConnection()
        .getProperties()
        .getName()
        .equals(DataSourceInfoNZ.dBType)) {
      StringBuilder sb = new StringBuilder();
      sb.append("select ").append(StringHandler.doubleQ(groupColumn)).append(",");
      Iterator<String> valueList_new = valueList.iterator();
      while (valueList_new.hasNext()) {
        String value = valueList_new.next();
        String colName = StringHandler.doubleQ(att.getName() + "_" + value);
        sb.append("case when ").append(colName).append(" is null then 0 else ");
        sb.append(colName).append(" end ").append(colName).append(",");
      }
      sb = sb.deleteCharAt(sb.length() - 1);
      sb.append(" from (").append(selectSql).append(") foo ");
      selectSql = sb;
    }
    sb_create.append(selectSql).append(" )");

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(getEndingString());
      insertTable.append(sqlGenerator.insertTable(selectSql.toString(), outputTableName));
    }
    try {
      Statement st = databaseConnection.createStatement(false);
      logger.debug("PivotTableAnalyzer.performOperation():sql=" + sb_create);
      st.execute(sb_create.toString());

      if (insertTable.length() > 0) {
        st.execute(insertTable.toString());
        logger.debug("PivotTableAnalyzer.performOperation():insertTableSql=" + insertTable);
      }
    } catch (SQLException e) {
      logger.error(e);
      if (e.getMessage().startsWith("ORA-03001")
          || e.getMessage().startsWith("ERROR:  invalid identifier")) {
        throw new AnalysisError(this, AnalysisErrorName.Invalid_Identifier, locale);
      } else {
        throw new OperatorException(e.getLocalizedMessage());
      }
    }
  }