Skip to content

Commit

Permalink
refactor: add join_unwind to SpawnedTask (apache#9422)
Browse files Browse the repository at this point in the history
* refactor: add `join_unwind` to `SpawnedTask`

In order to remove duplication of these handlers it seems logical to have such method.

I thought to add this logic to `join` but there are methods with additional logic

* docs: improve join_unwind comments

* docs: improve join_unwind comments
  • Loading branch information
DDtKey authored Mar 4, 2024
1 parent 22255c2 commit 581fd98
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 84 deletions.
2 changes: 1 addition & 1 deletion clippy.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
disallowed-methods = [
{ path = "tokio::task::spawn", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" },
{ path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" },
{ path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn_blocking` instead (https://github.com/apache/arrow-datafusion/issues/6513)" },
]

disallowed-types = [
Expand Down
17 changes: 17 additions & 0 deletions datafusion/common_runtime/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,27 @@ impl<R: 'static> SpawnedTask<R> {
Self { inner }
}

/// Joins the task, returning the result of join (`Result<R, JoinError>`).
pub async fn join(mut self) -> Result<R, JoinError> {
self.inner
.join_next()
.await
.expect("`SpawnedTask` instance always contains exactly 1 task")
}

/// Joins the task and unwinds the panic if it happens.
pub async fn join_unwind(self) -> R {
self.join().await.unwrap_or_else(|e| {
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
// Cancellation may be caused by two reasons:
// 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside).
// 2. The runtime is shutting down.
// So we consider this branch as unreachable.
unreachable!("SpawnedTask was cancelled unexpectedly");
}
})
}
}
11 changes: 1 addition & 10 deletions datafusion/core/src/datasource/file_format/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,7 @@ impl DataSink for ArrowFileSink {
}
}

match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
demux_task.join_unwind().await?;
Ok(row_count as u64)
}
}
Expand Down
75 changes: 17 additions & 58 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,16 +729,7 @@ impl DataSink for ParquetSink {
}
}

match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
demux_task.join_unwind().await?;

Ok(row_count as u64)
}
Expand Down Expand Up @@ -831,19 +822,8 @@ fn spawn_rg_join_and_finalize_task(
let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
for task in column_writer_tasks.into_iter() {
match task.join().await {
Ok(r) => {
let w = r?;
finalized_rg.push(w.close()?);
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic())
} else {
unreachable!()
}
}
}
let writer = task.join_unwind().await?;
finalized_rg.push(writer.close()?);
}

Ok((finalized_rg, rg_rows))
Expand Down Expand Up @@ -952,31 +932,21 @@ async fn concatenate_parallel_row_groups(
let mut row_count = 0;

while let Some(task) = serialize_rx.recv().await {
match task.join().await {
Ok(result) => {
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, cnt) = result?;
row_count += cnt;
for chunk in serialized_columns {
chunk.append_to_row_group(&mut rg_out)?;
let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
object_store_writer
.write_all(buff_to_flush.as_slice())
.await?;
buff_to_flush.clear();
}
}
rg_out.close()?;
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
let result = task.join_unwind().await;
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, cnt) = result?;
row_count += cnt;
for chunk in serialized_columns {
chunk.append_to_row_group(&mut rg_out)?;
let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
object_store_writer
.write_all(buff_to_flush.as_slice())
.await?;
buff_to_flush.clear();
}
}
rg_out.close()?;
}

let inner_writer = parquet_writer.into_inner()?;
Expand Down Expand Up @@ -1020,18 +990,7 @@ async fn output_single_parquet_file_parallelized(
)
.await?;

match launch_serialization_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err(e),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic())
} else {
unreachable!()
}
}
}

launch_serialization_task.join_unwind().await?;
Ok(row_count)
}

Expand Down
21 changes: 7 additions & 14 deletions datafusion/core/src/datasource/file_format/write/orchestration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::TaskContext;

use bytes::Bytes;
use futures::try_join;
use futures::join;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::JoinSet;
Expand Down Expand Up @@ -264,19 +264,12 @@ pub(crate) async fn stateless_multipart_put(
// Signal to the write coordinator that no more files are coming
drop(tx_file_bundle);

match try_join!(write_coordinator_task.join(), demux_task.join()) {
Ok((r1, r2)) => {
r1?;
r2?;
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
let (r1, r2) = join!(
write_coordinator_task.join_unwind(),
demux_task.join_unwind()
);
r1?;
r2?;

let total_count = rx_row_cnt.await.map_err(|_| {
internal_datafusion_err!("Did not receieve row count from write coordinater")
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,6 @@ impl DataSink for StreamWrite {
}
}
drop(sender);
write_task.join().await.unwrap()
write_task.join_unwind().await
}
}

0 comments on commit 581fd98

Please sign in to comment.