Skip to content

Commit

Permalink
extend FuzzyTermQuery to support json field (#2173)
Browse files Browse the repository at this point in the history
* extend fuzzy search for json field

* comments

* comments

* fmt fix

* comments
  • Loading branch information
PingXia-at authored Sep 11, 2023
1 parent 1932513 commit e4e416a
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 17 deletions.
29 changes: 28 additions & 1 deletion src/query/automaton_weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;

use super::phrase_prefix_query::prefix_end;
use crate::core::SegmentReader;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
Expand All @@ -14,6 +15,10 @@ use crate::{DocId, Score, TantivyError};
pub struct AutomatonWeight<A> {
field: Field,
automaton: Arc<A>,
// For JSON fields, the term dictionary include terms from all paths.
// We apply additional filtering based on the given JSON path, when searching within the term
// dictionary. This prevents terms from unrelated paths from matching the search criteria.
json_path_bytes: Option<Box<[u8]>>,
}

impl<A> AutomatonWeight<A>
Expand All @@ -26,6 +31,20 @@ where
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: None,
}
}

/// Create a new AutomationWeight for a json path
pub fn new_for_json_path<IntoArcA: Into<Arc<A>>>(
field: Field,
automaton: IntoArcA,
json_path_bytes: &[u8],
) -> AutomatonWeight<A> {
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()),
}
}

Expand All @@ -34,7 +53,15 @@ where
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let term_stream_builder = term_dict.search(automaton);
let mut term_stream_builder = term_dict.search(automaton);

if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
if let Some(end) = prefix_end(json_path_bytes) {
term_stream_builder = term_stream_builder.lt(&end);
}
}

term_stream_builder.into_stream()
}
}
Expand Down
126 changes: 117 additions & 9 deletions src/query/fuzzy_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use once_cell::sync::OnceCell;
use tantivy_fst::Automaton;

use crate::query::{AutomatonWeight, EnableScoring, Query, Weight};
use crate::schema::Term;
use crate::schema::{Term, Type};
use crate::TantivyError::InvalidArgument;

pub(crate) struct DfaWrapper(pub DFA);
Expand Down Expand Up @@ -132,18 +132,46 @@ impl FuzzyTermQuery {
});

let term_value = self.term.value();
let term_text = term_value.as_str().ok_or_else(|| {
InvalidArgument("The fuzzy term query requires a string term.".to_string())
})?;

let term_text = if term_value.typ() == Type::Json {
if let Some(json_path_type) = term_value.json_path_type() {
if json_path_type != Type::Str {
return Err(InvalidArgument(format!(
"The fuzzy term query requires a string path type for a json term. Found \
{:?}",
json_path_type
)));
}
}

std::str::from_utf8(self.term.serialized_value_bytes()).map_err(|_| {
InvalidArgument(
"Failed to convert json term value bytes to utf8 string.".to_string(),
)
})?
} else {
term_value.as_str().ok_or_else(|| {
InvalidArgument("The fuzzy term query requires a string term.".to_string())
})?
};
let automaton = if self.prefix {
automaton_builder.build_prefix_dfa(term_text)
} else {
automaton_builder.build_dfa(term_text)
};
Ok(AutomatonWeight::new(
self.term.field(),
DfaWrapper(automaton),
))

if let Some((json_path_bytes, _)) = term_value.as_json() {
Ok(AutomatonWeight::new_for_json_path(
self.term.field(),
DfaWrapper(automaton),
json_path_bytes,
))
} else {
Ok(AutomatonWeight::new(
self.term.field(),
DfaWrapper(automaton),
))
}
}
}

Expand All @@ -157,9 +185,89 @@ impl Query for FuzzyTermQuery {
mod test {
use super::FuzzyTermQuery;
use crate::collector::{Count, TopDocs};
use crate::schema::{Schema, TEXT};
use crate::indexer::NoMergePolicy;
use crate::query::QueryParser;
use crate::schema::{Schema, STORED, TEXT};
use crate::{assert_nearly_equals, Index, Term};

#[test]
pub fn test_fuzzy_json_path() -> crate::Result<()> {
// # Defining the schema
let mut schema_builder = Schema::builder();
let attributes = schema_builder.add_json_field("attributes", TEXT | STORED);
let schema = schema_builder.build();

// # Indexing documents
let index = Index::create_in_ram(schema.clone());

let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
let doc = schema.parse_document(
r#"{
"attributes": {
"a": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
let doc = schema.parse_document(
r#"{
"attributes": {
"aa": "japan"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;

let reader = index.reader()?;
let searcher = reader.searcher();

// # Fuzzy search
let query_parser = QueryParser::for_index(&index, vec![attributes]);

let get_json_path_term = |query: &str| -> crate::Result<Term> {
let query = query_parser.parse_query(query)?;
let mut terms = Vec::new();
query.query_terms(&mut |term, _| {
terms.push(term.clone());
});

Ok(terms[0].clone())
};

// shall not match the first document due to json path mismatch
{
let term = get_json_path_term("attributes.aa:japan")?;
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
}

// shall match the first document because Levenshtein distance is 1 (substitute 'o' with
// 'a')
{
let term = get_json_path_term("attributes.a:japon")?;

let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
}

// shall not match because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
{
let term = get_json_path_term("attributes.a:jap")?;

let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}

Ok(())
}

#[test]
pub fn test_fuzzy_term() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
Expand Down
2 changes: 1 addition & 1 deletion src/query/phrase_prefix_query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use phrase_prefix_query::PhrasePrefixQuery;
pub use phrase_prefix_scorer::PhrasePrefixScorer;
pub use phrase_prefix_weight::PhrasePrefixWeight;

fn prefix_end(prefix_start: &[u8]) -> Option<Vec<u8>> {
pub(crate) fn prefix_end(prefix_start: &[u8]) -> Option<Vec<u8>> {
let mut res = prefix_start.to_owned();
while !res.is_empty() {
let end = res.len() - 1;
Expand Down
24 changes: 18 additions & 6 deletions src/schema/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,20 +397,29 @@ where B: AsRef<[u8]>
Some(Ipv6Addr::from_u128(ip_u128))
}

/// Returns the json path (without non-human friendly separators),
/// Returns the json path type.
///
/// Returns `None` if the value is not JSON.
pub fn json_path_type(&self) -> Option<Type> {
let json_value_bytes = self.as_json_value_bytes()?;

Some(json_value_bytes.typ())
}

/// Returns the json path bytes (including the JSON_END_OF_PATH byte),
/// and the encoded ValueBytes after the json path.
///
/// Returns `None` if the value is not JSON.
pub(crate) fn as_json(&self) -> Option<(&str, ValueBytes<&[u8]>)> {
pub(crate) fn as_json(&self) -> Option<(&[u8], ValueBytes<&[u8]>)> {
if self.typ() != Type::Json {
return None;
}
let bytes = self.value_bytes();

let pos = bytes.iter().cloned().position(|b| b == JSON_END_OF_PATH)?;
let (json_path_bytes, term) = bytes.split_at(pos);
let json_path = str::from_utf8(json_path_bytes).ok()?;
Some((json_path, ValueBytes::wrap(&term[1..])))
// split at pos + 1, so that json_path_bytes includes the JSON_END_OF_PATH byte.
let (json_path_bytes, term) = bytes.split_at(pos + 1);
Some((json_path_bytes, ValueBytes::wrap(&term)))
}

/// Returns the encoded ValueBytes after the json path.
Expand Down Expand Up @@ -469,7 +478,10 @@ where B: AsRef<[u8]>
write_opt(f, self.as_bytes())?;
}
Type::Json => {
if let Some((path, sub_value_bytes)) = self.as_json() {
if let Some((path_bytes, sub_value_bytes)) = self.as_json() {
// Remove the JSON_END_OF_PATH byte & convert to utf8.
let path = str::from_utf8(&path_bytes[..path_bytes.len() - 1])
.map_err(|_| std::fmt::Error)?;
let path_pretty = path.replace(JSON_PATH_SEGMENT_SEP_STR, ".");
write!(f, "path={path_pretty}, ")?;
sub_value_bytes.debug_value_bytes(f)?;
Expand Down

0 comments on commit e4e416a

Please sign in to comment.