Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GroupValueBytesView for aggregation with StringView types #11519

Merged
merged 12 commits into from
Jul 20, 2024
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use ahash::RandomState;
use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -230,6 +231,9 @@ impl AggregateUDFImpl for Count {
DataType::Utf8 => {
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
}
DataType::Utf8View => {
Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8))
}
DataType::LargeUtf8 => {
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values

use crate::binary_map::{ArrowBytesSet, OutputType};
use crate::binary_view_map::ArrowBytesViewSet;
use arrow::array::{ArrayRef, OffsetSizeTrait};
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array_nullable;
Expand Down Expand Up @@ -88,3 +89,63 @@ impl<O: OffsetSizeTrait> Accumulator for BytesDistinctCountAccumulator<O> {
std::mem::size_of_val(self) + self.0.size()
}
}

/// Specialized implementation of
/// `COUNT DISTINCT` for [`StringViewArray`] and [`BinaryViewArray`].
///
/// [`StringViewArray`]: arrow::array::StringViewArray
/// [`BinaryViewArray`]: arrow::array::BinaryViewArray
#[derive(Debug)]
pub struct BytesViewDistinctCountAccumulator(ArrowBytesViewSet);

impl BytesViewDistinctCountAccumulator {
pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesViewSet::new(output_type))
}
}

impl Accumulator for BytesViewDistinctCountAccumulator {
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
let set = self.0.take();
let arr = set.into_state();
let list = Arc::new(array_into_list_array_nullable(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
if values.is_empty() {
return Ok(());
}

self.0.insert(&values[0]);

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
self.0.insert(&list);
};
Ok(())
})
}

fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self) + self.0.size()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ mod bytes;
mod native;

pub use bytes::BytesDistinctCountAccumulator;
pub use bytes::BytesViewDistinctCountAccumulator;
pub use native::FloatDistinctCountAccumulator;
pub use native::PrimitiveDistinctCountAccumulator;
6 changes: 6 additions & 0 deletions datafusion/physical-expr-common/src/binary_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ use std::sync::Arc;
pub enum OutputType {
/// `StringArray` or `LargeStringArray`
Utf8,
/// `StringViewArray`
Utf8View,
/// `BinaryArray` or `LargeBinaryArray`
Binary,
/// `BinaryViewArray`
BinaryView,
}

/// HashSet optimized for storing string or binary values that can produce that
Expand Down Expand Up @@ -318,6 +322,7 @@ where
observe_payload_fn,
)
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
};
}

Expand Down Expand Up @@ -516,6 +521,7 @@ where
GenericStringArray::new_unchecked(offsets, values, nulls)
})
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}

Expand Down
21 changes: 13 additions & 8 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@ use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use std::fmt::Debug;
use std::sync::Arc;

/// Should the output be a String or Binary?
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OutputType {
/// `StringViewArray`
Utf8View,
/// `BinaryViewArray`
BinaryView,
}
use crate::binary_map::OutputType;

/// HashSet optimized for storing string or binary values that can produce that
/// the final set as a `GenericBinaryViewArray` with minimal copies.
Expand All @@ -55,6 +48,14 @@ impl ArrowBytesViewSet {
.insert_if_new(values, make_payload_fn, observe_payload_fn);
}

/// Return the contents of this map and replace it with a new empty map with
/// the same output type
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.0.output_type);
std::mem::swap(self, &mut new_self);
new_self
}

/// Converts this set into a `StringViewArray` or `BinaryViewArray`
/// containing each distinct value that was interned.
/// This is done without copying the values.
Expand Down Expand Up @@ -216,6 +217,7 @@ where
observe_payload_fn,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}

Expand Down Expand Up @@ -327,6 +329,9 @@ where
let array = unsafe { array.to_string_view_unchecked() };
Arc::new(array)
}
_ => {
unreachable!("Utf8/Binary should use `ArrowBytesMap`")
}
}
}

Expand Down
129 changes: 129 additions & 0 deletions datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::aggregates::group_values::GroupValues;
use arrow_array::{Array, ArrayRef, RecordBatch};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;

/// A [`GroupValues`] storing single column of Utf8View/BinaryView values
///
/// This specialization is significantly faster than using the more general
/// purpose `Row`s format
pub struct GroupValuesBytesView {
/// Map string/binary values to group index
map: ArrowBytesViewMap<usize>,
/// The total number of groups so far (used to assign group_index)
num_groups: usize,
}

impl GroupValuesBytesView {
pub fn new(output_type: OutputType) -> Self {
Self {
map: ArrowBytesViewMap::new(output_type),
num_groups: 0,
}
}
}

impl GroupValues for GroupValuesBytesView {
fn intern(
&mut self,
cols: &[ArrayRef],
groups: &mut Vec<usize>,
) -> datafusion_common::Result<()> {
assert_eq!(cols.len(), 1);

// look up / add entries in the table
let arr = &cols[0];

groups.clear();
self.map.insert_if_new(
arr,
// called for each new group
|_value| {
// assign new group index on each insert
let group_idx = self.num_groups;
self.num_groups += 1;
group_idx
},
// called for each group
|group_idx| {
groups.push(group_idx);
},
);

// ensure we assigned a group to for each row
assert_eq!(groups.len(), arr.len());
Ok(())
}

fn size(&self) -> usize {
self.map.size() + std::mem::size_of::<Self>()
}

fn is_empty(&self) -> bool {
self.num_groups == 0
}

fn len(&self) -> usize {
self.num_groups
}

fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
// Reset the map to default, and convert it into a single array
let map_contents = self.map.take().into_state();

let group_values = match emit_to {
EmitTo::All => {
self.num_groups -= map_contents.len();
map_contents
}
EmitTo::First(n) if n == self.len() => {
self.num_groups -= map_contents.len();
map_contents
}
EmitTo::First(n) => {
// if we only wanted to take the first n, insert the rest back
// into the map we could potentially avoid this reallocation, at
// the expense of much more complex code.
// see https://github.com/apache/datafusion/issues/9195
let emit_group_values = map_contents.slice(0, n);
let remaining_group_values =
map_contents.slice(n, map_contents.len() - n);

self.num_groups = 0;
let mut group_indexes = vec![];
self.intern(&[remaining_group_values], &mut group_indexes)?;

// Verify that the group indexes were assigned in the correct order
assert_eq!(0, group_indexes[0]);

emit_group_values
}
};

Ok(vec![group_values])
}

fn clear_shrink(&mut self, _batch: &RecordBatch) {
// in theory we could potentially avoid this reallocation and clear the
// contents of the maps, but for now we just reset the map from the beginning
self.map.take();
}
}
33 changes: 22 additions & 11 deletions datafusion/physical-plan/src/aggregates/group_values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow::record_batch::RecordBatch;
use arrow_array::{downcast_primitive, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use bytes_view::GroupValuesBytesView;
use datafusion_common::Result;

pub(crate) mod primitive;
Expand All @@ -28,6 +29,7 @@ mod row;
use row::GroupValuesRows;

mod bytes;
mod bytes_view;
use bytes::GroupValuesByes;
use datafusion_physical_expr::binary_map::OutputType;

Expand Down Expand Up @@ -67,17 +69,26 @@ pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
_ => {}
}

if let DataType::Utf8 = d {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
}
if let DataType::LargeUtf8 = d {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
}
if let DataType::Binary = d {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
}
if let DataType::LargeBinary = d {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
match d {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

DataType::Utf8 => {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
}
DataType::LargeUtf8 => {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
}
DataType::Utf8View => {
return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View)));
}
DataType::Binary => {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
}
DataType::LargeBinary => {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
}
DataType::BinaryView => {
return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView)));
}
_ => {}
}
}

Expand Down