private void testPredictDiabetes(boolean skipDecoding) throws MLHttpClientException, JSONException { String payload = "[[1,89,66,23,94,28.1,0.167,21],[2,197,70,45,543,30.5,0.158,53]]"; String url = skipDecoding ? "/api/models/" + modelId + "/predict?skipDecoding=true" : "/api/models/" + modelId + "/predict"; response = mlHttpclient.doHttpPost(url, payload); assertEquals( "Unexpected response received", Response.Status.OK.getStatusCode(), response.getStatusLine().getStatusCode()); String reply = mlHttpclient.getResponseAsString(response); JSONArray predictions = new JSONArray(reply); assertEquals( "Expected 2 predictions but received only " + predictions.length(), 2, predictions.length()); if (skipDecoding) { assertEquals( "Expected a double value but found " + predictions.get(0), true, predictions.get(0) instanceof Double); assertEquals( "Expected a double value but found " + predictions.get(1), true, predictions.get(1) instanceof Double); } }
/** * A test case for building a model with the given learning algorithm * * @param algorithmName Name of the learning algorithm * @param algorithmType Type of the learning algorithm * @throws MLHttpClientException * @throws IOException * @throws JSONException * @throws InterruptedException */ private void buildModelWithLearningAlgorithm(String algorithmName, String algorithmType) throws MLHttpClientException, IOException, JSONException, InterruptedException { modelName = MLTestUtils.createModelWithConfigurations( algorithmName, algorithmType, MLIntegrationTestConstants.RESPONSE_ATTRIBUTE_DIABETES, MLIntegrationTestConstants.TRAIN_DATA_FRACTION, projectId, versionSetId, mlHttpclient); modelId = mlHttpclient.getModelId(modelName); response = mlHttpclient.doHttpPost("/api/models/" + modelId); assertEquals( "Unexpected response received", Response.Status.OK.getStatusCode(), response.getStatusLine().getStatusCode()); response.close(); // Waiting for model building to end boolean status = MLTestUtils.checkModelStatusCompleted( modelName, mlHttpclient, MLIntegrationTestConstants.THREAD_SLEEP_TIME_LARGE, 1000); // Checks whether model building completed successfully assertEquals("Model building did not complete successfully", true, status); }
/** * A test case for predicting for a given set of data points from a file. * * @throws MLHttpClientException * @throws JSONException */ private void testPredictDiabetesFromFile() throws MLHttpClientException, JSONException { response = mlHttpclient.predictFromCSV(modelId, MLIntegrationTestConstants.DIABETES_DATASET_TEST); assertEquals( "Unexpected response received", Response.Status.OK.getStatusCode(), response.getStatusLine().getStatusCode()); String reply = mlHttpclient.getResponseAsString(response); JSONArray predictions = new JSONArray(reply); assertEquals(7, predictions.length()); }
/** * A test case for predicting for a given set of data points * * @throws MLHttpClientException * @throws JSONException */ private void testPredictGammaTelescope() throws MLHttpClientException, JSONException { String payload = "[[18.8562,16.46,2.4385,0.5282,0.2933,25.1269,-6.5401,-16.9327,11.461,162.848]," + "[191.8036,49.7183,3.0006,0.2093,0.1225,146.2148,143.6098,31.6216,44.3492,245.4199]]"; response = mlHttpclient.doHttpPost("/api/models/" + modelId + "/predict", payload); assertEquals( "Unexpected response received", Response.Status.OK.getStatusCode(), response.getStatusLine().getStatusCode()); String reply = mlHttpclient.getResponseAsString(response); JSONArray predictions = new JSONArray(reply); assertEquals(2, predictions.length()); }
/** * A test case for predicting with a dataset incompatible with the trained dataset in terms of * number of features * * @throws MLHttpClientException * @throws JSONException */ private void testPredictDiabetesInvalidNumberOfFeatures() throws MLHttpClientException, JSONException { String payload = "[[1,89,66,23,94,28.1,0.167],[2,197,70,45,543,30.5,0.158]]"; response = mlHttpclient.doHttpPost("/api/models/" + modelId + "/predict", payload); assertEquals( "Unexpected response received", Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatusLine().getStatusCode()); }