Skip to content

Commit

Permalink
work in batches of docs
Browse files Browse the repository at this point in the history
  • Loading branch information
PSeitz committed Mar 13, 2023
1 parent 8459efa commit 116f2fb
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 110 deletions.
2 changes: 1 addition & 1 deletion columnar/src/column_values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
let cutoff = indexes.len() - indexes.len() % step_size;

for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx] as u32);
output[idx] = self.get_val(indexes[idx]);
}
}

Expand Down
65 changes: 9 additions & 56 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use crate::TantivyError;
/// into segment_size.
///
/// Result type is [`BucketResult`](crate::aggregation::agg_result::BucketResult) with
/// [`TermBucketEntry`](crate::aggregation::agg_result::BucketEntry) on the
/// [`BucketEntry`](crate::aggregation::agg_result::BucketEntry) on the
/// `AggregationCollector`.
///
/// Result type is
Expand Down Expand Up @@ -209,45 +209,6 @@ struct TermBuckets {
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}

#[derive(Clone, Default)]
struct TermBucketEntry {
doc_count: u64,
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
}

impl Debug for TermBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TermBucketEntry")
.field("doc_count", &self.doc_count)
.finish()
}
}

impl TermBucketEntry {
fn from_blueprint(blueprint: &Option<Box<dyn SegmentAggregationCollector>>) -> Self {
Self {
doc_count: 0,
sub_aggregations: blueprint.clone(),
}
}

pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateTermBucketEntry> {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregations {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)?
} else {
Default::default()
};

Ok(IntermediateTermBucketEntry {
doc_count: self.doc_count,
sub_aggregation,
})
}
}

impl TermBuckets {
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
Expand Down Expand Up @@ -314,7 +275,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
if accessor.get_cardinality() == Cardinality::Full {
self.val_cache.resize(docs.len(), 0);
accessor.values.get_vals(docs, &mut self.val_cache);
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
for term_id in self.val_cache.iter().cloned() {
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
Expand Down Expand Up @@ -445,17 +406,19 @@ impl SegmentTermCollector {

let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() {
let intermediate_entry = if self.blueprint.as_ref().is_some() {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: self
.term_buckets
.sub_aggs
.remove(&id)
.expect(&format!(
"Internal Error: could not find subaggregation for id {}",
id
))
.unwrap_or_else(|| {
panic!(
"Internal Error: could not find subaggregation for id {}",
id
)
})
.into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation,
)?,
Expand Down Expand Up @@ -525,21 +488,11 @@ impl SegmentTermCollector {
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u32, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (u64, u64) {
fn doc_count(&self) -> u64 {
self.1
}
}
impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
Expand Down
5 changes: 2 additions & 3 deletions src/aggregation/buf_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ impl SegmentAggregationCollector for BufAggregationCollector {
docs: &[crate::DocId],
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_with_accessor)?;
}
self.collector.collect_block(docs, agg_with_accessor)?;

Ok(())
}

Expand Down
29 changes: 22 additions & 7 deletions src/aggregation/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{build_segment_agg_collector, SegmentAggregationCollector};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::{SegmentReader, TantivyError};
use crate::{DocId, SegmentReader, TantivyError};

/// The default max bucket count, before the aggregation fails.
pub const MAX_BUCKET_COUNT: u32 = 65000;
Expand Down Expand Up @@ -135,7 +135,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor,
result: BufAggregationCollector,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}

Expand All @@ -153,7 +153,7 @@ impl AggregationSegmentCollector {
BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor,
result,
agg_collector: result,
error: None,
})
}
Expand All @@ -163,11 +163,26 @@ impl SegmentCollector for AggregationSegmentCollector {
type Fruit = crate::Result<IntermediateAggregationResults>;

#[inline]
fn collect(&mut self, doc: crate::DocId, _score: crate::Score) {
fn collect(&mut self, doc: DocId, _score: crate::Score) {
if self.error.is_some() {
return;
}
if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) {
if let Err(err) = self.agg_collector.collect(doc, &self.aggs_with_accessor) {
self.error = Some(err);
}
}

/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &self.aggs_with_accessor)
{
self.error = Some(err);
}
}
Expand All @@ -176,7 +191,7 @@ impl SegmentCollector for AggregationSegmentCollector {
if let Some(err) = self.error {
return Err(err);
}
self.result.flush(&self.aggs_with_accessor)?;
Box::new(self.result).into_intermediate_aggregations_result(&self.aggs_with_accessor)
self.agg_collector.flush(&self.aggs_with_accessor)?;
Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor)
}
}
19 changes: 14 additions & 5 deletions src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,11 @@ pub trait Collector: Sync + Send {
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |doc| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
Expand All @@ -192,8 +194,8 @@ pub trait Collector: Sync + Send {
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |doc| {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Expand Down Expand Up @@ -270,6 +272,13 @@ pub trait SegmentCollector: 'static {
/// The query pushes the scored document to the collector via this method.
fn collect(&mut self, doc: DocId, score: Score);

/// The query pushes the scored document to the collector via this method.
fn collect_block(&mut self, docs: &[DocId]) {
for doc in docs {
self.collect(*doc, 0.0);
}
}

/// Extract the fruit of the collection from the `SegmentCollector`.
fn harvest(self) -> Self::Fruit;
}
Expand Down
9 changes: 8 additions & 1 deletion src/docset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::DocId;
/// to compare `[u32; 4]`.
pub const TERMINATED: DocId = i32::MAX as u32;

pub const BUFFER_LEN: usize = 64;

/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
Expand Down Expand Up @@ -59,7 +61,7 @@ pub trait DocSet: Send {
/// This method is only here for specific high-performance
/// use case where batching. The normal way to
/// go through the `DocId`'s is to call `.advance()`.
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
Expand Down Expand Up @@ -149,6 +151,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}

fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)
}

fn doc(&self) -> DocId {
let unboxed: &TDocSet = self.borrow();
unboxed.doc()
Expand Down
10 changes: 6 additions & 4 deletions src/indexer/index_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ fn compute_deleted_bitset(
// document that were inserted before it.
delete_op
.target
.for_each_no_score(segment_reader, &mut |doc_matching_delete_query| {
if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) {
alive_bitset.remove(doc_matching_delete_query);
might_have_changed = true;
.for_each_no_score(segment_reader, &mut |docs_matching_delete_query| {
for doc_matching_delete_query in docs_matching_delete_query.iter().cloned() {
if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) {
alive_bitset.remove(doc_matching_delete_query);
might_have_changed = true;
}
}
})?;
delete_cursor.advance();
Expand Down
27 changes: 26 additions & 1 deletion src/query/all_query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, BUFFER_LEN, TERMINATED};
use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
Expand Down Expand Up @@ -44,6 +44,7 @@ pub struct AllScorer {
}

impl DocSet for AllScorer {
#[inline(always)]
fn advance(&mut self) -> DocId {
if self.doc + 1 >= self.max_doc {
self.doc = TERMINATED;
Expand All @@ -53,6 +54,30 @@ impl DocSet for AllScorer {
self.doc
}

fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
let is_safe_distance = self.doc() + (buffer.len() as u32) < self.max_doc;
if is_safe_distance {
let num_items = buffer.len();
for buffer_val in buffer {
*buffer_val = self.doc();
self.doc += 1;
}
num_items
} else {
for (i, buffer_val) in buffer.iter_mut().enumerate() {
*buffer_val = self.doc();
if self.advance() == TERMINATED {
return i + 1;
}
}
buffer.len()
}
}

#[inline(always)]
fn doc(&self) -> DocId {
self.doc
}
Expand Down
1 change: 1 addition & 0 deletions src/query/bitset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl From<BitSet> for BitSetDocSet {
}

impl DocSet for BitSetDocSet {
#[inline]
fn advance(&mut self) -> DocId {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket * 64u32) | lower;
Expand Down
11 changes: 7 additions & 4 deletions src/query/boolean_query/boolean_weight.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::HashMap;

use crate::core::SegmentReader;
use crate::docset::BUFFER_LEN;
use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset, for_each_pruning_scorer, for_each_scorer};
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{
intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer,
Union, Weight,
Expand Down Expand Up @@ -222,16 +223,18 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each_no_score(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId),
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let mut buffer = [0u32; BUFFER_LEN];

match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
for_each_docset(&mut union_scorer, callback);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset(scorer.as_mut(), callback);
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
}
Ok(())
Expand Down
3 changes: 2 additions & 1 deletion src/query/boost_query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt;

use crate::docset::BUFFER_LEN;
use crate::fastfield::AliveBitSet;
use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
Expand Down Expand Up @@ -106,7 +107,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.seek(target)
}

fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
self.underlying.fill_buffer(buffer)
}

Expand Down
Loading

0 comments on commit 116f2fb

Please sign in to comment.