Ejemplo n.º 1
0
 /**
  * Publish the model to ML registry
  *
  * @param modelId Unique id of the model to be published
  * @return JSON of {@link org.wso2.carbon.ml.rest.api.model.MLResponseBean} containing the
  *     published location of the model
  */
 @POST
 @Path("/{modelId}/publish")
 @Produces("application/json")
 @Consumes("application/json")
 public Response publishModel(@PathParam("modelId") long modelId) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     String registryPath = mlModelHandler.publishModel(tenantId, userName, modelId);
     return Response.ok(new MLResponseBean(registryPath)).build();
   } catch (InvalidRequestException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while publishing the model [id] %s of tenant [id] %s and [user] %s .",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.BAD_REQUEST)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelPublisherException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while publishing the model [id] %s of tenant [id] %s and [user] %s .",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 2
0
 /**
  * Predict using a file and return as a list of predicted values.
  *
  * @param modelId Unique id of the model
  * @param dataFormat Data format of the file (CSV or TSV)
  * @param inputStream File input stream generated from the file used for predictions
  * @return JSON array of predictions
  */
 @POST
 @Path("/predict")
 @Produces(MediaType.APPLICATION_JSON)
 @Consumes(MediaType.MULTIPART_FORM_DATA)
 public Response predict(
     @Multipart("modelId") long modelId,
     @Multipart("dataFormat") String dataFormat,
     @Multipart("file") InputStream inputStream) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     // validate input parameters
     // if it is a file upload, check whether the file is sent
     if (inputStream == null || inputStream.available() == 0) {
       String msg =
           String.format(
               "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .",
               modelId, tenantId, userName);
       logger.error(msg);
       return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
           .entity(new MLErrorBean(msg))
           .build();
     }
     List<?> predictions =
         mlModelHandler.predict(tenantId, userName, modelId, dataFormat, inputStream);
     return Response.ok(predictions).build();
   } catch (IOException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.BAD_REQUEST)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 3
0
 /**
  * Create a new Project. No validation happens here. Please call {@link #getProject(String)}
  * before this.
  *
  * @param project {@link org.wso2.carbon.ml.commons.domain.MLProject} object
  */
 @POST
 @Produces("application/json")
 @Consumes("application/json")
 public Response createProject(MLProject project) {
   if (project.getName() == null
       || project.getName().isEmpty()
       || project.getDatasetName() == null
       || project.getDatasetName().isEmpty()) {
     logger.error("Required parameters missing");
     return Response.status(Response.Status.BAD_REQUEST)
         .entity("Required parameters missing")
         .build();
   }
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     project.setTenantId(tenantId);
     project.setUserName(userName);
     mlProjectHandler.createProject(project);
     return Response.ok().build();
   } catch (MLProjectHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while creating a [project] %s of tenant [id] %s and [user] %s .",
                 project, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 4
0
 /**
  * Delete a project
  *
  * @param projectId Unique id of the project
  */
 @DELETE
 @Path("/{projectId}")
 @Produces("application/json")
 public Response deleteProject(@PathParam("projectId") long projectId) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     mlProjectHandler.deleteProject(tenantId, userName, projectId);
     auditLog.info(
         String.format(
             "User [name] %s of tenant [id] %s deleted a project [id] %s ",
             userName, tenantId, projectId));
     return Response.ok().build();
   } catch (MLProjectHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while deleting a project [id]  %s of tenant [id] %s and [user] %s .",
                 projectId, tenantId, userName),
             e);
     logger.error(msg, e);
     auditLog.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 5
0
 /**
  * Get analysis of a project given the analysis name
  *
  * @param projectId Unique id of the project
  * @param analysisName Name of the analysis
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLAnalysis} object
  */
 @GET
 @Path("/{projectId}/analyses/{analysisName}")
 @Produces("application/json")
 public Response getAnalysisOfProject(
     @PathParam("projectId") long projectId, @PathParam("analysisName") String analysisName) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     MLAnalysis analysis =
         mlProjectHandler.getAnalysisOfProject(tenantId, userName, projectId, analysisName);
     if (analysis == null) {
       return Response.status(Response.Status.NOT_FOUND).build();
     }
     return Response.ok(analysis).build();
   } catch (MLProjectHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving analysis with [name] %s of project [id]  %s of tenant [id] %s and [user] %s .",
                 analysisName, projectId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 6
0
 /**
  * Create a new Model.
  *
  * @param model {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  */
 @POST
 @Produces("application/json")
 @Consumes("application/json")
 public Response createModel(MLModelData model) {
   if (model.getAnalysisId() == 0 || model.getVersionSetId() == 0) {
     logger.error("Required parameters missing");
     return Response.status(Response.Status.BAD_REQUEST)
         .entity("Required parameters missing")
         .build();
   }
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   try {
     int tenantId = carbonContext.getTenantId();
     String userName = carbonContext.getUsername();
     model.setTenantId(tenantId);
     model.setUserName(userName);
     MLModelData insertedModel = mlModelHandler.createModel(model);
     return Response.ok(insertedModel).build();
   } catch (MLModelHandlerException e) {
     String msg = MLUtils.getErrorMsg("Error occurred while creating a model : " + model, e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 7
0
 /**
  * Get the model data
  *
  * @param modelName Name of the model
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.MLModelData} object
  */
 @GET
 @Path("/{modelName}")
 @Produces("application/json")
 public Response getModel(@PathParam("modelName") String modelName) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName);
     if (model == null) {
       return Response.status(Response.Status.NOT_FOUND).build();
     }
     return Response.ok(model).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving a model [name] %s of tenant [id] %s and [user] %s .",
                 modelName, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 8
0
 /**
  * Make predictions using a model
  *
  * @param modelId Unique id of the model
  * @param data List of string arrays containing the feature values used for predictions
  * @return JSON array of predicted values
  */
 @POST
 @Path("/{modelId}/predict")
 @Produces("application/json")
 @Consumes("application/json")
 public Response predict(@PathParam("modelId") long modelId, List<String[]> data) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     long t1 = System.currentTimeMillis();
     List<?> predictions = mlModelHandler.predict(tenantId, userName, modelId, data);
     logger.info(
         String.format(
             "Prediction from model [id] %s finished in %s seconds.",
             modelId, (System.currentTimeMillis() - t1) / 1000.0));
     return Response.ok(predictions).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 9
0
  /**
   * A utility method to pre-process data
   *
   * @param sc JavaSparkContext
   * @param workflow Machine learning workflow
   * @param lines JavaRDD of strings
   * @param headerRow HeaderFilter row
   * @param columnSeparator Column separator
   * @return Returns a JavaRDD of doubles
   * @throws org.wso2.carbon.ml.model.exceptions.ModelServiceException
   */
  public static JavaRDD<double[]> preProcess(MLModelConfigurationContext context)
      throws DatasetPreProcessingException {
    JavaSparkContext sc = context.getSparkContext();
    Workflow workflow = context.getFacts();
    JavaRDD<String> lines = context.getLines();
    String headerRow = context.getHeaderRow();
    String columnSeparator = context.getColumnSeparator();
    Map<String, String> summaryStatsOfFeatures = context.getSummaryStatsOfFeatures();
    List<Integer> newToOldIndicesList = context.getNewToOldIndicesList();
    int responseIndex = context.getResponseIndex();

    List<Map<String, Integer>> encodings =
        buildEncodings(
            workflow.getFeatures(), summaryStatsOfFeatures, newToOldIndicesList, responseIndex);
    context.setEncodings(encodings);

    // Apply the filter to discard rows with missing values.
    JavaRDD<String[]> tokensDiscardedRemoved =
        MLUtils.filterRows(
            columnSeparator,
            headerRow,
            lines,
            MLUtils.getImputeFeatureIndices(
                workflow, new ArrayList<Integer>(), MLConstants.DISCARD));
    JavaRDD<String[]> filteredTokens =
        tokensDiscardedRemoved.map(new RemoveDiscardedFeatures(newToOldIndicesList, responseIndex));
    JavaRDD<String[]> encodedTokens = filteredTokens.map(new BasicEncoder(encodings));
    JavaRDD<double[]> features = null;
    // get feature indices for mean imputation
    List<Integer> meanImputeIndices =
        MLUtils.getImputeFeatureIndices(workflow, newToOldIndicesList, MLConstants.MEAN_IMPUTATION);
    if (meanImputeIndices.size() > 0) {
      // calculate means for the whole dataset (sampleFraction = 1.0) or a sample
      Map<Integer, Double> means = getMeans(sc, encodedTokens, meanImputeIndices, 0.01);
      // Replace missing values in impute indices with the mean for that column
      MeanImputation meanImputation = new MeanImputation(means);
      features = encodedTokens.map(meanImputation);
    } else {
      /**
       * Mean imputation mapper will convert string tokens to doubles as a part of the operation. If
       * there is no mean imputation for any columns, tokens has to be converted into doubles.
       */
      features = encodedTokens.map(new StringArrayToDoubleArray());
    }
    return features;
  }
Ejemplo n.º 10
0
  /**
   * Get all projects created with the given dataset
   *
   * @param datasetName Name of the dataset
   * @return JSON array of {@link org.wso2.carbon.ml.rest.api.model.MLProjectBean} objects
   */
  @GET
  @Path("/analyses")
  @Produces("application/json")
  public Response getAllProjectsWithAnalyses(@QueryParam("datasetName") String datasetName) {
    PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
    int tenantId = carbonContext.getTenantId();
    String userName = carbonContext.getUsername();
    try {
      List<MLProject> projects = mlProjectHandler.getAllProjects(tenantId, userName);
      List<MLProjectBean> projectBeans = new ArrayList<MLProjectBean>();
      for (MLProject mlProject : projects) {
        if (!StringUtils.isEmpty(datasetName) && !datasetName.equals(mlProject.getDatasetName())) {
          continue;
        }
        MLProjectBean projectBean = new MLProjectBean();
        long projectId = mlProject.getId();
        projectBean.setId(projectId);
        projectBean.setCreatedTime(mlProject.getCreatedTime());
        projectBean.setDatasetId(mlProject.getDatasetId());
        projectBean.setDatasetName(mlProject.getDatasetName());
        projectBean.setDatasetStatus(mlProject.getDatasetStatus());
        projectBean.setDescription(mlProject.getDescription());
        projectBean.setName(mlProject.getName());

        List<MLAnalysisBean> analysisBeans = new ArrayList<MLAnalysisBean>();
        List<MLAnalysis> analyses =
            mlProjectHandler.getAllAnalysesOfProject(tenantId, userName, projectId);
        for (MLAnalysis mlAnalysis : analyses) {
          MLAnalysisBean analysisBean = new MLAnalysisBean();
          analysisBean.setId(mlAnalysis.getId());
          analysisBean.setName(mlAnalysis.getName());
          analysisBean.setProjectId(mlAnalysis.getProjectId());
          analysisBean.setComments(mlAnalysis.getComments());
          analysisBeans.add(analysisBean);
        }
        projectBean.setAnalyses(analysisBeans);
        projectBeans.add(projectBean);
      }
      return Response.ok(projectBeans).build();
    } catch (MLProjectHandlerException e) {
      String msg =
          MLUtils.getErrorMsg(
              String.format(
                  "Error occurred while retrieving all analyses of tenant [id] %s and [user] %s .",
                  tenantId, userName),
              e);
      logger.error(msg, e);
      return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
          .entity(new MLErrorBean(e.getMessage()))
          .build();
    }
  }
  /**
   * This method returns multiclass confusion matrix for a given multiclass metric object
   *
   * @param multiclassMetrics Multiclass metric object
   */
  private MulticlassConfusionMatrix getMulticlassConfusionMatrix(
      MulticlassMetrics multiclassMetrics, MLModel mlModel) {
    MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix();
    if (multiclassMetrics != null) {
      int size = multiclassMetrics.confusionMatrix().numCols();
      double[] matrixArray = multiclassMetrics.confusionMatrix().toArray();
      double[][] matrix = new double[size][size];
      // set values of matrix into a 2D array
      for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
          matrix[i][j] = matrixArray[(j * size) + i];
        }
      }
      multiclassConfusionMatrix.setMatrix(matrix);

      List<Map<String, Integer>> encodings = mlModel.getEncodings();
      // decode only if encodings are available
      if (encodings != null) {
        // last index is response variable encoding
        Map<String, Integer> encodingMap = encodings.get(encodings.size() - 1);
        List<String> decodedLabels = new ArrayList<String>();
        for (double label : multiclassMetrics.labels()) {
          Integer labelInt = (int) label;
          String decodedLabel = MLUtils.getKeyByValue(encodingMap, labelInt);
          if (decodedLabel != null) {
            decodedLabels.add(decodedLabel);
          } else {
            continue;
          }
        }
        multiclassConfusionMatrix.setLabels(decodedLabels);
      } else {
        List<String> labelList = toStringList(multiclassMetrics.labels());
        multiclassConfusionMatrix.setLabels(labelList);
      }

      multiclassConfusionMatrix.setSize(size);
    }
    return multiclassConfusionMatrix;
  }
Ejemplo n.º 12
0
 /**
  * Get all projects
  *
  * @return JSON array of {@link org.wso2.carbon.ml.commons.domain.MLProject} objects
  */
 @GET
 @Produces("application/json")
 public Response getAllProjects() {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     List<MLProject> projects = mlProjectHandler.getAllProjects(tenantId, userName);
     return Response.ok(projects).build();
   } catch (MLProjectHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving all projects of tenant [id] %s and [user] %s .",
                 tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
Ejemplo n.º 13
0
  /**
   * Download the model
   *
   * @param modelName Name of the model
   * @return A {@link org.wso2.carbon.ml.commons.domain.MLModel} as a {@link
   *     javax.ws.rs.core.StreamingOutput}
   */
  @GET
  @Path("/{modelName}/export")
  @Produces(MediaType.APPLICATION_OCTET_STREAM)
  public Response exportModel(@PathParam("modelName") String modelName) {

    PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
    int tenantId = carbonContext.getTenantId();
    String userName = carbonContext.getUsername();
    try {
      MLModelData model = mlModelHandler.getModel(tenantId, userName, modelName);
      if (model != null) {
        final MLModel generatedModel = mlModelHandler.retrieveModel(model.getId());
        StreamingOutput stream =
            new StreamingOutput() {
              public void write(OutputStream outputStream) throws IOException {
                ObjectOutputStream out = new ObjectOutputStream(outputStream);
                out.writeObject(generatedModel);
              }
            };
        return Response.ok(stream)
            .header("Content-disposition", "attachment; filename=" + modelName)
            .build();
      } else {
        return Response.status(Response.Status.NOT_FOUND).build();
      }
    } catch (MLModelHandlerException e) {
      String msg =
          MLUtils.getErrorMsg(
              String.format(
                  "Error occurred while retrieving model [name] %s of tenant [id] %s and [user] %s .",
                  modelName, tenantId, userName),
              e);
      logger.error(msg, e);
      return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
          .entity(new MLErrorBean(e.getMessage()))
          .build();
    }
  }
Ejemplo n.º 14
0
 /**
  * Get the model summary
  *
  * @param modelId Unique id of the model
  * @return JSON of {@link org.wso2.carbon.ml.commons.domain.ModelSummary} object
  */
 @GET
 @Path("/{modelId}/summary")
 @Produces("application/json")
 @Consumes("application/json")
 public Response getModelSummary(@PathParam("modelId") long modelId) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     ModelSummary modelSummary = mlModelHandler.getModelSummary(modelId);
     return Response.ok(modelSummary).build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while retrieving summary of the model [id] %s of tenant [id] %s and [user] %s .",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }
  /** Build a supervised model. */
  public MLModel build() throws MLModelBuilderException {
    MLModelConfigurationContext context = getContext();
    JavaSparkContext sparkContext = null;
    DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
    MLModel mlModel = new MLModel();
    try {
      sparkContext = context.getSparkContext();
      Workflow workflow = context.getFacts();
      long modelId = context.getModelId();

      // Verify validity of response variable
      String typeOfResponseVariable =
          getTypeOfResponseVariable(workflow.getResponseVariable(), workflow.getFeatures());

      if (typeOfResponseVariable == null) {
        throw new MLModelBuilderException(
            "Type of response variable cannot be null for supervised learning " + "algorithms.");
      }

      // Stops model building if a categorical attribute is used with numerical prediction
      if (workflow.getAlgorithmClass().equals(AlgorithmType.NUMERICAL_PREDICTION.getValue())
          && typeOfResponseVariable.equals(FeatureType.CATEGORICAL)) {
        throw new MLModelBuilderException(
            "Categorical attribute "
                + workflow.getResponseVariable()
                + " cannot be used as the response variable of the Numerical Prediction algorithm: "
                + workflow.getAlgorithmName());
      }

      // generate train and test datasets by converting tokens to labeled points
      int responseIndex = context.getResponseIndex();
      SortedMap<Integer, String> includedFeatures =
          MLUtils.getIncludedFeaturesAfterReordering(
              workflow, context.getNewToOldIndicesList(), responseIndex);

      // gets the pre-processed dataset
      JavaRDD<LabeledPoint> labeledPoints = preProcess().cache();

      JavaRDD<LabeledPoint>[] dataSplit =
          labeledPoints.randomSplit(
              new double[] {workflow.getTrainDataFraction(), 1 - workflow.getTrainDataFraction()},
              MLConstants.RANDOM_SEED);

      // remove from cache
      labeledPoints.unpersist();

      JavaRDD<LabeledPoint> trainingData = dataSplit[0].cache();
      JavaRDD<LabeledPoint> testingData = dataSplit[1];
      // create a deployable MLModel object
      mlModel.setAlgorithmName(workflow.getAlgorithmName());
      mlModel.setAlgorithmClass(workflow.getAlgorithmClass());
      mlModel.setFeatures(workflow.getIncludedFeatures());
      mlModel.setResponseVariable(workflow.getResponseVariable());
      mlModel.setEncodings(context.getEncodings());
      mlModel.setNewToOldIndicesList(context.getNewToOldIndicesList());
      mlModel.setResponseIndex(responseIndex);

      ModelSummary summaryModel = null;
      Map<Integer, Integer> categoricalFeatureInfo;

      // build a machine learning model according to user selected algorithm
      SUPERVISED_ALGORITHM supervisedAlgorithm =
          SUPERVISED_ALGORITHM.valueOf(workflow.getAlgorithmName());
      switch (supervisedAlgorithm) {
        case LOGISTIC_REGRESSION:
          summaryModel =
              buildLogisticRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  true);
          break;
        case LOGISTIC_REGRESSION_LBFGS:
          summaryModel =
              buildLogisticRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  false);
          break;
        case DECISION_TREE:
          categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings());
          summaryModel =
              buildDecisionTreeModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  categoricalFeatureInfo);
          break;
        case RANDOM_FOREST:
          categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings());
          summaryModel =
              buildRandomForestTreeModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  categoricalFeatureInfo);
          break;
        case SVM:
          summaryModel =
              buildSVMModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case NAIVE_BAYES:
          summaryModel =
              buildNaiveBayesModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case LINEAR_REGRESSION:
          summaryModel =
              buildLinearRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case RIDGE_REGRESSION:
          summaryModel =
              buildRidgeRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case LASSO_REGRESSION:
          summaryModel =
              buildLassoRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        default:
          throw new AlgorithmNameException("Incorrect algorithm name");
      }

      // persist model summary
      databaseService.updateModelSummary(modelId, summaryModel);
      return mlModel;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building supervised machine learning model: " + e.getMessage(),
          e);
    }
  }
Ejemplo n.º 16
0
 /**
  * Predict using a file and return predictions as a CSV.
  *
  * @param modelId Unique id of the model
  * @param dataFormat Data format of the file (CSV or TSV)
  * @param columnHeader Whether the file contains the column header as the first row (YES or NO)
  * @param inputStream Input stream generated from the file used for predictions
  * @return A file as a {@link javax.ws.rs.core.StreamingOutput}
  */
 @POST
 @Path("/predictionStreams")
 @Produces(MediaType.APPLICATION_OCTET_STREAM)
 @Consumes(MediaType.MULTIPART_FORM_DATA)
 public Response streamingPredict(
     @Multipart("modelId") long modelId,
     @Multipart("dataFormat") String dataFormat,
     @Multipart("columnHeader") String columnHeader,
     @Multipart("file") InputStream inputStream) {
   PrivilegedCarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
   int tenantId = carbonContext.getTenantId();
   String userName = carbonContext.getUsername();
   try {
     // validate input parameters
     // if it is a file upload, check whether the file is sent
     if (inputStream == null || inputStream.available() == 0) {
       String msg =
           String.format(
               "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s .",
               modelId, tenantId, userName);
       logger.error(msg);
       return Response.status(Response.Status.BAD_REQUEST).entity(new MLErrorBean(msg)).build();
     }
     final String predictions =
         mlModelHandler.streamingPredict(
             tenantId, userName, modelId, dataFormat, columnHeader, inputStream);
     StreamingOutput stream =
         new StreamingOutput() {
           @Override
           public void write(OutputStream outputStream) throws IOException {
             Writer writer =
                 new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8));
             writer.write(predictions);
             writer.flush();
             writer.close();
           }
         };
     return Response.ok(stream)
         .header(
             "Content-disposition",
             "attachment; filename=Predictions_"
                 + modelId
                 + "_"
                 + MLUtils.getDate()
                 + MLConstants.CSV)
         .build();
   } catch (IOException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while reading the file for model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.BAD_REQUEST)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   } catch (MLModelHandlerException e) {
     String msg =
         MLUtils.getErrorMsg(
             String.format(
                 "Error occurred while predicting from model [id] %s of tenant [id] %s and [user] %s.",
                 modelId, tenantId, userName),
             e);
     logger.error(msg, e);
     return Response.status(Response.Status.INTERNAL_SERVER_ERROR)
         .entity(new MLErrorBean(e.getMessage()))
         .build();
   }
 }