From f2dad194eaca1519c93b2e33e8bec38de37893d4 Mon Sep 17 00:00:00 2001 From: Adrien Guillo Date: Fri, 13 Jan 2023 14:33:00 -0500 Subject: [PATCH] Add count, min, max, and sum aggregations --- src/aggregation/agg_req.rs | 62 ++++++++++++++++++-- src/aggregation/agg_req_with_accessor.rs | 11 +++- src/aggregation/agg_result.rs | 40 ++++++++++--- src/aggregation/intermediate_agg_result.rs | 62 +++++++++++++++++--- src/aggregation/metric/average.rs | 37 ++++++------ src/aggregation/metric/count.rs | 59 +++++++++++++++++++ src/aggregation/metric/max.rs | 59 +++++++++++++++++++ src/aggregation/metric/min.rs | 59 +++++++++++++++++++ src/aggregation/metric/mod.rs | 66 ++++++++++++++++++++++ src/aggregation/metric/stats.rs | 46 ++++++++------- src/aggregation/metric/sum.rs | 59 +++++++++++++++++++ src/aggregation/mod.rs | 4 +- src/aggregation/segment_agg_result.rs | 29 ++++++++-- 13 files changed, 522 insertions(+), 71 deletions(-) create mode 100644 src/aggregation/metric/count.rs create mode 100644 src/aggregation/metric/max.rs create mode 100644 src/aggregation/metric/min.rs create mode 100644 src/aggregation/metric/sum.rs diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 0fba56dc14..c412898d68 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -51,7 +51,10 @@ use serde::{Deserialize, Serialize}; pub use super::bucket::RangeAggregation; use super::bucket::{HistogramAggregation, TermsAggregation}; -use super::metric::{AverageAggregation, StatsAggregation}; +use super::metric::{ + AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation, + SumAggregation, +}; use super::VecWithNames; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user @@ -237,20 +240,37 @@ impl BucketAggregationType { /// called multi-value numeric metrics aggregation. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum MetricAggregation { - /// Calculates the average. + /// Computes the average. #[serde(rename = "avg")] Average(AverageAggregation), + /// Counts the number of extracted values. + #[serde(rename = "value_count")] + Count(CountAggregation), + /// Finds the maximum value. + #[serde(rename = "max")] + Max(MaxAggregation), + /// Finds the minimum value. + #[serde(rename = "min")] + Min(MinAggregation), /// Calculates stats sum, average, min, max, standard_deviation on a field. #[serde(rename = "stats")] Stats(StatsAggregation), + /// Computes the sum. + #[serde(rename = "sum")] + Sum(SumAggregation), } impl MetricAggregation { fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - match self { - MetricAggregation::Average(avg) => fast_field_names.insert(avg.field.to_string()), - MetricAggregation::Stats(stats) => fast_field_names.insert(stats.field.to_string()), + let fast_field_name = match self { + MetricAggregation::Average(avg) => avg.field_name(), + MetricAggregation::Count(count) => count.field_name(), + MetricAggregation::Max(max) => max.field_name(), + MetricAggregation::Min(min) => min.field_name(), + MetricAggregation::Stats(stats) => stats.field_name(), + MetricAggregation::Sum(sum) => sum.field_name(), }; + fast_field_names.insert(fast_field_name.to_string()); } } @@ -258,6 +278,38 @@ impl MetricAggregation { mod tests { use super::*; + #[test] + fn test_metric_aggregations_deser() { + let agg_req_json = r#"{ + "price_avg": { "avg": { "field": "price" } }, + "price_count": { "value_count": { "field": "price" } }, + "price_max": { "max": { "field": "price" } }, + "price_min": { "min": { "field": "price" } }, + "price_stats": { "stats": { "field": "price" } }, + "price_sum": { "sum": { "field": "price" } } + }"#; + let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + + assert!( + matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price") + ); + assert!( + matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price") + ); + assert!( + matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price") + ); + assert!( + matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price") + ); + assert!( + matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price") + ); + assert!( + matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price") + ); + } + #[test] fn serialize_to_json_test() { let agg_req1: Aggregations = vec![( diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 23aba24095..0bc25c1e4b 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -8,7 +8,10 @@ use fastfield_codecs::Column; use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation}; -use super::metric::{AverageAggregation, StatsAggregation}; +use super::metric::{ + AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation, + SumAggregation, +}; use super::segment_agg_result::BucketCount; use super::VecWithNames; use crate::fastfield::{type_and_cardinality, MultiValuedFastFieldReader}; @@ -134,7 +137,11 @@ impl MetricAggregationWithAccessor { ) -> crate::Result { match &metric { MetricAggregation::Average(AverageAggregation { field: field_name }) - | MetricAggregation::Stats(StatsAggregation { field: field_name }) => { + | MetricAggregation::Count(CountAggregation { field: field_name }) + | MetricAggregation::Max(MaxAggregation { field: field_name }) + | MetricAggregation::Min(MinAggregation { field: field_name }) + | MetricAggregation::Stats(StatsAggregation { field: field_name }) + | MetricAggregation::Sum(SumAggregation { field: field_name }) => { let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?; diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index be5bfd0835..7122f3bee5 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -30,7 +30,7 @@ impl AggregationResults { } else { // Validation is be done during request parsing, so we can't reach this state. Err(TantivyError::InternalError(format!( - "Can't find aggregation {:?} in sub_aggregations", + "Can't find aggregation {:?} in sub-aggregations", name ))) } @@ -70,27 +70,51 @@ impl AggregationResult { pub enum MetricResult { /// Average metric result. Average(SingleMetricResult), + /// Count metric result. + Count(SingleMetricResult), + /// Max metric result. + Max(SingleMetricResult), + /// Min metric result. + Min(SingleMetricResult), /// Stats metric result. Stats(Stats), + /// Sum metric result. + Sum(SingleMetricResult), } impl MetricResult { fn get_value(&self, agg_property: &str) -> crate::Result> { match self { MetricResult::Average(avg) => Ok(avg.value), + MetricResult::Count(count) => Ok(count.value), + MetricResult::Max(max) => Ok(max.value), + MetricResult::Min(min) => Ok(min.value), MetricResult::Stats(stats) => stats.get_value(agg_property), + MetricResult::Sum(sum) => Ok(sum.value), } } } impl From for MetricResult { fn from(metric: IntermediateMetricResult) -> Self { match metric { - IntermediateMetricResult::Average(avg_data) => { - MetricResult::Average(avg_data.finalize().into()) + IntermediateMetricResult::Average(intermediate_avg) => { + MetricResult::Average(intermediate_avg.finalize().into()) + } + IntermediateMetricResult::Count(intermediate_count) => { + MetricResult::Count(intermediate_count.finalize().into()) + } + IntermediateMetricResult::Max(intermediate_max) => { + MetricResult::Max(intermediate_max.finalize().into()) + } + IntermediateMetricResult::Min(intermediate_min) => { + MetricResult::Min(intermediate_min.finalize().into()) } IntermediateMetricResult::Stats(intermediate_stats) => { MetricResult::Stats(intermediate_stats.finalize()) } + IntermediateMetricResult::Sum(intermediate_sum) => { + MetricResult::Sum(intermediate_sum.finalize().into()) + } } } } @@ -100,13 +124,13 @@ impl From for MetricResult { #[serde(untagged)] pub enum BucketResult { /// This is the range entry for a bucket, which contains a key, count, from, to, and optionally - /// sub_aggregations. + /// sub-aggregations. Range { /// The range buckets sorted by range. buckets: BucketEntries, }, /// This is the histogram entry for a bucket, which contains a key, count, and optionally - /// sub_aggregations. + /// sub-aggregations. Histogram { /// The buckets. /// @@ -151,7 +175,7 @@ pub enum BucketEntries { } /// This is the default entry for a bucket, which contains a key, count, and optionally -/// sub_aggregations. +/// sub-aggregations. /// /// # JSON Format /// ```json @@ -201,7 +225,7 @@ impl GetDocCount for BucketEntry { } /// This is the range entry for a bucket, which contains a key, count, and optionally -/// sub_aggregations. +/// sub-aggregations. /// /// # JSON Format /// ```json @@ -237,7 +261,7 @@ pub struct RangeBucketEntry { /// Number of documents in the bucket. pub doc_count: u64, #[serde(flatten)] - /// sub-aggregations in this bucket. + /// Sub-aggregations in this bucket. pub sub_aggregation: AggregationResults, /// The from range of the bucket. Equals `f64::MIN` when `None`. #[serde(skip_serializing_if = "Option::is_none")] diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 38bda09cac..d6d5889642 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -17,7 +17,10 @@ use super::bucket::{ cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets, GetDocCount, Order, OrderTarget, SegmentHistogramBucketEntry, TermsAggregation, }; -use super::metric::{IntermediateAverage, IntermediateStats}; +use super::metric::{ + IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, + IntermediateSum, +}; use super::segment_agg_result::SegmentMetricResultCollector; use super::{format_date, Key, SerializedKey, VecWithNames}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; @@ -204,22 +207,42 @@ pub enum IntermediateAggregationResult { /// Holds the intermediate data for metric results #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateMetricResult { - /// Intermediate average result + /// Intermediate average result. Average(IntermediateAverage), - /// Intermediate stats result + /// Intermediate count result. + Count(IntermediateCount), + /// Intermediate max result. + Max(IntermediateMax), + /// Intermediate min result. + Min(IntermediateMin), + /// Intermediate stats result. Stats(IntermediateStats), + /// Intermediate sum result. + Sum(IntermediateSum), } impl From for IntermediateMetricResult { fn from(tree: SegmentMetricResultCollector) -> Self { match tree { SegmentMetricResultCollector::Stats(collector) => match collector.collecting_for { + super::metric::SegmentStatsType::Average => IntermediateMetricResult::Average( + IntermediateAverage::from_collector(collector), + ), + super::metric::SegmentStatsType::Count => { + IntermediateMetricResult::Count(IntermediateCount::from_collector(collector)) + } + super::metric::SegmentStatsType::Max => { + IntermediateMetricResult::Max(IntermediateMax::from_collector(collector)) + } + super::metric::SegmentStatsType::Min => { + IntermediateMetricResult::Min(IntermediateMin::from_collector(collector)) + } super::metric::SegmentStatsType::Stats => { IntermediateMetricResult::Stats(collector.stats) } - super::metric::SegmentStatsType::Avg => IntermediateMetricResult::Average( - IntermediateAverage::from_collector(collector), - ), + super::metric::SegmentStatsType::Sum => { + IntermediateMetricResult::Sum(IntermediateSum::from_collector(collector)) + } }, } } @@ -231,18 +254,36 @@ impl IntermediateMetricResult { MetricAggregation::Average(_) => { IntermediateMetricResult::Average(IntermediateAverage::default()) } + MetricAggregation::Count(_) => { + IntermediateMetricResult::Count(IntermediateCount::default()) + } + MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()), + MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()), MetricAggregation::Stats(_) => { IntermediateMetricResult::Stats(IntermediateStats::default()) } + MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()), } } fn merge_fruits(&mut self, other: IntermediateMetricResult) { match (self, other) { ( - IntermediateMetricResult::Average(avg_data_left), - IntermediateMetricResult::Average(avg_data_right), + IntermediateMetricResult::Average(avg_left), + IntermediateMetricResult::Average(avg_right), + ) => { + avg_left.merge_fruits(avg_right); + } + ( + IntermediateMetricResult::Count(count_left), + IntermediateMetricResult::Count(count_right), ) => { - avg_data_left.merge_fruits(avg_data_right); + count_left.merge_fruits(count_right); + } + (IntermediateMetricResult::Max(max_left), IntermediateMetricResult::Max(max_right)) => { + max_left.merge_fruits(max_right); + } + (IntermediateMetricResult::Min(min_left), IntermediateMetricResult::Min(min_right)) => { + min_left.merge_fruits(min_right); } ( IntermediateMetricResult::Stats(stats_left), @@ -250,6 +291,9 @@ impl IntermediateMetricResult { ) => { stats_left.merge_fruits(stats_right); } + (IntermediateMetricResult::Sum(sum_left), IntermediateMetricResult::Sum(sum_right)) => { + sum_left.merge_fruits(sum_right); + } _ => { panic!("incompatible fruit types in tree"); } diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index 733faacb3c..2d16de9a86 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -2,9 +2,8 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use super::SegmentStatsCollector; +use super::{IntermediateStats, SegmentStatsCollector}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// A single-value metric aggregation that computes the average of numeric values that are /// extracted from the aggregated documents. /// Supported field types are u64, i64, and f64. @@ -18,47 +17,43 @@ use super::SegmentStatsCollector; /// } /// } /// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AverageAggregation { - /// The field name to compute the stats on. + /// The field name to compute the average on. pub field: String, } + impl AverageAggregation { - /// Create new AverageAggregation from a field. + /// Creates a new [`AverageAggregation`] instance from a field name. pub fn from_field_name(field_name: String) -> Self { - AverageAggregation { field: field_name } + Self { field: field_name } } - /// Return the field name. + /// Returns the field name the aggregation is computed on. pub fn field_name(&self) -> &str { &self.field } } -/// Contains mergeable version of average data. +/// Intermediate result of the average aggregation that can be combined with other intermediate +/// results. #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateAverage { - pub(crate) sum: f64, - pub(crate) doc_count: u64, + stats: IntermediateStats, } impl IntermediateAverage { + /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { Self { - sum: collector.stats.sum, - doc_count: collector.stats.count, + stats: collector.stats, } } - - /// Merge average data into this instance. + /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateAverage) { - self.sum += other.sum; - self.doc_count += other.doc_count; + self.stats.merge_fruits(other.stats); } - /// compute final result + /// Computes the final average value. pub fn finalize(&self) -> Option { - if self.doc_count == 0 { - None - } else { - Some(self.sum / self.doc_count as f64) - } + self.stats.finalize().avg } } diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs new file mode 100644 index 0000000000..2a49c8dfe7 --- /dev/null +++ b/src/aggregation/metric/count.rs @@ -0,0 +1,59 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use super::{IntermediateStats, SegmentStatsCollector}; + +/// A single-value metric aggregation that counts the number of values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [super::SingleMetricResult] for return value. +/// +/// # JSON Format +/// ```json +/// { +/// "value_count": { +/// "field": "score", +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CountAggregation { + /// The field name to compute the minimum on. + pub field: String, +} + +impl CountAggregation { + /// Creates a new [`CountAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + Self { field: field_name } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Intermediate result of the count aggregation that can be combined with other intermediate +/// results. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateCount { + stats: IntermediateStats, +} + +impl IntermediateCount { + /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. + pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { + Self { + stats: collector.stats, + } + } + /// Merges the other intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateCount) { + self.stats.merge_fruits(other.stats); + } + /// Computes the final minimum value. + pub fn finalize(&self) -> Option { + Some(self.stats.finalize().count as f64) + } +} diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs new file mode 100644 index 0000000000..842d183c93 --- /dev/null +++ b/src/aggregation/metric/max.rs @@ -0,0 +1,59 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use super::{IntermediateStats, SegmentStatsCollector}; + +/// A single-value metric aggregation that computes the maximum of numeric values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [super::SingleMetricResult] for return value. +/// +/// # JSON Format +/// ```json +/// { +/// "max": { +/// "field": "score", +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct MaxAggregation { + /// The field name to compute the maximum on. + pub field: String, +} + +impl MaxAggregation { + /// Creates a new [`MaxAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + Self { field: field_name } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Intermediate result of the maximum aggregation that can be combined with other intermediate +/// results. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateMax { + stats: IntermediateStats, +} + +impl IntermediateMax { + /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. + pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { + Self { + stats: collector.stats, + } + } + /// Merges the other intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateMax) { + self.stats.merge_fruits(other.stats); + } + /// Computes the final maximum value. + pub fn finalize(&self) -> Option { + self.stats.finalize().max + } +} diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs new file mode 100644 index 0000000000..548423da06 --- /dev/null +++ b/src/aggregation/metric/min.rs @@ -0,0 +1,59 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use super::{IntermediateStats, SegmentStatsCollector}; + +/// A single-value metric aggregation that computes the minimum of numeric values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [super::SingleMetricResult] for return value. +/// +/// # JSON Format +/// ```json +/// { +/// "min": { +/// "field": "score", +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct MinAggregation { + /// The field name to compute the minimum on. + pub field: String, +} + +impl MinAggregation { + /// Creates a new [`MinAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + Self { field: field_name } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Intermediate result of the minimum aggregation that can be combined with other intermediate +/// results. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateMin { + stats: IntermediateStats, +} + +impl IntermediateMin { + /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. + pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { + Self { + stats: collector.stats, + } + } + /// Merges the other intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateMin) { + self.stats.merge_fruits(other.stats); + } + /// Computes the final minimum value. + pub fn finalize(&self) -> Option { + self.stats.finalize().min + } +} diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index f4be52a538..87fd64e455 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -3,10 +3,18 @@ //! The aggregations in this family compute metrics, see [super::agg_req::MetricAggregation] for //! details. mod average; +mod count; +mod max; +mod min; mod stats; +mod sum; pub use average::*; +pub use count::*; +pub use max::*; +pub use min::*; use serde::{Deserialize, Serialize}; pub use stats::*; +pub use sum::*; /// Single-metric aggregations use this common result structure. /// @@ -28,3 +36,61 @@ impl From> for SingleMetricResult { Self { value } } } + +#[cfg(test)] +mod tests { + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::AggregationCollector; + use crate::query::AllQuery; + use crate::schema::{Cardinality, NumericOptions, Schema}; + use crate::Index; + + #[test] + fn test_metric_aggregations() { + let mut schema_builder = Schema::builder(); + let field_options = NumericOptions::default().set_fast(Cardinality::SingleValue); + let field = schema_builder.add_f64_field("price", field_options); + let index = Index::create_in_ram(schema_builder.build()); + let mut index_writer = index.writer_for_tests().unwrap(); + + for i in 0..3 { + index_writer + .add_document(doc!( + field => i as f64, + )) + .unwrap(); + } + index_writer.commit().unwrap(); + + for i in 3..6 { + index_writer + .add_document(doc!( + field => i as f64, + )) + .unwrap(); + } + index_writer.commit().unwrap(); + + let aggregations_json = r#"{ + "price_avg": { "avg": { "field": "price" } }, + "price_count": { "value_count": { "field": "price" } }, + "price_max": { "max": { "field": "price" } }, + "price_min": { "min": { "field": "price" } }, + "price_stats": { "stats": { "field": "price" } }, + "price_sum": { "sum": { "field": "price" } } + }"#; + let aggregations: Aggregations = serde_json::from_str(&aggregations_json).unwrap(); + let collector = AggregationCollector::from_aggs(aggregations, None, index.schema()); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + let aggregations_res_json = serde_json::to_value(&aggregations_res).unwrap(); + + assert_eq!(aggregations_res_json["price_avg"]["value"], 2.5); + assert_eq!(aggregations_res_json["price_count"]["value"], 6.0); + assert_eq!(aggregations_res_json["price_max"]["value"], 5.0); + assert_eq!(aggregations_res_json["price_min"]["value"], 0.0); + assert_eq!(aggregations_res_json["price_sum"]["value"], 15.0); + } +} diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index a28419716d..965bf6e019 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -5,8 +5,8 @@ use crate::aggregation::f64_from_fastfield_u64; use crate::schema::Type; use crate::{DocId, TantivyError}; -/// A multi-value metric aggregation that computes stats of numeric values that are -/// extracted from the aggregated documents. +/// A multi-value metric aggregation that computes a collection of statistics on numeric values that +/// are extracted from the aggregated documents. /// Supported field types are `u64`, `i64`, and `f64`. /// See [`Stats`] for returned statistics. /// @@ -26,11 +26,11 @@ pub struct StatsAggregation { } impl StatsAggregation { - /// Create new StatsAggregation from a field. + /// Creates a new [`StatsAggregation`] instance from a field name. pub fn from_field_name(field_name: String) -> Self { StatsAggregation { field: field_name } } - /// Return the field name. + /// Returns the field name the aggregation is computed on. pub fn field_name(&self) -> &str { &self.field } @@ -43,13 +43,13 @@ pub struct Stats { pub count: u64, /// The sum of the fast field values. pub sum: f64, - /// The standard deviation of the fast field values. `None` for count == 0. + /// The standard deviation of the fast field values. `None` if count equals zero. pub standard_deviation: Option, /// The min value of the fast field values. pub min: Option, /// The max value of the fast field values. pub max: Option, - /// The average of the values. `None` for count == 0. + /// The average of the fast field values. `None` if count equals zero. pub avg: Option, } @@ -63,27 +63,29 @@ impl Stats { "max" => Ok(self.max), "avg" => Ok(self.avg), _ => Err(TantivyError::InvalidArgument(format!( - "unknown property {} on stats metric aggregation", + "Unknown property {} on stats metric aggregation", agg_property ))), } } } -/// `IntermediateStats` contains the mergeable version for stats. +/// Intermediate result of the stats aggregation that can be combined with other intermediate +/// results. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { - /// the number of values + /// The number of values. pub count: u64, - /// the sum of the values + /// The sum of the values. pub sum: f64, - /// the squared sum of the values + /// The sum of the squared values. pub squared_sum: f64, - /// the min value of the values + /// The min value of the values. pub min: f64, - /// the max value of the values + /// The max value of the values. pub max: f64, } + impl Default for IntermediateStats { fn default() -> Self { Self { @@ -97,7 +99,7 @@ impl Default for IntermediateStats { } impl IntermediateStats { - pub(crate) fn avg(&self) -> Option { + fn avg(&self) -> Option { if self.count == 0 { None } else { @@ -109,12 +111,12 @@ impl IntermediateStats { self.squared_sum / (self.count as f64) } - pub(crate) fn standard_deviation(&self) -> Option { + fn standard_deviation(&self) -> Option { self.avg() .map(|average| (self.square_mean() - average * average).sqrt()) } - /// Merge data from other stats into this instance. + /// Merges the other stats intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateStats) { self.count += other.count; self.sum += other.sum; @@ -123,7 +125,7 @@ impl IntermediateStats { self.max = self.max.max(other.max); } - /// compute final resultimprove_docs + /// Computes the final stats value. pub fn finalize(&self) -> Stats { let min = if self.count == 0 { None @@ -157,23 +159,27 @@ impl IntermediateStats { #[derive(Clone, Debug, PartialEq)] pub(crate) enum SegmentStatsType { + Average, + Count, + Max, + Min, Stats, - Avg, + Sum, } #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentStatsCollector { - pub(crate) stats: IntermediateStats, field_type: Type, pub(crate) collecting_for: SegmentStatsType, + pub(crate) stats: IntermediateStats, } impl SegmentStatsCollector { pub fn from_req(field_type: Type, collecting_for: SegmentStatsType) -> Self { Self { field_type, - stats: IntermediateStats::default(), collecting_for, + stats: IntermediateStats::default(), } } pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &dyn Column) { diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs new file mode 100644 index 0000000000..204598bbc8 --- /dev/null +++ b/src/aggregation/metric/sum.rs @@ -0,0 +1,59 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use super::{IntermediateStats, SegmentStatsCollector}; + +/// A single-value metric aggregation that sums up numeric values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [super::SingleMetricResult] for return value. +/// +/// # JSON Format +/// ```json +/// { +/// "sum": { +/// "field": "score", +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SumAggregation { + /// The field name to compute the minimum on. + pub field: String, +} + +impl SumAggregation { + /// Creates a new [`SumAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + Self { field: field_name } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Intermediate result of the minimum aggregation that can be combined with other intermediate +/// results. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateSum { + stats: IntermediateStats, +} + +impl IntermediateSum { + /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. + pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { + Self { + stats: collector.stats, + } + } + /// Merges the other intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateSum) { + self.stats.merge_fruits(other.stats); + } + /// Computes the final minimum value. + pub fn finalize(&self) -> Option { + Some(self.stats.finalize().sum) + } +} diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ba17e8f298..df16871619 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -216,8 +216,8 @@ impl VecWithNames { fn from_entries(mut entries: Vec<(String, T)>) -> Self { // Sort to ensure order of elements match across multiple instances entries.sort_by(|left, right| left.0.cmp(&right.0)); - let mut data = vec![]; - let mut data_names = vec![]; + let mut data = Vec::with_capacity(entries.len()); + let mut data_names = Vec::with_capacity(entries.len()); for entry in entries { data_names.push(entry.0); data.push(entry.1); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index cc24692706..f6c7f8a3f7 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -15,7 +15,8 @@ use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTer use super::collector::MAX_BUCKET_COUNT; use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult}; use super::metric::{ - AverageAggregation, SegmentStatsCollector, SegmentStatsType, StatsAggregation, + AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector, + SegmentStatsType, StatsAggregation, SumAggregation, }; use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; @@ -169,16 +170,36 @@ pub(crate) enum SegmentMetricResultCollector { impl SegmentMetricResultCollector { pub fn from_req_and_validate(req: &MetricAggregationWithAccessor) -> crate::Result { match &req.metric { - MetricAggregation::Average(AverageAggregation { field: _ }) => { + MetricAggregation::Average(AverageAggregation { .. }) => { Ok(SegmentMetricResultCollector::Stats( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Avg), + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average), )) } - MetricAggregation::Stats(StatsAggregation { field: _ }) => { + MetricAggregation::Count(CountAggregation { .. }) => { + Ok(SegmentMetricResultCollector::Stats( + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count), + )) + } + MetricAggregation::Max(MaxAggregation { .. }) => { + Ok(SegmentMetricResultCollector::Stats( + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max), + )) + } + MetricAggregation::Min(MinAggregation { .. }) => { + Ok(SegmentMetricResultCollector::Stats( + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min), + )) + } + MetricAggregation::Stats(StatsAggregation { .. }) => { Ok(SegmentMetricResultCollector::Stats( SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats), )) } + MetricAggregation::Sum(SumAggregation { .. }) => { + Ok(SegmentMetricResultCollector::Stats( + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum), + )) + } } } pub(crate) fn collect_block(&mut self, doc: &[DocId], metric: &MetricAggregationWithAccessor) {