Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: account for memory in RepartitionExec #4820

Merged
merged 4 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ mod tests {
use crate::{assert_batches_sorted_eq, physical_plan::common};
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
Expand Down Expand Up @@ -1210,18 +1210,11 @@ mod tests {
let err = common::collect(stream).await.unwrap_err();

// error root cause traversal is a bit complicated, see #4172.
if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by clean-up

if let Some(err) = err.downcast_ref::<DataFusionError>() {
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong inner error type: {err}",
);
} else {
panic!("Wrong arrow error type: {err}")
}
} else {
panic!("Wrong outer error type: {err}")
}
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}

Ok(())
Expand Down
127 changes: 105 additions & 22 deletions datafusion/core/src/physical_plan/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::task::{Context, Poll};
use std::{any::Any, vec};

use crate::error::{DataFusionError, Result};
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
Expand All @@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;

type MaybeBatch = Option<ArrowResult<RecordBatch>>;
type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;

/// Inner state of [`RepartitionExec`].
#[derive(Debug)]
struct RepartitionExecState {
/// Channels for sending batches from input partitions to output partitions.
/// Key is the partition number.
channels:
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
channels: HashMap<
usize,
(
UnboundedSender<MaybeBatch>,
UnboundedReceiver<MaybeBatch>,
SharedMemoryReservation,
),
>,

/// Helper that ensures that that background job is killed once it is no longer needed.
abort_helper: Arc<AbortOnDropMany<()>>,
Expand Down Expand Up @@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
// for this would be to add spill-to-disk capabilities.
let (sender, receiver) =
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
state.channels.insert(partition, (sender, receiver));
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
.register(context.memory_pool()),
));
state
.channels
.insert(partition, (sender, receiver, reservation));
}

// launch one async task per *input* partition
Expand All @@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
let txs: HashMap<_, _> = state
.channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.map(|(partition, (tx, _rx, reservation))| {
(*partition, (tx.clone(), Arc::clone(reservation)))
})
.collect();

let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
Expand All @@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
// (and pass along any errors, including panic!s)
let join_handle = tokio::spawn(Self::wait_for_task(
AbortOnDropSingle::new(input_task),
txs,
txs.into_iter()
.map(|(partition, (tx, _reservation))| (partition, tx))
.collect(),
));
join_handles.push(join_handle);
}
Expand All @@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {

// now return stream for the specified *output* partition which will
// read from the channel
let (_tx, rx, reservation) = state
.channels
.remove(&partition)
.expect("partition not used yet");
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: UnboundedReceiverStream::new(
state.channels.remove(&partition).unwrap().1,
),
input: UnboundedReceiverStream::new(rx),
drop_helper: Arc::clone(&state.abort_helper),
reservation,
}))
}

Expand Down Expand Up @@ -439,7 +460,7 @@ impl RepartitionExec {
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
i: usize,
mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>,
partitioning: Partitioning,
r_metrics: RepartitionMetrics,
context: Arc<TaskContext>,
Expand Down Expand Up @@ -467,11 +488,16 @@ impl RepartitionExec {
};

partitioner.partition(batch, |partition, partitioned| {
let size = partitioned.get_array_memory_size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't account for sliced data (so if a batch of 1M gets cut up into 1000 pieces, each of the 1000 pieces will be charged the entire underlying size). that being said it is a conservative estimate so that is good 👍


let timer = r_metrics.send_time.timer();
// if there is still a receiver, send to it
if let Some(tx) = txs.get_mut(&partition) {
if let Some((tx, reservation)) = txs.get_mut(&partition) {
reservation.lock().try_grow(size)?;

if tx.send(Some(Ok(partitioned))).is_err() {
// If the other end has hung up, it was an early shutdown (e.g. LIMIT)
reservation.lock().shrink(size);
txs.remove(&partition);
}
}
Expand Down Expand Up @@ -546,6 +572,9 @@ struct RepartitionStream {
/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,

/// Memory reservation.
reservation: SharedMemoryReservation,
}

impl Stream for RepartitionStream {
Expand All @@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.input.poll_next_unpin(cx) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change here isn't as large as it seems. I've just converted the implicit loop via (tail) recursion into a proper one since I'm kinda afraid that this may explode under certain circumstances. See one of the early commits in this PR. The actual accounting change is in a separate commit.

Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;
if self.num_input_partitions == self.num_input_partitions_processed {
// all input partitions have finished sending batches
Poll::Ready(None)
} else {
// other partitions still have data to send
self.poll_next(cx)
loop {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Some(v))) => {
if let Ok(batch) = &v {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
}

return Poll::Ready(Some(v));
}
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;

if self.num_input_partitions == self.num_input_partitions_processed {
// all input partitions have finished sending batches
return Poll::Ready(None);
} else {
// other partitions still have data to send
continue;
}
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {
return Poll::Pending;
}
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
Expand All @@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::context::SessionConfig;
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
Expand Down Expand Up @@ -1078,4 +1124,41 @@ mod tests {
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}

#[tokio::test]
async fn oom() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);

// setup up context
let session_ctx = SessionContext::with_config_rt(
SessionConfig::default(),
Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
.unwrap(),
),
);
let task_ctx = session_ctx.task_ctx();

// create physical plan
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;

// pull partitions
for i in 0..exec.partitioning.partition_count() {
let mut stream = exec.execute(i, task_ctx.clone())?;
let err =
DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err());
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}

Ok(())
}
}
6 changes: 5 additions & 1 deletion datafusion/core/tests/memory_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize)

let runtime = RuntimeEnv::new(rt_config).unwrap();

let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
let ctx = SessionContext::with_config_rt(
// do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first)
SessionConfig::new().with_target_partitions(1),
Arc::new(runtime),
);
ctx.register_table("t", Arc::new(table))
.expect("registering table");

Expand Down