@Override
 public Nonconformity fit(DataFrame x, Vector y) {
   Objects.requireNonNull(x, "Input data is required.");
   Objects.requireNonNull(y, "Input target is required.");
   Check.argument(x.rows() == y.size(), "The size of input data and input target don't match");
   Classifier probabilityEstimator = classifier.fit(x, y);
   Check.state(
       probabilityEstimator != null
           && probabilityEstimator
               .getCharacteristics()
               .contains(ClassifierCharacteristic.ESTIMATOR),
       "The produced classifier can't estimate probabilities");
   return new ProbabilityEstimateNonconformity(probabilityEstimator, errorFunction);
 }
 @Override
 public DoubleArray estimate(DataFrame x, Vector y) {
   Objects.requireNonNull(x, "Input data required.");
   Objects.requireNonNull(y, "Input target required.");
   Check.argument(x.rows() == y.size(), "The size of input data and input target don't match.");
   return errorFunction.apply(classifier.estimate(x), y, classifier.getClasses());
 }
 @Override
 public double estimate(Vector example, Object label) {
   Objects.requireNonNull(example, "Require an example.");
   int trueClassIndex = classifier.getClasses().loc().indexOf(label);
   Check.argument(trueClassIndex >= 0, "illegal label %s", label);
   return errorFunction.apply(classifier.estimate(example), trueClassIndex);
 }
 @Override
 public InductiveConformalClassifier fit(DataFrame x, Vector y) {
   Objects.requireNonNull(x, "Input data is required.");
   Objects.requireNonNull(y, "Input target is required.");
   Check.argument(x.rows() == y.size(), "The size of input data and input target don't match.");
   return new InductiveConformalClassifier(learner.fit(x, y), Vectors.unique(y));
 }
示例#5
0
  @Override
  public DataFrame parse(Supplier<? extends DataFrame.Builder> supplier) {
    Check.state(url != null, "No database provided");
    Check.state(query != null, "No query provided");

    try {
      Connection connection = DriverManager.getConnection(url, properties);
      PreparedStatement stmt = connection.prepareStatement(query);
      ResultSet resultSet = stmt.executeQuery();
      ObjectIndex.Builder index = new ObjectIndex.Builder();

      if (header != null) {
        header.forEach(index::add);
      } else {
        ResultSetMetaData metaData = resultSet.getMetaData();
        for (int i = 0; i < metaData.getColumnCount(); i++) {
          // index starts with 1
          String columnLabel = metaData.getColumnLabel(i + 1);
          Object remappedColumnLabel = headerReMap.get(columnLabel);
          if (remappedColumnLabel != null) {
            index.add(remappedColumnLabel);
          } else {
            index.add(columnLabel);
          }
        }
      }

      SqlEntryReader entryReader = new SqlEntryReader(resultSet);
      DataFrame.Builder builder = supplier.get();
      List<Class<?>> columnTypes;
      if (types != null) {
        columnTypes = types;
      } else {
        columnTypes = entryReader.getTypes();
      }
      columnTypes.stream().map(VectorType::of).forEach(builder::add);
      builder.readAll(entryReader);
      builder.setColumnIndex(index.build());
      return builder.build();
    } catch (SQLException e) {
      throw new EntryReaderException(e);
    }
  }