Skip to content

Commit

Permalink
feat: add support for u64,i64,f64 fields in term aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
PSeitz committed Feb 20, 2023
1 parent 02bebf4 commit a04cbbc
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 65 deletions.
6 changes: 2 additions & 4 deletions examples/fuzzy_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
// - search for the best document matching a basic query
// - retrieve the best document's original content.

use std::collections::HashSet;

// ---
// Importing tantivy...
use tantivy::collector::{Count, TopDocs};
use tantivy::query::{FuzzyTermQuery, QueryParser};
use tantivy::query::FuzzyTermQuery;
use tantivy::schema::*;
use tantivy::{doc, DocId, Index, ReloadPolicy, Score, SegmentReader};
use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir;

fn main() -> tantivy::Result<()> {
Expand Down
209 changes: 165 additions & 44 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use crate::aggregation::intermediate_agg_result::{
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::VecWithNames;
use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames};
use crate::error::DataCorruption;
use crate::schema::Type;
use crate::TantivyError;

/// Creates a bucket for every unique term and counts the number of occurences.
Expand All @@ -25,6 +26,10 @@ use crate::TantivyError;
/// If the text is untokenized and single value, that means one term per document and therefore it
/// is in fact doc count.
///
/// ## Prerequisite
/// Term aggregations work only on [fast fields](`crate::fastfield`) of type `u64`, `f64`, `i64` and
/// text.
///
/// ### Terminology
/// Shard parameters are supposed to be equivalent to elasticsearch shard parameter.
/// Since they are
Expand Down Expand Up @@ -199,9 +204,9 @@ impl TermsAggregationInternal {
}

#[derive(Clone, Debug, Default)]
/// Container to store term_ids and their buckets.
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u32, TermBucketEntry>,
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -262,6 +267,7 @@ pub struct SegmentTermCollector {
term_buckets: TermBuckets,
req: TermsAggregationInternal,
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
field_type: Type,
accessor_idx: usize,
}

Expand Down Expand Up @@ -310,7 +316,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
let entry = self
.term_buckets
.entries
.entry(term_id as u32)
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
Expand All @@ -323,7 +329,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
let entry = self
.term_buckets
.entries
.entry(term_id as u32)
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
Expand All @@ -348,6 +354,7 @@ impl SegmentTermCollector {
pub(crate) fn from_req_and_validate(
req: &TermsAggregation,
sub_aggregations: &AggregationsWithAccessor,
field_type: Type,
accessor_idx: usize,
) -> crate::Result<Self> {
let term_buckets = TermBuckets::default();
Expand Down Expand Up @@ -378,6 +385,7 @@ impl SegmentTermCollector {
req: TermsAggregationInternal::from_req(req),
term_buckets,
blueprint,
field_type,
accessor_idx,
})
}
Expand All @@ -386,7 +394,7 @@ impl SegmentTermCollector {
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u32, TermBucketEntry)> =
let mut entries: Vec<(u64, TermBucketEntry)> =
self.term_buckets.entries.into_iter().collect();

let order_by_sub_aggregation =
Expand Down Expand Up @@ -423,41 +431,52 @@ impl SegmentTermCollector {
cut_off_buckets(&mut entries, self.req.segment_size as usize)
};

let inverted_index = agg_with_accessor
.str_dict_column
.as_ref()
.expect("internal error: inverted index not loaded for term aggregation");
let term_dict = inverted_index;

let mut dict: FxHashMap<String, IntermediateTermBucketEntry> = Default::default();
let mut buffer = String::new();
for (term_id, entry) in entries {
if !term_dict.ord_to_str(term_id as u64, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {} in dict",
term_id
)));
}
dict.insert(
buffer.to_string(),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
}
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
let mut stream = term_dict.dictionary().stream()?;
while let Some((key, _ord)) = stream.next() {
if dict.len() >= self.req.segment_size as usize {
break;
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
if self.field_type == Type::Str {
let term_dict = agg_with_accessor
.str_dict_column
.as_ref()
.expect("internal error: term dictionary not found for term aggregation");

let mut buffer = String::new();
for (term_id, entry) in entries {
if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {} in dict",
term_id
)));
}
dict.insert(
Key::Str(buffer.to_string()),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
}
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
let mut stream = term_dict.dictionary().stream()?;
while let Some((key, _ord)) = stream.next() {
if dict.len() >= self.req.segment_size as usize {
break;
}

let key = std::str::from_utf8(key)
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?;
if !dict.contains_key(key) {
dict.insert(key.to_owned(), Default::default());
let key = Key::Str(
std::str::from_utf8(key)
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?
.to_string(),
);
dict.entry(key).or_insert_with(Default::default);
}
}
}
} else {
for (val, entry) in entries {
let val = f64_from_fastfield_u64(val, &self.field_type);
dict.insert(
Key::F64(val),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
}
};

Ok(IntermediateBucketResult::Terms(
IntermediateTermBucketResult {
Expand All @@ -477,6 +496,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
self.1.doc_count
}
}
impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
Expand Down Expand Up @@ -627,7 +651,8 @@ mod tests {
fn terms_aggregation_test_order_count_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec![(5.0, "terma".to_string())],
vec![(4.0, "termb".to_string())],
vec![(2.0, "termb".to_string())],
vec![(2.0, "terma".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
Expand Down Expand Up @@ -668,7 +693,7 @@ mod tests {
}),
..Default::default()
}),
sub_aggregation: sub_agg,
sub_aggregation: sub_agg.clone(),
}),
)]
.into_iter()
Expand All @@ -677,18 +702,114 @@ mod tests {
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0);

assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 1.0);

assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 6);
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 4.5);

assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);

// Agg on non string
//
let agg_req: Aggregations = vec![
(
"my_scores1".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
),
(
"my_scores2".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_f64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
),
(
"my_scores3".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_i64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
sub_aggregation: sub_agg,
}),
),
]
.into_iter()
.collect();

let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_scores1"]["buckets"][0]["key"], 8.0);
assert_eq!(res["my_scores1"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["my_scores1"]["buckets"][0]["avg_score"]["value"], 8.0);

assert_eq!(res["my_scores1"]["buckets"][1]["key"], 2.0);
assert_eq!(res["my_scores1"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_scores1"]["buckets"][1]["avg_score"]["value"], 2.0);

assert_eq!(res["my_scores1"]["buckets"][2]["key"], 1.0);
assert_eq!(res["my_scores1"]["buckets"][2]["doc_count"], 3);
assert_eq!(res["my_scores1"]["buckets"][2]["avg_score"]["value"], 1.0);

assert_eq!(res["my_scores1"]["buckets"][3]["key"], 5.0);
assert_eq!(res["my_scores1"]["buckets"][3]["doc_count"], 5);
assert_eq!(res["my_scores1"]["buckets"][3]["avg_score"]["value"], 5.0);

assert_eq!(res["my_scores1"]["sum_other_doc_count"], 0);

assert_eq!(res["my_scores2"]["buckets"][0]["key"], 8.0);
assert_eq!(res["my_scores2"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["my_scores2"]["buckets"][0]["avg_score"]["value"], 8.0);

assert_eq!(res["my_scores2"]["buckets"][1]["key"], 2.0);
assert_eq!(res["my_scores2"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_scores2"]["buckets"][1]["avg_score"]["value"], 2.0);

assert_eq!(res["my_scores2"]["buckets"][2]["key"], 1.0);
assert_eq!(res["my_scores2"]["buckets"][2]["doc_count"], 3);
assert_eq!(res["my_scores2"]["buckets"][2]["avg_score"]["value"], 1.0);

assert_eq!(res["my_scores2"]["sum_other_doc_count"], 0);

assert_eq!(res["my_scores3"]["buckets"][0]["key"], 8.0);
assert_eq!(res["my_scores3"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["my_scores3"]["buckets"][0]["avg_score"]["value"], 8.0);

assert_eq!(res["my_scores3"]["buckets"][1]["key"], 2.0);
assert_eq!(res["my_scores3"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_scores3"]["buckets"][1]["avg_score"]["value"], 2.0);

assert_eq!(res["my_scores3"]["buckets"][2]["key"], 1.0);
assert_eq!(res["my_scores3"]["buckets"][2]["doc_count"], 3);
assert_eq!(res["my_scores3"]["buckets"][2]["avg_score"]["value"], 1.0);

assert_eq!(res["my_scores3"]["sum_other_doc_count"], 0);

Ok(())
}

Expand Down Expand Up @@ -1088,9 +1209,9 @@ mod tests {

assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 0);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 0);
Expand Down
Loading

0 comments on commit a04cbbc

Please sign in to comment.