Skip to content

Commit

Permalink
split term collection count and sub_agg (#1921)
Browse files Browse the repository at this point in the history
use unrolled ColumnValues::get_vals
  • Loading branch information
PSeitz authored Mar 13, 2023
1 parent 61cfd8d commit 8459efa
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ benchmark
.idea
trace.dat
cargo-timing*
control
variable
19 changes: 15 additions & 4 deletions columnar/src/column_values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,21 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
/// # Panics
///
/// May panic if `idx` is greater than the column length.
fn get_vals(&self, idxs: &[u32], output: &mut [T]) {
assert!(idxs.len() == output.len());
for (out, &idx) in output.iter_mut().zip(idxs.iter()) {
*out = self.get_val(idx);
fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
assert!(indexes.len() == output.len());
let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
for (out_x4, idx_x4) in out_and_idx_chunks {
out_x4[0] = self.get_val(idx_x4[0]);
out_x4[1] = self.get_val(idx_x4[1]);
out_x4[2] = self.get_val(idx_x4[2]);
out_x4[3] = self.get_val(idx_x4[3]);
}

let step_size = 4;
let cutoff = indexes.len() - indexes.len() % step_size;

for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx] as u32);
}
}

Expand Down
2 changes: 1 addition & 1 deletion columnar/src/column_values/monotonic_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
Input: PartialOrd + Send + Debug + Sync + Clone,
Output: PartialOrd + Send + Debug + Sync + Clone,
{
#[inline]
#[inline(always)]
fn get_val(&self, idx: u32) -> Output {
let from_val = self.from_column.get_val(idx);
self.monotonic_mapping.mapping(from_val)
Expand Down
14 changes: 14 additions & 0 deletions columnar/src/column_values/u64_based/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(

let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() {
let val = reader.get_val(doc as u32);
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);

buffer.resize(1, 0);
reader.get_vals(&[doc as u32], &mut buffer);
let val = buffer[0];
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
}

let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
buffer.resize(all_docs.len(), 0);
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);

if !vals.is_empty() {
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals
Expand Down
2 changes: 2 additions & 0 deletions src/aggregation/bucket/histogram/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
})
}

#[inline]
fn collect(
&mut self,
doc: crate::DocId,
Expand All @@ -238,6 +239,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
self.collect_block(&[doc], agg_with_accessor)
}

#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
Expand Down
2 changes: 2 additions & 0 deletions src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
})
}

#[inline]
fn collect(
&mut self,
doc: crate::DocId,
Expand All @@ -216,6 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
self.collect_block(&[doc], agg_with_accessor)
}

#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
Expand Down
112 changes: 76 additions & 36 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ impl TermsAggregationInternal {
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
pub(crate) entries: FxHashMap<u64, u64>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -249,10 +250,8 @@ impl TermBucketEntry {

impl TermBuckets {
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
for entry in &mut self.entries.values_mut() {
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.flush(agg_with_accessor)?;
}
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_with_accessor)?;
}
Ok(())
}
Expand All @@ -268,6 +267,7 @@ pub struct SegmentTermCollector {
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
field_type: ColumnType,
accessor_idx: usize,
val_cache: Vec<u64>,
}

pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
Expand All @@ -292,6 +292,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
})
}

#[inline]
fn collect(
&mut self,
doc: crate::DocId,
Expand All @@ -300,6 +301,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
self.collect_block(&[doc], agg_with_accessor)
}

#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
Expand All @@ -310,28 +312,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;

if accessor.get_cardinality() == Cardinality::Full {
for doc in docs {
let term_id = accessor.values.get_val(*doc);
let entry = self
.term_buckets
.entries
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
self.val_cache.resize(docs.len(), 0);
accessor.values.get_vals(docs, &mut self.val_cache);
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = self.blueprint.as_ref() {
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
}
}
} else {
for doc in docs {
for term_id in accessor.values_for_doc(*doc) {
let entry = self
.term_buckets
.entries
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
// TODO: check if seperate loop is faster (may depend on the codec)
if let Some(blueprint) = self.blueprint.as_ref() {
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
}
}
Expand Down Expand Up @@ -386,15 +395,16 @@ impl SegmentTermCollector {
blueprint,
field_type,
accessor_idx,
val_cache: Default::default(),
})
}

#[inline]
pub(crate) fn into_intermediate_bucket_result(
self,
mut self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, TermBucketEntry)> =
self.term_buckets.entries.into_iter().collect();
let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();

let order_by_sub_aggregation =
matches!(self.req.order.target, OrderTarget::SubAggregation(_));
Expand All @@ -417,9 +427,9 @@ impl SegmentTermCollector {
}
OrderTarget::Count => {
if self.req.order.order == Order::Desc {
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
} else {
entries.sort_unstable_by_key(|bucket| bucket.doc_count());
entries.sort_unstable_by_key(|bucket| bucket.1);
}
}
}
Expand All @@ -432,24 +442,51 @@ impl SegmentTermCollector {

let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());

let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: self
.term_buckets
.sub_aggs
.remove(&id)
.expect(&format!(
"Internal Error: could not find subaggregation for id {}",
id
))
.into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation,
)?,
}
} else {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: Default::default(),
}
};
Ok(intermediate_entry)
};

if self.field_type == ColumnType::Str {
let term_dict = agg_with_accessor
.str_dict_column
.as_ref()
.expect("internal error: term dictionary not found for term aggregation");

let mut buffer = String::new();
for (term_id, entry) in entries {
for (term_id, doc_count) in entries {
if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {} in dict",
term_id
)));
}
dict.insert(
Key::Str(buffer.to_string()),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);

let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;

dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
}
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
Expand All @@ -468,12 +505,10 @@ impl SegmentTermCollector {
}
}
} else {
for (val, entry) in entries {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.field_type);
dict.insert(
Key::F64(val),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
dict.insert(Key::F64(val), intermediate_entry);
}
};

Expand All @@ -495,6 +530,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
self.1.doc_count
}
}
impl GetDocCount for (u64, u64) {
fn doc_count(&self) -> u64 {
self.1
}
}
impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
Expand Down
4 changes: 4 additions & 0 deletions src/aggregation/buf_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ impl BufAggregationCollector {
}

impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
}

#[inline]
fn collect(
&mut self,
doc: crate::DocId,
Expand All @@ -56,6 +58,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Ok(())
}

#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
Expand All @@ -67,6 +70,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Ok(())
}

#[inline]
fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
Expand Down
11 changes: 8 additions & 3 deletions src/aggregation/metric/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ pub(crate) struct SegmentStatsCollector {
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}

impl SegmentStatsCollector {
Expand All @@ -169,14 +170,16 @@ impl SegmentStatsCollector {
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) {
if field.get_cardinality() == Cardinality::Full {
for doc in docs {
let val = field.values.get_val(*doc);
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.val_cache.resize(docs.len(), 0);
field.values.get_vals(docs, &mut self.val_cache);
for val in self.val_cache.iter() {
let val1 = f64_from_fastfield_u64(*val, &self.field_type);
self.stats.collect(val1);
}
} else {
Expand All @@ -191,6 +194,7 @@ impl SegmentStatsCollector {
}

impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
Expand Down Expand Up @@ -227,6 +231,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
})
}

#[inline]
fn collect(
&mut self,
doc: crate::DocId,
Expand Down

0 comments on commit 8459efa

Please sign in to comment.