diff --git a/native-engine/datafusion-ext-commons/src/bytes_arena.rs b/native-engine/datafusion-ext-commons/src/bytes_arena.rs index 43722ae2..dd6e45e9 100644 --- a/native-engine/datafusion-ext-commons/src/bytes_arena.rs +++ b/native-engine/datafusion-ext-commons/src/bytes_arena.rs @@ -31,7 +31,7 @@ impl Default for BytesArena { } impl BytesArena { - pub fn add(&mut self, bytes: &[u8]) -> u64 { + pub fn add(&mut self, bytes: &[u8]) -> BytesArenaAddr { // assume bytes_len < 2^32 let cur_buf_len = self.cur_buf().len(); let len = bytes.len(); @@ -45,16 +45,16 @@ impl BytesArena { let id = self.bufs.len() - 1; let offset = self.cur_buf().len(); self.cur_buf_mut().extend_from_slice(bytes); - make_arena_addr(id, offset, len) + BytesArenaAddr::new(id, offset, len) } - pub fn get(&self, addr: u64) -> &[u8] { - let (id, offset, len) = unapply_arena_addr(addr); + pub fn get(&self, addr: BytesArenaAddr) -> &[u8] { + let unpacked = addr.unpack(); unsafe { // safety - performance critical, assume addr is valid self.bufs - .get_unchecked(id) - .get_unchecked(offset..offset + len) + .get_unchecked(unpacked.id) + .get_unchecked(unpacked.offset..unpacked.offset + unpacked.len) } } @@ -64,17 +64,17 @@ impl BytesArena { /// specialized for merging two parts in sort-exec /// works like an IntoIterator, free memory of visited items - pub fn specialized_get_and_drop_last(&mut self, addr: u64) -> &[u8] { - let (id, offset, len) = unapply_arena_addr(addr); - if id > 0 && !self.bufs[id - 1].is_empty() { - self.bufs[id - 1].truncate(0); // drop last buf - self.bufs[id - 1].shrink_to_fit(); + pub fn specialized_get_and_drop_last(&mut self, addr: BytesArenaAddr) -> &[u8] { + let unpacked = addr.unpack(); + if unpacked.id > 0 && !self.bufs[unpacked.id - 1].is_empty() { + self.bufs[unpacked.id - 1].truncate(0); // drop last buf + self.bufs[unpacked.id - 1].shrink_to_fit(); } unsafe { // safety - performance critical, assume addr is valid self.bufs - .get_unchecked(id) - .get_unchecked(offset..offset + len) + .get_unchecked(unpacked.id) + .get_unchecked(unpacked.offset..unpacked.offset + unpacked.len) } } @@ -97,14 +97,27 @@ impl BytesArena { } } -fn make_arena_addr(id: usize, offset: usize, len: usize) -> u64 { - (id as u64 * BUF_CAPACITY_TARGET as u64 + offset as u64) << 32 | len as u64 +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct BytesArenaAddr(u64); + +impl BytesArenaAddr { + pub fn new(id: usize, offset: usize, len: usize) -> Self { + Self((id as u64 * BUF_CAPACITY_TARGET as u64 + offset as u64) << 32 | len as u64) + } + + pub fn unpack(self) -> UnpackedBytesArenaAddr { + let id_offset = self.0 >> 32; + let id = (id_offset / (BUF_CAPACITY_TARGET as u64)) as usize; + let offset = (id_offset % (BUF_CAPACITY_TARGET as u64)) as usize; + let len = (self.0 << 32 >> 32) as usize; + + UnpackedBytesArenaAddr { id, offset, len } + } } -fn unapply_arena_addr(addr: u64) -> (usize, usize, usize) { - let id_offset = addr >> 32; - let id = id_offset / (BUF_CAPACITY_TARGET as u64); - let offset = id_offset % (BUF_CAPACITY_TARGET as u64); - let len = addr << 32 >> 32; - (id as usize, offset as usize, len as usize) +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct UnpackedBytesArenaAddr { + pub id: usize, + pub offset: usize, + pub len: usize, } diff --git a/native-engine/datafusion-ext-commons/src/io/scalar_serde.rs b/native-engine/datafusion-ext-commons/src/io/scalar_serde.rs index d2357703..3f7ac0dd 100644 --- a/native-engine/datafusion-ext-commons/src/io/scalar_serde.rs +++ b/native-engine/datafusion-ext-commons/src/io/scalar_serde.rs @@ -22,169 +22,150 @@ use crate::{ io::{read_bytes_slice, read_len, read_u8, write_len, write_u8}, }; -pub fn write_scalar(value: &ScalarValue, output: &mut W) -> Result<()> { - fn write_primitive_valid_scalar(buf: &[u8], output: &mut W) -> Result<()> { - write_len(1 as usize, output)?; - output.write_all(buf)?; - Ok(()) +pub fn write_scalar(value: &ScalarValue, nullable: bool, output: &mut W) -> Result<()> { + assert!(nullable || !value.is_null()); + + macro_rules! write_prim { + ($v:expr) => {{ + if nullable { + if let Some(v) = $v { + write_u8(1, output)?; + output.write_all(&v.to_ne_bytes())?; + } else { + write_u8(0, output)?; + } + } else { + output.write_all(&$v.unwrap().to_ne_bytes())?; + } + }}; } + match value { - ScalarValue::Null => {} - ScalarValue::Boolean(Some(value)) => write_u8((*value as u8) + 1u8, output)?, - ScalarValue::Int8(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Int16(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Int32(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Int64(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::UInt8(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::UInt16(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::UInt32(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::UInt64(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Float32(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Float64(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Decimal128(Some(value), ..) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Utf8(Some(value)) => { - let value_bytes = value.as_bytes(); - write_len(value_bytes.len() + 1, output)?; - output.write_all(value_bytes)?; - } - ScalarValue::Binary(Some(value)) => { - let value_byte = value.as_bytes(); - write_len(value_byte.len() + 1, output)?; - output.write_all(value_byte)?; - } - ScalarValue::Date32(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::Date64(Some(value)) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::TimestampSecond(Some(value), _) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::TimestampMillisecond(Some(value), _) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? - } - ScalarValue::TimestampMicrosecond(Some(value), _) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? + ScalarValue::Null => write_u8(0, output)?, + ScalarValue::Boolean(v) => write_prim!(v.map(|v| v as i8)), + ScalarValue::Int8(v) => write_prim!(v), + ScalarValue::Int16(v) => write_prim!(v), + ScalarValue::Int32(v) => write_prim!(v), + ScalarValue::Int64(v) => write_prim!(v), + ScalarValue::UInt8(v) => write_prim!(v), + ScalarValue::UInt16(v) => write_prim!(v), + ScalarValue::UInt32(v) => write_prim!(v), + ScalarValue::UInt64(v) => write_prim!(v), + ScalarValue::Float32(v) => write_prim!(v), + ScalarValue::Float64(v) => write_prim!(v), + ScalarValue::Decimal128(v, ..) => write_prim!(v), + ScalarValue::Date32(v) => write_prim!(v), + ScalarValue::Date64(v) => write_prim!(v), + ScalarValue::TimestampSecond(v, ..) => write_prim!(v), + ScalarValue::TimestampMillisecond(v, ..) => write_prim!(v), + ScalarValue::TimestampMicrosecond(v, ..) => write_prim!(v), + ScalarValue::TimestampNanosecond(v, ..) => write_prim!(v), + ScalarValue::Utf8(v) => { + if let Some(v) = v { + write_len(v.as_bytes().len() + 1, output)?; + output.write_all(v.as_bytes())?; + } else { + write_len(0, output)?; + } } - ScalarValue::TimestampNanosecond(Some(value), _) => { - write_primitive_valid_scalar(value.to_ne_bytes().as_slice(), output)? + ScalarValue::Binary(v) => { + if let Some(v) = v { + write_len(v.as_bytes().len() + 1, output)?; + output.write_all(v.as_bytes())?; + } else { + write_len(0, output)?; + } } - ScalarValue::List(Some(value), _field) => { - write_len(value.len() + 1, output)?; - if value.len() != 0 { - for element in value { - write_scalar(element, output)?; + ScalarValue::List(v, field) => { + if let Some(v) = v { + write_len(v.len() + 1, output)?; + for v in v { + write_scalar(v, field.is_nullable(), output)?; } + } else { + write_len(0, output)?; } } - ScalarValue::Struct(Some(value), _fields) => { - write_len(value.len() + 1, output)?; - if value.len() != 0 { - for element in value { - write_scalar(element, output)?; + ScalarValue::Struct(v, fields) => { + if nullable { + if let Some(v) = v { + write_u8(1, output)?; + for (v, field) in v.iter().zip(fields) { + write_scalar(v, field.is_nullable(), output)?; + } + } else { + write_u8(0, output)?; + } + } else { + for (v, field) in v.as_ref().unwrap().iter().zip(fields) { + write_scalar(v, field.is_nullable(), output)?; } } } ScalarValue::Map(value, _bool) => { - write_scalar(value, output)?; - } - ScalarValue::Boolean(None) - | ScalarValue::Int8(None) - | ScalarValue::Int16(None) - | ScalarValue::Int32(None) - | ScalarValue::Int64(None) - | ScalarValue::UInt8(None) - | ScalarValue::UInt16(None) - | ScalarValue::UInt32(None) - | ScalarValue::UInt64(None) - | ScalarValue::Float32(None) - | ScalarValue::Float64(None) - | ScalarValue::Decimal128(None, ..) - | ScalarValue::Binary(None) - | ScalarValue::Utf8(None) - | ScalarValue::Date32(None) - | ScalarValue::Date64(None) - | ScalarValue::TimestampSecond(None, _) - | ScalarValue::TimestampMillisecond(None, _) - | ScalarValue::TimestampMicrosecond(None, _) - | ScalarValue::TimestampNanosecond(None, _) - | ScalarValue::List(None, _) - | ScalarValue::Struct(None, ..) => write_len(0 as usize, output)?, + write_scalar(value, nullable, output)?; + } other => df_unimplemented_err!("unsupported scalarValue type: {other}")?, } Ok(()) } -pub fn read_scalar(input: &mut R, data_type: &DataType) -> Result { - macro_rules! read_primitive_scalar { - ($input:ident, $len:expr, $byte_kind:ident) => {{ - let valid = read_len(input)?; - if valid != 0 { - let mut buf = [0; $len]; - $input.read_exact(&mut buf)?; - Some($byte_kind::from_ne_bytes(buf)) +pub fn read_scalar( + input: &mut R, + data_type: &DataType, + nullable: bool, +) -> Result { + macro_rules! read_prim { + ($ty:ty) => {{ + if nullable { + let valid = read_u8(input)? != 0; + if valid { + let mut buf = [0u8; std::mem::size_of::<$ty>()]; + input.read_exact(&mut buf)?; + Some(<$ty>::from_ne_bytes(buf)) + } else { + None + } } else { - None + let mut buf = [0u8; std::mem::size_of::<$ty>()]; + input.read_exact(&mut buf)?; + Some(<$ty>::from_ne_bytes(buf)) } }}; } Ok(match data_type { - DataType::Null => ScalarValue::Null, - DataType::Boolean => match read_u8(input)? { - 0u8 => ScalarValue::Boolean(None), - 1u8 => ScalarValue::Boolean(Some(false)), - _ => ScalarValue::Boolean(Some(true)), - }, - DataType::Int8 => ScalarValue::Int8(read_primitive_scalar!(input, 1, i8)), - DataType::Int16 => ScalarValue::Int16(read_primitive_scalar!(input, 2, i16)), - DataType::Int32 => ScalarValue::Int32(read_primitive_scalar!(input, 4, i32)), - DataType::Int64 => ScalarValue::Int64(read_primitive_scalar!(input, 8, i64)), - DataType::UInt8 => ScalarValue::UInt8(read_primitive_scalar!(input, 1, u8)), - DataType::UInt16 => ScalarValue::UInt16(read_primitive_scalar!(input, 2, u16)), - DataType::UInt32 => ScalarValue::UInt32(read_primitive_scalar!(input, 4, u32)), - DataType::UInt64 => ScalarValue::UInt64(read_primitive_scalar!(input, 8, u64)), - DataType::Float32 => ScalarValue::Float32(read_primitive_scalar!(input, 4, f32)), - DataType::Float64 => ScalarValue::Float64(read_primitive_scalar!(input, 8, f64)), + DataType::Null => { + read_u8(input)?; + ScalarValue::Null + } + DataType::Boolean => ScalarValue::Boolean(read_prim!(u8).map(|v| v != 0)), + DataType::Int8 => ScalarValue::Int8(read_prim!(i8)), + DataType::Int16 => ScalarValue::Int16(read_prim!(i16)), + DataType::Int32 => ScalarValue::Int32(read_prim!(i32)), + DataType::Int64 => ScalarValue::Int64(read_prim!(i64)), + DataType::UInt8 => ScalarValue::UInt8(read_prim!(u8)), + DataType::UInt16 => ScalarValue::UInt16(read_prim!(u16)), + DataType::UInt32 => ScalarValue::UInt32(read_prim!(u32)), + DataType::UInt64 => ScalarValue::UInt64(read_prim!(u64)), + DataType::Float32 => ScalarValue::Float32(read_prim!(f32)), + DataType::Float64 => ScalarValue::Float64(read_prim!(f64)), DataType::Decimal128(precision, scale) => { - ScalarValue::Decimal128(read_primitive_scalar!(input, 16, i128), *precision, *scale) + ScalarValue::Decimal128(read_prim!(i128), *precision, *scale) } - DataType::Date32 => ScalarValue::Date32(read_primitive_scalar!(input, 4, i32)), - DataType::Date64 => ScalarValue::Date64(read_primitive_scalar!(input, 8, i64)), + DataType::Date32 => ScalarValue::Date32(read_prim!(i32)), + DataType::Date64 => ScalarValue::Date64(read_prim!(i64)), DataType::Timestamp(TimeUnit::Second, str) => { - ScalarValue::TimestampSecond(read_primitive_scalar!(input, 8, i64), str.clone()) + ScalarValue::TimestampSecond(read_prim!(i64), str.clone()) } DataType::Timestamp(TimeUnit::Millisecond, str) => { - ScalarValue::TimestampMillisecond(read_primitive_scalar!(input, 8, i64), str.clone()) + ScalarValue::TimestampMillisecond(read_prim!(i64), str.clone()) } DataType::Timestamp(TimeUnit::Microsecond, str) => { - ScalarValue::TimestampMicrosecond(read_primitive_scalar!(input, 8, i64), str.clone()) + ScalarValue::TimestampMicrosecond(read_prim!(i64), str.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, str) => { - ScalarValue::TimestampNanosecond(read_primitive_scalar!(input, 8, i64), str.clone()) + ScalarValue::TimestampNanosecond(read_prim!(i64), str.clone()) } DataType::Binary => { let data_len = read_len(input)?; @@ -210,34 +191,88 @@ pub fn read_scalar(input: &mut R, data_type: &DataType) -> Result 0 { let data_len = data_len - 1; - let mut list_data: Vec = Vec::with_capacity(data_len); + let mut children = Vec::with_capacity(data_len); for _i in 0..data_len { - let child_value = read_scalar(input, field.data_type())?; - list_data.push(child_value); + children.push(read_scalar(input, field.data_type(), field.is_nullable())?); } - ScalarValue::List(Some(list_data), field.clone()) + ScalarValue::List(Some(children), field.clone()) } else { ScalarValue::List(None, field.clone()) } } DataType::Struct(fields) => { - let data_len = read_len(input)?; - if data_len > 0 { - let data_len = data_len - 1; - let mut struct_data: Vec = Vec::with_capacity(data_len); - for i in 0..data_len { - let child_value = read_scalar(input, fields[i].data_type())?; - struct_data.push(child_value); + if nullable { + let valid = read_u8(input)? != 0; + if valid { + let mut children = Vec::with_capacity(fields.len()); + for field in fields { + children.push(read_scalar(input, field.data_type(), field.is_nullable())?); + } + ScalarValue::Struct(Some(children), fields.clone()) + } else { + ScalarValue::Struct(None, fields.clone()) } - ScalarValue::Struct(Some(struct_data), fields.clone()) } else { - ScalarValue::Struct(None, fields.clone()) + let mut children = Vec::with_capacity(fields.len()); + for field in fields { + children.push(read_scalar(input, field.data_type(), field.is_nullable())?); + } + ScalarValue::Struct(Some(children), fields.clone()) } } DataType::Map(field, bool) => { - let map_value = read_scalar(input, field.data_type())?; + let map_value = read_scalar(input, field.data_type(), field.is_nullable())?; ScalarValue::Map(Box::new(map_value), *bool) } other => df_unimplemented_err!("unsupported data type: {other}")?, }) } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use arrow_schema::DataType; + use datafusion::common::{Result, ScalarValue}; + + use crate::io::{read_scalar, write_scalar}; + + #[test] + fn test() -> Result<()> { + let mut buf = vec![]; + + write_scalar(&ScalarValue::from(123), false, &mut buf)?; + write_scalar(&ScalarValue::from("Wooden"), false, &mut buf)?; + write_scalar(&ScalarValue::from("Slash"), true, &mut buf)?; + write_scalar(&ScalarValue::Utf8(None), true, &mut buf)?; + write_scalar(&ScalarValue::Null, true, &mut buf)?; + write_scalar(&ScalarValue::from(3.15), false, &mut buf)?; + + let mut cur = Cursor::new(&buf); + assert_eq!( + read_scalar(&mut cur, &DataType::Int32, false)?, + ScalarValue::from(123) + ); + assert_eq!( + read_scalar(&mut cur, &DataType::Utf8, false)?, + ScalarValue::from("Wooden") + ); + assert_eq!( + read_scalar(&mut cur, &DataType::Utf8, true)?, + ScalarValue::from("Slash") + ); + assert_eq!( + read_scalar(&mut cur, &DataType::Utf8, true)?, + ScalarValue::Utf8(None) + ); + assert_eq!( + read_scalar(&mut cur, &DataType::Null, true)?, + ScalarValue::Null + ); + assert_eq!( + read_scalar(&mut cur, &DataType::Float64, false)?, + ScalarValue::from(3.15) + ); + Ok(()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/acc.rs b/native-engine/datafusion-ext-plans/src/agg/acc.rs index 66983c3d..b18b3a4e 100644 --- a/native-engine/datafusion-ext-plans/src/agg/acc.rs +++ b/native-engine/datafusion-ext-plans/src/agg/acc.rs @@ -25,13 +25,16 @@ use datafusion::{ }; use datafusion_ext_commons::{ df_execution_err, downcast_any, - io::{read_bytes_slice, read_len, read_scalar, write_len, write_scalar}, + io::{read_bytes_slice, read_len, read_scalar, read_u8, write_len, write_scalar, write_u8}, slim_bytes::SlimBytes, }; -use hashbrown::HashSet; +use hashbrown::raw::RawTable; +use itertools::Itertools; use slimmer_box::SlimmerBox; use smallvec::SmallVec; +use crate::agg::agg_table::gx_hash; + pub type DynVal = Option>; const ACC_STORE_BLOCK_SIZE: usize = 65536; @@ -440,46 +443,64 @@ pub fn create_dyn_loaders_from_initial_value(values: &[AccumInitialValue]) -> Re other => { let dt = other.get_datatype(); Box::new(move |r: &mut LoadReader| { - Ok(Some(Box::new(AggDynScalar::new(read_scalar( - &mut r.0, &dt, - )?)))) + let valid = read_u8(&mut r.0)? != 0; + if valid { + let scalar = read_scalar(&mut r.0, &dt, false)?; + Ok(Some(Box::new(AggDynScalar::new(scalar)))) + } else { + Ok(None) + } }) } }, - AccumInitialValue::DynList(dt) => { - let dt = dt.clone(); - Box::new(move |r: &mut LoadReader| { - Ok(match read_len(&mut r.0)? { - 0 => None, - n => { - let data_len = n - 1; - let mut load_vec: SmallVec<[ScalarValue; 4]> = SmallVec::new(); - for _i in 0..data_len { - load_vec.push(read_scalar(&mut r.0, &dt)?); - } - Some(Box::new(AggDynList { values: load_vec })) - } - }) + AccumInitialValue::DynList(_dt) => Box::new(move |r: &mut LoadReader| { + Ok(match read_len(&mut r.0)? { + 0 => None, + n => { + let data_len = n - 1; + let raw = read_bytes_slice(&mut r.0, data_len)?.into_vec(); + Some(Box::new(AggDynList { raw })) + } }) - } - AccumInitialValue::DynSet(dt) => { - let dt = dt.clone(); - Box::new(move |r: &mut LoadReader| { - Ok(match read_len(&mut r.0)? { - 0 => None, - n => { - let vec_len = n - 1; - let mut scalar_vec: SmallVec<[ScalarValue; 4]> = SmallVec::new(); - for _i in 0..vec_len { - scalar_vec.push(read_scalar(&mut r.0, &dt)?); + }), + AccumInitialValue::DynSet(_dt) => Box::new(move |r: &mut LoadReader| { + Ok(match read_len(&mut r.0)? { + 0 => None, + n => { + let data_len = n - 1; + let raw = read_bytes_slice(&mut r.0, data_len)?.into_vec(); + let num_items = read_len(&mut r.0)?; + + let list = AggDynList { raw }; + let mut internal_set = if num_items <= 4 { + InternalSet::Small(SmallVec::new()) + } else { + InternalSet::Huge(RawTable::with_capacity(num_items)) + }; + + let mut pos = 0; + for _ in 0..num_items { + let pos_len = (pos, read_len(&mut r.0)? as u32); + pos += pos_len.1; + + match &mut internal_set { + InternalSet::Small(s) => s.push(pos_len), + InternalSet::Huge(s) => { + let raw = list.ref_raw(pos_len); + let hash = gx_hash::(raw); + s.insert(hash, pos_len, |&pos_len| { + gx_hash::(list.ref_raw(pos_len)) + }); + } } - Some(Box::new(AggDynSet { - values: OptimizedSet::SmallVec(scalar_vec), - })) } - }) + Some(Box::new(AggDynSet { + list, + set: internal_set, + })) + } }) - } + }), }; loaders.push(loader); } @@ -543,10 +564,13 @@ pub fn create_dyn_savers_from_initial_value(values: &[AccumInitialValue]) -> Res _other => { fn f(w: &mut SaveWriter, v: DynVal) -> Result<()> { if let Some(v) = v { - write_scalar(&downcast_any!(v, AggDynScalar)?.value, &mut w.0) - } else { - write_scalar(&ScalarValue::Int32(None), &mut w.0) + let scalar = &downcast_any!(v, AggDynScalar)?.value; + if !scalar.is_null() { + write_u8(1, &mut w.0)?; + return write_scalar(scalar, false, &mut w.0); + } } + return write_u8(0, &mut w.0); } let f: SaveFn = Box::new(f); f @@ -554,19 +578,15 @@ pub fn create_dyn_savers_from_initial_value(values: &[AccumInitialValue]) -> Res }, AccumInitialValue::DynList(_dt) => { fn f(w: &mut SaveWriter, v: DynVal) -> Result<()> { - match v { - None => write_len(0, &mut w.0)?, - Some(v) => { - let list = v - .as_any_boxed() - .downcast::() - .or_else(|_| df_execution_err!("error downcasting to AggDynList"))? - .into_values(); - write_len(list.len() + 1, &mut w.0)?; - for v in list { - write_scalar(&v, &mut w.0)?; - } - } + if let Some(v) = v { + let list = v + .as_any_boxed() + .downcast::() + .or_else(|_| df_execution_err!("error downcasting to AggDynList"))?; + write_len(list.raw.len() + 1, &mut w.0)?; + w.0.write_all(&list.raw)?; + } else { + write_len(0, &mut w.0)?; } Ok(()) } @@ -574,35 +594,28 @@ pub fn create_dyn_savers_from_initial_value(values: &[AccumInitialValue]) -> Res f } AccumInitialValue::DynSet(_dt) => { - fn f(w: &mut SaveWriter, v: DynVal) -> Result<()> { - match v { - None => write_len(0, &mut w.0)?, - Some(v) => { - let set = v - .as_any_boxed() - .downcast::() - .or_else(|_| df_execution_err!("error downcasting to AggDynSet"))? - .into_values(); - - match set { - OptimizedSet::SmallVec(vec) => { - write_len(vec.len() + 1, &mut w.0)?; - for v in vec { - write_scalar(&v, &mut w.0)?; - } - } - OptimizedSet::Set(set) => { - write_len(set.len() + 1, &mut w.0)?; - for v in set { - write_scalar(&v, &mut w.0)?; - } - } - } + let f: SaveFn = Box::new(move |w: &mut SaveWriter, v: DynVal| -> Result<()> { + if let Some(v) = v { + let mut set = v + .as_any_boxed() + .downcast::() + .or_else(|_| df_execution_err!("error downcasting to AggDynSet"))?; + write_len(set.list.raw.len() + 1, &mut w.0)?; + w.0.write_all(&set.list.raw)?; + + write_len(set.set.len(), &mut w.0)?; + for len in std::mem::take(&mut set.set) + .into_iter() + .sorted() + .map(|pos_len| pos_len.1) + { + write_len(len as usize, &mut w.0)?; } + } else { + write_len(0, &mut w.0)?; } Ok(()) - } - let f: SaveFn = Box::new(f); + }); f } }; @@ -754,26 +767,37 @@ impl AggDynValue for AggDynStr { } } -#[derive(Clone, Default, Eq, PartialEq)] +#[derive(Clone, Default)] pub struct AggDynList { - pub values: SmallVec<[ScalarValue; 4]>, + pub raw: Vec, } impl AggDynList { - pub fn append(&mut self, value: ScalarValue) { - self.values.push(value); + pub fn append(&mut self, value: &ScalarValue, nullable: bool) { + write_scalar(&value, nullable, &mut self.raw).unwrap(); } pub fn merge(&mut self, other: &mut Self) { - self.values.append(&mut other.values); + self.raw.extend(std::mem::take(&mut other.raw)); } - pub fn values(&self) -> &[ScalarValue] { - self.values.as_slice() + pub fn into_values(self, dt: DataType, nullable: bool) -> impl Iterator { + struct ValuesIterator(Cursor>, DataType, bool); + impl Iterator for ValuesIterator { + type Item = ScalarValue; + + fn next(&mut self) -> Option { + if self.0.position() < self.0.get_ref().len() as u64 { + return Some(read_scalar(&mut self.0, &self.1, self.2).unwrap()); + } + None + } + } + ValuesIterator(Cursor::new(self.raw), dt, nullable) } - pub fn into_values(self) -> SmallVec<[ScalarValue; 4]> { - self.values + fn ref_raw(&self, pos_len: (u32, u32)) -> &[u8] { + &self.raw[pos_len.0 as usize..][..pos_len.1 as usize] } } @@ -791,19 +815,7 @@ impl AggDynValue for AggDynList { } fn mem_size(&self) -> usize { - let spilled_size = if self.values.spilled() { - self.values.capacity() * (1 + size_of::()) - } else { - 0 - }; - let mem_size = size_of::() - + self - .values - .iter() - .map(|sv| sv.size() - size_of_val(sv)) - .sum::() - + spilled_size; - mem_size + size_of::() + self.raw.capacity() } fn clone_boxed(&self) -> Box { @@ -811,60 +823,126 @@ impl AggDynValue for AggDynList { } } -#[derive(Clone, Default, Eq, PartialEq)] +#[derive(Clone, Default)] pub struct AggDynSet { - pub values: OptimizedSet, + list: AggDynList, + set: InternalSet, } -impl AggDynSet { - pub fn append(&mut self, value: ScalarValue) { - match &mut self.values { - OptimizedSet::SmallVec(vec) => { - if vec.len() < vec.inline_size() { - vec.push(value); - } else { - let mut value_set = HashSet::from_iter(std::mem::take(vec).into_iter()); - value_set.insert(value); - self.values = OptimizedSet::Set(value_set); - } - } - OptimizedSet::Set(value_set) => { - value_set.insert(value); - } +#[derive(Clone)] +enum InternalSet { + Small(SmallVec<[(u32, u32); 4]>), + Huge(RawTable<(u32, u32)>), +} + +impl Default for InternalSet { + fn default() -> Self { + Self::Small(SmallVec::new()) + } +} + +impl InternalSet { + fn len(&self) -> usize { + match self { + InternalSet::Small(s) => s.len(), + InternalSet::Huge(s) => s.len(), } } - pub fn merge(&mut self, other: &mut Self) { - match (&mut self.values, &mut other.values) { - (OptimizedSet::SmallVec(vec1), OptimizedSet::SmallVec(vec2)) => { - if vec1.len() + vec2.len() <= vec1.inline_size() { - vec1.append(vec2); - } else { - let new_set = HashSet::from_iter( - std::mem::take(vec1).into_iter().chain(std::mem::take(vec2)), - ); - self.values = OptimizedSet::Set(new_set); - } + fn into_iter(self) -> impl Iterator { + let iter: Box> = match self { + InternalSet::Small(s) => Box::new(s.into_iter()), + InternalSet::Huge(s) => Box::new(s.into_iter()), + }; + iter + } + + fn insert(&mut self, list: &mut AggDynList, raw_value: &[u8]) { + if let Self::Small(s) = self { + if s.len() == s.inline_size() { + self.convert_to_huge(list); } - (OptimizedSet::SmallVec(vec), OptimizedSet::Set(set)) => { - set.extend(std::mem::take(vec).into_iter()); - self.values = OptimizedSet::Set(std::mem::take(set)); + } + + match self { + InternalSet::Small(s) => { + for &mut pos_len in &mut *s { + if list.ref_raw(pos_len) == raw_value { + return; + } + } + let new_pos_len = (list.raw.len() as u32, raw_value.len() as u32); + list.raw.extend_from_slice(raw_value); + s.push(new_pos_len); } - (OptimizedSet::Set(set), OptimizedSet::SmallVec(vec)) => { - set.extend(std::mem::take(vec).into_iter()); + InternalSet::Huge(s) => { + let hash = gx_hash::(raw_value); + match s.find_or_find_insert_slot( + hash, + |&pos_len| { + raw_value.len() == pos_len.1 as usize && raw_value == list.ref_raw(pos_len) + }, + |&pos_len| gx_hash::(list.ref_raw(pos_len)), + ) { + Ok(_found) => {} + Err(slot) => { + let new_pos_len = (list.raw.len() as u32, raw_value.len() as u32); + list.raw.extend_from_slice(&raw_value); + unsafe { + // safety: call unsafe `insert_in_slot` method + s.insert_in_slot(hash, slot, new_pos_len); + } + } + } } - (OptimizedSet::Set(set1), OptimizedSet::Set(set2)) => { - set1.extend(std::mem::take(set2).into_iter()); + } + } + + fn convert_to_huge(&mut self, list: &mut AggDynList) { + if let Self::Small(s) = self { + let mut huge = RawTable::default(); + + for &mut pos_len in s { + let raw = list.ref_raw(pos_len); + let hash = gx_hash::(raw); + huge.insert(hash, pos_len, |&pos_len| { + gx_hash::(list.ref_raw(pos_len)) + }); } + *self = Self::Huge(huge); } } +} - pub fn values(&self) -> &OptimizedSet { - &self.values +const AGG_DYN_SET_HASH_SEED: i64 = 0x7BCB48DA4C72B4F2; + +impl AggDynSet { + pub fn append(&mut self, value: &ScalarValue, nullable: bool) { + let mut raw_value = vec![]; + write_scalar(value, nullable, &mut raw_value).unwrap(); + self.append_raw(&raw_value); } - pub fn into_values(self) -> OptimizedSet { - self.values + pub fn merge(&mut self, other: &mut Self) { + for pos_len in std::mem::take(&mut other.set).into_iter() { + self.append_raw(other.ref_raw(pos_len)); + } + } + + pub fn into_values(self, dt: DataType, nullable: bool) -> impl Iterator { + self.list.into_values(dt, nullable) + } + + fn append_raw(&mut self, raw_value: &[u8]) { + let self_set = unsafe { + // safety: bypass borrow checking + std::mem::transmute::<_, &mut InternalSet>(&mut self.set) + }; + self_set.insert(&mut self.list, raw_value) + } + + fn ref_raw(&self, pos_len: (u32, u32)) -> &[u8] { + self.list.ref_raw(pos_len) } } @@ -882,7 +960,12 @@ impl AggDynValue for AggDynSet { } fn mem_size(&self) -> usize { - size_of::() + self.values.mem_size() - size_of_val(&self.values) + size_of::() + + self.list.raw.capacity() + + match &self.set { + InternalSet::Small(_) => 0, + InternalSet::Huge(s) => s.capacity() * size_of::<(u32, u32, u8)>(), + } } fn clone_boxed(&self) -> Box { @@ -890,40 +973,6 @@ impl AggDynValue for AggDynSet { } } -#[derive(Clone, Eq, PartialEq, Debug)] -pub enum OptimizedSet { - SmallVec(SmallVec<[ScalarValue; 4]>), - Set(HashSet), -} - -impl Default for OptimizedSet { - fn default() -> Self { - OptimizedSet::SmallVec(SmallVec::default()) - } -} - -impl OptimizedSet { - fn mem_size(&self) -> usize { - match self { - OptimizedSet::SmallVec(vec) => { - size_of::() - + vec - .iter() - .map(|sv| sv.size() - size_of_val(sv)) - .sum::() - } - OptimizedSet::Set(hash_set) => { - size_of::() - + hash_set.capacity() * size_of::() - + hash_set - .iter() - .map(|sv| sv.size() - size_of_val(sv)) - .sum::() - } - } - } -} - #[derive(Default, Clone, Copy)] pub struct AccumStateValAddr(u64); @@ -955,137 +1004,65 @@ impl AccumStateValAddr { #[cfg(test)] mod test { - use std::{io::Cursor, sync::Arc}; + use std::{collections::HashSet, io::Cursor}; - use arrow::datatypes::{DataType, Field, Fields}; + use arrow::datatypes::DataType; use datafusion::common::{Result, ScalarValue}; use datafusion_ext_commons::downcast_any; - use smallvec::SmallVec; use crate::agg::acc::{ create_acc_from_initial_value, create_dyn_loaders_from_initial_value, - create_dyn_savers_from_initial_value, AccumInitialValue, AccumStateRow, AggDynList, - AggDynSet, AggDynStr, LoadReader, OptimizedSet, SaveWriter, + create_dyn_savers_from_initial_value, AccumInitialValue, AccumStateRow, AggDynSet, + AggDynStr, LoadReader, SaveWriter, }; - #[test] - fn test_dyn_list() { - let list_field = Arc::new(Field::new("item", DataType::Int32, true)); - let l0 = ScalarValue::List( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, true)), - ); - - let l1 = ScalarValue::List( - Some(vec![ScalarValue::from(4i32), ScalarValue::Int32(None)]), - Arc::new(Field::new("item", DataType::Int32, true)), - ); - - let l2 = ScalarValue::List(None, Arc::new(Field::new("item", DataType::Int32, true))); - - let loaders = create_dyn_loaders_from_initial_value(&[AccumInitialValue::DynList( - DataType::List(list_field.clone()), - )]) - .unwrap(); - let savers = create_dyn_savers_from_initial_value(&[AccumInitialValue::DynList( - DataType::List(list_field.clone()), - )]) - .unwrap(); - let mut dyn_list = AggDynList::default(); - dyn_list.append(l0.clone()); - dyn_list.append(l1.clone()); - dyn_list.append(l2.clone()); - - let mut buf = vec![]; - savers[0]( - &mut SaveWriter(Box::new(&mut buf)), - Some(Box::new(dyn_list)), - ) - .unwrap(); - - let dyn_list = loaders[0](&mut LoadReader(Box::new(Cursor::new(&buf)))).unwrap(); - assert_eq!( - downcast_any!(dyn_list.unwrap(), AggDynList) - .unwrap() - .values(), - &[l0.clone(), l1.clone(), l2.clone(),] - ); - } - #[test] fn test_dyn_set() { - let fields_b = Fields::from(vec![ - Field::new("ba", DataType::UInt64, true), - Field::new("bb", DataType::UInt64, true), - ]); - let fields = Fields::from(vec![ - Field::new("a", DataType::UInt64, true), - Field::new("b", DataType::Struct(fields_b.clone()), true), - ]); - let scalars = vec![ - ScalarValue::Struct(None, fields.clone()), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct(None, fields_b.clone()), - ]), - fields.clone(), - ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct( - Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]), - fields_b.clone(), - ), - ]), - fields.clone(), - ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(1)), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(2)), - ScalarValue::UInt64(Some(3)), - ]), - fields_b, - ), - ]), - fields.clone(), - ), - ]; - - let loaders = create_dyn_loaders_from_initial_value(&[AccumInitialValue::DynSet( - DataType::Struct(fields.clone()), - )]) - .unwrap(); - let savers = create_dyn_savers_from_initial_value(&[AccumInitialValue::DynSet( - DataType::Struct(fields.clone()), - )]) - .unwrap(); let mut dyn_set = AggDynSet::default(); - dyn_set.append(scalars[0].clone()); - dyn_set.append(scalars[1].clone()); - dyn_set.append(scalars[3].clone()); - + dyn_set.append(&ScalarValue::from("Hello"), false); + dyn_set.append(&ScalarValue::from("Wooden"), false); + dyn_set.append(&ScalarValue::from("Bird"), false); + dyn_set.append(&ScalarValue::from("Snake"), false); + dyn_set.append(&ScalarValue::from("Wooden"), false); + dyn_set.append(&ScalarValue::from("Bird"), false); + + // test merge + let mut dyn_set2 = AggDynSet::default(); + dyn_set2.append(&ScalarValue::from("Hello"), false); + dyn_set2.append(&ScalarValue::from("Batman"), false); + dyn_set2.append(&ScalarValue::from("Candy"), false); + dyn_set.merge(&mut dyn_set2); + + // test save let mut buf = vec![]; - savers[0](&mut SaveWriter(Box::new(&mut buf)), Some(Box::new(dyn_set))).unwrap(); - - let dyn_set = loaders[0](&mut LoadReader(Box::new(Cursor::new(&buf)))).unwrap(); - - let right_set: SmallVec<[ScalarValue; 4]> = SmallVec::from_iter( - vec![scalars[0].clone(), scalars[1].clone(), scalars[3].clone()].into_iter(), - ); - let right = OptimizedSet::SmallVec(right_set); - assert_eq!( - downcast_any!(dyn_set.unwrap(), AggDynSet).unwrap().values(), - &right - ); + let mut save_writer = SaveWriter(Box::new(Cursor::new(&mut buf))); + let savers = + create_dyn_savers_from_initial_value(&[AccumInitialValue::DynSet(DataType::Utf8)]) + .unwrap(); + savers[0](&mut save_writer, Some(Box::new(dyn_set))).unwrap(); + drop(save_writer); + + // test load + let mut load_reader = LoadReader(Box::new(Cursor::new(&buf))); + let loaders = + create_dyn_loaders_from_initial_value(&[AccumInitialValue::DynSet(DataType::Utf8)]) + .unwrap(); + let dyn_set = loaders[0](&mut load_reader) + .unwrap() + .unwrap() + .as_any_boxed() + .downcast::() + .unwrap(); + drop(load_reader); + + let actual_set: HashSet = dyn_set.into_values(DataType::Utf8, false).collect(); + assert_eq!(actual_set.len(), 6); + assert!(actual_set.contains(&ScalarValue::from("Hello"))); + assert!(actual_set.contains(&ScalarValue::from("Wooden"))); + assert!(actual_set.contains(&ScalarValue::from("Bird"))); + assert!(actual_set.contains(&ScalarValue::from("Snake"))); + assert!(actual_set.contains(&ScalarValue::from("Batman"))); + assert!(actual_set.contains(&ScalarValue::from("Candy"))); } #[test] diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs index 79a5cdfd..3df34b6d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs @@ -35,7 +35,7 @@ use datafusion::{ }; use datafusion_ext_commons::{ array_size::ArraySize, - bytes_arena::BytesArena, + bytes_arena::{BytesArena, BytesArenaAddr}, downcast_any, ds::rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}, io::{read_bytes_slice, read_len, write_len}, @@ -532,7 +532,7 @@ pub struct HashingData { task_ctx: Arc, acc_store: AccStore, map_key_store: BytesArena, - map: RawTable<(u64, u32)>, // keys addr to accs store addr + map: RawTable<(BytesArenaAddr, u32)>, // keys addr to accs store addr num_input_records: usize, spill_metrics: SpillMetrics, } @@ -593,7 +593,10 @@ impl HashingData { .map .find_or_find_insert_slot( hash, - |v| self.map_key_store.get(v.0) == row.as_ref(), + |v| { + v.0.unpack().len == row.as_ref().len() + && self.map_key_store.get(v.0) == row.as_ref() + }, |v| gx_hash::(self.map_key_store.get(v.0)), ) .unwrap_or_else(|slot| { diff --git a/native-engine/datafusion-ext-plans/src/agg/collect_list.rs b/native-engine/datafusion-ext-plans/src/agg/collect_list.rs index 28ef87bd..27e2112b 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect_list.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect_list.rs @@ -124,12 +124,12 @@ impl Agg for AggCollectList { let list = downcast_any!(dyn_list, mut AggDynList)?; self.sub_mem_used(list.mem_size()); - list.append(ScalarValue::try_from_array(&values[0], row_idx)?); + list.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false); self.add_mem_used(list.mem_size()); } w => { let mut new_list = AggDynList::default(); - new_list.append(ScalarValue::try_from_array(&values[0], row_idx)?); + new_list.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false); self.add_mem_used(new_list.mem_size()); *w = Some(Box::new(new_list)); } @@ -153,7 +153,7 @@ impl Agg for AggCollectList { for i in 0..values[0].len() { if values[0].is_valid(i) { - list.append(ScalarValue::try_from_array(&values[0], i)?); + list.append(&ScalarValue::try_from_array(&values[0], i)?, false); } } self.add_mem_used(list.mem_size()); @@ -195,7 +195,7 @@ impl Agg for AggCollectList { .or_else(|_| df_execution_err!("error downcasting to AggDynList"))?; self.sub_mem_used(list.mem_size()); ScalarValue::new_list( - Some(list.into_values().into_vec()), + Some(list.into_values(self.arg_type.clone(), false).collect()), self.arg_type.clone(), ) } diff --git a/native-engine/datafusion-ext-plans/src/agg/collect_set.rs b/native-engine/datafusion-ext-plans/src/agg/collect_set.rs index db0e6cad..d36f57e0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect_set.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect_set.rs @@ -24,11 +24,10 @@ use datafusion::{ physical_expr::PhysicalExpr, }; use datafusion_ext_commons::{df_execution_err, downcast_any}; -use hashbrown::HashSet; use crate::agg::{ acc::{ - AccumInitialValue, AccumStateRow, AccumStateValAddr, AggDynSet, AggDynValue, OptimizedSet, + AccumInitialValue, AccumStateRow, AccumStateValAddr, AggDynSet, AggDynValue, RefAccumStateRow, }, Agg, WithAggBufAddrs, WithMemTracking, @@ -125,12 +124,12 @@ impl Agg for AggCollectSet { let set = downcast_any!(dyn_set, mut AggDynSet)?; self.sub_mem_used(set.mem_size()); - set.append(ScalarValue::try_from_array(&values[0], row_idx)?); + set.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false); self.add_mem_used(set.mem_size()); } w => { let mut new_set = AggDynSet::default(); - new_set.append(ScalarValue::try_from_array(&values[0], row_idx)?); + new_set.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false); self.add_mem_used(new_set.mem_size()); *w = Some(Box::new(new_set)); } @@ -154,7 +153,7 @@ impl Agg for AggCollectSet { for i in 0..values[0].len() { if values[0].is_valid(i) { - set.append(ScalarValue::try_from_array(&values[0], i)?); + set.append(&ScalarValue::try_from_array(&values[0], i)?, false); } } self.add_mem_used(set.mem_size()); @@ -190,24 +189,12 @@ impl Agg for AggCollectSet { match std::mem::take(acc.dyn_value_mut(self.accum_state_val_addr)) { Some(w) => { self.sub_mem_used(w.mem_size()); - let mut dyn_set = w + let set = w .as_any_boxed() .downcast::() .or_else(|_| df_execution_err!("error downcasting to AggDynSet"))? - .into_values(); - let scalar_list = match &mut dyn_set { - OptimizedSet::SmallVec(vec) => { - let convert_set: HashSet = - HashSet::from_iter(std::mem::take(vec).into_iter()); - Some(convert_set.into_iter().collect::>()) - } - OptimizedSet::Set(set) => Some( - std::mem::take(set) - .into_iter() - .collect::>(), - ), - }; - ScalarValue::new_list(scalar_list, self.arg_type.clone()) + .into_values(self.arg_type.clone(), false); + ScalarValue::new_list(Some(set.into_iter().collect()), self.arg_type.clone()) } None => ScalarValue::new_list(None, self.arg_type.clone()), },