Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for keyed parameter in range and histgram aggregations #1424

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> {
(9f64..14f64).into(),
(14f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1.clone(),
}),
Expand Down
13 changes: 12 additions & 1 deletion src/aggregation/agg_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: Default::default(),
//! }),
Expand Down Expand Up @@ -100,6 +101,12 @@ pub(crate) struct BucketAggregationInternal {
}

impl BucketAggregationInternal {
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self.bucket_agg {
BucketAggregationType::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
match &self.bucket_agg {
BucketAggregationType::Histogram(histogram) => Some(histogram),
Expand Down Expand Up @@ -264,6 +271,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Expand All @@ -290,7 +298,8 @@ mod tests {
{
"from": 20.0
}
]
],
"keyed": true
}
}
}"#;
Expand All @@ -312,6 +321,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Expand All @@ -337,6 +347,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: agg_req2,
}),
Expand Down
3 changes: 1 addition & 2 deletions src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name,
ranges: _,
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..
Expand Down
16 changes: 14 additions & 2 deletions src/aggregation/agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use std::collections::HashMap;

use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};

use super::agg_req::BucketAggregationInternal;
Expand Down Expand Up @@ -104,7 +105,7 @@ pub enum BucketResult {
/// sub_aggregations.
Range {
/// The range buckets sorted by range.
buckets: Vec<RangeBucketEntry>,
buckets: BucketEntries<RangeBucketEntry>,
},
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
Expand All @@ -114,7 +115,7 @@ pub enum BucketResult {
/// If there are holes depends on the request, if min_doc_count is 0, then there are no
/// holes between the first and last bucket.
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
buckets: Vec<BucketEntry>,
buckets: BucketEntries<BucketEntry>,
},
/// This is the term result
Terms {
Expand All @@ -137,6 +138,17 @@ impl BucketResult {
}
}

/// This is the wrapper of buckets entries, which can be vector or hashmap
/// depending on if it's keyed or not.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum BucketEntries<T> {
/// Vector format bucket entries
Vec(Vec<T>),
/// HashMap format bucket entries
HashMap(FnvHashMap<String, T>),
}

/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
///
Expand Down
47 changes: 45 additions & 2 deletions src/aggregation/bucket/histogram/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ use crate::{DocId, TantivyError};
///
/// # Limitations/Compatibility
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # JSON Format
/// ```json
/// {
Expand Down Expand Up @@ -117,6 +115,9 @@ pub struct HistogramAggregation {
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
/// bounds would not be returned.
pub extended_bounds: Option<HistogramBounds>,
/// Whether to return the buckets as a hash map
#[serde(default)]
pub keyed: bool,
}

impl HistogramAggregation {
Expand Down Expand Up @@ -1395,4 +1396,46 @@ mod tests {

Ok(())
}

#[test]
fn histogram_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;

let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 50.0,
keyed: true,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();

let res = exec_request(agg_req, &index)?;

assert_eq!(
res,
json!({
"histogram": {
"buckets": {
"0": {
"key": 0.0,
"doc_count": 50
},
"50": {
"key": 50.0,
"doc_count": 50
}
}
}
})
);

Ok(())
}
}
52 changes: 49 additions & 3 deletions src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ use crate::{DocId, TantivyError};
/// # Limitations/Compatibility
/// Overlapping ranges are not yet supported.
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # Request JSON Format
/// ```json
/// {
Expand All @@ -51,13 +49,16 @@ use crate::{DocId, TantivyError};
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct RangeAggregation {
/// The field to aggregate on.
pub field: String,
/// Note that this aggregation includes the from value and excludes the to value for each
/// range. Extra buckets will be created until the first to, and last from, if necessary.
pub ranges: Vec<RangeAggregationRange>,
/// Whether to return the buckets as a hash map
#[serde(default)]
pub keyed: bool,
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -406,6 +407,7 @@ mod tests {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
..Default::default()
};

SegmentRangeCollector::from_req_and_validate(
Expand All @@ -427,6 +429,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Expand Down Expand Up @@ -454,6 +457,49 @@ mod tests {
Ok(())
}

#[test]
fn range_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;

let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req, None);

let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap();

let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;

assert_eq!(
res,
json!({
"range": {
"buckets": {
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
"0-0.1": {"key": "0-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
"0.1-0.2": {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
"0.2-*": {"key": "0.2-*", "doc_count": 80, "from": 0.2},
}
}
})
);

Ok(())
}

#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
Expand Down
27 changes: 26 additions & 1 deletion src/aggregation/intermediate_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::bucket::{
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;

/// Contains the intermediate aggregation result, which is optimized to be merged with other
Expand Down Expand Up @@ -281,6 +281,21 @@ impl IntermediateBucketResult {
.unwrap_or(f64::MIN)
.total_cmp(&right.from.unwrap_or(f64::MIN))
});

let is_keyed = req
.as_range()
.expect("unexpected aggregation, expected range aggregation")
.keyed;
let buckets = if is_keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
Expand All @@ -291,6 +306,16 @@ impl IntermediateBucketResult {
&req.sub_aggregation,
)?;

let buckets = if req.as_histogram().unwrap().keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
Expand Down
1 change: 1 addition & 0 deletions src/aggregation/metric/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Expand Down
Loading