Skip to content

Commit

Permalink
Support SortMerge spilling
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Jul 9, 2024
1 parent 9c16696 commit 0717597
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 114 deletions.
180 changes: 68 additions & 112 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ impl BufferedBatch {
) -> Result<()> {
let batch = std::mem::replace(
&mut self.batch,
RecordBatch::new_empty(buffered_schema.clone()),
RecordBatch::new_empty(Arc::clone(&buffered_schema)),
);
let _ = spill_record_batch_by_size(
batch,
Expand All @@ -652,7 +652,6 @@ impl BufferedBatch {
batch_size,
);
self.spill_file = Some(path);
dbg!(&self.spill_file);
Ok(())
}
}
Expand Down Expand Up @@ -886,7 +885,6 @@ impl SMJStream {
self.streamed_state = StreamedState::Exhausted;
}
Poll::Ready(Some(batch)) => {
println!("\nstreamed rows {}", batch.num_rows());
if batch.num_rows() > 0 {
self.freeze_streamed()?;
self.join_metrics.input_batches.add(1);
Expand All @@ -907,6 +905,43 @@ impl SMJStream {
}
}

fn mem_allocate_batch(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
match self.reservation.try_grow(buffered_batch.size_estimation) {
Ok(_) => {
self.join_metrics
.peak_mem_used
.set_max(self.reservation.size());
Ok(())
}
Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
// spill buffered batch to disk
let spill_file = self
.runtime_env
.disk_manager
.create_tmp_file("SortMergeJoinBuffered")?;

buffered_batch.spill_to_disk(
spill_file,
Arc::clone(&self.buffered_schema),
self.batch_size,
)?;

// update metrics to display spill
self.join_metrics.spill_count.add(1);
self.join_metrics
.spilled_bytes
.add(buffered_batch.size_estimation);
self.join_metrics.spilled_rows.add(buffered_batch.num_rows);

Ok(())
}
Err(e) => Err(e),
}?;

self.buffered_data.batches.push_back(buffered_batch);
Ok(())
}

/// Poll next buffered batches
fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
loop {
Expand All @@ -921,11 +956,12 @@ impl SMJStream {
if let Some(buffered_batch) =
self.buffered_data.batches.pop_front()
{
// Noop on shrink complaints, this might happen
// on spilled batches
self.reservation
.try_shrink(buffered_batch.size_estimation)
.unwrap_or(());
// Shrink mem usage for non spilled batches only
if buffered_batch.spill_file.is_none() {
self.reservation
.try_shrink(buffered_batch.size_estimation)
.unwrap_or(());
}
}
} else {
// If the head batch is not fully processed, break the loop.
Expand Down Expand Up @@ -953,64 +989,16 @@ impl SMJStream {
Poll::Ready(Some(batch)) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
println!(
"\nbatch rows {} mem {}",
batch.num_rows(),
self.reservation.size()
);
if batch.num_rows() > 0 {
let mut buffered_batch =
let buffered_batch =
BufferedBatch::new(batch, 0..1, &self.on_buffered);

match self
.reservation
.try_grow(buffered_batch.size_estimation)
{
Ok(_) => {
self.join_metrics
.peak_mem_used
.set_max(self.reservation.size());
Ok(())
}
Err(_)
if self
.runtime_env
.disk_manager
.tmp_files_enabled() =>
{
// spill buffered batch to disk
let spill_file = self
.runtime_env
.disk_manager
.create_tmp_file("SortMergeJoinBuffered")?;

buffered_batch.spill_to_disk(
spill_file,
self.buffered_schema.clone(),
self.batch_size,
)?;

// update metrics to display spill
self.join_metrics.spill_count.add(1);
self.join_metrics
.spilled_bytes
.add(buffered_batch.size_estimation);
self.join_metrics
.spilled_rows
.add(buffered_batch.num_rows);

Ok(())
}
Err(e) => Err(e),
}?;

self.buffered_data.batches.push_back(buffered_batch);
self.mem_allocate_batch(buffered_batch)?;
self.buffered_state = BufferedState::PollingRest;
}
}
},
BufferedState::PollingRest => {
println!("Polling Rest");
if self.buffered_data.tail_batch().range.end
< self.buffered_data.tail_batch().num_rows
{
Expand Down Expand Up @@ -1038,7 +1026,7 @@ impl SMJStream {
self.buffered_state = BufferedState::Ready;
}
Poll::Ready(Some(batch)) => {
// This code is unreachable! Think about dropping it
// Multi batch
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if batch.num_rows() > 0 {
Expand All @@ -1047,12 +1035,7 @@ impl SMJStream {
0..0,
&self.on_buffered,
);
self.reservation
.try_grow(buffered_batch.size_estimation)?;
self.join_metrics
.peak_mem_used
.set_max(self.reservation.size());
self.buffered_data.batches.push_back(buffered_batch);
self.mem_allocate_batch(buffered_batch)?;
}
}
}
Expand Down Expand Up @@ -1090,7 +1073,6 @@ impl SMJStream {
/// Produce join and fill output buffer until reaching target batch size
/// or the join is finished
fn join_partial(&mut self) -> Result<()> {
println!("join_partial");
// Whether to join streamed rows
let mut join_streamed = false;
// Whether to join buffered rows
Expand Down Expand Up @@ -1159,13 +1141,10 @@ impl SMJStream {
}

if join_buffered {
//println!("join_partial: join_buffered");

// joining streamed/nulls and buffered
while !self.buffered_data.scanning_finished()
&& self.output_size < self.batch_size
{
//println!("join_partial: while join_buffered");
let scanning_idx = self.buffered_data.scanning_idx();
if join_streamed {
// Join streamed row and buffered row
Expand Down Expand Up @@ -1296,7 +1275,7 @@ impl SMJStream {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
get_buffered_columns(
&mut self.buffered_data,
&self.buffered_data,
buffered_idx,
&buffered_indices,
)?
Expand All @@ -1310,8 +1289,6 @@ impl SMJStream {
.collect::<Vec<_>>()
};

dbg!(&buffered_columns);

let streamed_columns_length = streamed_columns.len();
let buffered_columns_length = buffered_columns.len();

Expand All @@ -1326,7 +1303,7 @@ impl SMJStream {
) {
// unwrap is safe here as we check is_some on top of if statement
let buffered_columns = get_buffered_columns(
&mut self.buffered_data,
&self.buffered_data,
chunk.buffered_batch_idx.unwrap(),
&buffered_indices,
)?;
Expand Down Expand Up @@ -1570,7 +1547,7 @@ fn produce_buffered_null_batch(
schema: &SchemaRef,
streamed_schema: &SchemaRef,
buffered_indices: &PrimitiveArray<UInt64Type>,
buffered_batch: &mut BufferedBatch,
buffered_batch: &BufferedBatch,
) -> Result<Option<RecordBatch>> {
if buffered_indices.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -1599,53 +1576,34 @@ fn produce_buffered_null_batch(
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
#[inline(always)]
fn get_buffered_columns(
buffered_data: &mut BufferedData,
buffered_data: &BufferedData,
buffered_batch_idx: usize,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
get_buffered_columns_from_batch(
&mut buffered_data.batches[buffered_batch_idx],
&buffered_data.batches[buffered_batch_idx],
buffered_indices,
)
}

#[inline(always)]
fn get_buffered_columns_from_batch(
buffered_batch: &mut BufferedBatch,
buffered_batch: &BufferedBatch,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
if let Some(spill_file) = mem::take(&mut buffered_batch.spill_file) {
// if spilled read as a stream
if let Some(spill_file) = &buffered_batch.spill_file {
// if spilled read from disk in smaller sub batches
let mut buffered_cols: Vec<ArrayRef> = Vec::with_capacity(buffered_indices.len());
// let mut stream =
// read_spill_as_stream(spill_file, buffered_batch.batch.schema(), 2)?;

let file = BufReader::new(File::open(spill_file.path())?);
let reader = FileReader::try_new(file, None)?;

for batch in reader {
let batch = batch?;
batch.columns().iter().for_each(|column| {
batch?.columns().iter().for_each(|column| {
buffered_cols.extend(take(column, &buffered_indices, None))
});
}

// let _ = futures::stream::once(async {
// dbg!("in");
// while let Some(batch) = stream.next().await {
// dbg!("stream spilled batch");
//
// let batch = batch?;
// batch.columns().iter().for_each(|column| {
// buffered_cols.extend(take(column, &buffered_indices, None))
// });
// }
//
// Ok::<(), ArrowError>(())
// });

dbg!(&buffered_cols);

Ok(buffered_cols)
} else {
buffered_batch
Expand Down Expand Up @@ -2933,7 +2891,7 @@ mod tests {
for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
.with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);

let join = join_with_options(
Expand Down Expand Up @@ -3016,7 +2974,7 @@ mod tests {
for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
.with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);
let join = join_with_options(
Arc::clone(&left),
Expand Down Expand Up @@ -3062,7 +3020,7 @@ mod tests {
JoinType::Inner,
JoinType::Left,
JoinType::Right,
//JoinType::Full,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];
Expand All @@ -3077,14 +3035,12 @@ mod tests {
for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
.with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);

println!("{join_type}");

let join = join_with_options(
left.clone(),
right.clone(),
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
join_type,
sort_options.clone(),
Expand Down Expand Up @@ -3141,26 +3097,26 @@ mod tests {
JoinType::Inner,
JoinType::Left,
JoinType::Right,
//JoinType::Full,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];

// Enable DiskManager to allow spilling
let runtime_config = RuntimeConfig::new()
.with_memory_limit(100, 1.0)
.with_memory_limit(500, 1.0)
.with_disk_manager(DiskManagerConfig::NewOs);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_config = SessionConfig::default().with_batch_size(50);

for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
.with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);
let join = join_with_options(
left.clone(),
right.clone(),
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
join_type,
sort_options.clone(),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ mod tests {
let cnt = spill_record_batches(
vec![batch1, batch2],
spill_file.path().into(),
schema.clone(),
Arc::clone(&schema),
);
assert_eq!(cnt.unwrap(), num_rows);

Expand Down Expand Up @@ -1086,7 +1086,7 @@ mod tests {
let cnt = spill_record_batch_by_size(
batch1,
spill_file.path().into(),
schema.clone(),
Arc::clone(&schema),
1,
);
assert_eq!(cnt.unwrap(), num_rows);
Expand Down

0 comments on commit 0717597

Please sign in to comment.