From a80c8dc77b8a6ded3bcfd534242fc5faacba0652 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 18 Apr 2023 14:47:28 +0800 Subject: [PATCH 1/5] refactor result handling --- src/aggregation/agg_req.rs | 49 +++- src/aggregation/agg_req_deser.rs | 144 +++++++++++ src/aggregation/agg_result.rs | 11 - src/aggregation/agg_tests.rs | 68 +++++ src/aggregation/bucket/histogram/histogram.rs | 25 +- src/aggregation/bucket/range.rs | 26 +- src/aggregation/bucket/term_agg.rs | 47 ++-- src/aggregation/buf_collector.rs | 7 +- src/aggregation/collector.rs | 9 +- src/aggregation/intermediate_agg_result.rs | 241 +++++++----------- src/aggregation/metric/percentiles.rs | 20 +- src/aggregation/metric/stats.rs | 20 +- src/aggregation/mod.rs | 10 +- src/aggregation/segment_agg_result.rs | 38 +-- 14 files changed, 442 insertions(+), 273 deletions(-) create mode 100644 src/aggregation/agg_req_deser.rs diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 8e3974062d..de9c1c842a 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -28,6 +28,7 @@ use std::collections::{HashMap, HashSet}; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::{ @@ -37,7 +38,6 @@ use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, }; -use super::VecWithNames; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user /// defined names. It is also used in [buckets](BucketAggregation) to define sub-aggregations. @@ -48,32 +48,57 @@ pub type Aggregations = HashMap; /// Like Aggregations, but optimized to work with the aggregation result #[derive(Clone, Debug)] pub(crate) struct AggregationsInternal { - pub(crate) metrics: VecWithNames, - pub(crate) buckets: VecWithNames, + pub(crate) aggs: FxHashMap, } impl From for AggregationsInternal { fn from(aggs: Aggregations) -> Self { - let mut metrics = vec![]; - let mut buckets = vec![]; + let mut aggs_internal = FxHashMap::default(); for (key, agg) in aggs { match agg { Aggregation::Bucket(bucket) => { let sub_aggregation = bucket.get_sub_aggs().clone().into(); - buckets.push(( + aggs_internal.insert( key, - BucketAggregationInternal { + AggregationInternal::Bucket(Box::new(BucketAggregationInternal { bucket_agg: bucket.bucket_agg, sub_aggregation, - }, - )) + })), + ); + } + Aggregation::Metric(metric) => { + aggs_internal.insert(key, AggregationInternal::Metric(metric)); } - Aggregation::Metric(metric) => metrics.push((key, metric)), } } Self { - metrics: VecWithNames::from_entries(metrics), - buckets: VecWithNames::from_entries(buckets), + aggs: aggs_internal, + } + } +} + +/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`]. +/// +/// An aggregation is either a bucket or a metric. +#[derive(Clone, Debug)] +pub(crate) enum AggregationInternal { + /// Bucket aggregation, see [`BucketAggregation`] for details. + Bucket(Box), + /// Metric aggregation, see [`MetricAggregation`] for details. + Metric(MetricAggregation), +} + +impl AggregationInternal { + pub fn as_bucket(&self) -> Option<&Box> { + match self { + AggregationInternal::Bucket(bucket) => Some(bucket), + _ => None, + } + } + pub fn as_metric(&self) -> Option<&MetricAggregation> { + match self { + AggregationInternal::Metric(metric) => Some(metric), + _ => None, } } } diff --git a/src/aggregation/agg_req_deser.rs b/src/aggregation/agg_req_deser.rs new file mode 100644 index 0000000000..2d7840b13f --- /dev/null +++ b/src/aggregation/agg_req_deser.rs @@ -0,0 +1,144 @@ +use std::collections::{HashMap, HashSet}; + +use serde::*; + +use super::bucket::*; +use super::metric::*; +pub type Aggregations = HashMap; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct AggregationDeser { + /// Bucket aggregation strategy to group documents. + #[serde(flatten)] + pub agg: AggregationVariants, + /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the + /// bucket. + #[serde(rename = "aggs")] + #[serde(default)] + #[serde(skip_serializing_if = "Aggregations::is_empty")] + pub sub_aggregation: Aggregations, +} + +impl AggregationDeser { + fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { + fast_field_names.insert(self.agg.get_fast_field_name().to_string()); + fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); + } +} + +/// Extract all fast field names used in the tree. +pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { + let mut fast_field_names = Default::default(); + for el in aggs.values() { + el.get_fast_field_names(&mut fast_field_names) + } + fast_field_names +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum AggregationVariants { + // Bucket aggregation types + /// Put data into buckets of user-defined ranges. + #[serde(rename = "range")] + Range(RangeAggregation), + /// Put data into a histogram. + #[serde(rename = "histogram")] + Histogram(HistogramAggregation), + /// Put data into a date histogram. + #[serde(rename = "date_histogram")] + DateHistogram(DateHistogramAggregationReq), + /// Put data into buckets of terms. + #[serde(rename = "terms")] + Terms(TermsAggregation), + + // Metric aggregation types + /// Computes the average of the extracted values. + #[serde(rename = "avg")] + Average(AverageAggregation), + /// Counts the number of extracted values. + #[serde(rename = "value_count")] + Count(CountAggregation), + /// Finds the maximum value. + #[serde(rename = "max")] + Max(MaxAggregation), + /// Finds the minimum value. + #[serde(rename = "min")] + Min(MinAggregation), + /// Computes a collection of statistics (`min`, `max`, `sum`, `count`, and `avg`) over the + /// extracted values. + #[serde(rename = "stats")] + Stats(StatsAggregation), + /// Computes the sum of the extracted values. + #[serde(rename = "sum")] + Sum(SumAggregation), + /// Computes the sum of the extracted values. + #[serde(rename = "percentiles")] + Percentiles(PercentilesAggregationReq), +} + +impl AggregationVariants { + fn get_fast_field_name(&self) -> &str { + match self { + AggregationVariants::Terms(terms) => terms.field.as_str(), + AggregationVariants::Range(range) => range.field.as_str(), + AggregationVariants::Histogram(histogram) => histogram.field.as_str(), + AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), + AggregationVariants::Average(avg) => avg.field_name(), + AggregationVariants::Count(count) => count.field_name(), + AggregationVariants::Max(max) => max.field_name(), + AggregationVariants::Min(min) => min.field_name(), + AggregationVariants::Stats(stats) => stats.field_name(), + AggregationVariants::Sum(sum) => sum.field_name(), + AggregationVariants::Percentiles(per) => per.field_name(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deser_json_test() { + let agg_req_json = r#"{ + "price_avg": { "avg": { "field": "price" } }, + "price_count": { "value_count": { "field": "price" } }, + "price_max": { "max": { "field": "price" } }, + "price_min": { "min": { "field": "price" } }, + "price_stats": { "stats": { "field": "price" } }, + "price_sum": { "sum": { "field": "price" } } + }"#; + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } + + #[test] + fn deser_json_test_bucket() { + let agg_req_json = r#" + { + "termagg": { + "terms": { + "field": "json.mixed_type", + "order": { "min_price": "desc" } + }, + "aggs": { + "min_price": { "min": { "field": "json.mixed_type" } } + } + }, + "rangeagg": { + "range": { + "field": "json.mixed_type", + "ranges": [ + { "to": 3.0 }, + { "from": 19.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "json.mixed_type" } } + } + } + } "#; + + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } +} diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 09068f51d7..bb95858ba3 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -7,11 +7,8 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use super::agg_req::BucketAggregationInternal; use super::bucket::GetDocCount; -use super::intermediate_agg_result::IntermediateBucketResult; use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats}; -use super::segment_agg_result::AggregationLimits; use super::{AggregationError, Key}; use crate::TantivyError; @@ -164,14 +161,6 @@ impl BucketResult { } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), } } - - pub(crate) fn empty_from_req( - req: &BucketAggregationInternal, - limits: &AggregationLimits, - ) -> crate::Result { - let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - empty_bucket.into_final_bucket_result(req, limits) - } } /// This is the wrapper of buckets entries, which can be vector or hashmap diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 2da0ce36d7..3f3b73b4c3 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -195,6 +195,74 @@ fn test_aggregation_flushing_variants() { test_aggregation_flushing(true, true).unwrap(); } +#[test] +fn test_aggregation_level1_simple() -> crate::Result<()> { + let index = get_test_index_2_segments(true)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let range_agg = |field_name: &str| -> Aggregation { + serde_json::from_value(json!({ + "range": { + "field": field_name, + "ranges": [ { "from": 3.0f64, "to": 7.0f64 }, { "from": 7.0f64, "to": 20.0f64 } ] + } + })) + .unwrap() + }; + + let agg_req_1: Aggregations = vec![ + ("average".to_string(), get_avg_req("score")), + ("range".to_string(), range_agg("score")), + ] + .into_iter() + .collect(); + + let collector = get_collector(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + assert_eq!(res["average"]["value"], 12.142857142857142); + assert_eq!( + res["range"]["buckets"], + json!( + [ + { + "key": "*-3", + "doc_count": 1, + "to": 3.0 + }, + { + "key": "3-7", + "doc_count": 2, + "from": 3.0, + "to": 7.0 + }, + { + "key": "7-20", + "doc_count": 3, + "from": 7.0, + "to": 20.0 + }, + { + "key": "20-*", + "doc_count": 1, + "from": 20.0 + } + ]) + ); + + Ok(()) +} + #[test] fn test_aggregation_level1() -> crate::Result<()> { let index = get_test_index_2_segments(true)?; diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 4056bcb4e7..d9906abe6f 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -14,12 +14,13 @@ use crate::aggregation::agg_req_with_accessor::{ }; use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateHistogramBucketEntry, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, format_date}; use crate::TantivyError; /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. @@ -190,11 +191,13 @@ impl SegmentHistogramBucketEntry { sub_aggregation: Box, agg_with_accessor: &AggregationsWithAccessor, ) -> crate::Result { + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + sub_aggregation + .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?; Ok(IntermediateHistogramBucketEntry { key: self.key, doc_count: self.doc_count, - sub_aggregation: sub_aggregation - .into_intermediate_aggregations_result(agg_with_accessor)?, + sub_aggregation: sub_aggregation_res, }) } } @@ -215,20 +218,18 @@ pub struct SegmentHistogramCollector { } impl SegmentAggregationCollector for SegmentHistogramCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] @@ -695,7 +696,7 @@ mod tests { assert_eq!( res.to_string(), "Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \ - 102.48 KB" + 59.71 KB" ); Ok(()) diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 591db38314..972bdd1dd0 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -7,14 +7,14 @@ use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, - IntermediateRangeBucketResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; use crate::aggregation::{ - f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames, + f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, }; use crate::TantivyError; @@ -157,8 +157,10 @@ impl SegmentRangeBucketEntry { self, agg_with_accessor: &AggregationsWithAccessor, ) -> crate::Result { - let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)? + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + if let Some(sub_aggregation) = self.sub_aggregation { + sub_aggregation + .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)? } else { Default::default() }; @@ -166,7 +168,7 @@ impl SegmentRangeBucketEntry { Ok(IntermediateRangeBucketEntry { key: self.key, doc_count: self.doc_count, - sub_aggregation, + sub_aggregation: sub_aggregation_res, from: self.from, to: self.to, }) @@ -174,10 +176,11 @@ impl SegmentRangeBucketEntry { } impl SegmentAggregationCollector for SegmentRangeCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let field_type = self.column_type; let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); let sub_agg = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; @@ -200,12 +203,9 @@ impl SegmentAggregationCollector for SegmentRangeCollector { column_type: Some(self.column_type), }); - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 329ea120bf..8646a15d8b 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -10,13 +10,13 @@ use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateTermBucketEntry, - IntermediateTermBucketResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateTermBucketEntry, IntermediateTermBucketResult, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, Key}; use crate::error::DataCorruption; use crate::TantivyError; @@ -246,20 +246,18 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { } impl SegmentAggregationCollector for SegmentTermCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] @@ -410,21 +408,24 @@ impl SegmentTermCollector { let mut into_intermediate_bucket_entry = |id, doc_count| -> crate::Result { let intermediate_entry = if self.blueprint.as_ref().is_some() { + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + self.term_buckets + .sub_aggs + .remove(&id) + .unwrap_or_else(|| { + panic!( + "Internal Error: could not find subaggregation for id {}", + id + ) + }) + .add_intermediate_aggregation_result( + &agg_with_accessor.sub_aggregation, + &mut sub_aggregation_res, + )?; + IntermediateTermBucketEntry { doc_count, - sub_aggregation: self - .term_buckets - .sub_aggs - .remove(&id) - .unwrap_or_else(|| { - panic!( - "Internal Error: could not find subaggregation for id {}", - id - ) - }) - .into_intermediate_aggregations_result( - &agg_with_accessor.sub_aggregation, - )?, + sub_aggregation: sub_aggregation_res, } } else { IntermediateTermBucketEntry { diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs index d8ec399bc4..15be6281bf 100644 --- a/src/aggregation/buf_collector.rs +++ b/src/aggregation/buf_collector.rs @@ -35,11 +35,12 @@ impl BufAggregationCollector { impl SegmentAggregationCollector for BufAggregationCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor) + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results) } #[inline] diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index dc80a7e53f..183cc24258 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -184,6 +184,13 @@ impl SegmentCollector for AggregationSegmentCollector { return Err(err); } self.agg_collector.flush(&mut self.aggs_with_accessor)?; - Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor) + + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + Box::new(self.agg_collector).add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + )?; + + Ok(sub_aggregation_res) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 182fc90abb..0e4375bd9c 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -11,8 +11,8 @@ use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::agg_req::{ - Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType, - MetricAggregation, + AggregationInternal, Aggregations, AggregationsInternal, BucketAggregationInternal, + BucketAggregationType, MetricAggregation, }; use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry}; use super::bucket::{ @@ -34,13 +34,14 @@ use crate::TantivyError; /// intermediate results. #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateAggregationResults { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) metrics: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) buckets: Option>, + pub(crate) aggs_res: VecWithNames, } impl IntermediateAggregationResults { + pub fn push(&mut self, key: String, value: IntermediateAggregationResult) { + self.aggs_res.push(key, value); + } + /// Convert intermediate result and its aggregation request to the final result. pub fn into_final_result( self, @@ -69,64 +70,46 @@ impl IntermediateAggregationResults { req: &AggregationsInternal, limits: &AggregationLimits, ) -> crate::Result { - // Important assumption: - // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the - // request let mut results: FxHashMap = FxHashMap::default(); - - if let Some(buckets) = self.buckets { - convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)? - } else { - // When there are no buckets, we create empty buckets, so that the serialized json - // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)? - }; - - if let Some(metrics) = self.metrics { - convert_and_add_final_metrics_to_result(&mut results, metrics, &req.metrics); - } else { - // When there are no metrics, we create empty metric results, so that the serialized - // json format is constant - add_empty_final_metrics_to_result(&mut results, &req.metrics)?; + for (key, agg_res) in self.aggs_res.into_iter() { + let req = req.aggs.get(key.as_str()).unwrap(); + results.insert(key, agg_res.into_final_result(req, limits)?); + } + // Handle empty results + if results.len() != req.aggs.len() { + for (key, req) in req.aggs.iter() { + if !results.contains_key(key) { + let empty_res = match req { + AggregationInternal::Bucket(b) => IntermediateAggregationResult::Bucket( + IntermediateBucketResult::empty_from_req(&b.bucket_agg), + ), + AggregationInternal::Metric(m) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::empty_from_req(m), + ), + }; + results.insert(key.to_string(), empty_res.into_final_result(req, limits)?); + } + } } Ok(AggregationResults(results)) } pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { - let metrics = if req.metrics.is_empty() { - None - } else { - let metrics = req - .metrics - .iter() - .map(|(key, req)| { - ( - key.to_string(), - IntermediateMetricResult::empty_from_req(req), - ) - }) - .collect(); - Some(VecWithNames::from_entries(metrics)) - }; - - let buckets = if req.buckets.is_empty() { - None - } else { - let buckets = req - .buckets - .iter() - .map(|(key, req)| { - ( - key.to_string(), - IntermediateBucketResult::empty_from_req(&req.bucket_agg), - ) - }) - .collect(); - Some(VecWithNames::from_entries(buckets)) - }; + let mut aggs_res: VecWithNames = VecWithNames::default(); + for (key, req) in req.aggs.iter() { + let empty_res = match req { + AggregationInternal::Bucket(b) => IntermediateAggregationResult::Bucket( + IntermediateBucketResult::empty_from_req(&b.bucket_agg), + ), + AggregationInternal::Metric(m) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::empty_from_req(m), + ), + }; + aggs_res.push(key.to_string(), empty_res); + } - Self { metrics, buckets } + Self { aggs_res } } /// Merge another intermediate aggregation result into this result. @@ -134,87 +117,13 @@ impl IntermediateAggregationResults { /// The order of the values need to be the same on both results. This is ensured when the same /// (key values) are present on the underlying `VecWithNames` struct. pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> { - if let (Some(buckets_left), Some(buckets_right)) = (&mut self.buckets, other.buckets) { - for (bucket_left, bucket_right) in - buckets_left.values_mut().zip(buckets_right.into_values()) - { - bucket_left.merge_fruits(bucket_right)?; - } - } - - if let (Some(metrics_left), Some(metrics_right)) = (&mut self.metrics, other.metrics) { - for (metric_left, metric_right) in - metrics_left.values_mut().zip(metrics_right.into_values()) - { - metric_left.merge_fruits(metric_right)?; - } + for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) { + left.merge_fruits(right)?; } Ok(()) } } -fn convert_and_add_final_metrics_to_result( - results: &mut FxHashMap, - metrics: VecWithNames, - metrics_req: &VecWithNames, -) { - let metric_result_with_request = metrics.into_iter().zip(metrics_req.values()); - results.extend( - metric_result_with_request - .into_iter() - .map(|((key, metric), req)| { - ( - key, - AggregationResult::MetricResult(metric.into_final_metric_result(req)), - ) - }), - ); -} - -fn add_empty_final_metrics_to_result( - results: &mut FxHashMap, - req_metrics: &VecWithNames, -) -> crate::Result<()> { - results.extend(req_metrics.iter().map(|(key, req)| { - let empty_bucket = IntermediateMetricResult::empty_from_req(req); - ( - key.to_string(), - AggregationResult::MetricResult(empty_bucket.into_final_metric_result(req)), - ) - })); - Ok(()) -} - -fn add_empty_final_buckets_to_result( - results: &mut FxHashMap, - req_buckets: &VecWithNames, - limits: &AggregationLimits, -) -> crate::Result<()> { - let requested_buckets = req_buckets.iter(); - for (key, req) in requested_buckets { - let empty_bucket = - AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?); - results.insert(key.to_string(), empty_bucket); - } - Ok(()) -} - -fn convert_and_add_final_buckets_to_result( - results: &mut FxHashMap, - buckets: VecWithNames, - req_buckets: &VecWithNames, - limits: &AggregationLimits, -) -> crate::Result<()> { - assert_eq!(buckets.len(), req_buckets.len()); - - let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); - for ((key, bucket), req) in buckets_with_request { - let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?); - results.insert(key, result); - } - Ok(()) -} - /// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateAggregationResult { @@ -224,6 +133,44 @@ pub enum IntermediateAggregationResult { Metric(IntermediateMetricResult), } +impl IntermediateAggregationResult { + pub(crate) fn into_final_result( + self, + req: &AggregationInternal, + limits: &AggregationLimits, + ) -> crate::Result { + let res = match self { + IntermediateAggregationResult::Bucket(bucket) => AggregationResult::BucketResult( + bucket.into_final_bucket_result( + req.as_bucket() + .expect("mismatch bucket result and metric request type"), + limits, + )?, + ), + IntermediateAggregationResult::Metric(metric) => AggregationResult::MetricResult( + metric.into_final_metric_result( + req.as_metric() + .expect("mismatch metric result and bucket request type"), + ), + ), + }; + Ok(res) + } + fn merge_fruits(&mut self, other: IntermediateAggregationResult) -> crate::Result<()> { + match (self, other) { + ( + IntermediateAggregationResult::Bucket(b1), + IntermediateAggregationResult::Bucket(b2), + ) => b1.merge_fruits(b2), + ( + IntermediateAggregationResult::Metric(m1), + IntermediateAggregationResult::Metric(m2), + ) => m1.merge_fruits(m2), + _ => panic!("aggregation result type mismatch (mixed metric and buckets)"), + } + } +} + /// Holds the intermediate data for metric results #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateMetricResult { @@ -541,7 +488,9 @@ where fn deserialize_entries<'de, D>( deserializer: D, ) -> Result, D::Error> -where D: Deserializer<'de> { +where + D: Deserializer<'de>, +{ let vec_entries: Vec<(Key, IntermediateTermBucketEntry)> = Deserialize::deserialize(deserializer)?; Ok(vec_entries.into_iter().collect()) @@ -825,14 +774,15 @@ mod tests { } map.insert( "my_agg_level2".to_string(), - IntermediateBucketResult::Range(IntermediateRangeBucketResult { - buckets, - column_type: None, - }), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + IntermediateRangeBucketResult { + buckets, + column_type: None, + }, + )), ); IntermediateAggregationResults { - buckets: Some(VecWithNames::from_entries(map.into_iter().collect())), - metrics: Default::default(), + aggs_res: VecWithNames::from_entries(map.into_iter().collect()), } } @@ -858,14 +808,15 @@ mod tests { } map.insert( "my_agg_level1".to_string(), - IntermediateBucketResult::Range(IntermediateRangeBucketResult { - buckets, - column_type: None, - }), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + IntermediateRangeBucketResult { + buckets, + column_type: None, + }, + )), ); IntermediateAggregationResults { - buckets: Some(VecWithNames::from_entries(map.into_iter().collect())), - metrics: Default::default(), + aggs_res: VecWithNames::from_entries(map.into_iter().collect()), } } diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index cf51606a59..1e33a79166 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -8,10 +8,10 @@ use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, MetricAggregationWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateMetricResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, AggregationError, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, AggregationError}; use crate::{DocId, TantivyError}; /// # Percentiles @@ -255,22 +255,20 @@ impl SegmentPercentilesCollector { impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); - let metrics = Some(VecWithNames::from_entries(vec![( + results.push( name, - intermediate_metric_result, - )])); + IntermediateAggregationResult::Metric(intermediate_metric_result), + ); - Ok(IntermediateAggregationResults { - metrics, - buckets: None, - }) + Ok(()) } #[inline] diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index b7bfe8f6cf..0500abab62 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -5,11 +5,11 @@ use super::*; use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, MetricAggregationWithAccessor, }; +use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateMetricResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, VecWithNames}; use crate::{DocId, TantivyError}; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that @@ -194,10 +194,11 @@ impl SegmentStatsCollector { impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); let intermediate_metric_result = match self.collecting_for { @@ -219,15 +220,12 @@ impl SegmentAggregationCollector for SegmentStatsCollector { } }; - let metrics = Some(VecWithNames::from_entries(vec![( + results.push( name, - intermediate_metric_result, - )])); + IntermediateAggregationResult::Metric(intermediate_metric_result), + ); - Ok(IntermediateAggregationResults { - metrics, - buckets: None, - }) + Ok(()) } #[inline] diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 71df51ca19..38bfce95b4 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -157,6 +157,7 @@ mod agg_limits; pub mod agg_req; +pub mod agg_req_deser; mod agg_req_with_accessor; pub mod agg_result; pub mod bucket; @@ -216,9 +217,9 @@ impl From> for VecWithNames { } impl VecWithNames { - fn extend(&mut self, entries: VecWithNames) { - self.keys.extend(entries.keys); - self.values.extend(entries.values); + fn push(&mut self, key: String, value: T) { + self.keys.push(key); + self.values.push(value); } fn from_entries(mut entries: Vec<(String, T)>) -> Self { @@ -247,9 +248,6 @@ impl VecWithNames { fn into_values(self) -> impl Iterator { self.values.into_iter() } - fn values(&self) -> impl Iterator + '_ { - self.values.iter() - } fn values_mut(&mut self) -> impl Iterator + '_ { self.values.iter_mut() } diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 074ff782fb..aac4d96857 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -17,14 +17,14 @@ use super::metric::{ SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, SumAggregation, }; -use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result; + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()>; fn collect( &mut self, @@ -50,7 +50,8 @@ pub(crate) trait CollectorClone { } impl CollectorClone for T -where T: 'static + SegmentAggregationCollector + Clone +where + T: 'static + SegmentAggregationCollector + Clone, { fn clone_box(&self) -> Box { Box::new(self.clone()) @@ -181,36 +182,23 @@ impl Debug for GenericSegmentAggregationResultsCollector { } impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let buckets = if let Some(buckets) = self.buckets { - let mut intermeditate_buckets = VecWithNames::default(); + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + if let Some(buckets) = self.buckets { for bucket in buckets { - // TODO too many allocations? - let res = bucket.into_intermediate_aggregations_result(agg_with_accessor)?; - // unwrap is fine since we only have buckets here - intermeditate_buckets.extend(res.buckets.unwrap()); + bucket.add_intermediate_aggregation_result(agg_with_accessor, results)?; } - Some(intermeditate_buckets) - } else { - None }; - let metrics = if let Some(metrics) = self.metrics { - let mut intermeditate_metrics = VecWithNames::default(); + if let Some(metrics) = self.metrics { for metric in metrics { - // TODO too many allocations? - let res = metric.into_intermediate_aggregations_result(agg_with_accessor)?; - // unwrap is fine since we only have metrics here - intermeditate_metrics.extend(res.metrics.unwrap()); + metric.add_intermediate_aggregation_result(agg_with_accessor, results)?; } - Some(intermeditate_metrics) - } else { - None }; - Ok(IntermediateAggregationResults { metrics, buckets }) + Ok(()) } fn collect( From 394546c9d5e0b97d81e598c3747f6f0e380ecf44 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 18 Apr 2023 16:44:20 +0800 Subject: [PATCH 2/5] remove Internal stuff --- src/aggregation/agg_req.rs | 66 ++----------------- src/aggregation/bucket/histogram/histogram.rs | 6 +- src/aggregation/intermediate_agg_result.rs | 38 +++++------ src/aggregation/segment_agg_result.rs | 3 +- 4 files changed, 28 insertions(+), 85 deletions(-) diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index de9c1c842a..b060809d06 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -28,7 +28,6 @@ use std::collections::{HashMap, HashSet}; -use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::{ @@ -45,76 +44,23 @@ use super::metric::{ /// The key is the user defined name of the aggregation. pub type Aggregations = HashMap; -/// Like Aggregations, but optimized to work with the aggregation result -#[derive(Clone, Debug)] -pub(crate) struct AggregationsInternal { - pub(crate) aggs: FxHashMap, -} - -impl From for AggregationsInternal { - fn from(aggs: Aggregations) -> Self { - let mut aggs_internal = FxHashMap::default(); - for (key, agg) in aggs { - match agg { - Aggregation::Bucket(bucket) => { - let sub_aggregation = bucket.get_sub_aggs().clone().into(); - aggs_internal.insert( - key, - AggregationInternal::Bucket(Box::new(BucketAggregationInternal { - bucket_agg: bucket.bucket_agg, - sub_aggregation, - })), - ); - } - Aggregation::Metric(metric) => { - aggs_internal.insert(key, AggregationInternal::Metric(metric)); - } - } - } - Self { - aggs: aggs_internal, - } - } -} - -/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`]. -/// -/// An aggregation is either a bucket or a metric. -#[derive(Clone, Debug)] -pub(crate) enum AggregationInternal { - /// Bucket aggregation, see [`BucketAggregation`] for details. - Bucket(Box), - /// Metric aggregation, see [`MetricAggregation`] for details. - Metric(MetricAggregation), -} - -impl AggregationInternal { - pub fn as_bucket(&self) -> Option<&Box> { +impl Aggregation { + pub fn as_bucket(&self) -> Option<&Box> { match self { - AggregationInternal::Bucket(bucket) => Some(bucket), + Aggregation::Bucket(bucket) => Some(bucket), _ => None, } } pub fn as_metric(&self) -> Option<&MetricAggregation> { match self { - AggregationInternal::Metric(metric) => Some(metric), + Aggregation::Metric(metric) => Some(metric), _ => None, } } } -#[derive(Clone, Debug)] -// Like BucketAggregation, but optimized to work with the result -pub(crate) struct BucketAggregationInternal { - /// Bucket aggregation strategy to group documents. - pub bucket_agg: BucketAggregationType, - /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the - /// bucket. - sub_aggregation: AggregationsInternal, -} - -impl BucketAggregationInternal { - pub(crate) fn sub_aggregation(&self) -> &AggregationsInternal { +impl BucketAggregation { + pub(crate) fn sub_aggregation(&self) -> &Aggregations { &self.sub_aggregation } pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index d9906abe6f..fdcb8095a0 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; use crate::aggregation::agg_limits::MemoryConsumption; -use crate::aggregation::agg_req::AggregationsInternal; +use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, }; @@ -385,7 +385,7 @@ fn get_bucket_key_from_pos(bucket_pos: f64, interval: f64, offset: f64) -> f64 { fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, - sub_aggregation: &AggregationsInternal, + sub_aggregation: &Aggregations, limits: &AggregationLimits, ) -> crate::Result> { // Generate the full list of buckets without gaps. @@ -444,7 +444,7 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( buckets: Vec, column_type: Option, histogram_req: &HistogramAggregation, - sub_aggregation: &AggregationsInternal, + sub_aggregation: &Aggregations, limits: &AggregationLimits, ) -> crate::Result> { let mut buckets = if histogram_req.min_doc_count() == 0 { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 0e4375bd9c..79973489e0 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -11,8 +11,7 @@ use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::agg_req::{ - AggregationInternal, Aggregations, AggregationsInternal, BucketAggregationInternal, - BucketAggregationType, MetricAggregation, + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, }; use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry}; use super::bucket::{ @@ -38,6 +37,7 @@ pub struct IntermediateAggregationResults { } impl IntermediateAggregationResults { + /// Add a result pub fn push(&mut self, key: String, value: IntermediateAggregationResult) { self.aggs_res.push(key, value); } @@ -67,23 +67,23 @@ impl IntermediateAggregationResults { /// for internal processing, by splitting metric and buckets into separate groups. pub(crate) fn into_final_result_internal( self, - req: &AggregationsInternal, + req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { let mut results: FxHashMap = FxHashMap::default(); for (key, agg_res) in self.aggs_res.into_iter() { - let req = req.aggs.get(key.as_str()).unwrap(); + let req = req.get(key.as_str()).unwrap(); results.insert(key, agg_res.into_final_result(req, limits)?); } // Handle empty results - if results.len() != req.aggs.len() { - for (key, req) in req.aggs.iter() { + if results.len() != req.len() { + for (key, req) in req.iter() { if !results.contains_key(key) { let empty_res = match req { - AggregationInternal::Bucket(b) => IntermediateAggregationResult::Bucket( + Aggregation::Bucket(b) => IntermediateAggregationResult::Bucket( IntermediateBucketResult::empty_from_req(&b.bucket_agg), ), - AggregationInternal::Metric(m) => IntermediateAggregationResult::Metric( + Aggregation::Metric(m) => IntermediateAggregationResult::Metric( IntermediateMetricResult::empty_from_req(m), ), }; @@ -95,14 +95,14 @@ impl IntermediateAggregationResults { Ok(AggregationResults(results)) } - pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { + pub(crate) fn empty_from_req(req: &Aggregations) -> Self { let mut aggs_res: VecWithNames = VecWithNames::default(); - for (key, req) in req.aggs.iter() { + for (key, req) in req.iter() { let empty_res = match req { - AggregationInternal::Bucket(b) => IntermediateAggregationResult::Bucket( + Aggregation::Bucket(b) => IntermediateAggregationResult::Bucket( IntermediateBucketResult::empty_from_req(&b.bucket_agg), ), - AggregationInternal::Metric(m) => IntermediateAggregationResult::Metric( + Aggregation::Metric(m) => IntermediateAggregationResult::Metric( IntermediateMetricResult::empty_from_req(m), ), }; @@ -136,7 +136,7 @@ pub enum IntermediateAggregationResult { impl IntermediateAggregationResult { pub(crate) fn into_final_result( self, - req: &AggregationInternal, + req: &Aggregation, limits: &AggregationLimits, ) -> crate::Result { let res = match self { @@ -302,7 +302,7 @@ pub enum IntermediateBucketResult { impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, - req: &BucketAggregationInternal, + req: &BucketAggregation, limits: &AggregationLimits, ) -> crate::Result { match self { @@ -488,9 +488,7 @@ where fn deserialize_entries<'de, D>( deserializer: D, ) -> Result, D::Error> -where - D: Deserializer<'de>, -{ +where D: Deserializer<'de> { let vec_entries: Vec<(Key, IntermediateTermBucketEntry)> = Deserialize::deserialize(deserializer)?; Ok(vec_entries.into_iter().collect()) @@ -500,7 +498,7 @@ impl IntermediateTermBucketResult { pub(crate) fn into_final_result( self, req: &TermsAggregation, - sub_aggregation_req: &AggregationsInternal, + sub_aggregation_req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); @@ -636,7 +634,7 @@ pub struct IntermediateHistogramBucketEntry { impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, - req: &AggregationsInternal, + req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { Ok(BucketEntry { @@ -681,7 +679,7 @@ pub struct IntermediateRangeBucketEntry { impl IntermediateRangeBucketEntry { pub(crate) fn into_final_bucket_entry( self, - req: &AggregationsInternal, + req: &Aggregations, _range_req: &RangeAggregation, column_type: Option, limits: &AggregationLimits, diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index aac4d96857..68bb254f1c 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -50,8 +50,7 @@ pub(crate) trait CollectorClone { } impl CollectorClone for T -where - T: 'static + SegmentAggregationCollector + Clone, +where T: 'static + SegmentAggregationCollector + Clone { fn clone_box(&self) -> Box { Box::new(self.clone()) From c3c007b8e4e1237ba4e89f21028f25637325f311 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 18 Apr 2023 17:50:42 +0800 Subject: [PATCH 3/5] merge different accessors --- src/aggregation/agg_req_deser.rs | 29 +++ src/aggregation/agg_req_with_accessor.rs | 192 ++++++++---------- src/aggregation/bucket/histogram/histogram.rs | 4 +- src/aggregation/bucket/term_agg.rs | 4 +- src/aggregation/metric/percentiles.rs | 4 +- src/aggregation/metric/stats.rs | 4 +- src/aggregation/segment_agg_result.rs | 176 ++++++++-------- 7 files changed, 218 insertions(+), 195 deletions(-) diff --git a/src/aggregation/agg_req_deser.rs b/src/aggregation/agg_req_deser.rs index 2d7840b13f..6546024b40 100644 --- a/src/aggregation/agg_req_deser.rs +++ b/src/aggregation/agg_req_deser.rs @@ -92,6 +92,35 @@ impl AggregationVariants { AggregationVariants::Percentiles(per) => per.field_name(), } } + + pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { + match &self { + AggregationVariants::Range(range) => Some(range), + _ => None, + } + } + pub(crate) fn as_histogram(&self) -> crate::Result> { + match &self { + AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())), + AggregationVariants::DateHistogram(histogram) => { + Ok(Some(histogram.to_histogram_req()?)) + } + _ => Ok(None), + } + } + pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { + match &self { + AggregationVariants::Terms(terms) => Some(terms), + _ => None, + } + } + + pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { + match &self { + AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), + _ => None, + } + } } #[cfg(test)] diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 7b26e2f326..3af5de2593 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -16,14 +16,14 @@ use crate::SegmentReader; #[derive(Clone, Default)] pub(crate) struct AggregationsWithAccessor { - pub metrics: VecWithNames, - pub buckets: VecWithNames, + pub metrics: VecWithNames, + pub buckets: VecWithNames, } impl AggregationsWithAccessor { fn from_data( - metrics: VecWithNames, - buckets: VecWithNames, + metrics: VecWithNames, + buckets: VecWithNames, ) -> Self { Self { metrics, buckets } } @@ -34,67 +34,85 @@ impl AggregationsWithAccessor { } #[derive(Clone)] -pub struct BucketAggregationWithAccessor { +pub struct AggregationWithAccessor { /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. So eventually this needs to be Option or moved. pub(crate) accessor: Column, pub(crate) str_dict_column: Option, pub(crate) field_type: ColumnType, - pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, pub(crate) limits: AggregationLimits, pub(crate) column_block_accessor: ColumnBlockAccessor, + pub(crate) agg: Aggregation, } -fn get_numeric_or_date_column_types() -> &'static [ColumnType] { - &[ - ColumnType::F64, - ColumnType::U64, - ColumnType::I64, - ColumnType::DateTime, - ] -} - -impl BucketAggregationWithAccessor { - fn try_from_bucket( - bucket: &BucketAggregationType, +impl AggregationWithAccessor { + fn try_from_agg( + agg: &Aggregation, sub_aggregation: &Aggregations, reader: &SegmentReader, limits: AggregationLimits, - ) -> crate::Result { + ) -> crate::Result { let mut str_dict_column = None; - let (accessor, field_type) = match &bucket { - BucketAggregationType::Range(RangeAggregation { - field: field_name, .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Histogram(HistogramAggregation { - field: field_name, .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::DateHistogram(DateHistogramAggregationReq { - field: field_name, - .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Terms(TermsAggregation { - field: field_name, .. - }) => { - str_dict_column = reader.fast_fields().str(field_name)?; - get_ff_reader_and_validate(reader, field_name, None)? - } + let (accessor, field_type) = match &agg { + Aggregation::Bucket(b) => match &b.bucket_agg { + BucketAggregationType::Range(RangeAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + BucketAggregationType::Histogram(HistogramAggregation { + field: field_name, + .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + BucketAggregationType::DateHistogram(DateHistogramAggregationReq { + field: field_name, + .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + BucketAggregationType::Terms(TermsAggregation { + field: field_name, .. + }) => { + str_dict_column = reader.fast_fields().str(field_name)?; + get_ff_reader_and_validate(reader, field_name, None)? + } + }, + Aggregation::Metric(metric) => match &metric { + MetricAggregation::Average(AverageAggregation { field: field_name }) + | MetricAggregation::Count(CountAggregation { field: field_name }) + | MetricAggregation::Max(MaxAggregation { field: field_name }) + | MetricAggregation::Min(MinAggregation { field: field_name }) + | MetricAggregation::Stats(StatsAggregation { field: field_name }) + | MetricAggregation::Sum(SumAggregation { field: field_name }) => { + let (accessor, field_type) = get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?; + + (accessor, field_type) + } + MetricAggregation::Percentiles(percentiles) => { + let (accessor, field_type) = get_ff_reader_and_validate( + reader, + percentiles.field_name(), + Some(get_numeric_or_date_column_types()), + )?; + (accessor, field_type) + } + }, }; let sub_aggregation = sub_aggregation.clone(); - Ok(BucketAggregationWithAccessor { + Ok(AggregationWithAccessor { accessor, field_type, sub_aggregation: get_aggs_with_accessor_and_validate( @@ -102,7 +120,7 @@ impl BucketAggregationWithAccessor { reader, &limits.clone(), )?, - bucket_agg: bucket.clone(), + agg: agg.clone(), str_dict_column, limits, column_block_accessor: Default::default(), @@ -110,56 +128,13 @@ impl BucketAggregationWithAccessor { } } -/// Contains the metric request and the fast field accessor. -#[derive(Clone)] -pub struct MetricAggregationWithAccessor { - pub metric: MetricAggregation, - pub field_type: ColumnType, - pub accessor: Column, - pub column_block_accessor: ColumnBlockAccessor, -} - -impl MetricAggregationWithAccessor { - fn try_from_metric( - metric: &MetricAggregation, - reader: &SegmentReader, - ) -> crate::Result { - match &metric { - MetricAggregation::Average(AverageAggregation { field: field_name }) - | MetricAggregation::Count(CountAggregation { field: field_name }) - | MetricAggregation::Max(MaxAggregation { field: field_name }) - | MetricAggregation::Min(MinAggregation { field: field_name }) - | MetricAggregation::Stats(StatsAggregation { field: field_name }) - | MetricAggregation::Sum(SumAggregation { field: field_name }) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?; - - Ok(MetricAggregationWithAccessor { - accessor, - field_type, - metric: metric.clone(), - column_block_accessor: Default::default(), - }) - } - MetricAggregation::Percentiles(percentiles) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - percentiles.field_name(), - Some(get_numeric_or_date_column_types()), - )?; - - Ok(MetricAggregationWithAccessor { - accessor, - field_type, - metric: metric.clone(), - column_block_accessor: Default::default(), - }) - } - } - } +fn get_numeric_or_date_column_types() -> &'static [ColumnType] { + &[ + ColumnType::F64, + ColumnType::U64, + ColumnType::I64, + ColumnType::DateTime, + ] } pub(crate) fn get_aggs_with_accessor_and_validate( @@ -167,22 +142,27 @@ pub(crate) fn get_aggs_with_accessor_and_validate( reader: &SegmentReader, limits: &AggregationLimits, ) -> crate::Result { - let mut metrics = vec![]; - let mut buckets = vec![]; + let mut metrics = Vec::new(); + let mut buckets = Vec::new(); for (key, agg) in aggs.iter() { match agg { Aggregation::Bucket(bucket) => buckets.push(( key.to_string(), - BucketAggregationWithAccessor::try_from_bucket( - &bucket.bucket_agg, + AggregationWithAccessor::try_from_agg( + &agg, bucket.get_sub_aggs(), reader, limits.clone(), )?, )), - Aggregation::Metric(metric) => metrics.push(( + Aggregation::Metric(_metric) => metrics.push(( key.to_string(), - MetricAggregationWithAccessor::try_from_metric(metric, reader)?, + AggregationWithAccessor::try_from_agg( + &agg, + &Default::default(), + reader, + limits.clone(), + )?, )), } } diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index fdcb8095a0..1eefa77b5c 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -10,7 +10,7 @@ use tantivy_bitpacker::minmax; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::intermediate_agg_result::{ @@ -309,7 +309,7 @@ impl SegmentHistogramCollector { } pub fn into_intermediate_bucket_result( self, - agg_with_accessor: &BucketAggregationWithAccessor, + agg_with_accessor: &AggregationWithAccessor, ) -> crate::Result { let mut buckets = Vec::with_capacity(self.buckets.len()); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 8646a15d8b..779d765e12 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use super::{CustomOrder, Order, OrderTarget}; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, @@ -364,7 +364,7 @@ impl SegmentTermCollector { #[inline] pub(crate) fn into_intermediate_bucket_result( mut self, - agg_with_accessor: &BucketAggregationWithAccessor, + agg_with_accessor: &AggregationWithAccessor, ) -> crate::Result { let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect(); diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index 1e33a79166..17fbb4ef90 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use super::*; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, MetricAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, @@ -240,7 +240,7 @@ impl SegmentPercentilesCollector { pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut MetricAggregationWithAccessor, + agg_accessor: &mut AggregationWithAccessor, ) { agg_accessor .column_block_accessor diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 0500abab62..88c1ef83fd 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use super::*; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, MetricAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ @@ -179,7 +179,7 @@ impl SegmentStatsCollector { pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut MetricAggregationWithAccessor, + agg_accessor: &mut AggregationWithAccessor, ) { agg_accessor .column_block_accessor diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 68bb254f1c..771bfce208 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -6,10 +6,8 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimits; -use super::agg_req::MetricAggregation; -use super::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, -}; +use super::agg_req::{Aggregation, MetricAggregation}; +use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; use super::intermediate_agg_result::IntermediateAggregationResults; use super::metric::{ @@ -70,95 +68,111 @@ pub(crate) fn build_segment_agg_collector( if req.buckets.is_empty() && req.metrics.len() == 1 { let req = &req.metrics.values[0]; let accessor_idx = 0; - return build_metric_segment_agg_collector(req, accessor_idx); + return build_single_agg_segment_collector(req, accessor_idx); } // Single bucket special case if req.metrics.is_empty() && req.buckets.len() == 1 { let req = &req.buckets.values[0]; let accessor_idx = 0; - return build_bucket_segment_agg_collector(req, accessor_idx); + return build_single_agg_segment_collector(req, accessor_idx); } let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?; Ok(Box::new(agg)) } -pub(crate) fn build_metric_segment_agg_collector( - req: &MetricAggregationWithAccessor, +pub(crate) fn build_single_agg_segment_collector( + req: &AggregationWithAccessor, accessor_idx: usize, ) -> crate::Result> { - match &req.metric { - MetricAggregation::Average(AverageAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Average, - accessor_idx, - ))) - } - MetricAggregation::Count(CountAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count, accessor_idx), - )), - MetricAggregation::Max(MaxAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max, accessor_idx), - )), - MetricAggregation::Min(MinAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min, accessor_idx), - )), - MetricAggregation::Stats(StatsAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats, accessor_idx), - )), - MetricAggregation::Sum(SumAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum, accessor_idx), - )), - MetricAggregation::Percentiles(percentiles_req) => Ok(Box::new( - SegmentPercentilesCollector::from_req_and_validate( - percentiles_req, - req.field_type, - accessor_idx, - )?, - )), - } -} - -pub(crate) fn build_bucket_segment_agg_collector( - req: &BucketAggregationWithAccessor, - accessor_idx: usize, -) -> crate::Result> { - match &req.bucket_agg { - BucketAggregationType::Terms(terms_req) => { - Ok(Box::new(SegmentTermCollector::from_req_and_validate( - terms_req, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Range(range_req) => { - Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - range_req, - &req.sub_aggregation, - &req.limits, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Histogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - histogram, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::DateHistogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - &histogram.to_histogram_req()?, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } + match &req.agg { + Aggregation::Bucket(bucket) => match &bucket.bucket_agg { + BucketAggregationType::Terms(terms_req) => { + Ok(Box::new(SegmentTermCollector::from_req_and_validate( + terms_req, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)) + } + BucketAggregationType::Range(range_req) => { + Ok(Box::new(SegmentRangeCollector::from_req_and_validate( + range_req, + &req.sub_aggregation, + &req.limits, + req.field_type, + accessor_idx, + )?)) + } + BucketAggregationType::Histogram(histogram) => { + Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + histogram, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)) + } + BucketAggregationType::DateHistogram(histogram) => { + Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + &histogram.to_histogram_req()?, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)) + } + }, + Aggregation::Metric(metric) => match &metric { + MetricAggregation::Average(AverageAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Average, + accessor_idx, + ))) + } + MetricAggregation::Count(CountAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Count, + accessor_idx, + ))) + } + MetricAggregation::Max(MaxAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Max, + accessor_idx, + ))) + } + MetricAggregation::Min(MinAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Min, + accessor_idx, + ))) + } + MetricAggregation::Stats(StatsAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Stats, + accessor_idx, + ))) + } + MetricAggregation::Sum(SumAggregation { .. }) => { + Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Sum, + accessor_idx, + ))) + } + MetricAggregation::Percentiles(percentiles_req) => Ok(Box::new( + SegmentPercentilesCollector::from_req_and_validate( + percentiles_req, + req.field_type, + accessor_idx, + )?, + )), + }, } } @@ -252,7 +266,7 @@ impl GenericSegmentAggregationResultsCollector { .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { - build_bucket_segment_agg_collector(req, accessor_idx) + build_single_agg_segment_collector(req, accessor_idx) }) .collect::>>>()?; let metrics = req @@ -260,7 +274,7 @@ impl GenericSegmentAggregationResultsCollector { .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { - build_metric_segment_agg_collector(req, accessor_idx) + build_single_agg_segment_collector(req, accessor_idx) }) .collect::>>>()?; From 11a82869290b2b4fa45d26fc8c36b0f75b951277 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 18 Apr 2023 20:24:19 +0800 Subject: [PATCH 4/5] switch to Aggregation without serde_untagged --- examples/aggregation.rs | 6 +- src/aggregation/agg_bench.rs | 377 +++++------------- src/aggregation/agg_req.rs | 248 ++++++------ src/aggregation/agg_req_deser.rs | 173 -------- src/aggregation/agg_req_with_accessor.rs | 146 +++---- src/aggregation/agg_tests.rs | 54 ++- src/aggregation/bucket/histogram/histogram.rs | 10 +- src/aggregation/bucket/mod.rs | 10 + src/aggregation/bucket/range.rs | 8 +- src/aggregation/bucket/term_agg.rs | 32 +- src/aggregation/intermediate_agg_result.rs | 128 +++--- src/aggregation/metric/mod.rs | 9 +- src/aggregation/metric/percentiles.rs | 41 +- src/aggregation/metric/stats.rs | 83 ++-- src/aggregation/mod.rs | 1 - src/aggregation/segment_agg_result.rs | 234 ++++------- 16 files changed, 555 insertions(+), 1005 deletions(-) delete mode 100644 src/aggregation/agg_req_deser.rs diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 0946c2a085..eb03e7815e 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -7,12 +7,8 @@ // --- use serde_json::{Deserializer, Value}; -use tantivy::aggregation::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, -}; +use tantivy::aggregation::agg_req::Aggregations; use tantivy::aggregation::agg_result::AggregationResults; -use tantivy::aggregation::bucket::{RangeAggregation, RangeAggregationRange}; -use tantivy::aggregation::metric::AverageAggregation; use tantivy::aggregation::AggregationCollector; use tantivy::query::AllQuery; use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing, FAST}; diff --git a/src/aggregation/agg_bench.rs b/src/aggregation/agg_bench.rs index 9f8ca504fa..12a8989687 100644 --- a/src/aggregation/agg_bench.rs +++ b/src/aggregation/agg_bench.rs @@ -5,16 +5,10 @@ mod bench { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; + use serde_json::json; use test::{self, Bencher}; - use crate::aggregation::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, - }; - use crate::aggregation::bucket::{ - CustomOrder, HistogramAggregation, HistogramBounds, Order, OrderTarget, RangeAggregation, - TermsAggregation, - }; - use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::AggregationCollector; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING}; @@ -153,14 +147,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average": { "avg": { "field": "score", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -182,14 +172,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_f64".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "stats": { "field": "score_f64", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -211,14 +197,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "avg": { "field": "score_f64", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -265,22 +247,11 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![ - ( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - ), - ( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "avg": { "field": "score_f64" } }, + "average": { "avg": { "field": "score" } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -296,21 +267,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_few_terms".to_string(), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_few_terms" } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -326,30 +286,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { "field": "text_many_terms" }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -365,21 +310,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_many_terms" } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -395,25 +329,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - order: Some(CustomOrder { - order: Order::Desc, - target: OrderTarget::Key, - }), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -429,29 +348,17 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..30000f64).into(), - (30000f64..40000f64).into(), - (40000f64..50000f64).into(), - (50000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "range_f64": { "range": { "field": "score_f64", "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 30000 }, + { "from": 30000, "to": 40000 }, + { "from": 40000, "to": 50000 }, + { "from": 50000, "to": 60000 } + ] } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -467,38 +374,25 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..30000f64).into(), - (30000f64..40000f64).into(), - (40000f64..50000f64).into(), - (50000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "range": { + "field": "score_f64", + "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 30000 }, + { "from": 30000, "to": 40000 }, + { "from": 40000, "to": 50000 }, + { "from": 50000, "to": 60000 } + ] + }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -519,26 +413,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, - hard_bounds: Some(HistogramBounds { - min: 1000.0, - max: 300_000.0, - }), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { "histogram": { "field": "score_f64", "interval": 100, "hard_bounds": { "min": 1000, "max": 300000 } } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); let searcher = reader.searcher(); @@ -553,31 +431,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, // 1000 buckets - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "histogram": { "field": "score_f64", "interval": 100 }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -593,22 +455,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, // 1000 buckets - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "histogram": { + "field": "score_f64", + "interval": 100 // 1000 buckets + }, + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -630,43 +485,23 @@ mod bench { IndexRecordOption::Basic, ); - let sub_agg_req_1: Aggregations = vec![( - "average_in_range".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![ - ( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: sub_agg_req_1, - } - .into(), - ), - ), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "range": { + "field": "score_f64", + "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 60000 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "score" } } + } + }, + "average": { "avg": { "field": "score" } } + })) + .unwrap(); let collector = get_collector(agg_req_1); diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index b060809d06..1aa7fa6a5e 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -44,45 +44,30 @@ use super::metric::{ /// The key is the user defined name of the aggregation. pub type Aggregations = HashMap; -impl Aggregation { - pub fn as_bucket(&self) -> Option<&Box> { - match self { - Aggregation::Bucket(bucket) => Some(bucket), - _ => None, - } - } - pub fn as_metric(&self) -> Option<&MetricAggregation> { - match self { - Aggregation::Metric(metric) => Some(metric), - _ => None, - } - } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// Aggregation request. +/// +/// An aggregation is either a bucket or a metric. +pub struct Aggregation { + /// The aggregation variant, which can be either a bucket or a metric. + #[serde(flatten)] + pub agg: AggregationVariants, + /// The sub_aggregations, only valid for bucket type aggregations. Each bucket will aggregate + /// on the document set in the bucket. + #[serde(rename = "aggs")] + #[serde(default)] + #[serde(skip_serializing_if = "Aggregations::is_empty")] + pub sub_aggregation: Aggregations, } -impl BucketAggregation { +impl Aggregation { pub(crate) fn sub_aggregation(&self) -> &Aggregations { &self.sub_aggregation } - pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { - match &self.bucket_agg { - BucketAggregationType::Range(range) => Some(range), - _ => None, - } - } - pub(crate) fn as_histogram(&self) -> crate::Result> { - match &self.bucket_agg { - BucketAggregationType::Histogram(histogram) => Ok(Some(histogram.clone())), - BucketAggregationType::DateHistogram(histogram) => { - Ok(Some(histogram.to_histogram_req()?)) - } - _ => Ok(None), - } - } - pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { - match &self.bucket_agg { - BucketAggregationType::Terms(terms) => Some(terms), - _ => None, - } + + fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { + fast_field_names.insert(self.agg.get_fast_field_name().to_string()); + fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); } } @@ -95,100 +80,24 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { fast_field_names } -/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`]. -/// -/// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum Aggregation { - /// Bucket aggregation, see [`BucketAggregation`] for details. - Bucket(Box), - /// Metric aggregation, see [`MetricAggregation`] for details. - Metric(MetricAggregation), -} - -impl Aggregation { - fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - match self { - Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names), - Aggregation::Metric(metric) => { - fast_field_names.insert(metric.get_fast_field_name().to_string()); - } - } - } -} - -/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which -/// determines whether or not a document in the falls into it. In other words, the buckets -/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can -/// fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also -/// compute and return the number of documents for each bucket. Bucket aggregations, as opposed to -/// metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for -/// the buckets created by their "parent" bucket aggregation. There are different bucket -/// aggregators, each with a different "bucketing" strategy. Some define a single bucket, some -/// define fixed number of multiple buckets, and others dynamically create the buckets during the -/// aggregation process. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct BucketAggregation { - /// Bucket aggregation strategy to group documents. - #[serde(flatten)] - pub bucket_agg: BucketAggregationType, - /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the - /// bucket. - #[serde(rename = "aggs")] - #[serde(default)] - #[serde(skip_serializing_if = "Aggregations::is_empty")] - pub sub_aggregation: Aggregations, -} - -impl BucketAggregation { - pub(crate) fn get_sub_aggs(&self) -> &Aggregations { - &self.sub_aggregation - } - fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - let fast_field_name = self.bucket_agg.get_fast_field_name(); - fast_field_names.insert(fast_field_name.to_string()); - fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); - } -} - -/// The bucket aggregation types. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum BucketAggregationType { +/// All aggregation types. +pub enum AggregationVariants { + // Bucket aggregation types /// Put data into buckets of user-defined ranges. #[serde(rename = "range")] Range(RangeAggregation), - /// Put data into buckets of user-defined ranges. + /// Put data into a histogram. #[serde(rename = "histogram")] Histogram(HistogramAggregation), - /// Put data into buckets of user-defined ranges. + /// Put data into a date histogram. #[serde(rename = "date_histogram")] DateHistogram(DateHistogramAggregationReq), /// Put data into buckets of terms. #[serde(rename = "terms")] Terms(TermsAggregation), -} -impl BucketAggregationType { - fn get_fast_field_name(&self) -> &str { - match self { - BucketAggregationType::Terms(terms) => terms.field.as_str(), - BucketAggregationType::Range(range) => range.field.as_str(), - BucketAggregationType::Histogram(histogram) => histogram.field.as_str(), - BucketAggregationType::DateHistogram(histogram) => histogram.field.as_str(), - } - } -} - -/// The aggregations in this family compute metrics based on values extracted -/// from the documents that are being aggregated. Values are extracted from the fast field of -/// the document. - -/// Some aggregations output a single numeric metric (e.g. Average) and are called -/// single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are -/// called multi-value numeric metrics aggregation. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum MetricAggregation { + // Metric aggregation types /// Computes the average of the extracted values. #[serde(rename = "avg")] Average(AverageAggregation), @@ -213,31 +122,102 @@ pub enum MetricAggregation { Percentiles(PercentilesAggregationReq), } -impl MetricAggregation { - pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { +impl AggregationVariants { + fn get_fast_field_name(&self) -> &str { + match self { + AggregationVariants::Terms(terms) => terms.field.as_str(), + AggregationVariants::Range(range) => range.field.as_str(), + AggregationVariants::Histogram(histogram) => histogram.field.as_str(), + AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), + AggregationVariants::Average(avg) => avg.field_name(), + AggregationVariants::Count(count) => count.field_name(), + AggregationVariants::Max(max) => max.field_name(), + AggregationVariants::Min(min) => min.field_name(), + AggregationVariants::Stats(stats) => stats.field_name(), + AggregationVariants::Sum(sum) => sum.field_name(), + AggregationVariants::Percentiles(per) => per.field_name(), + } + } + + pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { + match &self { + AggregationVariants::Range(range) => Some(range), + _ => None, + } + } + pub(crate) fn as_histogram(&self) -> crate::Result> { + match &self { + AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())), + AggregationVariants::DateHistogram(histogram) => { + Ok(Some(histogram.to_histogram_req()?)) + } + _ => Ok(None), + } + } + pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { match &self { - MetricAggregation::Percentiles(percentile_req) => Some(percentile_req), + AggregationVariants::Terms(terms) => Some(terms), _ => None, } } - fn get_fast_field_name(&self) -> &str { - match self { - MetricAggregation::Average(avg) => avg.field_name(), - MetricAggregation::Count(count) => count.field_name(), - MetricAggregation::Max(max) => max.field_name(), - MetricAggregation::Min(min) => min.field_name(), - MetricAggregation::Stats(stats) => stats.field_name(), - MetricAggregation::Sum(sum) => sum.field_name(), - MetricAggregation::Percentiles(per) => per.field_name(), + pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { + match &self { + AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), + _ => None, } } } #[cfg(test)] mod tests { + use super::*; + #[test] + fn deser_json_test() { + let agg_req_json = r#"{ + "price_avg": { "avg": { "field": "price" } }, + "price_count": { "value_count": { "field": "price" } }, + "price_max": { "max": { "field": "price" } }, + "price_min": { "min": { "field": "price" } }, + "price_stats": { "stats": { "field": "price" } }, + "price_sum": { "sum": { "field": "price" } } + }"#; + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } + + #[test] + fn deser_json_test_bucket() { + let agg_req_json = r#" + { + "termagg": { + "terms": { + "field": "json.mixed_type", + "order": { "min_price": "desc" } + }, + "aggs": { + "min_price": { "min": { "field": "json.mixed_type" } } + } + }, + "rangeagg": { + "range": { + "field": "json.mixed_type", + "ranges": [ + { "to": 3.0 }, + { "from": 19.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "json.mixed_type" } } + } + } + } "#; + + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } + #[test] fn test_metric_aggregations_deser() { let agg_req_json = r#"{ @@ -251,22 +231,22 @@ mod tests { let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); assert!( - matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price") + matches!(&agg_req.get("price_avg").unwrap().agg, AggregationVariants::Average(avg) if avg.field == "price") ); assert!( - matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price") + matches!(&agg_req.get("price_count").unwrap().agg, AggregationVariants::Count(count) if count.field == "price") ); assert!( - matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price") + matches!(&agg_req.get("price_max").unwrap().agg, AggregationVariants::Max(max) if max.field == "price") ); assert!( - matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price") + matches!(&agg_req.get("price_min").unwrap().agg, AggregationVariants::Min(min) if min.field == "price") ); assert!( - matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price") + matches!(&agg_req.get("price_stats").unwrap().agg, AggregationVariants::Stats(stats) if stats.field == "price") ); assert!( - matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price") + matches!(&agg_req.get("price_sum").unwrap().agg, AggregationVariants::Sum(sum) if sum.field == "price") ); } diff --git a/src/aggregation/agg_req_deser.rs b/src/aggregation/agg_req_deser.rs deleted file mode 100644 index 6546024b40..0000000000 --- a/src/aggregation/agg_req_deser.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use serde::*; - -use super::bucket::*; -use super::metric::*; -pub type Aggregations = HashMap; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct AggregationDeser { - /// Bucket aggregation strategy to group documents. - #[serde(flatten)] - pub agg: AggregationVariants, - /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the - /// bucket. - #[serde(rename = "aggs")] - #[serde(default)] - #[serde(skip_serializing_if = "Aggregations::is_empty")] - pub sub_aggregation: Aggregations, -} - -impl AggregationDeser { - fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - fast_field_names.insert(self.agg.get_fast_field_name().to_string()); - fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); - } -} - -/// Extract all fast field names used in the tree. -pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { - let mut fast_field_names = Default::default(); - for el in aggs.values() { - el.get_fast_field_names(&mut fast_field_names) - } - fast_field_names -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum AggregationVariants { - // Bucket aggregation types - /// Put data into buckets of user-defined ranges. - #[serde(rename = "range")] - Range(RangeAggregation), - /// Put data into a histogram. - #[serde(rename = "histogram")] - Histogram(HistogramAggregation), - /// Put data into a date histogram. - #[serde(rename = "date_histogram")] - DateHistogram(DateHistogramAggregationReq), - /// Put data into buckets of terms. - #[serde(rename = "terms")] - Terms(TermsAggregation), - - // Metric aggregation types - /// Computes the average of the extracted values. - #[serde(rename = "avg")] - Average(AverageAggregation), - /// Counts the number of extracted values. - #[serde(rename = "value_count")] - Count(CountAggregation), - /// Finds the maximum value. - #[serde(rename = "max")] - Max(MaxAggregation), - /// Finds the minimum value. - #[serde(rename = "min")] - Min(MinAggregation), - /// Computes a collection of statistics (`min`, `max`, `sum`, `count`, and `avg`) over the - /// extracted values. - #[serde(rename = "stats")] - Stats(StatsAggregation), - /// Computes the sum of the extracted values. - #[serde(rename = "sum")] - Sum(SumAggregation), - /// Computes the sum of the extracted values. - #[serde(rename = "percentiles")] - Percentiles(PercentilesAggregationReq), -} - -impl AggregationVariants { - fn get_fast_field_name(&self) -> &str { - match self { - AggregationVariants::Terms(terms) => terms.field.as_str(), - AggregationVariants::Range(range) => range.field.as_str(), - AggregationVariants::Histogram(histogram) => histogram.field.as_str(), - AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), - AggregationVariants::Average(avg) => avg.field_name(), - AggregationVariants::Count(count) => count.field_name(), - AggregationVariants::Max(max) => max.field_name(), - AggregationVariants::Min(min) => min.field_name(), - AggregationVariants::Stats(stats) => stats.field_name(), - AggregationVariants::Sum(sum) => sum.field_name(), - AggregationVariants::Percentiles(per) => per.field_name(), - } - } - - pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { - match &self { - AggregationVariants::Range(range) => Some(range), - _ => None, - } - } - pub(crate) fn as_histogram(&self) -> crate::Result> { - match &self { - AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())), - AggregationVariants::DateHistogram(histogram) => { - Ok(Some(histogram.to_histogram_req()?)) - } - _ => Ok(None), - } - } - pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { - match &self { - AggregationVariants::Terms(terms) => Some(terms), - _ => None, - } - } - - pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { - match &self { - AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), - _ => None, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn deser_json_test() { - let agg_req_json = r#"{ - "price_avg": { "avg": { "field": "price" } }, - "price_count": { "value_count": { "field": "price" } }, - "price_max": { "max": { "field": "price" } }, - "price_min": { "min": { "field": "price" } }, - "price_stats": { "stats": { "field": "price" } }, - "price_sum": { "sum": { "field": "price" } } - }"#; - let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); - } - - #[test] - fn deser_json_test_bucket() { - let agg_req_json = r#" - { - "termagg": { - "terms": { - "field": "json.mixed_type", - "order": { "min_price": "desc" } - }, - "aggs": { - "min_price": { "min": { "field": "json.mixed_type" } } - } - }, - "rangeagg": { - "range": { - "field": "json.mixed_type", - "ranges": [ - { "to": 3.0 }, - { "from": 19.0, "to": 20.0 }, - { "from": 20.0 } - ] - }, - "aggs": { - "average_in_range": { "avg": { "field": "json.mixed_type" } } - } - } - } "#; - - let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); - } -} diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 3af5de2593..1ac2b7da6c 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -2,7 +2,7 @@ use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; -use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; +use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; @@ -16,20 +16,16 @@ use crate::SegmentReader; #[derive(Clone, Default)] pub(crate) struct AggregationsWithAccessor { - pub metrics: VecWithNames, - pub buckets: VecWithNames, + pub aggs: VecWithNames, } impl AggregationsWithAccessor { - fn from_data( - metrics: VecWithNames, - buckets: VecWithNames, - ) -> Self { - Self { metrics, buckets } + fn from_data(aggs: VecWithNames) -> Self { + Self { aggs } } pub fn is_empty(&self) -> bool { - self.metrics.is_empty() && self.buckets.is_empty() + self.aggs.is_empty() } } @@ -54,62 +50,57 @@ impl AggregationWithAccessor { limits: AggregationLimits, ) -> crate::Result { let mut str_dict_column = None; - let (accessor, field_type) = match &agg { - Aggregation::Bucket(b) => match &b.bucket_agg { - BucketAggregationType::Range(RangeAggregation { - field: field_name, .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Histogram(HistogramAggregation { - field: field_name, - .. - }) => get_ff_reader_and_validate( + use AggregationVariants::*; + let (accessor, field_type) = match &agg.agg { + Range(RangeAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + Histogram(HistogramAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + DateHistogram(DateHistogramAggregationReq { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + Terms(TermsAggregation { + field: field_name, .. + }) => { + str_dict_column = reader.fast_fields().str(field_name)?; + get_ff_reader_and_validate(reader, field_name, None)? + } + Average(AverageAggregation { field: field_name }) + | Count(CountAggregation { field: field_name }) + | Max(MaxAggregation { field: field_name }) + | Min(MinAggregation { field: field_name }) + | Stats(StatsAggregation { field: field_name }) + | Sum(SumAggregation { field: field_name }) => { + let (accessor, field_type) = get_ff_reader_and_validate( reader, field_name, Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::DateHistogram(DateHistogramAggregationReq { - field: field_name, - .. - }) => get_ff_reader_and_validate( + )?; + + (accessor, field_type) + } + Percentiles(percentiles) => { + let (accessor, field_type) = get_ff_reader_and_validate( reader, - field_name, + percentiles.field_name(), Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Terms(TermsAggregation { - field: field_name, .. - }) => { - str_dict_column = reader.fast_fields().str(field_name)?; - get_ff_reader_and_validate(reader, field_name, None)? - } - }, - Aggregation::Metric(metric) => match &metric { - MetricAggregation::Average(AverageAggregation { field: field_name }) - | MetricAggregation::Count(CountAggregation { field: field_name }) - | MetricAggregation::Max(MaxAggregation { field: field_name }) - | MetricAggregation::Min(MinAggregation { field: field_name }) - | MetricAggregation::Stats(StatsAggregation { field: field_name }) - | MetricAggregation::Sum(SumAggregation { field: field_name }) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?; - - (accessor, field_type) - } - MetricAggregation::Percentiles(percentiles) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - percentiles.field_name(), - Some(get_numeric_or_date_column_types()), - )?; - (accessor, field_type) - } - }, + )?; + (accessor, field_type) + } }; let sub_aggregation = sub_aggregation.clone(); Ok(AggregationWithAccessor { @@ -142,33 +133,20 @@ pub(crate) fn get_aggs_with_accessor_and_validate( reader: &SegmentReader, limits: &AggregationLimits, ) -> crate::Result { - let mut metrics = Vec::new(); - let mut buckets = Vec::new(); + let mut aggss = Vec::new(); for (key, agg) in aggs.iter() { - match agg { - Aggregation::Bucket(bucket) => buckets.push(( - key.to_string(), - AggregationWithAccessor::try_from_agg( - &agg, - bucket.get_sub_aggs(), - reader, - limits.clone(), - )?, - )), - Aggregation::Metric(_metric) => metrics.push(( - key.to_string(), - AggregationWithAccessor::try_from_agg( - &agg, - &Default::default(), - reader, - limits.clone(), - )?, - )), - } + aggss.push(( + key.to_string(), + AggregationWithAccessor::try_from_agg( + agg, + agg.sub_aggregation(), + reader, + limits.clone(), + )?, + )); } Ok(AggregationsWithAccessor::from_data( - VecWithNames::from_entries(metrics), - VecWithNames::from_entries(buckets), + VecWithNames::from_entries(aggss), )) } diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 3f3b73b4c3..ad1dc2494e 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -1,11 +1,10 @@ use serde_json::Value; -use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; +use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; -use crate::aggregation::metric::AverageAggregation; use crate::aggregation::segment_agg_result::AggregationLimits; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; @@ -14,9 +13,12 @@ use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, Term}; fn get_avg_req(field_name: &str) -> Aggregation { - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name(field_name.to_string()), - )) + serde_json::from_value(json!({ + "avg": { + "field": field_name, + } + })) + .unwrap() } fn get_collector(agg_req: Aggregations) -> AggregationCollector { @@ -517,14 +519,14 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { let reader = index.reader()?; let avg_on_field = |field_name: &str| { - let agg_req_1: Aggregations = vec![( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name(field_name.to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average": { + "avg": { + "field": field_name, + }, + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -539,6 +541,32 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { r#"InvalidArgument("Field \"dummy_text\" is not configured as fast field")"# ); + let agg_req_1: Result = serde_json::from_value(json!({ + "average": { + "avg": { + "fieldd": "a", + }, + } + })); + + assert_eq!(agg_req_1.is_err(), true); + assert_eq!(agg_req_1.unwrap_err().to_string(), "missing field `field`"); + + let agg_req_1: Result = serde_json::from_value(json!({ + "average": { + "doesnotmatchanyagg": { + "field": "a", + }, + } + })); + + assert_eq!(agg_req_1.is_err(), true); + // TODO: This should list valid values + assert_eq!( + agg_req_1.unwrap_err().to_string(), + "no variant of enum AggregationVariants found in flattened data" + ); + // TODO: This should return an error // let agg_res = avg_on_field("not_exist_field").unwrap_err(); // assert_eq!( diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 1eefa77b5c..7d6112f597 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -223,8 +223,8 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { agg_with_accessor: &AggregationsWithAccessor, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; results.push(name, IntermediateAggregationResult::Bucket(bucket)); @@ -247,7 +247,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; let mem_pre = self.get_memory_consumption(); @@ -281,7 +281,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits; + let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; limits.add_memory_consumed(mem_delta as u64); limits.validate_memory_consumption()?; @@ -290,7 +290,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; for sub_aggregation in self.sub_aggregations.values_mut() { sub_aggregation.flush(sub_aggregation_accessor)?; diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index 3d178d9625..3ccc53e974 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -2,6 +2,16 @@ //! //! BucketAggregations create buckets of documents //! [`BucketAggregation`](super::agg_req::BucketAggregation). +//! Each bucket is associated with a rule which +//! determines whether or not a document in the falls into it. In other words, the buckets +//! effectively define document sets. Buckets are not necessarily disjunct, therefore a document can +//! fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also +//! compute and return the number of documents for each bucket. Bucket aggregations, as opposed to +//! metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for +//! the buckets created by their "parent" bucket aggregation. There are different bucket +//! aggregators, each with a different "bucketing" strategy. Some define a single bucket, some +//! define fixed number of multiple buckets, and others dynamically create the buckets during the +//! aggregation process. //! //! Results of final buckets are [`BucketResult`](super::agg_result::BucketResult). //! Results of intermediate buckets are diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 972bdd1dd0..82c4cbddc3 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -182,8 +182,8 @@ impl SegmentAggregationCollector for SegmentRangeCollector { results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { let field_type = self.column_type; - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let sub_agg = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; let buckets: FxHashMap = self .buckets @@ -223,7 +223,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; bucket_agg_accessor .column_block_accessor @@ -245,7 +245,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; for bucket in self.buckets.iter_mut() { if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 779d765e12..0a86ecd674 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -251,8 +251,8 @@ impl SegmentAggregationCollector for SegmentTermCollector { agg_with_accessor: &AggregationsWithAccessor, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; results.push(name, IntermediateAggregationResult::Bucket(bucket)); @@ -275,7 +275,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; let mem_pre = self.get_memory_consumption(); @@ -299,7 +299,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits; + let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; limits.add_memory_consumed(mem_delta as u64); limits.validate_memory_consumption()?; @@ -308,7 +308,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; self.term_buckets.force_flush(sub_aggregation_accessor)?; Ok(()) @@ -335,7 +335,7 @@ impl SegmentTermCollector { if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target { let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); - sub_aggregations.metrics.get(agg_name).ok_or_else(|| { + sub_aggregations.aggs.get(agg_name).ok_or_else(|| { TantivyError::InvalidArgument(format!( "could not find aggregation with name {} in metric sub_aggregations", agg_name @@ -523,8 +523,7 @@ pub(crate) fn cut_off_buckets( #[cfg(test)] mod tests { - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; - use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::tests::{ exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, @@ -639,23 +638,6 @@ mod tests { ]; let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?; - let _sub_agg: Aggregations = vec![ - ( - "avg_score".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ( - "stats_score".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - ), - ] - .into_iter() - .collect(); - let sub_agg: Aggregations = serde_json::from_value(json!({ "avg_score": { "avg": { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 79973489e0..8956d8af96 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -10,9 +10,7 @@ use rustc_hash::FxHashMap; use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use super::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, -}; +use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry}; use super::bucket::{ cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets, @@ -48,7 +46,7 @@ impl IntermediateAggregationResults { req: Aggregations, limits: &AggregationLimits, ) -> crate::Result { - let res = self.into_final_result_internal(&(req.into()), limits)?; + let res = self.into_final_result_internal(&req, limits)?; let bucket_count = res.get_bucket_count() as u32; if bucket_count > limits.get_bucket_limit() { return Err(TantivyError::AggregationError( @@ -79,14 +77,7 @@ impl IntermediateAggregationResults { if results.len() != req.len() { for (key, req) in req.iter() { if !results.contains_key(key) { - let empty_res = match req { - Aggregation::Bucket(b) => IntermediateAggregationResult::Bucket( - IntermediateBucketResult::empty_from_req(&b.bucket_agg), - ), - Aggregation::Metric(m) => IntermediateAggregationResult::Metric( - IntermediateMetricResult::empty_from_req(m), - ), - }; + let empty_res = empty_from_req(req); results.insert(key.to_string(), empty_res.into_final_result(req, limits)?); } } @@ -98,14 +89,7 @@ impl IntermediateAggregationResults { pub(crate) fn empty_from_req(req: &Aggregations) -> Self { let mut aggs_res: VecWithNames = VecWithNames::default(); for (key, req) in req.iter() { - let empty_res = match req { - Aggregation::Bucket(b) => IntermediateAggregationResult::Bucket( - IntermediateBucketResult::empty_from_req(&b.bucket_agg), - ), - Aggregation::Metric(m) => IntermediateAggregationResult::Metric( - IntermediateMetricResult::empty_from_req(m), - ), - }; + let empty_res = empty_from_req(req); aggs_res.push(key.to_string(), empty_res); } @@ -124,6 +108,45 @@ impl IntermediateAggregationResults { } } +pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult { + use AggregationVariants::*; + match req.agg { + Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms( + Default::default(), + )), + Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + Default::default(), + )), + Histogram(_) | DateHistogram(_) => { + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Histogram { + buckets: Vec::new(), + column_type: None, + }) + } + Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average( + IntermediateAverage::default(), + )), + Count(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Count( + IntermediateCount::default(), + )), + Max(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Max( + IntermediateMax::default(), + )), + Min(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Min( + IntermediateMin::default(), + )), + Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats( + IntermediateStats::default(), + )), + Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum( + IntermediateSum::default(), + )), + Percentiles(_) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::Percentiles(PercentilesCollector::default()), + ), + } +} + /// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateAggregationResult { @@ -140,19 +163,12 @@ impl IntermediateAggregationResult { limits: &AggregationLimits, ) -> crate::Result { let res = match self { - IntermediateAggregationResult::Bucket(bucket) => AggregationResult::BucketResult( - bucket.into_final_bucket_result( - req.as_bucket() - .expect("mismatch bucket result and metric request type"), - limits, - )?, - ), - IntermediateAggregationResult::Metric(metric) => AggregationResult::MetricResult( - metric.into_final_metric_result( - req.as_metric() - .expect("mismatch metric result and bucket request type"), - ), - ), + IntermediateAggregationResult::Bucket(bucket) => { + AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?) + } + IntermediateAggregationResult::Metric(metric) => { + AggregationResult::MetricResult(metric.into_final_metric_result(req)) + } }; Ok(res) } @@ -191,7 +207,7 @@ pub enum IntermediateMetricResult { } impl IntermediateMetricResult { - fn into_final_metric_result(self, req: &MetricAggregation) -> MetricResult { + fn into_final_metric_result(self, req: &Aggregation) -> MetricResult { match self { IntermediateMetricResult::Average(intermediate_avg) => { MetricResult::Average(intermediate_avg.finalize().into()) @@ -212,30 +228,12 @@ impl IntermediateMetricResult { MetricResult::Sum(intermediate_sum.finalize().into()) } IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles( - percentiles.into_final_result(req.as_percentile().expect("unexpected metric type")), + percentiles + .into_final_result(req.agg.as_percentile().expect("unexpected metric type")), ), } } - pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self { - match req { - MetricAggregation::Average(_) => { - IntermediateMetricResult::Average(IntermediateAverage::default()) - } - MetricAggregation::Count(_) => { - IntermediateMetricResult::Count(IntermediateCount::default()) - } - MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()), - MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()), - MetricAggregation::Stats(_) => { - IntermediateMetricResult::Stats(IntermediateStats::default()) - } - MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()), - MetricAggregation::Percentiles(_) => { - IntermediateMetricResult::Percentiles(PercentilesCollector::default()) - } - } - } fn merge_fruits(&mut self, other: IntermediateMetricResult) -> crate::Result<()> { match (self, other) { ( @@ -302,7 +300,7 @@ pub enum IntermediateBucketResult { impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, - req: &BucketAggregation, + req: &Aggregation, limits: &AggregationLimits, ) -> crate::Result { match self { @@ -313,7 +311,8 @@ impl IntermediateBucketResult { .map(|bucket| { bucket.into_final_bucket_entry( req.sub_aggregation(), - req.as_range() + req.agg + .as_range() .expect("unexpected aggregation, expected histogram aggregation"), range_res.column_type, limits, @@ -328,6 +327,7 @@ impl IntermediateBucketResult { }); let is_keyed = req + .agg .as_range() .expect("unexpected aggregation, expected range aggregation") .keyed; @@ -348,6 +348,7 @@ impl IntermediateBucketResult { buckets, } => { let histogram_req = &req + .agg .as_histogram()? .expect("unexpected aggregation, expected histogram aggregation"); let buckets = intermediate_histogram_buckets_to_final_buckets( @@ -371,7 +372,8 @@ impl IntermediateBucketResult { Ok(BucketResult::Histogram { buckets }) } IntermediateBucketResult::Terms(terms) => terms.into_final_result( - req.as_term() + req.agg + .as_term() .expect("unexpected aggregation, expected term aggregation"), req.sub_aggregation(), limits, @@ -379,18 +381,6 @@ impl IntermediateBucketResult { } } - pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { - match req { - BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()), - BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()), - BucketAggregationType::Histogram(_) | BucketAggregationType::DateHistogram(_) => { - IntermediateBucketResult::Histogram { - buckets: vec![], - column_type: None, - } - } - } - } fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> { match (self, other) { ( diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index b1d05a05d3..50ff0389ff 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -1,7 +1,12 @@ //! Module for all metric aggregations. //! -//! The aggregations in this family compute metrics, see [super::agg_req::MetricAggregation] for -//! details. +//! The aggregations in this family compute metrics based on values extracted +//! from the documents that are being aggregated. Values are extracted from the fast field of +//! the document. +//! Some aggregations output a single numeric metric (e.g. Average) and are called +//! single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are +//! called multi-value numeric metrics aggregation. + mod average; mod count; mod max; diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index 17fbb4ef90..db496a7dcd 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -260,7 +260,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { agg_with_accessor: &AggregationsWithAccessor, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); results.push( @@ -277,7 +277,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { doc: crate::DocId, agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor; + let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; for val in field.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); @@ -293,7 +293,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.metrics.values[self.accessor_idx]; + let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; self.collect_block_with_field(docs, field); Ok(()) } @@ -308,9 +308,8 @@ mod tests { use rand::SeedableRng; use serde_json::Value; - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::AggregationResults; - use crate::aggregation::metric::PercentilesAggregationReq; use crate::aggregation::tests::{ get_test_index_from_values, get_test_index_from_values_and_terms, }; @@ -324,14 +323,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "percentiles".to_string(), - Aggregation::Metric(MetricAggregation::Percentiles( - PercentilesAggregationReq::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "percentiles": { + "percentiles": { + "field": "score", + } + }, + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -362,14 +361,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "percentiles".to_string(), - Aggregation::Metric(MetricAggregation::Percentiles( - PercentilesAggregationReq::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "percentiles": { + "percentiles": { + "field": "score", + } + }, + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 88c1ef83fd..bd63f08dd5 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -199,7 +199,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector { agg_with_accessor: &AggregationsWithAccessor, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); let intermediate_metric_result = match self.collecting_for { SegmentStatsType::Average => { @@ -234,7 +234,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector { doc: crate::DocId, agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor; + let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; for val in field.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); @@ -250,7 +250,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.metrics.values[self.accessor_idx]; + let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; self.collect_block_with_field(docs, field); Ok(()) } @@ -261,9 +261,8 @@ mod tests { use serde_json::Value; - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; + use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; - use crate::aggregation::metric::StatsAggregation; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values}; use crate::aggregation::AggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -277,14 +276,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats": { + "stats": { + "field": "score", + }, + } + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -313,14 +312,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats": { + "stats": { + "field": "score", + }, + } + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -372,29 +371,25 @@ mod tests { .unwrap() }; - let agg_req_1: Aggregations = vec![ - ( - "stats_i64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_i64".to_string(), - ))), - ), - ( - "stats_f64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_f64".to_string(), - ))), - ), - ( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - ), - ("range".to_string(), range_agg), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats_i64": { + "stats": { + "field": "score_i64", + }, + }, + "stats_f64": { + "stats": { + "field": "score_f64", + }, + }, + "stats": { + "stats": { + "field": "score", + }, + }, + "range": range_agg + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 38bfce95b4..2e8e0dbedb 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -157,7 +157,6 @@ mod agg_limits; pub mod agg_req; -pub mod agg_req_deser; mod agg_req_with_accessor; pub mod agg_result; pub mod bucket; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 771bfce208..1353c56c68 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimits; -use super::agg_req::{Aggregation, MetricAggregation}; +use super::agg_req::AggregationVariants; use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; use super::intermediate_agg_result::IntermediateAggregationResults; @@ -15,7 +15,6 @@ use super::metric::{ SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, SumAggregation, }; -use crate::aggregation::agg_req::BucketAggregationType; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( @@ -64,16 +63,9 @@ impl Clone for Box { pub(crate) fn build_segment_agg_collector( req: &AggregationsWithAccessor, ) -> crate::Result> { - // Single metric special case - if req.buckets.is_empty() && req.metrics.len() == 1 { - let req = &req.metrics.values[0]; - let accessor_idx = 0; - return build_single_agg_segment_collector(req, accessor_idx); - } - - // Single bucket special case - if req.metrics.is_empty() && req.buckets.len() == 1 { - let req = &req.buckets.values[0]; + // Single collector special case + if req.aggs.is_empty() && req.aggs.len() == 1 { + let req = &req.aggs.values[0]; let accessor_idx = 0; return build_single_agg_segment_collector(req, accessor_idx); } @@ -86,93 +78,70 @@ pub(crate) fn build_single_agg_segment_collector( req: &AggregationWithAccessor, accessor_idx: usize, ) -> crate::Result> { - match &req.agg { - Aggregation::Bucket(bucket) => match &bucket.bucket_agg { - BucketAggregationType::Terms(terms_req) => { - Ok(Box::new(SegmentTermCollector::from_req_and_validate( - terms_req, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Range(range_req) => { - Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - range_req, - &req.sub_aggregation, - &req.limits, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Histogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - histogram, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::DateHistogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - &histogram.to_histogram_req()?, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - }, - Aggregation::Metric(metric) => match &metric { - MetricAggregation::Average(AverageAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Average, - accessor_idx, - ))) - } - MetricAggregation::Count(CountAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Count, - accessor_idx, - ))) - } - MetricAggregation::Max(MaxAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Max, - accessor_idx, - ))) - } - MetricAggregation::Min(MinAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Min, - accessor_idx, - ))) - } - MetricAggregation::Stats(StatsAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Stats, - accessor_idx, - ))) - } - MetricAggregation::Sum(SumAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Sum, - accessor_idx, - ))) - } - MetricAggregation::Percentiles(percentiles_req) => Ok(Box::new( - SegmentPercentilesCollector::from_req_and_validate( - percentiles_req, - req.field_type, - accessor_idx, - )?, - )), - }, + use AggregationVariants::*; + match &req.agg.agg { + Terms(terms_req) => Ok(Box::new(SegmentTermCollector::from_req_and_validate( + terms_req, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( + range_req, + &req.sub_aggregation, + &req.limits, + req.field_type, + accessor_idx, + )?)), + Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + histogram, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + &histogram.to_histogram_req()?, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + Average(AverageAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Average, + accessor_idx, + ))), + Count(CountAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Count, + accessor_idx, + ))), + Max(MaxAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Max, + accessor_idx, + ))), + Min(MinAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Min, + accessor_idx, + ))), + Stats(StatsAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Stats, + accessor_idx, + ))), + Sum(SumAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Sum, + accessor_idx, + ))), + Percentiles(percentiles_req) => Ok(Box::new( + SegmentPercentilesCollector::from_req_and_validate( + percentiles_req, + req.field_type, + accessor_idx, + )?, + )), } } @@ -181,15 +150,13 @@ pub(crate) fn build_single_agg_segment_collector( /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one /// and can provide specialized versions instead, that remove some of its overhead. pub(crate) struct GenericSegmentAggregationResultsCollector { - pub(crate) metrics: Option>>, - pub(crate) buckets: Option>>, + pub(crate) aggs: Vec>, } impl Debug for GenericSegmentAggregationResultsCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentAggregationResultsCollector") - .field("metrics", &self.metrics) - .field("buckets", &self.buckets) + .field("aggs", &self.aggs) .finish() } } @@ -200,16 +167,9 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { agg_with_accessor: &AggregationsWithAccessor, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - if let Some(buckets) = self.buckets { - for bucket in buckets { - bucket.add_intermediate_aggregation_result(agg_with_accessor, results)?; - } - }; - if let Some(metrics) = self.metrics { - for metric in metrics { - metric.add_intermediate_aggregation_result(agg_with_accessor, results)?; - } - }; + for agg in self.aggs { + agg.add_intermediate_aggregation_result(agg_with_accessor, results)?; + } Ok(()) } @@ -229,31 +189,16 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - if let Some(metrics) = self.metrics.as_mut() { - for collector in metrics { - collector.collect_block(docs, agg_with_accessor)?; - } - } - - if let Some(buckets) = self.buckets.as_mut() { - for collector in buckets { - collector.collect_block(docs, agg_with_accessor)?; - } + for collector in &mut self.aggs { + collector.collect_block(docs, agg_with_accessor)?; } Ok(()) } fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { - if let Some(metrics) = &mut self.metrics { - for collector in metrics { - collector.flush(agg_with_accessor)?; - } - } - if let Some(buckets) = &mut self.buckets { - for collector in buckets { - collector.flush(agg_with_accessor)?; - } + for collector in &mut self.aggs { + collector.flush(agg_with_accessor)?; } Ok(()) } @@ -261,34 +206,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { impl GenericSegmentAggregationResultsCollector { pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result { - let buckets = req - .buckets + let aggs = req + .aggs .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { build_single_agg_segment_collector(req, accessor_idx) }) .collect::>>>()?; - let metrics = req - .metrics - .iter() - .enumerate() - .map(|(accessor_idx, (_key, req))| { - build_single_agg_segment_collector(req, accessor_idx) - }) - .collect::>>>()?; - - let metrics = if metrics.is_empty() { - None - } else { - Some(metrics) - }; - let buckets = if buckets.is_empty() { - None - } else { - Some(buckets) - }; - Ok(GenericSegmentAggregationResultsCollector { metrics, buckets }) + Ok(GenericSegmentAggregationResultsCollector { aggs }) } } From e841842fea663aa8ff5b91b83fbf47219ac2c6c7 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Wed, 19 Apr 2023 10:15:04 +0800 Subject: [PATCH 5/5] fix doctests --- src/aggregation/mod.rs | 70 ++++++++++-------------------------------- 1 file changed, 17 insertions(+), 53 deletions(-) diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 2e8e0dbedb..c476938177 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -49,34 +49,6 @@ //! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an //! `(String, agg_req::Aggregation)` iterator. //! -//! ``` -//! use tantivy::aggregation::agg_req::{Aggregations, Aggregation, MetricAggregation}; -//! use tantivy::aggregation::AggregationCollector; -//! use tantivy::aggregation::metric::AverageAggregation; -//! use tantivy::query::AllQuery; -//! use tantivy::aggregation::agg_result::AggregationResults; -//! use tantivy::IndexReader; -//! -//! # #[allow(dead_code)] -//! fn aggregate_on_index(reader: &IndexReader) { -//! let agg_req: Aggregations = vec![ -//! ( -//! "average".to_string(), -//! Aggregation::Metric(MetricAggregation::Average( -//! AverageAggregation::from_field_name("score".to_string()), -//! )), -//! ), -//! ] -//! .into_iter() -//! .collect(); -//! -//! let collector = AggregationCollector::from_aggs(agg_req, Default::default()); -//! -//! let searcher = reader.searcher(); -//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); -//! } -//! ``` -//! # Example JSON //! Requests are compatible with the elasticsearch JSON request format. //! //! ``` @@ -116,32 +88,24 @@ //! aggregation and then calculate the average on each bucket. //! ``` //! use tantivy::aggregation::agg_req::*; -//! use tantivy::aggregation::metric::AverageAggregation; -//! use tantivy::aggregation::bucket::RangeAggregation; -//! let sub_agg_req_1: Aggregations = vec![( -//! "average_in_range".to_string(), -//! Aggregation::Metric(MetricAggregation::Average( -//! AverageAggregation::from_field_name("score".to_string()), -//! )), -//! )] -//! .into_iter() -//! .collect(); +//! use serde_json::json; //! -//! let agg_req_1: Aggregations = vec![ -//! ( -//! "range".to_string(), -//! Aggregation::Bucket(Box::new(BucketAggregation { -//! bucket_agg: BucketAggregationType::Range(RangeAggregation{ -//! field: "score".to_string(), -//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], -//! keyed: false, -//! }), -//! sub_aggregation: sub_agg_req_1.clone(), -//! })), -//! ), -//! ] -//! .into_iter() -//! .collect(); +//! let agg_req_1: Aggregations = serde_json::from_value(json!({ +//! "rangef64": { +//! "range": { +//! "field": "score", +//! "ranges": [ +//! { "from": 3, "to": 7000 }, +//! { "from": 7000, "to": 20000 }, +//! { "from": 50000, "to": 60000 } +//! ] +//! }, +//! "aggs": { +//! "average_in_range": { "avg": { "field": "score" } } +//! } +//! }, +//! })) +//! .unwrap(); //! ``` //! //! # Distributed Aggregation