Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add get_dask_client #33

Merged
merged 24 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `get_dask_sync_client` and `get_dask_async_client` allowing for distributed computation within a task - [#33](https://github.com/PrefectHQ/prefect-dask/pull/33)

### Changed

### Deprecated
Expand Down
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,80 @@ DaskTaskRunner(
)
```

### Distributing Dask collections across workers

If you use a Dask collection, such as a `dask.DataFrame` or `dask.Bag`, to distribute the work across workers and achieve parallel computations, use one of the context managers `get_dask_sync_client` or `get_dask_async_client`:

```python
import dask
from prefect import flow, task
from prefect_dask import DaskTaskRunner, get_dask_sync_client

@task
def compute_task():
with get_dask_sync_client() as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = df.describe().compute()
return summary_df

@flow(task_runner=DaskTaskRunner())
def dask_flow():
prefect_future = compute_task.submit()
return prefect_future.result()

dask_flow()
```

The context managers can be used the same way in both `flow` run contexts and `task` run contexts.

!!! warning "Resolving futures in sync client"
Note, by default, `dask_collection.compute()` returns concrete values while `client.compute(dask_collection)` returns Dask Futures. Therefore, if you call `client.compute`, you must resolve all futures before exiting out of the context manager by either:

1. setting `sync=True`
```python
with get_dask_sync_client() as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = client.compute(df.describe(), sync=True)
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
```

2. calling `result()`
```python
with get_dask_sync_client() as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = client.compute(df.describe()).result()
```
For more information, visit the docs on [Waiting on Futures](https://docs.dask.org/en/stable/futures.html#waiting-on-futures).

There is also an equivalent `async` version, namely `get_dask_async_client`.

```python
import asyncio

import dask
from prefect import flow, task
from prefect_dask import DaskTaskRunner, get_dask_async_client

@task
async def compute_task():
async with get_dask_async_client() as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = await client.compute(df.describe())
return summary_df

@flow(task_runner=DaskTaskRunner())
async def dask_flow():
prefect_future = await compute_task.submit()
return await prefect_future.result()

asyncio.run(dask_flow())
```
!!! warning "Resolving futures in async client"
With the async client, you do not need to set `sync=True` or call `result()`.

However you must `await client.compute` before exiting out of the context manager.

Running `await dask_collection.compute()` will result in an error: `TypeError: 'coroutine' object is not iterable`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems contrary, can you clarify? Is the second bit a Dask bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have to try with dask alone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import dask
from dask.distributed import Client

async with Client(asynchronous=True) as client:
    df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
    print(type(df))
    print(type(df.describe()))
    print(type(df.describe().compute())) # errors on this line here
    summary_df = df.describe().compute()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this works

import dask
from dask.distributed import Client, wait

async with Client(asynchronous=True) as client:
    df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
    summary_df = await df.describe().compute(sync=False)[0].result()


### Using a temporary cluster

The `DaskTaskRunner` is capable of creating a temporary cluster using any of [Dask's cluster-manager options](https://docs.dask.org/en/latest/setup.html). This can be useful when you want each flow run to have its own Dask cluster, allowing for per-flow adaptive scaling.
Expand Down
1 change: 1 addition & 0 deletions docs/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: prefect_dask.utils
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ plugins:
nav:
- Home: index.md
- Task Runners: task_runners.md
- Utils: utils.md

1 change: 1 addition & 0 deletions prefect_dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import _version
from .task_runners import DaskTaskRunner # noqa
from .utils import get_dask_sync_client, get_dask_async_client # noqa

__version__ = _version.get_versions()["version"]
9 changes: 9 additions & 0 deletions prefect_dask/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Exceptions specific to prefect-dask.
"""


class ImproperClientError(Exception):
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
"""
Raised when the flow/task is async but the client is sync.
"""
47 changes: 31 additions & 16 deletions prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ class DaskTaskRunner(BaseTaskRunner):
different cluster class (e.g.
[`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can
specify `cluster_class`/`cluster_kwargs`.

Alternatively, if you already have a dask cluster running, you can provide
the address of the scheduler via the `address` kwarg.

!!! warning "Multiprocessing safety"
Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows
in scripts must be guarded with `if __name__ == "__main__":` or warnings will
be displayed.

Args:
address (string, optional): Address of a currently running dask
scheduler; if one is not provided, a temporary cluster will be
Expand All @@ -113,24 +116,33 @@ class name (e.g. `"distributed.LocalCluster"`), or the class itself.
is only enabled if `adapt_kwargs` are provided.
client_kwargs (dict, optional): Additional kwargs to use when creating a
[`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client).

Examples:
Using a temporary local dask cluster:
>>> from prefect import flow
>>> from prefect_dask.task_runners import DaskTaskRunner
>>> @flow(task_runner=DaskTaskRunner)
>>> def my_flow():
>>> ...
```python
from prefect import flow
from prefect_dask.task_runners import DaskTaskRunner
@flow(task_runner=DaskTaskRunner)
def my_flow():
...
```

Using a temporary cluster running elsewhere. Any Dask cluster class should
work, here we use [dask-cloudprovider](https://cloudprovider.dask.org):
>>> DaskTaskRunner(
>>> cluster_class="dask_cloudprovider.FargateCluster",
>>> cluster_kwargs={
>>> "image": "prefecthq/prefect:latest",
>>> "n_workers": 5,
>>> },
>>> )
```python
DaskTaskRunner(
cluster_class="dask_cloudprovider.FargateCluster",
cluster_kwargs={
"image": "prefecthq/prefect:latest",
"n_workers": 5,
},
)
```

Connecting to an existing dask cluster:
>>> DaskTaskRunner(address="192.0.2.255:8786")
```python
DaskTaskRunner(address="192.0.2.255:8786")
```
"""

def __init__(
Expand Down Expand Up @@ -271,22 +283,25 @@ async def _start(self, exit_stack: AsyncExitStack):
self.logger.info(
f"Connecting to an existing Dask cluster at {self.address}"
)
connect_to = self.address
self._connect_to = self.address
else:
self.cluster_class = self.cluster_class or distributed.LocalCluster

self.logger.info(
f"Creating a new Dask cluster with "
f"`{to_qualified_name(self.cluster_class)}`"
)
connect_to = self._cluster = await exit_stack.enter_async_context(
self._cluster = await exit_stack.enter_async_context(
self.cluster_class(asynchronous=True, **self.cluster_kwargs)
)
self._connect_to = self._cluster.scheduler_address
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
if self.adapt_kwargs:
self._cluster.adapt(**self.adapt_kwargs)

self._client = await exit_stack.enter_async_context(
distributed.Client(connect_to, asynchronous=True, **self.client_kwargs)
distributed.Client(
self._connect_to, asynchronous=True, **self.client_kwargs
)
)

if self._client.dashboard_link:
Expand Down
165 changes: 165 additions & 0 deletions prefect_dask/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Utils to use alongside prefect-dask.
"""

from contextlib import asynccontextmanager, contextmanager
from datetime import timedelta
from typing import Any, Dict, Optional, Union

from distributed import Client, get_client
from prefect.context import FlowRunContext, TaskRunContext

from prefect_dask.exceptions import ImproperClientError


def _populate_client_kwargs(
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
async_client: bool,
timeout: Optional[Union[int, float, str, timedelta]] = None,
**client_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to populate keyword arguments for `distributed.Client`.
"""
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if flow_run_context:
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
context = "flow"
task_runner = flow_run_context.task_runner
input_client_kwargs = task_runner.client_kwargs
address = task_runner._connect_to
asynchronous = flow_run_context.flow.isasync
elif task_run_context:
# copies functionality of worker_client(separate_thread=False)
# because this allows us to set asynchronous based on user's task
context = "task"
input_client_kwargs = {}
address = get_client().scheduler.address
asynchronous = task_run_context.task.isasync
else:
# this else clause allows users to debug or test
# without much change to code
context = ""
input_client_kwargs = {}
address = None
asynchronous = async_client

if not async_client and asynchronous:
raise ImproperClientError(
f"The {context} run is async; use `get_dask_async_client` instead"
)

input_client_kwargs["address"] = address
input_client_kwargs["asynchronous"] = asynchronous
if timeout is not None:
input_client_kwargs["timeout"] = timeout
input_client_kwargs.update(**client_kwargs)
return input_client_kwargs


@contextmanager
def get_dask_sync_client(
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
timeout: Optional[Union[int, float, str, timedelta]] = None,
**client_kwargs: Dict[str, Any],
) -> Client:
"""
Yields a temporary dask sync client; this is useful for parallelizing operations
on dask collections, such as a `dask.DataFrame` or `dask.Bag`.

Without invoking this, workers do not automatically get a client to connect
to the full cluster. Therefore, it will attempt perform work within the
worker itself serially, and potentially overwhelming the single worker.

Args:
timeout: Timeout after which to error out; has no effect in
flow run contexts because the client has already started;
Defaults to the `distributed.comm.timeouts.connect`
configuration value.
client_kwargs: Additional keyword arguments to pass to
`distributed.Client`, and overwrites inherited keyword arguments
from the task runner, if any.

Yields:
A temporary dask sync client.

Examples:
Use `get_dask_sync_client` to distribute work across workers.
```python
import dask
from prefect import flow, task
from prefect_dask import DaskTaskRunner, get_dask_sync_client

@task
def compute_task():
with get_dask_sync_client(timeout="120s") as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = client.compute(df.describe()).result()
return summary_df

@flow(task_runner=DaskTaskRunner())
def dask_flow():
prefect_future = compute_task.submit()
return prefect_future.result()

dask_flow()
```
"""
client_kwargs = _populate_client_kwargs(
async_client=False, timeout=timeout, **client_kwargs
)
with Client(**client_kwargs) as client:
yield client


@asynccontextmanager
async def get_dask_async_client(
timeout: Optional[Union[int, float, str, timedelta]] = None,
**client_kwargs: Dict[str, Any],
) -> Client:
"""
Yields a temporary async client; this is useful for parallelizing operations
on dask collections, such as a `dask.DataFrame` or `dask.Bag`.

Without invoking this, workers do not automatically get a client to connect
to the full cluster. Therefore, it will attempt perform work within the
worker itself serially, and potentially overwhelming the single worker.

Args:
timeout: Timeout after which to error out; has no effect in
flow run contexts because the client has already started;
Defaults to the `distributed.comm.timeouts.connect`
configuration value.
client_kwargs: Additional keyword arguments to pass to
`distributed.Client`, and overwrites inherited keyword arguments
from the task runner, if any.

Yields:
A temporary dask async client.

Examples:
Use `get_dask_async_client` to distribute work across workers.
```python
import dask
from prefect import flow, task
from prefect_dask import DaskTaskRunner, get_dask_async_client

@task
async def compute_task():
async with get_dask_async_client(timeout="120s") as client:
df = dask.datasets.timeseries("2000", "2001", partition_freq="4w")
summary_df = await client.compute(df.describe())
return summary_df

@flow(task_runner=DaskTaskRunner())
async def dask_flow():
prefect_future = await compute_task.submit()
return await prefect_future.result()

asyncio.run(dask_flow())
```
"""
client_kwargs = _populate_client_kwargs(
async_client=True, timeout=timeout, **client_kwargs
)
async with Client(**client_kwargs) as client:
yield client
Loading