Ejemplo n.º 1
0
 public String getInputBaseArray() {
   ArrayList<String> array = new ArrayList<String>();
   for (int i = 0; i < nodeInputs.length; i++) {
     double temp = nodeInputs[i].getColumnBase();
     if (Double.isNaN(temp)) {
       temp = 0.0;
     }
     array.add(String.valueOf(temp));
   }
   return CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float).toString();
 }
Ejemplo n.º 2
0
 public String getColumnNames(boolean prediction) {
   ArrayList<String> array = new ArrayList<String>();
   for (int i = 0; i < nodeInputs.length; i++) {
     if (prediction) {
       array.add(nodeInputs[i].getTransformValue().toString());
     } else {
       array.add(StringHandler.doubleQ(nodeInputs[i].getColumn().getName()));
     }
   }
   return CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float).toString();
 }
Ejemplo n.º 3
0
  public String getWeightArray() {
    ArrayList<String> weightArray = new ArrayList<String>();
    int lastLayerIndex = -99;
    for (NodeInner nodeInner : nodeInners) {
      // skip outputs here and add them later

      // layer name
      int layerIndex = nodeInner.getLayerIndex();
      if (layerIndex != NNNode.OUTPUT) {
        if ((lastLayerIndex == -99) || (lastLayerIndex != layerIndex)) {
          lastLayerIndex = layerIndex;
        }

        // input weights
        double[] weights = nodeInner.getWeights();
        NNNode[] inputNodes = nodeInner.getInputNodes();
        for (int i = 0; i <= inputNodes.length; i++) {
          double temp = weights[i];
          if (Double.isNaN(temp)) {
            temp = 0.0;
          }
          weightArray.add(String.valueOf(temp));
        }
      }
    }

    // add output nodes
    for (NodeInner nodeInner : nodeInners) {
      // layer name
      int layerIndex = nodeInner.getLayerIndex();
      if (layerIndex == NNNode.OUTPUT) {
        // input weights
        double[] weights = nodeInner.getWeights();
        NNNode[] inputNodes = nodeInner.getInputNodes();
        for (int i = 0; i <= inputNodes.length; i++) {
          double temp = weights[i];
          if (Double.isNaN(temp)) {
            temp = 0.0;
          }
          weightArray.add(String.valueOf(temp));
        }
      }
    }
    return CommonUtility.array2OracleArray(weightArray, CommonUtility.OracleDataType.Float)
        .toString();
  }
Ejemplo n.º 4
0
  protected void adjustPerWholeData(
      DataSet dataSet,
      List<String[]> hiddenLayers,
      int maxCycles,
      double maxError,
      double learningRate,
      double momentum,
      boolean decay,
      boolean normalize,
      AlpineRandom randomGenerator,
      int fetchSize,
      boolean adjustPerRow,
      Column label,
      int numberOfClasses,
      DatabaseConnection databaseConnection,
      String tableName)
      throws OperatorException {
    double totalSize = dataSet.size();

    Statement st = null;
    ResultSet rs = null;

    try {
      st = databaseConnection.createStatement(true);
    } catch (SQLException e) {
      throw new OperatorException(e.getLocalizedMessage());
    }
    double error = 0;
    // optimization loop
    for (int cycle = 0; cycle < maxCycles; cycle++) {
      double tempRate = learningRate;
      if (decay) {
        tempRate /= (cycle + 1);
      }

      StringBuffer sql = new StringBuffer();
      sql.append("select ").append(getArraySum()).append("(");
      //			sql.append("select ").append("(");
      sql.append(getAllWeightChange(getNumberOfClasses(label)));
      sql.append(") from ").append(tableName);

      try {
        if (useFloatArraySumCursor()) {
          StringBuffer varcharArray = CommonUtility.splitOracleSqlToVarcharArray(sql);
          sql = new StringBuffer();
          sql.append("select floatarraysum_cursor(").append(varcharArray).append(") from dual");
        }
        itsLogger.debug("NNModel.adjustPerWholeData():sql=" + sql);

        rs = st.executeQuery(sql.toString());
        if (rs.next()) {
          Double[] currentChanges = getResult(rs);
          //				String currentChangesString= new String();
          for (int i = 0; i < currentChanges.length; i++) {
            if (Double.isNaN(currentChanges[i]) || Double.isInfinite(currentChanges[i])) {
              try {
                if (rs != null) {
                  rs.close();
                }
                if (st != null) {
                  st.close();
                }
                databaseConnection.getConnection().setAutoCommit(true);
              } catch (SQLException e) {
                e.printStackTrace();
                //							throw new OperatorException(e.getLocalizedMessage ());
                copyBestErrorWeightsToWeights();
                return;
              }
              copyBestErrorWeightsToWeights();
              return;
            }
          }
          error = currentChanges[currentChanges.length - 1] / numberOfClasses / totalSize;
          if (error < bestError) {
            bestError = error;
            copyWeightsToBestErrorWeights();
          }
          updateWeight(currentChanges, tempRate, momentum);
        }
      } catch (SQLException e) {
        e.printStackTrace();
        copyBestErrorWeightsToWeights();
        //				throw new OperatorException(e.getLocalizedMessage());
        return;
      }
      //			error /= totalSize;
      itsLogger.debug("cycle" + cycle + ";error:" + error);
      if (error < maxError) {
        itsLogger.debug("loop break : " + cycle + " error: " + error);
        break;
      }

      if (Double.isInfinite(error) || Double.isNaN(error)) {
        if (Tools.isLessEqual(learningRate, 0.0d)) // should hardly happen
        throw new RuntimeException("Cannot reset network to a smaller learning rate.");
        learningRate /= 2;
        train(
            dataSet,
            hiddenLayers,
            maxCycles,
            maxError,
            learningRate,
            momentum,
            decay,
            normalize,
            randomGenerator,
            fetchSize,
            adjustPerRow);
      }
    }
    copyBestErrorWeightsToWeights();

    try {
      if (rs != null) {
        rs.close();
      }
      if (st != null) {
        st.close();
      }
    } catch (SQLException e) {
      e.printStackTrace();
      //			throw new OperatorException(e.getLocalizedMessage ());
      //			copyBestErrorWeightsToWeights();
      return;
    }
  }
Ejemplo n.º 5
0
  private void performOperation(
      DatabaseConnection databaseConnection, DataSet dataSet, Locale locale)
      throws AnalysisError, OperatorException {
    String outputTableName = getQuotaedTableName(getOutputSchema(), getOutputTable());

    String inputTableName = getQuotaedTableName(getInputSchema(), getInputTable());

    Columns atts = dataSet.getColumns();
    String dbType = databaseConnection.getProperties().getName();
    IDataSourceInfo dataSourceInfo = DataSourceInfoFactory.createConnectionInfo(dbType);

    IMultiDBUtility multiDBUtility = MultiDBUtilityFactory.createConnectionInfo(dbType);

    ISqlGeneratorMultiDB sqlGenerator = SqlGeneratorMultiDBFactory.createConnectionInfo(dbType);

    dropIfExist(dataSet);

    DatabaseUtil.alterParallel(databaseConnection, getOutputType()); // for oracle

    StringBuilder sb_create = new StringBuilder("create ");
    StringBuilder insertTable = new StringBuilder();

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(" table ");
    } else {
      sb_create.append(" view ");
    }
    sb_create.append(outputTableName);
    sb_create.append(
        getOutputType().equalsIgnoreCase(Resources.TableType) ? getAppendOnlyString() : "");
    sb_create.append(DatabaseUtil.addParallel(databaseConnection, getOutputType())).append(" as (");
    StringBuilder selectSql = new StringBuilder(" select ");

    selectSql.append(StringHandler.doubleQ(groupColumn)).append(",");

    Column att = atts.get(columnNames);
    dataSet.computeColumnStatistics(att);
    if (att.isNumerical()) {
      logger.error("PivotTableAnalyzer cannot accept numeric type column");
      throw new AnalysisError(
          this,
          AnalysisErrorName.Not_numeric,
          locale,
          SDKLanguagePack.getMessage(SDKLanguagePack.PIVOT_NAME, locale));
    }
    String attName = StringHandler.doubleQ(att.getName());
    List<String> valueList = att.getMapping().getValues();
    if (!useArray
        && valueList.size() > Integer.parseInt(AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD)) {
      logger.error("Too many distinct value for column " + StringHandler.doubleQ(columnNames));
      throw new AnalysisError(
          this,
          AnalysisErrorName.Too_Many_Distinct_value,
          locale,
          StringHandler.doubleQ(columnNames),
          AlpineMinerConfig.PIVOT_DISTINCTVALUE_THRESHOLD);
    }

    if (valueList.size() <= 0) {
      logger.error("Empty table");
      throw new AnalysisError(this, AnalysisErrorName.Empty_table, locale);
    }

    String aggColumnName;
    if (!StringUtil.isEmpty(aggColumn)) {
      aggColumnName = StringHandler.doubleQ(aggColumn);
    } else {
      aggColumnName = "1";
    }

    Iterator<String> valueList_i = valueList.iterator();

    if (useArray) {
      if (dataSourceInfo.getDBType().equals(DataSourceInfoOracle.dBType)) {
        ArrayList<String> array = new ArrayList<String>();
        while (valueList_i.hasNext()) {
          String value = StringHandler.escQ(valueList_i.next());
          String newValue =
              "alpine_miner_null_to_0("
                  + aggrType
                  + " (case when "
                  + attName
                  + "="
                  + CommonUtility.quoteValue(dbType, att, value)
                  + " then "
                  + aggColumnName
                  + " end )) ";
          array.add(newValue);
        }
        selectSql.append(
            CommonUtility.array2OracleArray(array, CommonUtility.OracleDataType.Float));
      } else {
        selectSql.append(multiDBUtility.floatArrayHead());
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
        selectSql.append(multiDBUtility.floatArrayTail());
      }
      selectSql.append(" " + StringHandler.doubleQ(att.getName()));
    } else {
      if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoNZ.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else if (((DBTable) dataSet.getDBTable())
          .getDatabaseConnection()
          .getProperties()
          .getName()
          .equals(DataSourceInfoDB2.dBType)) {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (double(case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end ))) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      } else {
        while (valueList_i.hasNext()) {
          String value = valueList_i.next();
          selectSql.append("alpine_miner_null_to_0(").append(aggrType);
          selectSql.append(" (case when ").append(attName).append("=");
          value = StringHandler.escQ(value);
          selectSql
              .append(CommonUtility.quoteValue(dbType, att, value))
              .append(" then ")
              .append(aggColumnName)
              .append(" end )) "); // else 0
          String colName = StringHandler.doubleQ(att.getName() + "_" + value);
          selectSql.append(colName);
          selectSql.append(",");
        }
        selectSql = selectSql.deleteCharAt(selectSql.length() - 1);
      }
    }
    selectSql.append(" from ").append(inputTableName).append(" foo group by ");
    selectSql.append(StringHandler.doubleQ(groupColumn));

    if (((DBTable) dataSet.getDBTable())
        .getDatabaseConnection()
        .getProperties()
        .getName()
        .equals(DataSourceInfoNZ.dBType)) {
      StringBuilder sb = new StringBuilder();
      sb.append("select ").append(StringHandler.doubleQ(groupColumn)).append(",");
      Iterator<String> valueList_new = valueList.iterator();
      while (valueList_new.hasNext()) {
        String value = valueList_new.next();
        String colName = StringHandler.doubleQ(att.getName() + "_" + value);
        sb.append("case when ").append(colName).append(" is null then 0 else ");
        sb.append(colName).append(" end ").append(colName).append(",");
      }
      sb = sb.deleteCharAt(sb.length() - 1);
      sb.append(" from (").append(selectSql).append(") foo ");
      selectSql = sb;
    }
    sb_create.append(selectSql).append(" )");

    if (getOutputType().equalsIgnoreCase("table")) {
      sb_create.append(getEndingString());
      insertTable.append(sqlGenerator.insertTable(selectSql.toString(), outputTableName));
    }
    try {
      Statement st = databaseConnection.createStatement(false);
      logger.debug("PivotTableAnalyzer.performOperation():sql=" + sb_create);
      st.execute(sb_create.toString());

      if (insertTable.length() > 0) {
        st.execute(insertTable.toString());
        logger.debug("PivotTableAnalyzer.performOperation():insertTableSql=" + insertTable);
      }
    } catch (SQLException e) {
      logger.error(e);
      if (e.getMessage().startsWith("ORA-03001")
          || e.getMessage().startsWith("ERROR:  invalid identifier")) {
        throw new AnalysisError(this, AnalysisErrorName.Invalid_Identifier, locale);
      } else {
        throw new OperatorException(e.getLocalizedMessage());
      }
    }
  }