Skip to content

Commit

Permalink
implement batch updating/merging in aggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangli20 committed Nov 10, 2023
1 parent a90cf5f commit 1cf9347
Show file tree
Hide file tree
Showing 11 changed files with 725 additions and 591 deletions.
38 changes: 38 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/agg_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,21 @@ impl AggContext {
Ok(())
}

pub fn partial_batch_update_input(
&self,
agg_bufs: &mut [AggBuf],
input_arrays: &[Vec<ArrayRef>],
) -> Result<usize> {
let mut mem_diff = 0;
if self.need_partial_update {
for (idx, agg) in &self.need_partial_update_aggs {
mem_diff +=
agg.partial_batch_update(agg_bufs, self.agg_addrs(*idx), &input_arrays[*idx])?;
}
}
Ok(mem_diff)
}

pub fn partial_update_input_all(
&self,
agg_buf: &mut AggBuf,
Expand Down Expand Up @@ -331,6 +346,29 @@ impl AggContext {
Ok(())
}

pub fn partial_batch_merge_input(
&self,
agg_bufs: &mut [AggBuf],
agg_buf_array: &BinaryArray,
) -> Result<usize> {
let mut mem_diff = 0;
if self.need_partial_merge {
let mut input_agg_bufs = agg_buf_array
.iter()
.map(|value| {
let mut input_agg_buf = self.initial_input_agg_buf.clone();
input_agg_buf.load_from_bytes(value.unwrap())?;
Ok(input_agg_buf)
})
.collect::<Result<Vec<_>>>()?;
for (idx, agg) in &self.need_partial_merge_aggs {
mem_diff +=
agg.partial_batch_merge(agg_bufs, &mut input_agg_bufs, self.agg_addrs(*idx))?;
}
}
Ok(mem_diff)
}

pub fn partial_merge_input_all(
&self,
agg_buf: &mut AggBuf,
Expand Down
48 changes: 27 additions & 21 deletions native-engine/datafusion-ext-plans/src/agg/agg_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use ahash::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use std::io::{BufReader, Read, Write};
use std::mem::size_of;
use std::mem::{size_of, ManuallyDrop};
use std::sync::{Arc, Weak};

use arrow::row::{RowConverter, Rows};
Expand Down Expand Up @@ -84,10 +84,10 @@ impl AggTables {
pub async fn update_entries(
&self,
key_rows: Rows,
fn_entry: impl Fn(usize, &mut AggBuf) -> Result<()>,
fn_entries: impl Fn(&mut [AggBuf]) -> Result<usize>,
) -> Result<()> {
let mut in_mem = self.in_mem.lock().await;
in_mem.update_entries(&self.agg_ctx, key_rows, fn_entry)?;
in_mem.update_entries(&self.agg_ctx, key_rows, fn_entries)?;

let mem_used = in_mem.mem_used();
drop(in_mem);
Expand Down Expand Up @@ -363,25 +363,26 @@ impl InMemTable {
&mut self,
agg_ctx: &Arc<AggContext>,
key_rows: Rows,
fn_entry: impl Fn(usize, &mut AggBuf) -> Result<()>,
fn_entries: impl Fn(&mut [AggBuf]) -> Result<usize>,
) -> Result<()> {
if self.is_hash {
self.update_hash_entries(agg_ctx, key_rows, fn_entry)
self.update_hash_entries(agg_ctx, key_rows, fn_entries)
} else {
self.update_unsorted_entries(agg_ctx, key_rows, fn_entry)
self.update_unsorted_entries(agg_ctx, key_rows, fn_entries)
}
}

fn update_hash_entries(
&mut self,
agg_ctx: &Arc<AggContext>,
key_rows: Rows,
fn_entry: impl Fn(usize, &mut AggBuf) -> Result<()>,
fn_entries: impl Fn(&mut [AggBuf]) -> Result<usize>,
) -> Result<()> {
let hashes: Vec<u64> = key_rows
.iter()
.map(|row| RANDOM_STATE.hash_one(row.as_ref()))
.collect();
let mut agg_bufs = Vec::with_capacity(key_rows.num_rows());

for (row_idx, row) in key_rows.iter().enumerate() {
match self
Expand All @@ -390,35 +391,40 @@ impl InMemTable {
.from_hash(hashes[row_idx], |&addr| {
self.map_keys.get(addr) == row.as_ref()
}) {
RawEntryMut::Occupied(mut view) => {
self.agg_buf_mem_used -= view.get().mem_size();
fn_entry(row_idx, view.get_mut())?;
self.agg_buf_mem_used += view.get().mem_size();
RawEntryMut::Occupied(view) => {
// safety: agg_buf lives longer than this function call.
// items in agg_bufs are later moved into ManuallyDrop to avoid double drop.
agg_bufs.push(unsafe { std::ptr::read(view.get() as *const AggBuf) });
}
RawEntryMut::Vacant(view) => {
let new_key_addr = self.map_keys.add(row.as_ref());
let mut new_entry = agg_ctx.initial_agg_buf.clone();
fn_entry(row_idx, &mut new_entry)?;
self.agg_buf_mem_used += new_entry.mem_size();
let new_entry = agg_ctx.initial_agg_buf.clone();
// safety: agg_buf lives longer than this function call.
// items in agg_bufs are later moved into ManuallyDrop to avoid double drop.
agg_bufs.push(unsafe { std::ptr::read(&new_entry as *const AggBuf) });
view.insert(new_key_addr, new_entry);
}
}
}

self.agg_buf_mem_used += fn_entries(&mut agg_bufs)?;
for agg_buf in agg_bufs {
let _ = ManuallyDrop::new(agg_buf);
}
Ok(())
}

fn update_unsorted_entries(
&mut self,
agg_ctx: &Arc<AggContext>,
key_rows: Rows,
fn_entry: impl Fn(usize, &mut AggBuf) -> Result<()>,
fn_entries: impl Fn(&mut [AggBuf]) -> Result<usize>,
) -> Result<()> {
for i in 0..key_rows.num_rows() {
let mut new_entry = agg_ctx.initial_agg_buf.clone();
fn_entry(i, &mut new_entry)?;
self.agg_buf_mem_used += new_entry.mem_size();
self.unsorted_values.push(new_entry);
}
let beg = self.unsorted_values.len();
let len = key_rows.num_rows();
self.unsorted_values
.extend((0..len).map(|_| agg_ctx.initial_agg_buf.clone()));
self.agg_buf_mem_used += fn_entries(&mut self.unsorted_values[beg..][..len])?;
self.unsorted_keys_mem_used += key_rows.size();
self.unsorted_keys.push(key_rows);
Ok(())
Expand Down
32 changes: 32 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ impl Agg for AggAvg {
Ok(())
}

fn partial_batch_update(
&self,
agg_bufs: &mut [AggBuf],
agg_buf_addrs: &[u64],
values: &[ArrayRef],
) -> Result<usize> {
let mut mem_diff = 0;
mem_diff += self
.agg_sum
.partial_batch_update(agg_bufs, agg_buf_addrs, values)?;
mem_diff += self
.agg_count
.partial_batch_update(agg_bufs, agg_buf_addrs, values)?;
Ok(mem_diff)
}

fn partial_update_all(
&self,
agg_buf: &mut AggBuf,
Expand All @@ -129,6 +145,22 @@ impl Agg for AggAvg {
Ok(())
}

fn partial_batch_merge(
&self,
agg_bufs: &mut [AggBuf],
merging_agg_bufs: &mut [AggBuf],
agg_buf_addrs: &[u64],
) -> Result<usize> {
let mut mem_diff = 0;
mem_diff += self
.agg_sum
.partial_batch_merge(agg_bufs, merging_agg_bufs, agg_buf_addrs)?;
mem_diff +=
self.agg_count
.partial_batch_merge(agg_bufs, merging_agg_bufs, agg_buf_addrs)?;
Ok(mem_diff)
}

fn final_merge(&self, agg_buf: &mut AggBuf, agg_buf_addrs: &[u64]) -> Result<ScalarValue> {
let sum = self.agg_sum.final_merge(agg_buf, agg_buf_addrs)?;
let count = match self.agg_count.final_merge(agg_buf, &agg_buf_addrs[1..])? {
Expand Down
29 changes: 29 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ impl Agg for AggCount {
Ok(())
}

fn partial_batch_update(
&self,
agg_bufs: &mut [AggBuf],
agg_buf_addrs: &[u64],
values: &[ArrayRef],
) -> Result<usize> {
let addr = agg_buf_addrs[0];
let value = &values[0];
for (row_idx, agg_buf) in agg_bufs.iter_mut().enumerate() {
if value.is_valid(row_idx) {
agg_buf.update_fixed_value::<i64>(addr, |v| v + 1);
}
}
Ok(0)
}

fn partial_update_all(
&self,
agg_buf: &mut AggBuf,
Expand All @@ -99,6 +115,19 @@ impl Agg for AggCount {
Ok(())
}

fn partial_batch_merge(
&self,
agg_bufs: &mut [AggBuf],
merging_agg_bufs: &mut [AggBuf],
agg_buf_addrs: &[u64],
) -> Result<usize> {
let addr = agg_buf_addrs[0];
for (agg_buf, merging_agg_buf) in agg_bufs.iter_mut().zip(merging_agg_bufs) {
let merging_num_valids = merging_agg_buf.fixed_value::<i64>(addr);
agg_buf.update_fixed_value::<i64>(addr, |v| v + merging_num_valids);
}
Ok(0)
}
fn final_merge(&self, agg_buf: &mut AggBuf, agg_buf_addrs: &[u64]) -> Result<ScalarValue> {
let addr = agg_buf_addrs[0];
Ok(ScalarValue::from(agg_buf.fixed_value::<i64>(addr)))
Expand Down
Loading

0 comments on commit 1cf9347

Please sign in to comment.