diff --git a/python/src/utils.rs b/python/src/utils.rs index 3b0a0b0041..b063b64d08 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,9 +15,10 @@ pub fn rt() -> &'static Runtime { Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime. Some(pid) => { panic!( - "Forked process detected - current PID is {} but the tokio runtime was by {}. The tokio runtime - does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are seeing this - message while using Python multithreading make sure to use the `spawn` or `forkserver` mode.", + "Forked process detected - current PID is {} but the tokio runtime was created by {}. The tokio \ + runtime does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are \ + seeing this message while using Python multithreading make sure to use the `spawn` or `forkserver` \ + mode.", pid, std::process::id() ); } diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 6ab5030c8a..cc36fc0274 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,10 +1,9 @@ -from itertools import product import os from datetime import date, datetime, timezone from pathlib import Path from random import random from threading import Barrier, Thread -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Type from unittest.mock import Mock from deltalake._util import encode_partition_value @@ -19,14 +18,15 @@ else: _has_pandas = True +import multiprocessing +from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor + import pyarrow as pa import pyarrow.dataset as ds import pytest from pyarrow.dataset import ParquetReadOptions from pyarrow.fs import LocalFileSystem, SubTreeFileSystem -import multiprocessing -import threading from deltalake import DeltaTable @@ -59,22 +59,51 @@ def test_read_simple_table_to_dict(): dt = DeltaTable(table_path) assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} -def recursively_read_simple_table(thread_or_process_class, depth): - print(thread_or_process_class, depth) - test_read_simple_table_to_dict() + +class _SerializableException(BaseException): + pass + + +def _recursively_read_simple_table(executor_class: Type[Executor], depth): + try: + test_read_simple_table_to_dict() + except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible. + # Re-raise as something that can be serialized and therefore sent back to parent processes. + raise _SerializableException(f"Seraializatble exception: {e}") from e + if depth == 0: return - - process_or_thread = thread_or_process_class(target=recursively_read_simple_table, args=(thread_or_process_class, depth - 1)) - process_or_thread.start() - process_or_thread.join() + # We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors + # are re-rasied in the parent process/thread when we call `future.result()`. + with executor_class(max_workers=1) as executor: + future = executor.submit( + _recursively_read_simple_table, executor_class, depth - 1 + ) + future.result() -@pytest.mark.parametrize("thread_or_process_class, multiprocessing_start_method", [(threading.Thread, None), (multiprocessing.Process, "forkserver"), (multiprocessing.Process, "spawn"), (multiprocessing.Process, "fork")]) -def test_read_simple_in_threads_and_processes(thread_or_process_class, multiprocessing_start_method): +@pytest.mark.parametrize( + "executor_class,multiprocessing_start_method,expect_panic", + [ + (ThreadPoolExecutor, None, False), + (ProcessPoolExecutor, "forkserver", False), + (ProcessPoolExecutor, "spawn", False), + (ProcessPoolExecutor, "fork", True), + ], +) +def test_read_simple_in_threads_and_processes( + executor_class, multiprocessing_start_method, expect_panic +): if multiprocessing_start_method is not None: multiprocessing.set_start_method(multiprocessing_start_method, force=True) - recursively_read_simple_table(thread_or_process_class=thread_or_process_class, depth=10) + if expect_panic: + with pytest.raises( + _SerializableException, + match="The tokio runtime does not support forked processes", + ): + _recursively_read_simple_table(executor_class=executor_class, depth=5) + else: + _recursively_read_simple_table(executor_class=executor_class, depth=5) def test_read_simple_table_by_version_to_dict():