Пример #1
0
  @Override
  public PredictedLabels predictForTest(
      Example vector, Vector W, String realClass, ClassifierData classifierData, int returnAll) {

    if (returnAll != Consts.ERROR_NUMBER) {
      try {
        PredictedLabels tree = new PredictedLabels();

        // validation
        if (vector.sizeOfVector <= 0) {
          Logger.error(ErrorConstants.PHI_VECTOR_DATA);
          return null;
        }

        for (int i = 0; i < numOfClass; i++) {
          Example phiData =
              classifierData.phi.convert(vector, String.valueOf(i), classifierData.kernel);

          // multiple the vectors
          double tmp = MathHelpers.multipleVectors(W, phiData.getFeatures());

          // get the max value for the max classification
          tree.put(String.valueOf(i), tmp);
        }

        MapValueComparatorDescending vc = new MapValueComparatorDescending(tree);
        PredictedLabels result = new PredictedLabels(vc);
        result.putAll(tree);

        return result;

      } catch (Exception e) {
        e.printStackTrace();
        return null;
      }
    } else return predictForTrain(vector, W, realClass, classifierData, 0);
  }
Пример #2
0
  @Override
  // predict function
  // argmax(yS,yE) (W*Phi(Xi,yS,yE)) + Task Loss
  // this function assumes that the argument vector has already been converted to phi vector
  // return null on error
  public PredictedLabels predictForTrain(
      Example vector,
      Vector W,
      String realClass,
      ClassifierData classifierData,
      double epsilonArgMax) {
    try {
      // validation
      if (vector.sizeOfVector <= 0) {
        Logger.error(ErrorConstants.PHI_VECTOR_DATA);
        return null;
      }

      if (this.numOfClass > 0) {
        // get score for the first label
        String maxLabel = "0";
        Example phiData =
            classifierData.phi.convert(vector, String.valueOf(0), classifierData.kernel);
        double maxScore = MathHelpers.multipleVectors(W, phiData.getFeatures());

        if (epsilonArgMax != 0) {
          // add the task loss
          maxScore +=
              epsilonArgMax
                  * classifierData.taskLoss.computeTaskLoss(
                      String.valueOf(0), realClass, classifierData.arguments);
        }

        for (int i = 1; i < numOfClass; i++) {
          phiData = classifierData.phi.convert(vector, String.valueOf(i), classifierData.kernel);
          // multiple the vectors
          double tmp = MathHelpers.multipleVectors(W, phiData.getFeatures());

          if (epsilonArgMax != 0) {
            // add the task loss
            tmp +=
                epsilonArgMax
                    * classifierData.taskLoss.computeTaskLoss(
                        String.valueOf(i), realClass, classifierData.arguments);
          }

          // updates the max score and max label
          if (tmp > maxScore) {
            maxLabel = String.valueOf(i);
            maxScore = tmp;
          }
        }

        PredictedLabels result = new PredictedLabels();
        result.put(maxLabel, maxScore);

        return result;
      }
      return null;
    } catch (Exception e) {
      e.printStackTrace();
      return null;
    }
  }