diff --git a/CHANGELOG.md b/CHANGELOG.md index 66560c4..4fc2454 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `get_dask_client` and `get_async_dask_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) + ### Changed ### Deprecated diff --git a/README.md b/README.md index bea610d..84bbe73 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,80 @@ DaskTaskRunner( ) ``` +### Distributing Dask collections across workers + +If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_client` or `get_async_dask_client`: + +```python +import dask +from prefect import flow, task +from prefect_dask import DaskTaskRunner, get_dask_client + +@task +def compute_task(): + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = df.describe().compute() + return summary_df + +@flow(task_runner=DaskTaskRunner()) +def dask_flow(): + prefect_future = compute_task.submit() + return prefect_future.result() + +dask_flow() +``` + +The context managers can be used the same way in both `flow` run contexts and `task` run contexts. + +!!! warning "Resolving futures in sync client" + Note, by default, `dask_collection.compute()` returns concrete values while `client.compute(dask_collection)` returns Dask Futures. Therefore, if you call `client.compute`, you must resolve all futures before exiting out of the context manager by either: + + 1. setting `sync=True` + ```python + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe(), sync=True) + ``` + + 2. calling `result()` + ```python + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()).result() + ``` + For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). + +There is also an equivalent context manager for asynchronous tasks and flows: `get_async_dask_client`. + +```python +import asyncio + +import dask +from prefect import flow, task +from prefect_dask import DaskTaskRunner, get_async_dask_client + +@task +async def compute_task(): + async with get_async_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = await client.compute(df.describe()) + return summary_df + +@flow(task_runner=DaskTaskRunner()) +async def dask_flow(): + prefect_future = await compute_task.submit() + return await prefect_future.result() + +asyncio.run(dask_flow()) +``` +!!! warning "Resolving futures in async client" + With the async client, you do not need to set `sync=True` or call `result()`. + + However you must `await client.compute(dask_collection)` before exiting out of the context manager. + + To invoke `compute` from the Dask collection, set `sync=False` and call `result()` before exiting out of the context manager: `await dask_collection.compute(sync=False)`. + ### Using a temporary cluster The `DaskTaskRunner` is capable of creating a temporary cluster using any of [Dask's cluster-manager options](https://docs.dask.org/en/latest/setup.html). This can be useful when you want each flow run to have its own Dask cluster, allowing for per-flow adaptive scaling. diff --git a/docs/utils.md b/docs/utils.md new file mode 100644 index 0000000..1e9fc7b --- /dev/null +++ b/docs/utils.md @@ -0,0 +1 @@ +::: prefect_dask.utils diff --git a/mkdocs.yml b/mkdocs.yml index 23e8476..7700e56 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,3 +47,5 @@ plugins: nav: - Home: index.md - Task Runners: task_runners.md + - Utils: utils.md + diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index bfd452d..e01e06a 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,4 +1,5 @@ from . import _version from .task_runners import DaskTaskRunner # noqa +from .utils import get_dask_client, get_async_dask_client # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index e5ac80b..fbe0b1b 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -93,12 +93,15 @@ class DaskTaskRunner(BaseTaskRunner): different cluster class (e.g. [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can specify `cluster_class`/`cluster_kwargs`. + Alternatively, if you already have a dask cluster running, you can provide the address of the scheduler via the `address` kwarg. + !!! warning "Multiprocessing safety" Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows in scripts must be guarded with `if __name__ == "__main__":` or warnings will be displayed. + Args: address (string, optional): Address of a currently running dask scheduler; if one is not provided, a temporary cluster will be @@ -113,24 +116,33 @@ class name (e.g. `"distributed.LocalCluster"`), or the class itself. is only enabled if `adapt_kwargs` are provided. client_kwargs (dict, optional): Additional kwargs to use when creating a [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). + Examples: Using a temporary local dask cluster: - >>> from prefect import flow - >>> from prefect_dask.task_runners import DaskTaskRunner - >>> @flow(task_runner=DaskTaskRunner) - >>> def my_flow(): - >>> ... + ```python + from prefect import flow + from prefect_dask.task_runners import DaskTaskRunner + @flow(task_runner=DaskTaskRunner) + def my_flow(): + ... + ``` + Using a temporary cluster running elsewhere. Any Dask cluster class should work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): - >>> DaskTaskRunner( - >>> cluster_class="dask_cloudprovider.FargateCluster", - >>> cluster_kwargs={ - >>> "image": "prefecthq/prefect:latest", - >>> "n_workers": 5, - >>> }, - >>> ) + ```python + DaskTaskRunner( + cluster_class="dask_cloudprovider.FargateCluster", + cluster_kwargs={ + "image": "prefecthq/prefect:latest", + "n_workers": 5, + }, + ) + ``` + Connecting to an existing dask cluster: - >>> DaskTaskRunner(address="192.0.2.255:8786") + ```python + DaskTaskRunner(address="192.0.2.255:8786") + ``` """ def __init__( @@ -271,7 +283,7 @@ async def _start(self, exit_stack: AsyncExitStack): self.logger.info( f"Connecting to an existing Dask cluster at {self.address}" ) - connect_to = self.address + self._connect_to = self.address else: self.cluster_class = self.cluster_class or distributed.LocalCluster @@ -279,14 +291,16 @@ async def _start(self, exit_stack: AsyncExitStack): f"Creating a new Dask cluster with " f"`{to_qualified_name(self.cluster_class)}`" ) - connect_to = self._cluster = await exit_stack.enter_async_context( + self._connect_to = self._cluster = await exit_stack.enter_async_context( self.cluster_class(asynchronous=True, **self.cluster_kwargs) ) if self.adapt_kwargs: self._cluster.adapt(**self.adapt_kwargs) self._client = await exit_stack.enter_async_context( - distributed.Client(connect_to, asynchronous=True, **self.client_kwargs) + distributed.Client( + self._connect_to, asynchronous=True, **self.client_kwargs + ) ) if self._client.dashboard_link: @@ -301,7 +315,7 @@ def __getstate__(self): Must be deserialized on a dask worker. """ data = self.__dict__.copy() - data.update({k: None for k in {"_client", "_cluster"}}) + data.update({k: None for k in {"_client", "_cluster", "_connect_to"}}) return data def __setstate__(self, data: dict): diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py new file mode 100644 index 0000000..39aaad2 --- /dev/null +++ b/prefect_dask/utils.py @@ -0,0 +1,159 @@ +""" +Utils to use alongside prefect-dask. +""" + +from contextlib import asynccontextmanager, contextmanager +from datetime import timedelta +from typing import Any, Dict, Optional, Union + +from distributed import Client, get_client +from prefect.context import FlowRunContext, TaskRunContext + + +def _generate_client_kwargs( + async_client: bool, + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper method to populate keyword arguments for `distributed.Client`. + """ + flow_run_context = FlowRunContext.get() + task_run_context = TaskRunContext.get() + + if task_run_context: + # copies functionality of worker_client(separate_thread=False) + # because this allows us to set asynchronous based on user's task + input_client_kwargs = {} + address = get_client().scheduler.address + asynchronous = task_run_context.task.isasync + elif flow_run_context: + task_runner = flow_run_context.task_runner + input_client_kwargs = task_runner.client_kwargs + address = task_runner._connect_to + asynchronous = flow_run_context.flow.isasync + else: + # this else clause allows users to debug or test + # without much change to code + input_client_kwargs = {} + address = None + asynchronous = async_client + + input_client_kwargs["address"] = address + input_client_kwargs["asynchronous"] = asynchronous + if timeout is not None: + input_client_kwargs["timeout"] = timeout + input_client_kwargs.update(**client_kwargs) + return input_client_kwargs + + +@contextmanager +def get_dask_client( + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any], +) -> Client: + """ + Yields a temporary synchronous dask client; this is useful + for parallelizing operations on dask collections, + such as a `dask.DataFrame` or `dask.Bag`. + + Without invoking this, workers do not automatically get a client to connect + to the full cluster. Therefore, it will attempt perform work within the + worker itself serially, and potentially overwhelming the single worker. + + When in an async context, we recommend using `get_async_dask_client` instead. + + Args: + timeout: Timeout after which to error out; has no effect in + flow run contexts because the client has already started; + Defaults to the `distributed.comm.timeouts.connect` + configuration value. + client_kwargs: Additional keyword arguments to pass to + `distributed.Client`, and overwrites inherited keyword arguments + from the task runner, if any. + + Yields: + A temporary synchronous dask client. + + Examples: + Use `get_dask_client` to distribute work across workers. + ```python + import dask + from prefect import flow, task + from prefect_dask import DaskTaskRunner, get_dask_client + + @task + def compute_task(): + with get_dask_client(timeout="120s") as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()).result() + return summary_df + + @flow(task_runner=DaskTaskRunner()) + def dask_flow(): + prefect_future = compute_task.submit() + return prefect_future.result() + + dask_flow() + ``` + """ + client_kwargs = _generate_client_kwargs( + async_client=False, timeout=timeout, **client_kwargs + ) + with Client(**client_kwargs) as client: + yield client + + +@asynccontextmanager +async def get_async_dask_client( + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any], +) -> Client: + """ + Yields a temporary asynchronous dask client; this is useful + for parallelizing operations on dask collections, + such as a `dask.DataFrame` or `dask.Bag`. + + Without invoking this, workers do not automatically get a client to connect + to the full cluster. Therefore, it will attempt perform work within the + worker itself serially, and potentially overwhelming the single worker. + + Args: + timeout: Timeout after which to error out; has no effect in + flow run contexts because the client has already started; + Defaults to the `distributed.comm.timeouts.connect` + configuration value. + client_kwargs: Additional keyword arguments to pass to + `distributed.Client`, and overwrites inherited keyword arguments + from the task runner, if any. + + Yields: + A temporary asynchronous dask client. + + Examples: + Use `get_async_dask_client` to distribute work across workers. + ```python + import dask + from prefect import flow, task + from prefect_dask import DaskTaskRunner, get_async_dask_client + + @task + async def compute_task(): + async with get_async_dask_client(timeout="120s") as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = await client.compute(df.describe()) + return summary_df + + @flow(task_runner=DaskTaskRunner()) + async def dask_flow(): + prefect_future = await compute_task.submit() + return await prefect_future.result() + + asyncio.run(dask_flow()) + ``` + """ + client_kwargs = _generate_client_kwargs( + async_client=True, timeout=timeout, **client_kwargs + ) + async with Client(**client_kwargs) as client: + yield client diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..81c4091 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,120 @@ +import dask +import pytest +from distributed import Client +from prefect import flow, task + +from prefect_dask import DaskTaskRunner, get_async_dask_client, get_dask_client + + +class TestDaskSyncClient: + def test_from_task(self): + @task + def test_task(): + delayed_num = dask.delayed(42) + with get_dask_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + return result + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + future = test_task.submit() + return future.result() + + assert test_flow() == 42 + + def test_from_flow(self): + @flow(task_runner=DaskTaskRunner) + def test_flow(): + delayed_num = dask.delayed(42) + with get_dask_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + return result + + assert test_flow() == 42 + + def test_outside_run_context(self): + delayed_num = dask.delayed(42) + with get_dask_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + assert result == 42 + + @pytest.mark.parametrize("timeout", [None, 8]) + def test_include_timeout(self, timeout): + delayed_num = dask.delayed(42) + with get_dask_client(timeout=timeout) as client: + assert isinstance(client, Client) + if timeout is not None: + assert client._timeout == timeout + result = client.compute(delayed_num).result() + assert result == 42 + + +class TestDaskAsyncClient: + async def test_from_task(self): + @task + async def test_task(): + delayed_num = dask.delayed(42) + async with get_async_dask_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + return result + + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + future = await test_task.submit() + return await future.result() + + assert (await test_flow()) == 42 + + def test_from_sync_task_error(self): + @task + def test_task(): + with get_async_dask_client(): + pass + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + test_task.submit() + + with pytest.raises(AttributeError, match="__enter__"): + test_flow() + + async def test_from_flow(self): + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + delayed_num = dask.delayed(42) + async with get_async_dask_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + return result + + assert (await test_flow()) == 42 + + def test_from_sync_flow_error(self): + @flow(task_runner=DaskTaskRunner) + def test_flow(): + with get_async_dask_client(): + pass + + with pytest.raises(AttributeError, match="__enter__"): + test_flow() + + async def test_outside_run_context(self): + delayed_num = dask.delayed(42) + async with get_async_dask_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + assert result == 42 + + @pytest.mark.parametrize("timeout", [None, 8]) + async def test_include_timeout(self, timeout): + delayed_num = dask.delayed(42) + async with get_async_dask_client(timeout=timeout) as client: + assert isinstance(client, Client) + if timeout is not None: + assert client._timeout == timeout + result = await client.compute(delayed_num).result() + assert result == 42