diff --git a/clippy.toml b/clippy.toml index 6eb9906c89cf..62d8263085df 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,6 +1,6 @@ disallowed-methods = [ { path = "tokio::task::spawn", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, - { path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, + { path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn_blocking` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, ] disallowed-types = [ diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common_runtime/src/common.rs index 88b74448c7a8..2f7ddb972f42 100644 --- a/datafusion/common_runtime/src/common.rs +++ b/datafusion/common_runtime/src/common.rs @@ -51,10 +51,27 @@ impl SpawnedTask { Self { inner } } + /// Joins the task, returning the result of join (`Result`). pub async fn join(mut self) -> Result { self.inner .join_next() .await .expect("`SpawnedTask` instance always contains exactly 1 task") } + + /// Joins the task and unwinds the panic if it happens. + pub async fn join_unwind(self) -> R { + self.join().await.unwrap_or_else(|e| { + // `JoinError` can be caused either by panic or cancellation. We have to handle panics: + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + // Cancellation may be caused by two reasons: + // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). + // 2. The runtime is shutting down. + // So we consider this branch as unreachable. + unreachable!("SpawnedTask was cancelled unexpectedly"); + } + }) + } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index d5f07d11bee9..90417a978137 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -295,16 +295,7 @@ impl DataSink for ArrowFileSink { } } - match demux_task.join().await { - Ok(r) => r?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + demux_task.join_unwind().await?; Ok(row_count as u64) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 4ea6c2a273f1..3824177cb363 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -729,16 +729,7 @@ impl DataSink for ParquetSink { } } - match demux_task.join().await { - Ok(r) => r?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + demux_task.join_unwind().await?; Ok(row_count as u64) } @@ -831,19 +822,8 @@ fn spawn_rg_join_and_finalize_task( let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - match task.join().await { - Ok(r) => { - let w = r?; - finalized_rg.push(w.close()?); - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()) - } else { - unreachable!() - } - } - } + let writer = task.join_unwind().await?; + finalized_rg.push(writer.close()?); } Ok((finalized_rg, rg_rows)) @@ -952,31 +932,21 @@ async fn concatenate_parallel_row_groups( let mut row_count = 0; while let Some(task) = serialize_rx.recv().await { - match task.join().await { - Ok(result) => { - let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, cnt) = result?; - row_count += cnt; - for chunk in serialized_columns { - chunk.append_to_row_group(&mut rg_out)?; - let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > BUFFER_FLUSH_BYTES { - object_store_writer - .write_all(buff_to_flush.as_slice()) - .await?; - buff_to_flush.clear(); - } - } - rg_out.close()?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } + let result = task.join_unwind().await; + let mut rg_out = parquet_writer.next_row_group()?; + let (serialized_columns, cnt) = result?; + row_count += cnt; + for chunk in serialized_columns { + chunk.append_to_row_group(&mut rg_out)?; + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); } } + rg_out.close()?; } let inner_writer = parquet_writer.into_inner()?; @@ -1020,18 +990,7 @@ async fn output_single_parquet_file_parallelized( ) .await?; - match launch_serialization_task.join().await { - Ok(Ok(_)) => (), - Ok(Err(e)) => return Err(e), - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()) - } else { - unreachable!() - } - } - } - + launch_serialization_task.join_unwind().await?; Ok(row_count) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index dd0e5ce6a40e..b7f268959311 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -34,7 +34,7 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use bytes::Bytes; -use futures::try_join; +use futures::join; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; @@ -264,19 +264,12 @@ pub(crate) async fn stateless_multipart_put( // Signal to the write coordinator that no more files are coming drop(tx_file_bundle); - match try_join!(write_coordinator_task.join(), demux_task.join()) { - Ok((r1, r2)) => { - r1?; - r2?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + let (r1, r2) = join!( + write_coordinator_task.join_unwind(), + demux_task.join_unwind() + ); + r1?; + r2?; let total_count = rx_row_cnt.await.map_err(|_| { internal_datafusion_err!("Did not receieve row count from write coordinater") diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 0d91b1cba34d..079c1a891d14 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -359,6 +359,6 @@ impl DataSink for StreamWrite { } } drop(sender); - write_task.join().await.unwrap() + write_task.join_unwind().await } }