Skip to content

Commit

Permalink
improve shuffle write performance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangli20 committed Sep 9, 2024
1 parent 23d102b commit 6ef86b6
Show file tree
Hide file tree
Showing 22 changed files with 281 additions and 328 deletions.
85 changes: 0 additions & 85 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion native-engine/datafusion-ext-commons/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ log = "0.4.22"
num = "0.4.2"
once_cell = "1.19.0"
paste = "1.0.15"
postcard = { version = "1.0.10", features = ["alloc"]}
radsort = "0.1.1"
slimmer_box = "0.6.5"
tempfile = "3"
Expand Down
67 changes: 14 additions & 53 deletions native-engine/datafusion-ext-commons/src/io/batch_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,30 @@ use datafusion::common::Result;
use unchecked_index::unchecked_index;

use crate::{
df_execution_err, df_unimplemented_err,
df_unimplemented_err,
io::{read_bytes_slice, read_len, write_len},
};

pub fn write_batch(num_rows: usize, cols: &[ArrayRef], mut output: impl Write) -> Result<()> {
// write number of columns and rows
write_len(cols.len(), &mut output)?;
write_len(num_rows, &mut output)?;

// write columns
for col in cols {
write_data_type(col.data_type(), &mut output)?;
write_array(col, &mut output)?;
}
Ok(())
}

pub fn read_batch(mut input: impl Read) -> Result<(usize, Vec<ArrayRef>)> {
pub fn read_batch(mut input: impl Read, schema: &SchemaRef) -> Result<(usize, Vec<ArrayRef>)> {
// read number of columns and rows
let num_cols = read_len(&mut input)?;
let num_rows = read_len(&mut input)?;

// read columns
let cols = (0..num_cols)
.map(|_| {
let dt = read_data_type(&mut input)?;
read_array(&mut input, &dt, num_rows)
})
let cols = schema
.fields()
.into_iter()
.map(|field| read_array(&mut input, &field.data_type(), num_rows))
.collect::<Result<_>>()?;
Ok((num_rows, cols))
}
Expand Down Expand Up @@ -169,41 +165,6 @@ fn read_bits_buffer<R: Read>(input: &mut R, bits_len: usize) -> Result<Buffer> {
Ok(Buffer::from(buf))
}

fn nameless_field(field: &Field) -> Field {
Field::new(
"",
nameless_data_type(field.data_type()),
field.is_nullable(),
)
}

fn nameless_data_type(data_type: &DataType) -> DataType {
match data_type {
DataType::List(field) => DataType::List(Arc::new(nameless_field(field))),
DataType::Map(field, sorted) => DataType::Map(Arc::new(nameless_field(field)), *sorted),
DataType::Struct(fields) => {
DataType::Struct(fields.iter().map(|field| nameless_field(field)).collect())
}
others => others.clone(),
}
}

pub fn write_data_type<W: Write>(data_type: &DataType, output: &mut W) -> Result<()> {
let buf = postcard::to_allocvec(&nameless_data_type(data_type))
.or_else(|err| df_execution_err!("serialize data type error: {err}"))?;
write_len(buf.len(), output)?;
output.write_all(&buf)?;
Ok(())
}

pub fn read_data_type<R: Read>(input: &mut R) -> Result<DataType> {
let buf_len = read_len(input)?;
let buf = read_bytes_slice(input, buf_len)?;
let data_type = postcard::from_bytes(&buf)
.or_else(|err| df_execution_err!("deserialize data type error: {err}"))?;
Ok(data_type)
}

fn write_primitive_array<W: Write, PT: ArrowPrimitiveType>(
array: &PrimitiveArray<PT>,
output: &mut W,
Expand Down Expand Up @@ -665,7 +626,7 @@ mod test {
let mut buf = vec![];
write_batch(batch.num_rows(), batch.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, batch.schema()).unwrap(),
batch
Expand All @@ -676,7 +637,7 @@ mod test {
let mut buf = vec![];
write_batch(sliced.num_rows(), sliced.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, batch.schema()).unwrap(),
sliced
Expand Down Expand Up @@ -717,7 +678,7 @@ mod test {
let mut buf = vec![];
write_batch(batch.num_rows(), batch.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_batches_eq!(
vec![
"+-----------+-----------+",
Expand All @@ -737,7 +698,7 @@ mod test {
let mut buf = vec![];
write_batch(sliced.num_rows(), sliced.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_batches_eq!(
vec![
"+----------+----------+",
Expand Down Expand Up @@ -781,7 +742,7 @@ mod test {
let mut buf = vec![];
write_batch(batch.num_rows(), batch.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, batch.schema()).unwrap(),
batch
Expand All @@ -792,7 +753,7 @@ mod test {
let mut buf = vec![];
write_batch(sliced.num_rows(), sliced.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, sliced.schema()).unwrap(),
sliced
Expand All @@ -819,7 +780,7 @@ mod test {
let mut buf = vec![];
write_batch(batch.num_rows(), batch.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, batch.schema()).unwrap(),
batch
Expand All @@ -830,7 +791,7 @@ mod test {
let mut buf = vec![];
write_batch(sliced.num_rows(), sliced.columns(), &mut buf).unwrap();
let mut cursor = Cursor::new(buf);
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor).unwrap();
let (decoded_num_rows, decoded_cols) = read_batch(&mut cursor, &batch.schema()).unwrap();
assert_eq!(
recover_named_batch(decoded_num_rows, &decoded_cols, batch.schema()).unwrap(),
sliced
Expand Down
9 changes: 6 additions & 3 deletions native-engine/datafusion-ext-commons/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use arrow::{
datatypes::SchemaRef,
record_batch::RecordBatch,
};
pub use batch_serde::{read_array, read_data_type, write_array, write_data_type};
pub use batch_serde::{read_array, write_array};
use datafusion::common::Result;
pub use scalar_serde::{read_scalar, write_scalar};

Expand Down Expand Up @@ -62,7 +62,10 @@ pub fn write_one_batch(num_rows: usize, cols: &[ArrayRef], mut output: impl Writ
Ok(())
}

pub fn read_one_batch(mut input: impl Read) -> Result<Option<(usize, Vec<ArrayRef>)>> {
pub fn read_one_batch(
mut input: impl Read,
schema: &SchemaRef,
) -> Result<Option<(usize, Vec<ArrayRef>)>> {
let batch_data_len = match read_len(&mut input) {
Ok(len) => len,
Err(e) => {
Expand All @@ -73,7 +76,7 @@ pub fn read_one_batch(mut input: impl Read) -> Result<Option<(usize, Vec<ArrayRe
}
};
let mut input = input.take(batch_data_len as u64);
let (num_rows, cols) = batch_serde::read_batch(&mut input)?;
let (num_rows, cols) = batch_serde::read_batch(&mut input, schema)?;

// consume trailing bytes
std::io::copy(&mut input, &mut std::io::sink())?;
Expand Down
2 changes: 1 addition & 1 deletion native-engine/datafusion-ext-commons/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn batch_size() -> usize {

// bigger for better radix sort performance
pub const fn staging_mem_size_for_partial_sort() -> usize {
8388608
1048576
}

// use bigger batch memory size writing shuffling data
Expand Down
26 changes: 0 additions & 26 deletions native-engine/datafusion-ext-commons/src/spark_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,6 @@ fn hash_one<T: num::PrimInt>(
}
}

pub fn pmod(hash: i32, n: usize) -> usize {
let n = n as i32;
let r = hash % n;
let result = if r < 0 { (r + n) % n } else { r };
result as usize
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -549,25 +542,6 @@ mod tests {
assert_eq!(hashes, expected);
}

#[test]
fn test_pmod() {
let i: Vec<i32> = [
0x99f0149d_u32,
0x9c67b85d,
0xc8008529,
0xa05b5d7b,
0xcd1e64fb,
]
.into_iter()
.map(|v| v as i32)
.collect();
let result = i.into_iter().map(|i| pmod(i, 200)).collect::<Vec<usize>>();

// expected partition from Spark with n=200
let expected = vec![69, 5, 193, 171, 115];
assert_eq!(result, expected);
}

#[test]
fn test_map_array() {
// Construct key and values
Expand Down
Loading

0 comments on commit 6ef86b6

Please sign in to comment.