diff --git a/quickwit/quickwit-query/src/query_ast/tantivy_query_ast.rs b/quickwit/quickwit-query/src/query_ast/tantivy_query_ast.rs index b89130424cf..e9ff7edc68c 100644 --- a/quickwit/quickwit-query/src/query_ast/tantivy_query_ast.rs +++ b/quickwit/quickwit-query/src/query_ast/tantivy_query_ast.rs @@ -36,6 +36,18 @@ pub(crate) enum TantivyQueryAst { ConstPredicate(MatchAllOrNone), } +impl Clone for TantivyQueryAst { + fn clone(&self) -> Self { + match self { + TantivyQueryAst::Bool(bool_query) => TantivyQueryAst::Bool(bool_query.clone()), + TantivyQueryAst::ConstPredicate(predicate) => { + TantivyQueryAst::ConstPredicate(*predicate) + } + TantivyQueryAst::Leaf(query) => TantivyQueryAst::Leaf(query.box_clone()), + } + } +} + impl From for TantivyQueryAst { fn from(match_all_or_none: MatchAllOrNone) -> Self { TantivyQueryAst::ConstPredicate(match_all_or_none) @@ -142,7 +154,7 @@ fn remove_with_guard( } } -#[derive(Default, Debug, Eq, PartialEq)] +#[derive(Default, Debug, Clone, Eq, PartialEq)] pub(crate) struct TantivyBoolQuery { pub must: Vec, pub must_not: Vec, @@ -169,6 +181,7 @@ impl TantivyBoolQuery { } pub fn simplify(mut self) -> TantivyQueryAst { + // simplify sub branches self.must = simplify_asts(self.must); self.should = simplify_asts(self.should); self.must_not = simplify_asts(self.must_not); @@ -188,6 +201,72 @@ impl TantivyBoolQuery { // This is just a convention mimicking Elastic/Commonsearch's behavior. return TantivyQueryAst::match_all(); } + + let mut new_must = Vec::with_capacity(self.must.len()); + for must in self.must { + let mut must_bool = match must { + TantivyQueryAst::Bool(bool_query) => bool_query, + _ => { + new_must.push(must); + continue; + } + }; + if must_bool.should.is_empty() { + new_must.append(&mut must_bool.must); + self.filter.append(&mut must_bool.filter); + self.must_not.append(&mut must_bool.must_not); + } else { + new_must.push(TantivyQueryAst::Bool(must_bool)); + } + } + self.must = new_must; + + let mut new_filter = Vec::with_capacity(self.filter.len()); + for filter in self.filter { + let mut filter_bool = match filter { + TantivyQueryAst::Bool(bool_query) => bool_query, + _ => { + new_filter.push(filter); + continue; + } + }; + if filter_bool.should.is_empty() { + new_filter.append(&mut filter_bool.must); + new_filter.append(&mut filter_bool.filter); + // must_not doeen't contribute to score, no need to move it to some filter_not kind + // of thing + self.must_not.append(&mut filter_bool.must_not); + } else { + new_filter.push(TantivyQueryAst::Bool(filter_bool)); + } + } + self.filter = new_filter; + + let mut new_should = Vec::with_capacity(self.should.len()); + for should in self.should { + let mut should_bool = match should { + TantivyQueryAst::Bool(bool_query) => bool_query, + _ => { + new_should.push(should); + continue; + } + }; + if should_bool.must.is_empty() + && should_bool.filter.is_empty() + && should_bool.must_not.is_empty() + { + new_should.append(&mut should_bool.should); + } else { + new_should.push(TantivyQueryAst::Bool(should_bool)); + } + } + self.should = new_should; + + // TODO we could turn must_not(must_not(abc, def)) into should(filter(abc), filter(def)), + // we can't simply have should(abc, def) because of scoring, and should(filter(abc, def)) + // has a different meaning + + // remove sub-queries which don't impact the result remove_with_guard(&mut self.must, MatchAllOrNone::MatchAll, true); let mut has_no_positive_ast_so_far = self.must.is_empty(); remove_with_guard( @@ -196,24 +275,24 @@ impl TantivyBoolQuery { has_no_positive_ast_so_far, ); has_no_positive_ast_so_far &= self.filter.is_empty(); + if !self.filter.is_empty() { + // if filter is not empty, we can re-try cleaning must. we can't just check + // has_no_positive_ast_so_far as it would clean must if must or filter contained + // something + remove_with_guard(&mut self.must, MatchAllOrNone::MatchAll, false); + } remove_with_guard( &mut self.should, MatchAllOrNone::MatchNone, has_no_positive_ast_so_far, ); has_no_positive_ast_so_far &= self.should.is_empty(); - // we do that a second time in case must happens to have a MatchAll and nothing else, - // but filter and/or should had something - remove_with_guard( - &mut self.must, - MatchAllOrNone::MatchAll, - has_no_positive_ast_so_far, - ); remove_with_guard( &mut self.must_not, MatchAllOrNone::MatchNone, has_no_positive_ast_so_far, ); + for must_child in self.must.iter().chain(self.filter.iter()) { if must_child.const_predicate() == Some(MatchAllOrNone::MatchNone) { return TantivyQueryAst::ConstPredicate(MatchAllOrNone::MatchNone); @@ -287,11 +366,21 @@ impl From for Box { #[cfg(test)] mod tests { - use tantivy::query::EmptyQuery; + use proptest::prelude::*; + use tantivy::query::{EmptyQuery, TermQuery}; use super::TantivyBoolQuery; use crate::query_ast::tantivy_query_ast::{remove_with_guard, MatchAllOrNone, TantivyQueryAst}; + fn term(val: &str) -> TantivyQueryAst { + use tantivy::schema::{Field, Term}; + TermQuery::new( + Term::from_field_text(Field::from_field_id(0), val), + Default::default(), + ) + .into() + } + #[test] fn test_simplify_bool_query_with_no_clauses() { let bool_query = TantivyBoolQuery::default(); @@ -458,22 +547,6 @@ mod tests { #[test] fn test_simplify_bool_query_with_match_must_and_other_positive_clauses() { - let bool_query = TantivyBoolQuery { - must: vec![TantivyQueryAst::match_all()], - should: vec![EmptyQuery.into()], - ..Default::default() - } - .simplify(); - assert_eq!(bool_query, EmptyQuery.into()); - - let bool_query = TantivyBoolQuery { - must: vec![TantivyQueryAst::match_all()], - should: vec![EmptyQuery.into()], - ..Default::default() - } - .simplify(); - assert_eq!(bool_query, EmptyQuery.into()); - let bool_query = TantivyBoolQuery { must: vec![TantivyQueryAst::match_all()], filter: vec![EmptyQuery.into()], @@ -523,4 +596,310 @@ mod tests { Some(MatchAllOrNone::MatchAll) ); } + + #[test] + fn test_simplify_lift_bool_bool() { + let bool_query = TantivyBoolQuery { + must: vec![ + TantivyBoolQuery { + must: vec![term("abc"), term("def")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + must: vec![term("ghi"), term("jkl")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + must: vec![term("abc"), term("def"), term("ghi"), term("jkl"),], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + should: vec![ + TantivyBoolQuery { + should: vec![term("abc"), term("def")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + should: vec![term("ghi"), term("jkl")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + should: vec![term("abc"), term("def"), term("ghi"), term("jkl"),], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + must: vec![ + TantivyBoolQuery { + must: vec![term("abc"), term("def")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + should: vec![term("ghi"), term("jkl")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + must: vec![ + term("abc"), + term("def"), + TantivyBoolQuery { + should: vec![term("ghi"), term("jkl")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + should: vec![ + TantivyBoolQuery { + must: vec![term("abc")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + filter: vec![term("ghi")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + should: vec![ + term("abc"), + // filter can't get optimized for scoring reasons + TantivyBoolQuery { + filter: vec![term("ghi")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + must: vec![ + TantivyBoolQuery { + should: vec![term("abc")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + should: vec![term("def")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + must: vec![term("abc"), term("def"),], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + must_not: vec![ + TantivyBoolQuery { + should: vec![term("abc")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + must: vec![term("def")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + must: vec![MatchAllOrNone::MatchAll.into()], + must_not: vec![term("abc"), term("def"),], + ..Default::default() + } + .into() + ); + + let bool_query = TantivyBoolQuery { + must: vec![ + TantivyBoolQuery { + must_not: vec![term("abc"), term("def")], + ..Default::default() + } + .into(), + TantivyBoolQuery { + must_not: vec![term("ghi")], + ..Default::default() + } + .into(), + ], + ..Default::default() + } + .simplify(); + assert_eq!( + bool_query, + TantivyBoolQuery { + must: vec![MatchAllOrNone::MatchAll.into()], + must_not: vec![term("abc"), term("def"), term("ghi"),], + ..Default::default() + } + .into() + ); + } + + #[derive(Debug, Clone)] + struct ConstQuery(bool, u32); + + impl tantivy::query::Query for ConstQuery { + fn weight( + &self, + _: tantivy::query::EnableScoring<'_>, + ) -> tantivy::Result> { + unimplemented!() + } + } + + impl TantivyQueryAst { + fn evaluate_test(&self) -> Option { + match self { + TantivyQueryAst::ConstPredicate(MatchAllOrNone::MatchNone) => None, + TantivyQueryAst::ConstPredicate(MatchAllOrNone::MatchAll) => Some(0), + TantivyQueryAst::Bool(bool_query) => bool_query.evaluate_test(), + TantivyQueryAst::Leaf(query) => { + let const_query = query + .downcast_ref::() + .expect("query wasn't a ConstQuery"); + const_query.0.then_some(const_query.1) + } + } + } + } + + impl TantivyBoolQuery { + fn evaluate_test(&self) -> Option { + if self + .must_not + .iter() + .any(|sub_ast| sub_ast.evaluate_test().is_some()) + { + return None; + } + let should_score: u32 = self + .should + .iter() + .filter_map(|should| should.evaluate_test()) + .sum(); + if self.must.len() + self.filter.len() > 0 { + if self + .must + .iter() + .all(|sub_ast| sub_ast.evaluate_test().is_some()) + && self + .filter + .iter() + .all(|sub_ast| sub_ast.evaluate_test().is_some()) + { + Some( + self.must + .iter() + .map(|sub_ast| sub_ast.evaluate_test().unwrap()) + .sum::() + + should_score, + ) + } else { + None + } + } else { + if self.should.is_empty() { + // by convention, an empty query returns all match. + return Some(0); + } + self.should + .iter() + .any(|sub_ast| sub_ast.evaluate_test().is_some()) + .then_some(should_score) + } + } + } + + fn ast_strategy() -> impl Strategy { + let ast_leaf = proptest::prop_oneof![ + Just(TantivyQueryAst::ConstPredicate(MatchAllOrNone::MatchNone)), + Just(TantivyQueryAst::ConstPredicate(MatchAllOrNone::MatchAll)), + (prop::bool::ANY, 0u32..5) + .prop_map(|(matc, score)| TantivyQueryAst::Leaf(Box::new(ConstQuery(matc, score)))), + ]; + + ast_leaf.prop_recursive(4, 32, 16, |element| { + let must = proptest::collection::vec(element.clone(), 0..4); + let filter = proptest::collection::vec(element.clone(), 0..4); + let should = proptest::collection::vec(element.clone(), 0..4); + let must_not = proptest::collection::vec(element.clone(), 0..4); + (must, filter, should, must_not).prop_map(|(must, filter, should, must_not)| { + TantivyQueryAst::Bool(TantivyBoolQuery { + must, + filter, + should, + must_not, + }) + }) + }) + } + + proptest::proptest! { + #![proptest_config(ProptestConfig { + cases: 10000, .. ProptestConfig::default() + })] + #[test] + fn test_proptest_simplify_never_change_result(ast in ast_strategy()) { + let simplified_ast = ast.clone().simplify(); + assert_eq!(dbg!(simplified_ast).evaluate_test(), ast.evaluate_test()); + } + } }