-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: account for memory in RepartitionExec
#4820
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ use std::task::{Context, Poll}; | |
use std::{any::Any, vec}; | ||
|
||
use crate::error::{DataFusionError, Result}; | ||
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; | ||
use crate::physical_plan::hash_utils::create_hashes; | ||
use crate::physical_plan::{ | ||
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, | ||
|
@@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; | |
use tokio::task::JoinHandle; | ||
|
||
type MaybeBatch = Option<ArrowResult<RecordBatch>>; | ||
type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>; | ||
|
||
/// Inner state of [`RepartitionExec`]. | ||
#[derive(Debug)] | ||
struct RepartitionExecState { | ||
/// Channels for sending batches from input partitions to output partitions. | ||
/// Key is the partition number. | ||
channels: | ||
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>, | ||
channels: HashMap< | ||
usize, | ||
( | ||
UnboundedSender<MaybeBatch>, | ||
UnboundedReceiver<MaybeBatch>, | ||
SharedMemoryReservation, | ||
), | ||
>, | ||
|
||
/// Helper that ensures that that background job is killed once it is no longer needed. | ||
abort_helper: Arc<AbortOnDropMany<()>>, | ||
|
@@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec { | |
// for this would be to add spill-to-disk capabilities. | ||
let (sender, receiver) = | ||
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>(); | ||
state.channels.insert(partition, (sender, receiver)); | ||
let reservation = Arc::new(Mutex::new( | ||
MemoryConsumer::new(format!("RepartitionExec[{partition}]")) | ||
.register(context.memory_pool()), | ||
)); | ||
state | ||
.channels | ||
.insert(partition, (sender, receiver, reservation)); | ||
} | ||
|
||
// launch one async task per *input* partition | ||
|
@@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec { | |
let txs: HashMap<_, _> = state | ||
.channels | ||
.iter() | ||
.map(|(partition, (tx, _rx))| (*partition, tx.clone())) | ||
.map(|(partition, (tx, _rx, reservation))| { | ||
(*partition, (tx.clone(), Arc::clone(reservation))) | ||
}) | ||
.collect(); | ||
|
||
let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); | ||
|
@@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec { | |
// (and pass along any errors, including panic!s) | ||
let join_handle = tokio::spawn(Self::wait_for_task( | ||
AbortOnDropSingle::new(input_task), | ||
txs, | ||
txs.into_iter() | ||
.map(|(partition, (tx, _reservation))| (partition, tx)) | ||
.collect(), | ||
)); | ||
join_handles.push(join_handle); | ||
} | ||
|
@@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec { | |
|
||
// now return stream for the specified *output* partition which will | ||
// read from the channel | ||
let (_tx, rx, reservation) = state | ||
.channels | ||
.remove(&partition) | ||
.expect("partition not used yet"); | ||
Ok(Box::pin(RepartitionStream { | ||
num_input_partitions, | ||
num_input_partitions_processed: 0, | ||
schema: self.input.schema(), | ||
input: UnboundedReceiverStream::new( | ||
state.channels.remove(&partition).unwrap().1, | ||
), | ||
input: UnboundedReceiverStream::new(rx), | ||
drop_helper: Arc::clone(&state.abort_helper), | ||
reservation, | ||
})) | ||
} | ||
|
||
|
@@ -439,7 +460,7 @@ impl RepartitionExec { | |
async fn pull_from_input( | ||
input: Arc<dyn ExecutionPlan>, | ||
i: usize, | ||
mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>, | ||
mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>, | ||
partitioning: Partitioning, | ||
r_metrics: RepartitionMetrics, | ||
context: Arc<TaskContext>, | ||
|
@@ -467,11 +488,16 @@ impl RepartitionExec { | |
}; | ||
|
||
partitioner.partition(batch, |partition, partitioned| { | ||
let size = partitioned.get_array_memory_size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't account for sliced data (so if a batch of 1M gets cut up into 1000 pieces, each of the 1000 pieces will be charged the entire underlying size). that being said it is a conservative estimate so that is good 👍 |
||
|
||
let timer = r_metrics.send_time.timer(); | ||
// if there is still a receiver, send to it | ||
if let Some(tx) = txs.get_mut(&partition) { | ||
if let Some((tx, reservation)) = txs.get_mut(&partition) { | ||
reservation.lock().try_grow(size)?; | ||
|
||
if tx.send(Some(Ok(partitioned))).is_err() { | ||
// If the other end has hung up, it was an early shutdown (e.g. LIMIT) | ||
reservation.lock().shrink(size); | ||
txs.remove(&partition); | ||
} | ||
} | ||
|
@@ -546,6 +572,9 @@ struct RepartitionStream { | |
/// Handle to ensure background tasks are killed when no longer needed. | ||
#[allow(dead_code)] | ||
drop_helper: Arc<AbortOnDropMany<()>>, | ||
|
||
/// Memory reservation. | ||
reservation: SharedMemoryReservation, | ||
} | ||
|
||
impl Stream for RepartitionStream { | ||
|
@@ -555,20 +584,35 @@ impl Stream for RepartitionStream { | |
mut self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
) -> Poll<Option<Self::Item>> { | ||
match self.input.poll_next_unpin(cx) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the change here isn't as large as it seems. I've just converted the implicit loop via (tail) recursion into a proper one since I'm kinda afraid that this may explode under certain circumstances. See one of the early commits in this PR. The actual accounting change is in a separate commit. |
||
Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)), | ||
Poll::Ready(Some(None)) => { | ||
self.num_input_partitions_processed += 1; | ||
if self.num_input_partitions == self.num_input_partitions_processed { | ||
// all input partitions have finished sending batches | ||
Poll::Ready(None) | ||
} else { | ||
// other partitions still have data to send | ||
self.poll_next(cx) | ||
loop { | ||
match self.input.poll_next_unpin(cx) { | ||
Poll::Ready(Some(Some(v))) => { | ||
if let Ok(batch) = &v { | ||
self.reservation | ||
.lock() | ||
.shrink(batch.get_array_memory_size()); | ||
} | ||
|
||
return Poll::Ready(Some(v)); | ||
} | ||
Poll::Ready(Some(None)) => { | ||
self.num_input_partitions_processed += 1; | ||
|
||
if self.num_input_partitions == self.num_input_partitions_processed { | ||
// all input partitions have finished sending batches | ||
return Poll::Ready(None); | ||
} else { | ||
// other partitions still have data to send | ||
continue; | ||
} | ||
} | ||
Poll::Ready(None) => { | ||
return Poll::Ready(None); | ||
} | ||
Poll::Pending => { | ||
return Poll::Pending; | ||
} | ||
} | ||
Poll::Ready(None) => Poll::Ready(None), | ||
Poll::Pending => Poll::Pending, | ||
} | ||
} | ||
} | ||
|
@@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream { | |
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::execution::context::SessionConfig; | ||
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; | ||
use crate::from_slice::FromSlice; | ||
use crate::prelude::SessionContext; | ||
use crate::test::create_vec_batches; | ||
|
@@ -1078,4 +1124,41 @@ mod tests { | |
assert!(batch0.is_empty() || batch1.is_empty()); | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn oom() -> Result<()> { | ||
// define input partitions | ||
let schema = test_schema(); | ||
let partition = create_vec_batches(&schema, 50); | ||
let input_partitions = vec![partition]; | ||
let partitioning = Partitioning::RoundRobinBatch(4); | ||
|
||
// setup up context | ||
let session_ctx = SessionContext::with_config_rt( | ||
SessionConfig::default(), | ||
Arc::new( | ||
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) | ||
.unwrap(), | ||
), | ||
); | ||
let task_ctx = session_ctx.task_ctx(); | ||
|
||
// create physical plan | ||
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; | ||
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; | ||
|
||
// pull partitions | ||
for i in 0..exec.partitioning.partition_count() { | ||
let mut stream = exec.execute(i, task_ctx.clone())?; | ||
let err = | ||
DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err()); | ||
let err = err.find_root(); | ||
assert!( | ||
matches!(err, DataFusionError::ResourcesExhausted(_)), | ||
"Wrong error type: {err}", | ||
); | ||
} | ||
|
||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drive-by clean-up