Skip to content

Commit

Permalink
ConstScoreQuery (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
shikhar authored Aug 23, 2022
1 parent df0ac9e commit 4c6c6e4
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 54 deletions.
177 changes: 177 additions & 0 deletions src/query/const_score_query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
use std::collections::BTreeMap;
use std::fmt;

use crate::query::{Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, Searcher, SegmentReader, TantivyError, Term};

/// `ConstScoreQuery` is a wrapper over a query to provide a constant score.
/// It can avoid unnecessary score computation on the wrapped query.
///
/// The document set matched by the `ConstScoreQuery` is strictly the same as the underlying query.
/// The configured score is used for each document.
pub struct ConstScoreQuery {
query: Box<dyn Query>,
score: Score,
}

impl ConstScoreQuery {
/// Builds a const score query.
pub fn new(query: Box<dyn Query>, score: Score) -> ConstScoreQuery {
ConstScoreQuery { query, score }
}
}

impl Clone for ConstScoreQuery {
fn clone(&self) -> Self {
ConstScoreQuery {
query: self.query.box_clone(),
score: self.score,
}
}
}

impl fmt::Debug for ConstScoreQuery {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Const(score={}, query={:?})", self.score, self.query)
}
}

impl Query for ConstScoreQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> {
let inner_weight = self.query.weight(searcher, scoring_enabled)?;
Ok(if scoring_enabled {
Box::new(ConstWeight::new(inner_weight, self.score))
} else {
inner_weight
})
}

fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {
self.query.query_terms(terms)
}
}

struct ConstWeight {
weight: Box<dyn Weight>,
score: Score,
}

impl ConstWeight {
pub fn new(weight: Box<dyn Weight>, score: Score) -> Self {
ConstWeight { weight, score }
}
}

impl Weight for ConstWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let inner_scorer = self.weight.scorer(reader, boost)?;
Ok(Box::new(ConstScorer::new(inner_scorer, boost * self.score)))
}

fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!(
"Document #({}) does not match",
doc
)));
}
let mut explanation = Explanation::new("Const", self.score);
let underlying_explanation = self.weight.explain(reader, doc)?;
explanation.add_detail(underlying_explanation);
Ok(explanation)
}

fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
self.weight.count(reader)
}
}

/// Wraps a `DocSet` and simply returns a constant `Scorer`.
/// The `ConstScorer` is useful if you have a `DocSet` where
/// you needed a scorer.
///
/// The `ConstScorer`'s constant score can be set
/// by calling `.set_score(...)`.
pub struct ConstScorer<TDocSet: DocSet> {
docset: TDocSet,
score: Score,
}

impl<TDocSet: DocSet> ConstScorer<TDocSet> {
/// Creates a new `ConstScorer`.
pub fn new(docset: TDocSet, score: Score) -> ConstScorer<TDocSet> {
ConstScorer { docset, score }
}
}

impl<TDocSet: DocSet> From<TDocSet> for ConstScorer<TDocSet> {
fn from(docset: TDocSet) -> Self {
ConstScorer::new(docset, 1.0)
}
}

impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn advance(&mut self) -> DocId {
self.docset.advance()
}

fn seek(&mut self, target: DocId) -> DocId {
self.docset.seek(target)
}

fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
self.docset.fill_buffer(buffer)
}

fn doc(&self) -> DocId {
self.docset.doc()
}

fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
}

impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score {
self.score
}
}

#[cfg(test)]
mod tests {
use super::ConstScoreQuery;
use crate::query::{AllQuery, Query};
use crate::schema::Schema;
use crate::{DocAddress, Document, Index};

#[test]
fn test_const_score_query_explain() -> crate::Result<()> {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
index_writer.add_document(Document::new())?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query = ConstScoreQuery::new(Box::new(AllQuery), 0.42);
let explanation = query.explain(&searcher, DocAddress::new(0, 0u32)).unwrap();
assert_eq!(
explanation.to_pretty_json(),
r#"{
"value": 0.42,
"description": "Const",
"details": [
{
"value": 1.0,
"description": "AllQuery",
"context": []
}
],
"context": []
}"#
);
Ok(())
}
}
4 changes: 3 additions & 1 deletion src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod bitset;
mod bm25;
mod boolean_query;
mod boost_query;
mod const_score_query;
mod disjunction_max_query;
mod empty_query;
mod exclude;
Expand Down Expand Up @@ -37,6 +38,7 @@ pub(crate) use self::bm25::Bm25Weight;
pub use self::boolean_query::BooleanQuery;
pub(crate) use self::boolean_query::BooleanWeight;
pub use self::boost_query::BoostQuery;
pub use self::const_score_query::{ConstScoreQuery, ConstScorer};
pub use self::disjunction_max_query::DisjunctionMaxQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude;
Expand All @@ -55,7 +57,7 @@ pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{
DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner,
};
pub use self::scorer::{ConstScorer, Scorer};
pub use self::scorer::Scorer;
pub use self::term_query::TermQuery;
pub use self::union::Union;
#[cfg(test)]
Expand Down
54 changes: 1 addition & 53 deletions src/query/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::DerefMut;
use downcast_rs::impl_downcast;

use crate::docset::DocSet;
use crate::{DocId, Score};
use crate::Score;

/// Scored set of documents matching a query within a specific segment.
///
Expand All @@ -22,55 +22,3 @@ impl Scorer for Box<dyn Scorer> {
self.deref_mut().score()
}
}

/// Wraps a `DocSet` and simply returns a constant `Scorer`.
/// The `ConstScorer` is useful if you have a `DocSet` where
/// you needed a scorer.
///
/// The `ConstScorer`'s constant score can be set
/// by calling `.set_score(...)`.
pub struct ConstScorer<TDocSet: DocSet> {
docset: TDocSet,
score: Score,
}

impl<TDocSet: DocSet> ConstScorer<TDocSet> {
/// Creates a new `ConstScorer`.
pub fn new(docset: TDocSet, score: Score) -> ConstScorer<TDocSet> {
ConstScorer { docset, score }
}
}

impl<TDocSet: DocSet> From<TDocSet> for ConstScorer<TDocSet> {
fn from(docset: TDocSet) -> Self {
ConstScorer::new(docset, 1.0)
}
}

impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn advance(&mut self) -> DocId {
self.docset.advance()
}

fn seek(&mut self, target: DocId) -> DocId {
self.docset.seek(target)
}

fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
self.docset.fill_buffer(buffer)
}

fn doc(&self) -> DocId {
self.docset.doc()
}

fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
}

impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score {
self.score
}
}

0 comments on commit 4c6c6e4

Please sign in to comment.