diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 11c57d2..5738b90 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -29,7 +29,7 @@ jobs: poetry-version: 1.3.2 - name: Install package run: | - poetry install --no-interaction --without=notebook + poetry install --all-extras --without=notebook - name: Pytest run: | poetry run coverage run -m pytest diff --git a/.gitignore b/.gitignore index 2dc53ca..035260a 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +_tmp/ \ No newline at end of file diff --git a/README.md b/README.md index c748482..58f8fd0 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,54 @@ print(n3.results) # >>> 7.5 ``` +## Dask Support +ZnFlow comes with support for [Dask](https://www.dask.org/) to run your graph: +- in parallel. +- through e.g. SLURM (see https://jobqueue.dask.org/en/latest/api.html). +- with a nice GUI to track progress. + +All you need to do is install ZnFlow with Dask ``pip install znflow[dask]``. +We can then extend the example from above. This will run ``n1`` and ``n2`` in parallel. +You can investigate the graph on the Dask dashboard (typically http://127.0.0.1:8787/graph or via the client object in Jupyter.) + +````python +import znflow +import dataclasses +from dask.distributed import Client + +@znflow.nodify +def compute_mean(x, y): + return (x + y) / 2 + +@dataclasses.dataclass +class ComputeMean(znflow.Node): + x: float + y: float + + results: float = None + + def run(self): + self.results = (self.x + self.y) / 2 + +with znflow.DiGraph() as graph: + n1 = ComputeMean(2, 8) + n2 = compute_mean(13, 7) + # connecting classes and functions to a Node + n3 = ComputeMean(n1.results, n2) + +client = Client() +deployment = znflow.deployment.Deployment(graph=graph, client=client) +deployment.submit_graph() + +n3 = deployment.get_results(n3) +print(n3) +# >>> ComputeMean(x=5.0, y=10.0, results=7.5) +```` + +We need to get the updated instance from the Dask worker via ``Deployment.get_results``. +Due to the way Dask works, an inplace update is not possible. +To retrieve the full graph, you can use ``Deployment.get_results(graph.nodes)`` instead. + ### Working with lists ZnFlow supports some special features for working with lists. In the following example we want to ``combine`` two lists. diff --git a/poetry.lock b/poetry.lock index 51909ee..60801f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -268,6 +268,27 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.2)"] +[[package]] +name = "bokeh" +version = "2.4.3" +description = "Interactive plots and applications in the browser from Python" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "bokeh-2.4.3-py3-none-any.whl", hash = "sha256:104d2f0a4ca7774ee4b11e545aa34ff76bf3e2ad6de0d33944361981b65da420"}, + {file = "bokeh-2.4.3.tar.gz", hash = "sha256:ef33801161af379665ab7a34684f2209861e3aefd5c803a21fbbb99d94874b03"}, +] + +[package.dependencies] +Jinja2 = ">=2.9" +numpy = ">=1.11.3" +packaging = ">=16.8" +pillow = ">=7.1.0" +PyYAML = ">=3.10" +tornado = ">=5.1" +typing-extensions = ">=3.10.0" + [[package]] name = "certifi" version = "2022.12.7" @@ -446,7 +467,7 @@ files = [ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -457,11 +478,23 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +category = "main" +optional = true +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, +] + [[package]] name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -639,6 +672,54 @@ files = [ {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, ] +[[package]] +name = "dask" +version = "2022.12.1" +description = "Parallel PyData with Task Scheduling" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "dask-2022.12.1-py3-none-any.whl", hash = "sha256:a833ee774bf702c08d22f31412358d12b007df36c6e8c107f32f17a4b20f1f68"}, + {file = "dask-2022.12.1.tar.gz", hash = "sha256:ef12c98a6681964494ddfee4ba8071ebc8895d3c4ea27f5c5160a14e29f01d92"}, +] + +[package.dependencies] +click = ">=7.0" +cloudpickle = ">=1.1.1" +fsspec = ">=0.6.0" +packaging = ">=20.0" +partd = ">=0.3.10" +pyyaml = ">=5.3.1" +toolz = ">=0.8.2" + +[package.extras] +array = ["numpy (>=1.18)"] +complete = ["bokeh (>=2.4.2,<3)", "distributed (==2022.12.1)", "jinja2", "numpy (>=1.18)", "pandas (>=1.0)"] +dataframe = ["numpy (>=1.18)", "pandas (>=1.0)"] +diagnostics = ["bokeh (>=2.4.2,<3)", "jinja2"] +distributed = ["distributed (==2022.12.1)"] +test = ["pandas[test]", "pre-commit", "pytest", "pytest-rerunfailures", "pytest-xdist"] + +[[package]] +name = "dask-jobqueue" +version = "0.8.1" +description = "Deploy Dask on job queuing systems like PBS, Slurm, SGE or LSF" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "dask-jobqueue-0.8.1.tar.gz", hash = "sha256:16fd1b646a073ad3de75dde12a0dfe529b836f21a3bdbcee2a88bef24e9112a7"}, + {file = "dask_jobqueue-0.8.1-py2.py3-none-any.whl", hash = "sha256:22f7435bbda34feb75cd7abc4b3175309cbdb9e8dadb02174d37aba09944abe9"}, +] + +[package.dependencies] +dask = ">=2022.02.0" +distributed = ">=2022.02.0" + +[package.extras] +test = ["cryptography", "pytest", "pytest-asyncio"] + [[package]] name = "debugpy" version = "1.6.6" @@ -690,6 +771,35 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "distributed" +version = "2022.12.1" +description = "Distributed scheduler for Dask" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "distributed-2022.12.1-py3-none-any.whl", hash = "sha256:d7abd29277c6b7af8df7fef68c1552100478d3da7bf6a4a3562142be8948c1e8"}, + {file = "distributed-2022.12.1.tar.gz", hash = "sha256:87fedcfbf2126d14c3865c8e445c41eccabd2b07eaff87528af2afd9e20ced6c"}, +] + +[package.dependencies] +click = ">=7.0" +cloudpickle = ">=1.5.0" +dask = "2022.12.1" +jinja2 = "*" +locket = ">=1.0.0" +msgpack = ">=0.6.0" +packaging = ">=20.0" +psutil = ">=5.0" +pyyaml = "*" +sortedcontainers = "<2.0.0 || >2.0.0,<2.0.1 || >2.0.1" +tblib = ">=1.6.0" +toolz = ">=0.10.0" +tornado = ">=6.0.3" +urllib3 = "*" +zict = ">=0.1.3" + [[package]] name = "exceptiongroup" version = "1.1.1" @@ -773,6 +883,52 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "fsspec" +version = "2023.3.0" +description = "File-system specification" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.3.0-py3-none-any.whl", hash = "sha256:bf57215e19dbfa4fe7edae53040cc1deef825e3b1605cca9a8d2c2fadd2328a0"}, + {file = "fsspec-2023.3.0.tar.gz", hash = "sha256:24e635549a590d74c6c18274ddd3ffab4753341753e923408b1904eaabafe04d"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "heapdict" +version = "1.0.1" +description = "a heap with decrease-key and increase-key operations" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "HeapDict-1.0.1-py3-none-any.whl", hash = "sha256:6065f90933ab1bb7e50db403b90cab653c853690c5992e69294c2de2b253fc92"}, + {file = "HeapDict-1.0.1.tar.gz", hash = "sha256:8495f57b3e03d8e46d5f1b2cc62ca881aca392fd5cc048dc0aa2e1a6d23ecdb6"}, +] + [[package]] name = "idna" version = "3.4" @@ -978,7 +1134,7 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1381,11 +1537,23 @@ files = [ {file = "kiwisolver-1.4.4.tar.gz", hash = "sha256:d41997519fcba4a1e46eb4a2fe31bc12f0ff957b2b81bac28db24744f333e955"}, ] +[[package]] +name = "locket" +version = "1.0.0" +description = "File-based locks for Python on Linux and Windows" +category = "main" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "locket-1.0.0-py2.py3-none-any.whl", hash = "sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3"}, + {file = "locket-1.0.0.tar.gz", hash = "sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632"}, +] + [[package]] name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1531,6 +1699,79 @@ files = [ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"}, ] +[[package]] +name = "msgpack" +version = "1.0.5" +description = "MessagePack serializer" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "msgpack-1.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:525228efd79bb831cf6830a732e2e80bc1b05436b086d4264814b4b2955b2fa9"}, + {file = "msgpack-1.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4f8d8b3bf1ff2672567d6b5c725a1b347fe838b912772aa8ae2bf70338d5a198"}, + {file = "msgpack-1.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdc793c50be3f01106245a61b739328f7dccc2c648b501e237f0699fe1395b81"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cb47c21a8a65b165ce29f2bec852790cbc04936f502966768e4aae9fa763cb7"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e42b9594cc3bf4d838d67d6ed62b9e59e201862a25e9a157019e171fbe672dd3"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:55b56a24893105dc52c1253649b60f475f36b3aa0fc66115bffafb624d7cb30b"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1967f6129fc50a43bfe0951c35acbb729be89a55d849fab7686004da85103f1c"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:20a97bf595a232c3ee6d57ddaadd5453d174a52594bf9c21d10407e2a2d9b3bd"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d25dd59bbbbb996eacf7be6b4ad082ed7eacc4e8f3d2df1ba43822da9bfa122a"}, + {file = "msgpack-1.0.5-cp310-cp310-win32.whl", hash = "sha256:382b2c77589331f2cb80b67cc058c00f225e19827dbc818d700f61513ab47bea"}, + {file = "msgpack-1.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:4867aa2df9e2a5fa5f76d7d5565d25ec76e84c106b55509e78c1ede0f152659a"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9f5ae84c5c8a857ec44dc180a8b0cc08238e021f57abdf51a8182e915e6299f0"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9e6ca5d5699bcd89ae605c150aee83b5321f2115695e741b99618f4856c50898"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5494ea30d517a3576749cad32fa27f7585c65f5f38309c88c6d137877fa28a5a"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ab2f3331cb1b54165976a9d976cb251a83183631c88076613c6c780f0d6e45a"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28592e20bbb1620848256ebc105fc420436af59515793ed27d5c77a217477705"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe5c63197c55bce6385d9aee16c4d0641684628f63ace85f73571e65ad1c1e8d"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed40e926fa2f297e8a653c954b732f125ef97bdd4c889f243182299de27e2aa9"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b2de4c1c0538dcb7010902a2b97f4e00fc4ddf2c8cda9749af0e594d3b7fa3d7"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bf22a83f973b50f9d38e55c6aade04c41ddda19b00c4ebc558930d78eecc64ed"}, + {file = "msgpack-1.0.5-cp311-cp311-win32.whl", hash = "sha256:c396e2cc213d12ce017b686e0f53497f94f8ba2b24799c25d913d46c08ec422c"}, + {file = "msgpack-1.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c4c68d87497f66f96d50142a2b73b97972130d93677ce930718f68828b382e2"}, + {file = "msgpack-1.0.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a2b031c2e9b9af485d5e3c4520f4220d74f4d222a5b8dc8c1a3ab9448ca79c57"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f837b93669ce4336e24d08286c38761132bc7ab29782727f8557e1eb21b2080"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1d46dfe3832660f53b13b925d4e0fa1432b00f5f7210eb3ad3bb9a13c6204a6"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:366c9a7b9057e1547f4ad51d8facad8b406bab69c7d72c0eb6f529cf76d4b85f"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:4c075728a1095efd0634a7dccb06204919a2f67d1893b6aa8e00497258bf926c"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:f933bbda5a3ee63b8834179096923b094b76f0c7a73c1cfe8f07ad608c58844b"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:36961b0568c36027c76e2ae3ca1132e35123dcec0706c4b7992683cc26c1320c"}, + {file = "msgpack-1.0.5-cp36-cp36m-win32.whl", hash = "sha256:b5ef2f015b95f912c2fcab19c36814963b5463f1fb9049846994b007962743e9"}, + {file = "msgpack-1.0.5-cp36-cp36m-win_amd64.whl", hash = "sha256:288e32b47e67f7b171f86b030e527e302c91bd3f40fd9033483f2cacc37f327a"}, + {file = "msgpack-1.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:137850656634abddfb88236008339fdaba3178f4751b28f270d2ebe77a563b6c"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c05a4a96585525916b109bb85f8cb6511db1c6f5b9d9cbcbc940dc6b4be944b"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a62ec00b636583e5cb6ad313bbed36bb7ead5fa3a3e38938503142c72cba4f"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef8108f8dedf204bb7b42994abf93882da1159728a2d4c5e82012edd92c9da9f"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1835c84d65f46900920b3708f5ba829fb19b1096c1800ad60bae8418652a951d"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e57916ef1bd0fee4f21c4600e9d1da352d8816b52a599c46460e93a6e9f17086"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:17358523b85973e5f242ad74aa4712b7ee560715562554aa2134d96e7aa4cbbf"}, + {file = "msgpack-1.0.5-cp37-cp37m-win32.whl", hash = "sha256:cb5aaa8c17760909ec6cb15e744c3ebc2ca8918e727216e79607b7bbce9c8f77"}, + {file = "msgpack-1.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:ab31e908d8424d55601ad7075e471b7d0140d4d3dd3272daf39c5c19d936bd82"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b72d0698f86e8d9ddf9442bdedec15b71df3598199ba33322d9711a19f08145c"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:379026812e49258016dd84ad79ac8446922234d498058ae1d415f04b522d5b2d"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:332360ff25469c346a1c5e47cbe2a725517919892eda5cfaffe6046656f0b7bb"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:476a8fe8fae289fdf273d6d2a6cb6e35b5a58541693e8f9f019bfe990a51e4ba"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9985b214f33311df47e274eb788a5893a761d025e2b92c723ba4c63936b69b1"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48296af57cdb1d885843afd73c4656be5c76c0c6328db3440c9601a98f303d87"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:addab7e2e1fcc04bd08e4eb631c2a90960c340e40dfc4a5e24d2ff0d5a3b3edb"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:916723458c25dfb77ff07f4c66aed34e47503b2eb3188b3adbec8d8aa6e00f48"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:821c7e677cc6acf0fd3f7ac664c98803827ae6de594a9f99563e48c5a2f27eb0"}, + {file = "msgpack-1.0.5-cp38-cp38-win32.whl", hash = "sha256:1c0f7c47f0087ffda62961d425e4407961a7ffd2aa004c81b9c07d9269512f6e"}, + {file = "msgpack-1.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:bae7de2026cbfe3782c8b78b0db9cbfc5455e079f1937cb0ab8d133496ac55e1"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:20c784e66b613c7f16f632e7b5e8a1651aa5702463d61394671ba07b2fc9e025"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:266fa4202c0eb94d26822d9bfd7af25d1e2c088927fe8de9033d929dd5ba24c5"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18334484eafc2b1aa47a6d42427da7fa8f2ab3d60b674120bce7a895a0a85bdd"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57e1f3528bd95cc44684beda696f74d3aaa8a5e58c816214b9046512240ef437"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:586d0d636f9a628ddc6a17bfd45aa5b5efaf1606d2b60fa5d87b8986326e933f"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a740fa0e4087a734455f0fc3abf5e746004c9da72fbd541e9b113013c8dc3282"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3055b0455e45810820db1f29d900bf39466df96ddca11dfa6d074fa47054376d"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a61215eac016f391129a013c9e46f3ab308db5f5ec9f25811e811f96962599a8"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:362d9655cd369b08fda06b6657a303eb7172d5279997abe094512e919cf74b11"}, + {file = "msgpack-1.0.5-cp39-cp39-win32.whl", hash = "sha256:ac9dd47af78cae935901a9a500104e2dea2e253207c924cc95de149606dc43cc"}, + {file = "msgpack-1.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:06f5174b5f8ed0ed919da0e62cbd4ffde676a374aba4020034da05fab67b9164"}, + {file = "msgpack-1.0.5.tar.gz", hash = "sha256:c075544284eadc5cddc70f4757331d99dcbc16b2bbd4849d15f8aae4cf36d31c"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1825,6 +2066,25 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "partd" +version = "1.3.0" +description = "Appendable key-value storage" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "partd-1.3.0-py3-none-any.whl", hash = "sha256:6393a0c898a0ad945728e34e52de0df3ae295c5aff2e2926ba7cc3c60a734a15"}, + {file = "partd-1.3.0.tar.gz", hash = "sha256:ce91abcdc6178d668bcaa431791a5a917d902341cb193f543fe445d494660485"}, +] + +[package.dependencies] +locket = "*" +toolz = "*" + +[package.extras] +complete = ["blosc", "numpy (>=1.9.0)", "pandas (>=0.19.0)", "pyzmq"] + [[package]] name = "pathspec" version = "0.11.1" @@ -2033,7 +2293,7 @@ wcwidth = "*" name = "psutil" version = "5.9.4" description = "Cross-platform lib for process and system monitoring in Python." -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2269,7 +2529,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2522,6 +2782,18 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, + {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, +] + [[package]] name = "soupsieve" version = "2.4" @@ -2554,6 +2826,18 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "tblib" +version = "1.7.0" +description = "Traceback serialization library." +category = "main" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "tblib-1.7.0-py2.py3-none-any.whl", hash = "sha256:289fa7359e580950e7d9743eab36b0691f0310fce64dee7d9c31065b8f723e23"}, + {file = "tblib-1.7.0.tar.gz", hash = "sha256:059bd77306ea7b419d4f76016aef6d7027cc8a0785579b5aad198803435f882c"}, +] + [[package]] name = "terminado" version = "0.17.1" @@ -2606,11 +2890,23 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "toolz" +version = "0.12.0" +description = "List processing tools and functional utilities" +category = "main" +optional = true +python-versions = ">=3.5" +files = [ + {file = "toolz-0.12.0-py3-none-any.whl", hash = "sha256:2059bd4148deb1884bb0eb770a3cde70e7f954cfbbdc2285f1f2de01fd21eb6f"}, + {file = "toolz-0.12.0.tar.gz", hash = "sha256:88c570861c440ee3f2f6037c4654613228ff40c93a6c25e0eba70d17282c6194"}, +] + [[package]] name = "tornado" version = "6.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -category = "dev" +category = "main" optional = false python-versions = ">= 3.7" files = [ @@ -2647,7 +2943,7 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2674,7 +2970,7 @@ dev = ["flake8 (<4.0.0)", "flake8-annotations", "flake8-bugbear", "flake8-commas name = "urllib3" version = "1.26.15" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -2836,6 +3132,21 @@ y-py = ">=0.5.3,<0.6.0" [package.extras] test = ["mypy", "pre-commit", "pytest", "pytest-asyncio", "websockets (>=10.0)"] +[[package]] +name = "zict" +version = "2.2.0" +description = "Mutable mapping tools" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "zict-2.2.0-py2.py3-none-any.whl", hash = "sha256:dabcc8c8b6833aa3b6602daad50f03da068322c1a90999ff78aed9eecc8fa92c"}, + {file = "zict-2.2.0.tar.gz", hash = "sha256:d7366c2e2293314112dcf2432108428a67b927b00005619feefc310d12d833f3"}, +] + +[package.dependencies] +heapdict = "*" + [[package]] name = "zipp" version = "3.15.0" @@ -2867,7 +3178,10 @@ files = [ [package.extras] typeguard = ["typeguard (>=2.13.3,<3.0.0)"] +[extras] +dask = ["bokeh", "dask", "dask-jobqueue", "distributed"] + [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "2bc4a9e9ab991c57cc61db03c1d0f367142899690b874c851098d370535be045" +content-hash = "4d9c97b5d6c9f69aaaf3480900246212ad64081063103e35550092a27b953197" diff --git a/pyproject.toml b/pyproject.toml index 24dc476..81dcd4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znflow" -version = "0.1.9" +version = "0.1.10" description = "A general purpose framework for building and running computational graphs." authors = ["zincwarecode "] license = "Apache-2.0" @@ -11,6 +11,11 @@ python = "^3.8" networkx = "^3.0" matplotlib = "^3.6.3" +dask = { version = "^2022.12.1", optional = true } +distributed = { version = "^2022.12.1", optional = true } +dask-jobqueue = { version = "^0.8.1", optional = true } +bokeh = { version = "^2.4.2", optional = true } + [tool.poetry.group.lint.dependencies] black = "^22.10.0" isort = "^5.10.1" @@ -25,6 +30,10 @@ attrs = "^22.2.0" [tool.poetry.group.notebook.dependencies] jupyterlab = "^3.5.1" +[tool.poetry.extras] +dask = ["dask", "distributed", "dask-jobqueue", "bokeh"] + + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_deployment.py b/tests/test_deployment.py new file mode 100644 index 0000000..c77786f --- /dev/null +++ b/tests/test_deployment.py @@ -0,0 +1,99 @@ +import dataclasses + +import znflow + + +@znflow.nodify +def compute_sum(*args): + return sum(args) + + +@dataclasses.dataclass +class ComputeSum(znflow.Node): + inputs: list + outputs: float = None + + def run(self): + # this will just call the function compute_sum and won't construct a graph! + self.outputs = compute_sum(*self.inputs) + + +@znflow.nodify +def add_to_ComputeSum(instance: ComputeSum): + return instance.outputs + 1 + + +def test_single_nodify(): + with znflow.DiGraph() as graph: + node1 = compute_sum(1, 2, 3) + + depl = znflow.deployment.Deployment(graph=graph) + depl.submit_graph() + + node1 = depl.get_results(node1) + assert node1.result == 6 + + +def test_single_Node(): + with znflow.DiGraph() as graph: + node1 = ComputeSum(inputs=[1, 2, 3]) + + depl = znflow.deployment.Deployment(graph=graph) + depl.submit_graph() + + node1 = depl.get_results(node1) + assert node1.outputs == 6 + + +def test_multiple_nodify(): + with znflow.DiGraph() as graph: + node1 = compute_sum(1, 2, 3) + node2 = compute_sum(4, 5, 6) + node3 = compute_sum(node1, node2) + + depl = znflow.deployment.Deployment(graph=graph) + depl.submit_graph() + + node1 = depl.get_results(node1) + node2 = depl.get_results(node2) + node3 = depl.get_results(node3) + assert node1.result == 6 + assert node2.result == 15 + assert node3.result == 21 + + +def test_multiple_Node(): + with znflow.DiGraph() as graph: + node1 = ComputeSum(inputs=[1, 2, 3]) + node2 = ComputeSum(inputs=[4, 5, 6]) + node3 = ComputeSum(inputs=[node1.outputs, node2.outputs]) + + depl = znflow.deployment.Deployment(graph=graph) + depl.submit_graph() + + node1 = depl.get_results(node1) + node2 = depl.get_results(node2) + node3 = depl.get_results(node3) + assert node1.outputs == 6 + assert node2.outputs == 15 + assert node3.outputs == 21 + + +def test_multiple_nodify_and_Node(): + with znflow.DiGraph() as graph: + node1 = compute_sum(1, 2, 3) + node2 = ComputeSum(inputs=[4, 5, 6]) + node3 = compute_sum(node1, node2.outputs) + node4 = ComputeSum(inputs=[node1, node2.outputs, node3]) + node5 = add_to_ComputeSum(node4) + + depl = znflow.deployment.Deployment(graph=graph) + depl.submit_graph() + + results = depl.get_results(graph.nodes) + + assert results[node1.uuid].result == 6 + assert results[node2.uuid].outputs == 15 + assert results[node3.uuid].result == 21 + assert results[node4.uuid].outputs == 42 + assert results[node5.uuid].result == 43 diff --git a/tests/test_znflow.py b/tests/test_znflow.py index 5bcec00..ea31dba 100644 --- a/tests/test_znflow.py +++ b/tests/test_znflow.py @@ -4,4 +4,4 @@ def test_version(): """Test the version.""" - assert znflow.__version__ == "0.1.9" + assert znflow.__version__ == "0.1.10" diff --git a/znflow/__init__.py b/znflow/__init__.py index 601a750..533d87e 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -1,4 +1,5 @@ """The 'ZnFlow' package.""" +import contextlib import importlib.metadata import logging import sys @@ -32,6 +33,11 @@ "combine", ] +with contextlib.suppress(ImportError): + from znflow import deployment + + __all__ += ["deployment"] + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) diff --git a/znflow/deployment.py b/znflow/deployment.py new file mode 100644 index 0000000..c236b3e --- /dev/null +++ b/znflow/deployment.py @@ -0,0 +1,168 @@ +"""ZnFlow deployment using Dask.""" + +import dataclasses +import typing +import uuid + +from dask.distributed import Client, Future +from networkx.classes.reportviews import NodeView + +from znflow.base import Connection, NodeBaseMixin +from znflow.graph import DiGraph +from znflow.utils import IterableHandler + + +class _LoadNode(IterableHandler): + """Iterable handler for loading nodes.""" + + def default(self, value, **kwargs): + """Default handler for loading nodes. + + Parameters + ---------- + value: NodeBaseMixin|any + If a NodeBaseMixin, the node will be loaded and returned. + kwargs: dict + results: results dictionary of {uuid: node} shape. + + Returns + ------- + any: + If a NodeBaseMixin, the node will be loaded and returned. + Otherwise, the input value is returned. + + """ + results = kwargs["results"] + if isinstance(value, NodeBaseMixin): + return results[value.uuid].result() + + return value + + +class _UpdateConnections(IterableHandler): + """Iterable handler for replacing connections.""" + + def default(self, value, **kwargs): + """Replace connections by its values. + + Parameters + ---------- + value: Connection|any + If a Connection, the connection will be replaced by its result. + kwargs: dict + predecessors: dict of {uuid: Connection} shape. + + Returns + ------- + any: + If a Connection, the connection will be replaced by its result. + Otherwise, the input value is returned. + + """ + predecessors = kwargs["predecessors"] + if isinstance(value, Connection): + # We don't actually need the connection, we need the results. + return dataclasses.replace(value, instance=predecessors[value.uuid]).result + return value + + +def node_submit(node: NodeBaseMixin, **kwargs) -> NodeBaseMixin: + """Submit script for Dask worker. + + Parameters + ---------- + node: NodeBaseMixin + the Node class + kwargs: dict + predecessors: dict of {uuid: Connection} shape + + Returns + ------- + NodeBaseMixin: + the Node class with updated state (after calling "Node.run"). + + """ + predecessors = kwargs.get("predecessors", {}) + for item in dir(node): + # TODO this information is available in the graph, + # no need to expensively iterate over all attributes + if item.startswith("_"): + continue + updater = _UpdateConnections() + value = updater(getattr(node, item), predecessors=predecessors) + if updater.updated: + setattr(node, item, value) + + node.run() + return node + + +@dataclasses.dataclass +class Deployment: + """ZnFlow deployment using Dask. + + Attributes + ---------- + graph: DiGraph + the znflow graph containing the nodes. + client: Client, optional + the Dask client. + results: Dict[uuid, Future] + a dictionary of {uuid: Future} shape that is filled after the graph is submitted. + + """ + + graph: DiGraph + client: Client = dataclasses.field(default_factory=Client) + results: typing.Dict[uuid.UUID, Future] = dataclasses.field( + default_factory=dict, init=False + ) + + def submit_graph(self): + """Submit the graph to Dask. + + When submitting to Dask, a Node is serialized, processed and a + copy can be returned. + + This requires: + - the connections to be updated to the respective Nodes coming from Dask futures. + - the Node to be returned from the workers and passed to all successors. + """ + for node_uuid in self.graph.reverse(): + node = self.graph.nodes[node_uuid]["value"] + predecessors = list(self.graph.predecessors(node.uuid)) + + if len(predecessors) == 0: + self.results[node.uuid] = self.client.submit( # TODO how to name + node_submit, node=node, pure=False + ) + else: + self.results[node.uuid] = self.client.submit( + node_submit, + node=node, + predecessors={ + x: self.results[x] for x in self.results if x in predecessors + }, + pure=False, + ) + + def get_results(self, obj: typing.Union[NodeBaseMixin, list, dict, NodeView], /): + """Get the results from Dask based on the original object. + + Parameters + ---------- + obj: NodeBaseMixin|list|dict|NodeView + either a single Node or multiple Nodes from the submitted graph. + + Returns + ------- + any: + Returns an instance of obj which is updated with the results from Dask. + + """ + if isinstance(obj, NodeView): + data = _LoadNode()(dict(obj), results=self.results) + return {x: v["value"] for x, v in data.items()} + elif isinstance(obj, DiGraph): + raise NotImplementedError + return _LoadNode()(obj, results=self.results)