diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 0bba628c89..b392a7f8c9 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -51,7 +51,7 @@ jobs: run: | coverage run -m pytest tests/flytekit/unit -m "not sandbox_test" - name: Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v1.5.2 with: fail_ci_if_error: true # optional (default = false) @@ -77,16 +77,33 @@ jobs: - flytekit-kf-pytorch - flytekit-kf-tensorflow - flytekit-modin + - flytekit-onnx-pytorch + - flytekit-onnx-scikitlearn + - flytekit-onnx-tensorflow - flytekit-pandera - flytekit-papermill + - flytekit-polars - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy + - flytekit-whylogs exclude: # flytekit-modin depends on ray which does not have a 3.10 wheel yet. # Issue tracked in https://github.com/ray-project/ray/issues/19116. - python-version: 3.10 plugin-names: "flytekit-modin" + # Great-expectations does not support python 3.10 yet + # https://github.com/great-expectations/great_expectations/blob/develop/setup.py#L87-L89 + - python-version: 3.10 + plugin-names: "flytekit-greatexpectations" + # onnxruntime does not support python 3.10 yet + # https://github.com/microsoft/onnxruntime/issues/9782 + - python-version: 3.10 + plugin-names: "flytekit-onnx-pytorch" + - python-version: 3.10 + plugin-names: "flytekit-onnx-scikitlearn" + - python-version: 3.10 + plugin-names: "flytekit-onnx-tensorflow" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -104,7 +121,7 @@ jobs: run: | make setup cd plugins/${{ matrix.plugin-names }} - pip install -e . + pip install -r requirements.txt if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi pip install --no-deps -U https://github.com/flyteorg/flytekit/archive/${{ github.sha }}.zip#egg=flytekit pip freeze diff --git a/.gitignore b/.gitignore index b3566d8032..12931f955b 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ docs/source/plugins/generated/ .pytest_flyte htmlcov *.ipynb +.env +*dat +.env/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 316f1b0ee2..39470b7370 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 4.0.1 hooks: - id: flake8 - repo: https://github.com/psf/black @@ -8,17 +8,17 @@ repos: hooks: - id: black - repo: https://github.com/PyCQA/isort - rev: 5.9.3 + rev: 5.10.1 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.7.2.1 + rev: v0.8.0.4 hooks: - id: shellcheck diff --git a/CODEOWNERS b/CODEOWNERS index 96d7d9d004..9389524869 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence. -* @wild-endeavor @kumare3 @eapolinario +* @wild-endeavor @kumare3 @eapolinario @pingsutw diff --git a/Makefile b/Makefile index b45e8f3f8b..4b3278bec0 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ help: .PHONY: install-piptools install-piptools: # pip 22.1 broke pip-tools: https://github.com/jazzband/pip-tools/issues/1617 - pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1" + python -m pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1" .PHONY: update_boilerplate update_boilerplate: diff --git a/README.md b/README.md index 7144575a37..67b6d12297 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,17 @@ - -

- Flyte Logo -

-

- Flytekit Python -

-

- Flytekit Python is the Python SDK built on top of Flyte -

-

- Plugins - · - Contribution Guide -

- +

+ Flyte Logo +

+

+ Flytekit Python +

+

+ Flytekit Python is the Python SDK built on top of Flyte +

+

+ Plugins + · + Contribution Guide +

[![PyPI version fury.io](https://badge.fury.io/py/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![PyPI download day](https://img.shields.io/pypi/dd/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) diff --git a/dev-requirements.in b/dev-requirements.in index 0c29d2fe8a..b2e6cf74a1 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -10,3 +10,5 @@ pre-commit codespell google-cloud-bigquery google-cloud-bigquery-storage +IPython +torch diff --git a/dev-requirements.txt b/dev-requirements.txt index a410225b32..3705a578c0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # make dev-requirements.txt @@ -18,56 +18,65 @@ attrs==20.3.0 # jsonschema # pytest # pytest-docker -bcrypt==3.2.0 +backcall==0.2.0 + # via ipython +bcrypt==3.2.2 # via paramiko binaryornot==0.4.4 # via # -c requirements.txt # cookiecutter -cachetools==5.0.0 +cached-property==1.5.2 + # via docker-compose +cachetools==5.2.0 # via google-auth -certifi==2021.10.8 +certifi==2022.6.15 # via # -c requirements.txt # requests -cffi==1.15.0 +cffi==1.15.1 # via + # -c requirements.txt # bcrypt # cryptography # pynacl cfgv==3.3.1 # via pre-commit -chardet==4.0.0 +chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via # -c requirements.txt # requests -click==8.1.2 +click==8.1.3 # via # -c requirements.txt # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via # -c requirements.txt # flytekit codespell==2.1.0 # via -r dev-requirements.in -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.3.2 +coverage[toml]==6.4.1 # via -r dev-requirements.in -croniter==1.3.4 +croniter==1.3.5 # via # -c requirements.txt # flytekit -cryptography==36.0.2 - # via paramiko +cryptography==37.0.4 + # via + # -c requirements.txt + # paramiko + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -75,6 +84,7 @@ dataclasses-json==0.5.7 decorator==5.1.1 # via # -c requirements.txt + # ipython # retry deprecated==1.2.13 # via @@ -105,68 +115,84 @@ dockerpty==0.4.1 # via docker-compose docopt==0.6.2 # via docker-compose -docstring-parser==0.14 +docstring-parser==0.14.1 # via # -c requirements.txt # flytekit -filelock==3.6.0 +filelock==3.7.1 # via virtualenv -flyteidl==1.0.0.post1 +flyteidl==1.1.8 # via # -c requirements.txt # flytekit -google-api-core[grpc]==2.7.2 +google-api-core[grpc]==2.8.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.6.6 +google-auth==2.9.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.0.1 +google-cloud-bigquery==3.2.0 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.13.1 +google-cloud-bigquery-storage==2.13.2 # via # -r dev-requirements.in # google-cloud-bigquery -google-cloud-core==2.3.0 +google-cloud-core==2.3.1 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media -google-resumable-media==2.3.2 +google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # -c requirements.txt # flyteidl # google-api-core # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # -c requirements.txt # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via # -c requirements.txt # flytekit # google-api-core -identify==2.4.12 +identify==2.5.1 # via pre-commit idna==3.3 # via # -c requirements.txt # requests -importlib-metadata==4.11.3 +importlib-metadata==4.12.0 # via # -c requirements.txt + # click + # flytekit + # jsonschema # keyring + # pluggy + # pre-commit + # pytest + # virtualenv iniconfig==1.1.1 # via pytest -jinja2==3.1.1 +ipython==7.34.0 + # via -r dev-requirements.in +jedi==0.18.1 + # via ipython +jeepney==0.8.0 + # via + # -c requirements.txt + # keyring + # secretstorage +jinja2==3.1.2 # via # -c requirements.txt # cookiecutter @@ -182,7 +208,7 @@ jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.5.0 +keyring==23.6.0 # via # -c requirements.txt # flytekit @@ -190,7 +216,7 @@ markupsafe==2.1.1 # via # -c requirements.txt # jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # -c requirements.txt # dataclasses-json @@ -204,9 +230,11 @@ marshmallow-jsonschema==0.13.0 # via # -c requirements.txt # flytekit +matplotlib-inline==0.1.3 + # via ipython mock==4.0.3 # via -r dev-requirements.in -mypy==0.942 +mypy==0.961 # via -r dev-requirements.in mypy-extensions==0.4.3 # via @@ -217,11 +245,12 @@ natsort==8.1.0 # via # -c requirements.txt # flytekit -nodeenv==1.6.0 +nodeenv==1.7.0 # via pre-commit numpy==1.21.6 # via # -c requirements.txt + # flytekit # pandas # pyarrow packaging==21.3 @@ -234,19 +263,23 @@ pandas==1.3.5 # via # -c requirements.txt # flytekit -paramiko==2.10.4 +paramiko==2.11.0 # via docker +parso==0.8.3 + # via jedi +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython platformdirs==2.5.2 # via virtualenv pluggy==1.0.0 # via pytest -poyo==0.5.0 - # via - # -c requirements.txt - # cookiecutter -pre-commit==2.18.1 +pre-commit==2.19.0 # via -r dev-requirements.in -proto-plus==1.20.3 +prompt-toolkit==3.0.30 + # via ipython +proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -257,6 +290,7 @@ protobuf==3.20.1 # flytekit # google-api-core # google-cloud-bigquery + # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus @@ -265,6 +299,8 @@ protoc-gen-swagger==0.1.0 # via # -c requirements.txt # flyteidl +ptyprocess==0.7.0 + # via pexpect py==1.11.0 # via # -c requirements.txt @@ -282,10 +318,18 @@ pyasn1==0.4.8 pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 - # via cffi + # via + # -c requirements.txt + # cffi +pygments==2.12.0 + # via ipython pynacl==1.5.0 # via paramiko -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via + # -c requirements.txt + # flytekit +pyparsing==3.0.9 # via # -c requirements.txt # packaging @@ -316,7 +360,7 @@ python-json-logger==2.0.2 # via # -c requirements.txt # flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via # -c requirements.txt # cookiecutter @@ -332,14 +376,15 @@ pytz==2022.1 pyyaml==5.4.1 # via # -c requirements.txt + # cookiecutter # docker-compose # flytekit # pre-commit -regex==2022.4.24 +regex==2022.6.2 # via # -c requirements.txt # docker-image-py -requests==2.27.1 +requests==2.28.1 # via # -c requirements.txt # cookiecutter @@ -349,7 +394,7 @@ requests==2.27.1 # google-api-core # google-cloud-bigquery # responses -responses==0.20.0 +responses==0.21.0 # via # -c requirements.txt # flytekit @@ -359,11 +404,17 @@ retry==0.9.2 # flytekit rsa==4.8 # via google-auth +secretstorage==3.3.2 + # via + # -c requirements.txt + # keyring +singledispatchmethod==1.0 + # via + # -c requirements.txt + # flytekit six==1.16.0 # via # -c requirements.txt - # bcrypt - # cookiecutter # dockerpty # google-auth # grpcio @@ -393,11 +444,23 @@ tomli==2.0.1 # coverage # mypy # pytest -typing-extensions==4.2.0 +torch==1.11.0 + # via -r dev-requirements.in +traitlets==5.3.0 + # via + # ipython + # matplotlib-inline +typed-ast==1.5.4 + # via mypy +typing-extensions==4.3.0 # via # -c requirements.txt + # arrow # flytekit + # importlib-metadata # mypy + # responses + # torch # typing-inspect typing-inspect==0.7.1 # via @@ -409,8 +472,10 @@ urllib3==1.26.9 # flytekit # requests # responses -virtualenv==20.14.1 +virtualenv==20.15.1 # via pre-commit +wcwidth==0.2.5 + # via prompt-toolkit websocket-client==0.59.0 # via # -c requirements.txt @@ -420,7 +485,7 @@ wheel==0.37.1 # via # -c requirements.txt # flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # -c requirements.txt # deprecated diff --git a/doc-requirements.in b/doc-requirements.in index 4d60a6919b..760d3903dc 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -33,3 +33,4 @@ papermill # papermill jupyter # papermill pyspark # spark sqlalchemy # sqlalchemy +torch # pytorch diff --git a/doc-requirements.txt b/doc-requirements.txt index 3eee9655cb..e9f16c89a0 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # make doc-requirements.txt @@ -12,28 +12,26 @@ altair==4.2.0 # via great-expectations ansiwrap==0.8.4 # via papermill -appnope==0.1.3 - # via - # ipykernel - # ipython argon2-cffi==21.3.0 # via notebook argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.2 # via jinja2-time -astroid==2.11.3 +astroid==2.11.6 # via sphinx-autoapi -asttokens==2.0.5 - # via stack-data attrs==21.4.0 # via # jsonschema # visions -babel==2.10.1 +babel==2.10.3 # via sphinx backcall==0.2.0 # via ipython +backports-zoneinfo==0.2.1 + # via + # pytz-deprecation-shim + # tzlocal beautifulsoup4==4.11.1 # via # furo @@ -42,42 +40,44 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -bleach==5.0.0 +bleach==5.0.1 # via nbconvert -botocore==1.25.0 +botocore==1.27.22 # via -r doc-requirements.in -cachetools==5.0.0 +cachetools==5.2.0 # via google-auth -certifi==2021.10.8 +certifi==2022.6.15 # via # kubernetes # requests -cffi==1.15.0 +cffi==1.15.1 # via # argon2-cffi-bindings # cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit # great-expectations # papermill -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -colorama==0.4.4 +colorama==0.4.5 # via great-expectations -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 +cryptography==37.0.4 # via # -r doc-requirements.in # great-expectations + # pyopenssl + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 @@ -102,7 +102,7 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14 +docstring-parser==0.14.1 # via flytekit docutils==0.17.1 # via @@ -118,59 +118,57 @@ entrypoints==0.4 # jupyter-client # nbconvert # papermill -executing==0.8.3 - # via stack-data fastjsonschema==2.15.3 # via nbformat -flyteidl==1.0.0.post1 +flyteidl==1.1.8 # via flytekit -fonttools==4.33.2 +fonttools==4.33.3 # via matplotlib -fsspec==2022.3.0 +fsspec==2022.5.0 # via # -r doc-requirements.in # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -google-api-core[grpc]==2.7.2 +google-api-core[grpc]==2.8.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.6.6 +google-auth==2.9.0 # via # google-api-core # google-cloud-core # kubernetes google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.0.1 +google-cloud-bigquery==3.2.0 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.13.1 +google-cloud-bigquery-storage==2.13.2 # via google-cloud-bigquery -google-cloud-core==2.3.0 +google-cloud-core==2.3.1 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media -google-resumable-media==2.3.2 +google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.2 +great-expectations==0.15.12 # via -r doc-requirements.in greenlet==1.1.2 # via sqlalchemy -grpcio==1.44.0 +grpcio==1.47.0 # via # -r doc-requirements.in # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via # flytekit # google-api-core @@ -180,22 +178,28 @@ idna==3.3 # via requests imagehash==4.2.1 # via visions -imagesize==1.3.0 +imagesize==1.4.1 # via sphinx -importlib-metadata==4.11.3 +importlib-metadata==4.12.0 # via + # click + # flytekit # great-expectations + # jsonschema # keyring # markdown # sphinx -ipykernel==6.13.0 + # sqlalchemy +importlib-resources==5.8.0 + # via jsonschema +ipykernel==6.15.0 # via # ipywidgets # jupyter # jupyter-console # notebook # qtconsole -ipython==8.2.0 +ipython==7.34.0 # via # great-expectations # ipykernel @@ -206,11 +210,15 @@ ipython-genutils==0.2.0 # ipywidgets # notebook # qtconsole -ipywidgets==7.7.0 +ipywidgets==7.7.1 # via jupyter jedi==0.18.1 # via ipython -jinja2==3.0.3 +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 # via # altair # cookiecutter @@ -223,9 +231,9 @@ jinja2==3.0.3 # sphinx-autoapi jinja2-time==0.2.0 # via cookiecutter -jmespath==1.0.0 +jmespath==1.0.1 # via botocore -joblib==1.0.1 +joblib==1.1.0 # via # pandas-profiling # phik @@ -233,21 +241,21 @@ jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.4.0 +jsonschema==4.6.1 # via # altair # great-expectations # nbformat jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.3.0 +jupyter-client==7.3.4 # via # ipykernel # jupyter-console # nbclient # notebook # qtconsole -jupyter-console==6.4.3 +jupyter-console==6.4.4 # via jupyter jupyter-core==4.10.0 # via @@ -258,26 +266,26 @@ jupyter-core==4.10.0 # qtconsole jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==1.1.0 +jupyterlab-widgets==1.1.1 # via ipywidgets -keyring==23.5.0 +keyring==23.6.0 # via flytekit -kiwisolver==1.4.2 +kiwisolver==1.4.3 # via matplotlib -kubernetes==23.3.0 +kubernetes==24.2.0 # via -r doc-requirements.in lazy-object-proxy==1.7.1 # via astroid -lxml==4.8.0 +lxml==4.9.1 # via sphinx-material -markdown==3.3.6 +markdown==3.3.7 # via -r doc-requirements.in -markupsafe==2.0.1 +markupsafe==2.1.1 # via # jinja2 # nbconvert # pandas-profiling -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -286,7 +294,7 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.1 +matplotlib==3.5.2 # via # missingno # pandas-profiling @@ -302,7 +310,7 @@ mistune==0.8.4 # via # great-expectations # nbconvert -modin==0.14.0 +modin==0.12.1 # via -r doc-requirements.in multimethod==1.8 # via @@ -312,7 +320,7 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -nbclient==0.6.0 +nbclient==0.6.6 # via # nbconvert # papermill @@ -320,10 +328,9 @@ nbconvert==6.5.0 # via # jupyter # notebook -nbformat==5.3.0 +nbformat==5.4.0 # via # great-expectations - # ipywidgets # nbclient # nbconvert # notebook @@ -334,16 +341,17 @@ nest-asyncio==1.5.5 # jupyter-client # nbclient # notebook -networkx==2.8 +networkx==2.6.3 # via visions -notebook==6.4.11 +notebook==6.4.12 # via # great-expectations # jupyter # widgetsnbextension -numpy==1.22.3 +numpy==1.21.6 # via # altair + # flytekit # great-expectations # imagehash # matplotlib @@ -372,7 +380,7 @@ packaging==21.3 # pandera # qtpy # sphinx -pandas==1.4.1 +pandas==1.3.5 # via # altair # dolt-integrations @@ -384,9 +392,9 @@ pandas==1.4.1 # phik # seaborn # visions -pandas-profiling==3.1.0 +pandas-profiling==3.2.0 # via -r doc-requirements.in -pandera==0.10.1 +pandera==0.9.0 # via -r doc-requirements.in pandocfilters==1.5.0 # via nbconvert @@ -400,22 +408,20 @@ phik==0.12.2 # via pandas-profiling pickleshare==0.7.5 # via ipython -pillow==9.1.0 +pillow==9.2.0 # via # imagehash # matplotlib # visions -plotly==5.7.0 +plotly==5.9.0 # via -r doc-requirements.in -poyo==0.5.0 - # via cookiecutter prometheus-client==0.14.1 # via notebook -prompt-toolkit==3.0.29 +prompt-toolkit==3.0.30 # via # ipython # jupyter-console -proto-plus==1.20.3 +proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -425,23 +431,22 @@ protobuf==3.20.1 # flytekit # google-api-core # google-cloud-bigquery + # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.0 +psutil==5.9.1 # via ipykernel ptyprocess==0.7.0 # via # pexpect # terminado -pure-eval==0.2.2 - # via stack-data py==1.11.0 # via retry -py4j==0.10.9.3 +py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 # via @@ -456,18 +461,21 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.9.0 +pydantic==1.9.1 # via # pandas-profiling # pandera pygments==2.12.0 # via + # furo # ipython # jupyter-console # nbconvert # qtconsole # sphinx # sphinx-prompt +pyopenssl==22.0.0 + # via flytekit pyparsing==2.4.7 # via # great-expectations @@ -475,7 +483,7 @@ pyparsing==2.4.7 # packaging pyrsistent==0.18.1 # via jsonschema -pyspark==3.2.1 +pyspark==3.3.0 # via -r doc-requirements.in python-dateutil==2.8.2 # via @@ -491,7 +499,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify[unidecode]==6.1.1 +python-slugify[unidecode]==6.1.2 # via # cookiecutter # sphinx-material @@ -509,23 +517,25 @@ pywavelets==1.3.0 # via imagehash pyyaml==6.0 # via + # cookiecutter # flytekit # kubernetes # pandas-profiling # papermill # sphinx-autoapi -pyzmq==22.3.0 +pyzmq==23.2.0 # via + # ipykernel # jupyter-client # notebook # qtconsole -qtconsole==5.3.0 +qtconsole==5.3.1 # via jupyter -qtpy==2.0.1 +qtpy==2.1.0 # via qtconsole -regex==2022.4.24 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker @@ -541,7 +551,7 @@ requests==2.27.1 # sphinx requests-oauthlib==1.3.1 # via kubernetes -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit @@ -551,7 +561,7 @@ ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.6 # via ruamel-yaml -scipy==1.8.0 +scipy==1.7.3 # via # great-expectations # imagehash @@ -563,18 +573,19 @@ seaborn==0.11.2 # via # missingno # pandas-profiling +secretstorage==3.3.2 + # via keyring send2trash==1.8.0 # via notebook +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # asttokens # bleach - # cookiecutter # google-auth # grpcio # imagehash # kubernetes - # plotly # python-dateutil # sphinx-code-include snowballstemmer==2.2.0 @@ -588,6 +599,7 @@ sphinx==4.5.0 # -r doc-requirements.in # furo # sphinx-autoapi + # sphinx-basic-ng # sphinx-code-include # sphinx-copybutton # sphinx-fontawesome @@ -598,6 +610,8 @@ sphinx==4.5.0 # sphinxcontrib-yt sphinx-autoapi==1.8.4 # via -r doc-requirements.in +sphinx-basic-ng==0.0.1a12 + # via furo sphinx-code-include==1.1.1 # via -r doc-requirements.in sphinx-copybutton==0.5.0 @@ -626,13 +640,11 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.35 +sqlalchemy==1.4.39 # via -r doc-requirements.in -stack-data==0.2.0 - # via ipython statsd==3.3.0 # via flytekit -tangled-up-in-unicode==0.1.0 +tangled-up-in-unicode==0.2.0 # via # pandas-profiling # visions @@ -642,7 +654,7 @@ tenacity==8.0.1 # plotly termcolor==1.1.0 # via great-expectations -terminado==0.13.3 +terminado==0.15.0 # via notebook text-unidecode==1.3 # via python-slugify @@ -652,7 +664,9 @@ tinycss2==1.1.1 # via nbconvert toolz==0.11.2 # via altair -tornado==6.1 +torch==1.11.0 + # via -r doc-requirements.in +tornado==6.2 # via # ipykernel # jupyter-client @@ -663,7 +677,7 @@ tqdm==4.64.0 # great-expectations # pandas-profiling # papermill -traitlets==5.1.1 +traitlets==5.3.0 # via # ipykernel # ipython @@ -676,12 +690,22 @@ traitlets==5.1.1 # nbformat # notebook # qtconsole -typing-extensions==4.2.0 +typed-ast==1.5.4 + # via astroid +typing-extensions==4.3.0 # via + # argon2-cffi + # arrow # astroid # flytekit # great-expectations + # importlib-metadata + # jsonschema + # kiwisolver + # pandera # pydantic + # responses + # torch # typing-inspect typing-inspect==0.7.1 # via @@ -711,22 +735,24 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.2 +websocket-client==1.3.3 # via # docker # kubernetes wheel==0.37.1 # via flytekit -widgetsnbextension==3.6.0 +widgetsnbextension==3.6.1 # via ipywidgets -wrapt==1.14.0 +wrapt==1.14.1 # via # astroid # deprecated # flytekit # pandera zipp==3.8.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst index 1ed84de1b4..3f06961022 100644 --- a/docs/source/data.extend.rst +++ b/docs/source/data.extend.rst @@ -1,8 +1,11 @@ ############################## Extend Data Persistence layer ############################## -Flytekit provides a data persistence layer, which is used for recording metadata that is shared with backend Flyte. This persistence layer is also available for various types to store raw user data and is designed to be cross-cloud compatible. -Moreover, it is design to be extensible and users can bring their own data persistence plugins by following the persistence interface. NOTE, this is bound to get more extensive for variety of use-cases, but the core set of apis are battle tested. +Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible. +Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface. + +.. note:: + This will become extensive for a variety of use-cases, but the core set of APIs have been battle tested. .. automodule:: flytekit.core.data_persistence :no-members: @@ -13,3 +16,22 @@ Moreover, it is design to be extensible and users can bring their own data persi :no-members: :no-inherited-members: :no-special-members: + +The ``fsspec`` Data Plugin +-------------------------- + +Flytekit ships with a default storage driver that uses aws-cli on AWS and gsutil on GCP. By default, Flyte uploads the task outputs to S3 or GCS using these storage drivers. + +Why ``fsspec``? +^^^^^^^^^^^^^^^ + +You can use the fsspec plugin implementation to utilize all its available plugins with flytekit. The `fsspec `_ plugin provides an implementation of the data persistence layer in Flytekit. For example: HDFS, FTP are supported in fsspec, so you can use them with flytekit too. +The data persistence layer helps store logs of metadata and raw user data. +As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``. + +`Here `_ is a code snippet that shows protocols mapped to the class it implements. + +Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec. + +.. note:: + This plugin installs fsspec core only. To install all the fsspec plugins, see `here `_. diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst new file mode 100644 index 0000000000..12fd3d62d9 --- /dev/null +++ b/docs/source/extras.pytorch.rst @@ -0,0 +1,7 @@ +############ +PyTorch Type +############ +.. automodule:: flytekit.extras.pytorch + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index f1b15455dd..f0cdff28dc 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types. types.builtins.structured types.builtins.file types.builtins.directory + extras.pytorch diff --git a/flytekit/__init__.py b/flytekit/__init__.py index c0ea9ebe3d..c67a8a04b4 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -154,6 +154,7 @@ """ import sys +from typing import Generator if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -181,6 +182,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck +from flytekit.extras import pytorch from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -188,7 +190,7 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types import directory, file, schema +from flytekit.types import directory, file, numpy, schema from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, @@ -215,6 +217,10 @@ def current_context() -> ExecutionParameters: return FlyteContextManager.current_context().execution_state.user_space_params +def new_context() -> Generator[FlyteContext, None, None]: + return FlyteContextManager.with_context(FlyteContextManager.current_context().new_builder()) + + def load_implicit_plugins(): """ This method allows loading all plugins that have the entrypoint specification. This uses the plugin loading diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 68bd2578e5..1f8dd78ef0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -243,6 +243,7 @@ def setup_execution( tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, + task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 8db7c93a98..d542af5f7e 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1004,3 +1004,17 @@ def get_upload_signed_url( expires_in=expires_in_pb, ) ) + + def get_download_signed_url( + self, native_url: str, expires_in: datetime.timedelta = None + ) -> _data_proxy_pb2.CreateUploadLocationResponse: + expires_in_pb = None + if expires_in: + expires_in_pb = Duration() + expires_in_pb.FromTimedelta(expires_in) + return super(SynchronousFlyteClient, self).create_download_location( + _data_proxy_pb2.CreateDownloadLocationRequest( + native_url=native_url, + expires_in=expires_in_pb, + ) + ) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 9aba78888a..9bdf85a178 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,12 +1,14 @@ from __future__ import annotations import base64 as _base64 +import ssl import subprocess import time import typing from typing import Optional import grpc +import OpenSSL import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest from flyteidl.service import admin_pb2_grpc as _admin_service @@ -110,6 +112,26 @@ def __init__(self, cfg: PlatformConfig, **kwargs): self._cfg = cfg if cfg.insecure: self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs) + elif cfg.insecure_skip_verify: + # Get port from endpoint or use 443 + endpoint_parts = cfg.endpoint.rsplit(":", 1) + if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit(): + server_address = tuple(endpoint_parts) + else: + server_address = (cfg.endpoint, "443") + + cert = ssl.get_server_certificate(server_address) + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + cn = x509.get_subject().CN + credentials = grpc.ssl_channel_credentials(str.encode(cert)) + options = kwargs.get("options", []) + options.append(("grpc.ssl_target_name_override", cn)) + self._channel = grpc.secure_channel( + target=cfg.endpoint, + credentials=credentials, + options=options, + compression=kwargs.get("compression", None), + ) else: if "credentials" not in kwargs: credentials = grpc.ssl_channel_credentials( @@ -251,7 +273,10 @@ def _refresh_credentials_from_command(self): except subprocess.CalledProcessError as e: cli_logger.error("Failed to generate token from command {}".format(command)) raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) - self.set_access_token(output.stdout.strip()) + authorization_header_key = self.public_client_config.authorization_metadata_key or None + if not authorization_header_key: + self.set_access_token(output.stdout.strip()) + self.set_access_token(output.stdout.strip(), authorization_header_key) def _refresh_credentials_noop(self): pass @@ -833,6 +858,17 @@ def create_upload_location( """ return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def create_download_location( + self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest + ) -> _dataproxy_pb2.CreateDownloadLocationResponse: + """ + Get a signed url to be used during fast registration + :param flyteidl.service.dataproxy_pb2.CreateDownloadLocationRequest create_download_location_request: + :rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse + """ + return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata) + def get_token(token_endpoint, authorization_header, scope): """ diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 3628af7beb..21aec1c4ad 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -4,6 +4,7 @@ import os as _os import stat as _stat import sys as _sys +from dataclasses import replace from typing import Callable, Dict, List, Tuple, Union import click as _click @@ -276,7 +277,7 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC if parent_ctx.obj["cacert"]: kwargs["root_certificates"] = parent_ctx.obj["cacert"] cfg = parent_ctx.obj["config"] - cfg = cfg.with_parameters(endpoint=host, insecure=insecure) + cfg = replace(cfg, endpoint=host, insecure=insecure) return _friendly_client.SynchronousFlyteClient(cfg, **kwargs) diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index 73274e972e..d922f5e3c1 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -1,5 +1,7 @@ +import sys from typing import Tuple, Union +import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan from flyteidl.admin.task_pb2 import TaskSpec from flyteidl.admin.workflow_pb2 import WorkflowSpec @@ -125,3 +127,9 @@ def hydrate_registration_parameters( del entity.sub_workflows[:] entity.sub_workflows.extend(refreshed_sub_workflows) return identifier, entity + + +def display_help_with_error(ctx: click.Context, message: str): + click.echo(f"{ctx.get_help()}\n") + click.secho(message, fg="red") + sys.exit(1) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py new file mode 100644 index 0000000000..a9a9c4900d --- /dev/null +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -0,0 +1,32 @@ +import click + +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE +from flytekit.configuration import Config +from flytekit.loggers import cli_logger +from flytekit.remote.remote import FlyteRemote + +FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" + + +def get_and_save_remote_with_click_context( + ctx: click.Context, project: str, domain: str, save: bool = True +) -> FlyteRemote: + """ + NB: This function will by default mutate the click Context.obj dictionary, adding a remote key with value + of the created FlyteRemote object. + + :param ctx: the click context object + :param project: default project for the remote instance + :param domain: default domain + :param save: If false, will not mutate the context.obj dict + :return: FlyteRemote instance + """ + cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) + cfg_obj = Config.auto(cfg_file_location) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") + ) + r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) + if save: + ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r + return r diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 71efeab576..2a884e29da 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,8 +1,8 @@ import os -import sys import click +from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants from flytekit.configuration import ( DEFAULT_RUNTIME_PYTHON_INTERPRETER, @@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: - click.secho("No packages to scan for flyte entities. Aborting!", fg="red") - sys.exit(-1) + display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index c2b5f2b045..76777c5663 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -5,6 +5,7 @@ from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package +from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.configuration.internal import LocalSDK @@ -68,6 +69,7 @@ def main(ctx, pkgs=None, config=None): main.add_command(local_cache) main.add_command(init) main.add_command(run) +main.add_command(register) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py new file mode 100644 index 0000000000..03e00d7896 --- /dev/null +++ b/flytekit/clis/sdk_in_container/register.py @@ -0,0 +1,179 @@ +import os +import pathlib +import typing + +import click + +from flytekit.clis.helpers import display_help_with_error +from flytekit.clis.sdk_in_container import constants +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration.default_images import DefaultImages +from flytekit.loggers import cli_logger +from flytekit.tools.fast_registration import fast_package +from flytekit.tools.repo import find_common_root, load_packages_and_modules +from flytekit.tools.repo import register as repo_register +from flytekit.tools.translator import Options + +_register_help = """ +This command is similar to package but instead of producing a zip file, all your Flyte entities are compiled, +and then sent to the backend specified by your config file. Think of this as combining the pyflyte package +and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl +like service account here. + +Note: This command runs "fast" register by default. Future work to come to add a non-fast version. +This means that a zip is created from the detected root of the packages given, and uploaded. Just like with +pyflyte run, tasks registered from this command will download and unzip that code package before running. + +Note: This command only works on regular Python packages, not namespace packages. When determining + the root of your project, it finds the first folder that does not have an __init__.py file. +""" + + +@click.command("register", help=_register_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Project to register and run this workflow in", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Domain to register and run this workflow in", +) +@click.option( + "-i", + "--image", + "image_config", + required=False, + multiple=True, + type=click.UNPROCESSED, + callback=ImageConfig.validate_image, + default=[DefaultImages.default_image()], + help="A fully qualified tag for an docker image, e.g. somedocker.com/myimage:someversion123. This is a " + "multi-option and can be of the form --image xyz.io/docker:latest " + "--image my_image=xyz.io/docker2:latest. Note, the `name=image_uri`. The name is optional, if not " + "provided the image will be used as the default image. All the names have to be unique, and thus " + "there can only be one --image option with no name.", +) +@click.option( + "-o", + "--output", + required=False, + type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True), + default=None, + help="Directory to write the output zip file containing the protobuf definitions", +) +@click.option( + "-d", + "--destination-dir", + required=False, + type=str, + default="/root", + help="Directory inside the image where the tar file containing the code will be copied to", +) +@click.option( + "--service-account", + required=False, + type=str, + default="", + help="Service account used when creating launch plans", +) +@click.option( + "--raw-data-prefix", + required=False, + type=str, + default="", + help="Raw output data prefix when creating launch plans, where offloaded data will be stored", +) +@click.option( + "-v", + "--version", + required=False, + type=str, + help="Version the package or module is registered with", +) +@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) +@click.pass_context +def register( + ctx: click.Context, + project: str, + domain: str, + image_config: ImageConfig, + output: str, + destination_dir: str, + service_account: str, + raw_data_prefix: str, + version: typing.Optional[str], + package_or_module: typing.Tuple[str], +): + """ + see help + """ + pkgs = ctx.obj[constants.CTX_PACKAGES] + if not pkgs: + cli_logger.debug("No pkgs") + if pkgs: + raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") + + if len(package_or_module) == 0: + display_help_with_error( + ctx, + "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", + ) + + cli_logger.debug( + f"Running pyflyte register from {os.getcwd()} " + f"with images {image_config} " + f"and image destinationfolder {destination_dir} " + f"on {len(package_or_module)} package(s) {package_or_module}" + ) + + # Create and save FlyteRemote, + remote = get_and_save_remote_with_click_context(ctx, project, domain) + + # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings + # Create a zip file containing all the entries. + detected_root = find_common_root(package_or_module) + cli_logger.debug(f"Using {detected_root} as root folder for project") + zip_file = fast_package(detected_root, output) + + # Upload zip file to Admin using FlyteRemote. + md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) + cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") + + # Create serialization settings + # Todo: Rely on default Python interpreter for now, this will break custom Spark containers + serialization_settings = SerializationSettings( + project=project, + domain=domain, + image_config=image_config, + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ), + ) + + options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) + + # Load all the entities + registerable_entities = load_packages_and_modules( + serialization_settings, detected_root, list(package_or_module), options + ) + if len(registerable_entities) == 0: + display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") + cli_logger.info(f"Found and serialized {len(registerable_entities)} entities") + + if not version: + version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa + cli_logger.info(f"Computed version is {version}") + + # Register using repo code + repo_register(registerable_entities, project, domain, version, remote.client) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index b6f723fe0c..bd7b8ee20b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -2,6 +2,7 @@ import functools import importlib import json +import logging import os import pathlib import typing @@ -11,10 +12,12 @@ import click from dataclasses_json import DataClassJsonMixin from pytimeparse import parse +from typing_extensions import get_args from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT -from flytekit.configuration import Config, ImageConfig, SerializationSettings +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT +from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context +from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager, tracker from flytekit.core.base_task import PythonTask @@ -22,19 +25,17 @@ from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase -from flytekit.loggers import cli_logger from flytekit.models import literals from flytekit.models.interface import Variable from flytekit.models.literals import Blob, BlobMetadata, Primitive from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution -from flytekit.remote.remote import FlyteRemote from flytekit.tools import module_loader, script_mode +from flytekit.tools.script_mode import _find_project_root from flytekit.tools.translator import Options REMOTE_FLAG_KEY = "remote" RUN_LEVEL_PARAMS_KEY = "run_level_params" -FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" DATA_PROXY_CALLBACK_KEY = "data_proxy" @@ -244,6 +245,31 @@ def convert_to_blob( return lit + def convert_to_union( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any + ) -> Literal: + lt = self._literal_type + for i in range(len(self._literal_type.union_type.variants)): + variant = self._literal_type.union_type.variants[i] + python_type = get_args(self._python_type)[i] + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + variant, + python_type, + self._create_upload_fn, + ) + try: + # Here we use click converter to convert the input in command line to native python type, + # and then use flyte converter to convert it to literal. + python_val = converter._click_type.convert(value, param, ctx) + literal = converter.convert_to_literal(ctx, param, python_val) + self._python_type = python_type + return literal + except (Exception or AttributeError) as e: + logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) + raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") + def convert_to_literal( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any ) -> Literal: @@ -255,7 +281,7 @@ def convert_to_literal( if self._literal_type.collection_type or self._literal_type.map_value_type: # TODO Does not support nested flytefile, flyteschema types - v = json.loads(value) + v = json.loads(value) if isinstance(value, str) else value if self._literal_type.collection_type and not isinstance(v, list): raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}") if self._literal_type.map_value_type and not isinstance(v, dict): @@ -263,11 +289,14 @@ def convert_to_literal( return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type) if self._literal_type.union_type: - raise NotImplementedError("Union type is not yet implemented for pyflyte run") + return self.convert_to_union(ctx, param, value) if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - o = cast(DataClassJsonMixin, self._python_type).from_json(value) + if type(value) != self._python_type: + o = cast(DataClassJsonMixin, self._python_type).from_json(value) + else: + o = value return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) return Literal(scalar=self._converter.convert(value, self._python_type)) @@ -396,16 +425,14 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: ] -def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]: +def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]: """ Load the workflow of a the script file. N.B.: it assumes that the file is self-contained, in other words, there are no relative imports. """ - flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - SerializationSettings(None) - ) - with context_manager.FlyteContextManager.with_context(flyte_ctx): - with module_loader.add_sys_path(os.getcwd()): + flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder() + with context_manager.FlyteContextManager.with_context(flyte_ctx_builder): + with module_loader.add_sys_path(project_root): importlib.import_module(module_name) return module_loader.load_object_from_module(f"{module_name}.{entity_name}") @@ -444,9 +471,7 @@ def get_entities_in_file(filename: str) -> Entities: """ Returns a list of flyte workflow names and list of Flyte tasks in a file. """ - flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - SerializationSettings(None) - ) + flyte_ctx = context_manager.FlyteContextManager.current_context().new_builder() module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): @@ -473,6 +498,8 @@ def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, """ def _run(*args, **kwargs): + # By the time we get to this function, all the loading has already happened + run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] project, domain = run_level_params.get("project"), run_level_params.get("domain") inputs = {} @@ -486,10 +513,6 @@ def _run(*args, **kwargs): remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] - # StructuredDatasetTransformerEngine.register( - # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True - # ) - remote_entity = remote.register_script( entity, project=project, @@ -532,32 +555,41 @@ class WorkflowCommand(click.MultiCommand): def __init__(self, filename: str, *args, **kwargs): super().__init__(*args, **kwargs) - self._filename = filename + self._filename = pathlib.Path(filename).resolve() def list_commands(self, ctx): entities = get_entities_in_file(self._filename) return entities.all() def get_command(self, ctx, exe_entity): + """ + This command uses the filename with which this command was created, and the string name of the entity passed + after the Python filename on the command line, to load the Python object, and then return the Command that + click should run. + :param ctx: The click Context object. + :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task + function. + :return: + """ + rel_path = os.path.relpath(self._filename) if rel_path.startswith(".."): raise ValueError( f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}" ) + project_root = _find_project_root(self._filename) + # Find the relative path for the filename relative to the root of the project. + # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as + # a parameter. + rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - entity = load_naive_entity(module, exe_entity) + entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT) d = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN) - cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Run is using config object {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) - r = FlyteRemote(cfg_obj, default_project=p, default_domain=d) - ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r + r = get_and_save_remote_with_click_context(ctx, p, d) get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=p, domain=d) flyte_ctx = context_manager.FlyteContextManager.current_context() @@ -596,8 +628,16 @@ def get_command(self, ctx, filename): return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") +_run_help = """ +This command can execute either a workflow or a task from the commandline, for fully self-contained scripts. +Tasks and workflows cannot be imported from other files currently. Please use `pyflyte package` or +`pyflyte register` to handle those and then launch from the Flyte UI or `flytectl` + +Note: This command only works on regular Python packages, not namespace packages. When determining + the root of your project, it finds the first folder that does not have an __init__.py file. +""" + run = RunCommand( name="run", - help="Run command: This command can execute either a workflow or a task from the commandline, for " - "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.", + help=_run_help, ) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 188f6ebeee..5354abde4f 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -297,6 +297,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL + :param insecure_skip_verify: Wether to skip SSL certificate verification :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -309,32 +310,13 @@ class PlatformConfig(object): endpoint: str = "localhost:30081" insecure: bool = False + insecure_skip_verify: bool = False command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD - def with_parameters( - self, - endpoint: str = "localhost:30081", - insecure: bool = False, - command: typing.Optional[typing.List[str]] = None, - client_id: typing.Optional[str] = None, - client_credentials_secret: typing.Optional[str] = None, - scopes: List[str] = None, - auth_mode: AuthType = AuthType.STANDARD, - ) -> PlatformConfig: - return PlatformConfig( - endpoint=endpoint, - insecure=insecure, - command=command, - client_id=client_id, - client_credentials_secret=client_credentials_secret, - scopes=scopes if scopes else [], - auth_mode=auth_mode, - ) - @classmethod def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig: """ @@ -345,6 +327,9 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "insecure", _internal.Platform.INSECURE.read(config_file)) + kwargs = set_if_exists( + kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file) + ) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( @@ -562,7 +547,7 @@ def for_sandbox(cls) -> Config: :return: Config """ return Config( - platform=PlatformConfig(insecure=True), + platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True), data_config=DataConfig( s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage") ), diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index b329a16112..467f660d42 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -32,13 +32,16 @@ class LegacyConfigEntry(object): option: str type_: typing.Type = str + def get_env_name(self): + return f"FLYTE_{self.section.upper()}_{self.option.upper()}" + def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]: """ Reads the config entry from environment variable, the structure of the env var is current ``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future. :return: """ - env = f"FLYTE_{self.section.upper()}_{self.option.upper()}" + env = self.get_env_name() v = os.environ.get(env, None) if v is None: return None @@ -159,7 +162,7 @@ def __init__(self, location: str): Load the config from this location """ self._location = location - if location.endswith("yaml"): + if location.endswith("yaml") or location.endswith("yml"): self._legacy_config = None self._yaml_config = self._read_yaml_config(location) else: diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 55018c1caf..fb09015b6f 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -19,7 +19,7 @@ def get_specified_images(cfg: ConfigFile) -> typing.Dict[str, str]: :returns a dictionary of name: image Version is optional """ images: typing.Dict[str, str] = {} - if cfg is None: + if cfg is None or not cfg.legacy_config: return images try: image_names = cfg.legacy_config.options("images") @@ -105,6 +105,9 @@ class Platform(object): LegacyConfigEntry(SECTION, "url"), YamlConfigEntry("admin.endpoint"), lambda x: x.replace("dns:///", "") ) INSECURE = ConfigEntry(LegacyConfigEntry(SECTION, "insecure", bool), YamlConfigEntry("admin.insecure", bool)) + INSECURE_SKIP_VERIFY = ConfigEntry( + LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool) + ) class LocalSDK(object): diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index cc3319e359..6aef36305e 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -463,7 +463,6 @@ def dispatch_execute( new_user_params = self.pre_execute(ctx.user_space_params) from flytekit.deck.deck import _output_deck - new_user_params._decks = [ctx.user_space_params.default_deck] # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) @@ -539,11 +538,9 @@ def dispatch_execute( for k, v in native_outputs_as_map.items(): output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) - new_user_params.decks.append(input_deck) - new_user_params.decks.append(output_deck) - if _internal.Deck.DISABLE_DECK.read() is not True and self.disable_deck is False: _output_deck(self.name.split(".")[-1], new_user_params) + outputs_literal_map = _literal_models.LiteralMap(literals=literals) # After the execute has been successfully completed return outputs_literal_map diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 2d16b87aab..a17a9148b6 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -21,12 +21,12 @@ import traceback import typing from contextlib import contextmanager +from contextvars import ContextVar from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import Generator, List, Optional, Union -import flytekit from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils @@ -38,10 +38,14 @@ from flytekit.loggers import logger, user_space_logger from flytekit.models.core import identifier as _identifier +if typing.TYPE_CHECKING: + from flytekit.deck.deck import Deck + # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin # Enables static type checking https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING +flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[]) if typing.TYPE_CHECKING: from flytekit.core.base_task import TaskResolverMixin @@ -78,12 +82,13 @@ class Builder(object): stats: taggable.TaggableStats execution_date: datetime logging: _logging.Logger - execution_id: str + execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] attrs: typing.Dict[str, typing.Any] working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir] checkpoint: typing.Optional[Checkpoint] - decks: List[flytekit.Deck] + decks: List[Deck] raw_output_prefix: str + task_id: typing.Optional[_identifier.Identifier] def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.stats = current.stats if current else None @@ -95,6 +100,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.decks = current._decks if current else [] self.attrs = current._attrs if current else {} self.raw_output_prefix = current.raw_output_prefix if current else None + self.task_id = current.task_id if current else None def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: self.attrs[key] = v @@ -112,6 +118,7 @@ def build(self) -> ExecutionParameters: checkpoint=self.checkpoint, decks=self.decks, raw_output_prefix=self.raw_output_prefix, + task_id=self.task_id, **self.attrs, ) @@ -141,11 +148,12 @@ def __init__( execution_date, tmp_dir, stats, - execution_id, + execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier], logging, raw_output_prefix, checkpoint=None, decks=None, + task_id: typing.Optional[_identifier.Identifier] = None, **kwargs, ): """ @@ -171,6 +179,7 @@ def __init__( self._secrets_manager = SecretsManager() self._checkpoint = checkpoint self._decks = decks + self._task_id = task_id @property def stats(self) -> taggable.TaggableStats: @@ -228,6 +237,14 @@ def execution_id(self) -> _identifier.WorkflowExecutionIdentifier: """ return self._execution_id + @property + def task_id(self) -> typing.Optional[_identifier.Identifier]: + """ + At production run-time, this will be generated by reading environment variables that are set + by the backend. + """ + return self._task_id + @property def secrets(self) -> SecretsManager: return self._secrets_manager @@ -246,8 +263,8 @@ def decks(self) -> typing.List: return self._decks @property - def default_deck(self) -> "Deck": - from flytekit import Deck + def default_deck(self) -> Deck: + from flytekit.deck.deck import Deck return Deck("default") @@ -600,6 +617,11 @@ def current_context() -> Optional[FlyteContext]: """ return FlyteContextManager.current_context() + def get_deck(self) -> str: + from flytekit.deck.deck import _get_deck + + return _get_deck(self.execution_state.user_space_params) + @dataclass class Builder(object): file_access: FileAccessProvider @@ -701,8 +723,6 @@ class FlyteContextManager(object): FlyteContextManager.pop_context() """ - _OBJS: typing.List[FlyteContext] = [] - @staticmethod def get_origin_stackframe(limit=2) -> traceback.FrameSummary: ss = traceback.extract_stack(limit=limit + 1) @@ -711,31 +731,36 @@ def get_origin_stackframe(limit=2) -> traceback.FrameSummary: return ss[0] @staticmethod - def current_context() -> Optional[FlyteContext]: - if FlyteContextManager._OBJS: - return FlyteContextManager._OBJS[-1] - return None + def current_context() -> FlyteContext: + if not flyte_context_Var.get(): + # we will lost the default flyte context in the new thread. Therefore, reinitialize the context when running in the new thread. + FlyteContextManager.initialize() + return flyte_context_Var.get()[-1] @staticmethod def push_context(ctx: FlyteContext, f: Optional[traceback.FrameSummary] = None) -> FlyteContext: if not f: f = FlyteContextManager.get_origin_stackframe(limit=2) ctx.set_stackframe(f) - FlyteContextManager._OBJS.append(ctx) + context_list = flyte_context_Var.get() + context_list.append(ctx) + flyte_context_Var.set(context_list) t = "\t" logger.debug( - f"{t * ctx.level}[{len(FlyteContextManager._OBJS)}] Pushing context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}" + f"{t * ctx.level}[{len(flyte_context_Var.get())}] Pushing context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}" ) return ctx @staticmethod def pop_context() -> FlyteContext: - ctx = FlyteContextManager._OBJS.pop() + context_list = flyte_context_Var.get() + ctx = context_list.pop() + flyte_context_Var.set(context_list) t = "\t" logger.debug( - f"{t * ctx.level}[{len(FlyteContextManager._OBJS) + 1}] Popping context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}" + f"{t * ctx.level}[{len(flyte_context_Var.get()) + 1}] Popping context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}" ) - if len(FlyteContextManager._OBJS) == 0: + if len(flyte_context_Var.get()) == 0: raise AssertionError(f"Illegal Context state! Popped, {ctx}") return ctx @@ -766,7 +791,7 @@ def with_context(b: FlyteContext.Builder) -> Generator[FlyteContext, None, None] @staticmethod def size() -> int: - return len(FlyteContextManager._OBJS) + return len(flyte_context_Var.get()) @staticmethod def initialize(): @@ -786,6 +811,7 @@ def initialize(): default_context = FlyteContext(file_access=default_local_file_access_provider) default_user_space_params = ExecutionParameters( execution_id=WorkflowExecutionIdentifier.promote_from_model(default_execution_id), + task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"), execution_date=_datetime.datetime.utcnow(), stats=mock_stats.MockStats(), logging=user_space_logger, @@ -798,7 +824,7 @@ def initialize(): default_context.new_execution_state().with_params(user_space_params=default_user_space_params) ).build() default_context.set_stackframe(s=FlyteContextManager.get_origin_stackframe()) - FlyteContextManager._OBJS = [default_context] + flyte_context_Var.set([default_context]) class FlyteEntities(object): diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index fdd47d1741..0f651410bf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,12 +7,15 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing_extensions import get_args, get_origin, get_type_hints + from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.models.literals import Void from flytekit.types.pickle import FlytePickle T = typing.TypeVar("T") @@ -182,11 +185,17 @@ def transform_inputs_to_parameters( inputs_with_def = interface.inputs_with_defaults for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] - required = _default is None - default_lv = None - if _default is not None: - default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) - params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) + if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): + from flytekit import Literal, Scalar + + literal = Literal(scalar=Scalar(none_type=Void())) + params[k] = _interface_models.Parameter(var=v, default=literal, required=False) + else: + required = _default is None + default_lv = None + if _default is not None: + default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) + params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) return _interface_models.ParameterMap(params) @@ -274,11 +283,8 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - try: - # include_extras can only be used in python >= 3.9 - type_hints = typing.get_type_hints(fn, include_extras=True) - except TypeError: - type_hints = typing.get_type_hints(fn) + + type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) @@ -386,7 +392,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(typing.get_type_hints(return_annotation)) + return dict(get_type_hints(return_annotation, include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index e1f53dc5d2..e8f5bd4aa3 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -292,7 +292,6 @@ def get_or_create( LaunchPlan.CACHE[name or workflow.name] = lp return lp - # TODO: Add QoS after it's done def __init__( self, name: str, @@ -359,6 +358,10 @@ def clone_with( def python_interface(self) -> Interface: return self.workflow.python_interface + @property + def interface(self) -> _interface_models.TypedInterface: + return self.workflow.interface + @property def name(self) -> str: return self._name diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 6eb4f51d81..1694af6d70 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,7 @@ from __future__ import annotations import collections -from typing import Type, Union +from typing import TYPE_CHECKING, Type, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -12,11 +12,15 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import logger +if TYPE_CHECKING: + from flytekit.remote.remote_callable import RemoteEntity + + # This file exists instead of moving to node.py because it needs Task/Workflow/LaunchPlan and those depend on Node def create_node( - entity: Union[PythonTask, LaunchPlan, WorkflowBase], *args, **kwargs + entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs ) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or @@ -65,6 +69,8 @@ def sub_wf(): t2(t1_node.o0) """ + from flytekit.remote.remote_callable import RemoteEntity + if len(args) > 0: raise _user_exceptions.FlyteAssertion( f"Only keyword args are supported to pass inputs to workflows and tasks." @@ -75,8 +81,9 @@ def sub_wf(): not isinstance(entity, PythonTask) and not isinstance(entity, WorkflowBase) and not isinstance(entity, LaunchPlan) + and not isinstance(entity, RemoteEntity) ): - raise AssertionError("Should be but it's not") + raise AssertionError(f"Should be a callable Flyte entity (either local or fetched) but is {type(entity)}") # This function is only called from inside workflows and dynamic tasks. # That means there are two scenarios we need to take care of, compilation and local workflow execution. @@ -84,7 +91,6 @@ def sub_wf(): # When compiling, calling the entity will create a node. ctx = FlyteContext.current_context() if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: - outputs = entity(**kwargs) # This is always the output of create_and_link_node which returns create_task_output, which can be # VoidPromise, Promise, or our custom namedtuple of Promises. @@ -105,9 +111,11 @@ def sub_wf(): return node # If a Promise or custom namedtuple of Promises, we need to attach each output as an attribute to the node. - if entity.python_interface.outputs: + # todo: fix the noqas below somehow... can't add abstract property to RemoteEntity because it has to come + # before the model Template classes in FlyteTask/Workflow/LaunchPlan + if entity.interface.outputs: # noqa if isinstance(outputs, tuple): - for output_name in entity.python_interface.output_names: + for output_name in entity.interface.outputs.keys(): # noqa attr = getattr(outputs, output_name) if attr is None: raise _user_exceptions.FlyteAssertion( @@ -120,7 +128,7 @@ def sub_wf(): setattr(node, output_name, attr) node.outputs[output_name] = attr else: - output_names = entity.python_interface.output_names + output_names = [k for k in entity.interface.outputs.keys()] # noqa if len(output_names) != 1: raise _user_exceptions.FlyteAssertion(f"Output of length 1 expected but {len(output_names)} found") @@ -136,6 +144,9 @@ def sub_wf(): # Handling local execution elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if isinstance(entity, RemoteEntity): + raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") + if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: logger.warning(f"Manual node creation cannot be used in branch logic {entity.name}") raise Exception("Being more restrictive for now and disallowing manual node creation in branch logic") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 99ead16b2c..4c9150881d 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast -from typing_extensions import Protocol +from typing_extensions import Protocol, get_args from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -23,6 +23,7 @@ from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive +from flytekit.models.types import SimpleType def translate_inputs_to_literals( @@ -68,29 +69,43 @@ def extract_value( ) -> _literal_models.Literal: if isinstance(input_val, list): - if flyte_literal_type.collection_type is None: + lt = flyte_literal_type + python_type = val_type + if flyte_literal_type.union_type: + for i in range(len(flyte_literal_type.union_type.variants)): + variant = flyte_literal_type.union_type.variants[i] + if variant.collection_type: + lt = variant + python_type = get_args(val_type)[i] + if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(val_type) + sub_type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise sub_type = type(input_val[0]) - literal_list = [extract_value(ctx, v, sub_type, flyte_literal_type.collection_type) for v in input_val] + literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): - if ( - flyte_literal_type.map_value_type is None - and flyte_literal_type.simple != _type_models.SimpleType.STRUCT - ): - raise TypeError(f"Not a map type {flyte_literal_type} but got a map {input_val}") - k_type, sub_type = DictTransformer.get_dict_types(val_type) # type: ignore - if flyte_literal_type.simple == _type_models.SimpleType.STRUCT: - return TypeEngine.to_literal(ctx, input_val, type(input_val), flyte_literal_type) + lt = flyte_literal_type + python_type = val_type + if flyte_literal_type.union_type: + for i in range(len(flyte_literal_type.union_type.variants)): + variant = flyte_literal_type.union_type.variants[i] + if variant.map_value_type: + lt = variant + python_type = get_args(val_type)[i] + if variant.simple == _type_models.SimpleType.STRUCT: + lt = variant + python_type = get_args(val_type)[i] + if lt.map_value_type is None and lt.simple != _type_models.SimpleType.STRUCT: + raise TypeError(f"Not a map type {lt} but got a map {input_val}") + if lt.simple == _type_models.SimpleType.STRUCT: + return TypeEngine.to_literal(ctx, input_val, type(input_val), lt) else: - literal_map = { - k: extract_value(ctx, v, sub_type, flyte_literal_type.map_value_type) for k, v in input_val.items() - } + k_type, sub_type = DictTransformer.get_dict_types(python_type) # type: ignore + literal_map = {k: extract_value(ctx, v, sub_type, lt.map_value_type) for k, v in input_val.items()} return _literal_models.Literal(map=_literal_models.LiteralMap(literals=literal_map)) elif isinstance(input_val, Promise): # In the example above, this handles the "in2=a" type of argument @@ -863,7 +878,7 @@ def create_and_link_node( ctx: FlyteContext, entity: SupportsNodeCreation, **kwargs, -): +) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ This method is used to generate a node with bindings. This is not used in the execution path. """ @@ -881,7 +896,20 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + is_optional = False + if var.type.union_type: + for variant in var.type.union_type.variants: + if variant.simple == SimpleType.NONE: + val, _default = interface.inputs_with_defaults[k] + if _default is not None: + raise ValueError( + f"The default value for the optional type must be None, but got {_default}" + ) + is_optional = True + if not is_optional: + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + else: + continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed @@ -999,6 +1027,7 @@ def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwar ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: + cast(FlyteContext, child_ctx).user_space_params._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0fff171f36..4b6743afd3 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -156,7 +156,10 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._get_command_fn(settings) def get_container(self, settings: SerializationSettings) -> _task_model.Container: - env = {**settings.env, **self.environment} if self.environment else settings.env + env = {} + for elem in (settings.env, self.environment): + if elem: + env.update(elem) return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), command=[], @@ -253,4 +256,4 @@ def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> st # fqn will access the fully qualified name of the image (e.g. registry/imagename:version -> registry/imagename) # version will access the version part of the image (e.g. registry/imagename:version -> version) # With empty attribute, it'll access the full image path (e.g. registry/imagename:version -> registry/imagename:version) -_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z]+))(?:\.([a-zA-Z]+))?\s*}})", re.IGNORECASE) +_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z0-9_]+))(?:\.([a-zA-Z0-9_]+))?\s*}})", re.IGNORECASE) diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 0c5fe786ae..7addc89197 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -15,13 +15,14 @@ # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. class CronSchedule(_schedule_models.Schedule): """ - Use this when you have a launch plan that you want to run on a cron expression. The syntax currently used for this - follows the `AWS convention `__ + Use this when you have a launch plan that you want to run on a cron expression. + This uses standard `cron format `__ + in case where you are using default native scheduler using the schedule attribute. .. code-block:: CronSchedule( - cron_expression="0 10 * * ? *", + schedule="*/1 * * * *", # Following schedule runs every min ) See the :std:ref:`User Guide ` for further examples. @@ -54,9 +55,10 @@ def __init__( self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None ): """ - :param str cron_expression: This should be a cron expression in AWS style. + :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. :param str schedule: This takes a cron alias (see ``_VALID_CRON_ALIASES``) or a croniter parseable schedule. - Only one of this or ``cron_expression`` can be set, not both. + Only one of this or ``cron_expression`` can be set, not both. This uses standard `cron format `_ + and is supported by native scheduler :param str offset: :param str kickoff_time_input_arg: This is a convenient argument to use when your code needs to know what time a run was kicked off. Supply the name of the input argument of your workflow to this argument here. Note @@ -67,7 +69,7 @@ def __init__( def my_wf(kickoff_time: datetime): ... schedule = CronSchedule( - cron_expression="0 10 * * ? *", + schedule="*/1 * * * *" kickoff_time_input_arg="kickoff_time") """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 60c42ac339..b3851e77ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -319,7 +319,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) ) - def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. """ @@ -328,36 +328,42 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset - for f in dataclasses.fields(python_type): - v = python_val.__getattribute__(f.name) - field_type = f.type - if inspect.isclass(field_type) and ( - issubclass(field_type, FlyteSchema) - or issubclass(field_type, FlyteFile) - or issubclass(field_type, FlyteDirectory) - or issubclass(field_type, StructuredDataset) - ): - lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None) - # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a - # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the - # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, - # so that dataclass_json can always get a remote path. - # In other words, the file transformer has special code that handles the fact that if remote_source is - # set, then the real uri in the literal should be the remote source, not the path (which may be an - # auto-generated random local path). To be sure we're writing the right path to the json, use the uri - # as determined by the transformer. - if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory): - python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri)) - elif issubclass(field_type, StructuredDataset): - python_val.__setattr__( - f.name, - field_type( - uri=lv.scalar.structured_dataset.uri, - ), - ) + if hasattr(python_type, "__origin__") and python_type.__origin__ is list: + return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + + if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: + return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} - elif dataclasses.is_dataclass(field_type): - self._serialize_flyte_type(v, field_type) + if not dataclasses.is_dataclass(python_type): + return python_val + + if inspect.isclass(python_type) and ( + issubclass(python_type, FlyteSchema) + or issubclass(python_type, FlyteFile) + or issubclass(python_type, FlyteDirectory) + or issubclass(python_type, StructuredDataset) + ): + lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None) + # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a + # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the + # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, + # so that dataclass_json can always get a remote path. + # In other words, the file transformer has special code that handles the fact that if remote_source is + # set, then the real uri in the literal should be the remote source, not the path (which may be an + # auto-generated random local path). To be sure we're writing the right path to the json, use the uri + # as determined by the transformer. + if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory): + return python_type(path=lv.scalar.blob.uri) + elif issubclass(python_type, StructuredDataset): + return python_type(uri=lv.scalar.structured_dataset.uri) + else: + return python_val + else: + for v in dataclasses.fields(python_type): + val = python_val.__getattribute__(v.name) + field_type = v.type + python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type)) + return python_val def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer @@ -365,6 +371,12 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: + return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] + + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is dict: + return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()} + if not dataclasses.is_dataclass(expected_python_type): return python_val @@ -737,11 +749,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) != len(python_types): + if len(lm.literals) > len(python_types): raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} @classmethod def dict_to_literal_map( @@ -993,7 +1005,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp # Should really never happen, sanity check raise TypeError("Ambiguous choice of variant for union type") found_res = True - except TypeTransformerFailedError as e: + except (TypeTransformerFailedError, AttributeError) as e: logger.debug(f"Failed to convert from {python_val} to {t}", e) continue @@ -1047,7 +1059,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) res_tag = trans.name found_res = True - except TypeTransformerFailedError as e: + except (TypeTransformerFailedError, AttributeError) as e: logger.debug(f"Failed to convert from {lv} to {v}", e) if found_res: diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index 435416e0ff..599b886ab0 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -1,9 +1,13 @@ import os -from typing import Dict, Optional +from typing import Optional from jinja2 import Environment, FileSystemLoader -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager +from flytekit.loggers import logger + +OUTPUT_DIR_JUPYTER_PREFIX = "jupyter" +DECK_FILE_NAME = "deck.html" class Deck: @@ -69,30 +73,43 @@ def html(self) -> str: return self._html -def _output_deck(task_name: str, new_user_params: ExecutionParameters): - deck_map: Dict[str, str] = {} - decks = new_user_params.decks - ctx = FlyteContext.current_context() +def _ipython_check() -> bool: + """ + Check if interface is launching from iPython (not colab) + :return is_ipython (bool): True or False + """ + is_ipython = False + try: # Check if running interactively using ipython. + from IPython import get_ipython - # TODO: upload deck file to remote filesystems (s3, gcs) - output_dir = ctx.file_access.get_random_local_directory() + if get_ipython() is not None: + is_ipython = True + except (ImportError, NameError): + pass + return is_ipython - for deck in decks: - _deck_to_html_file(deck, deck_map, output_dir) - root = os.path.dirname(os.path.abspath(__file__)) - templates_dir = os.path.join(root, "html") - env = Environment(loader=FileSystemLoader(templates_dir)) - template = env.get_template("template.html") +def _get_deck(new_user_params: ExecutionParameters) -> str: + """ + Get flyte deck html string + """ + deck_map = {deck.name: deck.html for deck in new_user_params.decks} + return template.render(metadata=deck_map) + - deck_path = os.path.join(output_dir, "deck.html") +def _output_deck(task_name: str, new_user_params: ExecutionParameters): + ctx = FlyteContext.current_context() + if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: + output_dir = ctx.execution_state.engine_dir + else: + output_dir = ctx.file_access.get_random_local_directory() + deck_path = os.path.join(output_dir, DECK_FILE_NAME) with open(deck_path, "w") as f: - f.write(template.render(metadata=deck_map)) + f.write(_get_deck(new_user_params)) + logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}") -def _deck_to_html_file(deck: Deck, deck_map: Dict[str, str], output_dir: str): - file_name = deck.name + ".html" - path = os.path.join(output_dir, file_name) - with open(path, "w") as output: - deck_map[deck.name] = file_name - output.write(deck.html) +root = os.path.dirname(os.path.abspath(__file__)) +templates_dir = os.path.join(root, "html") +env = Environment(loader=FileSystemLoader(templates_dir)) +template = env.get_template("template.html") diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html index 1a429a2ab6..3992ab9c0f 100644 --- a/flytekit/deck/html/template.html +++ b/flytekit/deck/html/template.html @@ -4,56 +4,12 @@ User Content + - - -
-
+ +
+ {% for key, value in metadata.items() %} +
{{value}}
+ {% endfor %} +
diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py new file mode 100644 index 0000000000..ae077d9755 --- /dev/null +++ b/flytekit/extras/pytorch/__init__.py @@ -0,0 +1,20 @@ +""" +Flytekit PyTorch +========================================= +.. currentmodule:: flytekit.extras.pytorch + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PyTorchCheckpoint +""" +from flytekit.loggers import logger + +try: + from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer + from .native import PyTorchModuleTransformer, PyTorchTensorTransformer +except ImportError: + logger.info( + "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." + ) diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py new file mode 100644 index 0000000000..c7561f13f4 --- /dev/null +++ b/flytekit/extras/pytorch/checkpoint.py @@ -0,0 +1,137 @@ +import pathlib +import typing +from dataclasses import asdict, dataclass, fields, is_dataclass +from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union + +import torch +from dataclasses_json import dataclass_json +from typing_extensions import Protocol + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class IsDataclass(Protocol): + __dataclass_fields__: Dict + __dataclass_params__: Dict + __post_init__: Optional[Callable] + + +@dataclass_json +@dataclass +class PyTorchCheckpoint: + """ + This class is helpful to save a checkpoint. + """ + + module: Optional[torch.nn.Module] = None + hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None + optimizer: Optional[torch.optim.Optimizer] = None + + def __post_init__(self): + if not ( + isinstance(self.hyperparameters, dict) + or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type)) + or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields")) + or (self.hyperparameters is None) + ): + raise TypeTransformerFailedError( + f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}" + ) + + if not (self.module or self.hyperparameters or self.optimizer): + raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer") + + +class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]): + """ + TypeTransformer that supports serializing and deserializing checkpoint. + """ + + PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint" + + def __init__(self): + super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint) + + def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: PyTorchCheckpoint, + python_type: Type[PyTorchCheckpoint], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".pt" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + to_save = {} + for field in fields(python_val): + value = getattr(python_val, field.name) + + if value and field.name in ["module", "optimizer"]: + to_save[field.name + "_state_dict"] = getattr(value, "state_dict")() + elif value and field.name == "hyperparameters": + if isinstance(value, dict): + to_save.update(value) + elif isinstance(value, tuple): + to_save.update(value._asdict()) + elif is_dataclass(value): + to_save.update(asdict(value)) + + if not to_save: + raise TypeTransformerFailedError(f"Cannot save empty {python_val}") + + # save checkpoint to a file + torch.save(to_save, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint] + ) -> PyTorchCheckpoint: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # cpu <-> gpu conversion + if torch.cuda.is_available(): + map_location = "cuda:0" + else: + map_location = torch.device("cpu") + + # load checkpoint from a file + return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location)) + + def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT + ): + return PyTorchCheckpoint + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(PyTorchCheckpointTransformer()) diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py new file mode 100644 index 0000000000..4cf37871fb --- /dev/null +++ b/flytekit/extras/pytorch/native.py @@ -0,0 +1,92 @@ +import pathlib +from typing import Generic, Type, TypeVar + +import torch + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = TypeVar("T") + + +class PyTorchTypeTransformer(TypeTransformer, Generic[T]): + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.PYTORCH_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.PYTORCH_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".pt" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save pytorch tensor/module to a file + torch.save(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # cpu <-> gpu conversion + if torch.cuda.is_available(): + map_location = "cuda:0" + else: + map_location = torch.device("cpu") + + # load pytorch tensor/module from a file + return torch.load(local_path, map_location=map_location) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_FORMAT + ): + return T + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): + PYTORCH_FORMAT = "PyTorchTensor" + + def __init__(self): + super().__init__(name="PyTorch Tensor", t=torch.Tensor) + + +class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]): + PYTORCH_FORMAT = "PyTorchModule" + + def __init__(self): + super().__init__(name="PyTorch Module", t=torch.nn.Module) + + +TypeEngine.register(PyTorchTensorTransformer()) +TypeEngine.register(PyTorchModuleTransformer()) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 762dfd196f..220db5cc5f 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -92,6 +92,7 @@ def __init__( started_at, duration, output_uri=None, + deck_uri=None, error=None, workflow_node_metadata: typing.Optional[WorkflowNodeMetadata] = None, task_node_metadata: typing.Optional[TaskNodeMetadata] = None, @@ -107,6 +108,7 @@ def __init__( self._started_at = started_at self._duration = duration self._output_uri = output_uri + self._deck_uri = deck_uri self._error = error self._workflow_node_metadata = workflow_node_metadata self._task_node_metadata = task_node_metadata @@ -140,6 +142,13 @@ def output_uri(self): """ return self._output_uri + @property + def deck_uri(self): + """ + :rtype: str + """ + return self._deck_uri + @property def error(self): """ @@ -166,6 +175,7 @@ def to_flyte_idl(self): obj = _node_execution_pb2.NodeExecutionClosure( phase=self.phase, output_uri=self.output_uri, + deck_uri=self.deck_uri, error=self.error.to_flyte_idl() if self.error is not None else None, workflow_node_metadata=self.workflow_node_metadata.to_flyte_idl() if self.workflow_node_metadata is not None @@ -185,6 +195,7 @@ def from_flyte_idl(cls, p): return cls( phase=p.phase, output_uri=p.output_uri if p.HasField("output_uri") else None, + deck_uri=p.deck_uri, error=_core_execution.ExecutionError.from_flyte_idl(p.error) if p.HasField("error") else None, started_at=p.started_at.ToDatetime().replace(tzinfo=_pytz.UTC), duration=p.duration.ToTimedelta(), diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index a00ac97222..99f54b7933 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -9,6 +9,7 @@ import functools import hashlib import os +import pathlib import time import typing import uuid @@ -22,7 +23,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings -from flytekit.core import constants, tracker, utils +from flytekit.core import constants, utils from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider @@ -59,7 +60,7 @@ from flytekit.remote.remote_callable import RemoteEntity from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow -from flytekit.tools.script_mode import fast_register_single_script +from flytekit.tools.script_mode import fast_register_single_script, hash_file from flytekit.tools.translator import FlyteLocalEntity, Options, get_serializable, get_serializable_launch_plan ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] @@ -507,6 +508,65 @@ def register_workflow( fwf._python_interface = entity.python_interface return fwf + def _upload_file( + self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None + ) -> typing.Tuple[bytes, str]: + """ + Function will use remote's client to hash and then upload the file using Admin's data proxy service. + + :param to_upload: Must be a single file + :param project: Project to upload under, if not supplied will use the remote's default + :param domain: Domain to upload under, if not specified will use the remote's default + :return: The uploaded location. + """ + if not to_upload.is_file(): + raise ValueError(f"{to_upload} is not a single file, upload arg must be a single file.") + md5_bytes, str_digest = hash_file(to_upload) + remote_logger.debug(f"Text hash of file to upload is {str_digest}") + + upload_location = self.client.get_upload_signed_url( + project=project or self.default_project, + domain=domain or self.default_domain, + content_md5=md5_bytes, + filename=to_upload.name, + ) + self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) + remote_logger.warning( + f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" + ) + + return md5_bytes, upload_location.native_url + + @staticmethod + def _version_from_hash( + md5_bytes: bytes, + serialization_settings: SerializationSettings, + *additional_context: str, + ) -> str: + """ + The md5 version that we send to S3/GCS has to match the file contents exactly, + but we don't have to use it when registering with the Flyte backend. + To avoid changes in the For that add the hash of the compilation settings to hash of file + + :param md5_bytes: + :param serialization_settings: + :param additional_context: This is for additional context to factor into the version computation, + meant for objects (like Options for instance) that don't easily consistently stringify. + :return: + """ + from flytekit import __version__ + + additional_context = additional_context or [] + + h = hashlib.md5(md5_bytes) + h.update(bytes(serialization_settings.to_json(), "utf-8")) + h.update(bytes(__version__, "utf-8")) + + for s in additional_context: + h.update(bytes(s, "utf-8")) + + return base64.urlsafe_b64encode(h.digest()).decode("ascii") + def register_script( self, entity: typing.Union[WorkflowBase, PythonTask], @@ -530,8 +590,6 @@ def register_script( :param options: Additional execution options that can be configured for the default launchplan :return: """ - _, _, _, fname = tracker.extract_task_module(entity) - if image_config is None: image_config = ImageConfig.auto_default_image() @@ -560,12 +618,7 @@ def register_script( # The md5 version that we send to S3/GCS has to match the file contents exactly, # but we don't have to use it when registering with the Flyte backend. # For that add the hash of the compilation settings to hash of file - from flytekit import __version__ - - h = hashlib.md5(md5_bytes) - h.update(bytes(serialization_settings.to_json(), "utf-8")) - h.update(bytes(__version__, "utf-8")) - version = base64.urlsafe_b64encode(h.digest()) + version = self._version_from_hash(md5_bytes, serialization_settings) if isinstance(entity, PythonTask): return self.register_task(entity, serialization_settings, version) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index ade2903a26..c4ac31a01a 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -1,14 +1,19 @@ from __future__ import annotations +import gzip import hashlib import os import posixpath import subprocess as _subprocess import tarfile +import tempfile from typing import Optional +import click + from flytekit.core.context_manager import FlyteContextManager from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore +from flytekit.tools.script_mode import tar_strip_file_attributes FAST_PREFIX = "fast" FAST_FILEENDING = ".tar.gz" @@ -28,11 +33,19 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: digest = compute_digest(source, ignore.is_ignored) archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" - if output_dir: - archive_fname = os.path.join(output_dir, archive_fname) + if output_dir is None: + output_dir = tempfile.mkdtemp() + click.secho(f"Output given as {None}, using a temporary directory at {output_dir} instead", fg="yellow") + + archive_fname = os.path.join(output_dir, archive_fname) - with tarfile.open(archive_fname, "w:gz") as tar: - tar.add(source, arcname="", filter=ignore.tar_filter) + with tempfile.TemporaryDirectory() as tmp_dir: + tar_path = os.path.join(tmp_dir, "tmp.tar") + with tarfile.open(tar_path, "w") as tar: + tar.add(source, arcname="", filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x))) + with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: + with open(tar_path, "rb") as tar_file: + gzipped.write(tar_file.read()) return archive_fname diff --git a/flytekit/tools/ignore.py b/flytekit/tools/ignore.py index 769727c4dd..49fab06d02 100644 --- a/flytekit/tools/ignore.py +++ b/flytekit/tools/ignore.py @@ -5,7 +5,7 @@ from fnmatch import fnmatch from pathlib import Path from shutil import which -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from docker.utils.build import PatternMatcher @@ -105,7 +105,7 @@ class IgnoreGroup(Ignore): """Groups multiple Ignores and checks a path against them. A file is ignored if any Ignore considers it ignored.""" - def __init__(self, root: str, ignores: List[Ignore]): + def __init__(self, root: str, ignores: List[Type[Ignore]]): super().__init__(root) self.ignores = [ignore(root) for ignore in ignores] diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index bc0c46bbbf..dc3a6bb9f4 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -6,19 +6,6 @@ from typing import Any, Iterator, List, Union -def iterate_modules(pkgs): - for package_name in pkgs: - package = importlib.import_module(package_name) - yield package - - # Check if package is a python file. If so, there is no reason to walk. - if not hasattr(package, "__path__"): - continue - - for _, name, _ in pkgutil.walk_packages(package.__path__, prefix="{}.".format(package_name)): - yield importlib.import_module(name) - - @contextlib.contextmanager def add_sys_path(path: Union[str, os.PathLike]) -> Iterator[None]: """Temporarily add given path to `sys.path`.""" @@ -36,37 +23,14 @@ def just_load_modules(pkgs: List[str]): """ for package_name in pkgs: package = importlib.import_module(package_name) - for _, name, _ in pkgutil.walk_packages(package.__path__, prefix="{}.".format(package_name)): - importlib.import_module(name) - - -def load_workflow_modules(pkgs): - """ - Load all modules and packages at and under the given package. Used for finding workflows/tasks to register. - :param list[Text] pkgs: List of dot separated string containing paths folders (packages) containing - the modules (python files) - :raises ImportError - """ - for _ in iterate_modules(pkgs): - pass - - -def load_module_object_for_type(pkgs, t, additional_path=None): - def iterate(): - entity_to_module_key = {} - for m in iterate_modules(pkgs): - for k in dir(m): - o = m.__dict__[k] - if isinstance(o, t): - entity_to_module_key[o] = (m.__name__, k) - return entity_to_module_key + # If it doesn't have a __path__ field, that means it's not a package, just a module + if not hasattr(package, "__path__"): + continue - if additional_path is not None: - with add_sys_path(additional_path): - return iterate() - else: - return iterate() + # Note that walk_packages takes an onerror arg and swallows import errors silently otherwise + for _, name, _ in pkgutil.walk_packages(package.__path__, prefix=f"{package_name}."): + importlib.import_module(name) def load_object_from_module(object_location: str) -> Any: @@ -78,24 +42,3 @@ def load_object_from_module(object_location: str) -> Any: class_obj_key = class_obj[-1] # e.g. 'default_task_class_obj' class_obj_mod = importlib.import_module(".".join(class_obj_mod)) return getattr(class_obj_mod, class_obj_key) - - -def trigger_loading( - pkgs, - local_source_root=None, -): - """ - This function will iterate all discovered entities in the given package list. It will then attempt to - topologically sort such that any entity with a dependency on another comes later in the list. Note that workflows - can reference other workflows and launch plans. - - :param list[Text] pkgs: - :param Text local_source_root: - """ - if local_source_root is not None: - with add_sys_path(local_source_root): - for _ in iterate_modules(pkgs): - ... - else: - for _ in iterate_modules(pkgs): - ... diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 63b21324a5..167c772184 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -1,18 +1,26 @@ +import os import tarfile import tempfile import typing +from pathlib import Path import click +from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan +from flyteidl.admin.launch_plan_pb2 import LaunchPlanCreateRequest +from flyteidl.admin.task_pb2 import TaskCreateRequest +from flyteidl.admin.task_pb2 import TaskSpec as _idl_admin_TaskSpec +from flyteidl.admin.workflow_pb2 import WorkflowCreateRequest +from flyteidl.admin.workflow_pb2 import WorkflowSpec as _idl_admin_WorkflowSpec +from flyteidl.core import identifier_pb2 -from flytekit import FlyteContextManager, logger from flytekit.clients.friendly import SynchronousFlyteClient +from flytekit.clis.helpers import hydrate_registration_parameters from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import FlyteContextManager from flytekit.exceptions.user import FlyteEntityAlreadyExistsException -from flytekit.models import launch_plan as launch_plan_models -from flytekit.models import task as task_models -from flytekit.models.admin import workflow as admin_workflow_models -from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.loggers import logger from flytekit.tools import fast_registration, module_loader +from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import RegistrableEntity, get_registrable_entities, persist_registrable_entities from flytekit.tools.translator import Options @@ -83,6 +91,10 @@ def package( # If Fast serialization is enabled, then an archive is also created and packaged if fast: + # If output exists and is a path within source, delete it so as to not re-bundle it again. + if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output): + click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow") + os.remove(output) archive_fname = fast_registration.fast_package(source, output_tmpdir) click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True) @@ -113,41 +125,100 @@ def register( domain: str, version: str, client: SynchronousFlyteClient, - source: str = ".", - fast: bool = False, ): - if fast: - # TODO handle fast - raise AssertionError("Fast not handled yet!") - for entity, cp_entity in registrable_entities: + # The incoming registrable entities are already in base protobuf form, not model form, so we use the + # raw client's methods instead of the friendly client's methods by calling super + for admin_entity in registrable_entities: try: - if isinstance(cp_entity, task_models.TaskSpec): - ident = Identifier( - resource_type=ResourceType.TASK, project=project, domain=domain, name=entity.name, version=version + if isinstance(admin_entity, _idl_admin_TaskSpec): + ident, task_spec = hydrate_registration_parameters( + identifier_pb2.TASK, project, domain, version, admin_entity + ) + logger.debug(f"Creating task {ident}") + super(SynchronousFlyteClient, client).create_task(TaskCreateRequest(id=ident, spec=task_spec)) + elif isinstance(admin_entity, _idl_admin_WorkflowSpec): + ident, wf_spec = hydrate_registration_parameters( + identifier_pb2.WORKFLOW, project, domain, version, admin_entity ) - client.create_task(task_identifer=ident, task_spec=cp_entity) - elif isinstance(cp_entity, admin_workflow_models.WorkflowSpec): - ident = Identifier( - resource_type=ResourceType.WORKFLOW, - project=project, - domain=domain, - name=entity.name, - version=version, + logger.debug(f"Creating workflow {ident}") + super(SynchronousFlyteClient, client).create_workflow(WorkflowCreateRequest(id=ident, spec=wf_spec)) + elif isinstance(admin_entity, _idl_admin_LaunchPlan): + ident, admin_lp = hydrate_registration_parameters( + identifier_pb2.LAUNCH_PLAN, project, domain, version, admin_entity ) - client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) - elif isinstance(cp_entity, launch_plan_models.LaunchPlanSpec): - ident = Identifier( - resource_type=ResourceType.LAUNCH_PLAN, - project=project, - domain=domain, - name=entity.name, - version=version, + logger.debug(f"Creating launch plan {ident}") + super(SynchronousFlyteClient, client).create_launch_plan( + LaunchPlanCreateRequest(id=ident, spec=admin_lp.spec) ) - client.create_launch_plan(launch_plan_identifer=ident, launch_plan_spec=cp_entity) else: - raise AssertionError(f"Unknown entity of type {type(cp_entity)}") + raise AssertionError(f"Unknown entity of type {type(admin_entity)}") except FlyteEntityAlreadyExistsException: - logger.info(f"{entity.name} already exists") + logger.info(f"{admin_entity} already exists") except Exception as e: - logger.info(f"Failed to register entity {entity.name} with error {e}") + logger.info(f"Failed to register entity {admin_entity} with error {e}") raise e + + +def find_common_root( + pkgs_or_mods: typing.Union[typing.Tuple[str], typing.List[str]], +) -> Path: + """ + Given an arbitrary list of folders and files, this function will use the script mode function to walk up + the filesystem to find the first folder without an init file. If all the folders and files resolve to + the same root folder, then that Path is returned. Otherwise an error is raised. + + :param pkgs_or_mods: + :return: The common detected root path, the output of _find_project_root + """ + project_root = None + for pm in pkgs_or_mods: + root = _find_project_root(pm) + if project_root is None: + project_root = root + else: + if project_root != root: + raise ValueError(f"Specified module {pm} has root {root} but {project_root} already specified") + + logger.debug(f"Common root folder detected as {str(project_root)}") + + return project_root + + +def load_packages_and_modules( + ss: SerializationSettings, + project_root: Path, + pkgs_or_mods: typing.List[str], + options: typing.Optional[Options] = None, +) -> typing.List[RegistrableEntity]: + """ + The project root is added as the first entry to sys.path, and then all the specified packages and modules + given are loaded with all submodules. The reason for prepending the entry is to ensure that the name that + the various modules are loaded under are the fully-resolved name. + + For example, using flytesnacks cookbook, if you are in core/ and you call this function with + ``flyte_basics/hello_world.py control_flow/``, the ``hello_world`` module would be loaded + as ``core.flyte_basics.hello_world`` even though you're already in the core/ folder. + + :param ss: + :param project_root: + :param pkgs_or_mods: + :param options: + :return: The common detected root path, the output of _find_project_root + """ + + pkgs_and_modules = [] + for pm in pkgs_or_mods: + p = Path(pm).resolve() + rel_path_from_root = p.relative_to(project_root) + # One day we should learn how to do this right. This is not the right way to load a python module + # from a file. See pydoc.importfile for inspiration + dot_delineated = os.path.splitext(rel_path_from_root)[0].replace(os.path.sep, ".") # noqa + + logger.debug( + f"User specified arg {pm} has {str(rel_path_from_root)} relative path loading it as {dot_delineated}" + ) + pkgs_and_modules.append(dot_delineated) + + registrable_entities = serialize(pkgs_and_modules, ss, str(project_root), options) + + return registrable_entities diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 7c8821da77..f43ecdf5fd 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -14,7 +14,7 @@ from flytekit.core.workflow import WorkflowBase -def compress_single_script(absolute_project_path: str, destination: str, full_module_name: str): +def compress_single_script(source_path: str, destination: str, full_module_name: str): """ Compresses the single script while maintaining the folder structure for that file. @@ -42,7 +42,6 @@ def compress_single_script(absolute_project_path: str, destination: str, full_mo Note how `another_example.py` and `yet_another_example.py` were not copied to the destination. """ with tempfile.TemporaryDirectory() as tmp_dir: - source_path = os.path.join(absolute_project_path) destination_path = os.path.join(tmp_dir, "code") # This is the script relative path to the root of the project script_relative_path = Path() @@ -55,7 +54,7 @@ def compress_single_script(absolute_project_path: str, destination: str, full_mo destination_path = os.path.join(destination_path, p) script_relative_path = Path(script_relative_path, p) init_file = Path(os.path.join(source_path, "__init__.py")) - if init_file.exists: + if init_file.exists(): shutil.copy(init_file, Path(os.path.join(tmp_dir, "code", script_relative_path, "__init__.py"))) # Ensure destination path exists to cover the case of a single file and no modules. @@ -139,10 +138,14 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): def _find_project_root(source_path) -> Path: """ - Traverse from current working directory until it can no longer find __init__.py files + Find the root of the project. + The root of the project is considered to be the first ancestor from source_path that does + not contain a __init__.py file. + + N.B.: This assumption only holds for regular packages (as opposed to namespace packages) """ # Start from the directory right above source_path - path = Path(source_path).parents[0] + path = Path(source_path).parent.resolve() while os.path.exists(os.path.join(path, "__init__.py")): path = path.parent return path diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 7c1969afa7..9fd7f05e89 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -20,9 +20,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.tools.translator import Options, get_serializable -RegistrableEntity = typing.Union[ - task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec -] +RegistrableEntity = typing.Union[_idl_admin_TaskSpec, _idl_admin_LaunchPlan, _idl_admin_WorkflowSpec] def _determine_text_chars(length): @@ -78,7 +76,7 @@ def get_registrable_entities( if isinstance(entity, WorkflowBase): lp = LaunchPlan.get_default_launch_plan(ctx, entity) - get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) + get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options) new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) @@ -114,7 +112,6 @@ def persist_registrable_entities(entities: typing.List[RegistrableEntity], folde """ zero_padded_length = _determine_text_chars(len(entities)) for i, entity in enumerate(entities): - name = "" fname_index = str(i).zfill(zero_padded_length) if isinstance(entity, _idl_admin_TaskSpec): name = entity.template.id.name diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 045ab7970a..57f863f055 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -289,7 +289,9 @@ def get_serializable_workflow( nodes=upstream_node_models, outputs=entity.output_bindings, ) - return admin_workflow_models.WorkflowSpec(template=wf_t, sub_workflows=list(set(sub_wfs))) + return admin_workflow_models.WorkflowSpec( + template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()) + ) def get_serializable_launch_plan( @@ -322,7 +324,10 @@ def get_serializable_launch_plan( if not options: options = Options() - raw = None + if options and options.raw_output_data_config: + raw_prefix_config = options.raw_output_data_config + else: + raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("") lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_id, @@ -335,7 +340,7 @@ def get_serializable_launch_plan( labels=options.labels or entity.labels or _common_models.Labels({}), annotations=options.annotations or entity.annotations or _common_models.Annotations({}), auth_role=None, - raw_output_data_config=raw or entity.raw_output_data_config or _common_models.RawOutputDataConfig(""), + raw_output_data_config=raw_prefix_config, max_parallelism=options.max_parallelism or entity.max_parallelism, security_context=options.security_context or entity.security_context, ) diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 44841d7e35..9e8fca1971 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -75,3 +75,8 @@ #: Can be used to receive or return a CSVFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. CSVFile = FlyteFile[csv] + +onnx = typing.TypeVar("onnx") +#: Can be used to receive or return an ONNXFile. The underlying type is a FlyteFile type. This is just a +#: decoration and useful for attaching content type information with the file and automatically documenting code. +ONNXFile = FlyteFile[onnx] diff --git a/flytekit/types/numpy/__init__.py b/flytekit/types/numpy/__init__.py new file mode 100644 index 0000000000..ec20e87970 --- /dev/null +++ b/flytekit/types/numpy/__init__.py @@ -0,0 +1 @@ +from .ndarray import NumpyArrayTransformer diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py new file mode 100644 index 0000000000..cb1cf2a900 --- /dev/null +++ b/flytekit/types/numpy/ndarray.py @@ -0,0 +1,74 @@ +import pathlib +import typing +from typing import Type + +import numpy as np + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class NumpyArrayTransformer(TypeTransformer[np.ndarray]): + """ + TypeTransformer that supports np.ndarray as a native type. + """ + + NUMPY_ARRAY_FORMAT = "NumpyArray" + + def __init__(self): + super().__init__(name="Numpy Array", t=np.ndarray) + + def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + def to_literal( + self, ctx: FlyteContext, python_val: np.ndarray, python_type: Type[np.ndarray], expected: LiteralType + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".npy" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save numpy array to a file + # allow_pickle=False prevents numpy from trying to save object arrays (dtype=object) using pickle + np.save(file=local_path, arr=python_val, allow_pickle=False) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray]) -> np.ndarray: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # load numpy array from a file + return np.load(file=local_path) + + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.NUMPY_ARRAY_FORMAT + ): + return np.ndarray + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(NumpyArrayTransformer()) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 3dd8b06235..52577a650d 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -10,7 +10,6 @@ StructuredDataset StructuredDatasetEncoder StructuredDatasetDecoder - StructuredDatasetTransformerEngine """ diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index a4488fbc1e..af60599f0c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -19,6 +19,8 @@ if importlib.util.find_spec("pyspark") is not None: import pyspark +if importlib.util.find_spec("polars") is not None: + import polars as pl from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin @@ -647,6 +649,9 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ return pd.DataFrame(df).describe().to_html() elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame): return pd.DataFrame(df.schema, columns=["StructField"]).to_html() + elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame): + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) else: raise NotImplementedError("Conversion to html string should be implemented") diff --git a/plugins/flytekit-aws-athena/requirements.txt b/plugins/flytekit-aws-athena/requirements.txt index 1df58a22b4..41c8b68f54 100644 --- a/plugins/flytekit-aws-athena/requirements.txt +++ b/plugins/flytekit-aws-athena/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-athena -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-aws-batch/requirements.txt b/plugins/flytekit-aws-batch/requirements.txt index e5969b85ea..14947ba486 100644 --- a/plugins/flytekit-aws-batch/requirements.txt +++ b/plugins/flytekit-aws-batch/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-awsbatch -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-aws-sagemaker/requirements.txt b/plugins/flytekit-aws-sagemaker/requirements.txt index d2ffac5f8b..a4806c45c3 100644 --- a/plugins/flytekit-aws-sagemaker/requirements.txt +++ b/plugins/flytekit-aws-sagemaker/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -8,40 +8,41 @@ # via -r requirements.in arrow==1.2.2 # via jinja2-time -bcrypt==3.2.0 +bcrypt==3.2.2 # via paramiko binaryornot==0.4.4 # via cookiecutter -boto3==1.21.46 +boto3==1.24.22 # via sagemaker-training -botocore==1.24.46 +botocore==1.27.22 # via # boto3 # s3transfer -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via # bcrypt # cryptography # pynacl -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 +cryptography==37.0.4 # via # paramiko + # pyopenssl # secretstorage dataclasses-json==0.5.7 # via flytekit @@ -55,51 +56,54 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-awssagemaker gevent==21.12.0 # via sagemaker-training -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status greenlet==1.1.2 # via gevent -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring inotify-simple==1.2.1 # via sagemaker-training jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -jmespath==1.0.0 +jmespath==1.0.1 # via # boto3 # botocore -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -112,20 +116,19 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow # sagemaker-training # scipy packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -paramiko==2.10.3 +paramiko==2.11.0 # via sagemaker-training -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -136,7 +139,7 @@ protobuf==3.20.1 # sagemaker-training protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.0 +psutil==5.9.1 # via sagemaker-training py==1.11.0 # via retry @@ -146,7 +149,9 @@ pycparser==2.21 # via cffi pynacl==1.5.0 # via paramiko -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -157,7 +162,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -166,33 +171,35 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.5.2 +s3transfer==0.6.0 # via boto3 sagemaker-training==3.9.2 # via flytekitplugins-awssagemaker -scipy==1.8.0 +scipy==1.7.3 # via sagemaker-training secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # bcrypt - # cookiecutter # grpcio # paramiko # python-dateutil @@ -204,9 +211,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -216,13 +226,13 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker -werkzeug==2.1.1 +werkzeug==2.1.2 # via sagemaker-training wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-bigquery/requirements.txt b/plugins/flytekit-bigquery/requirements.txt index ea5a1ed0ae..43f6e6c8ba 100644 --- a/plugins/flytekit-bigquery/requirements.txt +++ b/plugins/flytekit-bigquery/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,28 +10,30 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -cachetools==5.0.0 +cachetools==5.2.0 # via google-auth -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,65 +46,68 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-bigquery -google-api-core[grpc]==2.7.2 +google-api-core[grpc]==2.8.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.6.6 +google-auth==2.9.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.0.1 +google-cloud-bigquery==3.2.0 # via flytekitplugins-bigquery -google-cloud-bigquery-storage==2.13.1 +google-cloud-bigquery-storage==2.13.2 # via google-cloud-bigquery -google-cloud-core==2.3.0 +google-cloud-core==2.3.1 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media -google-resumable-media==2.3.2 +google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # google-api-core # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via # flytekit # google-api-core idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -115,19 +120,18 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via # google-cloud-bigquery # marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter -proto-plus==1.20.3 +proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -137,6 +141,7 @@ protobuf==3.20.1 # flytekit # google-api-core # google-cloud-bigquery + # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus @@ -157,7 +162,9 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -168,7 +175,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -177,10 +184,12 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker @@ -188,7 +197,7 @@ requests==2.27.1 # google-api-core # google-cloud-bigquery # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit @@ -196,9 +205,10 @@ rsa==4.8 # via google-auth secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # google-auth # grpcio # python-dateutil @@ -208,9 +218,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -219,11 +232,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-data-fsspec/requirements.txt b/plugins/flytekit-data-fsspec/requirements.txt index 9d20a425fb..b29beebd95 100644 --- a/plugins/flytekit-data-fsspec/requirements.txt +++ b/plugins/flytekit-data-fsspec/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,28 +10,30 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -botocore==1.24.46 +botocore==1.27.22 # via flytekitplugins-data-fsspec -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,45 +46,48 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-data-fsspec -fsspec==2022.3.0 +fsspec==2022.5.0 # via flytekitplugins-data-fsspec -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -jmespath==1.0.0 +jmespath==1.0.1 # via botocore -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -95,16 +100,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -120,7 +124,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -131,7 +137,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -140,24 +146,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -166,9 +175,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -178,11 +190,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-deck-standard/README.md b/plugins/flytekit-deck-standard/README.md index 22fbd3ea7c..719a2e77a8 100644 --- a/plugins/flytekit-deck-standard/README.md +++ b/plugins/flytekit-deck-standard/README.md @@ -5,5 +5,5 @@ This Plugin provides more renderers to improve task visibility. To install the plugin, run the following command: ```bash -pip install flytekitplugins-deck +pip install flytekitplugins-deck-standard ``` diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index d2b44e0b65..7e5aa30029 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -4,6 +4,7 @@ from pandas_profiling import ProfileReport + class FrameProfilingRenderer: """ Generate a ProfileReport based on a pandas DataFrame diff --git a/plugins/flytekit-deck-standard/requirements.in b/plugins/flytekit-deck-standard/requirements.in index 0b3a7eeb42..cfd03dbe82 100644 --- a/plugins/flytekit-deck-standard/requirements.in +++ b/plugins/flytekit-deck-standard/requirements.in @@ -1,2 +1,2 @@ . --e file:.#egg=flytekitplugins-deck +-e file:.#egg=flytekitplugins-deck-standard diff --git a/plugins/flytekit-deck-standard/requirements.txt b/plugins/flytekit-deck-standard/requirements.txt index b33fe58837..368fa073ea 100644 --- a/plugins/flytekit-deck-standard/requirements.txt +++ b/plugins/flytekit-deck-standard/requirements.txt @@ -1,8 +1,290 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in # --e file:.#egg=flytekitplugins-deck +-e file:.#egg=flytekitplugins-deck-standard # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +attrs==21.4.0 + # via visions +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage +cycler==0.11.0 + # via matplotlib +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-deck-standard +fonttools==4.33.3 + # via matplotlib +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +htmlmin==0.1.12 + # via pandas-profiling +idna==3.3 + # via requests +imagehash==4.2.1 + # via visions +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring + # markdown +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time + # pandas-profiling +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.1.0 + # via + # pandas-profiling + # phik +keyring==23.6.0 + # via flytekit +kiwisolver==1.4.3 + # via matplotlib +markdown==3.3.7 + # via flytekitplugins-deck-standard +markupsafe==2.1.1 + # via + # jinja2 + # pandas-profiling +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +matplotlib==3.5.2 + # via + # missingno + # pandas-profiling + # phik + # seaborn +missingno==0.5.1 + # via pandas-profiling +multimethod==1.8 + # via + # pandas-profiling + # visions +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +networkx==2.6.3 + # via visions +numpy==1.21.6 + # via + # flytekit + # imagehash + # matplotlib + # missingno + # pandas + # pandas-profiling + # phik + # pyarrow + # pywavelets + # scipy + # seaborn + # visions +packaging==21.3 + # via + # marshmallow + # matplotlib +pandas==1.3.5 + # via + # flytekit + # pandas-profiling + # phik + # seaborn + # visions +pandas-profiling==3.2.0 + # via flytekitplugins-deck-standard +phik==0.12.2 + # via pandas-profiling +pillow==9.2.0 + # via + # imagehash + # matplotlib + # visions +plotly==5.9.0 + # via flytekitplugins-deck-standard +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pydantic==1.9.1 + # via pandas-profiling +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via + # matplotlib + # packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # matplotlib + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pywavelets==1.3.0 + # via imagehash +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # pandas-profiling +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # pandas-profiling + # responses +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +scipy==1.7.3 + # via + # imagehash + # missingno + # pandas-profiling + # phik + # seaborn +seaborn==0.11.2 + # via + # missingno + # pandas-profiling +secretstorage==3.3.2 + # via keyring +singledispatchmethod==1.0 + # via flytekit +six==1.16.0 + # via + # grpcio + # imagehash + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +tangled-up-in-unicode==0.2.0 + # via + # pandas-profiling + # visions +tenacity==8.0.1 + # via plotly +text-unidecode==1.3 + # via python-slugify +tqdm==4.64.0 + # via pandas-profiling +typing-extensions==4.3.0 + # via + # arrow + # flytekit + # importlib-metadata + # kiwisolver + # pydantic + # responses + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +visions[type_image_path]==0.7.4 + # via pandas-profiling +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-dolt/requirements.txt b/plugins/flytekit-dolt/requirements.txt index d3f61895a2..09d78c7c78 100644 --- a/plugins/flytekit-dolt/requirements.txt +++ b/plugins/flytekit-dolt/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via # dolt-integrations @@ -44,45 +46,48 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit dolt-integrations==0.1.5 # via flytekitplugins-dolt doltcli==0.1.17 # via dolt-integrations -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-dolt -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -95,18 +100,17 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via # dolt-integrations # flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -122,7 +126,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -132,7 +138,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -141,24 +147,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -167,9 +176,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -178,11 +190,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-greatexpectations/requirements.txt b/plugins/flytekit-greatexpectations/requirements.txt index 62e839f8fb..e2f32a907f 100644 --- a/plugins/flytekit-greatexpectations/requirements.txt +++ b/plugins/flytekit-greatexpectations/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -14,44 +14,47 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.2 # via jinja2-time -asttokens==2.0.5 - # via stack-data attrs==21.4.0 # via jsonschema backcall==0.2.0 # via ipython +backports-zoneinfo==0.2.1 + # via + # pytz-deprecation-shim + # tzlocal beautifulsoup4==4.11.1 # via nbconvert binaryornot==0.4.4 # via cookiecutter -bleach==5.0.0 +bleach==5.0.1 # via nbconvert -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via # argon2-cffi-bindings # cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit # great-expectations -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -colorama==0.4.4 +colorama==0.4.5 # via great-expectations -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 +cryptography==37.0.4 # via # great-expectations + # pyopenssl # secretstorage dataclasses-json==0.5.7 # via flytekit @@ -71,44 +74,48 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit entrypoints==0.4 # via # altair # jupyter-client # nbconvert -executing==0.8.3 - # via stack-data fastjsonschema==2.15.3 # via nbformat -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-great-expectations -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -great-expectations==0.15.2 +great-expectations==0.15.12 # via flytekitplugins-great-expectations greenlet==1.1.2 # via sqlalchemy -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 +importlib-metadata==4.12.0 # via + # click + # flytekit # great-expectations + # jsonschema # keyring -ipykernel==6.13.0 + # sqlalchemy +importlib-resources==5.8.0 + # via jsonschema +ipykernel==6.15.0 # via notebook -ipython==8.2.0 +ipython==7.34.0 # via # great-expectations # ipykernel @@ -120,7 +127,7 @@ jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.0.3 +jinja2==3.1.2 # via # altair # cookiecutter @@ -134,12 +141,12 @@ jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.4.0 +jsonschema==4.6.1 # via # altair # great-expectations # nbformat -jupyter-client==7.2.2 +jupyter-client==7.3.4 # via # ipykernel # nbclient @@ -152,13 +159,13 @@ jupyter-core==4.10.0 # notebook jupyterlab-pygments==0.2.2 # via nbconvert -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via # jinja2 # nbconvert -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -171,7 +178,7 @@ matplotlib-inline==0.1.3 # via # ipykernel # ipython -mistune==0.8.4 +mistune==2.0.3 # via # great-expectations # nbconvert @@ -179,11 +186,11 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -nbclient==0.6.0 +nbclient==0.6.6 # via nbconvert -nbconvert==6.5.0 +nbconvert==7.0.0rc2 # via notebook -nbformat==5.3.0 +nbformat==5.4.0 # via # great-expectations # nbclient @@ -195,11 +202,12 @@ nest-asyncio==1.5.5 # jupyter-client # nbclient # notebook -notebook==6.4.11 +notebook==6.4.12 # via great-expectations -numpy==1.22.3 +numpy==1.21.6 # via # altair + # flytekit # great-expectations # pandas # pyarrow @@ -210,7 +218,7 @@ packaging==21.3 # ipykernel # marshmallow # nbconvert -pandas==1.4.2 +pandas==1.3.5 # via # altair # flytekit @@ -223,11 +231,9 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -poyo==0.5.0 - # via cookiecutter prometheus-client==0.14.1 # via notebook -prompt-toolkit==3.0.29 +prompt-toolkit==3.0.30 # via ipython protobuf==3.20.1 # via @@ -238,24 +244,24 @@ protobuf==3.20.1 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.0 +psutil==5.9.1 # via ipykernel ptyprocess==0.7.0 # via # pexpect # terminado -pure-eval==0.2.2 - # via stack-data py==1.11.0 # via retry pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pygments==2.11.2 +pygments==2.12.0 # via # ipython # nbconvert +pyopenssl==22.0.0 + # via flytekit pyparsing==2.4.7 # via # great-expectations @@ -272,7 +278,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -284,21 +290,24 @@ pytz==2022.1 pytz-deprecation-shim==0.1.0.post0 # via tzlocal pyyaml==6.0 - # via flytekit -pyzmq==22.3.0 # via + # cookiecutter + # flytekit +pyzmq==23.2.0 + # via + # ipykernel # jupyter-client # notebook -regex==2022.3.15 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # great-expectations # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit @@ -306,34 +315,32 @@ ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.6 # via ruamel-yaml -scipy==1.8.0 +scipy==1.7.3 # via great-expectations secretstorage==3.3.2 # via keyring send2trash==1.8.0 # via notebook +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # asttokens # bleach - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 # via flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 -sqlalchemy==1.4.35 +sqlalchemy==1.4.39 # via # -r requirements.in # flytekitplugins-great-expectations -stack-data==0.2.0 - # via ipython statsd==3.3.0 # via flytekit termcolor==1.1.0 # via great-expectations -terminado==0.13.3 +terminado==0.15.0 # via notebook text-unidecode==1.3 # via python-slugify @@ -341,7 +348,7 @@ tinycss2==1.1.1 # via nbconvert toolz==0.11.2 # via altair -tornado==6.1 +tornado==6.2 # via # ipykernel # jupyter-client @@ -349,7 +356,7 @@ tornado==6.1 # terminado tqdm==4.64.0 # via great-expectations -traitlets==5.1.1 +traitlets==5.3.0 # via # ipykernel # ipython @@ -360,10 +367,15 @@ traitlets==5.1.1 # nbconvert # nbformat # notebook -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # argon2-cffi + # arrow # flytekit # great-expectations + # importlib-metadata + # jsonschema + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -383,16 +395,18 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit zipp==3.8.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-greatexpectations/tests/test_schema.py b/plugins/flytekit-greatexpectations/tests/test_schema.py index 734357ba44..da9ab4f8e3 100644 --- a/plugins/flytekit-greatexpectations/tests/test_schema.py +++ b/plugins/flytekit-greatexpectations/tests/test_schema.py @@ -6,7 +6,7 @@ import pandas as pd import pytest from flytekitplugins.great_expectations import BatchRequestConfig, GreatExpectationsFlyteConfig, GreatExpectationsType -from great_expectations.exceptions import ValidationError +from great_expectations.exceptions import InvalidBatchRequestError, ValidationError from flytekit import task, workflow from flytekit.types.file import CSVFile @@ -144,7 +144,7 @@ def my_wf(): my_task(directory="my_assets") # Capture IndexError - with pytest.raises(IndexError): + with pytest.raises(InvalidBatchRequestError): my_wf() diff --git a/plugins/flytekit-greatexpectations/tests/test_task.py b/plugins/flytekit-greatexpectations/tests/test_task.py index 3d1af6f817..545f2f4cff 100644 --- a/plugins/flytekit-greatexpectations/tests/test_task.py +++ b/plugins/flytekit-greatexpectations/tests/test_task.py @@ -6,7 +6,7 @@ import pandas as pd import pytest from flytekitplugins.great_expectations import BatchRequestConfig, GreatExpectationsTask -from great_expectations.exceptions import ValidationError +from great_expectations.exceptions import InvalidBatchRequestError, ValidationError from flytekit import kwtypes, task, workflow from flytekit.types.file import CSVFile, FlyteFile @@ -82,7 +82,7 @@ def test_invalid_ge_batchrequest_pandas_config(): ) # Capture IndexError - with pytest.raises(IndexError): + with pytest.raises(InvalidBatchRequestError): task_object(data="my_assets") diff --git a/plugins/flytekit-hive/requirements.txt b/plugins/flytekit-hive/requirements.txt index 2597c52147..73e7a6cdb6 100644 --- a/plugins/flytekit-hive/requirements.txt +++ b/plugins/flytekit-hive/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-hive -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-k8s-pod/requirements.txt b/plugins/flytekit-k8s-pod/requirements.txt index c8ee0b67f3..4d8026a973 100644 --- a/plugins/flytekit-k8s-pod/requirements.txt +++ b/plugins/flytekit-k8s-pod/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,30 +10,32 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -cachetools==5.0.0 +cachetools==5.2.0 # via google-auth -certifi==2021.10.8 +certifi==2022.6.15 # via # kubernetes # requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -46,45 +48,48 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-pod -google-auth==2.6.6 +google-auth==2.9.0 # via kubernetes -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit -kubernetes==23.3.0 +kubernetes==24.2.0 # via flytekitplugins-pod markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -97,18 +102,17 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow oauthlib==3.2.0 # via requests-oauthlib packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -130,7 +134,9 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -141,7 +147,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -151,11 +157,12 @@ pytz==2022.1 # pandas pyyaml==6.0 # via + # cookiecutter # flytekit # kubernetes -regex==2022.3.15 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker @@ -165,7 +172,7 @@ requests==2.27.1 # responses requests-oauthlib==1.3.1 # via kubernetes -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit @@ -173,9 +180,10 @@ rsa==4.8 # via google-auth secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # google-auth # grpcio # kubernetes @@ -186,9 +194,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -198,13 +209,13 @@ urllib3==1.26.9 # kubernetes # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via # docker # kubernetes wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index 7f7c67a52f..36bd1f2895 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,43 +44,46 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via # flytekit # flytekitplugins-kfmpi -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-kfmpi -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -91,16 +96,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -116,7 +120,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -126,7 +132,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -135,24 +141,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -161,9 +170,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -172,11 +184,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index f1bd20567c..7a879873c0 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-kfpytorch -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index ce6cd23463..89a7e29bd1 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-kftensorflow -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-modin/requirements.in b/plugins/flytekit-modin/requirements.in index 4620f4beab..0248a83ad9 100644 --- a/plugins/flytekit-modin/requirements.in +++ b/plugins/flytekit-modin/requirements.in @@ -1,4 +1,2 @@ -grpcio<=1.43.0 -grpcio-status<=1.43.0 . -e file:.#egg=flytekitplugins-modin diff --git a/plugins/flytekit-modin/requirements.txt b/plugins/flytekit-modin/requirements.txt index 61e0034d7b..089d6f5678 100644 --- a/plugins/flytekit-modin/requirements.txt +++ b/plugins/flytekit-modin/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # pip-compile requirements.in @@ -57,9 +57,9 @@ filelock==3.6.0 # via # ray # virtualenv -flyteidl==0.24.21 +flyteidl==1.0.1 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0b2 # via flytekitplugins-modin frozenlist==1.3.0 # via @@ -75,20 +75,20 @@ googleapis-common-protos==1.56.0 # grpcio-status grpcio==1.43.0 # via - # -r requirements.in # flytekit # flytekitplugins-modin # grpcio-status # ray grpcio-status==1.43.0 # via - # -r requirements.in # flytekit # flytekitplugins-modin idna==3.3 # via requests importlib-metadata==4.11.3 # via keyring +importlib-resources==5.7.1 + # via jsonschema jeepney==0.8.0 # via # keyring @@ -231,4 +231,6 @@ wrapt==1.14.0 # deprecated # flytekit zipp==3.8.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-onnx-pytorch/README.md b/plugins/flytekit-onnx-pytorch/README.md new file mode 100644 index 0000000000..48bc736854 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX PyTorch Plugin + +This plugin allows you to generate ONNX models from your PyTorch models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxpytorch +``` diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py new file mode 100644 index 0000000000..384fb1cab3 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py @@ -0,0 +1 @@ +from .schema import PyTorch2ONNX, PyTorch2ONNXConfig diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py new file mode 100644 index 0000000000..7031867c8d --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +from dataclasses_json import dataclass_json +from torch.onnx import OperatorExportTypes, TrainingMode +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class PyTorch2ONNXConfig: + args: Union[Tuple, torch.Tensor] + export_params: bool = True + verbose: bool = False + training: TrainingMode = TrainingMode.EVAL + opset_version: int = 9 + input_names: List[str] = field(default_factory=list) + output_names: List[str] = field(default_factory=list) + operator_export_type: Optional[OperatorExportTypes] = None + do_constant_folding: bool = False + dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = field(default_factory=dict) + keep_initializers_as_inputs: Optional[bool] = None + custom_opsets: Dict[str, int] = field(default_factory=dict) + export_modules_as_functions: Union[bool, set[Type]] = False + + +@dataclass_json +@dataclass +class PyTorch2ONNX: + model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction] = field(default=None) + + +def extract_config(t: Type[PyTorch2ONNX]) -> Tuple[Type[PyTorch2ONNX], PyTorch2ONNXConfig]: + config = None + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, PyTorch2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type PyTorch2ONNXConfig") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + torch.onnx.export( + model, + **config, + f=local_path, + ) + + return local_path + + +class PyTorch2ONNXTransformer(TypeTransformer[PyTorch2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="PyTorch ONNX", t=PyTorch2ONNX) + + def get_literal_type(self, t: Type[PyTorch2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: PyTorch2ONNX, + python_type: Type[PyTorch2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + remote_path = ctx.file_access.get_random_remote_path() + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not (lv.scalar.blob.uri and lv.scalar.blob.metadata.format == self.ONNX_FORMAT): + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorch2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return PyTorch2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(PyTorch2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-pytorch/requirements.in b/plugins/flytekit-onnx-pytorch/requirements.in new file mode 100644 index 0000000000..7632919db4 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/requirements.in @@ -0,0 +1,5 @@ +. +-e file:.#egg=flytekitplugins-onnxpytorch +onnxruntime +pillow +torchvision>=0.12.0 diff --git a/plugins/flytekit-onnx-pytorch/requirements.txt b/plugins/flytekit-onnx-pytorch/requirements.txt new file mode 100644 index 0000000000..fa8ac445af --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/requirements.txt @@ -0,0 +1,199 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxpytorch + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==2.0 + # via onnxruntime +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxpytorch +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.6.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # onnxruntime + # pandas + # pyarrow + # torchvision +onnxruntime==1.11.1 + # via -r requirements.in +packaging==21.3 + # via marshmallow +pandas==1.4.3 + # via flytekit +pillow==9.2.0 + # via + # -r requirements.in + # torchvision +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnxruntime + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses + # torchvision +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +torch==1.12.0 + # via + # flytekitplugins-onnxpytorch + # torchvision +torchvision==0.13.0 + # via -r requirements.in +typing-extensions==4.3.0 + # via + # flytekit + # torch + # torchvision + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.10 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-onnx-pytorch/setup.py b/plugins/flytekit-onnx-pytorch/setup.py new file mode 100644 index 0000000000..74e3b940ec --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxpytorch" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "torch>=1.11.0"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX PyTorch Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-deck-standard/__init__.py b/plugins/flytekit-onnx-pytorch/tests/__init__.py similarity index 100% rename from plugins/flytekit-deck-standard/__init__.py rename to plugins/flytekit-onnx-pytorch/tests/__init__.py diff --git a/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py b/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py new file mode 100644 index 0000000000..3a704a4780 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py @@ -0,0 +1,128 @@ +# Some standard imports +from pathlib import Path + +import numpy as np +import onnxruntime +import requests +import torch.nn.init as init +import torch.onnx +import torch.utils.model_zoo as model_zoo +import torchvision.transforms as transforms +from flytekitplugins.onnxpytorch import PyTorch2ONNX, PyTorch2ONNXConfig +from PIL import Image +from torch import nn +from typing_extensions import Annotated + +import flytekit +from flytekit import task, workflow +from flytekit.types.file import JPEGImageFile, ONNXFile + + +class SuperResolutionNet(nn.Module): + def __init__(self, upscale_factor, inplace=False): + super(SuperResolutionNet, self).__init__() + + self.relu = nn.ReLU(inplace=inplace) + self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) + self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) + self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) + self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) + + self._initialize_weights() + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = self.pixel_shuffle(self.conv4(x)) + return x + + def _initialize_weights(self): + init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv4.weight) + + +def test_onnx_pytorch(): + @task + def train() -> Annotated[ + PyTorch2ONNX, + PyTorch2ONNXConfig( + args=torch.randn(1, 1, 224, 224, requires_grad=True), + export_params=True, # store the trained parameter weights inside + opset_version=10, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=["input"], # the model's input names + output_names=["output"], # the model's output names + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, # variable length axes + ), + ]: + # Create the super-resolution model by using the above model definition. + torch_model = SuperResolutionNet(upscale_factor=3) + + # Load pretrained model weights + model_url = "https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth" + + # Initialize model with the pretrained weights + map_location = lambda storage, loc: storage # noqa: E731 + if torch.cuda.is_available(): + map_location = None + torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) + + return PyTorch2ONNX(model=torch_model) + + @task + def onnx_predict(model_file: ONNXFile) -> JPEGImageFile: + ort_session = onnxruntime.InferenceSession(model_file.download()) + + img = Image.open( + requests.get( + "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg", stream=True + ).raw + ) + + resize = transforms.Resize([224, 224]) + img = resize(img) + + img_ycbcr = img.convert("YCbCr") + img_y, img_cb, img_cr = img_ycbcr.split() + + to_tensor = transforms.ToTensor() + img_y = to_tensor(img_y) + img_y.unsqueeze_(0) + + # compute ONNX Runtime output prediction + ort_inputs = { + ort_session.get_inputs()[0].name: img_y.detach().cpu().numpy() + if img_y.requires_grad + else img_y.cpu().numpy() + } + ort_outs = ort_session.run(None, ort_inputs) + img_out_y = ort_outs[0] + + img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L") + + # get the output image follow post-processing step from PyTorch implementation + final_img = Image.merge( + "YCbCr", + [ + img_out_y, + img_cb.resize(img_out_y.size, Image.BICUBIC), + img_cr.resize(img_out_y.size, Image.BICUBIC), + ], + ).convert("RGB") + + img_path = Path(flytekit.current_context().working_directory) / "cat_superres_with_ort.jpg" + final_img.save(img_path) + + # Save the image, we will compare this with the output image from mobile device + return JPEGImageFile(path=str(img_path)) + + @workflow + def wf() -> JPEGImageFile: + model = train() + return onnx_predict(model_file=model) + + print(wf()) diff --git a/plugins/flytekit-onnx-scikitlearn/README.md b/plugins/flytekit-onnx-scikitlearn/README.md new file mode 100644 index 0000000000..220a157090 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX ScikitLearn Plugin + +This plugin allows you to generate ONNX models from your ScikitLearn models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxscikitlearn +``` diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py new file mode 100644 index 0000000000..d09c317c07 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py @@ -0,0 +1 @@ +from .schema import ScikitLearn2ONNX, ScikitLearn2ONNXConfig diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py new file mode 100644 index 0000000000..db50986b5e --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import skl2onnx.common.data_types +from dataclasses_json import dataclass_json +from skl2onnx import convert_sklearn +from sklearn.base import BaseEstimator +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class ScikitLearn2ONNXConfig: + initial_types: List[Tuple[str, Type]] + name: Optional[str] = None + doc_string: str = "" + target_opset: Optional[int] = None + custom_conversion_functions: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + custom_shape_calculators: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + custom_parsers: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + options: Dict[Any, Any] = field(default_factory=dict) + intermediate: bool = False + naming: Union[str, Callable[..., Any]] = None + white_op: Optional[Set[str]] = None + black_op: Optional[Set[str]] = None + verbose: int = 0 + final_types: Optional[List[Tuple[str, Type]]] = None + + def __post_init__(self): + validate_initial_types = [ + True for item in self.initial_types if item in inspect.getmembers(skl2onnx.common.data_types) + ] + if not all(validate_initial_types): + raise ValueError("All types in initial_types must be in skl2onnx.common.data_types") + + if self.final_types: + validate_final_types = [ + True for item in self.final_types if item in inspect.getmembers(skl2onnx.common.data_types) + ] + if not all(validate_final_types): + raise ValueError("All types in final_types must be in skl2onnx.common.data_types") + + +@dataclass_json +@dataclass +class ScikitLearn2ONNX: + model: BaseEstimator = field(default=None) + + +def extract_config(t: Type[ScikitLearn2ONNX]) -> Tuple[Type[ScikitLearn2ONNX], ScikitLearn2ONNXConfig]: + config = None + + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, ScikitLearn2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type ScikitLearn2ONNXConfig") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + onx = convert_sklearn(model, **config) + + with open(local_path, "wb") as f: + f.write(onx.SerializeToString()) + + return local_path + + +class ScikitLearn2ONNXTransformer(TypeTransformer[ScikitLearn2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="ScikitLearn ONNX", t=ScikitLearn2ONNX) + + def get_literal_type(self, t: Type[ScikitLearn2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: ScikitLearn2ONNX, + python_type: Type[ScikitLearn2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + remote_path = ctx.file_access.get_random_remote_path() + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not lv.scalar.blob.uri: + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[ScikitLearn2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return ScikitLearn2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(ScikitLearn2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.in b/plugins/flytekit-onnx-scikitlearn/requirements.in new file mode 100644 index 0000000000..bdf7848c4a --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-onnxscikitlearn +onnxruntime diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.txt b/plugins/flytekit-onnx-scikitlearn/requirements.txt new file mode 100644 index 0000000000..7ff826a50b --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/requirements.txt @@ -0,0 +1,212 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxscikitlearn + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==2.0 + # via onnxruntime +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxscikitlearn +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.1.0 + # via scikit-learn +keyring==23.6.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # onnx + # onnxconverter-common + # onnxruntime + # pandas + # pyarrow + # scikit-learn + # scipy + # skl2onnx +onnx==1.12.0 + # via + # onnxconverter-common + # skl2onnx +onnxconverter-common==1.9.0 + # via skl2onnx +onnxruntime==1.11.1 + # via -r requirements.in +packaging==21.3 + # via marshmallow +pandas==1.4.3 + # via flytekit +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnx + # onnxconverter-common + # onnxruntime + # protoc-gen-swagger + # skl2onnx +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +scikit-learn==1.1.1 + # via skl2onnx +scipy==1.8.1 + # via + # scikit-learn + # skl2onnx +six==1.16.0 + # via + # grpcio + # python-dateutil +skl2onnx==1.11.2 + # via flytekitplugins-onnxscikitlearn +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +threadpoolctl==3.1.0 + # via scikit-learn +typing-extensions==4.3.0 + # via + # flytekit + # onnx + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.10 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py new file mode 100644 index 0000000000..9815bedaf2 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxscikitlearn" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "skl2onnx>=1.10.3"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX ScikitLearn Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tests/flytekit/unit/core/functools/__init__.py b/plugins/flytekit-onnx-scikitlearn/tests/__init__.py similarity index 100% rename from tests/flytekit/unit/core/functools/__init__.py rename to plugins/flytekit-onnx-scikitlearn/tests/__init__.py diff --git a/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py b/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py new file mode 100644 index 0000000000..d6f1617ece --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py @@ -0,0 +1,113 @@ +from typing import List, NamedTuple + +import numpy +import onnxruntime as rt +import pandas as pd +from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig +from skl2onnx.common._apply_operation import apply_mul +from skl2onnx.common.data_types import FloatTensorType +from skl2onnx.proto import onnx_proto +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.types.file import ONNXFile + + +def test_onnx_scikitlearn_simple(): + TrainOutput = NamedTuple( + "TrainOutput", + [ + ( + "model", + Annotated[ + ScikitLearn2ONNX, + ScikitLearn2ONNXConfig( + initial_types=[("float_input", FloatTensorType([None, 4]))], + target_opset=12, + ), + ], + ), + ("test", pd.DataFrame), + ], + ) + + @task + def train() -> TrainOutput: + iris = load_iris(as_frame=True) + X, y = iris.data, iris.target + X_train, X_test, y_train, _ = train_test_split(X, y) + model = RandomForestClassifier() + model.fit(X_train, y_train) + + return TrainOutput(test=X_test, model=ScikitLearn2ONNX(model)) + + @task + def predict( + model: ONNXFile, + X_test: pd.DataFrame, + ) -> List[int]: + sess = rt.InferenceSession(model.download()) + input_name = sess.get_inputs()[0].name + label_name = sess.get_outputs()[0].name + pred_onx = sess.run([label_name], {input_name: X_test.to_numpy(dtype=numpy.float32)})[0] + return pred_onx.tolist() + + @workflow + def wf() -> List[int]: + train_output = train() + return predict(model=train_output.model, X_test=train_output.test) + + print(wf()) + + +class CustomTransform(BaseEstimator, TransformerMixin): + def __init__(self): + TransformerMixin.__init__(self) + BaseEstimator.__init__(self) + + def fit(self, X, y, sample_weight=None): + pass + + def transform(self, X): + return X * numpy.array([[0.5, 0.1, 10], [0.5, 0.1, 10]]).T + + +def custom_transform_shape_calculator(operator): + operator.outputs[0].type = FloatTensorType([3, 2]) + + +def custom_tranform_converter(scope, operator, container): + input = operator.inputs[0] + output = operator.outputs[0] + + weights_name = scope.get_unique_variable_name("weights") + atype = onnx_proto.TensorProto.FLOAT + weights = [0.5, 0.1, 10] + shape = [len(weights), 1] + container.add_initializer(weights_name, atype, shape, weights) + apply_mul(scope, [input.full_name, weights_name], output.full_name, container) + + +def test_onnx_scikitlearn(): + @task + def get_model() -> Annotated[ + ScikitLearn2ONNX, + ScikitLearn2ONNXConfig( + initial_types=[("input", FloatTensorType([None, numpy.array([[1, 2], [3, 4], [4, 5]]).shape[1]]))], + custom_shape_calculators={CustomTransform: custom_transform_shape_calculator}, + custom_conversion_functions={CustomTransform: custom_tranform_converter}, + target_opset=12, + ), + ]: + model = CustomTransform() + return ScikitLearn2ONNX(model) + + @workflow + def wf() -> ONNXFile: + return get_model() + + print(wf()) diff --git a/plugins/flytekit-onnx-tensorflow/README.md b/plugins/flytekit-onnx-tensorflow/README.md new file mode 100644 index 0000000000..cd29ede0e1 --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX TensorFlow Plugin + +This plugin allows you to generate ONNX models from your TensorFlow Keras models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxtensorflow +``` diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py new file mode 100644 index 0000000000..c359a7893d --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py @@ -0,0 +1 @@ +from .schema import TensorFlow2ONNX, TensorFlow2ONNXConfig diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py new file mode 100644 index 0000000000..184083f90a --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import tensorflow as tf +import tf2onnx +from dataclasses_json import dataclass_json +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class TensorFlow2ONNXConfig: + input_signature: Union[tf.TensorSpec, np.ndarray] + custom_ops: Optional[Dict[str, Any]] = None + target: Optional[List[Any]] = None + custom_op_handlers: Optional[Dict[Any, Tuple]] = None + custom_rewriter: Optional[List[Any]] = None + opset: Optional[int] = None + extra_opset: Optional[List[int]] = None + shape_override: Optional[Dict[str, List[Any]]] = None + inputs_as_nchw: Optional[List[str]] = None + large_model: bool = False + + +@dataclass_json +@dataclass +class TensorFlow2ONNX: + model: tf.keras = field(default=None) + + +def extract_config(t: Type[TensorFlow2ONNX]) -> Tuple[Type[TensorFlow2ONNX], TensorFlow2ONNXConfig]: + config = None + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, TensorFlow2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type TensorFlow2ONNX") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + tf2onnx.convert.from_keras(model, **config, output_path=local_path) + + return local_path + + +class TensorFlow2ONNXTransformer(TypeTransformer[TensorFlow2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="TensorFlow ONNX", t=TensorFlow2ONNX) + + def get_literal_type(self, t: Type[TensorFlow2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: TensorFlow2ONNX, + python_type: Type[TensorFlow2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + remote_path = ctx.file_access.get_random_remote_path() + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not lv.scalar.blob.uri: + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TensorFlow2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return TensorFlow2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlow2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-tensorflow/requirements.in b/plugins/flytekit-onnx-tensorflow/requirements.in new file mode 100644 index 0000000000..0752c85f88 --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/requirements.in @@ -0,0 +1,4 @@ +. +-e file:.#egg=flytekitplugins-onnxtensorflow +onnxruntime +pillow diff --git a/plugins/flytekit-onnx-tensorflow/requirements.txt b/plugins/flytekit-onnx-tensorflow/requirements.txt new file mode 100644 index 0000000000..51f649a4ed --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/requirements.txt @@ -0,0 +1,285 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxtensorflow + # via -r requirements.in +absl-py==1.1.0 + # via + # tensorboard + # tensorflow +arrow==1.2.2 + # via jinja2-time +astunparse==1.6.3 + # via tensorflow +binaryornot==0.4.4 + # via cookiecutter +cachetools==5.2.0 + # via google-auth +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==1.12 + # via + # onnxruntime + # tensorflow + # tf2onnx +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxtensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.9.0 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status + # tensorboard + # tensorflow +grpcio-status==1.47.0 + # via flytekit +h5py==3.7.0 + # via tensorflow +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring + # markdown +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keras==2.9.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +keyring==23.6.0 + # via flytekit +libclang==14.0.1 + # via tensorflow +markdown==3.3.7 + # via tensorboard +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # h5py + # keras-preprocessing + # onnx + # onnxruntime + # opt-einsum + # pandas + # pyarrow + # tensorboard + # tensorflow + # tf2onnx +oauthlib==3.2.0 + # via requests-oauthlib +onnx==1.12.0 + # via tf2onnx +onnxruntime==1.11.1 + # via -r requirements.in +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via + # marshmallow + # tensorflow +pandas==1.4.3 + # via flytekit +pillow==9.2.0 + # via -r requirements.in +protobuf==3.19.4 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnx + # onnxruntime + # protoc-gen-swagger + # tensorboard + # tensorflow +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # requests-oauthlib + # responses + # tensorboard + # tf2onnx +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +rsa==4.8 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # python-dateutil + # tensorflow + # tf2onnx +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +tensorboard==2.9.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.9.1 + # via flytekitplugins-onnxtensorflow +tensorflow-estimator==2.9.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.26.0 + # via tensorflow +termcolor==1.1.0 + # via tensorflow +text-unidecode==1.3 + # via python-slugify +tf2onnx==1.11.1 + # via flytekitplugins-onnxtensorflow +typing-extensions==4.3.0 + # via + # flytekit + # onnx + # tensorflow + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +werkzeug==2.1.2 + # via tensorboard +wheel==0.37.1 + # via + # astunparse + # flytekit + # tensorboard +wrapt==1.14.1 + # via + # deprecated + # flytekit + # tensorflow +zipp==3.8.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-onnx-tensorflow/setup.py b/plugins/flytekit-onnx-tensorflow/setup.py new file mode 100644 index 0000000000..d2865b083d --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxtensorflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX TensorFlow Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-onnx-tensorflow/tests/__init__.py b/plugins/flytekit-onnx-tensorflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py b/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py new file mode 100644 index 0000000000..259113828a --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py @@ -0,0 +1,78 @@ +import urllib +from io import BytesIO +from typing import List, NamedTuple + +import numpy as np +import onnxruntime as rt +import tensorflow as tf +from flytekitplugins.onnxtensorflow import TensorFlow2ONNX, TensorFlow2ONNXConfig +from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input +from tensorflow.keras.preprocessing import image +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.types.file import ONNXFile + + +def test_tf_onnx(): + @task + def load_test_img() -> np.ndarray: + with urllib.request.urlopen( + "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/ade20k.jpg" + ) as url: + img = image.load_img( + BytesIO(url.read()), + target_size=(224, 224), + ) + + x = image.img_to_array(img) + x = np.expand_dims(x, axis=0) + x = preprocess_input(x) + return x + + TrainPredictOutput = NamedTuple( + "TrainPredictOutput", + [ + ("predictions", np.ndarray), + ( + "model", + Annotated[ + TensorFlow2ONNX, + TensorFlow2ONNXConfig( + input_signature=(tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),), opset=13 + ), + ], + ), + ], + ) + + @task + def train_and_predict(img: np.ndarray) -> TrainPredictOutput: + model = ResNet50(weights="imagenet") + + preds = model.predict(img) + return TrainPredictOutput(predictions=preds, model=TensorFlow2ONNX(model)) + + @task + def onnx_predict( + model: ONNXFile, + img: np.ndarray, + ) -> List[np.ndarray]: + m = rt.InferenceSession(model.download(), providers=["CPUExecutionProvider"]) + onnx_pred = m.run([n.name for n in m.get_outputs()], {"input": img}) + + return onnx_pred + + WorkflowOutput = NamedTuple( + "WorkflowOutput", [("keras_predictions", np.ndarray), ("onnx_predictions", List[np.ndarray])] + ) + + @workflow + def wf() -> WorkflowOutput: + img = load_test_img() + train_predict_output = train_and_predict(img=img) + onnx_preds = onnx_predict(model=train_predict_output.model, img=img) + return WorkflowOutput(keras_predictions=train_predict_output.predictions, onnx_predictions=onnx_preds) + + predictions = wf() + np.testing.assert_allclose(predictions.keras_predictions, predictions.onnx_predictions[0], rtol=1e-5) diff --git a/plugins/flytekit-pandera/requirements.txt b/plugins/flytekit-pandera/requirements.txt index 2e2474669b..82660ea15a 100644 --- a/plugins/flytekit-pandera/requirements.txt +++ b/plugins/flytekit-pandera/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-pandera -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,8 +94,9 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pandera # pyarrow @@ -98,14 +104,12 @@ packaging==21.3 # via # marshmallow # pandera -pandas==1.4.2 +pandas==1.3.5 # via # flytekit # pandera -pandera==0.10.1 +pandera==0.9.0 # via flytekitplugins-pandera -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -123,9 +127,11 @@ pyarrow==6.0.1 # pandera pycparser==2.21 # via cffi -pydantic==1.9.0 +pydantic==1.9.1 # via pandera -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -135,7 +141,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -144,24 +150,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -170,10 +179,14 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # pandera # pydantic + # responses # typing-inspect typing-inspect==0.7.1 # via @@ -184,11 +197,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 422ae7c9f8..4b5cde2509 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -50,11 +50,11 @@ googleapis-common-protos==1.55.0 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests diff --git a/plugins/flytekit-papermill/requirements.txt b/plugins/flytekit-papermill/requirements.txt index a2cb568159..10c137989e 100644 --- a/plugins/flytekit-papermill/requirements.txt +++ b/plugins/flytekit-papermill/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,8 +10,6 @@ ansiwrap==0.8.4 # via papermill arrow==1.2.2 # via jinja2-time -asttokens==2.0.5 - # via stack-data attrs==21.4.0 # via jsonschema backcall==0.2.0 @@ -20,29 +18,31 @@ beautifulsoup4==4.11.1 # via nbconvert binaryornot==0.4.4 # via cookiecutter -bleach==5.0.0 +bleach==5.0.1 # via nbconvert -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit # papermill -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit debugpy==1.6.0 @@ -61,38 +61,42 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit entrypoints==0.4 # via # jupyter-client # nbconvert # papermill -executing==0.8.3 - # via stack-data fastjsonschema==2.15.3 # via nbformat -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-papermill -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring -ipykernel==6.13.0 +importlib-metadata==4.12.0 + # via + # click + # flytekit + # jsonschema + # keyring +importlib-resources==5.8.0 + # via jsonschema +ipykernel==6.15.0 # via flytekitplugins-papermill -ipython==8.2.0 +ipython==7.34.0 # via ipykernel jedi==0.18.1 # via ipython @@ -100,16 +104,16 @@ jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time # nbconvert jinja2-time==0.2.0 # via cookiecutter -jsonschema==4.4.0 +jsonschema==4.6.1 # via nbformat -jupyter-client==7.2.2 +jupyter-client==7.3.4 # via # ipykernel # nbclient @@ -120,13 +124,13 @@ jupyter-core==4.10.0 # nbformat jupyterlab-pygments==0.2.2 # via nbconvert -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via # jinja2 # nbconvert -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -139,19 +143,19 @@ matplotlib-inline==0.1.3 # via # ipykernel # ipython -mistune==0.8.4 +mistune==2.0.3 # via nbconvert mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -nbclient==0.6.0 +nbclient==0.6.6 # via # nbconvert # papermill -nbconvert==6.5.0 +nbconvert==7.0.0rc2 # via flytekitplugins-papermill -nbformat==5.3.0 +nbformat==5.4.0 # via # nbclient # nbconvert @@ -161,8 +165,9 @@ nest-asyncio==1.5.5 # ipykernel # jupyter-client # nbclient -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 @@ -170,7 +175,7 @@ packaging==21.3 # ipykernel # marshmallow # nbconvert -pandas==1.4.2 +pandas==1.3.5 # via flytekit pandocfilters==1.5.0 # via nbconvert @@ -182,9 +187,7 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -poyo==0.5.0 - # via cookiecutter -prompt-toolkit==3.0.29 +prompt-toolkit==3.0.30 # via ipython protobuf==3.20.1 # via @@ -195,23 +198,23 @@ protobuf==3.20.1 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.0 +psutil==5.9.1 # via ipykernel ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.2 - # via stack-data py==1.11.0 # via retry pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pygments==2.11.2 +pygments==2.12.0 # via # ipython # nbconvert -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging pyrsistent==0.18.1 # via jsonschema @@ -224,7 +227,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -234,38 +237,39 @@ pytz==2022.1 # pandas pyyaml==6.0 # via + # cookiecutter # flytekit # papermill -pyzmq==22.3.0 - # via jupyter-client -regex==2022.3.15 +pyzmq==23.2.0 + # via + # ipykernel + # jupyter-client +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # papermill # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # asttokens # bleach - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 # via flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 -stack-data==0.2.0 - # via ipython statsd==3.3.0 # via flytekit tenacity==8.0.1 @@ -276,13 +280,13 @@ textwrap3==0.9.2 # via ansiwrap tinycss2==1.1.1 # via nbconvert -tornado==6.1 +tornado==6.2 # via # ipykernel # jupyter-client tqdm==4.64.0 # via papermill -traitlets==5.1.1 +traitlets==5.3.0 # via # ipykernel # ipython @@ -292,9 +296,13 @@ traitlets==5.1.1 # nbclient # nbconvert # nbformat -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # jsonschema + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -309,16 +317,18 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit zipp==3.8.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-polars/README.md b/plugins/flytekit-polars/README.md new file mode 100644 index 0000000000..011a447582 --- /dev/null +++ b/plugins/flytekit-polars/README.md @@ -0,0 +1,10 @@ +# Flytekit Polars Plugin +[Polars](https://github.com/pola-rs/polars) is a blazingly fast DataFrames library implemented in Rust using Apache Arrow Columnar Format as memory model. + +This plugin supports `polars.DataFrame` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-polars +``` diff --git a/plugins/flytekit-polars/flytekitplugins/polars/__init__.py b/plugins/flytekit-polars/flytekitplugins/polars/__init__.py new file mode 100644 index 0000000000..85948bed73 --- /dev/null +++ b/plugins/flytekit-polars/flytekitplugins/polars/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.polars + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PolarsDataFrameToParquetEncodingHandler + ParquetToPolarsDataFrameDecodingHandler +""" + +from .sd_transformers import ParquetToPolarsDataFrameDecodingHandler, PolarsDataFrameToParquetEncodingHandler diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py new file mode 100644 index 0000000000..1a667fe699 --- /dev/null +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -0,0 +1,73 @@ +import typing + +import polars as pl + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + GCS, + LOCAL, + PARQUET, + S3, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) + + +class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self, protocol: str): + super().__init__(pl.DataFrame, protocol, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + df = typing.cast(pl.DataFrame, structured_dataset.dataframe) + + local_dir = ctx.file_access.get_random_local_directory() + local_path = f"{local_dir}/00000" + + # Polars 0.13.12 deprecated to_parquet in favor of write_parquet + if hasattr(df, "write_parquet"): + df.write_parquet(local_path) + else: + df.to_parquet(local_path) + remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + ctx.file_access.upload_directory(local_dir, remote_dir) + return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + +class ParquetToPolarsDataFrameDecodingHandler(StructuredDatasetDecoder): + def __init__(self, protocol: str): + super().__init__(pl.DataFrame, protocol, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pl.DataFrame: + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + path = f"{local_dir}/00000" + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return pl.read_parquet(path, columns=columns) + return pl.read_parquet(path) + + +for protocol in [LOCAL, S3]: + StructuredDatasetTransformerEngine.register( + PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=True + ) + StructuredDatasetTransformerEngine.register( + ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=True + ) +StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler(GCS), default_for_type=False) +StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler(GCS), default_for_type=False) diff --git a/plugins/flytekit-polars/requirements.in b/plugins/flytekit-polars/requirements.in new file mode 100644 index 0000000000..8425c5645f --- /dev/null +++ b/plugins/flytekit-polars/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-polars diff --git a/plugins/flytekit-polars/requirements.txt b/plugins/flytekit-polars/requirements.txt new file mode 100644 index 0000000000..a4fa78640e --- /dev/null +++ b/plugins/flytekit-polars/requirements.txt @@ -0,0 +1,198 @@ +# +# This file is autogenerated by pip-compile with python 3.7 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-polars + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-polars +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.6.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.21.6 + # via + # flytekit + # pandas + # polars + # pyarrow +packaging==21.3 + # via marshmallow +pandas==1.3.5 + # via flytekit +polars==0.13.51 + # via flytekitplugins-polars +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +secretstorage==3.3.2 + # via keyring +singledispatchmethod==1.0 + # via flytekit +six==1.16.0 + # via + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +typing-extensions==4.3.0 + # via + # arrow + # flytekit + # importlib-metadata + # polars + # responses + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-polars/setup.py b/plugins/flytekit-polars/setup.py new file mode 100644 index 0000000000..ea3feb8582 --- /dev/null +++ b/plugins/flytekit-polars/setup.py @@ -0,0 +1,38 @@ +from setuptools import setup + +PLUGIN_NAME = "polars" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.1.0b0,<1.2.0", + "polars>=0.8.27", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Robin Kahlow", + description="Polars plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-polars/tests/__init__.py b/plugins/flytekit-polars/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py new file mode 100644 index 0000000000..3c9c2613ae --- /dev/null +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -0,0 +1,64 @@ +import flytekitplugins.polars # noqa F401 +import polars as pl + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +from flytekit import kwtypes, task, workflow +from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset + +subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET] +full_schema = Annotated[StructuredDataset, PARQUET] + + +def test_polars_workflow_subset(): + @task + def generate() -> subset_schema: + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + return StructuredDataset(dataframe=df) + + @task + def consume(df: subset_schema) -> subset_schema: + df = df.open(pl.DataFrame).all() + + assert df["col2"][0] == "a" + assert df["col2"][1] == "b" + assert df["col2"][2] == "c" + + return StructuredDataset(dataframe=df) + + @workflow + def wf() -> subset_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_polars_workflow_full(): + @task + def generate() -> full_schema: + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + return StructuredDataset(dataframe=df) + + @task + def consume(df: full_schema) -> full_schema: + df = df.open(pl.DataFrame).all() + + assert df["col1"][0] == 1 + assert df["col1"][1] == 3 + assert df["col1"][2] == 2 + assert df["col2"][0] == "a" + assert df["col2"][1] == "b" + assert df["col2"][2] == "c" + + return StructuredDataset(dataframe=df.sort("col1")) + + @workflow + def wf() -> full_schema: + return consume(df=generate()) + + result = wf() + assert result is not None diff --git a/plugins/flytekit-snowflake/requirements.txt b/plugins/flytekit-snowflake/requirements.txt index dff8846bb4..969a62a435 100644 --- a/plugins/flytekit-snowflake/requirements.txt +++ b/plugins/flytekit-snowflake/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-snowflake -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -114,7 +118,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -124,7 +130,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -133,24 +139,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -159,9 +168,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -170,11 +182,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-spark/requirements.txt b/plugins/flytekit-spark/requirements.txt index 023e6102f4..979feb79ef 100644 --- a/plugins/flytekit-spark/requirements.txt +++ b/plugins/flytekit-spark/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,41 +44,44 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-spark -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -89,16 +94,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -110,15 +114,17 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -py4j==0.10.9.3 +py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging -pyspark==3.2.1 +pyspark==3.3.0 # via flytekitplugins-spark python-dateutil==2.8.2 # via @@ -128,7 +134,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -137,24 +143,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -163,9 +172,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -174,11 +186,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-spark/tests/test_wf.py b/plugins/flytekit-spark/tests/test_wf.py index 8c42a6162f..b551d5cdff 100644 --- a/plugins/flytekit-spark/tests/test_wf.py +++ b/plugins/flytekit-spark/tests/test_wf.py @@ -4,7 +4,7 @@ import flytekit from flytekit import kwtypes, task, workflow -from flytekit.types.schema import FlyteSchema + try: from typing import Annotated diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 1150c8a941..88e4ef41c0 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import pandas as pd +from pandas.io.sql import pandasSQL_builder from sqlalchemy import create_engine # type: ignore from flytekit import current_context, kwtypes @@ -82,12 +83,14 @@ def __init__( query_template: str, task_config: SQLAlchemyConfig, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = FlyteSchema, container_image: str = SQLAlchemyDefaultImages.default_image(), **kwargs, ): - output_schema = output_schema_type if output_schema_type else FlyteSchema - outputs = kwtypes(results=output_schema) + if output_schema_type: + outputs = kwtypes(results=output_schema_type) + else: + outputs = None super().__init__( name=name, @@ -128,5 +131,9 @@ def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.A interpolated_query = SQLAlchemyTask.interpolate_query(tt.custom["query_template"], **kwargs) print(f"Interpolated query {interpolated_query}") with engine.begin() as connection: - df = pd.read_sql_query(interpolated_query, connection) + df = None + if tt.interface.outputs: + df = pd.read_sql_query(interpolated_query, connection) + else: + pandasSQL_builder(connection).execute(interpolated_query) return df diff --git a/plugins/flytekit-sqlalchemy/requirements.txt b/plugins/flytekit-sqlalchemy/requirements.txt index a858b8b789..6d4bbbb40d 100644 --- a/plugins/flytekit-sqlalchemy/requirements.txt +++ b/plugins/flytekit-sqlalchemy/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # pip-compile requirements.in @@ -10,26 +10,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,43 +44,47 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.14.1 # via flytekit -flyteidl==0.24.21 +flyteidl==1.1.8 # via flytekit -flytekit==1.0.0b3 +flytekit==1.1.0 # via flytekitplugins-sqlalchemy -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status greenlet==1.1.2 # via sqlalchemy -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring + # sqlalchemy jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -91,16 +97,15 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # pandas # pyarrow packaging==21.3 # via marshmallow -pandas==1.4.2 +pandas==1.3.5 # via flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -116,7 +121,9 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging python-dateutil==2.8.2 # via @@ -126,7 +133,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -135,37 +142,43 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.3.15 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit secretstorage==3.3.2 # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 # via flytekit -sqlalchemy==1.4.35 +sqlalchemy==1.4.39 # via flytekitplugins-sqlalchemy statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -174,11 +187,11 @@ urllib3==1.26.9 # flytekit # requests # responses -websocket-client==1.3.2 +websocket-client==1.3.3 # via docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 167c8e796d..6d20027b2a 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -75,6 +75,13 @@ def test_workflow(sql_server): def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) + insert_task = SQLAlchemyTask( + "test", + query_template="insert into tracks values (5, 'flyte')", + output_schema_type=None, + task_config=SQLAlchemyConfig(uri=sql_server), + ) + sql_task = SQLAlchemyTask( "test", query_template="select * from tracks limit {{.inputs.limit}}", @@ -84,9 +91,10 @@ def my_task(df: pandas.DataFrame) -> int: @workflow def wf(limit: int) -> int: + insert_task() return my_task(df=sql_task(limit=limit)) - assert wf(limit=5) == 5 + assert wf(limit=10) == 6 def test_task_serialization(sql_server): diff --git a/plugins/flytekit-whylogs/README.md b/plugins/flytekit-whylogs/README.md new file mode 100644 index 0000000000..aeaff969e5 --- /dev/null +++ b/plugins/flytekit-whylogs/README.md @@ -0,0 +1,57 @@ +# Flytekit whylogs Plugin + +whylogs is an open source library for logging any kind of data. With whylogs, +you are able to generate summaries of datasets (called whylogs profiles) which +can be used to: + +- Create data constraints to know whether your data looks the way it should +- Quickly visualize key summary statistics about a dataset +- Track changes in a dataset over time + +```bash +pip install flytekitplugins-whylogs +``` + +To generate profiles, you can add a task like the following: + +```python +from whylogs.core import DatasetProfileView +import whylogs as ylog + +import pandas as pd + +@task +def profile(df: pd.DataFrame) -> DatasetProfileView: + result = ylog.log(df) # Various overloads for different common data types exist + profile = result.view() + return profile +``` + +>**NOTE:** You'll be passing around `DatasetProfileView` from tasks, not `DatasetProfile`. + +## Validating Data + +A common step in data pipelines is data validation. This can be done in +`whylogs` through the constraint feature. You'll be able to create failure tasks +if the data in the workflow doesn't conform to some configured constraints, like +min/max values on features, data types on features, etc. + +```python +@task +def validate_data(profile: DatasetProfileView): + column = profile.get_column("my_column") + print(column.to_summary_dict()) # To see available things you can validate against + builder = ConstraintsBuilder(profile) + numConstraint = MetricConstraint( + name='numbers between 0 and 4 only', + condition=lambda x: x.min > 0 and x.max < 4, + metric_selector=MetricsSelector(metric_name='distribution', column_name='my_column')) + builder.add_constraint(numConstraint) + constraint = builder.build() + valid = constraint.validate() + + if(not valid): + raise Exception("Invalid data found") +``` + +Check out our [constraints notebook](https://github.com/whylabs/whylogs/blob/1.0.x/python/examples/basic/MetricConstraints.ipynb) for more examples. diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/__init__.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/__init__.py new file mode 100644 index 0000000000..cbf7623796 --- /dev/null +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/__init__.py @@ -0,0 +1,9 @@ +from .schema import WhylogsDatasetProfileTransformer +from .renderer import WhylogsConstraintsRenderer, WhylogsSummaryDriftRenderer + + +__all__ = [ + "WhylogsDatasetProfileTransformer", + "WhylogsConstraintsRenderer", + "WhylogsSummaryDriftRenderer" +] diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/renderer.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/renderer.py new file mode 100644 index 0000000000..69e9718c15 --- /dev/null +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/renderer.py @@ -0,0 +1,41 @@ +import pandas as pd +import whylogs as why +from whylogs.viz import NotebookProfileVisualizer +from whylogs.core.constraints import Constraints + + +class WhylogsSummaryDriftRenderer: + """ + Creates a whylogs' Summary Drift report from two pandas DataFrames. One of them + is the reference and the other one is the target data, meaning that this is what + the report will compare it against. + """ + @staticmethod + def to_html( + reference_data: pd.DataFrame, + target_data: pd.DataFrame + ) -> str: + """ + This static method will profile the input data and then generate an HTML report + with the Summary Drift calculations for all of the dataframe's columns + + :param reference_data: The DataFrame that will be the reference for the drift report + :type: pandas.DataFrame + + :param target_data: The data to compare against and create the Summary Drift report + :type target_data: pandas.DataFrame + """ + + target_view = why.log(target_data).view() + reference_view = why.log(reference_data).view() + viz = NotebookProfileVisualizer() + viz.set_profiles(target_profile_view=target_view, reference_profile_view=reference_view) + return viz.summary_drift_report().data + + +class WhylogsConstraintsRenderer: + @staticmethod + def to_html(constraints: Constraints) -> str: + viz = NotebookProfileVisualizer() + report = viz.constraints_report(constraints=constraints) + return report.data diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py new file mode 100644 index 0000000000..5a22ec968b --- /dev/null +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py @@ -0,0 +1,53 @@ +from typing import Type + +from whylogs.core import DatasetProfileView + +from flytekit import FlyteContext, BlobType +from flytekit.extend import T, TypeTransformer, TypeEngine +from flytekit.models.literals import Literal, Scalar, Blob, BlobMetadata +from flytekit.models.types import LiteralType + + +class WhylogsDatasetProfileTransformer(TypeTransformer[DatasetProfileView]): + """ + Transforms whylogs Dataset Profile Views to and from a Schema (typed/untyped) + """ + + _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.SINGLE) + + def __init__(self): + super(WhylogsDatasetProfileTransformer, self).__init__("whylogs-profile-transformer", t=DatasetProfileView) + + def get_literal_type(self, t: Type[DatasetProfileView]) -> LiteralType: + return LiteralType(blob=self._TYPE_INFO) + + def to_literal( + self, + ctx: FlyteContext, + python_val: DatasetProfileView, + python_type: Type[DatasetProfileView], + expected: LiteralType, + ) -> Literal: + remote_path = ctx.file_access.get_random_remote_directory() + local_dir = ctx.file_access.get_random_local_path() + python_val.write(local_dir) + ctx.file_access.upload(local_dir, remote_path) + return Literal(scalar=Scalar(blob=Blob(uri=remote_path, metadata=BlobMetadata(type=self._TYPE_INFO)))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DatasetProfileView]) -> T: + local_dir = ctx.file_access.get_random_local_path() + ctx.file_access.download(lv.scalar.blob.uri, local_dir) + return DatasetProfileView.read(local_dir) + + def to_html( + self, + ctx: FlyteContext, + python_val: DatasetProfileView, + expected_python_type: Type[DatasetProfileView] + ) -> str: + pandas_profile = str(python_val.to_pandas().to_html()) + header = str("

Profile View

\n") + return header + pandas_profile + + +TypeEngine.register(WhylogsDatasetProfileTransformer()) diff --git a/plugins/flytekit-whylogs/setup.py b/plugins/flytekit-whylogs/setup.py new file mode 100644 index 0000000000..d70118655c --- /dev/null +++ b/plugins/flytekit-whylogs/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "whylogs" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["whylogs", "whylogs[viz]"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="whylabs", + author_email="support@whylabs.ai", + description="Enable the use of whylogs profiles to be used in flyte tasks to get aggregate statistics about data.", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-whylogs", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-whylogs/tests/__init__.py b/plugins/flytekit-whylogs/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-whylogs/tests/test_schema.py b/plugins/flytekit-whylogs/tests/test_schema.py new file mode 100644 index 0000000000..294573bbed --- /dev/null +++ b/plugins/flytekit-whylogs/tests/test_schema.py @@ -0,0 +1,134 @@ +from typing import Optional + +import plotly.express as px +import numpy as np +import pandas as pd +import pyspark +import whylogs as why +from whylogs.core import DatasetProfileView +from whylogs.core.constraints import ConstraintsBuilder, MetricsSelector, MetricConstraint +from whylogs.api.pyspark.experimental.profiler import collect_dataset_profile_view + +import flytekit +from flytekitplugins.spark import Spark +from flytekit import task, workflow +from flytekitplugins.whylogs import WhylogsSummaryDriftRenderer, WhylogsConstraintsRenderer + + +@task +def read_data() -> pd.DataFrame: + return px.data.iris() + + +@task +def make_data(n_rows: int = 3) -> pd.DataFrame: + data = { + 'sepal_length': np.random.random_sample(n_rows), + 'sepal_width': np.random.random_sample(n_rows), + 'petal_length': np.random.random_sample(n_rows), + 'petal_width': np.random.random_sample(n_rows), + 'species': np.random.choice(['virginica', 'setosa', 'versicolor'], n_rows), + 'species_id': np.random.choice([1, 2, 3], n_rows) + } + return pd.DataFrame(data) + + +@task +def summary_drift_report(target_data: pd.DataFrame, reference_data: pd.DataFrame) -> None: + renderer = WhylogsSummaryDriftRenderer() + flytekit.Deck("summary drift", renderer.to_html(target_data=target_data, reference_data=reference_data)) + + +@task +def run_constraints(df: pd.DataFrame, + min_value: Optional[float] = 0.0, + max_value: Optional[float] = 4.0 + ) -> bool: + # This API constraints workflow is very flexible but a bit cumbersome. + # It will be simplified in the future, so for now we'll stick with injecting + # a Constraints object to the renderer. + profile_view = why.log(df).view() + builder = ConstraintsBuilder(profile_view) + num_constraint = MetricConstraint( + name=f'numbers between {min_value} and {max_value} only', + condition=lambda x: x.min > min_value and x.max < max_value, + metric_selector=MetricsSelector(metric_name='distribution', column_name='sepal_length')) + + builder.add_constraint(num_constraint) + constraints = builder.build() + + renderer = WhylogsConstraintsRenderer() + flytekit.Deck("constraints", renderer.to_html(constraints=constraints)) + + return constraints.validate() + + +@workflow +def whylogs_workflow(min_value: float, max_value: float) -> bool: + new_data = make_data(n_rows=10) + reference_data = read_data() + + summary_drift_report(target_data=new_data, reference_data=reference_data) + validated = run_constraints(df=new_data, min_value=min_value, max_value=max_value) + return validated + + +@task( + task_config=Spark( + spark_conf={ + "spark.driver.memory": "1000M", + "spark.executor.instances": "1", + "spark.driver.cores": "1", + } + ) +) +def make_spark_dataframe(n_rows: int = 10) -> pyspark.sql.DataFrame: + spark = flytekit.current_context().spark_session + data = { + 'sepal_length': np.random.random_sample(n_rows), + 'sepal_width': np.random.random_sample(n_rows), + 'petal_length': np.random.random_sample(n_rows), + 'petal_width': np.random.random_sample(n_rows), + 'species': np.random.choice(['virginica', 'setosa', 'versicolor'], n_rows), + 'species_id': np.random.choice([1, 2, 3], n_rows) + } + pandas_df = pd.DataFrame(data) + spark_df = spark.createDataFrame(pandas_df) + return spark_df + + +@task( + task_config=Spark( + spark_conf={ + "spark.driver.memory": "1000M", + "spark.executor.instances": "1", + "spark.driver.cores": "1", + } + ) +) +def profile_spark_df(df: pyspark.sql.DataFrame) -> DatasetProfileView: + profile_view = collect_dataset_profile_view(df) + return profile_view + + +@workflow +def spark_whylogs_wf() -> DatasetProfileView: + spark_df = make_spark_dataframe(n_rows=10) + profile_view = profile_spark_df(df=spark_df) + return profile_view + + +def test_workflow_with_whylogs(): + validated = whylogs_workflow(min_value=0.0, max_value=1.0) + assert validated is True + + +def test_constraints(): + validated = whylogs_workflow(min_value=-1.0, max_value=0.0) + assert validated is False + + +def test_pyspark_wf(): + profile_view = spark_whylogs_wf() + assert profile_view is not None + assert isinstance(profile_view, DatasetProfileView) diff --git a/plugins/setup.py b/plugins/setup.py index 8f3cc5c299..bc15144ee1 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -21,6 +21,9 @@ "flytekitplugins-kfpytorch": "flytekit-kf-pytorch", "flytekitplugins-kftensorflow": "flytekit-kf-tensorflow", "flytekitplugins-modin": "flytekit-modin", + "flytekitplugins-onnxscikitlearn": "flytekit-onnx-scikitlearn", + "flytekitplugins-onnxtensorflow": "flytekit-onnx-tensorflow", + "flytekitplugins-onnxpytorch": "flytekit-onnx-pytorch", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", "flytekitplugins-snowflake": "flytekit-snowflake", diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 4b24f292a8..9b15318529 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # make requirements-spark2.txt @@ -16,22 +16,28 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -chardet==4.0.0 +cffi==1.15.1 + # via cryptography +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,25 +50,33 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14 +docstring-parser==0.14.1 # via flytekit -flyteidl==1.0.0.post1 +flyteidl==1.1.8 # via flytekit -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring -jinja2==3.1.1 +importlib-metadata==4.12.0 + # via + # click + # flytekit + # jsonschema + # keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 # via # cookiecutter # jinja2-time @@ -70,11 +84,11 @@ jinja2-time==0.2.0 # via cookiecutter jsonschema==3.2.0 # via -r requirements.in -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -90,6 +104,7 @@ natsort==8.1.0 numpy==1.21.6 # via # -r requirements.in + # flytekit # pandas # pyarrow packaging==21.3 @@ -98,8 +113,6 @@ pandas==1.3.5 # via # -r requirements.in # flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -113,7 +126,11 @@ py==1.11.0 # via retry pyarrow==6.0.1 # via flytekit -pyparsing==3.0.8 +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging pyrsistent==0.18.1 # via jsonschema @@ -125,7 +142,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -136,22 +153,26 @@ pytz==2022.1 pyyaml==5.4.1 # via # -r requirements.in + # cookiecutter # flytekit -regex==2022.4.24 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.2 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # jsonschema # python-dateutil @@ -162,9 +183,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -179,7 +203,7 @@ websocket-client==0.59.0 # docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/requirements.txt b/requirements.txt index bc09bf4121..7bc335f046 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # make requirements.txt @@ -14,22 +14,28 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -chardet==4.0.0 +cffi==1.15.1 + # via cryptography +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,25 +48,33 @@ docker==5.0.3 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14 +docstring-parser==0.14.1 # via flytekit -flyteidl==1.0.0.post1 +flyteidl==1.1.8 # via flytekit -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring -jinja2==3.1.1 +importlib-metadata==4.12.0 + # via + # click + # flytekit + # jsonschema + # keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 # via # cookiecutter # jinja2-time @@ -68,11 +82,11 @@ jinja2-time==0.2.0 # via cookiecutter jsonschema==3.2.0 # via -r requirements.in -keyring==23.5.0 +keyring==23.6.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -88,6 +102,7 @@ natsort==8.1.0 numpy==1.21.6 # via # -r requirements.in + # flytekit # pandas # pyarrow packaging==21.3 @@ -96,8 +111,6 @@ pandas==1.3.5 # via # -r requirements.in # flytekit -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -111,7 +124,11 @@ py==1.11.0 # via retry pyarrow==6.0.1 # via flytekit -pyparsing==3.0.8 +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via packaging pyrsistent==0.18.1 # via jsonschema @@ -123,7 +140,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -134,22 +151,26 @@ pytz==2022.1 pyyaml==5.4.1 # via # -r requirements.in + # cookiecutter # flytekit -regex==2022.4.24 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.2 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # jsonschema # python-dateutil @@ -160,9 +181,12 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -177,7 +201,7 @@ websocket-client==0.59.0 # docker wheel==0.37.1 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/setup.py b/setup.py index c645639ce5..6a7b25ae0c 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ ] }, install_requires=[ - "flyteidl>=1.0.0,<1.1.0", + "flyteidl>=1.1.3,<1.2.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<7.0.0", @@ -48,6 +48,8 @@ "python-dateutil>=2.1", "grpcio>=1.43.0,!=1.45.0,<2.0", "grpcio-status>=1.43,!=1.45.0", + "importlib-metadata", + "pyopenssl", "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 0d1f1a37be..c93c56435c 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # # make tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -8,24 +8,28 @@ arrow==1.2.2 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.6.15 # via requests -chardet==4.0.0 +cffi==1.15.1 + # via cryptography +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests -checksumdir==1.2.0 - # via flytekit -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.0.0 +cloudpickle==2.1.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.5 # via flytekit +cryptography==37.0.4 + # via + # pyopenssl + # secretstorage cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 @@ -36,31 +40,40 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit +docker==5.0.3 + # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14 +docstring-parser==0.14.1 # via flytekit -flyteidl==1.0.0.post1 +flyteidl==1.1.8 # via flytekit -flytekit==0.32.6 +flytekit==1.1.0 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.33.2 +fonttools==4.33.3 # via matplotlib -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.56.3 # via # flyteidl # grpcio-status -grpcio==1.44.0 +grpcio==1.47.0 # via # flytekit # grpcio-status -grpcio-status==1.44.0 +grpcio-status==1.47.0 # via flytekit idna==3.3 # via requests -importlib-metadata==4.11.3 - # via keyring -jinja2==3.1.1 +importlib-metadata==4.12.0 + # via + # click + # flytekit + # keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 # via # cookiecutter # jinja2-time @@ -68,13 +81,13 @@ jinja2-time==0.2.0 # via cookiecutter joblib==1.1.0 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -keyring==23.5.0 +keyring==23.6.0 # via flytekit -kiwisolver==1.4.2 +kiwisolver==1.4.3 # via matplotlib markupsafe==2.1.1 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -83,30 +96,29 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.1 +matplotlib==3.5.2 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -numpy==1.22.3 +numpy==1.21.6 # via + # flytekit # matplotlib # opencv-python # pandas # pyarrow -opencv-python==4.5.5.64 +opencv-python==4.6.0.66 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in packaging==21.3 # via # marshmallow # matplotlib -pandas==1.4.2 +pandas==1.3.5 # via flytekit -pillow==9.1.0 +pillow==9.2.0 # via matplotlib -poyo==0.5.0 - # via cookiecutter protobuf==3.20.1 # via # flyteidl @@ -120,7 +132,11 @@ py==1.11.0 # via retry pyarrow==6.0.1 # via flytekit -pyparsing==3.0.8 +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 # via # matplotlib # packaging @@ -132,7 +148,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.2 # via flytekit -python-slugify==6.1.1 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -141,21 +157,27 @@ pytz==2022.1 # flytekit # pandas pyyaml==6.0 - # via flytekit -regex==2022.4.24 + # via + # cookiecutter + # flytekit +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter + # docker # flytekit # responses -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.2 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil sortedcontainers==2.4.0 @@ -164,9 +186,13 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via + # arrow # flytekit + # importlib-metadata + # kiwisolver + # responses # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -175,11 +201,13 @@ urllib3==1.26.9 # flytekit # requests # responses +websocket-client==1.3.3 + # via docker wheel==0.37.1 # via # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in # flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 0341e2cfc8..c8d1d93cbd 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -299,3 +299,22 @@ def test_normalize_inputs(): assert normalize_inputs("/raw", "/cp1", '""') == ("/raw", "/cp1", None) assert normalize_inputs("/raw", "/cp1", "") == ("/raw", "/cp1", None) assert normalize_inputs("/raw", "/cp1", "/prev") == ("/raw", "/cp1", "/prev") + + +@mock.patch("flytekit.bin.entrypoint.os") +def test_env_reading(mock_os): + mock_env = { + "FLYTE_INTERNAL_EXECUTION_PROJECT": "exec_proj", + "FLYTE_INTERNAL_EXECUTION_DOMAIN": "exec_dom", + "FLYTE_INTERNAL_EXECUTION_ID": "exec_name", + "FLYTE_INTERNAL_TASK_PROJECT": "task_proj", + "FLYTE_INTERNAL_TASK_DOMAIN": "task_dom", + "FLYTE_INTERNAL_TASK_NAME": "task_name", + "FLYTE_INTERNAL_TASK_VERSION": "task_ver", + } + mock_os.environ = mock_env + + with setup_execution("qwerty") as ctx: + assert ctx.execution_state.user_space_params.task_id.name == "task_name" + assert ctx.execution_state.user_space_params.task_id.version == "task_ver" + assert ctx.execution_state.user_space_params.execution_id.name == "exec_name" diff --git a/tests/flytekit/unit/cli/pyflyte/conftest.py b/tests/flytekit/unit/cli/pyflyte/conftest.py index 723fb4878b..6ce51bd4e1 100644 --- a/tests/flytekit/unit/cli/pyflyte/conftest.py +++ b/tests/flytekit/unit/cli/pyflyte/conftest.py @@ -18,7 +18,7 @@ def _fake_module_load(names): yield simple -@pytest.yield_fixture( +@pytest.fixture( scope="function", params=[ os.path.join( diff --git a/tests/flytekit/unit/cli/pyflyte/default_arguments/collection_wf.py b/tests/flytekit/unit/cli/pyflyte/default_arguments/collection_wf.py new file mode 100644 index 0000000000..20049a3cb2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/default_arguments/collection_wf.py @@ -0,0 +1,13 @@ +from typing import List + +from flytekit import task, workflow + + +@task +def t1(x: List[int]) -> int: + return sum(x) + + +@workflow +def wf(x: List[int] = [1, 2, 3]) -> int: + return t1(x=x) diff --git a/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py b/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py new file mode 100644 index 0000000000..d9ba207cf2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +from dataclasses_json import dataclass_json + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class DataclassA: + a: str + b: int + + +@task +def t(dca: DataclassA): + print(dca) + + +@workflow +def wf(dca: DataclassA = DataclassA("hello", 42)): + t(dca=dca) diff --git a/tests/flytekit/unit/cli/pyflyte/default_arguments/map_wf.py b/tests/flytekit/unit/cli/pyflyte/default_arguments/map_wf.py new file mode 100644 index 0000000000..79f955f4f2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/default_arguments/map_wf.py @@ -0,0 +1,13 @@ +from typing import Dict + +from flytekit import task, workflow + + +@task +def t1(x: Dict[str, int]) -> int: + return sum(x.values()) + + +@workflow +def wf(x: Dict[str, int] = {"a": 1, "b": 2, "c": 3}) -> int: + return t1(x=x) diff --git a/tests/flytekit/unit/cli/pyflyte/test_main.py b/tests/flytekit/unit/cli/pyflyte/test_main.py new file mode 100644 index 0000000000..598d4c0096 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_main.py @@ -0,0 +1,34 @@ +import mock + +from flytekit.clis.flyte_cli.main import _get_client +from flytekit.configuration import AuthType, PlatformConfig + + +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") +@mock.patch("click.get_current_context") +def test_get_client(click_current_ctx, mock_flyte_client): + # This class helps in the process of overriding __getitem__ in an object. + class FlexiMock(mock.MagicMock): + def __init__(self, *args, **kwargs): + super(FlexiMock, self).__init__(*args, **kwargs) + self.__getitem__ = lambda obj, item: getattr(obj, item) # set the mock for `[...]` + + def get(self, x, default=None): + return getattr(self, x, default) + + click_current_ctx = mock.MagicMock + obj_mock = FlexiMock( + config=PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS), + cacert=None, + ) + click_current_ctx.obj = obj_mock + + _ = _get_client(host="some-host:12345", insecure=False) + + expected_platform_config = PlatformConfig( + endpoint="some-host:12345", + insecure=False, + auth_mode=AuthType.EXTERNAL_PROCESS, + ) + + mock_flyte_client.assert_called_with(expected_platform_config) diff --git a/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/__init__.py b/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/d/__init__.py b/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/d/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/d/wf.py b/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/d/wf.py new file mode 100644 index 0000000000..8304880cda --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_nested_wf/a/b/c/d/wf.py @@ -0,0 +1,11 @@ +from flytekit import task, workflow + + +@task +def t(m: str) -> str: + return m + + +@workflow +def wf_id(m: str) -> str: + return t(m=m) diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index d31b760608..364b6b14d9 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -1,3 +1,6 @@ +import os +import shutil + import pytest from click.testing import CliRunner from flyteidl.admin.launch_plan_pb2 import LaunchPlan @@ -11,6 +14,22 @@ from flytekit.core import context_manager from flytekit.exceptions.user import FlyteValidationException +sample_file_contents = """ +from flytekit import task, workflow + +@task(cache=True, cache_version="1", retries=3) +def sum(x: int, y: int) -> int: + return x + y + +@task(cache=True, cache_version="1", retries=3) +def square(z: int) -> int: + return z*z + +@workflow +def my_workflow(x: int, y: int) -> int: + return sum(x=square(z=x), y=square(z=y)) +""" + @flytekit.task def foo(): @@ -44,6 +63,29 @@ def test_get_registrable_entities(): assert False, f"found unknown entity {type(e)}" +def test_package_with_fast_registration(): + runner = CliRunner() + with runner.isolated_filesystem(): + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"]) + assert result.exit_code == 0 + assert "Successfully serialized" in result.output + assert "Successfully packaged" in result.output + result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"]) + assert result.exit_code == 2 + assert "flyte-package.tgz already exists, specify -f to override" in result.output + result = runner.invoke( + pyflyte.main, + ["--pkgs", "core", "package", "--image", "core:v1", "--fast", "--force"], + ) + assert result.exit_code == 0 + assert "deleting and re-creating it" in result.output + shutil.rmtree("core") + + def test_duplicate_registrable_entities(): @flytekit.task def t_1(): @@ -114,3 +156,11 @@ def test_package(): def test_pkgs(): pp = pyflyte.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"]) assert pp == ["a.b", "a.c", "b.a", "cc.a"] + + +def test_package_with_no_pkgs(): + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke(pyflyte.main, ["package"]) + assert result.exit_code == 1 + assert "No packages to scan for flyte entities. Aborting!" in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py new file mode 100644 index 0000000000..d078851e1b --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -0,0 +1,65 @@ +import os +import shutil +import subprocess + +import mock +from click.testing import CliRunner + +from flytekit.clients.friendly import SynchronousFlyteClient +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.remote.remote import FlyteRemote + +sample_file_contents = """ +from flytekit import task, workflow + +@task(cache=True, cache_version="1", retries=3) +def sum(x: int, y: int) -> int: + return x + y + +@task(cache=True, cache_version="1", retries=3) +def square(z: int) -> int: + return z*z + +@workflow +def my_workflow(x: int, y: int) -> int: + return sum(x=square(z=x), y=square(z=y)) +""" + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") +def test_saving_remote(mock_remote): + mock_context = mock.MagicMock + mock_context.obj = {} + get_and_save_remote_with_click_context(mock_context, "p", "d") + assert mock_context.obj["flyte_remote"] is not None + + +def test_register_with_no_package_or_module_argument(): + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke(pyflyte.main, ["register"]) + assert result.exit_code == 1 + assert ( + "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed" + in result.output + ) + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_with_no_output_dir_passed(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core"]) + assert "Output given as None, using a temporary directory at" in result.output + shutil.rmtree("core") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 8fcbd3667c..ec35f5362d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,11 +1,14 @@ import os +import pathlib +import pytest from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.run import get_entities_in_file WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +DIR_NAME = os.path.dirname(os.path.realpath(__file__)) def test_pyflyte_run_wf(): @@ -18,12 +21,11 @@ def test_pyflyte_run_wf(): def test_pyflyte_run_cli(): runner = CliRunner() - dir_name = os.path.dirname(os.path.realpath(__file__)) result = runner.invoke( pyflyte.main, [ "run", - os.path.join(dir_name, "workflow.py"), + WORKFLOW_FILE, "my_wf", "--a", "1", @@ -38,7 +40,7 @@ def test_pyflyte_run_cli(): "--f", '{"x":1.0, "y":2.0}', "--g", - os.path.join(dir_name, "testdata/df.parquet"), + os.path.join(DIR_NAME, "testdata/df.parquet"), "--i", "2020-05-01", "--j", @@ -46,9 +48,9 @@ def test_pyflyte_run_cli(): "--k", "RED", "--remote", - os.path.join(dir_name, "testdata"), + os.path.join(DIR_NAME, "testdata"), "--image", - os.path.join(dir_name, "testdata"), + os.path.join(DIR_NAME, "testdata"), "--h", ], catch_exceptions=False, @@ -57,8 +59,116 @@ def test_pyflyte_run_cli(): assert result.exit_code == 0 +@pytest.mark.parametrize( + "input", + ["1", os.path.join(DIR_NAME, "testdata/df.parquet"), '{"x":1.0, "y":2.0}', "2020-05-01", "RED"], +) +def test_union_type1(input): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "test_union1", + "--a", + input, + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "input", + [2.0, '{"i":1,"a":["h","e"]}', "[1, 2, 3]"], +) +def test_union_type2(input): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "test_union2", + "--a", + input, + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0 + + +def test_union_type_with_invalid_input(): + runner = CliRunner() + with pytest.raises(ValueError, match="Failed to convert python type typing.Union"): + runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "test_union2", + "--a", + "hello", + ], + catch_exceptions=False, + ) + + def test_get_entities_in_file(): e = get_entities_in_file(WORKFLOW_FILE) assert e.workflows == ["my_wf"] - assert e.tasks == ["get_subset_df", "print_all", "show_sd"] - assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"] + assert e.tasks == ["get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"] + assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"] + + +@pytest.mark.parametrize( + "working_dir, wf_path", + [ + (pathlib.Path("test_nested_wf"), os.path.join("a", "b", "c", "d", "wf.py")), + (pathlib.Path("test_nested_wf", "a"), os.path.join("b", "c", "d", "wf.py")), + (pathlib.Path("test_nested_wf", "a", "b"), os.path.join("c", "d", "wf.py")), + (pathlib.Path("test_nested_wf", "a", "b", "c"), os.path.join("d", "wf.py")), + (pathlib.Path("test_nested_wf", "a", "b", "c", "d"), os.path.join("wf.py")), + ], +) +def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch): + runner = CliRunner() + base_path = os.path.dirname(os.path.realpath(__file__)) + # Change working directory without side-effects (i.e. just for this test) + monkeypatch.chdir(os.path.join(base_path, working_dir)) + result = runner.invoke( + pyflyte.main, + [ + "run", + wf_path, + "wf_id", + "--m", + "wow", + ], + catch_exceptions=False, + ) + assert result.stdout.strip() == "wow" + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "wf_path", + [("collection_wf.py"), ("map_wf.py"), ("dataclass_wf.py")], +) +def test_list_default_arguments(wf_path): + runner = CliRunner() + dir_name = os.path.dirname(os.path.realpath(__file__)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(dir_name, "default_arguments", wf_path), + "wf", + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index cf9f87d37e..18d29f648d 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -58,6 +58,16 @@ def print_all( print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}") +@task +def test_union1(a: typing.Union[int, FlyteFile, typing.Dict[str, float], datetime.datetime, Color]): + print(a) + + +@task +def test_union2(a: typing.Union[float, typing.List[int], MyDataclass]): + print(a) + + @workflow def my_wf( a: int, diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 4899051169..6636c1e0c8 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -9,6 +9,7 @@ from flytekit.clients.raw import RawSynchronousFlyteClient, get_basic_authorization_header, get_token from flytekit.configuration import AuthType, PlatformConfig +from flytekit.configuration.internal import Credentials def get_admin_stub_mock() -> mock.MagicMock: @@ -49,17 +50,21 @@ def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_ad assert client.check_access_token("abc") +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.set_access_token") +@mock.patch("flytekit.clients.raw.auth_service") @mock.patch("subprocess.run") -def test_refresh_credentials_from_command(mock_call_to_external_process): - command = ["command", "generating", "token"] +def test_refresh_credentials_from_command(mock_call_to_external_process, mock_admin_auth, mock_set_access_token): token = "token" + command = ["command", "generating", "token"] - mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) + mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() + client = RawSynchronousFlyteClient(PlatformConfig(command=command)) - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=command)) - cc._refresh_credentials_from_command() + mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) + client._refresh_credentials_from_command() mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True) + mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) @mock.patch("flytekit.clients.raw.dataproxy_service") @@ -195,6 +200,14 @@ def test_basic_strings(mocked_method): @patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") def test_refresh_command(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS)) + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND)) + cc.refresh_credentials() + assert mocked_method.called + + +@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") +def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv(Credentials.AUTH_MODE.legacy.get_env_name(), AuthType.EXTERNAL_PROCESS.name, prepend=False) + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None)) cc.refresh_credentials() assert mocked_method.called diff --git a/tests/flytekit/unit/configuration/configs/sample.yml b/tests/flytekit/unit/configuration/configs/sample.yml new file mode 100644 index 0000000000..ae51ab2ece --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/sample.yml @@ -0,0 +1,13 @@ +admin: + # For GRPC endpoints you might want to use dns:///flyte.myexample.com + endpoint: dns:///flyte.mycorp.io + authType: Pkce + insecure: true + clientId: propeller + scopes: + - all +storage: + connection: + access-key: minio + endpoint: http://localhost:30084 + secret-key: miniostorage diff --git a/tests/flytekit/unit/configuration/test_yaml_file.py b/tests/flytekit/unit/configuration/test_yaml_file.py index dcf65f2404..443c107120 100644 --- a/tests/flytekit/unit/configuration/test_yaml_file.py +++ b/tests/flytekit/unit/configuration/test_yaml_file.py @@ -23,7 +23,8 @@ def test_config_entry_file(): @mock.patch("flytekit.configuration.file.getenv") def test_config_entry_file_2(mock_get): # Test reading of the environment variable that flytectl asks users to set. - sample_yaml_file_name = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml") + # Can take both extensions + sample_yaml_file_name = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yml") mock_get.return_value = sample_yaml_file_name diff --git a/tests/flytekit/unit/core/configs/images.config b/tests/flytekit/unit/core/configs/images.config index ea6f31212e..b9fd3b5783 100644 --- a/tests/flytekit/unit/core/configs/images.config +++ b/tests/flytekit/unit/core/configs/images.config @@ -1,3 +1,4 @@ [images] xyz=docker.io/xyz:latest abc=docker.io/abc +xyz_123=docker.io/xyz_123:v1 diff --git a/tests/flytekit/unit/core/flyte_functools/__init__.py b/tests/flytekit/unit/core/flyte_functools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/core/functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py similarity index 100% rename from tests/flytekit/unit/core/functools/decorator_source.py rename to tests/flytekit/unit/core/flyte_functools/decorator_source.py diff --git a/tests/flytekit/unit/core/functools/decorator_usage.py b/tests/flytekit/unit/core/flyte_functools/decorator_usage.py similarity index 100% rename from tests/flytekit/unit/core/functools/decorator_usage.py rename to tests/flytekit/unit/core/flyte_functools/decorator_usage.py diff --git a/tests/flytekit/unit/core/functools/nested_function.py b/tests/flytekit/unit/core/flyte_functools/nested_function.py similarity index 100% rename from tests/flytekit/unit/core/functools/nested_function.py rename to tests/flytekit/unit/core/flyte_functools/nested_function.py diff --git a/tests/flytekit/unit/core/functools/nested_wrapped_function.py b/tests/flytekit/unit/core/flyte_functools/nested_wrapped_function.py similarity index 100% rename from tests/flytekit/unit/core/functools/nested_wrapped_function.py rename to tests/flytekit/unit/core/flyte_functools/nested_wrapped_function.py diff --git a/tests/flytekit/unit/core/functools/simple_decorator.py b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py similarity index 100% rename from tests/flytekit/unit/core/functools/simple_decorator.py rename to tests/flytekit/unit/core/flyte_functools/simple_decorator.py diff --git a/tests/flytekit/unit/core/functools/stacked_decorators.py b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py similarity index 100% rename from tests/flytekit/unit/core/functools/stacked_decorators.py rename to tests/flytekit/unit/core/flyte_functools/stacked_decorators.py diff --git a/tests/flytekit/unit/core/functools/test_decorator_location.py b/tests/flytekit/unit/core/flyte_functools/test_decorator_location.py similarity index 60% rename from tests/flytekit/unit/core/functools/test_decorator_location.py rename to tests/flytekit/unit/core/flyte_functools/test_decorator_location.py index d896d09592..7103eb14f2 100644 --- a/tests/flytekit/unit/core/functools/test_decorator_location.py +++ b/tests/flytekit/unit/core/flyte_functools/test_decorator_location.py @@ -4,14 +4,14 @@ def test_dont_use_wrapper_location(): - m = importlib.import_module("tests.flytekit.unit.core.functools.decorator_usage") + m = importlib.import_module("tests.flytekit.unit.core.flyte_functools.decorator_usage") get_data_task = getattr(m, "get_data") assert "decorator_source" not in get_data_task.name assert "decorator_usage" in get_data_task.name a, b, c, _ = extract_task_module(get_data_task) assert (a, b, c) == ( - "tests.flytekit.unit.core.functools.decorator_usage.get_data", - "tests.flytekit.unit.core.functools.decorator_usage", + "tests.flytekit.unit.core.flyte_functools.decorator_usage.get_data", + "tests.flytekit.unit.core.flyte_functools.decorator_usage", "get_data", ) diff --git a/tests/flytekit/unit/core/functools/test_decorators.py b/tests/flytekit/unit/core/flyte_functools/test_decorators.py similarity index 100% rename from tests/flytekit/unit/core/functools/test_decorators.py rename to tests/flytekit/unit/core/flyte_functools/test_decorators.py diff --git a/tests/flytekit/unit/core/functools/unwrapped_decorator.py b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py similarity index 100% rename from tests/flytekit/unit/core/functools/unwrapped_decorator.py rename to tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 37f5d10195..3963c77c8d 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,4 +1,6 @@ -import typing +from typing import Dict, List, NamedTuple, Optional, Union + +import pytest from flytekit.core import launch_plan from flytekit.core.task import task @@ -8,7 +10,7 @@ def test_wf1_with_subwf(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -33,7 +35,7 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = typing.NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", sub_int=int) @task def t1(a: int) -> nt: @@ -68,7 +70,7 @@ def my_wf(a: int) -> int: def test_lp_default_handling(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -130,7 +132,7 @@ def my_wf2(a: int, b: int = 42) -> (str, str, int, int): def test_wf1_with_lp_node(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -169,3 +171,30 @@ def my_wf3(a: int = 42) -> (int, str, str, str): return x, y, u, v assert my_wf2() == (44, "world-44", "world-5", "world-7") + + +def test_optional_input(): + @task() + def t1(a: Optional[int] = None, b: Optional[List[int]] = None, c: Optional[Dict[str, int]] = None) -> Optional[int]: + ... + + @task() + def t2(a: Union[int, Optional[List[int]], None] = None) -> Union[int, Optional[List[int]], None]: + ... + + @workflow + def wf(a: Optional[int] = 1) -> Optional[int]: + t1() + return t2(a=a) + + assert wf() is None + + with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): + + @task() + def t3(c: Optional[int] = 3) -> Optional[int]: + ... + + @workflow + def wf(): + return t3() diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 36a7153b8a..98af80638a 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -1,4 +1,5 @@ import os +from datetime import datetime import py import pytest @@ -11,7 +12,9 @@ SecretsConfig, SerializationSettings, ) -from flytekit.core.context_manager import FlyteContext, FlyteContextManager, SecretsManager +from flytekit.core import mock_stats +from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, SecretsManager +from flytekit.models.core import identifier as id_models class SampleTestClass(object): @@ -205,3 +208,18 @@ def test_serialization_settings_transport(): assert ss is not None assert ss == serialization_settings assert len(tp) == 376 + + +def test_exec_params(): + ep = ExecutionParameters( + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + task_id=id_models.Identifier(id_models.ResourceType.TASK, "local", "local", "local", "local"), + execution_date=datetime.utcnow(), + stats=mock_stats.MockStats(), + logging=None, + tmp_dir="/tmp", + raw_output_prefix="", + decks=[], + ) + + assert ep.task_id.name == "local" diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 2317bab1c2..62c45c346e 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -12,6 +12,7 @@ transform_variable_map, ) from flytekit.models.core import types as _core_types +from flytekit.models.literals import Void from flytekit.types.file import FlyteFile from flytekit.types.pickle import FlytePickle @@ -199,6 +200,20 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"] assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} + def z( + a: typing.Optional[int] = None, b: typing.Optional[str] = None, c: typing.Union[typing.List[int], None] = None + ) -> typing.Tuple[int, str]: + ... + + our_interface = transform_function_to_interface(z) + params = transform_inputs_to_parameters(ctx, our_interface) + assert not params.parameters["a"].required + assert params.parameters["a"].default.scalar.none_type == Void() + assert not params.parameters["b"].required + assert params.parameters["b"].default.scalar.none_type == Void() + assert not params.parameters["c"].required + assert params.parameters["c"].default.scalar.none_type == Void() + def test_parameters_with_docstring(): ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_numpy.py b/tests/flytekit/unit/core/test_numpy.py new file mode 100644 index 0000000000..2045a1229a --- /dev/null +++ b/tests/flytekit/unit/core/test_numpy.py @@ -0,0 +1,61 @@ +from collections import OrderedDict + +import numpy as np +from numpy.testing import assert_array_equal + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable +from flytekit.types.numpy import NumpyArrayTransformer + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def test_get_literal_type(): + tf = NumpyArrayTransformer() + lt = tf.get_literal_type(np.ndarray) + assert lt == LiteralType( + blob=BlobType( + format=NumpyArrayTransformer.NUMPY_ARRAY_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE + ) + ) + + +def test_to_python_value_and_literal(): + ctx = context_manager.FlyteContext.current_context() + tf = NumpyArrayTransformer() + python_val = np.array([1, 2, 3]) + lt = tf.get_literal_type(np.ndarray) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=NumpyArrayTransformer.NUMPY_ARRAY_FORMAT, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, np.ndarray) + assert_array_equal(output, python_val) + + +def test_example(): + @task + def t1(array: np.ndarray) -> np.ndarray: + return array.flatten() + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is NumpyArrayTransformer.NUMPY_ARRAY_FORMAT diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py new file mode 100644 index 0000000000..6354ca3247 --- /dev/null +++ b/tests/flytekit/unit/core/test_promise.py @@ -0,0 +1,83 @@ +import typing +from dataclasses import dataclass + +import pytest +from dataclasses_json import dataclass_json + +from flytekit import task +from flytekit.core import context_manager +from flytekit.core.context_manager import CompilationState +from flytekit.core.promise import VoidPromise, create_and_link_node, translate_inputs_to_literals +from flytekit.exceptions.user import FlyteAssertion + + +def test_create_and_link_node(): + @task + def t1(a: typing.Union[int, typing.List[int]]) -> typing.Union[int, typing.List[int]]: + return a + + with pytest.raises(FlyteAssertion, match="Cannot create node when not compiling..."): + ctx = context_manager.FlyteContext.current_context() + create_and_link_node(ctx, t1, a=3) + + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + p = create_and_link_node(ctx, t1, a=3) + assert p.ref.node_id == "n0" + assert p.ref.var == "o0" + assert len(p.ref.node.bindings) == 1 + + @task + def t2(a: typing.Optional[int] = None) -> typing.Union[int]: + return a + + p = create_and_link_node(ctx, t2) + assert p.ref.var == "o0" + assert len(p.ref.node.bindings) == 0 + + +@pytest.mark.parametrize( + "input", + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], +) +def test_translate_inputs_to_literals(input): + @dataclass_json + @dataclass + class MyDataclass(object): + i: int + a: typing.List[str] + + @task + def t1(a: typing.Union[float, typing.List[int], MyDataclass]): + print(a) + + ctx = context_manager.FlyteContext.current_context() + translate_inputs_to_literals(ctx, {"a": input}, t1.interface.inputs, t1.python_interface.inputs) + + +def test_translate_inputs_to_literals_with_wrong_types(): + ctx = context_manager.FlyteContext.current_context() + with pytest.raises(TypeError, match="Not a map type union_type"): + + @task + def t1(a: typing.Union[float, typing.List[int]]): + print(a) + + translate_inputs_to_literals(ctx, {"a": {"a": 3}}, t1.interface.inputs, t1.python_interface.inputs) + + with pytest.raises(TypeError, match="Not a collection type union_type"): + + @task + def t1(a: typing.Union[float, typing.Dict[str, int]]): + print(a) + + translate_inputs_to_literals(ctx, {"a": [1, 2, 3]}, t1.interface.inputs, t1.python_interface.inputs) + + with pytest.raises( + AssertionError, match="Outputs of a non-output producing task n0 cannot be passed to another task" + ): + + @task + def t1(a: typing.Union[float, typing.Dict[str, int]]): + print(a) + + translate_inputs_to_literals(ctx, {"a": VoidPromise("n0")}, t1.interface.inputs, t1.python_interface.inputs) diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 2ddf4670c1..108b42ebf7 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -19,6 +19,11 @@ def default_serialization_settings(default_image_config): ) +@pytest.fixture +def minimal_serialization_settings(default_image_config): + return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config) + + def test_image_name_interpolation(default_image_config): img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special" img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config) @@ -31,6 +36,7 @@ def execute(self, **kwargs) -> Any: task = DummyAutoContainerTask(name="x", task_config=None, task_type="t") +task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") def test_default_command(default_serialization_settings): @@ -55,3 +61,21 @@ def test_default_command(default_serialization_settings): "task-name", "task", ] + + +def test_get_container(default_serialization_settings): + c = task.get_container(default_serialization_settings) + assert c.image == "docker.io/xyz:some-git-hash" + assert c.env == {"FOO": "bar"} + + +def test_get_container_with_task_envvars(default_serialization_settings): + c = task_with_env_vars.get_container(default_serialization_settings) + assert c.image == "docker.io/xyz:some-git-hash" + assert c.env == {"FOO": "bar", "HAM": "spam"} + + +def test_get_container_without_serialization_settings_envvars(minimal_serialization_settings): + c = task_with_env_vars.get_container(minimal_serialization_settings) + assert c.image == "docker.io/xyz:some-git-hash" + assert c.env == {"HAM": "spam"} diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index efcd6cd823..a69ec6665b 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -250,6 +250,10 @@ def t4(): def t5(a: int) -> int: return a + @task(container_image="{{.image.xyz_123.fqn}}:{{.image.xyz_123.version}}") + def t6(a: int) -> int: + return a + os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" imgs = ImageConfig.auto( config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config") @@ -274,6 +278,9 @@ def t5(a: int) -> int: t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:latest" + t5_spec = get_serializable(OrderedDict(), rs, t6) + assert t5_spec.template.container.image == "docker.io/xyz_123:v1" + def test_serialization_command1(): @task diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d99efe9342..df8e14d7cb 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -539,6 +539,8 @@ class TestInnerFileStruct(object): a: JPEGImageFile b: typing.List[FlyteFile] c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] @dataclass_json @dataclass @@ -546,8 +548,14 @@ class TestFileStruct(object): a: FlyteFile b: TestInnerFileStruct - f = FlyteFile("s3://tmp/file") - o = TestFileStruct(a=f, b=TestInnerFileStruct(a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f], c={"hello": f})) + remote_path = "s3://tmp/file" + f1 = FlyteFile(remote_path) + f2 = FlyteFile("/tmp/file") + f2._remote_source = remote_path + o = TestFileStruct( + a=f1, + b=TestInnerFileStruct(a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2}), + ) ctx = FlyteContext.current_context() tf = DataclassTransformer() @@ -563,6 +571,10 @@ class TestFileStruct(object): assert o.b.a.path == ot.b.a.remote_source assert o.b.b[0].path == ot.b.b[0].remote_source assert o.b.c["hello"].path == ot.b.c["hello"].remote_source + assert ot.b.d[0].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.d[0].path) + assert ot.b.e["hello"].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.e["hello"].path) def test_flyte_directory_in_dataclass(): @@ -572,6 +584,8 @@ class TestInnerFileStruct(object): a: TensorboardLogs b: typing.List[FlyteDirectory] c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] @dataclass_json @dataclass @@ -579,9 +593,15 @@ class TestFileStruct(object): a: FlyteDirectory b: TestInnerFileStruct + remote_path = "s3://tmp/file" tempdir = tempfile.mkdtemp(prefix="flyte-") - f = FlyteDirectory(tempdir) - o = TestFileStruct(a=f, b=TestInnerFileStruct(a=TensorboardLogs("s3://tensorboard"), b=[f], c={"hello": f})) + f1 = FlyteDirectory(tempdir) + f1._remote_source = remote_path + f2 = FlyteDirectory(remote_path) + o = TestFileStruct( + a=f1, + b=TestInnerFileStruct(a=TensorboardLogs("s3://tensorboard"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2}), + ) ctx = FlyteContext.current_context() tf = DataclassTransformer() @@ -594,10 +614,15 @@ class TestFileStruct(object): assert ot.b.b[0]._downloader is not noop assert ot.b.c["hello"]._downloader is not noop - assert o.a.path == ot.a.path + assert o.a.remote_directory == ot.a.remote_directory + assert not ctx.file_access.is_remote(ot.a.path) assert o.b.a.path == ot.b.a.remote_source - assert o.b.b[0].path == ot.b.b[0].path - assert o.b.c["hello"].path == ot.b.c["hello"].path + assert o.b.b[0].remote_directory == ot.b.b[0].remote_directory + assert not ctx.file_access.is_remote(ot.b.b[0].path) + assert o.b.c["hello"].remote_directory == ot.b.c["hello"].remote_directory + assert not ctx.file_access.is_remote(ot.b.c["hello"].path) + assert o.b.d[0].path == ot.b.d[0].remote_source + assert o.b.e["hello"].path == ot.b.e["hello"].remote_source def test_structured_dataset_in_dataclass(): @@ -607,6 +632,8 @@ def test_structured_dataset_in_dataclass(): @dataclass class InnerDatasetStruct(object): a: StructuredDataset + b: typing.List[StructuredDataset] + c: typing.Dict[str, StructuredDataset] @dataclass_json @dataclass @@ -615,7 +642,7 @@ class DatasetStruct(object): b: InnerDatasetStruct sd = StructuredDataset(dataframe=df, file_format="parquet") - o = DatasetStruct(a=sd, b=InnerDatasetStruct(a=sd)) + o = DatasetStruct(a=sd, b=InnerDatasetStruct(a=sd, b=[sd], c={"hello": sd})) ctx = FlyteContext.current_context() tf = DataclassTransformer() @@ -625,8 +652,12 @@ class DatasetStruct(object): assert_frame_equal(df, ot.a.open(pd.DataFrame).all()) assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all()) assert "parquet" == ot.a.file_format assert "parquet" == ot.b.a.file_format + assert "parquet" == ot.b.b[0].file_format + assert "parquet" == ot.b.c["hello"].file_format # Enums should have string values diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 077b19f8b5..abdf69f5b0 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -16,7 +16,7 @@ from dataclasses_json import dataclass_json from google.protobuf.struct_pb2 import Struct from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_origin import flytekit import flytekit.configuration @@ -89,6 +89,15 @@ def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int") assert context_manager.FlyteContextManager.size() == 1 +def test_annotated_namedtuple_output(): + @task + def my_task(a: int) -> typing.NamedTuple("OutputA", a=Annotated[int, "metadata-a"]): + return a + 2 + + assert my_task(a=9) == (11,) + assert get_origin(my_task.python_interface.outputs["a"]) is Annotated + + def test_simple_input_no_output(): @task def my_task(a: int): diff --git a/tests/flytekit/unit/deck/test_deck.py b/tests/flytekit/unit/deck/test_deck.py index 59218c5b43..3db311653c 100644 --- a/tests/flytekit/unit/deck/test_deck.py +++ b/tests/flytekit/unit/deck/test_deck.py @@ -1,6 +1,8 @@ import pandas as pd +from mock import mock -from flytekit import Deck, FlyteContextManager +import flytekit +from flytekit import Deck, FlyteContextManager, task from flytekit.deck import TopFrameRenderer from flytekit.deck.deck import _output_deck @@ -18,3 +20,28 @@ def test_deck(): assert len(ctx.user_space_params.decks) == 2 _output_deck("test_task", ctx.user_space_params) + + @task() + def t1(a: int) -> str: + return str(a) + + t1(a=3) + assert len(ctx.user_space_params.decks) == 2 # input, output decks + + +@mock.patch("flytekit.deck.deck._ipython_check") +def test_deck_in_jupyter(mock_ipython_check): + mock_ipython_check.return_value = True + + ctx = FlyteContextManager.current_context() + ctx.user_space_params._decks = [ctx.user_space_params.default_deck] + _output_deck("test_task", ctx.user_space_params) + + @task() + def t1(a: int) -> str: + return str(a) + + with flytekit.new_context() as ctx: + t1(a=3) + deck = ctx.get_deck() + assert deck is not None diff --git a/tests/flytekit/unit/extras/pytorch/__init__.py b/tests/flytekit/unit/extras/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/pytorch/test_checkpoint.py b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py new file mode 100644 index 0000000000..49ad083285 --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py @@ -0,0 +1,105 @@ +from dataclasses import asdict, dataclass +from typing import NamedTuple + +import pytest +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from dataclasses_json import dataclass_json + +from flytekit import task, workflow +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extras.pytorch import PyTorchCheckpoint + + +@dataclass_json +@dataclass +class Hyperparameters: + epochs: int + loss: float + + +class TupleHyperparameters(NamedTuple): + epochs: int + loss: float + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +@task +def generate_model_dict(hyperparameters: Hyperparameters) -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=asdict(hyperparameters), optimizer=optimizer) + + +@task +def generate_model_tuple() -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=TupleHyperparameters(epochs=5, loss=0.4), optimizer=optimizer) + + +@task +def generate_model_dataclass(hyperparameters: Hyperparameters) -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=hyperparameters, optimizer=optimizer) + + +@task +def generate_model_only_module() -> PyTorchCheckpoint: + bn = Net() + return PyTorchCheckpoint(module=bn) + + +@task +def empty_checkpoint(): + with pytest.raises(TypeTransformerFailedError): + return PyTorchCheckpoint() + + +@task +def t1(checkpoint: PyTorchCheckpoint): + new_bn = Net() + new_bn.load_state_dict(checkpoint["module_state_dict"]) + optimizer = optim.SGD(new_bn.parameters(), lr=0.001, momentum=0.9) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + assert checkpoint["epochs"] == 5 + assert checkpoint["loss"] == 0.4 + + +@workflow +def wf(): + checkpoint_dict = generate_model_dict(hyperparameters=Hyperparameters(epochs=5, loss=0.4)) + checkpoint_tuple = generate_model_tuple() + checkpoint_dataclass = generate_model_dataclass(hyperparameters=Hyperparameters(epochs=5, loss=0.4)) + t1(checkpoint=checkpoint_dict) + t1(checkpoint=checkpoint_tuple) + t1(checkpoint=checkpoint_dataclass) + generate_model_only_module() + empty_checkpoint() + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/pytorch/test_native.py b/tests/flytekit/unit/extras/pytorch/test_native.py new file mode 100644 index 0000000000..9d44ed1c1f --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_native.py @@ -0,0 +1,73 @@ +import torch + +from flytekit import task, workflow + + +@task +def generate_tensor_1d() -> torch.Tensor: + return torch.zeros(5, dtype=torch.int32) + + +@task +def generate_tensor_2d() -> torch.Tensor: + return torch.tensor([[1.0, -1.0, 2], [1.0, -1.0, 9], [0, 7.0, 3]]) + + +@task +def generate_module() -> torch.nn.Module: + bn = torch.nn.BatchNorm1d(3, track_running_stats=True) + return bn + + +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.l0 = torch.nn.Linear(4, 2) + self.l1 = torch.nn.Linear(2, 1) + + def forward(self, input): + out0 = self.l0(input) + out0_relu = torch.nn.functional.relu(out0) + return self.l1(out0_relu) + + +@task +def generate_model() -> torch.nn.Module: + return MyModel() + + +@task +def t1(tensor: torch.Tensor) -> torch.Tensor: + assert tensor.dtype == torch.int32 + tensor[0] = 1 + return tensor + + +@task +def t2(tensor: torch.Tensor) -> torch.Tensor: + # convert 2D to 3D + tensor.unsqueeze_(-1) + return tensor.expand(3, 3, 2) + + +@task +def t3(model: torch.nn.Module) -> torch.Tensor: + return model.weight + + +@task +def t4(model: torch.nn.Module) -> torch.nn.Module: + return model.l1 + + +@workflow +def wf(): + t1(tensor=generate_tensor_1d()) + t2(tensor=generate_tensor_2d()) + t3(model=generate_module()) + t4(model=MyModel()) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py new file mode 100644 index 0000000000..1a3a83ab93 --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -0,0 +1,130 @@ +from collections import OrderedDict + +import pytest +import torch + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.pytorch import ( + PyTorchCheckpoint, + PyTorchCheckpointTransformer, + PyTorchModuleTransformer, + PyTorchTensorTransformer, +) +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (PyTorchTensorTransformer(), torch.Tensor, PyTorchTensorTransformer.PYTORCH_FORMAT), + (PyTorchModuleTransformer(), torch.nn.Module, PyTorchModuleTransformer.PYTORCH_FORMAT), + (PyTorchCheckpointTransformer(), PyTorchCheckpoint, PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + ( + PyTorchTensorTransformer(), + torch.Tensor, + PyTorchTensorTransformer.PYTORCH_FORMAT, + torch.tensor([[1, 2], [3, 4]]), + ), + ( + PyTorchModuleTransformer(), + torch.nn.Module, + PyTorchModuleTransformer.PYTORCH_FORMAT, + torch.nn.Linear(2, 2), + ), + ( + PyTorchCheckpointTransformer(), + PyTorchCheckpoint, + PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT, + PyTorchCheckpoint( + module=torch.nn.Linear(2, 2), + hyperparameters={"epochs": 10, "batch_size": 32}, + optimizer=torch.optim.Adam(torch.nn.Linear(2, 2).parameters()), + ), + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + python_val = python_val + lt = tf.get_literal_type(python_type) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, python_type) + if isinstance(python_val, torch.Tensor): + assert torch.equal(output, python_val) + elif isinstance(python_val, torch.nn.Module): + for p1, p2 in zip(output.parameters(), python_val.parameters()): + if p1.data.ne(p2.data).sum() > 0: + assert False + assert True + else: + assert isinstance(output, dict) + + +def test_example_tensor(): + @task + def t1(array: torch.Tensor) -> torch.Tensor: + return torch.flatten(array) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is PyTorchTensorTransformer.PYTORCH_FORMAT + + +def test_example_module(): + @task + def t1() -> torch.nn.Module: + return torch.nn.BatchNorm1d(3, track_running_stats=True) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is PyTorchModuleTransformer.PYTORCH_FORMAT + + +def test_example_checkpoint(): + @task + def t1() -> PyTorchCheckpoint: + return PyTorchCheckpoint( + module=torch.nn.Linear(2, 2), + hyperparameters={"epochs": 10, "batch_size": 32}, + optimizer=torch.optim.Adam(torch.nn.Linear(2, 2).parameters()), + ) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert ( + task_spec.template.interface.outputs["o0"].type.blob.format + is PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT + ) diff --git a/tests/flytekit/unit/remote/responses/admin.task_pb2.Task.pb b/tests/flytekit/unit/remote/responses/admin.task_pb2.Task.pb new file mode 100644 index 0000000000..35acecdeb1 Binary files /dev/null and b/tests/flytekit/unit/remote/responses/admin.task_pb2.Task.pb differ diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index b2d1644d97..00258cb85e 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -1,7 +1,12 @@ +import os +import pathlib +import tempfile + import pytest from mock import MagicMock, patch -from flytekit.configuration import Config +import flytekit.configuration +from flytekit.configuration import Config, DefaultImages, ImageConfig from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models import security @@ -150,3 +155,59 @@ def test_passing_of_kwargs(mock_client): FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) assert mock_client.called assert mock_client.call_args[1] == additional_args + + +@patch("flytekit.remote.remote.SynchronousFlyteClient") +def test_more_stuff(mock_client): + r = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") + + # Can't upload a folder + with pytest.raises(ValueError): + with tempfile.TemporaryDirectory() as tmp_dir: + r._upload_file(pathlib.Path(tmp_dir)) + + # Test that this copies the file. + with tempfile.TemporaryDirectory() as tmp_dir: + mm = MagicMock() + mm.signed_url = os.path.join(tmp_dir, "tmp_file") + mock_client.return_value.get_upload_signed_url.return_value = mm + + r._upload_file(pathlib.Path(__file__)) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), + ) + + # gives a thing + computed_v = r._version_from_hash(b"", serialization_settings) + assert len(computed_v) > 0 + + # gives the same thing + computed_v2 = r._version_from_hash(b"", serialization_settings) + assert computed_v2 == computed_v2 + + # should give a different thing + computed_v3 = r._version_from_hash(b"", serialization_settings, "hi") + assert computed_v2 != computed_v3 + + +@patch("flytekit.remote.remote.SynchronousFlyteClient") +def test_generate_http_domain_sandbox_rewrite(mock_client): + _, temp_filename = tempfile.mkstemp(suffix=".yaml") + with open(temp_filename, "w") as f: + # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. + flytectl_config_file = """admin: + endpoint: localhost:30081 + authType: Pkce + insecure: true + """ + f.write(flytectl_config_file) + + remote = FlyteRemote( + config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" + ) + assert remote.generate_http_domain() == "http://localhost:30080" diff --git a/tests/flytekit/unit/remote/test_with_responses.py b/tests/flytekit/unit/remote/test_with_responses.py index 3a4534a11a..ee3fbb4d8a 100644 --- a/tests/flytekit/unit/remote/test_with_responses.py +++ b/tests/flytekit/unit/remote/test_with_responses.py @@ -1,13 +1,21 @@ import os +import typing +from collections import OrderedDict import mock -from flyteidl.admin import launch_plan_pb2, workflow_pb2 +from flyteidl.admin import launch_plan_pb2, task_pb2, workflow_pb2 -from flytekit.configuration import Config +import flytekit.configuration +from flytekit.configuration import Config, ImageConfig +from flytekit.configuration.default_images import DefaultImages +from flytekit.core.node_creation import create_node from flytekit.core.utils import load_proto_from_file +from flytekit.core.workflow import workflow from flytekit.models import launch_plan as launch_plan_models +from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.remote.remote import FlyteRemote +from flytekit.tools.translator import get_serializable rr = FlyteRemote( Config.for_sandbox(), @@ -35,3 +43,42 @@ def test_fetch_wf_wf_lp_pattern(mock_client): mock_client.get_launch_plan.return_value = leaf_lp fwf = rr.fetch_workflow(name="core.control_flow.subworkflows.root_level_wf", version="JiepXcXB3SiEJ8pwYDy-7g==") assert len(fwf.sub_workflows) == 2 + + +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_task(mock_client): + merge_sort_remotely = load_proto_from_file( + task_pb2.Task, + os.path.join(responses_dir, "admin.task_pb2.Task.pb"), + ) + admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) + mock_client.get_task.return_value = admin_task + ft = rr.fetch_task(name="merge_sort_remotely", version="tst") + assert len(ft.interface.inputs) == 2 + assert len(ft.interface.outputs) == 1 + + +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_normal_task(mock_client): + merge_sort_remotely = load_proto_from_file( + task_pb2.Task, + os.path.join(responses_dir, "admin.task_pb2.Task.pb"), + ) + admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) + mock_client.get_task.return_value = admin_task + ft = rr.fetch_task(name="merge_sort_remotely", version="tst") + + @workflow + def my_wf(numbers: typing.List[int], run_local_at_count: int) -> typing.List[int]: + t1_node = create_node(ft, numbers=numbers, run_local_at_count=run_local_at_count) + return t1_node.o0 + + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert wf_spec.template.nodes[0].task_node.reference_id.name == "merge_sort_remotely" diff --git a/tests/flytekit/unit/tools/test_module_loader.py b/tests/flytekit/unit/tools/test_module_loader.py index 9a568f0260..1fc2836232 100644 --- a/tests/flytekit/unit/tools/test_module_loader.py +++ b/tests/flytekit/unit/tools/test_module_loader.py @@ -1,44 +1,6 @@ -import os -import sys - -from flytekit.core import utils from flytekit.tools import module_loader -def test_module_loading(): - with utils.AutoDeletingTempDir("mypackage") as pkg: - path = pkg.name - # Create directories - top_level = os.path.join(path, "top") - middle_level = os.path.join(top_level, "middle") - bottom_level = os.path.join(middle_level, "bottom") - os.makedirs(bottom_level) - - # Create init files - with open(os.path.join(path, "__init__.py"), "w"): - pass - with open(os.path.join(top_level, "__init__.py"), "w"): - pass - with open(os.path.join(top_level, "a.py"), "w"): - pass - with open(os.path.join(middle_level, "__init__.py"), "w"): - pass - with open(os.path.join(middle_level, "a.py"), "w"): - pass - with open(os.path.join(bottom_level, "__init__.py"), "w"): - pass - with open(os.path.join(bottom_level, "a.py"), "w"): - pass - - sys.path.append(path) - - # Not a sufficient test but passes for now - assert sum(1 for _ in module_loader.iterate_modules(["top"])) == 6 - assert [ - pkg.__file__ for pkg in module_loader.iterate_modules(["top.a", "top.middle.a", "top.middle.bottom.a"]) - ] == [os.path.join(lvl, "a.py") for lvl in (top_level, middle_level, bottom_level)] - - def test_load_object(): loader_self = module_loader.load_object_from_module(f"{module_loader.__name__}.load_object_from_module") assert loader_self.__module__ == f"{module_loader.__name__}" diff --git a/tests/flytekit/unit/tools/test_repo.py b/tests/flytekit/unit/tools/test_repo.py new file mode 100644 index 0000000000..8bb6bd773a --- /dev/null +++ b/tests/flytekit/unit/tools/test_repo.py @@ -0,0 +1,70 @@ +import os +import pathlib +import tempfile + +import mock +import pytest + +import flytekit.configuration +from flytekit.configuration import DefaultImages, ImageConfig +from flytekit.tools.repo import find_common_root, load_packages_and_modules + +task_text = """ +from flytekit import task +@task +def t1(a: int): + ... +""" + + +# Mock out the entities so the load function doesn't try to load everything +@mock.patch("flytekit.core.context_manager.FlyteEntities") +@mock.patch("flytekit.core.base_task.FlyteEntities") +def test_module_loading(mock_entities, mock_entities_2): + entities = [] + mock_entities.entities = entities + mock_entities_2.entities = entities + with tempfile.TemporaryDirectory() as tmp_dir: + # Create directories + top_level = os.path.join(tmp_dir, "top") + middle_level = os.path.join(top_level, "middle") + bottom_level = os.path.join(middle_level, "bottom") + os.makedirs(bottom_level) + + top_level_2 = os.path.join(tmp_dir, "top2") + middle_level_2 = os.path.join(top_level_2, "middle") + os.makedirs(middle_level_2) + + # Create init files + pathlib.Path(os.path.join(top_level, "__init__.py")).touch() + pathlib.Path(os.path.join(top_level, "a.py")).touch() + pathlib.Path(os.path.join(middle_level, "__init__.py")).touch() + pathlib.Path(os.path.join(middle_level, "a.py")).touch() + pathlib.Path(os.path.join(bottom_level, "__init__.py")).touch() + pathlib.Path(os.path.join(bottom_level, "a.py")).touch() + with open(os.path.join(bottom_level, "a.py"), "w") as fh: + fh.write(task_text) + pathlib.Path(os.path.join(middle_level_2, "__init__.py")).touch() + + # Because they have different roots + with pytest.raises(ValueError): + find_common_root([middle_level_2, bottom_level]) + + # But now add one more init file + pathlib.Path(os.path.join(top_level_2, "__init__.py")).touch() + + # Now it should pass + root = find_common_root([middle_level_2, bottom_level]) + assert pathlib.Path(root).resolve() == pathlib.Path(tmp_dir).resolve() + + # Now load them + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), + ) + + x = load_packages_and_modules(serialization_settings, pathlib.Path(root), [bottom_level]) + assert len(x) == 1 diff --git a/tests/flytekit/unit/types/numpy/test_ndarray.py b/tests/flytekit/unit/types/numpy/test_ndarray.py new file mode 100644 index 0000000000..761cb74812 --- /dev/null +++ b/tests/flytekit/unit/types/numpy/test_ndarray.py @@ -0,0 +1,71 @@ +import numpy as np + +from flytekit import task, workflow + + +@task +def generate_numpy_1d() -> np.ndarray: + return np.array([1, 2, 3, 4, 5, 6], dtype=int) + + +@task +def generate_numpy_2d() -> np.ndarray: + return np.array([[1.8, 2.9, 3.1], [5.4, 6.0, 7.7]]) + + +@task +def generate_numpy_dtype_object() -> np.ndarray: + # dtype=object cannot be serialized + return np.array( + [ + [ + 405, + 162, + 414, + 0, + np.array([list([1, 9, 2]), 18, (405, 18, 207), 64, "Universal"], dtype=object), + 0, + 0, + 0, + ] + ], + dtype=object, + ) + + +@task +def t1(array: np.ndarray) -> np.ndarray: + assert array.dtype == int + output = np.empty(len(array)) + for i in range(len(array)): + output[i] = 1.0 / array[i] + return output + + +@task +def t2(array: np.ndarray) -> np.ndarray: + return array.flatten() + + +@task +def t3(array: np.ndarray) -> np.ndarray: + # convert 1D numpy array to 3D + return array.reshape(2, 3) + + +@workflow +def wf(): + array_1d = generate_numpy_1d() + array_2d = generate_numpy_2d() + try: + generate_numpy_dtype_object() + except Exception as e: + assert isinstance(e, TypeError) + t1(array=array_1d) + t2(array=array_2d) + t3(array=array_1d) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index 765d10b538..f0d58eb36d 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -1,6 +1,8 @@ import os import typing +import pytest + try: from typing import Annotated except ImportError: @@ -69,42 +71,42 @@ def decode( StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) -class NumpyEncodingHandlers(StructuredDatasetEncoder): - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - df = typing.cast(np.ndarray, structured_dataset.dataframe) - name = ["col" + str(i) for i in range(len(df))] - table = pa.Table.from_arrays(df, name) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - pq.write_table(table, local_path) - ctx.file_access.upload_directory(local_dir, path) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class NumpyDecodingHandlers(StructuredDatasetDecoder): - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) - table = pq.read_table(local_dir) - return table.to_pandas().to_numpy() - - -for protocol in [LOCAL, S3]: - StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) - StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) +@pytest.fixture(autouse=True) +def numpy_type(): + class NumpyEncodingHandlers(StructuredDatasetEncoder): + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + df = typing.cast(np.ndarray, structured_dataset.dataframe) + name = ["col" + str(i) for i in range(len(df))] + table = pa.Table.from_arrays(df, name) + local_dir = ctx.file_access.get_random_local_directory() + local_path = os.path.join(local_dir, f"{0:05}") + pq.write_table(table, local_path) + ctx.file_access.upload_directory(local_dir, path) + structured_dataset_type.format = PARQUET + return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + class NumpyDecodingHandlers(StructuredDatasetDecoder): + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> typing.Union[DF, typing.Generator[DF, None, None]]: + path = flyte_value.uri + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(path, local_dir, is_multipart=True) + table = pq.read_table(local_dir) + return table.to_pandas().to_numpy() + + for protocol in [LOCAL, S3]: + StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) + StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) @task