Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask support #21

Merged
merged 18 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
poetry-version: 1.3.2
- name: Install package
run: |
poetry install --no-interaction --without=notebook
poetry install --all-extras --without=notebook
- name: Pytest
run: |
poetry run coverage run -m pytest
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

_tmp/
48 changes: 48 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,54 @@ print(n3.results)
# >>> 7.5
```

## Dask Support
ZnFlow comes with support for [Dask](https://www.dask.org/) to run your graph:
- in parallel.
- through e.g. SLURM (see https://jobqueue.dask.org/en/latest/api.html).
- with a nice GUI to track progress.

All you need to do is install ZnFlow with Dask ``pip install znflow[dask]``.
We can then extend the example from above. This will run ``n1`` and ``n2`` in parallel.
You can investigate the graph on the Dask dashboard (typically http://127.0.0.1:8787/graph or via the client object in Jupyter.)

````python
import znflow
import dataclasses
from dask.distributed import Client

@znflow.nodify
def compute_mean(x, y):
return (x + y) / 2

@dataclasses.dataclass
class ComputeMean(znflow.Node):
x: float
y: float

results: float = None

def run(self):
self.results = (self.x + self.y) / 2

with znflow.DiGraph() as graph:
n1 = ComputeMean(2, 8)
n2 = compute_mean(13, 7)
# connecting classes and functions to a Node
n3 = ComputeMean(n1.results, n2)

client = Client()
deployment = znflow.deployment.Deployment(graph=graph, client=client)
deployment.submit_graph()

n3 = deployment.get_results(n3)
print(n3)
# >>> ComputeMean(x=5.0, y=10.0, results=7.5)
````

We need to get the updated instance from the Dask worker via ``Deployment.get_results``.
Due to the way Dask works, an inplace update is not possible.
To retrieve the full graph, you can use ``Deployment.get_results(graph.nodes)`` instead.

### Working with lists
ZnFlow supports some special features for working with lists.
In the following example we want to ``combine`` two lists.
Expand Down
334 changes: 324 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znflow"
version = "0.1.9"
version = "0.1.10"
description = "A general purpose framework for building and running computational graphs."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand All @@ -11,6 +11,11 @@ python = "^3.8"
networkx = "^3.0"
matplotlib = "^3.6.3"

dask = { version = "^2022.12.1", optional = true }
distributed = { version = "^2022.12.1", optional = true }
dask-jobqueue = { version = "^0.8.1", optional = true }
bokeh = { version = "^2.4.2", optional = true }

[tool.poetry.group.lint.dependencies]
black = "^22.10.0"
isort = "^5.10.1"
Expand All @@ -25,6 +30,10 @@ attrs = "^22.2.0"
[tool.poetry.group.notebook.dependencies]
jupyterlab = "^3.5.1"

[tool.poetry.extras]
dask = ["dask", "distributed", "dask-jobqueue", "bokeh"]


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
99 changes: 99 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import dataclasses

import znflow


@znflow.nodify
def compute_sum(*args):
return sum(args)


@dataclasses.dataclass
class ComputeSum(znflow.Node):
inputs: list
outputs: float = None

def run(self):
# this will just call the function compute_sum and won't construct a graph!
self.outputs = compute_sum(*self.inputs)


@znflow.nodify
def add_to_ComputeSum(instance: ComputeSum):
return instance.outputs + 1


def test_single_nodify():
with znflow.DiGraph() as graph:
node1 = compute_sum(1, 2, 3)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
assert node1.result == 6


def test_single_Node():
with znflow.DiGraph() as graph:
node1 = ComputeSum(inputs=[1, 2, 3])

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
assert node1.outputs == 6


def test_multiple_nodify():
with znflow.DiGraph() as graph:
node1 = compute_sum(1, 2, 3)
node2 = compute_sum(4, 5, 6)
node3 = compute_sum(node1, node2)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
node2 = depl.get_results(node2)
node3 = depl.get_results(node3)
assert node1.result == 6
assert node2.result == 15
assert node3.result == 21


def test_multiple_Node():
with znflow.DiGraph() as graph:
node1 = ComputeSum(inputs=[1, 2, 3])
node2 = ComputeSum(inputs=[4, 5, 6])
node3 = ComputeSum(inputs=[node1.outputs, node2.outputs])

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
node2 = depl.get_results(node2)
node3 = depl.get_results(node3)
assert node1.outputs == 6
assert node2.outputs == 15
assert node3.outputs == 21


def test_multiple_nodify_and_Node():
with znflow.DiGraph() as graph:
node1 = compute_sum(1, 2, 3)
node2 = ComputeSum(inputs=[4, 5, 6])
node3 = compute_sum(node1, node2.outputs)
node4 = ComputeSum(inputs=[node1, node2.outputs, node3])
node5 = add_to_ComputeSum(node4)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

results = depl.get_results(graph.nodes)

assert results[node1.uuid].result == 6
assert results[node2.uuid].outputs == 15
assert results[node3.uuid].result == 21
assert results[node4.uuid].outputs == 42
assert results[node5.uuid].result == 43
2 changes: 1 addition & 1 deletion tests/test_znflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test the version."""
assert znflow.__version__ == "0.1.9"
assert znflow.__version__ == "0.1.10"
6 changes: 6 additions & 0 deletions znflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The 'ZnFlow' package."""
import contextlib
import importlib.metadata
import logging
import sys
Expand Down Expand Up @@ -32,6 +33,11 @@
"combine",
]

with contextlib.suppress(ImportError):
from znflow import deployment

__all__ += ["deployment"]

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

Expand Down
168 changes: 168 additions & 0 deletions znflow/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""ZnFlow deployment using Dask."""

import dataclasses
import typing
import uuid

from dask.distributed import Client, Future
from networkx.classes.reportviews import NodeView

from znflow.base import Connection, NodeBaseMixin
from znflow.graph import DiGraph
from znflow.utils import IterableHandler


class _LoadNode(IterableHandler):
"""Iterable handler for loading nodes."""

def default(self, value, **kwargs):
"""Default handler for loading nodes.

Parameters
----------
value: NodeBaseMixin|any
If a NodeBaseMixin, the node will be loaded and returned.
kwargs: dict
results: results dictionary of {uuid: node} shape.

Returns
-------
any:
If a NodeBaseMixin, the node will be loaded and returned.
Otherwise, the input value is returned.

"""
results = kwargs["results"]
if isinstance(value, NodeBaseMixin):
return results[value.uuid].result()

return value


class _UpdateConnections(IterableHandler):
"""Iterable handler for replacing connections."""

def default(self, value, **kwargs):
"""Replace connections by its values.

Parameters
----------
value: Connection|any
If a Connection, the connection will be replaced by its result.
kwargs: dict
predecessors: dict of {uuid: Connection} shape.

Returns
-------
any:
If a Connection, the connection will be replaced by its result.
Otherwise, the input value is returned.

"""
predecessors = kwargs["predecessors"]
if isinstance(value, Connection):
# We don't actually need the connection, we need the results.
return dataclasses.replace(value, instance=predecessors[value.uuid]).result
return value


def node_submit(node: NodeBaseMixin, **kwargs) -> NodeBaseMixin:
"""Submit script for Dask worker.

Parameters
----------
node: NodeBaseMixin
the Node class
kwargs: dict
predecessors: dict of {uuid: Connection} shape

Returns
-------
NodeBaseMixin:
the Node class with updated state (after calling "Node.run").

"""
predecessors = kwargs.get("predecessors", {})
for item in dir(node):
# TODO this information is available in the graph,
# no need to expensively iterate over all attributes
if item.startswith("_"):
continue
updater = _UpdateConnections()
value = updater(getattr(node, item), predecessors=predecessors)
if updater.updated:
setattr(node, item, value)

node.run()
return node


@dataclasses.dataclass
class Deployment:
"""ZnFlow deployment using Dask.

Attributes
----------
graph: DiGraph
the znflow graph containing the nodes.
client: Client, optional
the Dask client.
results: Dict[uuid, Future]
a dictionary of {uuid: Future} shape that is filled after the graph is submitted.

"""

graph: DiGraph
client: Client = dataclasses.field(default_factory=Client)
results: typing.Dict[uuid.UUID, Future] = dataclasses.field(
default_factory=dict, init=False
)

def submit_graph(self):
"""Submit the graph to Dask.

When submitting to Dask, a Node is serialized, processed and a
copy can be returned.

This requires:
- the connections to be updated to the respective Nodes coming from Dask futures.
- the Node to be returned from the workers and passed to all successors.
"""
for node_uuid in self.graph.reverse():
node = self.graph.nodes[node_uuid]["value"]
predecessors = list(self.graph.predecessors(node.uuid))

if len(predecessors) == 0:
self.results[node.uuid] = self.client.submit( # TODO how to name
node_submit, node=node, pure=False
)
else:
self.results[node.uuid] = self.client.submit(
node_submit,
node=node,
predecessors={
x: self.results[x] for x in self.results if x in predecessors
},
pure=False,
)

def get_results(self, obj: typing.Union[NodeBaseMixin, list, dict, NodeView], /):
"""Get the results from Dask based on the original object.

Parameters
----------
obj: NodeBaseMixin|list|dict|NodeView
either a single Node or multiple Nodes from the submitted graph.

Returns
-------
any:
Returns an instance of obj which is updated with the results from Dask.

"""
if isinstance(obj, NodeView):
data = _LoadNode()(dict(obj), results=self.results)
return {x: v["value"] for x, v in data.items()}
elif isinstance(obj, DiGraph):
raise NotImplementedError
return _LoadNode()(obj, results=self.results)