diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000000..c6c754e440c7 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,4 @@ +disallowed-methods = [ + { path = "tokio::task::spawn", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, + { path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, +] diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3a60d57f6685..c04247210d46 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2172,6 +2172,7 @@ mod tests { } #[tokio::test] + #[allow(clippy::disallowed_methods)] async fn sendable() { let df = test_table().await.unwrap(); // dataframes should be sendable between threads/tasks diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index ead2db5a10c0..d5f07d11bee9 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -295,7 +295,7 @@ impl DataSink for ArrowFileSink { } } - match demux_task.await { + match demux_task.join().await { Ok(r) => r?, Err(e) => { if e.is_panic() { diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 89ec81630c1b..739850115370 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -32,7 +32,7 @@ use std::fmt::Debug; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::task::{JoinHandle, JoinSet}; +use tokio::task::JoinSet; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; @@ -42,6 +42,7 @@ use bytes::{BufMut, BytesMut}; use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::common::SpawnedTask; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; use object_store::path::Path; @@ -728,7 +729,7 @@ impl DataSink for ParquetSink { } } - match demux_task.await { + match demux_task.join().await { Ok(r) => r?, Err(e) => { if e.is_panic() { @@ -738,6 +739,7 @@ impl DataSink for ParquetSink { } } } + Ok(row_count as u64) } } @@ -754,8 +756,9 @@ async fn column_serializer_task( Ok(writer) } -type ColumnJoinHandle = JoinHandle>; +type ColumnWriterTask = SpawnedTask>; type ColSender = Sender; + /// Spawns a parallel serialization task for each column /// Returns join handles for each columns serialization task along with a send channel /// to send arrow arrays to each serialization task. @@ -763,23 +766,24 @@ fn spawn_column_parallel_row_group_writer( schema: Arc, parquet_props: Arc, max_buffer_size: usize, -) -> Result<(Vec, Vec)> { +) -> Result<(Vec, Vec)> { let schema_desc = arrow_to_parquet_schema(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; let num_columns = col_writers.len(); - let mut col_writer_handles = Vec::with_capacity(num_columns); + let mut col_writer_tasks = Vec::with_capacity(num_columns); let mut col_array_channels = Vec::with_capacity(num_columns); for writer in col_writers.into_iter() { // Buffer size of this channel limits the number of arrays queued up for column level serialization let (send_array, recieve_array) = mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); - col_writer_handles - .push(tokio::spawn(column_serializer_task(recieve_array, writer))) + + let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer)); + col_writer_tasks.push(task); } - Ok((col_writer_handles, col_array_channels)) + Ok((col_writer_tasks, col_array_channels)) } /// Settings related to writing parquet files in parallel @@ -820,14 +824,14 @@ async fn send_arrays_to_col_writers( /// Spawns a tokio task which joins the parallel column writer tasks, /// and finalizes the row group fn spawn_rg_join_and_finalize_task( - column_writer_handles: Vec>>, + column_writer_tasks: Vec, rg_rows: usize, -) -> JoinHandle { - tokio::spawn(async move { - let num_cols = column_writer_handles.len(); +) -> SpawnedTask { + SpawnedTask::spawn(async move { + let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); - for handle in column_writer_handles.into_iter() { - match handle.await { + for task in column_writer_tasks.into_iter() { + match task.join().await { Ok(r) => { let w = r?; finalized_rg.push(w.close()?); @@ -856,12 +860,12 @@ fn spawn_rg_join_and_finalize_task( /// given by n_columns * num_row_groups. fn spawn_parquet_parallel_serialization_task( mut data: Receiver, - serialize_tx: Sender>, + serialize_tx: Sender>, schema: Arc, writer_props: Arc, parallel_options: ParallelParquetWriterOptions, -) -> JoinHandle> { - tokio::spawn(async move { +) -> SpawnedTask> { + SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; let max_row_group_rows = writer_props.max_row_group_size(); let (mut column_writer_handles, mut col_array_channels) = @@ -931,7 +935,7 @@ fn spawn_parquet_parallel_serialization_task( /// Consume RowGroups serialized by other parallel tasks and concatenate them in /// to the final parquet file, while flushing finalized bytes to an [ObjectStore] async fn concatenate_parallel_row_groups( - mut serialize_rx: Receiver>, + mut serialize_rx: Receiver>, schema: Arc, writer_props: Arc, mut object_store_writer: AbortableWrite>, @@ -947,9 +951,8 @@ async fn concatenate_parallel_row_groups( let mut row_count = 0; - while let Some(handle) = serialize_rx.recv().await { - let join_result = handle.await; - match join_result { + while let Some(task) = serialize_rx.recv().await { + match task.join().await { Ok(result) => { let mut rg_out = parquet_writer.next_row_group()?; let (serialized_columns, cnt) = result?; @@ -999,7 +1002,7 @@ async fn output_single_parquet_file_parallelized( let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel let (serialize_tx, serialize_rx) = - mpsc::channel::>(max_rowgroups); + mpsc::channel::>(max_rowgroups); let arc_props = Arc::new(parquet_props.clone()); let launch_serialization_task = spawn_parquet_parallel_serialization_task( @@ -1017,7 +1020,7 @@ async fn output_single_parquet_file_parallelized( ) .await?; - match launch_serialization_task.await { + match launch_serialization_task.join().await { Ok(Ok(_)) => (), Ok(Err(e)) => return Err(e), Err(e) => { @@ -1027,7 +1030,7 @@ async fn output_single_parquet_file_parallelized( unreachable!() } } - }; + } Ok(row_count) } diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 8bccf3d71cf9..d70b4811da5b 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -41,8 +41,8 @@ use object_store::path::Path; use rand::distributions::DistString; +use datafusion_physical_plan::common::SpawnedTask; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; -use tokio::task::JoinHandle; type RecordBatchReceiver = Receiver; type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; @@ -76,15 +76,15 @@ pub(crate) fn start_demuxer_task( partition_by: Option>, base_output_path: ListingTableUrl, file_extension: String, -) -> (JoinHandle>, DemuxedStreamReceiver) { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); +) -> (SpawnedTask>, DemuxedStreamReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); let context = context.clone(); let single_file_output = !base_output_path.is_collection(); - let task: JoinHandle> = match partition_by { + let task = match partition_by { Some(parts) => { // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot // bound this channel without risking a deadlock. - tokio::spawn(async move { + SpawnedTask::spawn(async move { hive_style_partitions_demuxer( tx, input, @@ -96,7 +96,7 @@ pub(crate) fn start_demuxer_task( .await }) } - None => tokio::spawn(async move { + None => SpawnedTask::spawn(async move { row_count_demuxer( tx, input, diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 1a3042cbc00b..05406d3751c9 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -33,10 +33,11 @@ use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; use datafusion_execution::TaskContext; use bytes::Bytes; +use datafusion_physical_plan::common::SpawnedTask; +use futures::try_join; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; -use tokio::task::{JoinHandle, JoinSet}; -use tokio::try_join; +use tokio::task::JoinSet; type WriterType = AbortableWrite>; type SerializerType = Arc; @@ -51,14 +52,14 @@ pub(crate) async fn serialize_rb_stream_to_object_store( mut writer: AbortableWrite>, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = - mpsc::channel::>>(100); - let serialize_task = tokio::spawn(async move { + mpsc::channel::>>(100); + let serialize_task = SpawnedTask::spawn(async move { // Some serializers (like CSV) handle the first batch differently than // subsequent batches, so we track that here. let mut initial = true; while let Some(batch) = data_rx.recv().await { let serializer_clone = serializer.clone(); - let handle = tokio::spawn(async move { + let task = SpawnedTask::spawn(async move { let num_rows = batch.num_rows(); let bytes = serializer_clone.serialize(batch, initial)?; Ok((num_rows, bytes)) @@ -66,7 +67,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( if initial { initial = false; } - tx.send(handle).await.map_err(|_| { + tx.send(task).await.map_err(|_| { internal_datafusion_err!("Unknown error writing to object store") })?; } @@ -74,8 +75,8 @@ pub(crate) async fn serialize_rb_stream_to_object_store( }); let mut row_count = 0; - while let Some(handle) = rx.recv().await { - match handle.await { + while let Some(task) = rx.recv().await { + match task.join().await { Ok(Ok((cnt, bytes))) => { match writer.write_all(&bytes).await { Ok(_) => (), @@ -106,7 +107,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( } } - match serialize_task.await { + match serialize_task.join().await { Ok(Ok(_)) => (), Ok(Err(e)) => return Err((writer, e)), Err(_) => { @@ -115,7 +116,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( internal_datafusion_err!("Unknown error writing to object store"), )) } - }; + } Ok((writer, row_count as u64)) } @@ -241,9 +242,9 @@ pub(crate) async fn stateless_multipart_put( .execution .max_buffered_batches_per_output_file; - let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); + let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2); let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); - let write_coordinater_task = tokio::spawn(async move { + let write_coordinator_task = SpawnedTask::spawn(async move { stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { @@ -260,10 +261,10 @@ pub(crate) async fn stateless_multipart_put( })?; } - // Signal to the write coordinater that no more files are coming + // Signal to the write coordinator that no more files are coming drop(tx_file_bundle); - match try_join!(write_coordinater_task, demux_task) { + match try_join!(write_coordinator_task.join(), demux_task.join()) { Ok((r1, r2)) => { r1?; r2?; diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 830cd7a07e46..6dc59e4a5c65 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -29,12 +29,11 @@ use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; use arrow_schema::SchemaRef; use async_trait::async_trait; use futures::StreamExt; -use tokio::task::spawn_blocking; use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; -use datafusion_physical_plan::common::AbortOnDropSingle; +use datafusion_physical_plan::common::SpawnedTask; use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; @@ -344,7 +343,7 @@ impl DataSink for StreamWrite { let config = self.0.clone(); let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); // Note: FIFO Files support poll so this could use AsyncFd - let write = AbortOnDropSingle::new(spawn_blocking(move || { + let write_task = SpawnedTask::spawn_blocking(move || { let mut count = 0_u64; let mut writer = config.writer()?; while let Some(batch) = receiver.blocking_recv() { @@ -352,7 +351,7 @@ impl DataSink for StreamWrite { writer.write(&batch)?; } Ok(count) - })); + }); while let Some(b) = data.next().await.transpose()? { if sender.send(b).await.is_err() { @@ -360,6 +359,6 @@ impl DataSink for StreamWrite { } } drop(sender); - write.await.unwrap() + write_task.join().await.unwrap() } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index ffc4a4f717d7..453a00a1a5cf 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2288,6 +2288,7 @@ mod tests { } #[tokio::test] + #[allow(clippy::disallowed_methods)] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded // environment. Usecase is for concurrent planing. diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 93c7f7368065..c9ad95a3a042 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -103,6 +103,7 @@ mod unix_test { let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path.clone(); // Spawn a new thread to write to the FIFO file + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test @@ -357,6 +358,7 @@ mod unix_test { (sink_fifo_path.clone(), sink_fifo_path.display()); // Spawn a new thread to read sink EXTERNAL TABLE. + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index df6499e9b1e4..6c9c3359ebf4 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -302,6 +302,7 @@ mod sp_repartition_fuzz_tests { let mut handles = Vec::new(); for seed in seed_start..seed_end { + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_sort_preserving_repartition_test( make_staggered_batches::(n_row, n_distinct, seed as u64), is_first_roundrobin, diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 609d26c9c253..1cab4d5c2f98 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -123,6 +123,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> { for i in 0..n { let idx = i % test_cases.len(); let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_window_test( make_staggered_batches::(1000, n_distinct, i as u64), i as u64, diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index e83dc2525b9f..5172bc9b2a3c 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -21,7 +21,6 @@ use std::fs; use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::task::{Context, Poll}; use super::SendableRecordBatchStream; use crate::stream::RecordBatchReceiverStream; @@ -39,8 +38,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{Future, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use pin_project_lite::pin_project; -use tokio::task::JoinHandle; +use tokio::task::{JoinError, JoinSet}; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -174,50 +172,43 @@ pub fn compute_record_batch_statistics( } } -pin_project! { - /// Helper that aborts the given join handle on drop. - /// - /// Useful to kill background tasks when the consumer is dropped. - #[derive(Debug)] - pub struct AbortOnDropSingle{ - #[pin] - join_handle: JoinHandle, - } - - impl PinnedDrop for AbortOnDropSingle { - fn drop(this: Pin<&mut Self>) { - this.join_handle.abort(); - } - } +/// Helper that provides a simple API to spawn a single task and join it. +/// Provides guarantees of aborting on `Drop` to keep it cancel-safe. +/// +/// Technically, it's just a wrapper of `JoinSet` (with size=1). +#[derive(Debug)] +pub struct SpawnedTask { + inner: JoinSet, } -impl AbortOnDropSingle { - /// Create new abort helper from join handle. - pub fn new(join_handle: JoinHandle) -> Self { - Self { join_handle } +impl SpawnedTask { + pub fn spawn(task: T) -> Self + where + T: Future, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn(task); + Self { inner } } -} -impl Future for AbortOnDropSingle { - type Output = Result; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - this.join_handle.poll(cx) + pub fn spawn_blocking(task: T) -> Self + where + T: FnOnce() -> R, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn_blocking(task); + Self { inner } } -} - -/// Helper that aborts the given join handles on drop. -/// -/// Useful to kill background tasks when the consumer is dropped. -#[derive(Debug)] -pub struct AbortOnDropMany(pub Vec>); -impl Drop for AbortOnDropMany { - fn drop(&mut self) { - for join_handle in &self.0 { - join_handle.abort(); - } + pub async fn join(mut self) -> Result { + self.inner + .join_next() + .await + .expect("`SpawnedTask` instance always contains exactly 1 task") } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 1c4a6ac0ecaf..562e42a7da3b 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -298,14 +298,14 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// "abort" such tasks, they may continue to consume resources even after /// the plan is dropped, generating intermediate results that are never /// used. + /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`]. /// - /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and - /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all - /// background tasks are cancelled. + /// For more details see [`SpawnedTask`], [`JoinSet`] and [`RecordBatchReceiverStreamBuilder`] + /// for structures to help ensure all background tasks are cancelled. /// /// [`spawn`]: tokio::task::spawn - /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle - /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany + /// [`JoinSet`]: tokio::task::JoinSet + /// [`SpawnedTask`]: crate::common::SpawnedTask /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder /// /// # Implementation Examples diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 07693f747fee..a66a929796ab 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -32,21 +32,20 @@ use futures::{FutureExt, StreamExt}; use hashbrown::HashMap; use log::trace; use parking_lot::Mutex; -use tokio::task::JoinHandle; use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; -use crate::common::transpose; +use crate::common::{transpose, SpawnedTask}; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{channels, partition_aware_channels}; use crate::sorts::streaming_merge; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; -use super::common::{AbortOnDropMany, AbortOnDropSingle, SharedMemoryReservation}; +use super::common::SharedMemoryReservation; use super::expressions::PhysicalSortExpr; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream}; @@ -74,7 +73,7 @@ struct RepartitionExecState { >, /// Helper that ensures that that background job is killed once it is no longer needed. - abort_helper: Arc>, + abort_helper: Arc>>, } /// A utility that can be used to partition batches based on [`Partitioning`] @@ -522,7 +521,7 @@ impl ExecutionPlan for RepartitionExec { } // launch one async task per *input* partition - let mut join_handles = Vec::with_capacity(num_input_partitions); + let mut spawned_tasks = Vec::with_capacity(num_input_partitions); for i in 0..num_input_partitions { let txs: HashMap<_, _> = state .channels @@ -534,28 +533,27 @@ impl ExecutionPlan for RepartitionExec { let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); - let input_task: JoinHandle> = - tokio::spawn(Self::pull_from_input( - self.input.clone(), - i, - txs.clone(), - self.partitioning.clone(), - r_metrics, - context.clone(), - )); + let input_task = SpawnedTask::spawn(Self::pull_from_input( + self.input.clone(), + i, + txs.clone(), + self.partitioning.clone(), + r_metrics, + context.clone(), + )); // In a separate task, wait for each input to be done // (and pass along any errors, including panic!s) - let join_handle = tokio::spawn(Self::wait_for_task( - AbortOnDropSingle::new(input_task), + let wait_for_task = SpawnedTask::spawn(Self::wait_for_task( + input_task, txs.into_iter() .map(|(partition, (tx, _reservation))| (partition, tx)) .collect(), )); - join_handles.push(join_handle); + spawned_tasks.push(wait_for_task); } - state.abort_helper = Arc::new(AbortOnDropMany(join_handles)) + state.abort_helper = Arc::new(spawned_tasks) } trace!( @@ -638,7 +636,7 @@ impl RepartitionExec { partitioning, state: Arc::new(Mutex::new(RepartitionExecState { channels: HashMap::new(), - abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])), + abort_helper: Arc::new(Vec::new()), })), metrics: ExecutionPlanMetricsSet::new(), preserve_order: false, @@ -759,12 +757,13 @@ impl RepartitionExec { /// complete. Upon error, propagates the errors to all output tx /// channels. async fn wait_for_task( - input_task: AbortOnDropSingle>, + input_task: SpawnedTask>, txs: HashMap>, ) { // wait for completion, and propagate error // note we ignore errors on send (.ok) as that means the receiver has already shutdown. - match input_task.await { + + match input_task.join().await { // Error in joining task Err(e) => { let e = Arc::new(e); @@ -813,7 +812,7 @@ struct RepartitionStream { /// Handle to ensure background tasks are killed when no longer needed. #[allow(dead_code)] - drop_helper: Arc>, + drop_helper: Arc>>, /// Memory reservation. reservation: SharedMemoryReservation, @@ -877,7 +876,7 @@ struct PerPartitionStream { /// Handle to ensure background tasks are killed when no longer needed. #[allow(dead_code)] - drop_helper: Arc>, + drop_helper: Arc>>, /// Memory reservation. reservation: SharedMemoryReservation, @@ -1056,6 +1055,7 @@ mod tests { } #[tokio::test] + #[allow(clippy::disallowed_methods)] async fn many_to_many_round_robin_within_tokio_task() -> Result<()> { let join_handle: JoinHandle>>> = tokio::spawn(async move { diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 2d8237011fff..84bf3ec415ef 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -27,7 +27,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; -use crate::common::{spawn_buffered, IPCWriter}; +use crate::common::{spawn_buffered, IPCWriter, SpawnedTask}; use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -56,7 +56,6 @@ use datafusion_physical_expr::EquivalenceProperties; use futures::{StreamExt, TryStreamExt}; use log::{debug, error, trace}; use tokio::sync::mpsc::Sender; -use tokio::task; struct ExternalSorterMetrics { /// metrics @@ -604,8 +603,8 @@ async fn spill_sorted_batches( schema: SchemaRef, ) -> Result<()> { let path: PathBuf = path.into(); - let handle = task::spawn_blocking(move || write_sorted(batches, path, schema)); - match handle.await { + let task = SpawnedTask::spawn_blocking(move || write_sorted(batches, path, schema)); + match task.join().await { Ok(r) => r, Err(e) => exec_err!("Error occurred while spilling {e}"), } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index ffae144eae84..41c33deec643 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -88,6 +88,7 @@ async fn run_tests() -> Result<()> { // modifying shared state like `/tmp/`) let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tokio::task::spawn(async move { println!("Running {:?}", test_file.relative_path); if options.complete { diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index f99d6e15e869..ce5635b6daf4 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -136,4 +136,5 @@ datafusion/proto/src/generated/prost.rs .github/ISSUE_TEMPLATE/feature_request.yml .github/workflows/docs.yaml **/node_modules/* -datafusion/wasmtest/pkg/* \ No newline at end of file +datafusion/wasmtest/pkg/* +clippy.toml