示例#1
0
  // Adapt a trained model to a test dataset with different enums
  /*@Test*/ public void testModelAdapt() {
    File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("KDDTrain.hex");
    File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("KDDTest.hex");
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = ParseDataset2.parse(dest1, new Key[] {fkey1});
      UKV.remove(fkey1);
      gbm.response = gbm.source.remove(41); // Response is col 41
      gbm.ntrees = 2;
      gbm.max_depth = 8;
      gbm.learn_rate = 0.2f;
      gbm.min_rows = 10;
      gbm.nbins = 50;
      gbm.invoke();

      // The test data set has a few more enums than the train
      Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2});
      Frame preds = gbm.score(ftest);

    } finally {
      UKV.remove(dest1); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
 /**
  * Find & parse a folder of CSV files. NPE if file not found.
  *
  * @param fname Test filename
  * @return Frame or NPE
  */
 protected Frame parse_test_folder(String fname) {
   File folder = find_test_file(fname);
   assert folder.isDirectory();
   File[] files = folder.listFiles();
   Arrays.sort(files);
   ArrayList<Key> keys = new ArrayList<Key>();
   for (File f : files) if (f.isFile()) keys.add(NFSFileVec.make(f)._key);
   Key[] res = new Key[keys.size()];
   keys.toArray(res);
   return ParseDataset.parse(Key.make(), res);
 }
示例#3
0
  // ==========================================================================
  public void basicGBM(String fname, String hexname, PrepData prep) {
    File file = TestUtil.find_test_file(fname);
    if (file == null) return; // Silently abort test if the file is missing
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make(hexname);
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = fr = ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);
      int idx = prep.prep(fr);
      if (idx < 0) {
        gbm.classification = false;
        idx = ~idx;
      }
      String rname = fr._names[idx];
      gbm.response = fr.vecs()[idx];
      fr.remove(idx); // Move response to the end
      fr.add(rname, gbm.response);
      gbm.ntrees = 4;
      gbm.max_depth = 4;
      gbm.min_rows = 1;
      gbm.nbins = 50;
      gbm.cols = new int[fr.numCols()];
      for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
      gbm.learn_rate = .2f;
      gbm.invoke();

      fr = gbm.score(gbm.source);

      GBM.GBMModel gbmmodel = UKV.get(gbm.dest());
      // System.out.println(gbmmodel.toJava());

    } finally {
      UKV.remove(dest); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
示例#4
0
文件: Expr2Test.java 项目: Jfeng3/h2o
  @Test
  public void testBasicExpr1() {
    Key dest = Key.make("h.hex");
    try {
      File file = TestUtil.find_test_file("smalldata/tnc3_10.csv");
      // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv");
      // File file = TestUtil.find_test_file("smalldata/cars.csv");
      Key fkey = NFSFileVec.make(file);
      ParseDataset2.parse(dest, new Key[] {fkey});

      // Simple numbers & simple expressions
      checkStr("1.23", 1.23);
      checkStr(" 1.23 + 2.34", 3.57);
      checkStr(" 1.23 + 2.34 * 3", 8.25); // op precedence of * over +
      checkStr(
          " 1.23 2.34", "Junk at end of line\n" + " 1.23 2.34\n" + "      ^--^\n"); // Syntax error
      checkStr("1.23 < 2.34", 1);
      checkStr("1.23 <=2.34", 1);
      checkStr("1.23 > 2.34", 0);
      checkStr("1.23 >=2.34", 0);
      checkStr("1.23 ==2.34", 0);
      checkStr("1.23 !=2.34", 1);
      checkStr("1 & 2", 1);
      checkStr("NA&0", 0); // R-spec: 0 not NA
      checkStr("0&NA", 0); // R-spec: 0 not NA
      checkStr("NA&1", Double.NaN); // R-spec: NA not 1
      checkStr("1&NA", Double.NaN);
      checkStr("1|NA", 1);
      checkStr("1&&2", 1);
      checkStr("1||0", 1);
      checkStr("NA||1", 1);
      checkStr("NA||0", Double.NaN);
      checkStr("0||NA", Double.NaN);
      checkStr("!1", 0);
      checkStr("(!)(1)", 0);
      checkStr(
          "(!!)(1)",
          "Arg 'x' typed as dblary but passed dblary(dblary)\n" + "(!!)(1)\n" + " ^-^\n");
      checkStr("-1", -1);
      checkStr("-(1)", -1);
      checkStr("(-)(1)", "Passed 1 args but expected 2\n" + "(-)(1)\n" + "   ^--^\n");
      checkStr("-T", -1);
      checkStr(
          "* + 1",
          "Arg 'x' typed as dblary but passed anyary{dblary,dblary,}(dblary,dblary)\n"
              + "* + 1\n"
              + "^----^\n");
      // Simple op as prefix calls
      checkStr(
          "+(1.23,2.34)",
          "Missing ')'\n"
              + "+(1.23,2.34)\n"
              + "  ^---^\n"); // Syntax error: looks like unary op application
      checkStr("+(1.23)", 1.23); // Unary operator

      // Simple scalar assignment
      checkStr("1=2", "Junk at end of line\n" + "1=2\n" + " ^^\n");
      checkStr("x", "Unknown var x\n" + "x\n" + "^^\n");
      checkStr("x+2", "Unknown var x\n" + "x+2\n" + "^^\n");
      checkStr("2+x", "Missing expr or unknown ID\n" + "2+x\n" + "  ^\n");
      checkStr("x=1", 1);
      checkStr("x<-1", 1); // Alternative R assignment syntax
      checkStr("x=3;y=4", 4); // Return value is last expr

      // Ambiguity & Language
      checkStr("x=mean"); // Assign x to the built-in fcn mean
      checkStr(
          "x=mean=3",
          3); // Assign x & id mean with 3; "mean" here is not related to any built-in fcn
      checkStr("x=mean(c(3))", 3); // Assign x to the result of running fcn mean(3)
      checkStr("x=mean(c(\n3))", 3); // Assign x to the result of running fcn mean(3)
      checkStr(
          "x=mean+3",
          "Arg 'x' typed as dblary but passed dbl(ary)\n"
              + "x=mean+3\n"
              + "  ^-----^\n"); // Error: "mean" is a function; cannot add a function and a number
      checkStr(
          "apply(c(1,2,3),,nrow)", "Missing argument\napply(c(1,2,3),,nrow)\n               ^\n");
      checkStr(
          "foo==bar",
          "Unknown var foo\nfoo==bar\n^--^\n"); // Error msg is about "foo==" and not new assignment
      // "foo="

      // Simple array handling; broadcast operators
      checkStr("h.hex"); // Simple ref
      checkStr("h.hex[2,3]", 1); // Scalar selection
      checkStr(
          "h.hex[2,+]",
          "Must be scalar or array\n" + "h.hex[2,+]\n" + "        ^-^\n"); // Function not allowed
      checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4
      checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed
      checkStr(
          "h.hex[2+3,h.hex]",
          "Selector must be a single column: [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]"); // Error: col selector has too many columns
      checkStr("h.hex[2,]"); // Row 2 all cols
      checkStr("h.hex[,3]"); // Col 3 all rows
      checkStr("h.hex+1"); // Broadcast scalar over ary
      checkStr("h.hex-h.hex");
      checkStr("1.23+(h.hex-h.hex)");
      checkStr("(1.23+h.hex)-h.hex");
      checkStr("min(h.hex,1+2)", 0);
      checkStr("max(h.hex,1+2)", 211.3375);
      checkStr("min.na.rm(h.hex,NA)", 0); // 0
      checkStr("max.na.rm(h.hex,NA)", 211.3375); // 211.3375
      checkStr("min.na.rm(c(NA, 1), -1)", -1); // -1
      checkStr("max.na.rm(c(NA, 1), -1)", 1); // 1
      checkStr("max(c(Inf,1),  2 )", Double.POSITIVE_INFINITY); // Infinity
      checkStr("min(c(Inf,1),-Inf)", Double.NEGATIVE_INFINITY); // -Infinity
      checkStr("is.na(h.hex)");
      checkStr("sum(is.na(h.hex))", 0);
      checkStr("nrow(h.hex)*3", 30);
      checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]");
      checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level
      checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE

      checkStr(
          "(h.hex+1)<-2",
          "Junk at end of line\n" + "(h.hex+1)<-2\n" + "         ^-^\n"); // No L-value
      checkStr(
          "h.hex[nrow(h.hex=1),]",
          "Arg 'x' typed as ary but passed dbl\n"
              + "h.hex[nrow(h.hex=1),]\n"
              + "          ^--------^\n"); // Passing a scalar 1.0 to nrow
      checkStr(
          "h.hex[{h.hex=10},]"); // ERROR BROKEN: SHOULD PARSE statement list here; then do evil
      // side-effect killing h.hex but also using 10 to select last row
      checkStr("h.hex[3,4]<-4;", 4);
      checkStr("c(1,3,5)");
      // Column row subselection
      checkStr("h.hex[,c(1,3,5)]");
      checkStr("h.hex[c(1,3,5),]");
      checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]");
      // Named column selection
      checkStr("h.hex$ 2", "Missing column name after $\nh.hex$ 2\n      ^^\n");
      checkStr(
          "h.hex$crunk",
          "Missing column crunk in frame [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]");
      checkStr("h.hex$pclass");
      checkStr("mean(h.hex$pclass)", 1);

      // More complicated operator precedence
      checkStr("c(1,0)&c(2,3)"); // 1,0
      checkStr("c(2,NA)&&T", 1); // 1
      checkStr("-(x = 3)", -3);
      checkStr("x<-+");
      checkStr(
          "x<-+;x(2)",
          "Passed 1 args but expected 2\nx<-+;x(2)\n      ^--^\n"); // Error, + is binary if used as
      // prefix
      checkStr("x<-+;x(1,2)", 3); // 3
      checkStr("x<-*;x(2,3)", 6); // 6
      checkStr("x=c(0,1);!x+1"); // ! has lower precedence
      checkStr("x=c(1,-2);-+---x");
      checkStr("x=c(1,-2);--!--x");
      checkStr("!(y=c(3,4))");
      checkStr("!x!=1");
      checkStr("(!x)!=1");
      checkStr("1+x^2");
      checkStr("1+x**2");
      checkStr("x + 2/y");
      checkStr("x + (2/y)");
      checkStr("-x + y");
      checkStr("-(x + y)");
      checkStr("-x % y");
      checkStr("-(x % y)");
      checkStr("T|F&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0
      checkStr("T||F&&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0

      // User functions
      checkStr("function(=){x+1}(2)", "Invalid var\nfunction(=){x+1}(2)\n         ^\n");
      checkStr("function(x,=){x+1}(2)", "Invalid var\nfunction(x,=){x+1}(2)\n           ^\n");
      checkStr("function(x,<-){x+1}(2)", "Invalid var\nfunction(x,<-){x+1}(2)\n           ^\n");
      checkStr(
          "function(x,x){x+1}(2)", "Repeated argument\nfunction(x,x){x+1}(2)\n           ^^\n");
      checkStr("function(x,y,z){x[]}(h.hex,1,2)");
      checkStr(
          "function(x){x[]}(2)",
          "Arg 'x' typed as ary but passed dbl\nfunction(x){x[]}(2)\n                ^--^\n");
      checkStr("function(x){x+1}(2)", 3);
      checkStr("function(x){y=x+y}(2)");
      checkStr("function(x){}(2)");
      checkStr("function(x){y=x*2; y+1}(2)", 5);
      checkStr("function(x){y=1+2}(2)", 3);
      checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes
      checkStr("a=function(x) x+1; 7", 7); // Function def w/out curly-braces; return 7
      checkStr("a=function(x) {x+1}; 7", 7); // Function def w/ curly-braces; return 7
      checkStr("a=function(x) {x+1; 7}"); // Function def of 7
      checkStr("c(1,c(2,3))");
      checkStr("a=c(1,Inf);c(2,a)");
      // Test sum flattening all args
      checkStr("sum(1,2,3)", 6);
      checkStr("sum(c(1,3,5))", 9);
      checkStr("sum(4,c(1,3,5),2,6)", 21);
      checkStr("sum(1,h.hex,3)"); // should report an error because h.hex has enums
      checkStr("sum(c(NA,-1,1))", Double.NaN);
      checkStr("sum.na.rm(c(NA,-1,1))", 0);

      checkStr("function(a){a[];a=1}");
      checkStr("a=1;a=2;function(x){x=a;a=3}");
      checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)", 30);
      checkStr("a=h.hex;a[,1]=(a[,1]==8)");
      // Higher-order function typing: fun is typed in the body of function(x)
      checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)", 1);
      // Filter/selection
      checkStr("h.hex[h.hex[,4]>30,]");
      checkStr("a=c(1,2,3);a[a[,1]>10,1]");
      checkStr("sapply(a,sum)[1,1]", 6);
      checkStr("apply(h.hex,2,sum)"); // ERROR BROKEN: the ENUM cols should fold to NA
      checkStr("y=5;apply(h.hex,2,function(x){x[]+y})");
      checkStr(
          "apply(h.hex,2,function(x){x=1;h.hex})",
          "Arg 'fcn' typed as ary(ary) but passed ary(dbl)\napply(h.hex,2,function(x){x=1;h.hex})\n     ^-------------------------------^\n");
      checkStr(
          "apply(h.hex,2,function(x){h.hex})",
          "apply requires that ary fun(ary x) return 1 column");
      checkStr("apply(h.hex,2,function(x){sum(x)/nrow(x)})");
      checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)");
      checkStr("sum(apply(h.hex[,c(4,5)],1,mean))", 184.96); // Row-wise apply on mean

      // Conditional selection;
      checkStr("ifelse(0,1,2)", 2);
      checkStr("ifelse(0,h.hex+1,h.hex+2)");
      checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection
      checkStr("ifelse(0,+,*)(1,2)", 2); // Select functions
      checkStr("(0 ? + : *)(1,2)", 2); // Trinary select
      checkStr("(1? h.hex : (h.hex+1))[1,2]", 0); // True (vs false) test
      // Impute the mean
      checkStr(
          "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})");
      checkStr("factor(h.hex[,5])");

      // Slice assignment & map
      checkStr("h.hex[,2]");
      checkStr("h.hex[,2]+1");
      checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant
      checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col
      checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col
      checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols
      checkStr("h.hex[,7]=x=3; !(x+2)");
      checkStr("table(h.hex)");
      checkStr("table(h.hex[,5])");
      checkStr("table(h.hex[,c(2,7)])");
      checkStr("table(h.hex[,c(2,9)])");
      checkStr("a=cbind(c(1,2,3), c(4,5,6))");
      checkStr("a[,1] = factor(a[,1])");
      checkStr("is.factor(a[,1])", 1);
      checkStr("isTRUE(c(1,3))", 0);
      checkStr("a=1;isTRUE(1)", 1);
      checkStr("a=c(1,2);isTRUE(a)", 0);
      checkStr("isTRUE(min)", 0);
      checkStr("seq_len(0)", "Error in seq_len(0): argument must be coercible to positive integer");
      checkStr(
          "seq_len(-1)", "Error in seq_len(-1): argument must be coercible to positive integer");
      checkStr("seq_len(10)");
      checkStr("3 < 4 |  F &  3 > 4", 1); // Evals as (3<4) | (F & (3>4))
      checkStr("3 < 4 || F && 3 > 4", 1);
      checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81", Double.NaN);
      // checkStr("h.hex[h.hex[,4]>40,]=-99");
      // checkStr("h.hex[2,]=h.hex[7,]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]");
      // checkStr("map()");
      // checkStr("map(1)");
      // checkStr("map(+,h.hex,1)");
      // checkStr("map(+,1,2)");
      // checkStr("map(function(x){x[];1},h.hex)");
      // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)");
      // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)");

      // Quantile
      checkStr("quantile(seq_len(10),seq_len(10)/10)");
      checkStr("quantile(runif(seq_len(10000),-1),seq_len(10)/10)");
      checkStr("quantile(h.hex[,4],c(0,.05,0.3,0.55,0.7,0.95,0.99))");

      // ddply error checks
      checkStr("ddply(h.hex,h.hex,sum)", "Only one column-of-columns for column selection");
      checkStr("ddply(h.hex,seq_len(10000),sum)", "Too many columns selected");
      checkStr("ddply(h.hex,NA,sum)", "NA not a valid column");
      checkStr("ddply(h.hex,c(1,NA,3),sum)", "NA not a valid column");
      checkStr("ddply(h.hex,c(1,99,3),sum)", "Column 99 out of range for frame columns 17");

      checkStr("nrow(unique(h.hex[,5]))", 3);
      checkStr("nrow(unique(h.hex[,6]))", 2);
      checkStr("nrow(unique(h.hex[,c(5,6)]))", 4); // multi-column unique

      // Newlines as statement-ends
      checkStr("3*4+5*6", 42);
      checkStr("(h.hex[1,1]=2)", 2);
      checkStr("(h.hex[1,1]=2\n)", 2);
      checkStr("(h.hex[1,1]\n=2)", 2);
      checkStr("(h.hex\n[1,1]=2)", 2);
      checkStr("function(){x=1.23;(x=4.5)\n}()", 4.5);
      checkStr("function(){x=1.23;x=\n4.5\n}()", 4.5);
      checkStr("x=3\nfunction()x=1.23\nx", 3);
      checkStr("x=3\nfunction(){(x=1.23)}\nx", 3);
      checkStr("x=function(df)\n{\nmin(df$age)\n}\n;x(h.hex)", 0.92);
      checkStr("1.23\n-4", -4);
      checkStr("1.23 +\n-4", -2.77);
      checkStr("x=3;3*-x", -9); // *- is not a token
      checkStr(
          "x=3;3\n*\n-\nx", 3); // Each of '3' and '*' and '-' and 'x' is a standalone statement

      // No strings, yet
      checkStr(
          "function(df) { min(df[,\"age\"]) }",
          "The current Exec does not handle strings\nfunction(df) { min(df[,\"age\"]) }\n                        ^\n");

      // Cleanup testing temps
      checkStr("a=0;x=0;y=0", 0); // Delete keys from global scope

    } finally {
      Lockable.delete(dest); // Remove original hex frame key
    }
  }
示例#5
0
  @Test
  public void testBasicExpr1() {
    Key dest = Key.make("h.hex");
    try {
      File file = TestUtil.find_test_file("smalldata/tnc3_10.csv");
      // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv");
      // File file = TestUtil.find_test_file("smalldata/cars.csv");
      Key fkey = NFSFileVec.make(file);
      ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);

      checkStr("1.23"); // 1.23
      checkStr(" 1.23 + 2.34"); // 3.57
      checkStr(" 1.23 + 2.34 * 3"); // 10.71, L2R eval order
      checkStr(" 1.23 2.34"); // Syntax error
      checkStr("1.23 < 2.34"); // 1
      checkStr("1.23 <=2.34"); // 1
      checkStr("1.23 > 2.34"); // 0
      checkStr("1.23 >=2.34"); // 0
      checkStr("1.23 ==2.34"); // 0
      checkStr("1.23 !=2.34"); // 1
      checkStr("h.hex"); // Simple ref
      checkStr("+(1.23,2.34)"); // prefix 3.57
      checkStr("+(1.23)"); // Syntax error, not enuf args
      checkStr("+(1.23,2,3)"); // Syntax error, too many args
      checkStr("h.hex[2,3]"); // Scalar selection
      checkStr("h.hex[2,+]"); // Function not allowed
      checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4
      checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed
      checkStr("h.hex[2+3,h.hex]"); // Error: col selector has too many columns
      checkStr("h.hex[2,]"); // Row 2 all cols
      checkStr("h.hex[,3]"); // Col 3 all rows
      checkStr("h.hex+1"); // Broadcast scalar over ary
      checkStr("h.hex-h.hex");
      checkStr("1.23+(h.hex-h.hex)");
      checkStr("(1.23+h.hex)-h.hex");
      checkStr("min(h.hex,1+2)");
      checkStr("max(h.hex,1+2)");
      checkStr("is.na(h.hex)");
      checkStr("nrow(h.hex)*3");
      checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]");
      checkStr("1=2");
      checkStr("x");
      checkStr("x+2");
      checkStr("2+x");
      checkStr("x=1");
      checkStr("x<-1"); // Alternative R assignment syntax
      checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level
      checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE
      checkStr("x<-+");
      checkStr("(h.hex+1)<-2");
      checkStr("h.hex[nrow(h.hex=1),]");
      checkStr("h.hex[2,3]<-4;");
      checkStr("c(1,3,5)");
      checkStr("function(=){x+1}(2)");
      checkStr("function(x,=){x+1}(2)");
      checkStr("function(x,<-){x+1}(2)");
      checkStr("function(x,x){x+1}(2)");
      checkStr("function(x,y,z){x[]}(h.hex,1,2)");
      checkStr("function(x){x[]}(2)");
      checkStr("function(x){x+1}(2)");
      checkStr("function(x){y=x+y}(2)");
      checkStr("function(x){}(2)");
      checkStr("function(x){y=x*2; y+1}(2)");
      checkStr("function(x){y=1+2}(2)");
      checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes
      checkStr("sum(1,2,3)");
      checkStr("sum(c(1,3,5))");
      checkStr("sum(4,c(1,3,5),2,6)");
      checkStr("sum(1,h.hex,3)");
      checkStr("h.hex[,c(1,3,5)]");
      checkStr("h.hex[c(1,3,5),]");
      checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]");
      checkStr("function(a){a[];a=1}");
      checkStr("a=1;a=2;function(x){x=a;a=3}");
      checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)");
      checkStr("a=h.hex;a[,1]=(a[,1]==8)");
      // Higher-order function typing: fun is typed in the body of function(x)
      checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)");
      // Filter/selection
      checkStr("h.hex[h.hex[,2]>4,]");
      checkStr("a=c(1,2,3);a[a[,1]>10,1]");
      checkStr("apply(h.hex,2,sum)");
      checkStr("y=5;apply(h.hex,2,function(x){x[]+y})");
      checkStr("apply(h.hex,2,function(x){x=1;h.hex})");
      checkStr("apply(h.hex,2,function(x){h.hex})");
      checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)");

      // Conditional selection;
      checkStr("ifelse(0,1,2)");
      checkStr("ifelse(0,h.hex+1,h.hex+2)");
      checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection
      checkStr("ifelse(0,+,*)(1,2)"); // Select functions
      checkStr("(0 ? + : *)(1,2)"); // Trinary select
      checkStr("(1? h.hex : (h.hex+1))[1,2]"); // True (vs false) test
      // Impute the mean
      checkStr(
          "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})");
      checkStr("factor(h.hex[,5])");

      // Slice assignment & map
      checkStr("h.hex[,2]");
      checkStr("h.hex[,2]+1");
      checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant
      checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col
      checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col
      checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols
      checkStr("table(h.hex)");
      checkStr("table(h.hex[,3])");
      checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81");
      checkStr("a=cbind(c(1,2,3), c(4,5,6))");
      // checkStr("h.hex[h.hex[,2]>4,]=-99");
      // checkStr("h.hex[2,]=h.hex[7,]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]");
      // checkStr("map()");
      // checkStr("map(1)");
      // checkStr("map(+,h.hex,1)");
      // checkStr("map(+,1,2)");
      // checkStr("map(function(x){x[];1},h.hex)");
      // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)");
      // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)");

      checkStr("a=0;x=0"); // Delete keys from global scope

    } finally {
      UKV.remove(dest); // Remove original hex frame key
    }
  }
 /**
  * Find & parse a CSV file. NPE if file not found.
  *
  * @param fname Test filename
  * @return Frame or NPE
  */
 public static Frame parse_test_file(String fname) {
   return parse_test_file(Key.make(), fname);
 }
示例#7
0
  /**
   * The train/valid Frame instances are sorted by categorical (themselves sorted by cardinality
   * greatest to least) with all numerical columns following. The response column(s) are placed at
   * the end.
   *
   * <p>Interactions: 1. Num-Num (Note: N(0,1) * N(0,1) ~ N(0,1) ) 2. Num-Enum 3. Enum-Enum
   *
   * <p>Interactions are produced on the fly and are dense (in all 3 cases). Consumers of DataInfo
   * should not have to care how these interactions are generated. Any heuristic using the fullN
   * value should continue functioning the same.
   *
   * <p>Interactions are specified in two ways: A. As a list of pairs of column indices. B. As a
   * list of pairs of column indices with limited enums.
   */
  public DataInfo(
      Frame train,
      Frame valid,
      int nResponses,
      boolean useAllFactorLevels,
      TransformType predictor_transform,
      TransformType response_transform,
      boolean skipMissing,
      boolean imputeMissing,
      boolean missingBucket,
      boolean weight,
      boolean offset,
      boolean fold,
      Model.InteractionPair[] interactions) {
    super(Key.<DataInfo>make());
    _valid = valid != null;
    assert predictor_transform != null;
    assert response_transform != null;
    _offset = offset;
    _weights = weight;
    _fold = fold;
    assert !(skipMissing && imputeMissing) : "skipMissing and imputeMissing cannot both be true";
    _skipMissing = skipMissing;
    _imputeMissing = imputeMissing;
    _predictor_transform = predictor_transform;
    _response_transform = response_transform;
    _responses = nResponses;
    _useAllFactorLevels = useAllFactorLevels;
    _interactions = interactions;

    // create dummy InteractionWrappedVecs and shove them onto the front
    if (_interactions != null) {
      _interactionVecs = new int[_interactions.length];
      train =
          Model.makeInteractions(
                  train,
                  false,
                  _interactions,
                  _useAllFactorLevels,
                  _skipMissing,
                  predictor_transform == TransformType.STANDARDIZE)
              .add(train);
      if (valid != null)
        valid =
            Model.makeInteractions(
                    valid,
                    true,
                    _interactions,
                    _useAllFactorLevels,
                    _skipMissing,
                    predictor_transform == TransformType.STANDARDIZE)
                .add(valid); // FIXME: should be using the training subs/muls!
    }

    _permutation = new int[train.numCols()];
    final Vec[] tvecs = train.vecs();

    // Count categorical-vs-numerical
    final int n = tvecs.length - _responses - (offset ? 1 : 0) - (weight ? 1 : 0) - (fold ? 1 : 0);
    int[] nums = MemoryManager.malloc4(n);
    int[] cats = MemoryManager.malloc4(n);
    int nnums = 0, ncats = 0;
    for (int i = 0; i < n; ++i)
      if (tvecs[i].isCategorical()) cats[ncats++] = i;
      else nums[nnums++] = i;

    _nums = nnums;
    _cats = ncats;
    _catLvls = new int[ncats][];

    // sort the cats in the decreasing order according to their size
    for (int i = 0; i < ncats; ++i)
      for (int j = i + 1; j < ncats; ++j)
        if (tvecs[cats[i]].domain().length < tvecs[cats[j]].domain().length) {
          int x = cats[i];
          cats[i] = cats[j];
          cats[j] = x;
        }
    String[] names = new String[train.numCols()];
    Vec[] tvecs2 = new Vec[train.numCols()];

    // Compute the cardinality of each cat
    _catModes = new int[ncats];
    _catOffsets = MemoryManager.malloc4(ncats + 1);
    _catMissing = new boolean[ncats];
    int len = _catOffsets[0] = 0;
    int interactionIdx = 0; // simple index into the _interactionVecs array

    ArrayList<Integer> interactionIds;
    if (_interactions == null) {
      interactionIds = new ArrayList<>();
      for (int i = 0; i < tvecs.length; ++i)
        if (tvecs[i] instanceof InteractionWrappedVec) {
          interactionIds.add(i);
        }
      _interactionVecs = new int[interactionIds.size()];
      for (int i = 0; i < _interactionVecs.length; ++i) _interactionVecs[i] = interactionIds.get(i);
    }
    for (int i = 0; i < ncats; ++i) {
      names[i] = train._names[cats[i]];
      Vec v = (tvecs2[i] = tvecs[cats[i]]);
      _catMissing[i] = missingBucket; // needed for test time
      if (v instanceof InteractionWrappedVec) {
        if (_interactions != null) _interactions[interactionIdx].vecIdx = i;
        _interactionVecs[interactionIdx++] =
            i; // i (and not cats[i]) because this is the index in _adaptedFrame
        _catOffsets[i + 1] = (len += v.domain().length + (missingBucket ? 1 : 0));
      } else
        _catOffsets[i + 1] =
            (len +=
                v.domain().length
                    - (useAllFactorLevels ? 0 : 1)
                    + (missingBucket ? 1 : 0)); // missing values turn into a new factor level
      _catModes[i] =
          imputeMissing ? imputeCat(train.vec(cats[i])) : _catMissing[i] ? v.domain().length : -100;
      _permutation[i] = cats[i];
    }
    _numMeans = new double[nnums];
    _numOffsets = MemoryManager.malloc4(nnums + 1);
    _numOffsets[0] = len;
    boolean isIWV; // is InteractionWrappedVec?
    for (int i = 0; i < nnums; ++i) {
      names[i + ncats] = train._names[nums[i]];
      Vec v = train.vec(nums[i]);
      tvecs2[i + ncats] = v;
      isIWV = v instanceof InteractionWrappedVec;
      if (isIWV) {
        if (null != _interactions) _interactions[interactionIdx].vecIdx = i + ncats;
        _interactionVecs[interactionIdx++] = i + ncats;
      }
      _numOffsets[i + 1] = (len += (isIWV ? ((InteractionWrappedVec) v).expandedLength() : 1));
      _numMeans[i] = train.vec(nums[i]).mean();
      _permutation[i + ncats] = nums[i];
    }
    for (int i = names.length - nResponses - (weight ? 1 : 0) - (offset ? 1 : 0) - (fold ? 1 : 0);
        i < names.length;
        ++i) {
      names[i] = train._names[i];
      tvecs2[i] = train.vec(i);
    }
    _adaptedFrame = new Frame(names, tvecs2);
    train.restructure(names, tvecs2);
    if (valid != null) valid.restructure(names, valid.vecs(names));
    //    _adaptedFrame = train;

    setPredictorTransform(predictor_transform);
    if (_responses > 0) setResponseTransform(response_transform);
  }
示例#8
0
  // Test-on-Train.  Slow test, needed to build a good model.
  @Test
  public void testGBMTrainTest() {
    File file1 = TestUtil.find_test_file("..//classifcation1Train.txt");
    if (file1 == null) return; // Silently ignore if file not found
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("train.hex");
    File file2 = TestUtil.find_test_file("..//classification1Test.txt");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("test.hex");
    GBM gbm = null;
    Frame fr = null, fpreds = null;
    try {
      gbm = new GBM();
      fr = ParseDataset2.parse(dest1, new Key[] {fkey1});
      UKV.remove(fkey1);
      UKV.remove(fr.remove("agentId")._key); // Remove unique ID; too predictive
      gbm.response = fr.remove("outcome"); // Train on the outcome
      gbm.source = fr;
      gbm.ntrees = 5;
      gbm.max_depth = 10;
      gbm.learn_rate = 0.2f;
      gbm.min_rows = 10;
      gbm.nbins = 100;
      gbm.invoke();

      // Test on the train data
      Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2});
      UKV.remove(fkey2);
      fpreds = gbm.score(ftest);

      // Build a confusion matrix
      ConfusionMatrix CM = new ConfusionMatrix();
      CM.actual = ftest;
      CM.vactual = ftest.vecs()[ftest.find("outcome")];
      CM.predict = fpreds;
      CM.vpredict = fpreds.vecs()[fpreds.find("predict")];
      CM.serve(); // Start it, do it

      // Really crappy cut-n-paste of what should be in the ConfusionMatrix class itself
      long cm[][] = CM.cm;
      long acts[] = new long[cm.length];
      long preds[] = new long[cm[0].length];
      for (int a = 0; a < cm.length; a++) {
        long sum = 0;
        for (int p = 0; p < cm[a].length; p++) {
          sum += cm[a][p];
          preds[p] += cm[a][p];
        }
        acts[a] = sum;
      }
      String adomain[] = ConfusionMatrix.show(acts, CM.vactual.domain());
      String pdomain[] = ConfusionMatrix.show(preds, CM.vpredict.domain());

      StringBuilder sb = new StringBuilder();
      sb.append("Act/Prd\t");
      for (String s : pdomain) if (s != null) sb.append(s).append('\t');
      sb.append("Error\n");

      long terr = 0;
      for (int a = 0; a < cm.length; a++) {
        if (adomain[a] == null) continue;
        sb.append(adomain[a]).append('\t');
        long correct = 0;
        for (int p = 0; p < pdomain.length; p++) {
          if (pdomain[p] == null) continue;
          if (adomain[a].equals(pdomain[p])) correct = cm[a][p];
          sb.append(cm[a][p]).append('\t');
        }
        long err = acts[a] - correct;
        terr += err; // Bump totals
        sb.append(String.format("%5.3f = %d / %d\n", (double) err / acts[a], err, acts[a]));
      }
      sb.append("Totals\t");
      for (int p = 0; p < pdomain.length; p++)
        if (pdomain[p] != null) sb.append(preds[p]).append("\t");
      sb.append(
          String.format(
              "%5.3f = %d / %d\n", (double) terr / CM.vactual.length(), terr, CM.vactual.length()));

      System.out.println(sb);

    } finally {
      UKV.remove(dest1); // Remove original hex frame key
      UKV.remove(fkey2);
      UKV.remove(dest2);
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
      }
      if (fr != null) fr.remove();
      if (fpreds != null) fpreds.remove();
    }
  }