@Override
  public double getUnnormalizedLogProbability(Assignment assignment) {
    Preconditions.checkArgument(assignment.containsAll(getVars().getVariableNumsArray()));
    Tensor inputFeatureVector =
        (Tensor) assignment.getValue(getInputVariable().getOnlyVariableNum());

    if (conditionalVars.size() == 0) {
      // No normalization for any conditioned-on variables. This case
      // allows a more efficient implementation than the default
      // in ClassifierFactor.
      VariableNumMap outputVars = getOutputVariables();
      Tensor outputTensor =
          SparseTensor.singleElement(
              outputVars.getVariableNumsArray(),
              outputVars.getVariableSizes(),
              outputVars.assignmentToIntArray(assignment),
              1.0);

      Tensor featureIndicator = outputTensor.outerProduct(inputFeatureVector);
      return logWeights.innerProduct(featureIndicator).getByDimKey();
    } else {
      // Default to looking up the answer in the output log probabilities
      int[] outputIndexes = getOutputVariables().assignmentToIntArray(assignment);
      Tensor logProbs = getOutputLogProbTensor(inputFeatureVector);
      return logProbs.getByDimKey(outputIndexes);
    }
  }
  @Override
  protected Tensor getOutputLogProbTensor(Tensor inputFeatureVector) {
    Tensor logProbs = logWeights.innerProduct(inputFeatureVector.relabelDimensions(inputVarNums));

    if (conditionalVars.size() > 0) {
      Tensor probs = logProbs.elementwiseExp();
      Tensor normalizingConstants = probs.sumOutDimensions(conditionalVars.getVariableNumsArray());
      logProbs =
          probs.elementwiseProduct(normalizingConstants.elementwiseInverse()).elementwiseLog();
    }
    return logProbs;
  }
  public LinearClassifierFactor(
      VariableNumMap inputVar,
      VariableNumMap outputVars,
      DiscreteVariable featureDictionary,
      Tensor logWeights) {
    super(inputVar, outputVars, featureDictionary);
    Preconditions.checkArgument(
        inputVar.union(outputVars).containsAll(logWeights.getDimensionNumbers()));
    Preconditions.checkArgument(outputVars.getDiscreteVariables().size() == outputVars.size());

    this.inputVarNums = new int[] {inputVar.getOnlyVariableNum()};
    this.conditionalVars = VariableNumMap.EMPTY;
    this.logWeights = logWeights;
  }