// Adapt a trained model to a test dataset with different enums /*@Test*/ public void testModelAdapt() { File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz"); Key fkey1 = NFSFileVec.make(file1); Key dest1 = Key.make("KDDTrain.hex"); File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz"); Key fkey2 = NFSFileVec.make(file2); Key dest2 = Key.make("KDDTest.hex"); GBM gbm = null; Frame fr = null; try { gbm = new GBM(); gbm.source = ParseDataset2.parse(dest1, new Key[] {fkey1}); UKV.remove(fkey1); gbm.response = gbm.source.remove(41); // Response is col 41 gbm.ntrees = 2; gbm.max_depth = 8; gbm.learn_rate = 0.2f; gbm.min_rows = 10; gbm.nbins = 50; gbm.invoke(); // The test data set has a few more enums than the train Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2}); Frame preds = gbm.score(ftest); } finally { UKV.remove(dest1); // Remove original hex frame key if (gbm != null) { UKV.remove(gbm.dest()); // Remove the model UKV.remove(gbm.response._key); gbm.remove(); // Remove GBM Job if (fr != null) fr.remove(); } } }
/** * Find & parse a folder of CSV files. NPE if file not found. * * @param fname Test filename * @return Frame or NPE */ protected Frame parse_test_folder(String fname) { File folder = find_test_file(fname); assert folder.isDirectory(); File[] files = folder.listFiles(); Arrays.sort(files); ArrayList<Key> keys = new ArrayList<Key>(); for (File f : files) if (f.isFile()) keys.add(NFSFileVec.make(f)._key); Key[] res = new Key[keys.size()]; keys.toArray(res); return ParseDataset.parse(Key.make(), res); }
// ========================================================================== public void basicGBM(String fname, String hexname, PrepData prep) { File file = TestUtil.find_test_file(fname); if (file == null) return; // Silently abort test if the file is missing Key fkey = NFSFileVec.make(file); Key dest = Key.make(hexname); GBM gbm = null; Frame fr = null; try { gbm = new GBM(); gbm.source = fr = ParseDataset2.parse(dest, new Key[] {fkey}); UKV.remove(fkey); int idx = prep.prep(fr); if (idx < 0) { gbm.classification = false; idx = ~idx; } String rname = fr._names[idx]; gbm.response = fr.vecs()[idx]; fr.remove(idx); // Move response to the end fr.add(rname, gbm.response); gbm.ntrees = 4; gbm.max_depth = 4; gbm.min_rows = 1; gbm.nbins = 50; gbm.cols = new int[fr.numCols()]; for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i; gbm.learn_rate = .2f; gbm.invoke(); fr = gbm.score(gbm.source); GBM.GBMModel gbmmodel = UKV.get(gbm.dest()); // System.out.println(gbmmodel.toJava()); } finally { UKV.remove(dest); // Remove original hex frame key if (gbm != null) { UKV.remove(gbm.dest()); // Remove the model UKV.remove(gbm.response._key); gbm.remove(); // Remove GBM Job if (fr != null) fr.remove(); } } }
@Test public void testBasicExpr1() { Key dest = Key.make("h.hex"); try { File file = TestUtil.find_test_file("smalldata/tnc3_10.csv"); // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv"); // File file = TestUtil.find_test_file("smalldata/cars.csv"); Key fkey = NFSFileVec.make(file); ParseDataset2.parse(dest, new Key[] {fkey}); // Simple numbers & simple expressions checkStr("1.23", 1.23); checkStr(" 1.23 + 2.34", 3.57); checkStr(" 1.23 + 2.34 * 3", 8.25); // op precedence of * over + checkStr( " 1.23 2.34", "Junk at end of line\n" + " 1.23 2.34\n" + " ^--^\n"); // Syntax error checkStr("1.23 < 2.34", 1); checkStr("1.23 <=2.34", 1); checkStr("1.23 > 2.34", 0); checkStr("1.23 >=2.34", 0); checkStr("1.23 ==2.34", 0); checkStr("1.23 !=2.34", 1); checkStr("1 & 2", 1); checkStr("NA&0", 0); // R-spec: 0 not NA checkStr("0&NA", 0); // R-spec: 0 not NA checkStr("NA&1", Double.NaN); // R-spec: NA not 1 checkStr("1&NA", Double.NaN); checkStr("1|NA", 1); checkStr("1&&2", 1); checkStr("1||0", 1); checkStr("NA||1", 1); checkStr("NA||0", Double.NaN); checkStr("0||NA", Double.NaN); checkStr("!1", 0); checkStr("(!)(1)", 0); checkStr( "(!!)(1)", "Arg 'x' typed as dblary but passed dblary(dblary)\n" + "(!!)(1)\n" + " ^-^\n"); checkStr("-1", -1); checkStr("-(1)", -1); checkStr("(-)(1)", "Passed 1 args but expected 2\n" + "(-)(1)\n" + " ^--^\n"); checkStr("-T", -1); checkStr( "* + 1", "Arg 'x' typed as dblary but passed anyary{dblary,dblary,}(dblary,dblary)\n" + "* + 1\n" + "^----^\n"); // Simple op as prefix calls checkStr( "+(1.23,2.34)", "Missing ')'\n" + "+(1.23,2.34)\n" + " ^---^\n"); // Syntax error: looks like unary op application checkStr("+(1.23)", 1.23); // Unary operator // Simple scalar assignment checkStr("1=2", "Junk at end of line\n" + "1=2\n" + " ^^\n"); checkStr("x", "Unknown var x\n" + "x\n" + "^^\n"); checkStr("x+2", "Unknown var x\n" + "x+2\n" + "^^\n"); checkStr("2+x", "Missing expr or unknown ID\n" + "2+x\n" + " ^\n"); checkStr("x=1", 1); checkStr("x<-1", 1); // Alternative R assignment syntax checkStr("x=3;y=4", 4); // Return value is last expr // Ambiguity & Language checkStr("x=mean"); // Assign x to the built-in fcn mean checkStr( "x=mean=3", 3); // Assign x & id mean with 3; "mean" here is not related to any built-in fcn checkStr("x=mean(c(3))", 3); // Assign x to the result of running fcn mean(3) checkStr("x=mean(c(\n3))", 3); // Assign x to the result of running fcn mean(3) checkStr( "x=mean+3", "Arg 'x' typed as dblary but passed dbl(ary)\n" + "x=mean+3\n" + " ^-----^\n"); // Error: "mean" is a function; cannot add a function and a number checkStr( "apply(c(1,2,3),,nrow)", "Missing argument\napply(c(1,2,3),,nrow)\n ^\n"); checkStr( "foo==bar", "Unknown var foo\nfoo==bar\n^--^\n"); // Error msg is about "foo==" and not new assignment // "foo=" // Simple array handling; broadcast operators checkStr("h.hex"); // Simple ref checkStr("h.hex[2,3]", 1); // Scalar selection checkStr( "h.hex[2,+]", "Must be scalar or array\n" + "h.hex[2,+]\n" + " ^-^\n"); // Function not allowed checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4 checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed checkStr( "h.hex[2+3,h.hex]", "Selector must be a single column: [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]"); // Error: col selector has too many columns checkStr("h.hex[2,]"); // Row 2 all cols checkStr("h.hex[,3]"); // Col 3 all rows checkStr("h.hex+1"); // Broadcast scalar over ary checkStr("h.hex-h.hex"); checkStr("1.23+(h.hex-h.hex)"); checkStr("(1.23+h.hex)-h.hex"); checkStr("min(h.hex,1+2)", 0); checkStr("max(h.hex,1+2)", 211.3375); checkStr("min.na.rm(h.hex,NA)", 0); // 0 checkStr("max.na.rm(h.hex,NA)", 211.3375); // 211.3375 checkStr("min.na.rm(c(NA, 1), -1)", -1); // -1 checkStr("max.na.rm(c(NA, 1), -1)", 1); // 1 checkStr("max(c(Inf,1), 2 )", Double.POSITIVE_INFINITY); // Infinity checkStr("min(c(Inf,1),-Inf)", Double.NEGATIVE_INFINITY); // -Infinity checkStr("is.na(h.hex)"); checkStr("sum(is.na(h.hex))", 0); checkStr("nrow(h.hex)*3", 30); checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]"); checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE checkStr( "(h.hex+1)<-2", "Junk at end of line\n" + "(h.hex+1)<-2\n" + " ^-^\n"); // No L-value checkStr( "h.hex[nrow(h.hex=1),]", "Arg 'x' typed as ary but passed dbl\n" + "h.hex[nrow(h.hex=1),]\n" + " ^--------^\n"); // Passing a scalar 1.0 to nrow checkStr( "h.hex[{h.hex=10},]"); // ERROR BROKEN: SHOULD PARSE statement list here; then do evil // side-effect killing h.hex but also using 10 to select last row checkStr("h.hex[3,4]<-4;", 4); checkStr("c(1,3,5)"); // Column row subselection checkStr("h.hex[,c(1,3,5)]"); checkStr("h.hex[c(1,3,5),]"); checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]"); // Named column selection checkStr("h.hex$ 2", "Missing column name after $\nh.hex$ 2\n ^^\n"); checkStr( "h.hex$crunk", "Missing column crunk in frame [pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, boat, body, home.dest, survived]"); checkStr("h.hex$pclass"); checkStr("mean(h.hex$pclass)", 1); // More complicated operator precedence checkStr("c(1,0)&c(2,3)"); // 1,0 checkStr("c(2,NA)&&T", 1); // 1 checkStr("-(x = 3)", -3); checkStr("x<-+"); checkStr( "x<-+;x(2)", "Passed 1 args but expected 2\nx<-+;x(2)\n ^--^\n"); // Error, + is binary if used as // prefix checkStr("x<-+;x(1,2)", 3); // 3 checkStr("x<-*;x(2,3)", 6); // 6 checkStr("x=c(0,1);!x+1"); // ! has lower precedence checkStr("x=c(1,-2);-+---x"); checkStr("x=c(1,-2);--!--x"); checkStr("!(y=c(3,4))"); checkStr("!x!=1"); checkStr("(!x)!=1"); checkStr("1+x^2"); checkStr("1+x**2"); checkStr("x + 2/y"); checkStr("x + (2/y)"); checkStr("-x + y"); checkStr("-(x + y)"); checkStr("-x % y"); checkStr("-(x % y)"); checkStr("T|F&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0 checkStr("T||F&&F", 1); // Evals as T|(F&F)==1 not as (T|F)&F==0 // User functions checkStr("function(=){x+1}(2)", "Invalid var\nfunction(=){x+1}(2)\n ^\n"); checkStr("function(x,=){x+1}(2)", "Invalid var\nfunction(x,=){x+1}(2)\n ^\n"); checkStr("function(x,<-){x+1}(2)", "Invalid var\nfunction(x,<-){x+1}(2)\n ^\n"); checkStr( "function(x,x){x+1}(2)", "Repeated argument\nfunction(x,x){x+1}(2)\n ^^\n"); checkStr("function(x,y,z){x[]}(h.hex,1,2)"); checkStr( "function(x){x[]}(2)", "Arg 'x' typed as ary but passed dbl\nfunction(x){x[]}(2)\n ^--^\n"); checkStr("function(x){x+1}(2)", 3); checkStr("function(x){y=x+y}(2)"); checkStr("function(x){}(2)"); checkStr("function(x){y=x*2; y+1}(2)", 5); checkStr("function(x){y=1+2}(2)", 3); checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes checkStr("a=function(x) x+1; 7", 7); // Function def w/out curly-braces; return 7 checkStr("a=function(x) {x+1}; 7", 7); // Function def w/ curly-braces; return 7 checkStr("a=function(x) {x+1; 7}"); // Function def of 7 checkStr("c(1,c(2,3))"); checkStr("a=c(1,Inf);c(2,a)"); // Test sum flattening all args checkStr("sum(1,2,3)", 6); checkStr("sum(c(1,3,5))", 9); checkStr("sum(4,c(1,3,5),2,6)", 21); checkStr("sum(1,h.hex,3)"); // should report an error because h.hex has enums checkStr("sum(c(NA,-1,1))", Double.NaN); checkStr("sum.na.rm(c(NA,-1,1))", 0); checkStr("function(a){a[];a=1}"); checkStr("a=1;a=2;function(x){x=a;a=3}"); checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)", 30); checkStr("a=h.hex;a[,1]=(a[,1]==8)"); // Higher-order function typing: fun is typed in the body of function(x) checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)", 1); // Filter/selection checkStr("h.hex[h.hex[,4]>30,]"); checkStr("a=c(1,2,3);a[a[,1]>10,1]"); checkStr("sapply(a,sum)[1,1]", 6); checkStr("apply(h.hex,2,sum)"); // ERROR BROKEN: the ENUM cols should fold to NA checkStr("y=5;apply(h.hex,2,function(x){x[]+y})"); checkStr( "apply(h.hex,2,function(x){x=1;h.hex})", "Arg 'fcn' typed as ary(ary) but passed ary(dbl)\napply(h.hex,2,function(x){x=1;h.hex})\n ^-------------------------------^\n"); checkStr( "apply(h.hex,2,function(x){h.hex})", "apply requires that ary fun(ary x) return 1 column"); checkStr("apply(h.hex,2,function(x){sum(x)/nrow(x)})"); checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)"); checkStr("sum(apply(h.hex[,c(4,5)],1,mean))", 184.96); // Row-wise apply on mean // Conditional selection; checkStr("ifelse(0,1,2)", 2); checkStr("ifelse(0,h.hex+1,h.hex+2)"); checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection checkStr("ifelse(0,+,*)(1,2)", 2); // Select functions checkStr("(0 ? + : *)(1,2)", 2); // Trinary select checkStr("(1? h.hex : (h.hex+1))[1,2]", 0); // True (vs false) test // Impute the mean checkStr( "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})"); checkStr("factor(h.hex[,5])"); // Slice assignment & map checkStr("h.hex[,2]"); checkStr("h.hex[,2]+1"); checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols checkStr("h.hex[,7]=x=3; !(x+2)"); checkStr("table(h.hex)"); checkStr("table(h.hex[,5])"); checkStr("table(h.hex[,c(2,7)])"); checkStr("table(h.hex[,c(2,9)])"); checkStr("a=cbind(c(1,2,3), c(4,5,6))"); checkStr("a[,1] = factor(a[,1])"); checkStr("is.factor(a[,1])", 1); checkStr("isTRUE(c(1,3))", 0); checkStr("a=1;isTRUE(1)", 1); checkStr("a=c(1,2);isTRUE(a)", 0); checkStr("isTRUE(min)", 0); checkStr("seq_len(0)", "Error in seq_len(0): argument must be coercible to positive integer"); checkStr( "seq_len(-1)", "Error in seq_len(-1): argument must be coercible to positive integer"); checkStr("seq_len(10)"); checkStr("3 < 4 | F & 3 > 4", 1); // Evals as (3<4) | (F & (3>4)) checkStr("3 < 4 || F && 3 > 4", 1); checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81", Double.NaN); // checkStr("h.hex[h.hex[,4]>40,]=-99"); // checkStr("h.hex[2,]=h.hex[7,]"); // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]"); // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]"); // checkStr("map()"); // checkStr("map(1)"); // checkStr("map(+,h.hex,1)"); // checkStr("map(+,1,2)"); // checkStr("map(function(x){x[];1},h.hex)"); // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)"); // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)"); // Quantile checkStr("quantile(seq_len(10),seq_len(10)/10)"); checkStr("quantile(runif(seq_len(10000),-1),seq_len(10)/10)"); checkStr("quantile(h.hex[,4],c(0,.05,0.3,0.55,0.7,0.95,0.99))"); // ddply error checks checkStr("ddply(h.hex,h.hex,sum)", "Only one column-of-columns for column selection"); checkStr("ddply(h.hex,seq_len(10000),sum)", "Too many columns selected"); checkStr("ddply(h.hex,NA,sum)", "NA not a valid column"); checkStr("ddply(h.hex,c(1,NA,3),sum)", "NA not a valid column"); checkStr("ddply(h.hex,c(1,99,3),sum)", "Column 99 out of range for frame columns 17"); checkStr("nrow(unique(h.hex[,5]))", 3); checkStr("nrow(unique(h.hex[,6]))", 2); checkStr("nrow(unique(h.hex[,c(5,6)]))", 4); // multi-column unique // Newlines as statement-ends checkStr("3*4+5*6", 42); checkStr("(h.hex[1,1]=2)", 2); checkStr("(h.hex[1,1]=2\n)", 2); checkStr("(h.hex[1,1]\n=2)", 2); checkStr("(h.hex\n[1,1]=2)", 2); checkStr("function(){x=1.23;(x=4.5)\n}()", 4.5); checkStr("function(){x=1.23;x=\n4.5\n}()", 4.5); checkStr("x=3\nfunction()x=1.23\nx", 3); checkStr("x=3\nfunction(){(x=1.23)}\nx", 3); checkStr("x=function(df)\n{\nmin(df$age)\n}\n;x(h.hex)", 0.92); checkStr("1.23\n-4", -4); checkStr("1.23 +\n-4", -2.77); checkStr("x=3;3*-x", -9); // *- is not a token checkStr( "x=3;3\n*\n-\nx", 3); // Each of '3' and '*' and '-' and 'x' is a standalone statement // No strings, yet checkStr( "function(df) { min(df[,\"age\"]) }", "The current Exec does not handle strings\nfunction(df) { min(df[,\"age\"]) }\n ^\n"); // Cleanup testing temps checkStr("a=0;x=0;y=0", 0); // Delete keys from global scope } finally { Lockable.delete(dest); // Remove original hex frame key } }
@Test public void testBasicExpr1() { Key dest = Key.make("h.hex"); try { File file = TestUtil.find_test_file("smalldata/tnc3_10.csv"); // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv"); // File file = TestUtil.find_test_file("smalldata/cars.csv"); Key fkey = NFSFileVec.make(file); ParseDataset2.parse(dest, new Key[] {fkey}); UKV.remove(fkey); checkStr("1.23"); // 1.23 checkStr(" 1.23 + 2.34"); // 3.57 checkStr(" 1.23 + 2.34 * 3"); // 10.71, L2R eval order checkStr(" 1.23 2.34"); // Syntax error checkStr("1.23 < 2.34"); // 1 checkStr("1.23 <=2.34"); // 1 checkStr("1.23 > 2.34"); // 0 checkStr("1.23 >=2.34"); // 0 checkStr("1.23 ==2.34"); // 0 checkStr("1.23 !=2.34"); // 1 checkStr("h.hex"); // Simple ref checkStr("+(1.23,2.34)"); // prefix 3.57 checkStr("+(1.23)"); // Syntax error, not enuf args checkStr("+(1.23,2,3)"); // Syntax error, too many args checkStr("h.hex[2,3]"); // Scalar selection checkStr("h.hex[2,+]"); // Function not allowed checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4 checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed checkStr("h.hex[2+3,h.hex]"); // Error: col selector has too many columns checkStr("h.hex[2,]"); // Row 2 all cols checkStr("h.hex[,3]"); // Col 3 all rows checkStr("h.hex+1"); // Broadcast scalar over ary checkStr("h.hex-h.hex"); checkStr("1.23+(h.hex-h.hex)"); checkStr("(1.23+h.hex)-h.hex"); checkStr("min(h.hex,1+2)"); checkStr("max(h.hex,1+2)"); checkStr("is.na(h.hex)"); checkStr("nrow(h.hex)*3"); checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]"); checkStr("1=2"); checkStr("x"); checkStr("x+2"); checkStr("2+x"); checkStr("x=1"); checkStr("x<-1"); // Alternative R assignment syntax checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE checkStr("x<-+"); checkStr("(h.hex+1)<-2"); checkStr("h.hex[nrow(h.hex=1),]"); checkStr("h.hex[2,3]<-4;"); checkStr("c(1,3,5)"); checkStr("function(=){x+1}(2)"); checkStr("function(x,=){x+1}(2)"); checkStr("function(x,<-){x+1}(2)"); checkStr("function(x,x){x+1}(2)"); checkStr("function(x,y,z){x[]}(h.hex,1,2)"); checkStr("function(x){x[]}(2)"); checkStr("function(x){x+1}(2)"); checkStr("function(x){y=x+y}(2)"); checkStr("function(x){}(2)"); checkStr("function(x){y=x*2; y+1}(2)"); checkStr("function(x){y=1+2}(2)"); checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes checkStr("sum(1,2,3)"); checkStr("sum(c(1,3,5))"); checkStr("sum(4,c(1,3,5),2,6)"); checkStr("sum(1,h.hex,3)"); checkStr("h.hex[,c(1,3,5)]"); checkStr("h.hex[c(1,3,5),]"); checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]"); checkStr("function(a){a[];a=1}"); checkStr("a=1;a=2;function(x){x=a;a=3}"); checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)"); checkStr("a=h.hex;a[,1]=(a[,1]==8)"); // Higher-order function typing: fun is typed in the body of function(x) checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)"); // Filter/selection checkStr("h.hex[h.hex[,2]>4,]"); checkStr("a=c(1,2,3);a[a[,1]>10,1]"); checkStr("apply(h.hex,2,sum)"); checkStr("y=5;apply(h.hex,2,function(x){x[]+y})"); checkStr("apply(h.hex,2,function(x){x=1;h.hex})"); checkStr("apply(h.hex,2,function(x){h.hex})"); checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)"); // Conditional selection; checkStr("ifelse(0,1,2)"); checkStr("ifelse(0,h.hex+1,h.hex+2)"); checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection checkStr("ifelse(0,+,*)(1,2)"); // Select functions checkStr("(0 ? + : *)(1,2)"); // Trinary select checkStr("(1? h.hex : (h.hex+1))[1,2]"); // True (vs false) test // Impute the mean checkStr( "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})"); checkStr("factor(h.hex[,5])"); // Slice assignment & map checkStr("h.hex[,2]"); checkStr("h.hex[,2]+1"); checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols checkStr("table(h.hex)"); checkStr("table(h.hex[,3])"); checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81"); checkStr("a=cbind(c(1,2,3), c(4,5,6))"); // checkStr("h.hex[h.hex[,2]>4,]=-99"); // checkStr("h.hex[2,]=h.hex[7,]"); // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]"); // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]"); // checkStr("map()"); // checkStr("map(1)"); // checkStr("map(+,h.hex,1)"); // checkStr("map(+,1,2)"); // checkStr("map(function(x){x[];1},h.hex)"); // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)"); // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)"); checkStr("a=0;x=0"); // Delete keys from global scope } finally { UKV.remove(dest); // Remove original hex frame key } }
/** * Find & parse a CSV file. NPE if file not found. * * @param fname Test filename * @return Frame or NPE */ public static Frame parse_test_file(String fname) { return parse_test_file(Key.make(), fname); }
/** * The train/valid Frame instances are sorted by categorical (themselves sorted by cardinality * greatest to least) with all numerical columns following. The response column(s) are placed at * the end. * * <p>Interactions: 1. Num-Num (Note: N(0,1) * N(0,1) ~ N(0,1) ) 2. Num-Enum 3. Enum-Enum * * <p>Interactions are produced on the fly and are dense (in all 3 cases). Consumers of DataInfo * should not have to care how these interactions are generated. Any heuristic using the fullN * value should continue functioning the same. * * <p>Interactions are specified in two ways: A. As a list of pairs of column indices. B. As a * list of pairs of column indices with limited enums. */ public DataInfo( Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, boolean missingBucket, boolean weight, boolean offset, boolean fold, Model.InteractionPair[] interactions) { super(Key.<DataInfo>make()); _valid = valid != null; assert predictor_transform != null; assert response_transform != null; _offset = offset; _weights = weight; _fold = fold; assert !(skipMissing && imputeMissing) : "skipMissing and imputeMissing cannot both be true"; _skipMissing = skipMissing; _imputeMissing = imputeMissing; _predictor_transform = predictor_transform; _response_transform = response_transform; _responses = nResponses; _useAllFactorLevels = useAllFactorLevels; _interactions = interactions; // create dummy InteractionWrappedVecs and shove them onto the front if (_interactions != null) { _interactionVecs = new int[_interactions.length]; train = Model.makeInteractions( train, false, _interactions, _useAllFactorLevels, _skipMissing, predictor_transform == TransformType.STANDARDIZE) .add(train); if (valid != null) valid = Model.makeInteractions( valid, true, _interactions, _useAllFactorLevels, _skipMissing, predictor_transform == TransformType.STANDARDIZE) .add(valid); // FIXME: should be using the training subs/muls! } _permutation = new int[train.numCols()]; final Vec[] tvecs = train.vecs(); // Count categorical-vs-numerical final int n = tvecs.length - _responses - (offset ? 1 : 0) - (weight ? 1 : 0) - (fold ? 1 : 0); int[] nums = MemoryManager.malloc4(n); int[] cats = MemoryManager.malloc4(n); int nnums = 0, ncats = 0; for (int i = 0; i < n; ++i) if (tvecs[i].isCategorical()) cats[ncats++] = i; else nums[nnums++] = i; _nums = nnums; _cats = ncats; _catLvls = new int[ncats][]; // sort the cats in the decreasing order according to their size for (int i = 0; i < ncats; ++i) for (int j = i + 1; j < ncats; ++j) if (tvecs[cats[i]].domain().length < tvecs[cats[j]].domain().length) { int x = cats[i]; cats[i] = cats[j]; cats[j] = x; } String[] names = new String[train.numCols()]; Vec[] tvecs2 = new Vec[train.numCols()]; // Compute the cardinality of each cat _catModes = new int[ncats]; _catOffsets = MemoryManager.malloc4(ncats + 1); _catMissing = new boolean[ncats]; int len = _catOffsets[0] = 0; int interactionIdx = 0; // simple index into the _interactionVecs array ArrayList<Integer> interactionIds; if (_interactions == null) { interactionIds = new ArrayList<>(); for (int i = 0; i < tvecs.length; ++i) if (tvecs[i] instanceof InteractionWrappedVec) { interactionIds.add(i); } _interactionVecs = new int[interactionIds.size()]; for (int i = 0; i < _interactionVecs.length; ++i) _interactionVecs[i] = interactionIds.get(i); } for (int i = 0; i < ncats; ++i) { names[i] = train._names[cats[i]]; Vec v = (tvecs2[i] = tvecs[cats[i]]); _catMissing[i] = missingBucket; // needed for test time if (v instanceof InteractionWrappedVec) { if (_interactions != null) _interactions[interactionIdx].vecIdx = i; _interactionVecs[interactionIdx++] = i; // i (and not cats[i]) because this is the index in _adaptedFrame _catOffsets[i + 1] = (len += v.domain().length + (missingBucket ? 1 : 0)); } else _catOffsets[i + 1] = (len += v.domain().length - (useAllFactorLevels ? 0 : 1) + (missingBucket ? 1 : 0)); // missing values turn into a new factor level _catModes[i] = imputeMissing ? imputeCat(train.vec(cats[i])) : _catMissing[i] ? v.domain().length : -100; _permutation[i] = cats[i]; } _numMeans = new double[nnums]; _numOffsets = MemoryManager.malloc4(nnums + 1); _numOffsets[0] = len; boolean isIWV; // is InteractionWrappedVec? for (int i = 0; i < nnums; ++i) { names[i + ncats] = train._names[nums[i]]; Vec v = train.vec(nums[i]); tvecs2[i + ncats] = v; isIWV = v instanceof InteractionWrappedVec; if (isIWV) { if (null != _interactions) _interactions[interactionIdx].vecIdx = i + ncats; _interactionVecs[interactionIdx++] = i + ncats; } _numOffsets[i + 1] = (len += (isIWV ? ((InteractionWrappedVec) v).expandedLength() : 1)); _numMeans[i] = train.vec(nums[i]).mean(); _permutation[i + ncats] = nums[i]; } for (int i = names.length - nResponses - (weight ? 1 : 0) - (offset ? 1 : 0) - (fold ? 1 : 0); i < names.length; ++i) { names[i] = train._names[i]; tvecs2[i] = train.vec(i); } _adaptedFrame = new Frame(names, tvecs2); train.restructure(names, tvecs2); if (valid != null) valid.restructure(names, valid.vecs(names)); // _adaptedFrame = train; setPredictorTransform(predictor_transform); if (_responses > 0) setResponseTransform(response_transform); }
// Test-on-Train. Slow test, needed to build a good model. @Test public void testGBMTrainTest() { File file1 = TestUtil.find_test_file("..//classifcation1Train.txt"); if (file1 == null) return; // Silently ignore if file not found Key fkey1 = NFSFileVec.make(file1); Key dest1 = Key.make("train.hex"); File file2 = TestUtil.find_test_file("..//classification1Test.txt"); Key fkey2 = NFSFileVec.make(file2); Key dest2 = Key.make("test.hex"); GBM gbm = null; Frame fr = null, fpreds = null; try { gbm = new GBM(); fr = ParseDataset2.parse(dest1, new Key[] {fkey1}); UKV.remove(fkey1); UKV.remove(fr.remove("agentId")._key); // Remove unique ID; too predictive gbm.response = fr.remove("outcome"); // Train on the outcome gbm.source = fr; gbm.ntrees = 5; gbm.max_depth = 10; gbm.learn_rate = 0.2f; gbm.min_rows = 10; gbm.nbins = 100; gbm.invoke(); // Test on the train data Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2}); UKV.remove(fkey2); fpreds = gbm.score(ftest); // Build a confusion matrix ConfusionMatrix CM = new ConfusionMatrix(); CM.actual = ftest; CM.vactual = ftest.vecs()[ftest.find("outcome")]; CM.predict = fpreds; CM.vpredict = fpreds.vecs()[fpreds.find("predict")]; CM.serve(); // Start it, do it // Really crappy cut-n-paste of what should be in the ConfusionMatrix class itself long cm[][] = CM.cm; long acts[] = new long[cm.length]; long preds[] = new long[cm[0].length]; for (int a = 0; a < cm.length; a++) { long sum = 0; for (int p = 0; p < cm[a].length; p++) { sum += cm[a][p]; preds[p] += cm[a][p]; } acts[a] = sum; } String adomain[] = ConfusionMatrix.show(acts, CM.vactual.domain()); String pdomain[] = ConfusionMatrix.show(preds, CM.vpredict.domain()); StringBuilder sb = new StringBuilder(); sb.append("Act/Prd\t"); for (String s : pdomain) if (s != null) sb.append(s).append('\t'); sb.append("Error\n"); long terr = 0; for (int a = 0; a < cm.length; a++) { if (adomain[a] == null) continue; sb.append(adomain[a]).append('\t'); long correct = 0; for (int p = 0; p < pdomain.length; p++) { if (pdomain[p] == null) continue; if (adomain[a].equals(pdomain[p])) correct = cm[a][p]; sb.append(cm[a][p]).append('\t'); } long err = acts[a] - correct; terr += err; // Bump totals sb.append(String.format("%5.3f = %d / %d\n", (double) err / acts[a], err, acts[a])); } sb.append("Totals\t"); for (int p = 0; p < pdomain.length; p++) if (pdomain[p] != null) sb.append(preds[p]).append("\t"); sb.append( String.format( "%5.3f = %d / %d\n", (double) terr / CM.vactual.length(), terr, CM.vactual.length())); System.out.println(sb); } finally { UKV.remove(dest1); // Remove original hex frame key UKV.remove(fkey2); UKV.remove(dest2); if (gbm != null) { UKV.remove(gbm.dest()); // Remove the model UKV.remove(gbm.response._key); gbm.remove(); // Remove GBM Job } if (fr != null) fr.remove(); if (fpreds != null) fpreds.remove(); } }