예제 #1
0
  protected int gradientUpdate(_Query query) {
    double diff;
    int i, trainSize = 0;

    // Step 1: calculate the ranking score
    for (_QUPair pair : query.m_docList) pair.score(m_weight);
    m_eval.eval(query);

    // Step 2: accumulate the lambdas for each URL
    for (_QUPair pair : query.m_docList) {
      diff = 0;
      if (pair.m_worseURLs != null) {
        for (_QUPair worseURL : pair.m_worseURLs) { // force to moving up
          diff += Utils.logistic(worseURL.m_score - pair.m_score) * m_eval.delta(pair, worseURL);
          trainSize++;
        }
      }

      if (pair.m_betterURLs != null) {
        for (_QUPair betterURL : pair.m_betterURLs) { // force to moving down
          diff -= Utils.logistic(pair.m_score - betterURL.m_score) * m_eval.delta(betterURL, pair);
          trainSize++;
        }
      }

      // Step 3: update weight according to this URL
      if (diff != 0) {
        for (i = 0; i < pair.m_rankFv.length; i++) m_g[i] -= diff * pair.m_rankFv[i];
      }
    }

    return trainSize;
  }
예제 #2
0
  protected void evaluate() {
    double r;

    m_obj = 0;
    m_perf = 0;
    m_misorder = 0;

    for (_Query query : m_queries) {
      // calculate ranking score with latest weight
      for (_QUPair pair : query.m_docList) pair.score(m_weight);

      if ((r = m_eval.eval(query)) >= 0) // ranking score should already be calculated
      m_perf += r;

      for (_QUPair pair : query.m_docList) {
        if (pair.m_worseURLs != null) {
          for (_QUPair worseURL : pair.m_worseURLs) {
            if ((r = Utils.logistic(pair.m_score - worseURL.m_score)) > 0) m_obj += Math.log(r);
            if (pair.m_score <= worseURL.m_score) m_misorder++;
          }
        }

        if (pair.m_betterURLs != null) {
          for (_QUPair betterURL : pair.m_betterURLs) {
            if ((r = Utils.logistic(betterURL.m_score - pair.m_score)) > 0) m_obj += Math.log(r);
            if (pair.m_score >= betterURL.m_score) m_misorder++;
          }
        }
      }
    }
    m_misorder /= 2;
  }
예제 #3
0
  public LambdaRankWorker(
      int maxIter,
      int featureSize,
      int windowSize,
      double initStep,
      double shrinkage,
      double lambda,
      OptimizationType otype) {
    m_weight = new double[featureSize];
    m_g = new double[featureSize];
    m_queries = new ArrayList<_Query>();
    m_step = initStep;
    m_maxIter = maxIter;
    m_windowSize = windowSize;
    m_shrinkage = shrinkage;
    m_lambda = lambda;

    if (otype.equals(OptimizationType.OT_MAP)) m_eval = new MAP_Evaluator();
    else if (otype.equals(OptimizationType.OT_NDCG)) m_eval = new NDCG_Evaluator(LambdaRank.NDCG_K);
    else m_eval = new Evaluator();
    m_eval.setRate(0.5);
  }