Skip to content

Commit

Permalink
fix: use JoinSet to make spawned tasks cancel-safe (apache#9318)
Browse files Browse the repository at this point in the history
* fix: use `JoinSet` to make spawned tasks cancel-safe

* feat: drop `AbortOnDropSingle` and `AbortOnDropMany`

* style: doc lint

* fix: ordering of the tasks in `RepartitionExec`

* fix: replace spawn_blocking with JoinSet

* style: disallow spawn methods

* fixes: preserve ordering of tasks

* style: allow spawning in tests

* chore: exclude clippy.toml from rat

* chore: typo

* feat: introduce `SpawnedTask`

* revert outdated comment

* switch to SpawnedTask missed outdated part

* doc: improve reason for disallowed-method
  • Loading branch information
DDtKey authored Feb 27, 2024
1 parent 372204e commit 14264d2
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 124 deletions.
4 changes: 4 additions & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
disallowed-methods = [
{ path = "tokio::task::spawn", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" },
{ path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" },
]
1 change: 1 addition & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,7 @@ mod tests {
}

#[tokio::test]
#[allow(clippy::disallowed_methods)]
async fn sendable() {
let df = test_table().await.unwrap();
// dataframes should be sendable between threads/tasks
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ impl DataSink for ArrowFileSink {
}
}

match demux_task.await {
match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
Expand Down
51 changes: 27 additions & 24 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use std::fmt::Debug;
use std::sync::Arc;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::task::{JoinHandle, JoinSet};
use tokio::task::JoinSet;

use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::statistics::{create_max_min_accs, get_col_stats};
Expand All @@ -42,6 +42,7 @@ use bytes::{BufMut, BytesMut};
use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::common::SpawnedTask;
use futures::{StreamExt, TryStreamExt};
use hashbrown::HashMap;
use object_store::path::Path;
Expand Down Expand Up @@ -728,7 +729,7 @@ impl DataSink for ParquetSink {
}
}

match demux_task.await {
match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
Expand All @@ -738,6 +739,7 @@ impl DataSink for ParquetSink {
}
}
}

Ok(row_count as u64)
}
}
Expand All @@ -754,32 +756,34 @@ async fn column_serializer_task(
Ok(writer)
}

type ColumnJoinHandle = JoinHandle<Result<ArrowColumnWriter>>;
type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
type ColSender = Sender<ArrowLeafColumn>;

/// Spawns a parallel serialization task for each column
/// Returns join handles for each columns serialization task along with a send channel
/// to send arrow arrays to each serialization task.
fn spawn_column_parallel_row_group_writer(
schema: Arc<Schema>,
parquet_props: Arc<WriterProperties>,
max_buffer_size: usize,
) -> Result<(Vec<ColumnJoinHandle>, Vec<ColSender>)> {
) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
let schema_desc = arrow_to_parquet_schema(&schema)?;
let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?;
let num_columns = col_writers.len();

let mut col_writer_handles = Vec::with_capacity(num_columns);
let mut col_writer_tasks = Vec::with_capacity(num_columns);
let mut col_array_channels = Vec::with_capacity(num_columns);
for writer in col_writers.into_iter() {
// Buffer size of this channel limits the number of arrays queued up for column level serialization
let (send_array, recieve_array) =
mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
col_array_channels.push(send_array);
col_writer_handles
.push(tokio::spawn(column_serializer_task(recieve_array, writer)))

let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer));
col_writer_tasks.push(task);
}

Ok((col_writer_handles, col_array_channels))
Ok((col_writer_tasks, col_array_channels))
}

/// Settings related to writing parquet files in parallel
Expand Down Expand Up @@ -820,14 +824,14 @@ async fn send_arrays_to_col_writers(
/// Spawns a tokio task which joins the parallel column writer tasks,
/// and finalizes the row group
fn spawn_rg_join_and_finalize_task(
column_writer_handles: Vec<JoinHandle<Result<ArrowColumnWriter>>>,
column_writer_tasks: Vec<ColumnWriterTask>,
rg_rows: usize,
) -> JoinHandle<RBStreamSerializeResult> {
tokio::spawn(async move {
let num_cols = column_writer_handles.len();
) -> SpawnedTask<RBStreamSerializeResult> {
SpawnedTask::spawn(async move {
let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
for handle in column_writer_handles.into_iter() {
match handle.await {
for task in column_writer_tasks.into_iter() {
match task.join().await {
Ok(r) => {
let w = r?;
finalized_rg.push(w.close()?);
Expand Down Expand Up @@ -856,12 +860,12 @@ fn spawn_rg_join_and_finalize_task(
/// given by n_columns * num_row_groups.
fn spawn_parquet_parallel_serialization_task(
mut data: Receiver<RecordBatch>,
serialize_tx: Sender<JoinHandle<RBStreamSerializeResult>>,
serialize_tx: Sender<SpawnedTask<RBStreamSerializeResult>>,
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
parallel_options: ParallelParquetWriterOptions,
) -> JoinHandle<Result<(), DataFusionError>> {
tokio::spawn(async move {
) -> SpawnedTask<Result<(), DataFusionError>> {
SpawnedTask::spawn(async move {
let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream;
let max_row_group_rows = writer_props.max_row_group_size();
let (mut column_writer_handles, mut col_array_channels) =
Expand Down Expand Up @@ -931,7 +935,7 @@ fn spawn_parquet_parallel_serialization_task(
/// Consume RowGroups serialized by other parallel tasks and concatenate them in
/// to the final parquet file, while flushing finalized bytes to an [ObjectStore]
async fn concatenate_parallel_row_groups(
mut serialize_rx: Receiver<JoinHandle<RBStreamSerializeResult>>,
mut serialize_rx: Receiver<SpawnedTask<RBStreamSerializeResult>>,
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
mut object_store_writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
Expand All @@ -947,9 +951,8 @@ async fn concatenate_parallel_row_groups(

let mut row_count = 0;

while let Some(handle) = serialize_rx.recv().await {
let join_result = handle.await;
match join_result {
while let Some(task) = serialize_rx.recv().await {
match task.join().await {
Ok(result) => {
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, cnt) = result?;
Expand Down Expand Up @@ -999,7 +1002,7 @@ async fn output_single_parquet_file_parallelized(
let max_rowgroups = parallel_options.max_parallel_row_groups;
// Buffer size of this channel limits maximum number of RowGroups being worked on in parallel
let (serialize_tx, serialize_rx) =
mpsc::channel::<JoinHandle<RBStreamSerializeResult>>(max_rowgroups);
mpsc::channel::<SpawnedTask<RBStreamSerializeResult>>(max_rowgroups);

let arc_props = Arc::new(parquet_props.clone());
let launch_serialization_task = spawn_parquet_parallel_serialization_task(
Expand All @@ -1017,7 +1020,7 @@ async fn output_single_parquet_file_parallelized(
)
.await?;

match launch_serialization_task.await {
match launch_serialization_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err(e),
Err(e) => {
Expand All @@ -1027,7 +1030,7 @@ async fn output_single_parquet_file_parallelized(
unreachable!()
}
}
};
}

Ok(row_count)
}
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/datasource/file_format/write/demux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use object_store::path::Path;

use rand::distributions::DistString;

use datafusion_physical_plan::common::SpawnedTask;
use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;

type RecordBatchReceiver = Receiver<RecordBatch>;
type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
Expand Down Expand Up @@ -76,15 +76,15 @@ pub(crate) fn start_demuxer_task(
partition_by: Option<Vec<(String, DataType)>>,
base_output_path: ListingTableUrl,
file_extension: String,
) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
let (tx, rx) = mpsc::unbounded_channel();
let context = context.clone();
let single_file_output = !base_output_path.is_collection();
let task: JoinHandle<std::result::Result<(), DataFusionError>> = match partition_by {
let task = match partition_by {
Some(parts) => {
// There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot
// bound this channel without risking a deadlock.
tokio::spawn(async move {
SpawnedTask::spawn(async move {
hive_style_partitions_demuxer(
tx,
input,
Expand All @@ -96,7 +96,7 @@ pub(crate) fn start_demuxer_task(
.await
})
}
None => tokio::spawn(async move {
None => SpawnedTask::spawn(async move {
row_count_demuxer(
tx,
input,
Expand Down
29 changes: 15 additions & 14 deletions datafusion/core/src/datasource/file_format/write/orchestration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
use datafusion_execution::TaskContext;

use bytes::Bytes;
use datafusion_physical_plan::common::SpawnedTask;
use futures::try_join;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::{JoinHandle, JoinSet};
use tokio::try_join;
use tokio::task::JoinSet;

type WriterType = AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>;
type SerializerType = Arc<dyn BatchSerializer>;
Expand All @@ -51,31 +52,31 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> {
let (tx, mut rx) =
mpsc::channel::<JoinHandle<Result<(usize, Bytes), DataFusionError>>>(100);
let serialize_task = tokio::spawn(async move {
mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
let serialize_task = SpawnedTask::spawn(async move {
// Some serializers (like CSV) handle the first batch differently than
// subsequent batches, so we track that here.
let mut initial = true;
while let Some(batch) = data_rx.recv().await {
let serializer_clone = serializer.clone();
let handle = tokio::spawn(async move {
let task = SpawnedTask::spawn(async move {
let num_rows = batch.num_rows();
let bytes = serializer_clone.serialize(batch, initial)?;
Ok((num_rows, bytes))
});
if initial {
initial = false;
}
tx.send(handle).await.map_err(|_| {
tx.send(task).await.map_err(|_| {
internal_datafusion_err!("Unknown error writing to object store")
})?;
}
Ok(())
});

let mut row_count = 0;
while let Some(handle) = rx.recv().await {
match handle.await {
while let Some(task) = rx.recv().await {
match task.join().await {
Ok(Ok((cnt, bytes))) => {
match writer.write_all(&bytes).await {
Ok(_) => (),
Expand Down Expand Up @@ -106,7 +107,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
}
}

match serialize_task.await {
match serialize_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err((writer, e)),
Err(_) => {
Expand All @@ -115,7 +116,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
internal_datafusion_err!("Unknown error writing to object store"),
))
}
};
}
Ok((writer, row_count as u64))
}

Expand Down Expand Up @@ -241,9 +242,9 @@ pub(crate) async fn stateless_multipart_put(
.execution
.max_buffered_batches_per_output_file;

let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2);
let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
let write_coordinater_task = tokio::spawn(async move {
let write_coordinator_task = SpawnedTask::spawn(async move {
stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
});
while let Some((location, rb_stream)) = file_stream_rx.recv().await {
Expand All @@ -260,10 +261,10 @@ pub(crate) async fn stateless_multipart_put(
})?;
}

// Signal to the write coordinater that no more files are coming
// Signal to the write coordinator that no more files are coming
drop(tx_file_bundle);

match try_join!(write_coordinater_task, demux_task) {
match try_join!(write_coordinator_task.join(), demux_task.join()) {
Ok((r1, r2)) => {
r1?;
r2?;
Expand Down
9 changes: 4 additions & 5 deletions datafusion/core/src/datasource/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use futures::StreamExt;
use tokio::task::spawn_blocking;

use datafusion_common::{plan_err, Constraints, DataFusionError, Result};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{CreateExternalTable, Expr, TableType};
use datafusion_physical_plan::common::AbortOnDropSingle;
use datafusion_physical_plan::common::SpawnedTask;
use datafusion_physical_plan::insert::{DataSink, FileSinkExec};
use datafusion_physical_plan::metrics::MetricsSet;
use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
Expand Down Expand Up @@ -344,22 +343,22 @@ impl DataSink for StreamWrite {
let config = self.0.clone();
let (sender, mut receiver) = tokio::sync::mpsc::channel::<RecordBatch>(2);
// Note: FIFO Files support poll so this could use AsyncFd
let write = AbortOnDropSingle::new(spawn_blocking(move || {
let write_task = SpawnedTask::spawn_blocking(move || {
let mut count = 0_u64;
let mut writer = config.writer()?;
while let Some(batch) = receiver.blocking_recv() {
count += batch.num_rows() as u64;
writer.write(&batch)?;
}
Ok(count)
}));
});

while let Some(b) = data.next().await.transpose()? {
if sender.send(b).await.is_err() {
break;
}
}
drop(sender);
write.await.unwrap()
write_task.join().await.unwrap()
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,7 @@ mod tests {
}

#[tokio::test]
#[allow(clippy::disallowed_methods)]
async fn send_context_to_threads() -> Result<()> {
// ensure SessionContexts can be used in a multi-threaded
// environment. Usecase is for concurrent planing.
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/tests/fifo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ mod unix_test {
let broken_pipe_timeout = Duration::from_secs(10);
let sa = file_path.clone();
// Spawn a new thread to write to the FIFO file
#[allow(clippy::disallowed_methods)] // spawn allowed only in tests
spawn_blocking(move || {
let file = OpenOptions::new().write(true).open(sa).unwrap();
// Reference time to use when deciding to fail the test
Expand Down Expand Up @@ -357,6 +358,7 @@ mod unix_test {
(sink_fifo_path.clone(), sink_fifo_path.display());

// Spawn a new thread to read sink EXTERNAL TABLE.
#[allow(clippy::disallowed_methods)] // spawn allowed only in tests
tasks.push(spawn_blocking(move || {
let file = File::open(sink_fifo_path_thread).unwrap();
let schema = Arc::new(Schema::new(vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ mod sp_repartition_fuzz_tests {
let mut handles = Vec::new();

for seed in seed_start..seed_end {
#[allow(clippy::disallowed_methods)] // spawn allowed only in tests
let job = tokio::spawn(run_sort_preserving_repartition_test(
make_staggered_batches::<true>(n_row, n_distinct, seed as u64),
is_first_roundrobin,
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> {
for i in 0..n {
let idx = i % test_cases.len();
let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone();
#[allow(clippy::disallowed_methods)] // spawn allowed only in tests
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, n_distinct, i as u64),
i as u64,
Expand Down
Loading

0 comments on commit 14264d2

Please sign in to comment.