From 1a29e057003e95d992d0ddc83f4819beff496426 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Mon, 26 Sep 2022 20:20:33 -0700 Subject: [PATCH 01/21] Add get_dask_client --- README.md | 27 +++++++++++++++++++ prefect_dask/__init__.py | 1 + prefect_dask/utils.py | 56 ++++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 26 +++++++++++++++++++ 4 files changed, 110 insertions(+) create mode 100644 prefect_dask/utils.py create mode 100644 tests/test_utils.py diff --git a/README.md b/README.md index bea610d..ed99b34 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,33 @@ DaskTaskRunner( ) ``` +### Distributing work across workers + +If your task contains a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager within your task. + +Be mindful of the futures upon `submit` and `compute`. To resolve these futures, call `result`. + +```python +import dask +from prefect import flow, task +from prefect_dask import DaskTaskRunner, get_dask_client + +@task +def compute_task(): + with get_dask_client(timeout="120s") as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()) + return summary_df.result() + +@flow(task_runner=DaskTaskRunner()) +def dask_flow(): + prefect_future = compute_task.submit() + return prefect_future.result() + +dask_flow() +``` + + ### Using a temporary cluster The `DaskTaskRunner` is capable of creating a temporary cluster using any of [Dask's cluster-manager options](https://docs.dask.org/en/latest/setup.html). This can be useful when you want each flow run to have its own Dask cluster, allowing for per-flow adaptive scaling. diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index bfd452d..dd3ee25 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,4 +1,5 @@ from . import _version from .task_runners import DaskTaskRunner # noqa +from .utils import get_dask_client # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py new file mode 100644 index 0000000..88ab296 --- /dev/null +++ b/prefect_dask/utils.py @@ -0,0 +1,56 @@ +""" +Utils to use alongside prefect-dask. +""" + +from contextlib import contextmanager +from datetime import timedelta +from typing import Optional, Union + +from distributed import worker_client + + +@contextmanager +def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None): + """ + This is intended to be called within tasks that run on workers, and is + useful for operating on dask collections, such as a `dask.DataFrame`. + + Without invoking this, workers in a task do not automatically + get a client to connect to the full cluster. Therefore, it will attempt + perform work within the worker itself serially, and potentially overwhelming + the single worker. + + Internally, this context manager is a simple wrapper around + `distributed.worker_client` with `separate_thread=False` fixed. + + Args: + timeout: Timeout after which to error out. Defaults to the + `distributed.comm.timeouts.connect` configuration value. + + Returns: + The dask worker client. + + Examples: + Use `get_dask_client` to distribute work across workers. + ```python + import dask + from prefect import flow, task + from prefect_dask import DaskTaskRunner, get_dask_client + + @task + def compute_task(): + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2005", partition_freq="2w") + summary_df = df.describe() + client.compute(summary_df) + + @flow(task_runner=DaskTaskRunner()) + def dask_flow(): + compute_task.submit() + + if __name__ == "__main__": + dask_flow() + ``` + """ + with worker_client(timeout=timeout, separate_thread=False) as client: + yield client diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b4f595b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +from datetime import timedelta + +import dask +import pytest +from distributed import Client +from prefect import flow, task + +from prefect_dask import DaskTaskRunner, get_dask_client + + +@pytest.mark.parametrize("timeout", [10, 10.0, "10s", timedelta(seconds=10)]) +def test_get_dask_client_integration(timeout): + @task + def test_task(): + delayed_num = dask.delayed(42) + with get_dask_client(timeout=timeout) as client: + assert isinstance(client, Client) + future = client.compute(delayed_num) + return future.result() + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + future = test_task.submit() + return future.result() + + assert test_flow() == 42 From bbc9fe1fe1838d91586827ca05af0ef2ab45adc9 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Mon, 26 Sep 2022 20:26:02 -0700 Subject: [PATCH 02/21] Add changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66560c4..7ac618f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `get_dask_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) + ### Changed ### Deprecated From 30799312aa37c9c1e00d2b5f6b1bfc762b85c281 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Mon, 26 Sep 2022 20:28:15 -0700 Subject: [PATCH 03/21] Add docs --- docs/utils.md | 1 + mkdocs.yml | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 docs/utils.md diff --git a/docs/utils.md b/docs/utils.md new file mode 100644 index 0000000..1e9fc7b --- /dev/null +++ b/docs/utils.md @@ -0,0 +1 @@ +::: prefect_dask.utils diff --git a/mkdocs.yml b/mkdocs.yml index 23e8476..7700e56 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,3 +47,5 @@ plugins: nav: - Home: index.md - Task Runners: task_runners.md + - Utils: utils.md + From fe135a8d6df6db4eac6e671a609282c2e612469a Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 28 Sep 2022 12:06:29 -0700 Subject: [PATCH 04/21] Add flow run context --- README.md | 33 +++++++++++++++-- prefect_dask/task_runners.py | 31 +++++++++------- prefect_dask/utils.py | 72 +++++++++++++++++++++++++----------- tests/test_utils.py | 45 +++++++++++++++++++--- 4 files changed, 136 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index ed99b34..ad439a3 100644 --- a/README.md +++ b/README.md @@ -109,9 +109,9 @@ DaskTaskRunner( ### Distributing work across workers -If your task contains a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager within your task. +If your task contains a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager. -Be mindful of the futures upon `submit` and `compute`. To resolve these futures, call `result`. +Within task run contexts: ```python import dask @@ -122,8 +122,8 @@ from prefect_dask import DaskTaskRunner, get_dask_client def compute_task(): with get_dask_client(timeout="120s") as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") - summary_df = client.compute(df.describe()) - return summary_df.result() + summary_df = client.compute(df.describe()).result() + return summary_df @flow(task_runner=DaskTaskRunner()) def dask_flow(): @@ -133,6 +133,31 @@ def dask_flow(): dask_flow() ``` +Within flow run contexts; not `timeout` must be set in `DaskTaskRunner`: +```python +import dask +from prefect import flow +from prefect_dask import DaskTaskRunner, get_dask_client + +@flow(task_runner=DaskTaskRunner(client_kwargs=dict(timeout="120s"))) +def dask_flow(): + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()).result() + return summary_df + +dask_flow() +``` + +Be mindful of the futures upon `compute`; without resolving the futures, you may encounter: +`Future: finalize status: cancelled`. To resolve these futures, call `result`. + +To resolve multiple Dask futures together, use `sync=True`: +```python +summary_df = client.compute(futures, sync=True) +``` + +For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). ### Using a temporary cluster diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index e5ac80b..7fcc82a 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -115,22 +115,27 @@ class name (e.g. `"distributed.LocalCluster"`), or the class itself. [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). Examples: Using a temporary local dask cluster: - >>> from prefect import flow - >>> from prefect_dask.task_runners import DaskTaskRunner - >>> @flow(task_runner=DaskTaskRunner) - >>> def my_flow(): - >>> ... + ```python + from prefect import flow + from prefect_dask.task_runners import DaskTaskRunner + @flow(task_runner=DaskTaskRunner) + def my_flow(): + ... + ``` Using a temporary cluster running elsewhere. Any Dask cluster class should work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): - >>> DaskTaskRunner( - >>> cluster_class="dask_cloudprovider.FargateCluster", - >>> cluster_kwargs={ - >>> "image": "prefecthq/prefect:latest", - >>> "n_workers": 5, - >>> }, - >>> ) + + ```python + DaskTaskRunner( + cluster_class="dask_cloudprovider.FargateCluster", + cluster_kwargs={ + "image": "prefecthq/prefect:latest", + "n_workers": 5, + }, + ) Connecting to an existing dask cluster: - >>> DaskTaskRunner(address="192.0.2.255:8786") + DaskTaskRunner(address="192.0.2.255:8786") + ``` """ def __init__( diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 88ab296..3530b8d 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -6,32 +6,39 @@ from datetime import timedelta from typing import Optional, Union -from distributed import worker_client +from distributed import Client, worker_client +from prefect.context import FlowRunContext, TaskRunContext @contextmanager def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None): """ - This is intended to be called within tasks that run on workers, and is - useful for operating on dask collections, such as a `dask.DataFrame`. + This is useful for parallelizing operations on dask collections, + such as a `dask.DataFrame`. - Without invoking this, workers in a task do not automatically - get a client to connect to the full cluster. Therefore, it will attempt - perform work within the worker itself serially, and potentially overwhelming - the single worker. + Without invoking this, workers do not automatically get a client to connect + to the full cluster. Therefore, it will attempt perform work within the + worker itself serially, and potentially overwhelming the single worker. - Internally, this context manager is a simple wrapper around - `distributed.worker_client` with `separate_thread=False` fixed. + Within the task run context, this context manager is a simple + wrapper around `distributed.worker_client` with `separate_thread=False` fixed + Within the flow run context, this context manager simply returns + the existing client used in `DaskTaskRunner`. Args: - timeout: Timeout after which to error out. Defaults to the - `distributed.comm.timeouts.connect` configuration value. + timeout: Timeout after which to error out; has no effect in + flow run contexts because the client has already started; + Defaults to the `distributed.comm.timeouts.connect` + configuration value. Returns: - The dask worker client. + Within task run contexts, the dask worker client, and within flow run contexts, + the existing client used in `DaskTaskRunner`. Examples: - Use `get_dask_client` to distribute work across workers. + Use `get_dask_client` to distribute work across workers within task run context. + Be mindful of the futures upon `submit` and `compute`. To resolve these futures, + call `result` on the future. ```python import dask from prefect import flow, task @@ -39,18 +46,39 @@ def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None) @task def compute_task(): - with get_dask_client() as client: - df = dask.datasets.timeseries("2000", "2005", partition_freq="2w") - summary_df = df.describe() - client.compute(summary_df) + with get_dask_client(timeout="120s") as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()).result() + return summary_df @flow(task_runner=DaskTaskRunner()) def dask_flow(): - compute_task.submit() + prefect_future = compute_task.submit() + return prefect_future.result() - if __name__ == "__main__": - dask_flow() + dask_flow() ``` """ - with worker_client(timeout=timeout, separate_thread=False) as client: - yield client + task_run_context = TaskRunContext.get() + flow_run_context = FlowRunContext.get() + + if task_run_context: + with worker_client(timeout=timeout, separate_thread=False) as client: + yield client + elif flow_run_context: + if timeout is not None: + raise ValueError( + "Passing `timeout` to `get_dask_client` has no " + "effect within the flow run context; instead, pass `timeout` " + "to `client_kwargs` when instantiating `DaskTaskRunner`." + ) + task_runner = flow_run_context.task_runner + yield task_runner._client + else: + # this else clause allows users to debug or test + # without much change to code + client_kwargs = {} + if timeout is not None: # dask errors if timeout=None here + client_kwargs["timeout"] = timeout + with Client(**client_kwargs) as client: + yield client diff --git a/tests/test_utils.py b/tests/test_utils.py index b4f595b..2faecfe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,3 @@ -from datetime import timedelta - import dask import pytest from distributed import Client @@ -8,15 +6,17 @@ from prefect_dask import DaskTaskRunner, get_dask_client -@pytest.mark.parametrize("timeout", [10, 10.0, "10s", timedelta(seconds=10)]) -def test_get_dask_client_integration(timeout): +@pytest.mark.parametrize("timeout", [None, 10]) +def test_get_dask_client_task_run_context_integration(timeout): @task def test_task(): delayed_num = dask.delayed(42) with get_dask_client(timeout=timeout) as client: assert isinstance(client, Client) - future = client.compute(delayed_num) - return future.result() + if timeout is not None: + assert client._timeout == timeout + result = client.compute(delayed_num).result() + return result @flow(task_runner=DaskTaskRunner) def test_flow(): @@ -24,3 +24,36 @@ def test_flow(): return future.result() assert test_flow() == 42 + + +def test_get_dask_client_flow_run_context_integration(): + @flow(task_runner=DaskTaskRunner) + def test_flow(): + delayed_num = dask.delayed(42) + with get_dask_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num, sync=True) + return result + + assert test_flow() == 42 + + +def test_get_dask_client_flow_run_context_catch_timeout_error(): + @flow(task_runner=DaskTaskRunner) + def test_flow(): + with get_dask_client(timeout=42): + pass + + with pytest.raises(ValueError, match="Passing `timeout`"): + test_flow() + + +@pytest.mark.parametrize("timeout", [None, 10]) +def test_get_dask_client_no_context(timeout): + delayed_num = dask.delayed(42) + with get_dask_client(timeout=timeout) as client: + assert isinstance(client, Client) + if timeout is not None: + assert client._timeout == timeout + result = client.compute(delayed_num).result() + assert result == 42 From d9e4e6f957b6bc88e9390f2526e7762d9e97ffad Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 28 Sep 2022 12:08:08 -0700 Subject: [PATCH 05/21] minor tweak to readme --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ad439a3..340ba6b 100644 --- a/README.md +++ b/README.md @@ -107,12 +107,11 @@ DaskTaskRunner( ) ``` -### Distributing work across workers +### Distributing Dask collections across workers If your task contains a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager. Within task run contexts: - ```python import dask from prefect import flow, task @@ -133,7 +132,7 @@ def dask_flow(): dask_flow() ``` -Within flow run contexts; not `timeout` must be set in `DaskTaskRunner`: +Within flow run contexts; `timeout` must be set in `DaskTaskRunner`: ```python import dask from prefect import flow From d9355b0bdf3ff03ef0bc3cc7ffb32cbae0e70a00 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Wed, 28 Sep 2022 12:46:50 -0700 Subject: [PATCH 06/21] Apply suggestions from code review Co-authored-by: Michael Adkins --- tests/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2faecfe..4570c0c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("timeout", [None, 10]) -def test_get_dask_client_task_run_context_integration(timeout): +def test_get_dask_client_from_task(timeout): @task def test_task(): delayed_num = dask.delayed(42) @@ -26,7 +26,7 @@ def test_flow(): assert test_flow() == 42 -def test_get_dask_client_flow_run_context_integration(): +def test_get_dask_client_from_flow(): @flow(task_runner=DaskTaskRunner) def test_flow(): delayed_num = dask.delayed(42) @@ -49,7 +49,7 @@ def test_flow(): @pytest.mark.parametrize("timeout", [None, 10]) -def test_get_dask_client_no_context(timeout): +def test_get_dask_client_outside_run_context(timeout): delayed_num = dask.delayed(42) with get_dask_client(timeout=timeout) as client: assert isinstance(client, Client) From de8ba7dd954bdf6326220f74828448f686237731 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 28 Sep 2022 13:17:02 -0700 Subject: [PATCH 07/21] Start new client --- prefect_dask/task_runners.py | 9 ++++++--- prefect_dask/utils.py | 11 ++++------- tests/test_utils.py | 10 ---------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index 7fcc82a..79da1de 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -189,6 +189,7 @@ def __init__( # Runtime attributes self._client: "distributed.Client" = None self._cluster: "distributed.deploy.Cluster" = None + self._connect_to: Union[str, "distributed.deploy.Cluster"] = None self._dask_futures: Dict[str, "distributed.Future"] = {} super().__init__() @@ -276,7 +277,7 @@ async def _start(self, exit_stack: AsyncExitStack): self.logger.info( f"Connecting to an existing Dask cluster at {self.address}" ) - connect_to = self.address + self._connect_to = self.address else: self.cluster_class = self.cluster_class or distributed.LocalCluster @@ -284,14 +285,16 @@ async def _start(self, exit_stack: AsyncExitStack): f"Creating a new Dask cluster with " f"`{to_qualified_name(self.cluster_class)}`" ) - connect_to = self._cluster = await exit_stack.enter_async_context( + self._connect_to = self._cluster = await exit_stack.enter_async_context( self.cluster_class(asynchronous=True, **self.cluster_kwargs) ) if self.adapt_kwargs: self._cluster.adapt(**self.adapt_kwargs) self._client = await exit_stack.enter_async_context( - distributed.Client(connect_to, asynchronous=True, **self.client_kwargs) + distributed.Client( + self._connect_to, asynchronous=True, **self.client_kwargs + ) ) if self._client.dashboard_link: diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 3530b8d..ad57055 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -66,14 +66,11 @@ def dask_flow(): with worker_client(timeout=timeout, separate_thread=False) as client: yield client elif flow_run_context: - if timeout is not None: - raise ValueError( - "Passing `timeout` to `get_dask_client` has no " - "effect within the flow run context; instead, pass `timeout` " - "to `client_kwargs` when instantiating `DaskTaskRunner`." - ) task_runner = flow_run_context.task_runner - yield task_runner._client + connect_to = task_runner._connect_to + client_kwargs = task_runner.client_kwargs + with Client(connect_to, **client_kwargs) as client: + yield client else: # this else clause allows users to debug or test # without much change to code diff --git a/tests/test_utils.py b/tests/test_utils.py index 4570c0c..ea4c7b9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,16 +38,6 @@ def test_flow(): assert test_flow() == 42 -def test_get_dask_client_flow_run_context_catch_timeout_error(): - @flow(task_runner=DaskTaskRunner) - def test_flow(): - with get_dask_client(timeout=42): - pass - - with pytest.raises(ValueError, match="Passing `timeout`"): - test_flow() - - @pytest.mark.parametrize("timeout", [None, 10]) def test_get_dask_client_outside_run_context(timeout): delayed_num = dask.delayed(42) From 3058223fe37b2cb17e1ee364a9cfa4b3d8d2baab Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 28 Sep 2022 15:08:19 -0700 Subject: [PATCH 08/21] Refactor and update examples --- README.md | 41 ++++++++++-------------- prefect_dask/task_runners.py | 4 +-- prefect_dask/utils.py | 62 ++++++++++++++++++++---------------- tests/test_utils.py | 11 ++++--- 4 files changed, 59 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 340ba6b..c11ff8f 100644 --- a/README.md +++ b/README.md @@ -109,9 +109,8 @@ DaskTaskRunner( ### Distributing Dask collections across workers -If your task contains a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager. +If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager: -Within task run contexts: ```python import dask from prefect import flow, task @@ -119,9 +118,9 @@ from prefect_dask import DaskTaskRunner, get_dask_client @task def compute_task(): - with get_dask_client(timeout="120s") as client: + with get_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") - summary_df = client.compute(df.describe()).result() + summary_df = df.describe().compute() return summary_df @flow(task_runner=DaskTaskRunner()) @@ -132,31 +131,23 @@ def dask_flow(): dask_flow() ``` -Within flow run contexts; `timeout` must be set in `DaskTaskRunner`: -```python -import dask -from prefect import flow -from prefect_dask import DaskTaskRunner, get_dask_client +The util, `get_dask_client`, can be used the same way in both `flow` run contexts and `task` run contexts. -@flow(task_runner=DaskTaskRunner(client_kwargs=dict(timeout="120s"))) -def dask_flow(): +!!! warning "Resolving futures" + Note, by default, `client.compute(dask_collection)` returns Dask Futures while `dask_collection.compute()` returns concrete values. Therefore, if you call `client.compute`, you must resolve all futures before exiting out of the context manager by either setting `sync=True` or calling `result`! + + ```python with get_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") - summary_df = client.compute(df.describe()).result() - return summary_df + summary_df = client.compute(df.describe(), sync=True) + ``` -dask_flow() -``` - -Be mindful of the futures upon `compute`; without resolving the futures, you may encounter: -`Future: finalize status: cancelled`. To resolve these futures, call `result`. - -To resolve multiple Dask futures together, use `sync=True`: -```python -summary_df = client.compute(futures, sync=True) -``` - -For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). + ```python + with get_dask_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = client.compute(df.describe()).result() + ``` + For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). ### Using a temporary cluster diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index 79da1de..c65111e 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -189,7 +189,6 @@ def __init__( # Runtime attributes self._client: "distributed.Client" = None self._cluster: "distributed.deploy.Cluster" = None - self._connect_to: Union[str, "distributed.deploy.Cluster"] = None self._dask_futures: Dict[str, "distributed.Future"] = {} super().__init__() @@ -285,9 +284,10 @@ async def _start(self, exit_stack: AsyncExitStack): f"Creating a new Dask cluster with " f"`{to_qualified_name(self.cluster_class)}`" ) - self._connect_to = self._cluster = await exit_stack.enter_async_context( + self._cluster = await exit_stack.enter_async_context( self.cluster_class(asynchronous=True, **self.cluster_kwargs) ) + self._connect_to = self._cluster.scheduler_address if self.adapt_kwargs: self._cluster.adapt(**self.adapt_kwargs) diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index ad57055..3abaa0d 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -4,32 +4,33 @@ from contextlib import contextmanager from datetime import timedelta -from typing import Optional, Union +from typing import Any, Dict, Optional, Union -from distributed import Client, worker_client +from distributed import Client, get_client from prefect.context import FlowRunContext, TaskRunContext @contextmanager -def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None): +def get_dask_client( + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any] +) -> Client: """ - This is useful for parallelizing operations on dask collections, - such as a `dask.DataFrame`. + Yields a temporary client; this is useful for parallelizing operations + on dask collections, such as a `dask.DataFrame` or `dask.Bag`. Without invoking this, workers do not automatically get a client to connect to the full cluster. Therefore, it will attempt perform work within the worker itself serially, and potentially overwhelming the single worker. - Within the task run context, this context manager is a simple - wrapper around `distributed.worker_client` with `separate_thread=False` fixed - Within the flow run context, this context manager simply returns - the existing client used in `DaskTaskRunner`. - Args: timeout: Timeout after which to error out; has no effect in flow run contexts because the client has already started; Defaults to the `distributed.comm.timeouts.connect` configuration value. + client_kwargs: Additional keyword arguments to pass to + `distributed.Client`, and overwrites inherited keyword arguments + from the task runner, if any. Returns: Within task run contexts, the dask worker client, and within flow run contexts, @@ -37,8 +38,6 @@ def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None) Examples: Use `get_dask_client` to distribute work across workers within task run context. - Be mindful of the futures upon `submit` and `compute`. To resolve these futures, - call `result` on the future. ```python import dask from prefect import flow, task @@ -48,7 +47,7 @@ def get_dask_client(timeout: Optional[Union[int, float, str, timedelta]] = None) def compute_task(): with get_dask_client(timeout="120s") as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") - summary_df = client.compute(df.describe()).result() + summary_df = df.describe().compute() return summary_df @flow(task_runner=DaskTaskRunner()) @@ -59,23 +58,32 @@ def dask_flow(): dask_flow() ``` """ - task_run_context = TaskRunContext.get() flow_run_context = FlowRunContext.get() + task_run_context = TaskRunContext.get() - if task_run_context: - with worker_client(timeout=timeout, separate_thread=False) as client: - yield client - elif flow_run_context: + if flow_run_context: task_runner = flow_run_context.task_runner - connect_to = task_runner._connect_to - client_kwargs = task_runner.client_kwargs - with Client(connect_to, **client_kwargs) as client: - yield client + input_client_kwargs = task_runner.client_kwargs + address = task_runner._connect_to + asynchronous = flow_run_context.flow.isasync + elif task_run_context: + # copies functionality of worker_client(separate_thread=False) + # because this allows us to set asynchronous based on user's task + input_client_kwargs = {} + address = get_client(timeout=timeout).scheduler.address + asynchronous = task_run_context.task.isasync else: # this else clause allows users to debug or test # without much change to code - client_kwargs = {} - if timeout is not None: # dask errors if timeout=None here - client_kwargs["timeout"] = timeout - with Client(**client_kwargs) as client: - yield client + input_client_kwargs = {} + address = None + asynchronous = False + + input_client_kwargs["address"] = address + input_client_kwargs["asynchronous"] = asynchronous + if timeout is not None: + input_client_kwargs["timeout"] = timeout + input_client_kwargs.update(**client_kwargs) + + with Client(**input_client_kwargs) as client: + yield client diff --git a/tests/test_utils.py b/tests/test_utils.py index ea4c7b9..122f3c1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,7 +15,7 @@ def test_task(): assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = client.compute(delayed_num).result() + result = client.compute(delayed_num) return result @flow(task_runner=DaskTaskRunner) @@ -26,13 +26,14 @@ def test_flow(): assert test_flow() == 42 -def test_get_dask_client_from_flow(): +@pytest.mark.parametrize("timeout", [None, 10]) +def test_get_dask_client_from_flow(timeout): @flow(task_runner=DaskTaskRunner) def test_flow(): delayed_num = dask.delayed(42) - with get_dask_client() as client: + with get_dask_client(timeout=timeout) as client: assert isinstance(client, Client) - result = client.compute(delayed_num, sync=True) + result = client.compute(delayed_num) return result assert test_flow() == 42 @@ -45,5 +46,5 @@ def test_get_dask_client_outside_run_context(timeout): assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = client.compute(delayed_num).result() + result = client.compute(delayed_num) assert result == 42 From 847c7edddb5de42fb88c771871294fd0c7242a7f Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 28 Sep 2022 15:16:38 -0700 Subject: [PATCH 09/21] Fix tests --- tests/test_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 122f3c1..0060181 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,7 +15,7 @@ def test_task(): assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = client.compute(delayed_num) + result = delayed_num.compute() return result @flow(task_runner=DaskTaskRunner) @@ -33,7 +33,9 @@ def test_flow(): delayed_num = dask.delayed(42) with get_dask_client(timeout=timeout) as client: assert isinstance(client, Client) - result = client.compute(delayed_num) + if timeout is not None: + assert client._timeout == timeout + result = delayed_num.compute() return result assert test_flow() == 42 @@ -46,5 +48,5 @@ def test_get_dask_client_outside_run_context(timeout): assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = client.compute(delayed_num) + result = delayed_num.compute() assert result == 42 From 84336f7e9e2fb2e2ad62bd308c53d1be8d3b69fb Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 29 Sep 2022 14:53:10 -0700 Subject: [PATCH 10/21] Refactor for two separate sync/async clients --- README.md | 48 +++++++++-- prefect_dask/__init__.py | 2 +- prefect_dask/exceptions.py | 9 ++ prefect_dask/task_runners.py | 9 +- prefect_dask/utils.py | 152 +++++++++++++++++++++++++-------- tests/test_utils.py | 161 ++++++++++++++++++++++++++++------- 6 files changed, 301 insertions(+), 80 deletions(-) create mode 100644 prefect_dask/exceptions.py diff --git a/README.md b/README.md index c11ff8f..e526719 100644 --- a/README.md +++ b/README.md @@ -109,16 +109,16 @@ DaskTaskRunner( ### Distributing Dask collections across workers -If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use the `get_dask_client` context manager: +If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_sync_client` or `get_dask_async_client`: ```python import dask from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_client +from prefect_dask import DaskTaskRunner, get_dask_sync_client @task def compute_task(): - with get_dask_client() as client: + with get_dask_sync_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = df.describe().compute() return summary_df @@ -131,24 +131,54 @@ def dask_flow(): dask_flow() ``` -The util, `get_dask_client`, can be used the same way in both `flow` run contexts and `task` run contexts. - -!!! warning "Resolving futures" - Note, by default, `client.compute(dask_collection)` returns Dask Futures while `dask_collection.compute()` returns concrete values. Therefore, if you call `client.compute`, you must resolve all futures before exiting out of the context manager by either setting `sync=True` or calling `result`! +The context managers can be used the same way in both `flow` run contexts and `task` run contexts. +!!! warning "Resolving futures in sync client" + Note, by default, `dask_collection.compute()` returns concrete values while `client.compute(dask_collection)` returns Dask Futures. Therefore, if you call `client.compute`, you must resolve all futures before exiting out of the context manager by either: + + 1. setting `sync=True` ```python - with get_dask_client() as client: + with get_dask_sync_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = client.compute(df.describe(), sync=True) ``` + 2. calling `result()` ```python - with get_dask_client() as client: + with get_dask_sync_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = client.compute(df.describe()).result() ``` For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). +There is also an equivalent `async` version, namely `get_dask_async_client`. + +```python +import dask +from prefect import flow, task +from prefect_dask import DaskTaskRunner, get_dask_async_client + +@task +async def compute_task(): + async with get_dask_async_client() as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = await client.compute(df.describe()) + return summary_df + +@flow(task_runner=DaskTaskRunner()) +async def dask_flow(): + prefect_future = await compute_task.submit() + return await prefect_future.result() + +asyncio.run(dask_flow()) +``` +!!! warning "Resolving futures in async client" + With the async client, you do not need to set `sync=True` or call `result()`. + + However you must `await client.compute` before exiting out of the context manager. + + Running `await dask_collection.compute()` will result in an error: `TypeError: 'coroutine' object is not iterable`. + ### Using a temporary cluster The `DaskTaskRunner` is capable of creating a temporary cluster using any of [Dask's cluster-manager options](https://docs.dask.org/en/latest/setup.html). This can be useful when you want each flow run to have its own Dask cluster, allowing for per-flow adaptive scaling. diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index dd3ee25..fa2ba6f 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,5 +1,5 @@ from . import _version from .task_runners import DaskTaskRunner # noqa -from .utils import get_dask_client # noqa +from .utils import get_dask_sync_client, get_dask_async_client # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/exceptions.py b/prefect_dask/exceptions.py new file mode 100644 index 0000000..42c56f3 --- /dev/null +++ b/prefect_dask/exceptions.py @@ -0,0 +1,9 @@ +""" +Exceptions specific to prefect-dask. +""" + + +class ImproperClientError(Exception): + """ + Raised when the flow/task is async but the client is sync, or vice versa. + """ diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index c65111e..ec58052 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -93,12 +93,15 @@ class DaskTaskRunner(BaseTaskRunner): different cluster class (e.g. [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can specify `cluster_class`/`cluster_kwargs`. + Alternatively, if you already have a dask cluster running, you can provide the address of the scheduler via the `address` kwarg. + !!! warning "Multiprocessing safety" Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows in scripts must be guarded with `if __name__ == "__main__":` or warnings will be displayed. + Args: address (string, optional): Address of a currently running dask scheduler; if one is not provided, a temporary cluster will be @@ -113,6 +116,7 @@ class name (e.g. `"distributed.LocalCluster"`), or the class itself. is only enabled if `adapt_kwargs` are provided. client_kwargs (dict, optional): Additional kwargs to use when creating a [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). + Examples: Using a temporary local dask cluster: ```python @@ -122,9 +126,9 @@ class name (e.g. `"distributed.LocalCluster"`), or the class itself. def my_flow(): ... ``` + Using a temporary cluster running elsewhere. Any Dask cluster class should work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): - ```python DaskTaskRunner( cluster_class="dask_cloudprovider.FargateCluster", @@ -133,7 +137,10 @@ def my_flow(): "n_workers": 5, }, ) + ``` + Connecting to an existing dask cluster: + ```python DaskTaskRunner(address="192.0.2.255:8786") ``` """ diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 3abaa0d..c5c0b2e 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -2,28 +2,79 @@ Utils to use alongside prefect-dask. """ -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from datetime import timedelta from typing import Any, Dict, Optional, Union from distributed import Client, get_client from prefect.context import FlowRunContext, TaskRunContext +from prefect_dask.exceptions import ImproperClientError + + +def _populate_client_kwargs( + async_client: bool, + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper method to populate keyword arguments for `distributed.Client`. + """ + flow_run_context = FlowRunContext.get() + task_run_context = TaskRunContext.get() + + if flow_run_context: + context = "flow" + task_runner = flow_run_context.task_runner + input_client_kwargs = task_runner.client_kwargs + address = task_runner._connect_to + asynchronous = flow_run_context.flow.isasync + elif task_run_context: + # copies functionality of worker_client(separate_thread=False) + # because this allows us to set asynchronous based on user's task + context = "task" + input_client_kwargs = {} + address = get_client().scheduler.address + asynchronous = task_run_context.task.isasync + else: + # this else clause allows users to debug or test + # without much change to code + context = "" + input_client_kwargs = {} + address = None + asynchronous = async_client + + if not async_client and asynchronous: + raise ImproperClientError( + f"The {context} run is async; use `get_dask_async_client` instead" + ) + elif async_client and not asynchronous: + raise ImproperClientError( + f"The {context} run is not async; use `get_dask_sync_client` instead" + ) + + input_client_kwargs["address"] = address + input_client_kwargs["asynchronous"] = asynchronous + if timeout is not None: + input_client_kwargs["timeout"] = timeout + input_client_kwargs.update(**client_kwargs) + return input_client_kwargs + @contextmanager -def get_dask_client( +def get_dask_sync_client( timeout: Optional[Union[int, float, str, timedelta]] = None, - **client_kwargs: Dict[str, Any] + **client_kwargs: Dict[str, Any], ) -> Client: """ - Yields a temporary client; this is useful for parallelizing operations + Yields a temporary dask sync client; this is useful for parallelizing operations on dask collections, such as a `dask.DataFrame` or `dask.Bag`. Without invoking this, workers do not automatically get a client to connect to the full cluster. Therefore, it will attempt perform work within the worker itself serially, and potentially overwhelming the single worker. - Args: + Yields: timeout: Timeout after which to error out; has no effect in flow run contexts because the client has already started; Defaults to the `distributed.comm.timeouts.connect` @@ -33,21 +84,20 @@ def get_dask_client( from the task runner, if any. Returns: - Within task run contexts, the dask worker client, and within flow run contexts, - the existing client used in `DaskTaskRunner`. + A temporary dask sync client. Examples: - Use `get_dask_client` to distribute work across workers within task run context. + Use `get_dask_sync_client` to distribute work across workers. ```python import dask from prefect import flow, task - from prefect_dask import DaskTaskRunner, get_dask_client + from prefect_dask import DaskTaskRunner, get_dask_sync_client @task def compute_task(): - with get_dask_client(timeout="120s") as client: + with get_dask_sync_client(timeout="120s") as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") - summary_df = df.describe().compute() + summary_df = client.compute(df.describe()).result() return summary_df @flow(task_runner=DaskTaskRunner()) @@ -58,32 +108,62 @@ def dask_flow(): dask_flow() ``` """ - flow_run_context = FlowRunContext.get() - task_run_context = TaskRunContext.get() + client_kwargs = _populate_client_kwargs( + async_client=False, timeout=timeout, **client_kwargs + ) + with Client(**client_kwargs) as client: + yield client - if flow_run_context: - task_runner = flow_run_context.task_runner - input_client_kwargs = task_runner.client_kwargs - address = task_runner._connect_to - asynchronous = flow_run_context.flow.isasync - elif task_run_context: - # copies functionality of worker_client(separate_thread=False) - # because this allows us to set asynchronous based on user's task - input_client_kwargs = {} - address = get_client(timeout=timeout).scheduler.address - asynchronous = task_run_context.task.isasync - else: - # this else clause allows users to debug or test - # without much change to code - input_client_kwargs = {} - address = None - asynchronous = False - input_client_kwargs["address"] = address - input_client_kwargs["asynchronous"] = asynchronous - if timeout is not None: - input_client_kwargs["timeout"] = timeout - input_client_kwargs.update(**client_kwargs) +@asynccontextmanager +async def get_dask_async_client( + timeout: Optional[Union[int, float, str, timedelta]] = None, + **client_kwargs: Dict[str, Any], +) -> Client: + """ + Yields a temporary async client; this is useful for parallelizing operations + on dask collections, such as a `dask.DataFrame` or `dask.Bag`. - with Client(**input_client_kwargs) as client: + Without invoking this, workers do not automatically get a client to connect + to the full cluster. Therefore, it will attempt perform work within the + worker itself serially, and potentially overwhelming the single worker. + + Args: + timeout: Timeout after which to error out; has no effect in + flow run contexts because the client has already started; + Defaults to the `distributed.comm.timeouts.connect` + configuration value. + client_kwargs: Additional keyword arguments to pass to + `distributed.Client`, and overwrites inherited keyword arguments + from the task runner, if any. + + Yields: + A temporary dask async client. + + Examples: + Use `get_dask_async_client` to distribute work across workers. + ```python + import dask + from prefect import flow, task + from prefect_dask import DaskTaskRunner, get_dask_async_client + + @task + async def compute_task(): + async with get_dask_async_client(timeout="120s") as client: + df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") + summary_df = await client.compute(df.describe()) + return summary_df + + @flow(task_runner=DaskTaskRunner()) + async def dask_flow(): + prefect_future = await compute_task.submit() + return await prefect_future.result() + + asyncio.run(dask_flow()) + ``` + """ + client_kwargs = _populate_client_kwargs( + async_client=True, timeout=timeout, **client_kwargs + ) + async with Client(**client_kwargs) as client: yield client diff --git a/tests/test_utils.py b/tests/test_utils.py index 0060181..6eb581a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,50 +3,145 @@ from distributed import Client from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_client +from prefect_dask import DaskTaskRunner, get_dask_async_client, get_dask_sync_client +from prefect_dask.exceptions import ImproperClientError -@pytest.mark.parametrize("timeout", [None, 10]) -def test_get_dask_client_from_task(timeout): - @task - def test_task(): +class TestDaskSyncClient: + def test_from_task(self): + @task + def test_task(): + delayed_num = dask.delayed(42) + with get_dask_sync_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + return result + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + future = test_task.submit() + return future.result() + + assert test_flow() == 42 + + def test_from_async_task_error(self): + @task + async def test_task(): + with get_dask_sync_client(): + pass + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + test_task.submit() + + match = "The task run is async" + with pytest.raises(ImproperClientError, match=match): + test_flow() + + def test_from_flow(self): + @flow(task_runner=DaskTaskRunner) + def test_flow(): + delayed_num = dask.delayed(42) + with get_dask_sync_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + return result + + assert test_flow() == 42 + + async def test_from_async_flow_error(self): + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + with get_dask_sync_client(): + pass + + match = "The flow run is async" + with pytest.raises(ImproperClientError, match=match): + await test_flow() + + def test_outside_run_context(self): delayed_num = dask.delayed(42) - with get_dask_client(timeout=timeout) as client: + with get_dask_sync_client() as client: + assert isinstance(client, Client) + result = client.compute(delayed_num).result() + assert result == 42 + + @pytest.mark.parametrize("timeout", [None, 8]) + def test_include_timeout(self, timeout): + delayed_num = dask.delayed(42) + with get_dask_sync_client(timeout=timeout) as client: assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = delayed_num.compute() - return result + result = client.compute(delayed_num).result() + assert result == 42 + - @flow(task_runner=DaskTaskRunner) - def test_flow(): - future = test_task.submit() - return future.result() +class TestDaskAsyncClient: + async def test_from_task(self): + @task + async def test_task(): + delayed_num = dask.delayed(42) + async with get_dask_async_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + return result - assert test_flow() == 42 + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + future = await test_task.submit() + return await future.result() + assert (await test_flow()) == 42 -@pytest.mark.parametrize("timeout", [None, 10]) -def test_get_dask_client_from_flow(timeout): - @flow(task_runner=DaskTaskRunner) - def test_flow(): + async def test_from_async_task_error(self): + @task + async def test_task(): + with get_dask_sync_client(): + pass + + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + await test_task.submit() + + match = "The task run is sync" + with pytest.raises(ImproperClientError, match=match): + await test_flow() + + async def test_from_flow(self): + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + delayed_num = dask.delayed(42) + async with get_dask_async_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + return result + + assert (await test_flow()) == 42 + + async def test_from_async_flow_error(self): + @flow(task_runner=DaskTaskRunner) + async def test_flow(): + with get_dask_sync_client(): + pass + + match = "The flow run is async" + with pytest.raises(ImproperClientError, match=match): + await test_flow() + + async def test_outside_run_context(self): delayed_num = dask.delayed(42) - with get_dask_client(timeout=timeout) as client: + async with get_dask_async_client() as client: + assert isinstance(client, Client) + result = await client.compute(delayed_num).result() + assert result == 42 + + @pytest.mark.parametrize("timeout", [None, 8]) + async def test_include_timeout(self, timeout): + delayed_num = dask.delayed(42) + async with get_dask_async_client(timeout=timeout) as client: assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout - result = delayed_num.compute() - return result - - assert test_flow() == 42 - - -@pytest.mark.parametrize("timeout", [None, 10]) -def test_get_dask_client_outside_run_context(timeout): - delayed_num = dask.delayed(42) - with get_dask_client(timeout=timeout) as client: - assert isinstance(client, Client) - if timeout is not None: - assert client._timeout == timeout - result = delayed_num.compute() - assert result == 42 + result = await client.compute(delayed_num).result() + assert result == 42 From 0fb3eb58c2ccefe18e9dfcf997b16aa93906d724 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Thu, 29 Sep 2022 14:53:50 -0700 Subject: [PATCH 11/21] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ac618f..1e90c8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- `get_dask_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) +- `get_dask_sync_client` and `get_dask_async_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) ### Changed From 78b4682ef59e308648fc32fe75574899cfa4dfdf Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Thu, 29 Sep 2022 14:55:32 -0700 Subject: [PATCH 12/21] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e526719..88ae8c4 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,8 @@ The context managers can be used the same way in both `flow` run contexts and `t There is also an equivalent `async` version, namely `get_dask_async_client`. ```python +import asyncio + import dask from prefect import flow, task from prefect_dask import DaskTaskRunner, get_dask_async_client From a7f09bcc313a7c40ba259f99445f731d25e83b57 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 29 Sep 2022 15:06:30 -0700 Subject: [PATCH 13/21] Fix tests --- tests/test_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6eb581a..62e6b33 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -94,20 +94,6 @@ async def test_flow(): assert (await test_flow()) == 42 - async def test_from_async_task_error(self): - @task - async def test_task(): - with get_dask_sync_client(): - pass - - @flow(task_runner=DaskTaskRunner) - async def test_flow(): - await test_task.submit() - - match = "The task run is sync" - with pytest.raises(ImproperClientError, match=match): - await test_flow() - async def test_from_flow(self): @flow(task_runner=DaskTaskRunner) async def test_flow(): From 24d3ad83b7d6809199babc25f17c34ce9b71b706 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 29 Sep 2022 15:12:52 -0700 Subject: [PATCH 14/21] Fix tests and logic --- prefect_dask/utils.py | 4 ---- tests/test_utils.py | 24 ++++++++++++++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index c5c0b2e..53b1060 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -48,10 +48,6 @@ def _populate_client_kwargs( raise ImproperClientError( f"The {context} run is async; use `get_dask_async_client` instead" ) - elif async_client and not asynchronous: - raise ImproperClientError( - f"The {context} run is not async; use `get_dask_sync_client` instead" - ) input_client_kwargs["address"] = address input_client_kwargs["asynchronous"] = asynchronous diff --git a/tests/test_utils.py b/tests/test_utils.py index 62e6b33..91dc227 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -94,6 +94,19 @@ async def test_flow(): assert (await test_flow()) == 42 + def test_from_sync_task_error(self): + @task + def test_task(): + with get_dask_async_client(): + pass + + @flow(task_runner=DaskTaskRunner) + def test_flow(): + test_task.submit() + + with pytest.raises(AttributeError, match="__enter__"): + test_flow() + async def test_from_flow(self): @flow(task_runner=DaskTaskRunner) async def test_flow(): @@ -105,15 +118,14 @@ async def test_flow(): assert (await test_flow()) == 42 - async def test_from_async_flow_error(self): + def test_from_sync_flow_error(self): @flow(task_runner=DaskTaskRunner) - async def test_flow(): - with get_dask_sync_client(): + def test_flow(): + with get_dask_async_client(): pass - match = "The flow run is async" - with pytest.raises(ImproperClientError, match=match): - await test_flow() + with pytest.raises(AttributeError, match="__enter__"): + test_flow() async def test_outside_run_context(self): delayed_num = dask.delayed(42) From 9578caf3a598dccc678fb4301a64f684922d1aec Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 29 Sep 2022 15:13:14 -0700 Subject: [PATCH 15/21] Update exception docstring --- prefect_dask/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prefect_dask/exceptions.py b/prefect_dask/exceptions.py index 42c56f3..100b735 100644 --- a/prefect_dask/exceptions.py +++ b/prefect_dask/exceptions.py @@ -5,5 +5,5 @@ class ImproperClientError(Exception): """ - Raised when the flow/task is async but the client is sync, or vice versa. + Raised when the flow/task is async but the client is sync. """ From 62b09263f89a218c4afa555482e38a951a9153e8 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Thu, 29 Sep 2022 15:28:06 -0700 Subject: [PATCH 16/21] Update utils.py --- prefect_dask/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 53b1060..323b982 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -70,7 +70,7 @@ def get_dask_sync_client( to the full cluster. Therefore, it will attempt perform work within the worker itself serially, and potentially overwhelming the single worker. - Yields: + Args: timeout: Timeout after which to error out; has no effect in flow run contexts because the client has already started; Defaults to the `distributed.comm.timeouts.connect` @@ -79,7 +79,7 @@ def get_dask_sync_client( `distributed.Client`, and overwrites inherited keyword arguments from the task runner, if any. - Returns: + Yields: A temporary dask sync client. Examples: From 53b55602e44c9779505236890ca82fc1ff1d8b5f Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Fri, 30 Sep 2022 10:14:26 -0700 Subject: [PATCH 17/21] Update names --- CHANGELOG.md | 2 +- README.md | 10 +++++----- prefect_dask/__init__.py | 2 +- prefect_dask/utils.py | 22 ++++++++++++---------- tests/test_utils.py | 14 +++++++------- 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e90c8e..a500ebd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- `get_dask_sync_client` and `get_dask_async_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) +- `get_dask_client` and `get_dask_async_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) ### Changed diff --git a/README.md b/README.md index 88ae8c4..6356024 100644 --- a/README.md +++ b/README.md @@ -109,16 +109,16 @@ DaskTaskRunner( ### Distributing Dask collections across workers -If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_sync_client` or `get_dask_async_client`: +If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_client` or `get_dask_async_client`: ```python import dask from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_sync_client +from prefect_dask import DaskTaskRunner, get_dask_client @task def compute_task(): - with get_dask_sync_client() as client: + with get_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = df.describe().compute() return summary_df @@ -138,14 +138,14 @@ The context managers can be used the same way in both `flow` run contexts and `t 1. setting `sync=True` ```python - with get_dask_sync_client() as client: + with get_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = client.compute(df.describe(), sync=True) ``` 2. calling `result()` ```python - with get_dask_sync_client() as client: + with get_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = client.compute(df.describe()).result() ``` diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index fa2ba6f..225087e 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,5 +1,5 @@ from . import _version from .task_runners import DaskTaskRunner # noqa -from .utils import get_dask_sync_client, get_dask_async_client # noqa +from .utils import get_dask_client, get_dask_async_client # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 53b1060..50315cd 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -58,13 +58,14 @@ def _populate_client_kwargs( @contextmanager -def get_dask_sync_client( +def get_dask_client( timeout: Optional[Union[int, float, str, timedelta]] = None, **client_kwargs: Dict[str, Any], ) -> Client: """ - Yields a temporary dask sync client; this is useful for parallelizing operations - on dask collections, such as a `dask.DataFrame` or `dask.Bag`. + Yields a temporary synchronous dask client; this is useful + for parallelizing operations on dask collections, + such as a `dask.DataFrame` or `dask.Bag`. Without invoking this, workers do not automatically get a client to connect to the full cluster. Therefore, it will attempt perform work within the @@ -80,18 +81,18 @@ def get_dask_sync_client( from the task runner, if any. Returns: - A temporary dask sync client. + A temporary synchronous dask client. Examples: - Use `get_dask_sync_client` to distribute work across workers. + Use `get_dask_client` to distribute work across workers. ```python import dask from prefect import flow, task - from prefect_dask import DaskTaskRunner, get_dask_sync_client + from prefect_dask import DaskTaskRunner, get_dask_client @task def compute_task(): - with get_dask_sync_client(timeout="120s") as client: + with get_dask_client(timeout="120s") as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = client.compute(df.describe()).result() return summary_df @@ -117,8 +118,9 @@ async def get_dask_async_client( **client_kwargs: Dict[str, Any], ) -> Client: """ - Yields a temporary async client; this is useful for parallelizing operations - on dask collections, such as a `dask.DataFrame` or `dask.Bag`. + Yields a temporary asynchronous dask client; this is useful + for parallelizing operations on dask collections, + such as a `dask.DataFrame` or `dask.Bag`. Without invoking this, workers do not automatically get a client to connect to the full cluster. Therefore, it will attempt perform work within the @@ -134,7 +136,7 @@ async def get_dask_async_client( from the task runner, if any. Yields: - A temporary dask async client. + A temporary asynchronous dask client. Examples: Use `get_dask_async_client` to distribute work across workers. diff --git a/tests/test_utils.py b/tests/test_utils.py index 91dc227..704ad68 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ from distributed import Client from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_async_client, get_dask_sync_client +from prefect_dask import DaskTaskRunner, get_dask_async_client, get_dask_client from prefect_dask.exceptions import ImproperClientError @@ -12,7 +12,7 @@ def test_from_task(self): @task def test_task(): delayed_num = dask.delayed(42) - with get_dask_sync_client() as client: + with get_dask_client() as client: assert isinstance(client, Client) result = client.compute(delayed_num).result() return result @@ -27,7 +27,7 @@ def test_flow(): def test_from_async_task_error(self): @task async def test_task(): - with get_dask_sync_client(): + with get_dask_client(): pass @flow(task_runner=DaskTaskRunner) @@ -42,7 +42,7 @@ def test_from_flow(self): @flow(task_runner=DaskTaskRunner) def test_flow(): delayed_num = dask.delayed(42) - with get_dask_sync_client() as client: + with get_dask_client() as client: assert isinstance(client, Client) result = client.compute(delayed_num).result() return result @@ -52,7 +52,7 @@ def test_flow(): async def test_from_async_flow_error(self): @flow(task_runner=DaskTaskRunner) async def test_flow(): - with get_dask_sync_client(): + with get_dask_client(): pass match = "The flow run is async" @@ -61,7 +61,7 @@ async def test_flow(): def test_outside_run_context(self): delayed_num = dask.delayed(42) - with get_dask_sync_client() as client: + with get_dask_client() as client: assert isinstance(client, Client) result = client.compute(delayed_num).result() assert result == 42 @@ -69,7 +69,7 @@ def test_outside_run_context(self): @pytest.mark.parametrize("timeout", [None, 8]) def test_include_timeout(self, timeout): delayed_num = dask.delayed(42) - with get_dask_sync_client(timeout=timeout) as client: + with get_dask_client(timeout=timeout) as client: assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout From 4221fdae6e3cd46eeb6463d0af63fd36a3b4dff5 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Fri, 30 Sep 2022 17:06:25 -0700 Subject: [PATCH 18/21] Review comments --- CHANGELOG.md | 2 +- README.md | 8 ++++---- prefect_dask/__init__.py | 2 +- prefect_dask/task_runners.py | 5 ++--- prefect_dask/utils.py | 30 +++++++++++++++--------------- tests/test_utils.py | 14 +++++++------- 6 files changed, 30 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a500ebd..4fc2454 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- `get_dask_client` and `get_dask_async_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) +- `get_dask_client` and `get_async_dask_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33) ### Changed diff --git a/README.md b/README.md index 6356024..7c75e2c 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ DaskTaskRunner( ### Distributing Dask collections across workers -If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_client` or `get_dask_async_client`: +If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_client` or `get_async_dask_client`: ```python import dask @@ -151,18 +151,18 @@ The context managers can be used the same way in both `flow` run contexts and `t ``` For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). -There is also an equivalent `async` version, namely `get_dask_async_client`. +There is also an equivalent `async` version, namely `get_async_dask_client`. ```python import asyncio import dask from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_async_client +from prefect_dask import DaskTaskRunner, get_async_dask_client @task async def compute_task(): - async with get_dask_async_client() as client: + async with get_async_dask_client() as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = await client.compute(df.describe()) return summary_df diff --git a/prefect_dask/__init__.py b/prefect_dask/__init__.py index 225087e..e01e06a 100644 --- a/prefect_dask/__init__.py +++ b/prefect_dask/__init__.py @@ -1,5 +1,5 @@ from . import _version from .task_runners import DaskTaskRunner # noqa -from .utils import get_dask_client, get_dask_async_client # noqa +from .utils import get_dask_client, get_async_dask_client # noqa __version__ = _version.get_versions()["version"] diff --git a/prefect_dask/task_runners.py b/prefect_dask/task_runners.py index ec58052..fbe0b1b 100644 --- a/prefect_dask/task_runners.py +++ b/prefect_dask/task_runners.py @@ -291,10 +291,9 @@ async def _start(self, exit_stack: AsyncExitStack): f"Creating a new Dask cluster with " f"`{to_qualified_name(self.cluster_class)}`" ) - self._cluster = await exit_stack.enter_async_context( + self._connect_to = self._cluster = await exit_stack.enter_async_context( self.cluster_class(asynchronous=True, **self.cluster_kwargs) ) - self._connect_to = self._cluster.scheduler_address if self.adapt_kwargs: self._cluster.adapt(**self.adapt_kwargs) @@ -316,7 +315,7 @@ def __getstate__(self): Must be deserialized on a dask worker. """ data = self.__dict__.copy() - data.update({k: None for k in {"_client", "_cluster"}}) + data.update({k: None for k in {"_client", "_cluster", "_connect_to"}}) return data def __setstate__(self, data: dict): diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index afcc5a5..68566ad 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -12,7 +12,7 @@ from prefect_dask.exceptions import ImproperClientError -def _populate_client_kwargs( +def _generate_client_kwargs( async_client: bool, timeout: Optional[Union[int, float, str, timedelta]] = None, **client_kwargs: Dict[str, Any], @@ -23,19 +23,19 @@ def _populate_client_kwargs( flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() - if flow_run_context: - context = "flow" - task_runner = flow_run_context.task_runner - input_client_kwargs = task_runner.client_kwargs - address = task_runner._connect_to - asynchronous = flow_run_context.flow.isasync - elif task_run_context: + if task_run_context: # copies functionality of worker_client(separate_thread=False) # because this allows us to set asynchronous based on user's task context = "task" input_client_kwargs = {} address = get_client().scheduler.address asynchronous = task_run_context.task.isasync + elif flow_run_context: + context = "flow" + task_runner = flow_run_context.task_runner + input_client_kwargs = task_runner.client_kwargs + address = task_runner._connect_to + asynchronous = flow_run_context.flow.isasync else: # this else clause allows users to debug or test # without much change to code @@ -46,7 +46,7 @@ def _populate_client_kwargs( if not async_client and asynchronous: raise ImproperClientError( - f"The {context} run is async; use `get_dask_async_client` instead" + f"The {context} run is async; use `get_async_dask_client` instead" ) input_client_kwargs["address"] = address @@ -105,7 +105,7 @@ def dask_flow(): dask_flow() ``` """ - client_kwargs = _populate_client_kwargs( + client_kwargs = _generate_client_kwargs( async_client=False, timeout=timeout, **client_kwargs ) with Client(**client_kwargs) as client: @@ -113,7 +113,7 @@ def dask_flow(): @asynccontextmanager -async def get_dask_async_client( +async def get_async_dask_client( timeout: Optional[Union[int, float, str, timedelta]] = None, **client_kwargs: Dict[str, Any], ) -> Client: @@ -139,15 +139,15 @@ async def get_dask_async_client( A temporary asynchronous dask client. Examples: - Use `get_dask_async_client` to distribute work across workers. + Use `get_async_dask_client` to distribute work across workers. ```python import dask from prefect import flow, task - from prefect_dask import DaskTaskRunner, get_dask_async_client + from prefect_dask import DaskTaskRunner, get_async_dask_client @task async def compute_task(): - async with get_dask_async_client(timeout="120s") as client: + async with get_async_dask_client(timeout="120s") as client: df = dask.datasets.timeseries("2000", "2001", partition_freq="4w") summary_df = await client.compute(df.describe()) return summary_df @@ -160,7 +160,7 @@ async def dask_flow(): asyncio.run(dask_flow()) ``` """ - client_kwargs = _populate_client_kwargs( + client_kwargs = _generate_client_kwargs( async_client=True, timeout=timeout, **client_kwargs ) async with Client(**client_kwargs) as client: diff --git a/tests/test_utils.py b/tests/test_utils.py index 704ad68..82e41e9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ from distributed import Client from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_async_client, get_dask_client +from prefect_dask import DaskTaskRunner, get_async_dask_client, get_dask_client from prefect_dask.exceptions import ImproperClientError @@ -82,7 +82,7 @@ async def test_from_task(self): @task async def test_task(): delayed_num = dask.delayed(42) - async with get_dask_async_client() as client: + async with get_async_dask_client() as client: assert isinstance(client, Client) result = await client.compute(delayed_num).result() return result @@ -97,7 +97,7 @@ async def test_flow(): def test_from_sync_task_error(self): @task def test_task(): - with get_dask_async_client(): + with get_async_dask_client(): pass @flow(task_runner=DaskTaskRunner) @@ -111,7 +111,7 @@ async def test_from_flow(self): @flow(task_runner=DaskTaskRunner) async def test_flow(): delayed_num = dask.delayed(42) - async with get_dask_async_client() as client: + async with get_async_dask_client() as client: assert isinstance(client, Client) result = await client.compute(delayed_num).result() return result @@ -121,7 +121,7 @@ async def test_flow(): def test_from_sync_flow_error(self): @flow(task_runner=DaskTaskRunner) def test_flow(): - with get_dask_async_client(): + with get_async_dask_client(): pass with pytest.raises(AttributeError, match="__enter__"): @@ -129,7 +129,7 @@ def test_flow(): async def test_outside_run_context(self): delayed_num = dask.delayed(42) - async with get_dask_async_client() as client: + async with get_async_dask_client() as client: assert isinstance(client, Client) result = await client.compute(delayed_num).result() assert result == 42 @@ -137,7 +137,7 @@ async def test_outside_run_context(self): @pytest.mark.parametrize("timeout", [None, 8]) async def test_include_timeout(self, timeout): delayed_num = dask.delayed(42) - async with get_dask_async_client(timeout=timeout) as client: + async with get_async_dask_client(timeout=timeout) as client: assert isinstance(client, Client) if timeout is not None: assert client._timeout == timeout From 2f6c17579bedbfde1419046c0201dfba7caad4c1 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Mon, 3 Oct 2022 08:44:23 -0700 Subject: [PATCH 19/21] Remove improper client error --- prefect_dask/exceptions.py | 9 --------- prefect_dask/utils.py | 12 ++---------- tests/test_utils.py | 25 ------------------------- 3 files changed, 2 insertions(+), 44 deletions(-) delete mode 100644 prefect_dask/exceptions.py diff --git a/prefect_dask/exceptions.py b/prefect_dask/exceptions.py deleted file mode 100644 index 100b735..0000000 --- a/prefect_dask/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Exceptions specific to prefect-dask. -""" - - -class ImproperClientError(Exception): - """ - Raised when the flow/task is async but the client is sync. - """ diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index 68566ad..f3fbb06 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -9,8 +9,6 @@ from distributed import Client, get_client from prefect.context import FlowRunContext, TaskRunContext -from prefect_dask.exceptions import ImproperClientError - def _generate_client_kwargs( async_client: bool, @@ -26,12 +24,10 @@ def _generate_client_kwargs( if task_run_context: # copies functionality of worker_client(separate_thread=False) # because this allows us to set asynchronous based on user's task - context = "task" input_client_kwargs = {} address = get_client().scheduler.address asynchronous = task_run_context.task.isasync elif flow_run_context: - context = "flow" task_runner = flow_run_context.task_runner input_client_kwargs = task_runner.client_kwargs address = task_runner._connect_to @@ -39,16 +35,10 @@ def _generate_client_kwargs( else: # this else clause allows users to debug or test # without much change to code - context = "" input_client_kwargs = {} address = None asynchronous = async_client - if not async_client and asynchronous: - raise ImproperClientError( - f"The {context} run is async; use `get_async_dask_client` instead" - ) - input_client_kwargs["address"] = address input_client_kwargs["asynchronous"] = asynchronous if timeout is not None: @@ -71,6 +61,8 @@ def get_dask_client( to the full cluster. Therefore, it will attempt perform work within the worker itself serially, and potentially overwhelming the single worker. + For async, there is `get_async_dask_client`. + Args: timeout: Timeout after which to error out; has no effect in flow run contexts because the client has already started; diff --git a/tests/test_utils.py b/tests/test_utils.py index 82e41e9..81c4091 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,6 @@ from prefect import flow, task from prefect_dask import DaskTaskRunner, get_async_dask_client, get_dask_client -from prefect_dask.exceptions import ImproperClientError class TestDaskSyncClient: @@ -24,20 +23,6 @@ def test_flow(): assert test_flow() == 42 - def test_from_async_task_error(self): - @task - async def test_task(): - with get_dask_client(): - pass - - @flow(task_runner=DaskTaskRunner) - def test_flow(): - test_task.submit() - - match = "The task run is async" - with pytest.raises(ImproperClientError, match=match): - test_flow() - def test_from_flow(self): @flow(task_runner=DaskTaskRunner) def test_flow(): @@ -49,16 +34,6 @@ def test_flow(): assert test_flow() == 42 - async def test_from_async_flow_error(self): - @flow(task_runner=DaskTaskRunner) - async def test_flow(): - with get_dask_client(): - pass - - match = "The flow run is async" - with pytest.raises(ImproperClientError, match=match): - await test_flow() - def test_outside_run_context(self): delayed_num = dask.delayed(42) with get_dask_client() as client: From b63e451b861036a210968c5860956c0fee6c60d9 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Mon, 3 Oct 2022 08:55:33 -0700 Subject: [PATCH 20/21] Update warning --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7c75e2c..a899fcf 100644 --- a/README.md +++ b/README.md @@ -177,9 +177,9 @@ asyncio.run(dask_flow()) !!! warning "Resolving futures in async client" With the async client, you do not need to set `sync=True` or call `result()`. - However you must `await client.compute` before exiting out of the context manager. + However you must `await client.compute(dask_collection)` before exiting out of the context manager. - Running `await dask_collection.compute()` will result in an error: `TypeError: 'coroutine' object is not iterable`. + To invoke `compute` from the Dask collection, set `sync=False` and call `result()` before exiting out of the context manager: `await dask_collection.compute(sync=False)`. ### Using a temporary cluster From 661f19f45cd1c27ce6e9dcc3786baf62485b7f3a Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Mon, 3 Oct 2022 09:51:11 -0700 Subject: [PATCH 21/21] Apply suggestions from code review Co-authored-by: Michael Adkins --- README.md | 2 +- prefect_dask/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a899fcf..84bbe73 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ The context managers can be used the same way in both `flow` run contexts and `t ``` For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures). -There is also an equivalent `async` version, namely `get_async_dask_client`. +There is also an equivalent context manager for asynchronous tasks and flows: `get_async_dask_client`. ```python import asyncio diff --git a/prefect_dask/utils.py b/prefect_dask/utils.py index f3fbb06..39aaad2 100644 --- a/prefect_dask/utils.py +++ b/prefect_dask/utils.py @@ -61,7 +61,7 @@ def get_dask_client( to the full cluster. Therefore, it will attempt perform work within the worker itself serially, and potentially overwhelming the single worker. - For async, there is `get_async_dask_client`. + When in an async context, we recommend using `get_async_dask_client` instead. Args: timeout: Timeout after which to error out; has no effect in