示例#1
0
  // Adapt a trained model to a test dataset with different enums
  /*@Test*/ public void testModelAdapt() {
    File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("KDDTrain.hex");
    File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("KDDTest.hex");
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = ParseDataset2.parse(dest1, new Key[] {fkey1});
      UKV.remove(fkey1);
      gbm.response = gbm.source.remove(41); // Response is col 41
      gbm.ntrees = 2;
      gbm.max_depth = 8;
      gbm.learn_rate = 0.2f;
      gbm.min_rows = 10;
      gbm.nbins = 50;
      gbm.invoke();

      // The test data set has a few more enums than the train
      Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2});
      Frame preds = gbm.score(ftest);

    } finally {
      UKV.remove(dest1); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
示例#2
0
  // ==========================================================================
  public void basicGBM(String fname, String hexname, PrepData prep) {
    File file = TestUtil.find_test_file(fname);
    if (file == null) return; // Silently abort test if the file is missing
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make(hexname);
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = fr = ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);
      int idx = prep.prep(fr);
      if (idx < 0) {
        gbm.classification = false;
        idx = ~idx;
      }
      String rname = fr._names[idx];
      gbm.response = fr.vecs()[idx];
      fr.remove(idx); // Move response to the end
      fr.add(rname, gbm.response);
      gbm.ntrees = 4;
      gbm.max_depth = 4;
      gbm.min_rows = 1;
      gbm.nbins = 50;
      gbm.cols = new int[fr.numCols()];
      for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
      gbm.learn_rate = .2f;
      gbm.invoke();

      fr = gbm.score(gbm.source);

      GBM.GBMModel gbmmodel = UKV.get(gbm.dest());
      // System.out.println(gbmmodel.toJava());

    } finally {
      UKV.remove(dest); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
示例#3
0
文件: Expr2Test.java 项目: Jfeng3/h2o
  @Test
  public void testBasicExpr1() {
    Key dest = Key.make("h.hex");
    try {
      File file = TestUtil.find_test_file("smalldata/tnc3_10.csv");
      // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv");
      // File file = TestUtil.find_test_file("smalldata/cars.csv");
      Key fkey = NFSFileVec.make(file);
      ParseDataset2.parse(dest, new Key[] {fkey});

      // Simple numbers & simple expressions
      checkStr("1.23", 1.23);
      checkStr(" 1.23 + 2.34", 3.57);
      checkStr(" 1.23 + 2.34 * 3", 8.25); // op precedence of * over +
      checkStr(
          " 1.23 2.34", "Junk at end of line\n" + " 1.23 2.34\n" + "      ^--^\n"); // Syntax error
      checkStr("1.23 < 2.34", 1);
      checkStr("1.23 <=2.34", 1);
      checkStr("1.23 > 2.34", 0);
      checkStr("1.23 >=2.34", 0);
      checkStr("1.23 ==2.34", 0);
      checkStr("1.23 !=2.34", 1);
      checkStr("1 & 2", 1);
      checkStr("NA&0", 0); // R-spec: 0 not NA
      checkStr("0&NA", 0); // R-spec: 0 not NA
      checkStr("NA&1", Double.NaN); // R-spec: NA not 1
      checkStr("1&NA", Double.NaN);
      checkStr("1|NA", 1);
      checkStr("1&&2", 1);
      checkStr("1||0", 1);
      checkStr("NA||1", 1);
      checkStr("NA||0", Double.NaN);
      checkStr("0||NA", Double.NaN);
      checkStr("!1", 0);
      checkStr("(!)(1)", 0);
      checkStr(
          "(!!)(1)",
          "Arg 'x' typed as dblary but passed dblary(dblary)\n" + "(!!)(1)\n" + " ^-^\n");
      checkStr("-1", -1);
      checkStr("-(1)", -1);
      checkStr("(-)(1)", "Passed 1 args but expected 2\n" + "(-)(1)\n" + "   ^--^\n");
      checkStr("-T", -1);
      checkStr(
          "* + 1",
          "Arg 'x' typed as dblary but passed anyary{dblary,dblary,}(dblary,dblary)\n"
              + "* + 1\n"
              + "^----^\n");
      // Simple op as prefix calls
      checkStr(
          "+(1.23,2.34)",
          "Missing ')'\n"
              + "+(1.23,2.34)\n"
              + "  ^---^\n"); // Syntax error: looks like unary op application
      checkStr("+(1.23)", 1.23); // Unary operator

      // Simple scalar assignment
      checkStr("1=2", "Junk at end of line\n" + "1=2\n" + " ^^\n");
      checkStr("x", "Unknown var x\n" + "x\n" + "^^\n");
      checkStr("x+2", "Unknown var x\n" + "x+2\n" + "^^\n");
      checkStr("2+x", "Missing expr or unknown ID\n" + "2+x\n" + "  ^\n");
      checkStr("x=1", 1);
      checkStr("x<-1", 1); // Alternative R assignment syntax
      checkStr("x=3;y=4", 4); // Return value is last expr

      // Ambiguity & Language
      checkStr("x=mean"); // Assign x to the built-in fcn mean
      checkStr(
          "x=mean=3",
          3); // Assign x & id mean with 3; "mean" here is not related to any built-in fcn
      checkStr("x=mean(c(3))", 3); // Assign x to the result of running fcn mean(3)
      checkStr("x=mean(c(\n3))", 3); // Assign x to the result of running fcn mean(3)
      checkStr(
          "x=mean+3",
          "Arg 'x' typed as dblary but passed dbl(ary)\n"
              + "x=mean+3\n"
              + "  ^-----^\n"); // Error: "mean" is a function; cannot add a function and a number
      checkStr(
          "apply(c(1,2,3),,nrow)", "Missing argument\napply(c(1,2,3),,nrow)\n               ^\n");
      checkStr(
          "foo==bar",
          "Unknown var foo\nfoo==bar\n^--^\n"); // Error msg is about "foo==" and not new assignment
      // "foo="

      // Simple array handling; broadcast operators
      checkStr("h.hex"); // Simple ref
      checkStr("h.hex[2,3]", 1); // Scalar selection
      checkStr(
          "h.hex[2,+]",
          "Must be scalar or array\n" + "h.hex[2,+]\n" + "        ^-^\n"); // Function not allowed
      checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4
      checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed
      checkStr(
          "h.hex[2+3,h.hex]",
          "Selector must be a single column: [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]"); // Error: col selector has too many columns
      checkStr("h.hex[2,]"); // Row 2 all cols
      checkStr("h.hex[,3]"); // Col 3 all rows
      checkStr("h.hex+1"); // Broadcast scalar over ary
      checkStr("h.hex-h.hex");
      checkStr("1.23+(h.hex-h.hex)");
      checkStr("(1.23+h.hex)-h.hex");
      checkStr("min(h.hex,1+2)", 0);
      checkStr("max(h.hex,1+2)", 211.3375);
      checkStr("min.na.rm(h.hex,NA)", 0); // 0
      checkStr("max.na.rm(h.hex,NA)", 211.3375); // 211.3375
      checkStr("min.na.rm(c(NA, 1), -1)", -1); // -1
      checkStr("max.na.rm(c(NA, 1), -1)", 1); // 1
      checkStr("max(c(Inf,1),  2 )", Double.POSITIVE_INFINITY); // Infinity
      checkStr("min(c(Inf,1),-Inf)", Double.NEGATIVE_INFINITY); // -Infinity
      checkStr("is.na(h.hex)");
      checkStr("sum(is.na(h.hex))", 0);
      checkStr("nrow(h.hex)*3", 30);
      checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]");
      checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level
      checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE

      checkStr(
          "(h.hex+1)<-2",
          "Junk at end of line\n" + "(h.hex+1)<-2\n" + "         ^-^\n"); // No L-value
      checkStr(
          "h.hex[nrow(h.hex=1),]",
          "Arg 'x' typed as ary but passed dbl\n"
              + "h.hex[nrow(h.hex=1),]\n"
              + "          ^--------^\n"); // Passing a scalar 1.0 to nrow
      checkStr(
          "h.hex[{h.hex=10},]"); // ERROR BROKEN: SHOULD PARSE statement list here; then do evil
      // side-effect killing h.hex but also using 10 to select last row
      checkStr("h.hex[3,4]<-4;", 4);
      checkStr("c(1,3,5)");
      // Column row subselection
      checkStr("h.hex[,c(1,3,5)]");
      checkStr("h.hex[c(1,3,5),]");
      checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]");
      // Named column selection
      checkStr("h.hex$ 2", "Missing column name after $\nh.hex$ 2\n      ^^\n");
      checkStr(
          "h.hex$crunk",
          "Missing column crunk in frame [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]");
      checkStr("h.hex$pclass");
      checkStr("mean(h.hex$pclass)", 1);

      // More complicated operator precedence
      checkStr("c(1,0)&c(2,3)"); // 1,0
      checkStr("c(2,NA)&&T", 1); // 1
      checkStr("-(x = 3)", -3);
      checkStr("x<-+");
      checkStr(
          "x<-+;x(2)",
          "Passed 1 args but expected 2\nx<-+;x(2)\n      ^--^\n"); // Error, + is binary if used as
      // prefix
      checkStr("x<-+;x(1,2)", 3); // 3
      checkStr("x<-*;x(2,3)", 6); // 6
      checkStr("x=c(0,1);!x+1"); // ! has lower precedence
      checkStr("x=c(1,-2);-+---x");
      checkStr("x=c(1,-2);--!--x");
      checkStr("!(y=c(3,4))");
      checkStr("!x!=1");
      checkStr("(!x)!=1");
      checkStr("1+x^2");
      checkStr("1+x**2");
      checkStr("x + 2/y");
      checkStr("x + (2/y)");
      checkStr("-x + y");
      checkStr("-(x + y)");
      checkStr("-x % y");
      checkStr("-(x % y)");
      checkStr("T|F&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0
      checkStr("T||F&&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0

      // User functions
      checkStr("function(=){x+1}(2)", "Invalid var\nfunction(=){x+1}(2)\n         ^\n");
      checkStr("function(x,=){x+1}(2)", "Invalid var\nfunction(x,=){x+1}(2)\n           ^\n");
      checkStr("function(x,<-){x+1}(2)", "Invalid var\nfunction(x,<-){x+1}(2)\n           ^\n");
      checkStr(
          "function(x,x){x+1}(2)", "Repeated argument\nfunction(x,x){x+1}(2)\n           ^^\n");
      checkStr("function(x,y,z){x[]}(h.hex,1,2)");
      checkStr(
          "function(x){x[]}(2)",
          "Arg 'x' typed as ary but passed dbl\nfunction(x){x[]}(2)\n                ^--^\n");
      checkStr("function(x){x+1}(2)", 3);
      checkStr("function(x){y=x+y}(2)");
      checkStr("function(x){}(2)");
      checkStr("function(x){y=x*2; y+1}(2)", 5);
      checkStr("function(x){y=1+2}(2)", 3);
      checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes
      checkStr("a=function(x) x+1; 7", 7); // Function def w/out curly-braces; return 7
      checkStr("a=function(x) {x+1}; 7", 7); // Function def w/ curly-braces; return 7
      checkStr("a=function(x) {x+1; 7}"); // Function def of 7
      checkStr("c(1,c(2,3))");
      checkStr("a=c(1,Inf);c(2,a)");
      // Test sum flattening all args
      checkStr("sum(1,2,3)", 6);
      checkStr("sum(c(1,3,5))", 9);
      checkStr("sum(4,c(1,3,5),2,6)", 21);
      checkStr("sum(1,h.hex,3)"); // should report an error because h.hex has enums
      checkStr("sum(c(NA,-1,1))", Double.NaN);
      checkStr("sum.na.rm(c(NA,-1,1))", 0);

      checkStr("function(a){a[];a=1}");
      checkStr("a=1;a=2;function(x){x=a;a=3}");
      checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)", 30);
      checkStr("a=h.hex;a[,1]=(a[,1]==8)");
      // Higher-order function typing: fun is typed in the body of function(x)
      checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)", 1);
      // Filter/selection
      checkStr("h.hex[h.hex[,4]>30,]");
      checkStr("a=c(1,2,3);a[a[,1]>10,1]");
      checkStr("sapply(a,sum)[1,1]", 6);
      checkStr("apply(h.hex,2,sum)"); // ERROR BROKEN: the ENUM cols should fold to NA
      checkStr("y=5;apply(h.hex,2,function(x){x[]+y})");
      checkStr(
          "apply(h.hex,2,function(x){x=1;h.hex})",
          "Arg 'fcn' typed as ary(ary) but passed ary(dbl)\napply(h.hex,2,function(x){x=1;h.hex})\n     ^-------------------------------^\n");
      checkStr(
          "apply(h.hex,2,function(x){h.hex})",
          "apply requires that ary fun(ary x) return 1 column");
      checkStr("apply(h.hex,2,function(x){sum(x)/nrow(x)})");
      checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)");
      checkStr("sum(apply(h.hex[,c(4,5)],1,mean))", 184.96); // Row-wise apply on mean

      // Conditional selection;
      checkStr("ifelse(0,1,2)", 2);
      checkStr("ifelse(0,h.hex+1,h.hex+2)");
      checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection
      checkStr("ifelse(0,+,*)(1,2)", 2); // Select functions
      checkStr("(0 ? + : *)(1,2)", 2); // Trinary select
      checkStr("(1? h.hex : (h.hex+1))[1,2]", 0); // True (vs false) test
      // Impute the mean
      checkStr(
          "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})");
      checkStr("factor(h.hex[,5])");

      // Slice assignment & map
      checkStr("h.hex[,2]");
      checkStr("h.hex[,2]+1");
      checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant
      checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col
      checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col
      checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols
      checkStr("h.hex[,7]=x=3; !(x+2)");
      checkStr("table(h.hex)");
      checkStr("table(h.hex[,5])");
      checkStr("table(h.hex[,c(2,7)])");
      checkStr("table(h.hex[,c(2,9)])");
      checkStr("a=cbind(c(1,2,3), c(4,5,6))");
      checkStr("a[,1] = factor(a[,1])");
      checkStr("is.factor(a[,1])", 1);
      checkStr("isTRUE(c(1,3))", 0);
      checkStr("a=1;isTRUE(1)", 1);
      checkStr("a=c(1,2);isTRUE(a)", 0);
      checkStr("isTRUE(min)", 0);
      checkStr("seq_len(0)", "Error in seq_len(0): argument must be coercible to positive integer");
      checkStr(
          "seq_len(-1)", "Error in seq_len(-1): argument must be coercible to positive integer");
      checkStr("seq_len(10)");
      checkStr("3 < 4 |  F &  3 > 4", 1); // Evals as (3<4) | (F & (3>4))
      checkStr("3 < 4 || F && 3 > 4", 1);
      checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81", Double.NaN);
      // checkStr("h.hex[h.hex[,4]>40,]=-99");
      // checkStr("h.hex[2,]=h.hex[7,]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]");
      // checkStr("map()");
      // checkStr("map(1)");
      // checkStr("map(+,h.hex,1)");
      // checkStr("map(+,1,2)");
      // checkStr("map(function(x){x[];1},h.hex)");
      // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)");
      // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)");

      // Quantile
      checkStr("quantile(seq_len(10),seq_len(10)/10)");
      checkStr("quantile(runif(seq_len(10000),-1),seq_len(10)/10)");
      checkStr("quantile(h.hex[,4],c(0,.05,0.3,0.55,0.7,0.95,0.99))");

      // ddply error checks
      checkStr("ddply(h.hex,h.hex,sum)", "Only one column-of-columns for column selection");
      checkStr("ddply(h.hex,seq_len(10000),sum)", "Too many columns selected");
      checkStr("ddply(h.hex,NA,sum)", "NA not a valid column");
      checkStr("ddply(h.hex,c(1,NA,3),sum)", "NA not a valid column");
      checkStr("ddply(h.hex,c(1,99,3),sum)", "Column 99 out of range for frame columns 17");

      checkStr("nrow(unique(h.hex[,5]))", 3);
      checkStr("nrow(unique(h.hex[,6]))", 2);
      checkStr("nrow(unique(h.hex[,c(5,6)]))", 4); // multi-column unique

      // Newlines as statement-ends
      checkStr("3*4+5*6", 42);
      checkStr("(h.hex[1,1]=2)", 2);
      checkStr("(h.hex[1,1]=2\n)", 2);
      checkStr("(h.hex[1,1]\n=2)", 2);
      checkStr("(h.hex\n[1,1]=2)", 2);
      checkStr("function(){x=1.23;(x=4.5)\n}()", 4.5);
      checkStr("function(){x=1.23;x=\n4.5\n}()", 4.5);
      checkStr("x=3\nfunction()x=1.23\nx", 3);
      checkStr("x=3\nfunction(){(x=1.23)}\nx", 3);
      checkStr("x=function(df)\n{\nmin(df$age)\n}\n;x(h.hex)", 0.92);
      checkStr("1.23\n-4", -4);
      checkStr("1.23 +\n-4", -2.77);
      checkStr("x=3;3*-x", -9); // *- is not a token
      checkStr(
          "x=3;3\n*\n-\nx", 3); // Each of '3' and '*' and '-' and 'x' is a standalone statement

      // No strings, yet
      checkStr(
          "function(df) { min(df[,\"age\"]) }",
          "The current Exec does not handle strings\nfunction(df) { min(df[,\"age\"]) }\n                        ^\n");

      // Cleanup testing temps
      checkStr("a=0;x=0;y=0", 0); // Delete keys from global scope

    } finally {
      Lockable.delete(dest); // Remove original hex frame key
    }
  }
示例#4
0
  @Test
  public void testBasicExpr1() {
    Key dest = Key.make("h.hex");
    try {
      File file = TestUtil.find_test_file("smalldata/tnc3_10.csv");
      // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv");
      // File file = TestUtil.find_test_file("smalldata/cars.csv");
      Key fkey = NFSFileVec.make(file);
      ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);

      checkStr("1.23"); // 1.23
      checkStr(" 1.23 + 2.34"); // 3.57
      checkStr(" 1.23 + 2.34 * 3"); // 10.71, L2R eval order
      checkStr(" 1.23 2.34"); // Syntax error
      checkStr("1.23 < 2.34"); // 1
      checkStr("1.23 <=2.34"); // 1
      checkStr("1.23 > 2.34"); // 0
      checkStr("1.23 >=2.34"); // 0
      checkStr("1.23 ==2.34"); // 0
      checkStr("1.23 !=2.34"); // 1
      checkStr("h.hex"); // Simple ref
      checkStr("+(1.23,2.34)"); // prefix 3.57
      checkStr("+(1.23)"); // Syntax error, not enuf args
      checkStr("+(1.23,2,3)"); // Syntax error, too many args
      checkStr("h.hex[2,3]"); // Scalar selection
      checkStr("h.hex[2,+]"); // Function not allowed
      checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4
      checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed
      checkStr("h.hex[2+3,h.hex]"); // Error: col selector has too many columns
      checkStr("h.hex[2,]"); // Row 2 all cols
      checkStr("h.hex[,3]"); // Col 3 all rows
      checkStr("h.hex+1"); // Broadcast scalar over ary
      checkStr("h.hex-h.hex");
      checkStr("1.23+(h.hex-h.hex)");
      checkStr("(1.23+h.hex)-h.hex");
      checkStr("min(h.hex,1+2)");
      checkStr("max(h.hex,1+2)");
      checkStr("is.na(h.hex)");
      checkStr("nrow(h.hex)*3");
      checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]");
      checkStr("1=2");
      checkStr("x");
      checkStr("x+2");
      checkStr("2+x");
      checkStr("x=1");
      checkStr("x<-1"); // Alternative R assignment syntax
      checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level
      checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE
      checkStr("x<-+");
      checkStr("(h.hex+1)<-2");
      checkStr("h.hex[nrow(h.hex=1),]");
      checkStr("h.hex[2,3]<-4;");
      checkStr("c(1,3,5)");
      checkStr("function(=){x+1}(2)");
      checkStr("function(x,=){x+1}(2)");
      checkStr("function(x,<-){x+1}(2)");
      checkStr("function(x,x){x+1}(2)");
      checkStr("function(x,y,z){x[]}(h.hex,1,2)");
      checkStr("function(x){x[]}(2)");
      checkStr("function(x){x+1}(2)");
      checkStr("function(x){y=x+y}(2)");
      checkStr("function(x){}(2)");
      checkStr("function(x){y=x*2; y+1}(2)");
      checkStr("function(x){y=1+2}(2)");
      checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes
      checkStr("sum(1,2,3)");
      checkStr("sum(c(1,3,5))");
      checkStr("sum(4,c(1,3,5),2,6)");
      checkStr("sum(1,h.hex,3)");
      checkStr("h.hex[,c(1,3,5)]");
      checkStr("h.hex[c(1,3,5),]");
      checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]");
      checkStr("function(a){a[];a=1}");
      checkStr("a=1;a=2;function(x){x=a;a=3}");
      checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)");
      checkStr("a=h.hex;a[,1]=(a[,1]==8)");
      // Higher-order function typing: fun is typed in the body of function(x)
      checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)");
      // Filter/selection
      checkStr("h.hex[h.hex[,2]>4,]");
      checkStr("a=c(1,2,3);a[a[,1]>10,1]");
      checkStr("apply(h.hex,2,sum)");
      checkStr("y=5;apply(h.hex,2,function(x){x[]+y})");
      checkStr("apply(h.hex,2,function(x){x=1;h.hex})");
      checkStr("apply(h.hex,2,function(x){h.hex})");
      checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)");

      // Conditional selection;
      checkStr("ifelse(0,1,2)");
      checkStr("ifelse(0,h.hex+1,h.hex+2)");
      checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection
      checkStr("ifelse(0,+,*)(1,2)"); // Select functions
      checkStr("(0 ? + : *)(1,2)"); // Trinary select
      checkStr("(1? h.hex : (h.hex+1))[1,2]"); // True (vs false) test
      // Impute the mean
      checkStr(
          "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})");
      checkStr("factor(h.hex[,5])");

      // Slice assignment & map
      checkStr("h.hex[,2]");
      checkStr("h.hex[,2]+1");
      checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant
      checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col
      checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col
      checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols
      checkStr("table(h.hex)");
      checkStr("table(h.hex[,3])");
      checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81");
      checkStr("a=cbind(c(1,2,3), c(4,5,6))");
      // checkStr("h.hex[h.hex[,2]>4,]=-99");
      // checkStr("h.hex[2,]=h.hex[7,]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]");
      // checkStr("map()");
      // checkStr("map(1)");
      // checkStr("map(+,h.hex,1)");
      // checkStr("map(+,1,2)");
      // checkStr("map(function(x){x[];1},h.hex)");
      // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)");
      // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)");

      checkStr("a=0;x=0"); // Delete keys from global scope

    } finally {
      UKV.remove(dest); // Remove original hex frame key
    }
  }
示例#5
0
  // Test-on-Train.  Slow test, needed to build a good model.
  @Test
  public void testGBMTrainTest() {
    File file1 = TestUtil.find_test_file("..//classifcation1Train.txt");
    if (file1 == null) return; // Silently ignore if file not found
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("train.hex");
    File file2 = TestUtil.find_test_file("..//classification1Test.txt");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("test.hex");
    GBM gbm = null;
    Frame fr = null, fpreds = null;
    try {
      gbm = new GBM();
      fr = ParseDataset2.parse(dest1, new Key[] {fkey1});
      UKV.remove(fkey1);
      UKV.remove(fr.remove("agentId")._key); // Remove unique ID; too predictive
      gbm.response = fr.remove("outcome"); // Train on the outcome
      gbm.source = fr;
      gbm.ntrees = 5;
      gbm.max_depth = 10;
      gbm.learn_rate = 0.2f;
      gbm.min_rows = 10;
      gbm.nbins = 100;
      gbm.invoke();

      // Test on the train data
      Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2});
      UKV.remove(fkey2);
      fpreds = gbm.score(ftest);

      // Build a confusion matrix
      ConfusionMatrix CM = new ConfusionMatrix();
      CM.actual = ftest;
      CM.vactual = ftest.vecs()[ftest.find("outcome")];
      CM.predict = fpreds;
      CM.vpredict = fpreds.vecs()[fpreds.find("predict")];
      CM.serve(); // Start it, do it

      // Really crappy cut-n-paste of what should be in the ConfusionMatrix class itself
      long cm[][] = CM.cm;
      long acts[] = new long[cm.length];
      long preds[] = new long[cm[0].length];
      for (int a = 0; a < cm.length; a++) {
        long sum = 0;
        for (int p = 0; p < cm[a].length; p++) {
          sum += cm[a][p];
          preds[p] += cm[a][p];
        }
        acts[a] = sum;
      }
      String adomain[] = ConfusionMatrix.show(acts, CM.vactual.domain());
      String pdomain[] = ConfusionMatrix.show(preds, CM.vpredict.domain());

      StringBuilder sb = new StringBuilder();
      sb.append("Act/Prd\t");
      for (String s : pdomain) if (s != null) sb.append(s).append('\t');
      sb.append("Error\n");

      long terr = 0;
      for (int a = 0; a < cm.length; a++) {
        if (adomain[a] == null) continue;
        sb.append(adomain[a]).append('\t');
        long correct = 0;
        for (int p = 0; p < pdomain.length; p++) {
          if (pdomain[p] == null) continue;
          if (adomain[a].equals(pdomain[p])) correct = cm[a][p];
          sb.append(cm[a][p]).append('\t');
        }
        long err = acts[a] - correct;
        terr += err; // Bump totals
        sb.append(String.format("%5.3f = %d / %d\n", (double) err / acts[a], err, acts[a]));
      }
      sb.append("Totals\t");
      for (int p = 0; p < pdomain.length; p++)
        if (pdomain[p] != null) sb.append(preds[p]).append("\t");
      sb.append(
          String.format(
              "%5.3f = %d / %d\n", (double) terr / CM.vactual.length(), terr, CM.vactual.length()));

      System.out.println(sb);

    } finally {
      UKV.remove(dest1); // Remove original hex frame key
      UKV.remove(fkey2);
      UKV.remove(dest2);
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
      }
      if (fr != null) fr.remove();
      if (fpreds != null) fpreds.remove();
    }
  }