예제 #1
0
 /**
  * Create a new Model.
  *
  * @param model {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  */
 @POST
 @Produces("application/json")
 @Consumes("application/json")
 public Response createModel(MLModelData model) {
   if (model.getAnalysisId() == 0 || model.getVersionSetId() == 0) {
     logger.error("Required parameters missing");
     return Response.status(Response.Status.BAD_REQUEST)
         .entity("Required parameters missing")
         .build();
   }
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   try {
     int tenantId = carbonContext.getTenantId();
     String userName = carbonContext.getUsername();
     model.setTenantId(tenantId);
     model.setUserName(userName);
     MLModelData insertedModel = mlModelHandler.createModel(model);
     return Response.ok(insertedModel).build();
   } catch (MLModelHandlerException e) {
     String msg = MLUtils.getErrorMsg("Error occurred while creating a model : " + model, e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #2
0
 /**
  * Get the model data
  *
  * @param modelName Name of the model
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  */
 @GET
 @Path("/{modelName}")
 @Produces("application/json")
 public Response getModel(@PathParam("modelName") String modelName) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName);
     if (model == null) {
       return Response.status(Response.Status.NOT_FOUND).build();
     }
     return Response.ok(model).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving a model [name] %s of tenant [id] %s and [user] %s .",
                 modelName, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #3
0
 /**
  * Make predictions using a model
  *
  * @param modelId Unique id of the model
  * @param data List of string arrays containing the feature values used for predictions
  * @return JSON array of predicted values
  */
 @POST
 @Path("/{modelId}/predict")
 @Produces("application/json")
 @Consumes("application/json")
 public Response predict(@PathParam("modelId") long modelId, List<String[]> data) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     long t1 = System.currentTimeMillis();
     List<?> predictions = mlModelHandler.predict(tenantId, userName, modelId, data);
     logger.info(
         String.format(
             "Prediction from model [id] %s finished in %s seconds.",
             modelId, (System.currentTimeMillis() - t1) / 1000.0));
     return Response.ok(predictions).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #4
0
 /**
  * Build the model
  *
  * @param modelId Unique id of the model to be built.
  */
 @POST
 @Path("/{modelId}")
 @Produces("application/json")
 @Consumes("application/json")
 public Response buildModel(@PathParam("modelId") long modelId) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     mlModelHandler.buildModel(tenantId, userName, modelId);
     return Response.ok().build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while building the model [id] %s of tenant [id] %s and [user] %s .",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelBuilderException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while building the model [id] %s of tenant [id] %s and [user] %s .",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #5
0
 /**
  * Predict using a file and return as a list of predicted values.
  *
  * @param modelId Unique id of the model
  * @param dataFormat Data format of the file (CSV or TSV)
  * @param inputStream File input stream generated from the file used for predictions
  * @return JSON array of predictions
  */
 @POST
 @Path("/predict")
 @Produces(MediaType.APPLICATION_JSON)
 @Consumes(MediaType.MULTIPART_FORM_DATA)
 public Response predict(
     @Multipart("modelId") long modelId,
     @Multipart("dataFormat") String dataFormat,
     @Multipart("file") InputStream inputStream) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     // validate input parameters
     // if it is a file upload, check whether the file is sent
     if (inputStream == null || inputStream.available() == 0) {
       String msg =
           String.format(
               "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .",
               modelId, tenantId, userName);
       logger.error(msg);
       return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
           .entity(new MLErrorBean(msg))
           .build();
     }
     List<?> predictions =
         mlModelHandler.predict(tenantId, userName, modelId, dataFormat, inputStream);
     return Response.ok(predictions).build();
   } catch (IOException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.BAD_REQUEST)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #6
0
 /**
  * Get all models
  *
  * @return JSON array of {@link org.wso2.carbon.ml.commons.domain.MLModelData} objects
  */
 @GET
 @Produces("application/json")
 public Response getAllModels() {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     List<MLModelData> models = mlModelHandler.getAllModels(tenantId, userName);
     return Response.ok(models).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving all models of tenant [id] %s and [user] %s .",
                 tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
예제 #7
0
  /**
   * Download the model
   *
   * @param modelName Name of the model
   * @return A {@link org.wso2.carbon.ml.commons.domain.MLModel} as a {@link
   *     javax.ws.rs.core.StreamingOutput}
   */
  @GET
  @Path("/{modelName}/export")
  @Produces(MediaType.APPLICATION_OCTET_STREAM)
  public Response exportModel(@PathParam("modelName") String modelName) {

    PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
    int tenantId = carbonContext.getTenantId();
    String userName = carbonContext.getUsername();
    try {
      MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName);
      if (model != null) {
        final MLModel generatedModel = mlModelHandler.retrieveModel(model.getId());
        StreamingOutput stream =
            new StreamingOutput() {
              public void write(OutputStream outputStream) throws IOException {
                ObjectOutputStream out = new ObjectOutputStream(outputStream);
                out.writeObject(generatedModel);
              }
            };
        return Response.ok(stream)
            .header("Content-disposition", "attachment; filename=" + modelName)
            .build();
      } else {
        return Response.status(Response.Status.NOT_FOUND).build();
      }
    } catch (MLModelHandlerException e) {
      String msg =
          MLUtils.getErrorMsg(
              String.format(
                  "Error occurred while retrieving model [name] %s of tenant [id] %s and [user] %s .",
                  modelName, tenantId, userName),
              e);
      logger.error(msg, e);
      return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
          .entity(new MLErrorBean(e.getMessage()))
          .build();
    }
  }
예제 #8
0
 /**
  * Predict using a file and return predictions as a CSV.
  *
  * @param modelId Unique id of the model
  * @param dataFormat Data format of the file (CSV or TSV)
  * @param columnHeader Whether the file contains the column header as the first row (YES or NO)
  * @param inputStream Input stream generated from the file used for predictions
  * @return A file as a {@link javax.ws.rs.core.StreamingOutput}
  */
 @POST
 @Path("/predictionStreams")
 @Produces(MediaType.APPLICATION_OCTET_STREAM)
 @Consumes(MediaType.MULTIPART_FORM_DATA)
 public Response streamingPredict(
     @Multipart("modelId") long modelId,
     @Multipart("dataFormat") String dataFormat,
     @Multipart("columnHeader") String columnHeader,
     @Multipart("file") InputStream inputStream) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     // validate input parameters
     // if it is a file upload, check whether the file is sent
     if (inputStream == null || inputStream.available() == 0) {
       String msg =
           String.format(
               "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .",
               modelId, tenantId, userName);
       logger.error(msg);
       return Response.status(Response.Status.BAD_REQUEST).entity(new MLErrorBean(msg)).build();
     }
     final String predictions =
         mlModelHandler.streamingPredict(
             tenantId, userName, modelId, dataFormat, columnHeader, inputStream);
     StreamingOutput stream =
         new StreamingOutput() {
           @Override
           public void write(OutputStream outputStream) throws IOException {
             Writer writer =
                 new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8));
             writer.write(predictions);
             writer.flush();
             writer.close();
           }
         };
     return Response.ok(stream)
         .header(
             "Content-disposition",
             "attachment; filename=Predictions_"
                 + modelId
                 + "_"
                 + MLUtils.getDate()
                 + MLConstants.CSV)
         .build();
   } catch (IOException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.BAD_REQUEST)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }