From a04cbbc046861cd3a27660346cf25ad8ade1afd2 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Mon, 20 Feb 2023 14:16:30 +0800 Subject: [PATCH] feat: add support for u64,i64,f64 fields in term aggregation --- examples/fuzzy_search.rs | 6 +- src/aggregation/bucket/term_agg.rs | 209 ++++++++++++++++----- src/aggregation/intermediate_agg_result.rs | 25 ++- src/aggregation/mod.rs | 20 +- src/aggregation/segment_agg_result.rs | 13 +- 5 files changed, 208 insertions(+), 65 deletions(-) diff --git a/examples/fuzzy_search.rs b/examples/fuzzy_search.rs index 541656d125..d64e4a11fc 100644 --- a/examples/fuzzy_search.rs +++ b/examples/fuzzy_search.rs @@ -10,14 +10,12 @@ // - search for the best document matching a basic query // - retrieve the best document's original content. -use std::collections::HashSet; - // --- // Importing tantivy... use tantivy::collector::{Count, TopDocs}; -use tantivy::query::{FuzzyTermQuery, QueryParser}; +use tantivy::query::FuzzyTermQuery; use tantivy::schema::*; -use tantivy::{doc, DocId, Index, ReloadPolicy, Score, SegmentReader}; +use tantivy::{doc, Index, ReloadPolicy}; use tempfile::TempDir; fn main() -> tantivy::Result<()> { diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index e8130fdd3f..673a536b62 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -15,8 +15,9 @@ use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::VecWithNames; +use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames}; use crate::error::DataCorruption; +use crate::schema::Type; use crate::TantivyError; /// Creates a bucket for every unique term and counts the number of occurences. @@ -25,6 +26,10 @@ use crate::TantivyError; /// If the text is untokenized and single value, that means one term per document and therefore it /// is in fact doc count. /// +/// ## Prerequisite +/// Term aggregations work only on [fast fields](`crate::fastfield`) of type `u64`, `f64`, `i64` and +/// text. +/// /// ### Terminology /// Shard parameters are supposed to be equivalent to elasticsearch shard parameter. /// Since they are @@ -199,9 +204,9 @@ impl TermsAggregationInternal { } #[derive(Clone, Debug, Default)] -/// Container to store term_ids and their buckets. +/// Container to store term_ids/or u64 values and their buckets. struct TermBuckets { - pub(crate) entries: FxHashMap, + pub(crate) entries: FxHashMap, } #[derive(Clone, Default)] @@ -262,6 +267,7 @@ pub struct SegmentTermCollector { term_buckets: TermBuckets, req: TermsAggregationInternal, blueprint: Option>, + field_type: Type, accessor_idx: usize, } @@ -310,7 +316,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { let entry = self .term_buckets .entries - .entry(term_id as u32) + .entry(term_id) .or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { @@ -323,7 +329,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { let entry = self .term_buckets .entries - .entry(term_id as u32) + .entry(term_id) .or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { @@ -348,6 +354,7 @@ impl SegmentTermCollector { pub(crate) fn from_req_and_validate( req: &TermsAggregation, sub_aggregations: &AggregationsWithAccessor, + field_type: Type, accessor_idx: usize, ) -> crate::Result { let term_buckets = TermBuckets::default(); @@ -378,6 +385,7 @@ impl SegmentTermCollector { req: TermsAggregationInternal::from_req(req), term_buckets, blueprint, + field_type, accessor_idx, }) } @@ -386,7 +394,7 @@ impl SegmentTermCollector { self, agg_with_accessor: &BucketAggregationWithAccessor, ) -> crate::Result { - let mut entries: Vec<(u32, TermBucketEntry)> = + let mut entries: Vec<(u64, TermBucketEntry)> = self.term_buckets.entries.into_iter().collect(); let order_by_sub_aggregation = @@ -423,41 +431,52 @@ impl SegmentTermCollector { cut_off_buckets(&mut entries, self.req.segment_size as usize) }; - let inverted_index = agg_with_accessor - .str_dict_column - .as_ref() - .expect("internal error: inverted index not loaded for term aggregation"); - let term_dict = inverted_index; - - let mut dict: FxHashMap = Default::default(); - let mut buffer = String::new(); - for (term_id, entry) in entries { - if !term_dict.ord_to_str(term_id as u64, &mut buffer)? { - return Err(TantivyError::InternalError(format!( - "Couldn't find term_id {} in dict", - term_id - ))); - } - dict.insert( - buffer.to_string(), - entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, - ); - } - if self.req.min_doc_count == 0 { - // TODO: Handle rev streaming for descending sorting by keys - let mut stream = term_dict.dictionary().stream()?; - while let Some((key, _ord)) = stream.next() { - if dict.len() >= self.req.segment_size as usize { - break; + let mut dict: FxHashMap = Default::default(); + dict.reserve(entries.len()); + if self.field_type == Type::Str { + let term_dict = agg_with_accessor + .str_dict_column + .as_ref() + .expect("internal error: term dictionary not found for term aggregation"); + + let mut buffer = String::new(); + for (term_id, entry) in entries { + if !term_dict.ord_to_str(term_id, &mut buffer)? { + return Err(TantivyError::InternalError(format!( + "Couldn't find term_id {} in dict", + term_id + ))); } + dict.insert( + Key::Str(buffer.to_string()), + entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, + ); + } + if self.req.min_doc_count == 0 { + // TODO: Handle rev streaming for descending sorting by keys + let mut stream = term_dict.dictionary().stream()?; + while let Some((key, _ord)) = stream.next() { + if dict.len() >= self.req.segment_size as usize { + break; + } - let key = std::str::from_utf8(key) - .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?; - if !dict.contains_key(key) { - dict.insert(key.to_owned(), Default::default()); + let key = Key::Str( + std::str::from_utf8(key) + .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? + .to_string(), + ); + dict.entry(key).or_insert_with(Default::default); } } - } + } else { + for (val, entry) in entries { + let val = f64_from_fastfield_u64(val, &self.field_type); + dict.insert( + Key::F64(val), + entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, + ); + } + }; Ok(IntermediateBucketResult::Terms( IntermediateTermBucketResult { @@ -477,6 +496,11 @@ impl GetDocCount for (u32, TermBucketEntry) { self.1.doc_count } } +impl GetDocCount for (u64, TermBucketEntry) { + fn doc_count(&self) -> u64 { + self.1.doc_count + } +} impl GetDocCount for (String, IntermediateTermBucketEntry) { fn doc_count(&self) -> u64 { self.1.doc_count @@ -627,7 +651,8 @@ mod tests { fn terms_aggregation_test_order_count_merge_segment(merge_segments: bool) -> crate::Result<()> { let segment_and_terms = vec![ vec![(5.0, "terma".to_string())], - vec![(4.0, "termb".to_string())], + vec![(2.0, "termb".to_string())], + vec![(2.0, "terma".to_string())], vec![(1.0, "termc".to_string())], vec![(1.0, "termc".to_string())], vec![(1.0, "termc".to_string())], @@ -668,7 +693,7 @@ mod tests { }), ..Default::default() }), - sub_aggregation: sub_agg, + sub_aggregation: sub_agg.clone(), }), )] .into_iter() @@ -677,18 +702,114 @@ mod tests { let res = exec_request(agg_req, &index)?; assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb"); assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2); - assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 6.0); + assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0); assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3); assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 1.0); assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma"); - assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5); - assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 5.0); + assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 6); + assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 4.5); assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + // Agg on non string + // + let agg_req: Aggregations = vec![ + ( + "my_scores1".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg.clone(), + }), + ), + ( + "my_scores2".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score_f64".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg.clone(), + }), + ), + ( + "my_scores3".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score_i64".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg, + }), + ), + ] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_scores1"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores1"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores1"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores1"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores1"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores1"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores1"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores1"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores1"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores1"]["buckets"][3]["key"], 5.0); + assert_eq!(res["my_scores1"]["buckets"][3]["doc_count"], 5); + assert_eq!(res["my_scores1"]["buckets"][3]["avg_score"]["value"], 5.0); + + assert_eq!(res["my_scores1"]["sum_other_doc_count"], 0); + + assert_eq!(res["my_scores2"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores2"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores2"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores2"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores2"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores2"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores2"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores2"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores2"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores2"]["sum_other_doc_count"], 0); + + assert_eq!(res["my_scores3"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores3"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores3"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores3"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores3"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores3"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores3"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores3"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores3"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores3"]["sum_other_doc_count"], 0); + Ok(()) } @@ -1088,9 +1209,9 @@ mod tests { assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4); - assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc"); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 0); - assert_eq!(res["my_texts"]["buckets"][2]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc"); assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 0); assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 0); diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index b972bde2da..be1a77b64f 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -377,7 +377,7 @@ impl IntermediateBucketResult { IntermediateBucketResult::Terms(term_res_left), IntermediateBucketResult::Terms(term_res_right), ) => { - merge_maps(&mut term_res_left.entries, term_res_right.entries); + merge_key_maps(&mut term_res_left.entries, term_res_right.entries); term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count; term_res_left.doc_count_error_upper_bound += term_res_right.doc_count_error_upper_bound; @@ -387,7 +387,7 @@ impl IntermediateBucketResult { IntermediateBucketResult::Range(range_res_left), IntermediateBucketResult::Range(range_res_right), ) => { - merge_maps(&mut range_res_left.buckets, range_res_right.buckets); + merge_serialized_key_maps(&mut range_res_left.buckets, range_res_right.buckets); } ( IntermediateBucketResult::Histogram { @@ -438,7 +438,7 @@ pub struct IntermediateRangeBucketResult { #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] /// Term aggregation including error counts pub struct IntermediateTermBucketResult { - pub(crate) entries: FxHashMap, + pub(crate) entries: FxHashMap, pub(crate) sum_other_doc_count: u64, pub(crate) doc_count_error_upper_bound: u64, } @@ -458,7 +458,7 @@ impl IntermediateTermBucketResult { .map(|(key, entry)| { Ok(BucketEntry { key_as_string: None, - key: Key::Str(key), + key, doc_count: entry.doc_count, sub_aggregation: entry .sub_aggregation @@ -536,7 +536,7 @@ trait MergeFruits { fn merge_fruits(&mut self, other: Self); } -fn merge_maps( +fn merge_serialized_key_maps( entries_left: &mut FxHashMap, mut entries_right: FxHashMap, ) { @@ -551,6 +551,21 @@ fn merge_maps( } } +fn merge_key_maps( + entries_left: &mut FxHashMap, + mut entries_right: FxHashMap, +) { + for (name, entry_left) in entries_left.iter_mut() { + if let Some(entry_right) = entries_right.remove(name) { + entry_left.merge_fruits(entry_right); + } + } + + for (key, res) in entries_right.into_iter() { + entries_left.entry(key).or_insert(res); + } +} + /// This is the histogram entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index b92c99947c..5a84bbaf50 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -10,7 +10,7 @@ //! There are two categories: [Metrics](metric) and [Buckets](bucket). //! //! ## Prerequisite -//! Currently aggregations work only on [fast fields](`crate::fastfield`). Single value fast fields +//! Currently aggregations work only on [fast fields](`crate::fastfield`). Fast fields //! of type `u64`, `f64`, `i64`, `date` and fast fields on text fields. //! //! ## Usage @@ -262,7 +262,7 @@ impl VecWithNames { /// The serialized key is used in a `HashMap`. pub type SerializedKey = String; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, PartialOrd)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)] /// The key to identify a bucket. #[serde(untagged)] pub enum Key { @@ -271,6 +271,22 @@ pub enum Key { /// `f64` key F64(f64), } +impl Eq for Key {} +impl std::hash::Hash for Key { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + } +} + +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Str(l), Self::Str(r)) => l == r, + (Self::F64(l), Self::F64(r)) => l == r, + _ => false, + } + } +} impl Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 3e6bf6b3cb..3d568c1ccb 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -150,6 +150,7 @@ pub(crate) fn build_bucket_segment_agg_collector( SegmentTermCollector::from_req_and_validate( terms_req, &req.sub_aggregation, + req.field_type, accessor_idx, )?, )), @@ -279,11 +280,7 @@ impl GenericSegmentAggregationResultsCollector { .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { - Ok(build_bucket_segment_agg_collector( - req, - accessor_idx, - false, - )?) + build_bucket_segment_agg_collector(req, accessor_idx, false) }) .collect::>>>()?; let metrics = req @@ -291,11 +288,7 @@ impl GenericSegmentAggregationResultsCollector { .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { - Ok(build_metric_segment_agg_collector( - req, - accessor_idx, - false, - )?) + build_metric_segment_agg_collector(req, accessor_idx, false) }) .collect::>>>()?; let metrics = if metrics.is_empty() {