예제 #1
0
  @Test
  public void testConstant() {
    double tolerancePerc = 10.0; // 10% of correct value
    int nSamples = 500;
    int nFeatures = 3;
    int constant = 100;

    INDArray featureSet = Nd4j.zeros(nSamples, nFeatures).add(constant);
    INDArray labelSet = Nd4j.zeros(nSamples, 1);
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(sampleDataSet);
    // Checking if we gets nans
    assertFalse(Double.isNaN(myNormalizer.getStd().getDouble(0)));

    myNormalizer.transform(sampleDataSet);
    // Checking if we gets nans, because std dev is zero
    assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0)));
    // Checking to see if transformed values are close enough to zero
    assertEquals(
        Transforms.abs(sampleDataSet.getFeatures()).max(0, 1).getDouble(0, 0),
        0,
        constant * tolerancePerc / 100.0);

    myNormalizer.revert(sampleDataSet);
    // Checking if we gets nans, because std dev is zero
    assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0)));
    assertEquals(
        Transforms.abs(sampleDataSet.getFeatures().sub(featureSet)).min(0, 1).getDouble(0),
        0,
        constant * tolerancePerc / 100.0);
  }
예제 #2
0
 @Test
 public void testColumnAtEnd() throws Exception {
   DataSet dataSet = from(this.getClass().getResourceAsStream("test4.json"));
   // There are 4 columns, but Jackson doesn't take them into account if at end of content. This is
   // not "expected"
   // but known. This test ensure the known behavior remains the same.
   assertThat(dataSet.getMetadata(), nullValue());
 }
예제 #3
0
 @Test
 public void testRoundTrip() throws Exception {
   DataSet dataSet = from(DataSetJSONTest.class.getResourceAsStream("test3.json"));
   final DataSetMetadata metadata = dataSet.getMetadata();
   metadata.getContent().addParameter(CSVFormatGuess.SEPARATOR_PARAMETER, ",");
   metadata.getContent().setFormatGuessId(new CSVFormatGuess().getBeanId());
   assertNotNull(metadata);
   StringWriter writer = new StringWriter();
   to(dataSet, writer);
   assertThat(
       writer.toString(), sameJSONAsFile(DataSetJSONTest.class.getResourceAsStream("test3.json")));
 }
예제 #4
0
 @Test
 public void testAttributeOrder() {
   for (List<String> strings : someNonEmptyLists(uniqueValues(nonEmptyStrings(), 10))) {
     DataSetBuilder bld = new DataSetBuilder(nonEmptyStrings().next());
     bld.setTrain(
         new GenericDataSource(
             "train", EventCollectionDAO.create(Collections.<Rating>emptyList())));
     bld.setTest(
         new GenericDataSource(
             "test", EventCollectionDAO.create(Collections.<Rating>emptyList())));
     for (String str : strings) {
       bld.setAttribute(str, nonEmptyStrings().next());
     }
     DataSet ds = bld.build();
     assertThat(ds.getAttributes().size(), equalTo(strings.size()));
     assertThat(ds.getAttributes().keySet(), contains(strings.toArray()));
   }
 }
예제 #5
0
  @Test
  public void should_iterate_row_with_metadata() throws IOException {
    // given
    String[] columnNames =
        new String[] {
          "id",
          "firstname",
          "lastname",
          "state",
          "registration",
          "city",
          "birth",
          "nbCommands",
          "avgAmount"
        };

    final InputStream input = this.getClass().getResourceAsStream("dataSetRowMetadata.json");
    final ObjectMapper mapper = builder.build();
    try (JsonParser parser = mapper.getFactory().createParser(input)) {
      final DataSet dataSet = mapper.readerFor(DataSet.class).readValue(parser);
      final Iterator<DataSetRow> iterator = dataSet.getRecords().iterator();

      List<ColumnMetadata> actualColumns = new ArrayList<>();
      int recordCount = 0;
      while (iterator.hasNext()) {
        final DataSetRow next = iterator.next();
        actualColumns = next.getRowMetadata().getColumns();
        assertThat(actualColumns, not(empty()));
        recordCount++;
      }

      // then
      assertEquals(10, recordCount);
      for (int i = 0; i < actualColumns.size(); i++) {
        final ColumnMetadata column = actualColumns.get(i);
        assertEquals(columnNames[i], column.getId());
      }
    } catch (Exception e) {
      throw new TDPException(CommonErrorCodes.UNABLE_TO_PARSE_JSON, e);
    }
  }
예제 #6
0
  @Test
  public void testWrite1() throws Exception {
    final ColumnMetadata.Builder columnBuilder =
        ColumnMetadata.Builder //
            .column() //
            .id(5) //
            .name("column1") //
            .type(Type.STRING) //
            .empty(0) //
            .invalid(10) //
            .valid(50);

    DataSetMetadata metadata =
        metadataBuilder
            .metadata()
            .id("1234")
            .name("name")
            .author("author")
            .created(0)
            .row(columnBuilder)
            .build();

    final DataSetContent content = metadata.getContent();
    content.addParameter(CSVFormatGuess.SEPARATOR_PARAMETER, ",");
    content.setFormatGuessId(new CSVFormatGuess().getBeanId());
    content.setMediaType("text/csv");
    metadata.getLifecycle().qualityAnalyzed(true);
    metadata.getLifecycle().schemaAnalyzed(true);
    HttpLocation location = new HttpLocation();
    location.setUrl("http://estcequecestbientotleweekend.fr");
    metadata.setLocation(location);
    StringWriter writer = new StringWriter();
    DataSet dataSet = new DataSet();
    dataSet.setMetadata(metadata);
    to(dataSet, writer);
    assertThat(
        writer.toString(), sameJSONAsFile(DataSetJSONTest.class.getResourceAsStream("test2.json")));
  }
예제 #7
0
  @Test
  public void testRead1() throws Exception {

    DataSet dataSet = from(this.getClass().getResourceAsStream("test1.json"));
    assertNotNull(dataSet);

    final DataSetMetadata metadata = dataSet.getMetadata();
    assertEquals("410d2196-8f90-478f-a817-7e8b6694ac91", metadata.getId());
    assertEquals("test", metadata.getName());
    assertEquals("anonymousUser", metadata.getAuthor());
    assertEquals(2, metadata.getContent().getNbRecords());
    assertEquals(1, metadata.getContent().getNbLinesInHeader());
    assertEquals(0, metadata.getContent().getNbLinesInFooter());

    final SimpleDateFormat dateFormat = new SimpleDateFormat("MM-dd-yyyy HH:mm");
    dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));

    Date expectedDate = dateFormat.parse("02-17-2015 09:02");
    assertEquals(expectedDate, new Date(metadata.getCreationDate()));

    List<ColumnMetadata> columns = dataSet.getMetadata().getRowMetadata().getColumns();
    assertEquals(6, columns.size());

    ColumnMetadata firstColumn = columns.get(0);
    assertEquals("0001", firstColumn.getId());
    assertEquals("id", firstColumn.getName());
    assertEquals("integer", firstColumn.getType());
    assertEquals(20, firstColumn.getQuality().getEmpty());
    assertEquals(26, firstColumn.getQuality().getInvalid());
    assertEquals(54, firstColumn.getQuality().getValid());

    ColumnMetadata lastColumn = columns.get(5);
    assertEquals("0007", lastColumn.getId());
    assertEquals("string", lastColumn.getType());
    assertEquals(8, lastColumn.getQuality().getEmpty());
    assertEquals(25, lastColumn.getQuality().getInvalid());
    assertEquals(67, lastColumn.getQuality().getValid());
  }
예제 #8
0
  @Test
  public void testRevert() {
    double tolerancePerc = 0.01; // 0.01% of correct value
    int nSamples = 500;
    int nFeatures = 3;

    INDArray featureSet = Nd4j.randn(nSamples, nFeatures);
    INDArray labelSet = Nd4j.zeros(nSamples, 1);
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(sampleDataSet);
    DataSet transformed = sampleDataSet.copy();
    myNormalizer.transform(transformed);
    // System.out.println(transformed.getFeatures());
    myNormalizer.revert(transformed);
    // System.out.println(transformed.getFeatures());
    INDArray delta =
        Transforms.abs(transformed.getFeatures().sub(sampleDataSet.getFeatures()))
            .div(sampleDataSet.getFeatures());
    double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0, 0);
    assertTrue(maxdeltaPerc < tolerancePerc);
  }
예제 #9
0
 @Test
 public void genericTest() {
   assertNotNull(dataSet);
   assertNotNull(dataSet.getMetaData());
 }
예제 #10
0
 @After
 public void tearDown() {
   dataSet.close();
   dataSet = null;
 }