/** Explain the score we computed for doc */
 @Override
 public Explanation explain(LeafReaderContext context, int doc) throws IOException {
   boolean match = false;
   float max = 0.0f, sum = 0.0f;
   List<Explanation> subs = new ArrayList<>();
   for (Weight wt : weights) {
     Explanation e = wt.explain(context, doc);
     if (e.isMatch()) {
       match = true;
       subs.add(e);
       sum += e.getValue();
       max = Math.max(max, e.getValue());
     }
   }
   if (match) {
     final float score = max + (sum - max) * tieBreakerMultiplier;
     final String desc =
         tieBreakerMultiplier == 0.0f
             ? "max of:"
             : "max plus " + tieBreakerMultiplier + " times others of:";
     return Explanation.match(score, desc, subs);
   } else {
     return Explanation.noMatch("No matching clause");
   }
 }
 /** Apply the computed normalization factor to our subqueries */
 @Override
 public void normalize(float norm, float topLevelBoost) {
   topLevelBoost *= getBoost(); // Incorporate our boost
   for (Weight wt : weights) {
     wt.normalize(norm, topLevelBoost);
   }
 }
    @Override
    public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
      if (used.compareAndSet(false, true)) {
        policy.onUse(getQuery());
      }
      DocIdSet docIdSet = get(in.getQuery(), context);
      if (docIdSet == null) {
        if (shouldCache(context)) {
          docIdSet = cache(context);
          putIfAbsent(in.getQuery(), context, docIdSet);
        } else {
          return in.bulkScorer(context);
        }
      }

      assert docIdSet != null;
      if (docIdSet == DocIdSet.EMPTY) {
        return null;
      }
      final DocIdSetIterator disi = docIdSet.iterator();
      if (disi == null) {
        return null;
      }

      return new DefaultBulkScorer(new ConstantScoreScorer(this, 0f, disi));
    }
 /**
  * Compute the sub of squared weights of us applied to our subqueries. Used for normalization.
  */
 @Override
 public float getValueForNormalization() throws IOException {
   float max = 0.0f, sum = 0.0f;
   for (Weight currentWeight : weights) {
     float sub = currentWeight.getValueForNormalization();
     sum += sub;
     max = Math.max(max, sub);
   }
   return (((sum - max) * tieBreakerMultiplier * tieBreakerMultiplier) + max);
 }
  public void testSkipToFirsttimeHit() throws IOException {
    final DisjunctionMaxQuery dq = new DisjunctionMaxQuery(0.0f);
    dq.add(tq("dek", "albino"));
    dq.add(tq("dek", "DOES_NOT_EXIST"));

    QueryUtils.check(dq, s);

    final Weight dw = dq.weight(s);
    final Scorer ds = dw.scorer(s.getIndexReader(), true, false);
    assertTrue("firsttime skipTo found no match", ds.advance(3) != DocIdSetIterator.NO_MORE_DOCS);
    assertEquals("found wrong docid", "d4", r.document(ds.docID()).get("id"));
  }
  public void testSkipToFirsttimeMiss() throws IOException {
    final DisjunctionMaxQuery dq = new DisjunctionMaxQuery(0.0f);
    dq.add(tq("id", "d1"));
    dq.add(tq("dek", "DOES_NOT_EXIST"));

    QueryUtils.check(dq, s);

    final Weight dw = dq.weight(s);
    final Scorer ds = dw.scorer(s.getIndexReader(), true, false);
    final boolean skipOk = ds.advance(3) != DocIdSetIterator.NO_MORE_DOCS;
    if (skipOk) {
      fail("firsttime skipTo found a match? ... " + r.document(ds.docID()).get("id"));
    }
  }
 @Override
 public void normalize(float norm, float topLevelBoost) {
   this.queryNorm = norm * topLevelBoost;
   queryWeight *= this.queryNorm;
   // we normalize the inner weight, but ignore it (just to initialize everything)
   if (innerWeight != null) innerWeight.normalize(norm, topLevelBoost);
 }
    @Override
    public Scorer scorer(
        AtomicReaderContext context,
        boolean scoreDocsInOrder,
        boolean topScorer,
        final Bits acceptDocs)
        throws IOException {
      final DocIdSetIterator disi;
      if (filter != null) {
        assert query == null;
        final DocIdSet dis = filter.getDocIdSet(context, acceptDocs);
        if (dis == null) {
          return null;
        }
        disi = dis.iterator();
      } else {
        assert query != null && innerWeight != null;
        disi = innerWeight.scorer(context, scoreDocsInOrder, topScorer, acceptDocs);
      }

      if (disi == null) {
        return null;
      }
      return new ConstantScorer(disi, this, queryWeight);
    }
  // inherit javadoc
  public void search(Weight weight, Filter filter, final HitCollector results) throws IOException {

    Scorer scorer = weight.scorer(reader);
    if (scorer == null) return;

    if (filter == null) {
      scorer.score(results);
      return;
    }
    DocIdSetIterator filterDocIdIterator =
        filter.getDocIdSet(reader).iterator(); // CHECKME: use ConjunctionScorer here?

    boolean more = filterDocIdIterator.next() && scorer.skipTo(filterDocIdIterator.doc());

    while (more) {
      int filterDocId = filterDocIdIterator.doc();
      if (filterDocId > scorer.doc() && !scorer.skipTo(filterDocId)) {
        more = false;
      } else {
        int scorerDocId = scorer.doc();
        if (scorerDocId == filterDocId) { // permitted by filter
          results.collect(scorerDocId, scorer.score());
          more = filterDocIdIterator.next();
        } else {
          more = filterDocIdIterator.skipTo(scorerDocId);
        }
      }
    }
  }
 @Override
 public float getValueForNormalization() throws IOException {
   // we calculate sumOfSquaredWeights of the inner weight, but ignore it (just to initialize
   // everything)
   if (innerWeight != null) innerWeight.getValueForNormalization();
   queryWeight = getBoost();
   return queryWeight * queryWeight;
 }
 private DocIdSet cache(LeafReaderContext context) throws IOException {
   final BulkScorer scorer = in.bulkScorer(context);
   if (scorer == null) {
     return DocIdSet.EMPTY;
   } else {
     return cacheImpl(scorer, context.reader().maxDoc());
   }
 }
 @Override
 public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, boolean topScorer)
     throws IOException {
   Scorer subQueryScorer = qWeight.scorer(reader, true, false);
   if (subQueryScorer == null) {
     return null;
   }
   return new BoostedQuery.CustomScorer(
       getSimilarity(searcher), searcher, reader, this, subQueryScorer, boostVal);
 }
Exemple #13
0
 public synchronized boolean equals(java.lang.Object obj) {
   if (!(obj instanceof Weight)) return false;
   Weight other = (Weight) obj;
   if (obj == null) return false;
   if (this == obj) return true;
   if (__equalsCalc != null) {
     return (__equalsCalc == obj);
   }
   __equalsCalc = obj;
   boolean _equals;
   _equals =
       true
           && ((this.units == null && other.getUnits() == null)
               || (this.units != null && this.units.equals(other.getUnits())))
           && ((this.value == null && other.getValue() == null)
               || (this.value != null && this.value.equals(other.getValue())));
   __equalsCalc = null;
   return _equals;
 }
 /** Create the scorer used to score our associated DisjunctionMaxQuery */
 @Override
 public Scorer scorer(LeafReaderContext context) throws IOException {
   List<Scorer> scorers = new ArrayList<>();
   for (Weight w : weights) {
     // we will advance() subscorers
     Scorer subScorer = w.scorer(context);
     if (subScorer != null) {
       scorers.add(subScorer);
     }
   }
   if (scorers.isEmpty()) {
     // no sub-scorers had any documents
     return null;
   } else if (scorers.size() == 1) {
     // only one sub-scorer in this segment
     return scorers.get(0);
   } else {
     return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, needsScores);
   }
 }
 /** Explain the score we computed for doc */
 @Override
 public Explanation explain(LeafReaderContext context, int doc) throws IOException {
   if (disjuncts.size() == 1) return weights.get(0).explain(context, doc);
   ComplexExplanation result = new ComplexExplanation();
   float max = 0.0f, sum = 0.0f;
   result.setDescription(
       tieBreakerMultiplier == 0.0f
           ? "max of:"
           : "max plus " + tieBreakerMultiplier + " times others of:");
   for (Weight wt : weights) {
     Explanation e = wt.explain(context, doc);
     if (e.isMatch()) {
       result.setMatch(Boolean.TRUE);
       result.addDetail(e);
       sum += e.getValue();
       max = Math.max(max, e.getValue());
     }
   }
   result.setValue(max + (sum - max) * tieBreakerMultiplier);
   return result;
 }
  public static Weight[] calculateWeights(List list, String valueRef)
      throws NoSuchFieldException, SecurityException, IllegalArgumentException,
          IllegalAccessException {
    Weight weight;
    int sum = 0;
    int size = list.size();
    Weight[] resultList = new Weight[size];

    for (int i = 0; i < size; i++) {
      Object item = list.get(i);
      Class cls = item.getClass();
      Field field = cls.getDeclaredField(valueRef);

      int value = field.getInt(item);
      sum += value;

      weight = new Weight();
      weight.data = item;
      weight.value = value;
      resultList[i] = weight;
    }

    double inc = 0;

    for (int j = 0; j < size; j++) {
      weight = resultList[j];
      weight.proportion = weight.value / sum;
      weight.lower = inc;
      inc += weight.proportion;
      weight.upper = inc;
    }

    return resultList;
  }
 @Override
 public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, boolean topScorer)
     throws IOException {
   final DocIdSetIterator disi;
   if (filter != null) {
     assert query == null;
     final DocIdSet dis = filter.getDocIdSet(reader);
     if (dis == null) return null;
     disi = dis.iterator();
   } else {
     assert query != null && innerWeight != null;
     disi = innerWeight.scorer(reader, scoreDocsInOrder, topScorer);
   }
   if (disi == null) return null;
   return new ConstantScorer(similarity, disi, this);
 }
 public ConstantScorer(Similarity similarity, IndexReader reader, Weight w) throws IOException {
   super(similarity);
   this.reader = reader;
   theScore = w.getValue();
   DocIdSet docIdSet = filter.getDocIdSet(reader);
   if (docIdSet == null) {
     _innerIter = DocIdSet.EMPTY_DOCIDSET.iterator();
   } else {
     DocIdSetIterator iter = docIdSet.iterator();
     if (iter == null) {
       _innerIter = DocIdSet.EMPTY_DOCIDSET.iterator();
     } else {
       _innerIter = iter;
     }
   }
 }
 @Override
 public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, boolean topScorer)
     throws IOException {
   Scorer scorer = w.scorer(reader, scoreDocsInOrder, topScorer);
   if (scorer != null) {
     // check that scorer obeys disi contract for docID() before next()/advance
     try {
       int docid = scorer.docID();
       assert docid == -1 || docid == DocIdSetIterator.NO_MORE_DOCS;
     } catch (UnsupportedOperationException ignored) {
       // from a top-level BS1
       assert topScorer;
     }
   }
   return scorer;
 }
    @Override
    public Explanation explain(IndexReader reader, int doc) throws IOException {
      SolrIndexReader topReader = (SolrIndexReader) reader;
      SolrIndexReader[] subReaders = topReader.getLeafReaders();
      int[] offsets = topReader.getLeafOffsets();
      int readerPos = SolrIndexReader.readerIndex(doc, offsets);
      int readerBase = offsets[readerPos];

      Explanation subQueryExpl = qWeight.explain(reader, doc);
      if (!subQueryExpl.isMatch()) {
        return subQueryExpl;
      }

      DocValues vals = boostVal.getValues(context, subReaders[readerPos]);
      float sc = subQueryExpl.getValue() * vals.floatVal(doc - readerBase);
      Explanation res =
          new ComplexExplanation(true, sc, BoostedQuery.this.toString() + ", product of:");
      res.addDetail(subQueryExpl);
      res.addDetail(vals.explain(doc - readerBase));
      return res;
    }
 @Override
 public boolean scoresDocsOutOfOrder() {
   return (innerWeight != null) ? innerWeight.scoresDocsOutOfOrder() : false;
 }
 public ConstantScorer(Similarity similarity, IndexReader reader, Weight w) throws IOException {
   super(similarity);
   theScore = w.getValue();
   docIdSetIterator = filter.getDocIdSet(reader).iterator();
 }
Exemple #23
0
  @Test
  public void testGetThatFieldProbabilityRatioIsReflectedInBoost() throws Exception {

    ArgumentCaptor<Float> normalizeCaptor = ArgumentCaptor.forClass(Float.class);

    DocumentFrequencyCorrection dfc = new DocumentFrequencyCorrection();

    Directory directory = newDirectory();

    Analyzer analyzer =
        new Analyzer() {
          protected TokenStreamComponents createComponents(String fieldName) {
            Tokenizer source = new WhitespaceTokenizer();
            TokenStream filter =
                new WordDelimiterFilter(
                    source,
                    WordDelimiterFilter.GENERATE_WORD_PARTS
                        | WordDelimiterFilter.SPLIT_ON_CASE_CHANGE,
                    null);
            filter = new LowerCaseFilter(filter);
            return new TokenStreamComponents(source, filter);
          }
        };

    IndexWriterConfig conf = new IndexWriterConfig(analyzer);
    conf.setCodec(Codec.forName(TestUtil.LUCENE_CODEC));
    IndexWriter indexWriter = new IndexWriter(directory, conf);

    // Both fields f1 and f2 have 10 terms in total.
    // f1: the search terms (abc def) make 100% of all terms in f1
    // f2: the search terms (abc def) make 50% of all terms in f2
    // --> we expect that the sum of the boost factors for terms in bq(+f1:abc, +f1:def)
    // equals 2 * sum of the boost factors for terms in bq(+f2:abc, +f2:def)

    PRMSFieldBoostTest.addNumDocs("f1", "abc def", indexWriter, 2);
    PRMSFieldBoostTest.addNumDocs("f1", "abc", indexWriter, 4);
    PRMSFieldBoostTest.addNumDocs("f1", "def", indexWriter, 2);
    PRMSFieldBoostTest.addNumDocs("f2", "abc def", indexWriter, 1);
    PRMSFieldBoostTest.addNumDocs("f2", "abc", indexWriter, 2);
    PRMSFieldBoostTest.addNumDocs("f2", "def", indexWriter, 1);
    PRMSFieldBoostTest.addNumDocs("f2", "ghi", indexWriter, 5);

    indexWriter.close();

    IndexReader indexReader = DirectoryReader.open(directory);
    IndexSearcher indexSearcher = new IndexSearcher(indexReader);
    indexSearcher.setSimilarity(similarity);

    Map<String, Float> fields = new HashMap<>();
    fields.put("f1", 1f);
    fields.put("f2", 1f);
    SearchFieldsAndBoosting searchFieldsAndBoosting =
        new SearchFieldsAndBoosting(FieldBoostModel.PRMS, fields, fields, 0.8f);

    LuceneQueryBuilder queryBuilder =
        new LuceneQueryBuilder(dfc, analyzer, searchFieldsAndBoosting, 0.01f, null);

    WhiteSpaceQuerqyParser parser = new WhiteSpaceQuerqyParser();

    Query query = queryBuilder.createQuery(parser.parse("AbcDef"));
    dfc.finishedUserQuery();

    assertTrue(query instanceof DisjunctionMaxQuery);

    DisjunctionMaxQuery dmq = (DisjunctionMaxQuery) query;
    List<Query> disjuncts = dmq.getDisjuncts();
    assertEquals(2, disjuncts.size());

    Query disjunct1 = disjuncts.get(0);
    if (disjunct1 instanceof BoostQuery) {
      disjunct1 = ((BoostQuery) disjunct1).getQuery();
    }
    assertTrue(disjunct1 instanceof BooleanQuery);

    BooleanQuery bq1 = (BooleanQuery) disjunct1;

    Query disjunct2 = disjuncts.get(1);
    if (disjunct2 instanceof BoostQuery) {
      disjunct2 = ((BoostQuery) disjunct2).getQuery();
    }
    assertTrue(disjunct2 instanceof BooleanQuery);

    BooleanQuery bq2 = (BooleanQuery) disjunct2;

    final Weight weight1 = bq1.createWeight(indexSearcher, true);
    weight1.normalize(0.1f, 4f);

    final Weight weight2 = bq2.createWeight(indexSearcher, true);
    weight2.normalize(0.1f, 4f);

    Mockito.verify(simWeight, times(4)).normalize(eq(0.1f), normalizeCaptor.capture());
    final List<Float> capturedBoosts = normalizeCaptor.getAllValues();

    // capturedBoosts = boosts of [bq1.term1, bq1.term2, bq2.term1, bq2.term2 ]
    assertEquals(capturedBoosts.get(0), capturedBoosts.get(1), 0.00001);
    assertEquals(capturedBoosts.get(2), capturedBoosts.get(3), 0.00001);
    assertEquals(2f, capturedBoosts.get(0) / capturedBoosts.get(3), 0.00001);

    indexReader.close();
    directory.close();
    analyzer.close();
  }
 @Override
 public void normalize(float norm) {
   norm *= getBoost();
   qWeight.normalize(norm);
 }
 @Override
 public float sumOfSquaredWeights() throws IOException {
   float sum = qWeight.sumOfSquaredWeights();
   sum *= getBoost() * getBoost();
   return sum;
 }
 /** Apply the computed normalization factor to our subqueries */
 @Override
 public void normalize(float norm, float boost) {
   for (Weight wt : weights) {
     wt.normalize(norm, boost);
   }
 }
Exemple #27
0
  @Test
  public void quickTest() throws IOException {

    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];

    network.reset();
    weights = network.getFlat().getWeights();

    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];

    Weight weightCalculator = null;

    for (int i = 0; i < workers.length; i++) {
      workers[i] = initGradient(subsets[i]);
      workers[i].setWeights(weights);
    }

    log.info("Running QuickPropagtaion testing! ");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);

    for (int i = 0; i < NUM_EPOCHS; i++) {

      double error = 0.0;

      // each worker do the job
      for (int j = 0; j < workers.length; j++) {
        workers[j].run();
        error += workers[j].getError();
      }

      gradientError[i] = error / workers.length;

      log.info("The #" + i + " training error: " + gradientError[i]);

      // master
      globalParams.reset();

      for (int j = 0; j < workers.length; j++) {
        globalParams.accumulateGradients(workers[j].getGradients());
        globalParams.accumulateTrainSize(subsets[j].getRecordCount());
      }

      if (weightCalculator == null) {
        weightCalculator =
            new Weight(
                globalParams.getGradients().length,
                globalParams.getTrainSize(),
                this.rate,
                NNUtils.QUICK_PROPAGATION);
      }

      double[] interWeight =
          weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients());

      globalParams.setWeights(interWeight);

      // set weights
      for (int j = 0; j < workers.length; j++) {
        workers[j].setWeights(interWeight);
      }
    }

    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);

    Propagation p = null;
    p = new QuickPropagation(network, training, rate);
    // p = new ManhattanPropagation(network, training, rate);
    p.setThreadCount(numSplit);

    for (int i = 0; i < NUM_EPOCHS; i++) {
      p.iteration(1);
      // System.out.println("the #" + i + " training error: " + p.getError());
      ecogError[i] = p.getError();
    }

    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
      diff += Math.abs(ecogError[i] - gradientError[i]);
    }

    Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
  }
 @Override
 public void extractTerms(Set<Term> terms) {
   for (Weight weight : weights) {
     weight.extractTerms(terms);
   }
 }
 public ConstantScorer(Similarity similarity, DocIdSetIterator docIdSetIterator, Weight w)
     throws IOException {
   super(similarity, w);
   theScore = w.getValue();
   this.docIdSetIterator = docIdSetIterator;
 }
 @Override
 public float getValueForNormalization() throws IOException {
   float sum = parentWeight.getValueForNormalization();
   sum *= getBoost() * getBoost();
   return sum;
 }