/** * Publish the model to ML registry * * @param modelId Unique id of the model to be published * @return JSON of {@link org.wso2.carbon.ml.rest.api.model.MLResponseBean} containing the * published location of the model */ @POST @Path("/{modelId}/publish") @Produces("application/json") @Consumes("application/json") public Response publishModel(@PathParam("modelId") long modelId) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { String registryPath = mlModelHandler.publishModel(tenantId, userName, modelId); return Response.ok(new MLResponseBean(registryPath)).build(); } catch (InvalidRequestException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while publishing the model [id] %s of tenant [id] %s and [user] %s .", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.BAD_REQUEST) .entity(new MLErrorBean(e.getMessage())) .build(); } catch (MLModelPublisherException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while publishing the model [id] %s of tenant [id] %s and [user] %s .", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Predict using a file and return as a list of predicted values. * * @param modelId Unique id of the model * @param dataFormat Data format of the file (CSV or TSV) * @param inputStream File input stream generated from the file used for predictions * @return JSON array of predictions */ @POST @Path("/predict") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.MULTIPART_FORM_DATA) public Response predict( @Multipart("modelId") long modelId, @Multipart("dataFormat") String dataFormat, @Multipart("file") InputStream inputStream) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { // validate input parameters // if it is a file upload, check whether the file is sent if (inputStream == null || inputStream.available() == 0) { String msg = String.format( "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .", modelId, tenantId, userName); logger.error(msg); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(msg)) .build(); } List<?> predictions = mlModelHandler.predict(tenantId, userName, modelId, dataFormat, inputStream); return Response.ok(predictions).build(); } catch (IOException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.BAD_REQUEST) .entity(new MLErrorBean(e.getMessage())) .build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Create a new Project. No validation happens here. Please call {@link #getProject(String)} * before this. * * @param project {@link org.wso2.carbon.ml.commons.domain.MLProject} object */ @POST @Produces("application/json") @Consumes("application/json") public Response createProject(MLProject project) { if (project.getName() == null || project.getName().isEmpty() || project.getDatasetName() == null || project.getDatasetName().isEmpty()) { logger.error("Required parameters missing"); return Response.status(Response.Status.BAD_REQUEST) .entity("Required parameters missing") .build(); } PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { project.setTenantId(tenantId); project.setUserName(userName); mlProjectHandler.createProject(project); return Response.ok().build(); } catch (MLProjectHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while creating a [project] %s of tenant [id] %s and [user] %s .", project, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Delete a project * * @param projectId Unique id of the project */ @DELETE @Path("/{projectId}") @Produces("application/json") public Response deleteProject(@PathParam("projectId") long projectId) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { mlProjectHandler.deleteProject(tenantId, userName, projectId); auditLog.info( String.format( "User [name] %s of tenant [id] %s deleted a project [id] %s ", userName, tenantId, projectId)); return Response.ok().build(); } catch (MLProjectHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while deleting a project [id] %s of tenant [id] %s and [user] %s .", projectId, tenantId, userName), e); logger.error(msg, e); auditLog.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Get analysis of a project given the analysis name * * @param projectId Unique id of the project * @param analysisName Name of the analysis * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLAnalysis} object */ @GET @Path("/{projectId}/analyses/{analysisName}") @Produces("application/json") public Response getAnalysisOfProject( @PathParam("projectId") long projectId, @PathParam("analysisName") String analysisName) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { MLAnalysis analysis = mlProjectHandler.getAnalysisOfProject(tenantId, userName, projectId, analysisName); if (analysis == null) { return Response.status(Response.Status.NOT_FOUND).build(); } return Response.ok(analysis).build(); } catch (MLProjectHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving analysis with [name] %s of project [id] %s of tenant [id] %s and [user] %s .", analysisName, projectId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Create a new Model. * * @param model {@link org.wso2.carbon.ml.commons.domain.MLModelData} object * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object */ @POST @Produces("application/json") @Consumes("application/json") public Response createModel(MLModelData model) { if (model.getAnalysisId() == 0 || model.getVersionSetId() == 0) { logger.error("Required parameters missing"); return Response.status(Response.Status.BAD_REQUEST) .entity("Required parameters missing") .build(); } PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); try { int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); model.setTenantId(tenantId); model.setUserName(userName); MLModelData insertedModel = mlModelHandler.createModel(model); return Response.ok(insertedModel).build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg("Error occurred while creating a model : " + model, e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Get the model data * * @param modelName Name of the model * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object */ @GET @Path("/{modelName}") @Produces("application/json") public Response getModel(@PathParam("modelName") String modelName) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName); if (model == null) { return Response.status(Response.Status.NOT_FOUND).build(); } return Response.ok(model).build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving a model [name] %s of tenant [id] %s and [user] %s .", modelName, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Make predictions using a model * * @param modelId Unique id of the model * @param data List of string arrays containing the feature values used for predictions * @return JSON array of predicted values */ @POST @Path("/{modelId}/predict") @Produces("application/json") @Consumes("application/json") public Response predict(@PathParam("modelId") long modelId, List<String[]> data) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { long t1 = System.currentTimeMillis(); List<?> predictions = mlModelHandler.predict(tenantId, userName, modelId, data); logger.info( String.format( "Prediction from model [id] %s finished in %s seconds.", modelId, (System.currentTimeMillis() - t1) / 1000.0)); return Response.ok(predictions).build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * A utility method to pre-process data * * @param sc JavaSparkContext * @param workflow Machine learning workflow * @param lines JavaRDD of strings * @param headerRow HeaderFilter row * @param columnSeparator Column separator * @return Returns a JavaRDD of doubles * @throws org.wso2.carbon.ml.model.exceptions.ModelServiceException */ public static JavaRDD<double[]> preProcess(MLModelConfigurationContext context) throws DatasetPreProcessingException { JavaSparkContext sc = context.getSparkContext(); Workflow workflow = context.getFacts(); JavaRDD<String> lines = context.getLines(); String headerRow = context.getHeaderRow(); String columnSeparator = context.getColumnSeparator(); Map<String, String> summaryStatsOfFeatures = context.getSummaryStatsOfFeatures(); List<Integer> newToOldIndicesList = context.getNewToOldIndicesList(); int responseIndex = context.getResponseIndex(); List<Map<String, Integer>> encodings = buildEncodings( workflow.getFeatures(), summaryStatsOfFeatures, newToOldIndicesList, responseIndex); context.setEncodings(encodings); // Apply the filter to discard rows with missing values. JavaRDD<String[]> tokensDiscardedRemoved = MLUtils.filterRows( columnSeparator, headerRow, lines, MLUtils.getImputeFeatureIndices( workflow, new ArrayList<Integer>(), MLConstants.DISCARD)); JavaRDD<String[]> filteredTokens = tokensDiscardedRemoved.map(new RemoveDiscardedFeatures(newToOldIndicesList, responseIndex)); JavaRDD<String[]> encodedTokens = filteredTokens.map(new BasicEncoder(encodings)); JavaRDD<double[]> features = null; // get feature indices for mean imputation List<Integer> meanImputeIndices = MLUtils.getImputeFeatureIndices(workflow, newToOldIndicesList, MLConstants.MEAN_IMPUTATION); if (meanImputeIndices.size() > 0) { // calculate means for the whole dataset (sampleFraction = 1.0) or a sample Map<Integer, Double> means = getMeans(sc, encodedTokens, meanImputeIndices, 0.01); // Replace missing values in impute indices with the mean for that column MeanImputation meanImputation = new MeanImputation(means); features = encodedTokens.map(meanImputation); } else { /** * Mean imputation mapper will convert string tokens to doubles as a part of the operation. If * there is no mean imputation for any columns, tokens has to be converted into doubles. */ features = encodedTokens.map(new StringArrayToDoubleArray()); } return features; }
/** * Get all projects created with the given dataset * * @param datasetName Name of the dataset * @return JSON array of {@link org.wso2.carbon.ml.rest.api.model.MLProjectBean} objects */ @GET @Path("/analyses") @Produces("application/json") public Response getAllProjectsWithAnalyses(@QueryParam("datasetName") String datasetName) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { List<MLProject> projects = mlProjectHandler.getAllProjects(tenantId, userName); List<MLProjectBean> projectBeans = new ArrayList<MLProjectBean>(); for (MLProject mlProject : projects) { if (!StringUtils.isEmpty(datasetName) && !datasetName.equals(mlProject.getDatasetName())) { continue; } MLProjectBean projectBean = new MLProjectBean(); long projectId = mlProject.getId(); projectBean.setId(projectId); projectBean.setCreatedTime(mlProject.getCreatedTime()); projectBean.setDatasetId(mlProject.getDatasetId()); projectBean.setDatasetName(mlProject.getDatasetName()); projectBean.setDatasetStatus(mlProject.getDatasetStatus()); projectBean.setDescription(mlProject.getDescription()); projectBean.setName(mlProject.getName()); List<MLAnalysisBean> analysisBeans = new ArrayList<MLAnalysisBean>(); List<MLAnalysis> analyses = mlProjectHandler.getAllAnalysesOfProject(tenantId, userName, projectId); for (MLAnalysis mlAnalysis : analyses) { MLAnalysisBean analysisBean = new MLAnalysisBean(); analysisBean.setId(mlAnalysis.getId()); analysisBean.setName(mlAnalysis.getName()); analysisBean.setProjectId(mlAnalysis.getProjectId()); analysisBean.setComments(mlAnalysis.getComments()); analysisBeans.add(analysisBean); } projectBean.setAnalyses(analysisBeans); projectBeans.add(projectBean); } return Response.ok(projectBeans).build(); } catch (MLProjectHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving all analyses of tenant [id] %s and [user] %s .", tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * This method returns multiclass confusion matrix for a given multiclass metric object * * @param multiclassMetrics Multiclass metric object */ private MulticlassConfusionMatrix getMulticlassConfusionMatrix( MulticlassMetrics multiclassMetrics, MLModel mlModel) { MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix(); if (multiclassMetrics != null) { int size = multiclassMetrics.confusionMatrix().numCols(); double[] matrixArray = multiclassMetrics.confusionMatrix().toArray(); double[][] matrix = new double[size][size]; // set values of matrix into a 2D array for (int i = 0; i < size; i++) { for (int j = 0; j < size; j++) { matrix[i][j] = matrixArray[(j * size) + i]; } } multiclassConfusionMatrix.setMatrix(matrix); List<Map<String, Integer>> encodings = mlModel.getEncodings(); // decode only if encodings are available if (encodings != null) { // last index is response variable encoding Map<String, Integer> encodingMap = encodings.get(encodings.size() - 1); List<String> decodedLabels = new ArrayList<String>(); for (double label : multiclassMetrics.labels()) { Integer labelInt = (int) label; String decodedLabel = MLUtils.getKeyByValue(encodingMap, labelInt); if (decodedLabel != null) { decodedLabels.add(decodedLabel); } else { continue; } } multiclassConfusionMatrix.setLabels(decodedLabels); } else { List<String> labelList = toStringList(multiclassMetrics.labels()); multiclassConfusionMatrix.setLabels(labelList); } multiclassConfusionMatrix.setSize(size); } return multiclassConfusionMatrix; }
/** * Get all projects * * @return JSON array of {@link org.wso2.carbon.ml.commons.domain.MLProject} objects */ @GET @Produces("application/json") public Response getAllProjects() { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { List<MLProject> projects = mlProjectHandler.getAllProjects(tenantId, userName); return Response.ok(projects).build(); } catch (MLProjectHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving all projects of tenant [id] %s and [user] %s .", tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Download the model * * @param modelName Name of the model * @return A {@link org.wso2.carbon.ml.commons.domain.MLModel} as a {@link * javax.ws.rs.core.StreamingOutput} */ @GET @Path("/{modelName}/export") @Produces(MediaType.APPLICATION_OCTET_STREAM) public Response exportModel(@PathParam("modelName") String modelName) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName); if (model != null) { final MLModel generatedModel = mlModelHandler.retrieveModel(model.getId()); StreamingOutput stream = new StreamingOutput() { public void write(OutputStream outputStream) throws IOException { ObjectOutputStream out = new ObjectOutputStream(outputStream); out.writeObject(generatedModel); } }; return Response.ok(stream) .header("Content-disposition", "attachment; filename=" + modelName) .build(); } else { return Response.status(Response.Status.NOT_FOUND).build(); } } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving model [name] %s of tenant [id] %s and [user] %s .", modelName, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** * Get the model summary * * @param modelId Unique id of the model * @return JSON of {@link org.wso2.carbon.ml.commons.domain.ModelSummary} object */ @GET @Path("/{modelId}/summary") @Produces("application/json") @Consumes("application/json") public Response getModelSummary(@PathParam("modelId") long modelId) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { ModelSummary modelSummary = mlModelHandler.getModelSummary(modelId); return Response.ok(modelSummary).build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while retrieving summary of the model [id] %s of tenant [id] %s and [user] %s .", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }
/** Build a supervised model. */ public MLModel build() throws MLModelBuilderException { MLModelConfigurationContext context = getContext(); JavaSparkContext sparkContext = null; DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService(); MLModel mlModel = new MLModel(); try { sparkContext = context.getSparkContext(); Workflow workflow = context.getFacts(); long modelId = context.getModelId(); // Verify validity of response variable String typeOfResponseVariable = getTypeOfResponseVariable(workflow.getResponseVariable(), workflow.getFeatures()); if (typeOfResponseVariable == null) { throw new MLModelBuilderException( "Type of response variable cannot be null for supervised learning " + "algorithms."); } // Stops model building if a categorical attribute is used with numerical prediction if (workflow.getAlgorithmClass().equals(AlgorithmType.NUMERICAL_PREDICTION.getValue()) && typeOfResponseVariable.equals(FeatureType.CATEGORICAL)) { throw new MLModelBuilderException( "Categorical attribute " + workflow.getResponseVariable() + " cannot be used as the response variable of the Numerical Prediction algorithm: " + workflow.getAlgorithmName()); } // generate train and test datasets by converting tokens to labeled points int responseIndex = context.getResponseIndex(); SortedMap<Integer, String> includedFeatures = MLUtils.getIncludedFeaturesAfterReordering( workflow, context.getNewToOldIndicesList(), responseIndex); // gets the pre-processed dataset JavaRDD<LabeledPoint> labeledPoints = preProcess().cache(); JavaRDD<LabeledPoint>[] dataSplit = labeledPoints.randomSplit( new double[] {workflow.getTrainDataFraction(), 1 - workflow.getTrainDataFraction()}, MLConstants.RANDOM_SEED); // remove from cache labeledPoints.unpersist(); JavaRDD<LabeledPoint> trainingData = dataSplit[0].cache(); JavaRDD<LabeledPoint> testingData = dataSplit[1]; // create a deployable MLModel object mlModel.setAlgorithmName(workflow.getAlgorithmName()); mlModel.setAlgorithmClass(workflow.getAlgorithmClass()); mlModel.setFeatures(workflow.getIncludedFeatures()); mlModel.setResponseVariable(workflow.getResponseVariable()); mlModel.setEncodings(context.getEncodings()); mlModel.setNewToOldIndicesList(context.getNewToOldIndicesList()); mlModel.setResponseIndex(responseIndex); ModelSummary summaryModel = null; Map<Integer, Integer> categoricalFeatureInfo; // build a machine learning model according to user selected algorithm SUPERVISED_ALGORITHM supervisedAlgorithm = SUPERVISED_ALGORITHM.valueOf(workflow.getAlgorithmName()); switch (supervisedAlgorithm) { case LOGISTIC_REGRESSION: summaryModel = buildLogisticRegressionModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures, true); break; case LOGISTIC_REGRESSION_LBFGS: summaryModel = buildLogisticRegressionModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures, false); break; case DECISION_TREE: categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings()); summaryModel = buildDecisionTreeModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures, categoricalFeatureInfo); break; case RANDOM_FOREST: categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings()); summaryModel = buildRandomForestTreeModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures, categoricalFeatureInfo); break; case SVM: summaryModel = buildSVMModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures); break; case NAIVE_BAYES: summaryModel = buildNaiveBayesModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures); break; case LINEAR_REGRESSION: summaryModel = buildLinearRegressionModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures); break; case RIDGE_REGRESSION: summaryModel = buildRidgeRegressionModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures); break; case LASSO_REGRESSION: summaryModel = buildLassoRegressionModel( sparkContext, modelId, trainingData, testingData, workflow, mlModel, includedFeatures); break; default: throw new AlgorithmNameException("Incorrect algorithm name"); } // persist model summary databaseService.updateModelSummary(modelId, summaryModel); return mlModel; } catch (Exception e) { throw new MLModelBuilderException( "An error occurred while building supervised machine learning model: " + e.getMessage(), e); } }
/** * Predict using a file and return predictions as a CSV. * * @param modelId Unique id of the model * @param dataFormat Data format of the file (CSV or TSV) * @param columnHeader Whether the file contains the column header as the first row (YES or NO) * @param inputStream Input stream generated from the file used for predictions * @return A file as a {@link javax.ws.rs.core.StreamingOutput} */ @POST @Path("/predictionStreams") @Produces(MediaType.APPLICATION_OCTET_STREAM) @Consumes(MediaType.MULTIPART_FORM_DATA) public Response streamingPredict( @Multipart("modelId") long modelId, @Multipart("dataFormat") String dataFormat, @Multipart("columnHeader") String columnHeader, @Multipart("file") InputStream inputStream) { PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext(); int tenantId = carbonContext.getTenantId(); String userName = carbonContext.getUsername(); try { // validate input parameters // if it is a file upload, check whether the file is sent if (inputStream == null || inputStream.available() == 0) { String msg = String.format( "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .", modelId, tenantId, userName); logger.error(msg); return Response.status(Response.Status.BAD_REQUEST).entity(new MLErrorBean(msg)).build(); } final String predictions = mlModelHandler.streamingPredict( tenantId, userName, modelId, dataFormat, columnHeader, inputStream); StreamingOutput stream = new StreamingOutput() { @Override public void write(OutputStream outputStream) throws IOException { Writer writer = new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)); writer.write(predictions); writer.flush(); writer.close(); } }; return Response.ok(stream) .header( "Content-disposition", "attachment; filename=Predictions_" + modelId + "_" + MLUtils.getDate() + MLConstants.CSV) .build(); } catch (IOException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.BAD_REQUEST) .entity(new MLErrorBean(e.getMessage())) .build(); } catch (MLModelHandlerException e) { String msg = MLUtils.getErrorMsg( String.format( "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.", modelId, tenantId, userName), e); logger.error(msg, e); return Response.status(Response.Status.INTERNAL_SERVER_ERROR) .entity(new MLErrorBean(e.getMessage())) .build(); } }