diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 4c16153be9..2344b15391 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -17,14 +17,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest] python-version: ["3.7", "3.8", "3.9", "3.10"] - spark-version-suffix: ["", "-spark2"] exclude: - - python-version: 3.8 - spark-version-suffix: "-spark2" - - python-version: 3.9 - spark-version-suffix: "-spark2" - - python-version: 3.10 - spark-version-suffix: "-spark2" # Ignore this test because we failed to install docker-py # docker-py will install pywin32==227, whereas pywin only added support for python 3.10 in version 301. # For more detail, see https://github.com/flyteorg/flytekit/pull/856#issuecomment-1067152855 @@ -42,10 +35,10 @@ jobs: # This path is specific to Ubuntu path: ~/.cache/pip # Look to see if there is a cache hit for the corresponding requirements files - key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('requirements{0}.txt', matrix.spark-version-suffix))) }} + key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }} - name: Install dependencies run: | - make setup${{ matrix.spark-version-suffix }} + make setup pip freeze - name: Test with coverage run: | @@ -67,6 +60,7 @@ jobs: - flytekit-aws-batch - flytekit-aws-sagemaker - flytekit-bigquery + - flytekit-dask - flytekit-data-fsspec - flytekit-dbt - flytekit-deck-standard @@ -156,13 +150,13 @@ jobs: - uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.txt') }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.in') }} restore-keys: | ${{ runner.os }}-pip- - name: Install dependencies run: | - python -m pip install --upgrade pip==21.2.4 - pip install -r dev-requirements.txt + python -m pip install --upgrade pip + pip install -r dev-requirements.in - name: Lint run: | make lint diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39470b7370..1fd6e6b648 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: hooks: - id: black - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/.readthedocs.yml b/.readthedocs.yml index 86a85609d7..19b1898e94 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -18,3 +18,4 @@ sphinx: python: install: - requirements: doc-requirements.txt + - requirements: docs/requirements.txt diff --git a/CODEOWNERS b/CODEOWNERS index 9389524869..a9aab29ffd 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence. -* @wild-endeavor @kumare3 @eapolinario @pingsutw +* @wild-endeavor @kumare3 @eapolinario @pingsutw @cosmicBboy diff --git a/Dockerfile b/Dockerfile index 82f4fe5366..6c3228ad2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,10 @@ FROM python:${PYTHON_VERSION}-slim-buster MAINTAINER Flyte Team LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit + WORKDIR /root ENV PYTHONPATH /root diff --git a/Makefile b/Makefile index 8254904469..76a790d7b5 100644 --- a/Makefile +++ b/Makefile @@ -22,11 +22,7 @@ update_boilerplate: .PHONY: setup setup: install-piptools ## Install requirements - pip-sync requirements.txt dev-requirements.txt - -.PHONY: setup-spark2 -setup-spark2: install-piptools ## Install requirements - pip-sync requirements-spark2.txt dev-requirements.txt + pip install -r dev-requirements.in .PHONY: fmt fmt: ## Format code with black and isort @@ -35,11 +31,12 @@ fmt: ## Format code with black and isort .PHONY: lint lint: ## Run linters - mypy flytekit/core || true - mypy flytekit/types || true - mypy tests/flytekit/unit/core || true - # Exclude setup.py to fix error: Duplicate module named "setup" - mypy plugins --exclude setup.py || true + mypy flytekit/core + mypy flytekit/types + # allow-empty-bodies: Allow empty body in function. + # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". + # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. + mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files .PHONY: spellcheck @@ -62,18 +59,6 @@ unit_test: pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} && \ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} -requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt -requirements-spark2.txt: requirements-spark2.in install-piptools - $(PIP_COMPILE) $< - -requirements.txt: export CUSTOM_COMPILE_COMMAND := make requirements.txt -requirements.txt: requirements.in install-piptools - $(PIP_COMPILE) $< - -dev-requirements.txt: export CUSTOM_COMPILE_COMMAND := make dev-requirements.txt -dev-requirements.txt: dev-requirements.in requirements.txt install-piptools - $(PIP_COMPILE) $< - doc-requirements.txt: export CUSTOM_COMPILE_COMMAND := make doc-requirements.txt doc-requirements.txt: doc-requirements.in install-piptools $(PIP_COMPILE) $< @@ -83,7 +68,7 @@ ${MOCK_FLYTE_REPO}/requirements.txt: ${MOCK_FLYTE_REPO}/requirements.in install- $(PIP_COMPILE) $< .PHONY: requirements -requirements: requirements.txt dev-requirements.txt requirements-spark2.txt doc-requirements.txt ${MOCK_FLYTE_REPO}/requirements.txt ## Compile requirements +requirements: doc-requirements.txt ${MOCK_FLYTE_REPO}/requirements.txt ## Compile requirements # TODO: Change this in the future to be all of flytekit .PHONY: coverage diff --git a/dev-requirements.in b/dev-requirements.in index a02c8fa144..755231ed71 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,4 +1,4 @@ --c requirements.txt +-r requirements.in git+https://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte coverage[toml] @@ -12,8 +12,13 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython -tensorflow==2.8.1 + +# Only install tensorflow if not running on an arm Mac. +tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin' # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-protobuf +types-croniter +types-mock diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index ea0e19354a..0000000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,597 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.7 -# by the following command: -# -# make dev-requirements.txt -# --e file:.#egg=flytekit - # via - # -c requirements.txt - # pytest-flyte -absl-py==1.3.0 - # via - # tensorboard - # tensorflow -arrow==1.2.3 - # via - # -c requirements.txt - # jinja2-time -astunparse==1.6.3 - # via tensorflow -attrs==20.3.0 - # via - # -c requirements.txt - # jsonschema - # pytest - # pytest-docker -backcall==0.2.0 - # via ipython -bcrypt==4.0.1 - # via paramiko -binaryornot==0.4.4 - # via - # -c requirements.txt - # cookiecutter -cached-property==1.5.2 - # via docker-compose -cachetools==5.2.0 - # via google-auth -certifi==2022.12.7 - # via - # -c requirements.txt - # requests -cffi==1.15.1 - # via - # -c requirements.txt - # cryptography - # pynacl -cfgv==3.3.1 - # via pre-commit -chardet==5.1.0 - # via - # -c requirements.txt - # binaryornot -charset-normalizer==2.1.1 - # via - # -c requirements.txt - # requests -click==8.1.3 - # via - # -c requirements.txt - # cookiecutter - # flytekit -cloudpickle==2.2.0 - # via - # -c requirements.txt - # flytekit -codespell==2.2.2 - # via -r dev-requirements.in -cookiecutter==2.1.1 - # via - # -c requirements.txt - # flytekit -coverage[toml]==6.5.0 - # via - # -r dev-requirements.in - # pytest-cov -croniter==1.3.8 - # via - # -c requirements.txt - # flytekit -cryptography==38.0.4 - # via - # -c requirements.txt - # paramiko - # pyopenssl - # secretstorage -dataclasses-json==0.5.7 - # via - # -c requirements.txt - # flytekit -decorator==5.1.1 - # via - # -c requirements.txt - # ipython - # retry -deprecated==1.2.13 - # via - # -c requirements.txt - # flytekit -diskcache==5.4.0 - # via - # -c requirements.txt - # flytekit -distlib==0.3.6 - # via virtualenv -distro==1.8.0 - # via docker-compose -docker[ssh]==6.0.1 - # via - # -c requirements.txt - # docker-compose - # flytekit -docker-compose==1.29.2 - # via pytest-flyte -docker-image-py==0.1.12 - # via - # -c requirements.txt - # flytekit -dockerpty==0.4.1 - # via docker-compose -docopt==0.6.2 - # via docker-compose -docstring-parser==0.15 - # via - # -c requirements.txt - # flytekit -exceptiongroup==1.0.4 - # via pytest -filelock==3.8.2 - # via virtualenv -flatbuffers==22.12.6 - # via tensorflow -flyteidl==1.3.1 - # via - # -c requirements.txt - # flytekit -gast==0.5.3 - # via tensorflow -google-api-core[grpc]==2.11.0 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage - # google-cloud-core -google-auth==2.15.0 - # via - # google-api-core - # google-auth-oauthlib - # google-cloud-core - # tensorboard -google-auth-oauthlib==0.4.6 - # via tensorboard -google-cloud-bigquery==3.4.0 - # via -r dev-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via - # -r dev-requirements.in - # google-cloud-bigquery -google-cloud-core==2.3.2 - # via google-cloud-bigquery -google-crc32c==1.5.0 - # via google-resumable-media -google-pasta==0.2.0 - # via tensorflow -google-resumable-media==2.4.0 - # via google-cloud-bigquery -googleapis-common-protos==1.57.0 - # via - # -c requirements.txt - # flyteidl - # google-api-core - # grpcio-status -grpcio==1.51.1 - # via - # -c requirements.txt - # flytekit - # google-api-core - # google-cloud-bigquery - # grpcio-status - # tensorboard - # tensorflow -grpcio-status==1.51.1 - # via - # -c requirements.txt - # flytekit - # google-api-core -h5py==3.7.0 - # via tensorflow -identify==2.5.9 - # via pre-commit -idna==3.4 - # via - # -c requirements.txt - # requests -importlib-metadata==5.1.0 - # via - # -c requirements.txt - # click - # flytekit - # jsonschema - # keyring - # markdown - # pluggy - # pre-commit - # pytest - # virtualenv -iniconfig==1.1.1 - # via pytest -ipython==7.34.0 - # via -r dev-requirements.in -jaraco-classes==3.2.3 - # via - # -c requirements.txt - # keyring -jedi==0.18.2 - # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage -jinja2==3.1.2 - # via - # -c requirements.txt - # cookiecutter - # jinja2-time - # pytest-flyte -jinja2-time==0.2.0 - # via - # -c requirements.txt - # cookiecutter -joblib==1.2.0 - # via - # -c requirements.txt - # -r dev-requirements.in - # flytekit - # scikit-learn -jsonschema==3.2.0 - # via - # -c requirements.txt - # docker-compose -keras==2.8.0 - # via tensorflow -keras-preprocessing==1.1.2 - # via tensorflow -keyring==23.11.0 - # via - # -c requirements.txt - # flytekit -libclang==14.0.6 - # via tensorflow -markdown==3.4.1 - # via tensorboard -markupsafe==2.1.1 - # via - # -c requirements.txt - # jinja2 - # werkzeug -marshmallow==3.19.0 - # via - # -c requirements.txt - # dataclasses-json - # marshmallow-enum - # marshmallow-jsonschema -marshmallow-enum==1.5.1 - # via - # -c requirements.txt - # dataclasses-json -marshmallow-jsonschema==0.13.0 - # via - # -c requirements.txt - # flytekit -matplotlib-inline==0.1.6 - # via ipython -mock==4.0.3 - # via -r dev-requirements.in -more-itertools==9.0.0 - # via - # -c requirements.txt - # jaraco-classes -mypy==0.991 - # via -r dev-requirements.in -mypy-extensions==0.4.3 - # via - # -c requirements.txt - # mypy - # typing-inspect -natsort==8.2.0 - # via - # -c requirements.txt - # flytekit -nodeenv==1.7.0 - # via pre-commit -numpy==1.21.6 - # via - # -c requirements.txt - # flytekit - # h5py - # keras-preprocessing - # opt-einsum - # pandas - # pyarrow - # scikit-learn - # scipy - # tensorboard - # tensorflow -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via tensorflow -packaging==21.3 - # via - # -c requirements.txt - # docker - # google-cloud-bigquery - # marshmallow - # pytest -pandas==1.3.5 - # via - # -c requirements.txt - # flytekit -paramiko==2.12.0 - # via docker -parso==0.8.3 - # via jedi -pexpect==4.8.0 - # via ipython -pickleshare==0.7.5 - # via ipython -platformdirs==2.6.0 - # via virtualenv -pluggy==1.0.0 - # via pytest -pre-commit==2.20.0 - # via -r dev-requirements.in -prompt-toolkit==3.0.36 - # via ipython -proto-plus==1.22.1 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage -protobuf==4.21.10 - # via - # -c requirements.txt - # flyteidl - # google-api-core - # google-cloud-bigquery - # google-cloud-bigquery-storage - # googleapis-common-protos - # grpcio-status - # proto-plus - # protoc-gen-swagger - # tensorboard - # tensorflow -protoc-gen-swagger==0.1.0 - # via - # -c requirements.txt - # flyteidl -ptyprocess==0.7.0 - # via pexpect -py==1.11.0 - # via - # -c requirements.txt - # retry -pyarrow==10.0.1 - # via - # -c requirements.txt - # flytekit - # google-cloud-bigquery -pyasn1==0.4.8 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.2.8 - # via google-auth -pycparser==2.21 - # via - # -c requirements.txt - # cffi -pygments==2.13.0 - # via ipython -pynacl==1.5.0 - # via paramiko -pyopenssl==22.1.0 - # via - # -c requirements.txt - # flytekit -pyparsing==3.0.9 - # via - # -c requirements.txt - # packaging -pyrsistent==0.19.2 - # via - # -c requirements.txt - # jsonschema -pytest==7.2.0 - # via - # -r dev-requirements.in - # pytest-cov - # pytest-docker - # pytest-flyte -pytest-cov==4.0.0 - # via -r dev-requirements.in -pytest-docker==1.0.1 - # via pytest-flyte -pytest-flyte @ git+https://github.com/flyteorg/pytest-flyte@main - # via -r dev-requirements.in -python-dateutil==2.8.2 - # via - # -c requirements.txt - # arrow - # croniter - # flytekit - # google-cloud-bigquery - # pandas -python-dotenv==0.21.0 - # via docker-compose -python-json-logger==2.0.4 - # via - # -c requirements.txt - # flytekit -python-slugify==7.0.0 - # via - # -c requirements.txt - # cookiecutter -pytimeparse==1.1.8 - # via - # -c requirements.txt - # flytekit -pytz==2022.6 - # via - # -c requirements.txt - # flytekit - # pandas -pyyaml==5.4.1 - # via - # -c requirements.txt - # cookiecutter - # docker-compose - # flytekit - # pre-commit -regex==2022.10.31 - # via - # -c requirements.txt - # docker-image-py -requests==2.28.1 - # via - # -c requirements.txt - # cookiecutter - # docker - # docker-compose - # flytekit - # google-api-core - # google-cloud-bigquery - # requests-oauthlib - # responses - # tensorboard -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -responses==0.22.0 - # via - # -c requirements.txt - # flytekit -retry==0.9.2 - # via - # -c requirements.txt - # flytekit -rsa==4.9 - # via google-auth -scikit-learn==1.0.2 - # via -r dev-requirements.in -scipy==1.7.3 - # via scikit-learn -secretstorage==3.3.3 - # via - # -c requirements.txt - # keyring -singledispatchmethod==1.0 - # via - # -c requirements.txt - # flytekit -six==1.16.0 - # via - # -c requirements.txt - # astunparse - # dockerpty - # google-auth - # google-pasta - # jsonschema - # keras-preprocessing - # paramiko - # python-dateutil - # tensorflow - # websocket-client -sortedcontainers==2.4.0 - # via - # -c requirements.txt - # flytekit -statsd==3.3.0 - # via - # -c requirements.txt - # flytekit -tensorboard==2.8.0 - # via tensorflow -tensorboard-data-server==0.6.1 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -tensorflow==2.8.1 - # via -r dev-requirements.in -tensorflow-estimator==2.8.0 - # via tensorflow -tensorflow-io-gcs-filesystem==0.28.0 - # via tensorflow -termcolor==2.1.1 - # via tensorflow -text-unidecode==1.3 - # via - # -c requirements.txt - # python-slugify -texttable==1.6.7 - # via docker-compose -threadpoolctl==3.1.0 - # via scikit-learn -toml==0.10.2 - # via - # -c requirements.txt - # pre-commit - # responses -tomli==2.0.1 - # via - # coverage - # mypy - # pytest -torch==1.13.1 - # via -r dev-requirements.in -traitlets==5.6.0 - # via - # ipython - # matplotlib-inline -typed-ast==1.5.4 - # via mypy -types-toml==0.10.8.1 - # via - # -c requirements.txt - # responses -typing-extensions==4.4.0 - # via - # -c requirements.txt - # arrow - # flytekit - # importlib-metadata - # mypy - # responses - # tensorflow - # torch - # typing-inspect -typing-inspect==0.8.0 - # via - # -c requirements.txt - # dataclasses-json -urllib3==1.26.13 - # via - # -c requirements.txt - # docker - # flytekit - # requests - # responses -virtualenv==20.17.1 - # via pre-commit -wcwidth==0.2.5 - # via prompt-toolkit -websocket-client==0.59.0 - # via - # -c requirements.txt - # docker - # docker-compose -werkzeug==2.2.2 - # via tensorboard -wheel==0.38.4 - # via - # -c requirements.txt - # astunparse - # flytekit - # tensorboard -wrapt==1.14.1 - # via - # -c requirements.txt - # deprecated - # flytekit - # tensorflow -zipp==3.11.0 - # via - # -c requirements.txt - # importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/doc-requirements.in b/doc-requirements.in index 2850232418..a5b921481c 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -45,4 +45,6 @@ whylogs # whylogs whylabs-client # whylogs ray # ray scikit-learn # scikit-learn +dask[distributed] # dask vaex # vaex +mlflow # mlflow diff --git a/doc-requirements.txt b/doc-requirements.txt index 7c92fcb018..2eb0532253 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -6,15 +6,17 @@ # -e file:.#egg=flytekit # via -r doc-requirements.in -absl-py==1.3.0 +absl-py==1.4.0 # via # tensorboard # tensorflow aiosignal==1.3.1 # via ray -alabaster==0.7.12 +alabaster==0.7.13 # via sphinx -altair==4.2.0 +alembic==1.9.2 + # via mlflow +altair==4.2.2 # via great-expectations ansiwrap==0.8.4 # via papermill @@ -25,10 +27,6 @@ anyio==3.6.2 # watchfiles aplus==0.11.0 # via vaex-core -appnope==0.1.3 - # via - # ipykernel - # ipython argon2-cffi==21.3.0 # via # jupyter-server @@ -40,15 +38,15 @@ arrow==1.2.3 # via # isoduration # jinja2-time -astroid==2.12.13 +astroid==2.14.1 # via sphinx-autoapi -astropy==5.1.1 +astropy==5.2.1 # via vaex-astro asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow -attrs==22.1.0 +attrs==22.2.0 # via # jsonschema # ray @@ -57,7 +55,7 @@ babel==2.11.0 # via sphinx backcall==0.2.0 # via ipython -beautifulsoup4==4.11.1 +beautifulsoup4==4.11.2 # via # furo # nbconvert @@ -65,17 +63,19 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -blake3==0.3.1 +blake3==0.3.3 # via vaex-core -bleach==5.0.1 +bleach==6.0.0 # via nbconvert -botocore==1.29.26 +botocore==1.29.61 # via -r doc-requirements.in bqplot==0.12.36 - # via vaex-jupyter + # via + # ipyvolume + # vaex-jupyter branca==0.6.0 # via ipyleaflet -cachetools==5.2.0 +cachetools==5.3.0 # via # google-auth # vaex-server @@ -91,51 +91,62 @@ cfgv==3.3.1 # via pre-commit chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests -click==8.0.4 +click==8.1.3 # via # cookiecutter # dask + # databricks-cli + # distributed + # flask # flytekit # great-expectations + # mlflow # papermill # ray # sphinx-click # uvicorn -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via # dask + # distributed # flytekit + # mlflow + # shap # vaex-core colorama==0.4.6 # via great-expectations comm==0.1.2 # via ipykernel -commonmark==0.9.1 - # via rich -contourpy==1.0.6 +contourpy==1.0.7 # via matplotlib cookiecutter==2.1.1 # via flytekit croniter==1.3.8 # via flytekit -cryptography==38.0.4 +cryptography==39.0.0 # via # -r doc-requirements.in # great-expectations # pyopenssl + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib -dask==2022.12.0 - # via vaex-core +dask[distributed]==2023.1.1 + # via + # -r doc-requirements.in + # distributed + # vaex-core +databricks-cli==0.17.4 + # via mlflow dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.4 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -149,8 +160,12 @@ diskcache==5.4.0 # via flytekit distlib==0.3.6 # via virtualenv +distributed==2023.1.1 + # via dask docker==6.0.1 - # via flytekit + # via + # flytekit + # mlflow docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 @@ -167,22 +182,24 @@ doltcli==0.1.17 entrypoints==0.4 # via # altair - # jupyter-client + # mlflow # papermill executing==1.2.0 # via stack-data -fastapi==0.88.0 +fastapi==0.89.1 # via vaex-server fastjsonschema==2.16.2 # via nbformat -filelock==3.8.2 +filelock==3.9.0 # via # ray # vaex-core # virtualenv -flatbuffers==22.12.6 +flask==2.2.2 + # via mlflow +flatbuffers==23.1.21 # via tensorflow -flyteidl==1.3.1 +flyteidl==1.3.5 # via flytekit fonttools==4.38.0 # via matplotlib @@ -194,24 +211,29 @@ frozenlist==1.3.3 # via # aiosignal # ray -fsspec==2022.11.0 +fsspec==2023.1.0 # via # -r doc-requirements.in # dask # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -future==0.18.2 +future==0.18.3 # via vaex-core gast==0.5.3 # via tensorflow +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via + # flytekit + # mlflow google-api-core[grpc]==2.11.0 # via # -r doc-requirements.in # google-cloud-bigquery - # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.15.0 +google-auth==2.16.0 # via # google-api-core # google-auth-oauthlib @@ -222,26 +244,25 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.4.0 +google-cloud-bigquery==3.5.0 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.4.0 +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # google-api-core # grpcio-status -great-expectations==0.15.37 +great-expectations==0.15.46 # via -r doc-requirements.in -greenlet==2.0.1 +greenlet==2.0.2 # via sqlalchemy grpcio==1.51.1 # via @@ -256,17 +277,21 @@ grpcio-status==1.51.1 # via # flytekit # google-api-core +gunicorn==20.1.0 + # via mlflow h11==0.14.0 # via uvicorn -h5py==3.7.0 +h5py==3.8.0 # via # tensorflow # vaex-hdf5 +heapdict==1.0.1 + # via zict htmlmin==0.1.12 - # via pandas-profiling + # via ydata-profiling httptools==0.5.0 # via uvicorn -identify==2.5.9 +identify==2.5.17 # via pre-commit idna==3.4 # via @@ -277,17 +302,20 @@ imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx -importlib-metadata==5.1.0 +importlib-metadata==5.2.0 # via + # flask # flytekit # great-expectations + # jupyter-client # keyring # markdown + # mlflow # nbconvert # sphinx ipydatawidgets==4.3.2 # via pythreejs -ipykernel==6.19.2 +ipykernel==6.20.2 # via # ipywidgets # jupyter @@ -299,7 +327,7 @@ ipyleaflet==0.17.2 # via vaex-jupyter ipympl==0.9.2 # via vaex-jupyter -ipython==8.7.0 +ipython==8.9.0 # via # great-expectations # ipykernel @@ -312,15 +340,19 @@ ipython-genutils==0.2.0 # nbclassic # notebook # qtconsole -ipyvolume==0.5.2 +ipyvolume==0.6.0 # via vaex-jupyter ipyvue==1.8.0 - # via ipyvuetify + # via + # ipyvolume + # ipyvuetify ipyvuetify==1.8.4 - # via vaex-jupyter + # via + # ipyvolume + # vaex-jupyter ipywebrtc==0.6.0 # via ipyvolume -ipywidgets==8.0.3 +ipywidgets==8.0.4 # via # bqplot # great-expectations @@ -333,25 +365,34 @@ ipywidgets==8.0.3 # pythreejs isoduration==20.11.0 # via jsonschema +itsdangerous==2.1.2 + # via flask jaraco-classes==3.2.3 # via keyring jedi==0.18.2 # via ipython +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # altair # branca # cookiecutter + # distributed + # flask # great-expectations # jinja2-time # jupyter-server + # mlflow # nbclassic # nbconvert # notebook - # pandas-profiling # sphinx # sphinx-autoapi # vaex-ml + # ydata-profiling jinja2-time==0.2.0 # via cookiecutter jmespath==1.0.1 @@ -376,7 +417,7 @@ jsonschema[format-nongpl]==4.17.3 # ray jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.4.8 +jupyter-client==8.0.2 # via # ipykernel # jupyter-console @@ -387,7 +428,7 @@ jupyter-client==7.4.8 # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==5.1.0 +jupyter-core==5.2.0 # via # jupyter-client # jupyter-server @@ -397,47 +438,57 @@ jupyter-core==5.1.0 # nbformat # notebook # qtconsole -jupyter-events==0.5.0 +jupyter-events==0.6.3 # via jupyter-server -jupyter-server==2.0.1 +jupyter-server==2.2.0 # via # nbclassic # notebook-shim -jupyter-server-terminals==0.4.2 +jupyter-server-terminals==0.4.4 # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==3.0.4 +jupyterlab-widgets==3.0.5 # via ipywidgets keras==2.8.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib kubernetes==25.3.0 - # via -r doc-requirements.in -lazy-object-proxy==1.8.0 + # via + # -r doc-requirements.in + # flytekit +lazy-object-proxy==1.9.0 # via astroid -libclang==14.0.6 +libclang==15.0.6.1 # via tensorflow llvmlite==0.39.1 # via numba locket==1.0.0 - # via partd -lxml==4.9.1 + # via + # distributed + # partd +lxml==4.9.2 # via sphinx-material makefun==1.15.0 # via great-expectations +mako==1.2.4 + # via alembic markdown==3.4.1 # via # -r doc-requirements.in + # mlflow # tensorboard -markupsafe==2.1.1 +markdown-it-py==2.1.0 + # via rich +markupsafe==2.1.2 # via # jinja2 + # mako # nbconvert # werkzeug marshmallow==3.19.0 @@ -450,48 +501,56 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.6.2 +matplotlib==3.6.3 # via # ipympl - # pandas-profiling + # ipyvolume + # mlflow # phik # seaborn # vaex-viz + # ydata-profiling matplotlib-inline==0.1.6 # via # ipykernel # ipython +mdurl==0.1.2 + # via markdown-it-py mistune==2.0.4 # via # great-expectations # nbconvert -modin==0.17.1 +mlflow==2.1.1 + # via -r doc-requirements.in +modin==0.18.1 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes msgpack==1.0.4 - # via ray -multimethod==1.9 # via - # pandas-profiling + # distributed + # ray +multimethod==1.9.1 + # via # visions + # ydata-profiling mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -nbclassic==0.4.8 +nbclassic==0.5.1 # via notebook nbclient==0.7.2 # via # nbconvert # papermill -nbconvert==7.2.6 +nbconvert==7.2.9 # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.0 +nbformat==5.7.3 # via # great-expectations # jupyter-server @@ -503,11 +562,10 @@ nbformat==5.7.0 nest-asyncio==1.5.6 # via # ipykernel - # jupyter-client # nbclassic # notebook # vaex-core -networkx==2.8.8 +networkx==3.0 # via visions nodeenv==1.7.0 # via pre-commit @@ -518,13 +576,16 @@ notebook==6.5.2 notebook-shim==0.2.2 # via nbclassic numba==0.56.4 - # via vaex-ml + # via + # shap + # vaex-ml numpy==1.23.5 # via # altair # astropy # bqplot # contourpy + # flytekit # great-expectations # h5py # imagehash @@ -533,11 +594,11 @@ numpy==1.23.5 # ipyvolume # keras-preprocessing # matplotlib + # mlflow # modin # numba # opt-einsum # pandas - # pandas-profiling # pandera # patsy # phik @@ -549,20 +610,35 @@ numpy==1.23.5 # scikit-learn # scipy # seaborn + # shap # statsmodels # tensorboard # tensorflow # vaex-core # visions # xarray + # ydata-profiling +nvidia-cublas-cu11==11.10.3.66 + # via + # nvidia-cudnn-cu11 + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch oauthlib==3.2.2 - # via requests-oauthlib + # via + # databricks-cli + # requests-oauthlib opt-einsum==3.3.0 # via tensorflow -packaging==21.3 +packaging==22.0 # via # astropy # dask + # distributed # docker # google-cloud-bigquery # great-expectations @@ -570,30 +646,34 @@ packaging==21.3 # jupyter-server # marshmallow # matplotlib + # mlflow # modin # nbconvert # pandera # qtpy + # shap # sphinx # statsmodels # xarray -pandas==1.5.2 +pandas==1.5.3 # via # altair # bqplot # dolt-integrations # flytekit # great-expectations + # mlflow # modin - # pandas-profiling # pandera # phik # seaborn + # shap # statsmodels # vaex-core # visions # xarray -pandas-profiling==3.5.0 + # ydata-profiling +pandas-profiling==3.6.6 # via -r doc-requirements.in pandera==0.13.4 # via -r doc-requirements.in @@ -610,10 +690,10 @@ patsy==0.5.3 pexpect==4.8.0 # via ipython phik==0.12.3 - # via pandas-profiling + # via ydata-profiling pickleshare==0.7.5 # via ipython -pillow==9.3.0 +pillow==9.4.0 # via # imagehash # ipympl @@ -621,17 +701,17 @@ pillow==9.3.0 # matplotlib # vaex-viz # visions -platformdirs==2.6.0 +platformdirs==2.6.2 # via # jupyter-core # virtualenv -plotly==5.11.0 +plotly==5.13.0 # via -r doc-requirements.in -pre-commit==2.20.0 +pre-commit==3.0.2 # via sphinx-tags progressbar2==4.2.0 # via vaex-core -prometheus-client==0.15.0 +prometheus-client==0.16.0 # via # jupyter-server # nbclassic @@ -640,18 +720,16 @@ prompt-toolkit==3.0.36 # via # ipython # jupyter-console -proto-plus==1.22.1 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage -protobuf==4.21.11 +proto-plus==1.22.2 + # via google-cloud-bigquery +protobuf==4.21.12 # via # flyteidl # google-api-core # google-cloud-bigquery - # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status + # mlflow # proto-plus # protoc-gen-swagger # ray @@ -662,6 +740,7 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via + # distributed # ipykernel # modin ptyprocess==0.7.0 @@ -677,7 +756,7 @@ py4j==0.10.9.5 pyarrow==10.0.1 # via # flytekit - # google-cloud-bigquery + # mlflow # vaex-core pyasn1==0.4.8 # via @@ -687,16 +766,16 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.10.2 +pydantic==1.10.4 # via # fastapi # great-expectations - # pandas-profiling # pandera # vaex-core + # ydata-profiling pyerfa==2.0.0.1 # via astropy -pygments==2.13.0 +pygments==2.14.0 # via # furo # ipython @@ -706,14 +785,15 @@ pygments==2.13.0 # rich # sphinx # sphinx-prompt -pyopenssl==22.1.0 +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via # great-expectations # matplotlib - # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema pyspark==3.3.1 # via -r doc-requirements.in @@ -730,13 +810,13 @@ python-dateutil==2.8.2 # matplotlib # pandas # whylabs-client -python-dotenv==0.21.0 +python-dotenv==0.21.1 # via uvicorn python-json-logger==2.0.4 # via # flytekit # jupyter-events -python-slugify[unidecode]==7.0.0 +python-slugify[unidecode]==8.0.0 # via # cookiecutter # sphinx-material @@ -746,11 +826,12 @@ pythreejs==2.4.1 # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # babel # flytekit # great-expectations + # mlflow # pandas pytz-deprecation-shim==0.1.0.post0 # via tzlocal @@ -761,17 +842,19 @@ pyyaml==6.0 # astropy # cookiecutter # dask + # distributed # flytekit # jupyter-events # kubernetes - # pandas-profiling + # mlflow # papermill # pre-commit # ray # sphinx-autoapi # uvicorn # vaex-core -pyzmq==24.0.1 + # ydata-profiling +pyzmq==25.0.0 # via # ipykernel # jupyter-client @@ -783,13 +866,16 @@ qtconsole==5.4.0 # via jupyter qtpy==2.3.0 # via qtconsole -ray==2.1.0 +querystring-parser==1.2.4 + # via mlflow +ray==2.2.0 # via -r doc-requirements.in regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter + # databricks-cli # docker # flytekit # google-api-core @@ -797,7 +883,7 @@ requests==2.28.1 # great-expectations # ipyvolume # kubernetes - # pandas-profiling + # mlflow # papermill # ray # requests-oauthlib @@ -805,6 +891,7 @@ requests==2.28.1 # sphinx # tensorboard # vaex-core + # ydata-profiling requests-oauthlib==1.3.1 # via # google-auth-oauthlib @@ -814,10 +901,14 @@ responses==0.22.0 retry==0.9.2 # via flytekit rfc3339-validator==0.1.4 - # via jsonschema + # via + # jsonschema + # jupyter-events rfc3986-validator==0.1.1 - # via jsonschema -rich==12.6.0 + # via + # jsonschema + # jupyter-events +rich==13.3.1 # via vaex-core rsa==4.9 # via google-auth @@ -825,44 +916,61 @@ ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.7 # via ruamel-yaml -scikit-learn==1.2.0 - # via -r doc-requirements.in +scikit-learn==1.2.1 + # via + # -r doc-requirements.in + # mlflow + # shap scipy==1.9.3 # via # great-expectations # imagehash - # pandas-profiling + # mlflow # phik # scikit-learn + # shap # statsmodels -seaborn==0.12.1 - # via pandas-profiling + # ydata-profiling +seaborn==0.12.2 + # via ydata-profiling +secretstorage==3.3.3 + # via keyring send2trash==1.8.0 # via # jupyter-server # nbclassic # notebook +shap==0.41.0 + # via mlflow six==1.16.0 # via # asttokens # astunparse # bleach + # databricks-cli # google-auth # google-pasta # keras-preprocessing # kubernetes # patsy # python-dateutil + # querystring-parser # rfc3339-validator # sphinx-code-include # tensorflow # vaex-core +slicer==0.0.7 + # via shap +smmap==5.0.0 + # via gitdb sniffio==1.3.0 # via anyio snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 - # via flytekit + # via + # distributed + # flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 sphinx==4.5.0 @@ -881,7 +989,7 @@ sphinx==4.5.0 # sphinx-prompt # sphinx-tags # sphinxcontrib-yt -sphinx-autoapi==2.0.0 +sphinx-autoapi==2.0.1 # via -r doc-requirements.in sphinx-basic-ng==1.0.0b1 # via furo @@ -901,13 +1009,13 @@ sphinx-panels==0.6.0 # via -r doc-requirements.in sphinx-prompt==1.5.0 # via -r doc-requirements.in -sphinx-tags==0.1.6 +sphinx-tags==0.2.0 # via -r doc-requirements.in -sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-htmlhelp==2.0.1 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx @@ -917,8 +1025,13 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.44 - # via -r doc-requirements.in +sqlalchemy==1.4.46 + # via + # -r doc-requirements.in + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow stack-data==0.6.2 # via ipython starlette==0.22.0 @@ -926,11 +1039,15 @@ starlette==0.22.0 statsd==3.3.0 # via flytekit statsmodels==0.13.5 - # via pandas-profiling + # via ydata-profiling tabulate==0.9.0 - # via vaex-core + # via + # databricks-cli + # vaex-core tangled-up-in-unicode==0.2.0 # via visions +tblib==1.7.0 + # via distributed tenacity==8.1.0 # via # papermill @@ -945,9 +1062,9 @@ tensorflow==2.8.1 # via -r doc-requirements.in tensorflow-estimator==2.8.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.28.0 +tensorflow-io-gcs-filesystem==0.30.0 # via tensorflow -termcolor==2.1.1 +termcolor==2.2.0 # via tensorflow terminado==0.17.1 # via @@ -964,18 +1081,18 @@ threadpoolctl==3.1.0 tinycss2==1.2.1 # via nbconvert toml==0.10.2 - # via - # pre-commit - # responses + # via responses toolz==0.12.0 # via # altair # dask + # distributed # partd -torch==1.13.0 +torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via + # distributed # ipykernel # jupyter-client # jupyter-server @@ -986,9 +1103,10 @@ tornado==6.2 tqdm==4.64.1 # via # great-expectations - # pandas-profiling # papermill -traitlets==5.7.0 + # shap + # ydata-profiling +traitlets==5.9.0 # via # bqplot # comm @@ -1018,7 +1136,7 @@ traittypes==0.2.1 # ipyleaflet # ipyvolume typeguard==2.13.3 - # via pandas-profiling + # via ydata-profiling types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 @@ -1046,9 +1164,10 @@ unidecode==1.3.6 # sphinx-autoapi uri-template==1.2.0 # via jsonschema -urllib3==1.26.13 +urllib3==1.26.14 # via # botocore + # distributed # docker # flytekit # great-expectations @@ -1090,10 +1209,10 @@ virtualenv==20.17.1 # pre-commit # ray visions[type_image_path]==0.7.5 - # via pandas-profiling + # via ydata-profiling watchfiles==0.18.1 # via uvicorn -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit webcolors==1.12 # via jsonschema @@ -1101,7 +1220,7 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.4.2 +websocket-client==1.5.0 # via # docker # jupyter-server @@ -1109,19 +1228,23 @@ websocket-client==1.4.2 websockets==10.4 # via uvicorn werkzeug==2.2.2 - # via tensorboard + # via + # flask + # tensorboard wheel==0.38.4 # via # astunparse # flytekit + # nvidia-cublas-cu11 + # nvidia-cuda-runtime-cu11 # tensorboard -whylabs-client==0.4.2 +whylabs-client==0.4.3 # via -r doc-requirements.in -whylogs==1.1.16 +whylogs==1.1.24 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs -widgetsnbextension==4.0.4 +widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via @@ -1130,11 +1253,15 @@ wrapt==1.14.1 # flytekit # pandera # tensorflow -xarray==2022.12.0 +xarray==2023.1.0 # via vaex-jupyter xyzservices==2022.9.0 # via ipyleaflet -zipp==3.11.0 +ydata-profiling==4.0.0 + # via pandas-profiling +zict==2.2.0 + # via distributed +zipp==3.12.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/docs/Makefile b/docs/Makefile index e61723ad76..afa73807cb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,3 +18,7 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + + +clean: + rm -rf ./build ./source/generated diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000..1fb1b91359 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,7 @@ +# TODO: Remove after buf migration is done and packages updated, see doc-requirements.in +# skl2onnx and tf2onnx added here so that the plugin API reference is rendered, +# with the caveat that the docs build environment has the environment variable +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set so that protobuf can be parsed +# using Python, which is acceptable for docs building. +skl2onnx +tf2onnx diff --git a/docs/source/clients.rst b/docs/source/clients.rst new file mode 100644 index 0000000000..f67ebf6a3a --- /dev/null +++ b/docs/source/clients.rst @@ -0,0 +1,4 @@ +.. automodule:: flytekit.clients + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index b0d46866fa..db5902391b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -76,6 +76,7 @@ Expected output: flytekit configuration remote + clients testing extend deck diff --git a/docs/source/plugins/dask.rst b/docs/source/plugins/dask.rst new file mode 100644 index 0000000000..53e9f11fcb --- /dev/null +++ b/docs/source/plugins/dask.rst @@ -0,0 +1,12 @@ +.. _dask: + +################################################### +Dask API reference +################################################### + +.. tags:: Integration, DistributedComputing, KubernetesOperator + +.. automodule:: flytekitplugins.dask + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index bf0b03fb95..693587192e 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -9,6 +9,7 @@ Plugin API reference * :ref:`AWS Sagemaker ` - AWS Sagemaker plugin reference * :ref:`Google Bigquery ` - Google Bigquery plugin reference * :ref:`FS Spec ` - FS Spec API reference +* :ref:`Dask ` - Dask standard API reference * :ref:`Deck standard ` - Deck standard API reference * :ref:`Dolt standard ` - Dolt standard API reference * :ref:`Great expectations ` - Great expectations API reference @@ -29,6 +30,7 @@ Plugin API reference * :ref:`Ray ` - Ray API reference * :ref:`DBT ` - DBT API reference * :ref:`Vaex ` - Vaex API reference +* :ref:`MLflow ` - MLflow API reference .. toctree:: :maxdepth: 2 @@ -39,6 +41,7 @@ Plugin API reference AWS Sagemaker Google Bigquery FS Spec + Dask Deck standard Dolt standard Great expectations @@ -59,3 +62,4 @@ Plugin API reference Ray DBT Vaex + MLflow diff --git a/docs/source/plugins/mlflow.rst b/docs/source/plugins/mlflow.rst new file mode 100644 index 0000000000..60d1a7c66b --- /dev/null +++ b/docs/source/plugins/mlflow.rst @@ -0,0 +1,9 @@ +.. _mlflow: + +################################################### +MLflow API reference +################################################### + +.. tags:: Integration, MachineLearning, Tracking + +.. automodule:: flytekitplugins.mlflow diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 1a9f74f114..c2fc11816c 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -58,7 +58,7 @@ TaskMetadata - Wrapper object that allows users to specify Task Resources - Things like CPUs/Memory, etc. WorkflowFailurePolicy - Customizes what happens when a workflow fails. - + PodTemplate - Custom PodTemplate for a task. Dynamic and Nested Workflows ============================== @@ -71,6 +71,18 @@ dynamic +Signaling +========= + +.. autosummary:: + :nosignatures: + :template: custom.rst + :toctree: generated/ + + approve + sleep + wait_for_input + Scheduling ============================ @@ -108,6 +120,7 @@ WorkflowReference reference_task reference_workflow + reference_launch_plan Core Task Types ================= @@ -153,6 +166,30 @@ Scalar LiteralType BlobType + +Task Utilities +============== + +.. autosummary:: + :nosignatures: + :template: custom.rst + :toctree: generated/ + + Deck + HashMethod + +Documentation +============= + +.. autosummary:: + :nosignatures: + :template: custom.rst + :toctree: generated/ + + Description + Documentation + SourceCode + """ import sys from typing import Generator @@ -168,13 +205,13 @@ from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod -from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan from flytekit.core.map_task import map_task from flytekit.core.notification import Email, PagerDuty, Slack +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.reference import get_reference_entity from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference @@ -184,12 +221,12 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch, tensorflow -from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence +from flytekit.extras import pytorch, sklearn, tensorflow from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase from flytekit.models.core.types import BlobType +from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types import directory, file, numpy, schema @@ -237,7 +274,7 @@ def load_implicit_plugins(): # note the group is always ``flytekit.plugins`` setup( ... - entry_points={'flytekit.plugins’: 'fsspec=flytekitplugins.fsspec'}, + entry_points={'flytekit.plugins': 'fsspec=flytekitplugins.fsspec'}, ... ) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 3d5017675e..4f4962309d 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -161,6 +161,12 @@ def _dispatch_execute( logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") logger.debug("Finished _dispatch_execute") + if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict: + # This env is set by the flytepropeller + # AWS batch job get the status from the exit code, so once we catch the error, + # we should return the error code here + exit(1) + def get_one_of(*args) -> str: """ @@ -264,6 +270,8 @@ def setup_execution( if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() + ssb.project = exe_project + ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( diff --git a/flytekit/clients/__init__.py b/flytekit/clients/__init__.py index e69de29bb2..1b08e1c567 100644 --- a/flytekit/clients/__init__.py +++ b/flytekit/clients/__init__.py @@ -0,0 +1,19 @@ +""" +===================== +Clients +===================== + +.. currentmodule:: flytekit.clients + +This module provides lower level access to a Flyte backend. + +.. _clients_module: + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + :nosignatures: + + ~friendly.SynchronousFlyteClient + ~raw.RawSynchronousFlyteClient +""" diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 7c4439d83d..6c8f54e9ce 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -10,11 +10,13 @@ import grpc import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest +from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse from flyteidl.service import admin_pb2_grpc as _admin_service from flyteidl.service import auth_pb2 from flyteidl.service import auth_pb2_grpc as auth_service from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2 from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service +from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub from google.protobuf.json_format import MessageToJson as _MessageToJson @@ -145,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) + self._signal = signal_service.SignalServiceStub(self._channel) try: resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) self._public_client_config = resp @@ -406,6 +409,20 @@ def get_task(self, get_object_request): """ return self._stub.GetTask(get_object_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse: + """ + This sets a signal + """ + return self._signal.SetSignal(signal_set_request, metadata=self._metadata) + + @_handle_rpc_error(retry=True) + def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: + """ + This lists signals + """ + return self._signal.ListSignals(signal_list_request, metadata=self._metadata) + #################################################################################################################### # # Workflow Endpoints diff --git a/flytekit/clis/sdk_in_container/backfill.py b/flytekit/clis/sdk_in_container/backfill.py new file mode 100644 index 0000000000..80a799b600 --- /dev/null +++ b/flytekit/clis/sdk_in_container/backfill.py @@ -0,0 +1,178 @@ +import typing +from datetime import datetime, timedelta + +import click + +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.run import DateTimeType, DurationParamType + +_backfill_help = """ +The backfill command generates and registers a new workflow based on the input launchplan to run an +automated backfill. The workflow can be managed using the Flyte UI and can be canceled, relaunched, and recovered. + +- launchplan refers to the name of the launchplan +- launchplan_version is optional and should be a valid version for a launchplan version. +""" + + +def resolve_backfill_window( + from_date: datetime = None, + to_date: datetime = None, + backfill_window: timedelta = None, +) -> typing.Tuple[datetime, datetime]: + """ + Resolves the from_date -> to_date + """ + if from_date and to_date and backfill_window: + raise click.BadParameter("Setting from-date, to-date and backfill_window at the same time is not allowed.") + if not (from_date or to_date): + raise click.BadParameter( + "One of following pairs are required -> (from-date, to-date) | (from-date, backfill_window) |" + " (to-date, backfill_window)" + ) + if from_date and to_date: + pass + elif not backfill_window: + raise click.BadParameter("One of start-date and end-date are needed with duration") + elif from_date: + to_date = from_date + backfill_window + else: + from_date = to_date - backfill_window + return from_date, to_date + + +@click.command("backfill", help=_backfill_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Project to register and run this workflow in", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Domain to register and run this workflow in", +) +@click.option( + "-v", + "--version", + required=False, + type=str, + default=None, + help="Version for the registered workflow. If not specified it is auto-derived using the start and end date", +) +@click.option( + "-n", + "--execution-name", + required=False, + type=str, + default=None, + help="Create a named execution for the backfill. This can prevent launching multiple executions.", +) +@click.option( + "--dry-run", + required=False, + type=bool, + is_flag=True, + default=False, + show_default=True, + help="Just generate the workflow - do not register or execute", +) +@click.option( + "--parallel/--serial", + required=False, + type=bool, + is_flag=True, + default=False, + show_default=True, + help="All backfill steps can be run in parallel (limited by max-parallelism), if using --parallel." + " Else all steps will be run sequentially [--serial].", +) +@click.option( + "--execute/--do-not-execute", + required=False, + type=bool, + is_flag=True, + default=True, + show_default=True, + help="Generate the workflow and register, do not execute", +) +@click.option( + "--from-date", + required=False, + type=DateTimeType(), + default=None, + help="Date from which the backfill should begin. Start date is inclusive.", +) +@click.option( + "--to-date", + required=False, + type=DateTimeType(), + default=None, + help="Date to which the backfill should run_until. End date is inclusive", +) +@click.option( + "--backfill-window", + required=False, + type=DurationParamType(), + default=None, + help="Timedelta for number of days, minutes hours after the from-date or before the to-date to compute the " + "backfills between. This is needed with from-date / to-date. Optional if both from-date and to-date are " + "provided", +) +@click.argument( + "launchplan", + required=True, + type=str, +) +@click.argument( + "launchplan-version", + required=False, + type=str, + default=None, +) +@click.pass_context +def backfill( + ctx: click.Context, + project: str, + domain: str, + from_date: datetime, + to_date: datetime, + backfill_window: timedelta, + launchplan: str, + launchplan_version: str, + dry_run: bool, + execute: bool, + parallel: bool, + execution_name: str, + version: str, +): + from_date, to_date = resolve_backfill_window(from_date, to_date, backfill_window) + remote = get_and_save_remote_with_click_context(ctx, project, domain) + try: + entity = remote.launch_backfill( + project=project, + domain=domain, + from_date=from_date, + to_date=to_date, + launchplan=launchplan, + launchplan_version=launchplan_version, + execution_name=execution_name, + version=version, + dry_run=dry_run, + execute=execute, + parallel=parallel, + ) + if entity: + console_url = remote.generate_console_url(entity) + if execute: + click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") + return + click.secho(f"\n Workflow registered at {console_url}", fg="green") + except StopIteration as e: + click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 72246bcba4..6ac451be92 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -4,7 +4,7 @@ import click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config, ImageConfig +from flytekit.configuration import Config, ImageConfig, get_config_file from flytekit.loggers import cli_logger from flytekit.remote.remote import FlyteRemote @@ -25,10 +25,15 @@ def get_and_save_remote_with_click_context( :return: FlyteRemote instance """ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) + cfg_file = get_config_file(cfg_file_location) + if cfg_file is None: + cfg_obj = Config.for_sandbox() + cli_logger.info("No config files found, creating remote with sandbox config") + else: + cfg_obj = Config.auto(cfg_file_location) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") + ) r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 76777c5663..1f843450ed 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,6 +1,7 @@ import click from flytekit import configuration +from flytekit.clis.sdk_in_container.backfill import backfill from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache @@ -70,6 +71,7 @@ def main(ctx, pkgs=None, config=None): main.add_command(init) main.add_command(run) main.add_command(register) +main.add_command(backfill) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index e1bf4eb5c3..2a167e9d0e 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -17,7 +17,7 @@ and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl like service account here. -Note: This command runs "fast" register by default. Future work to come to add a non-fast version. +Note: This command runs "fast" register by default. This means that a zip is created from the detected root of the packages given, and uploaded. Just like with pyflyte run, tasks registered from this command will download and unzip that code package before running. diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 615be3ba14..793c15c911 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -105,12 +105,32 @@ def convert( raise click.BadParameter(f"parameter should be a valid file path, {value}") +class DateTimeType(click.DateTime): + + _NOW_FMT = "now" + _ADDITONAL_FORMATS = [_NOW_FMT] + + def __init__(self): + super().__init__() + self.formats.extend(self._ADDITONAL_FORMATS) + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value in self._ADDITONAL_FORMATS: + if value == self._NOW_FMT: + return datetime.datetime.now() + return super().convert(value, param, ctx) + + class DurationParamType(click.ParamType): - name = "timedelta" + name = "[1:24 | :22 | 1 minute | 10 days | ...]" def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Duration type.") return datetime.timedelta(seconds=parse(value)) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index cb228b788a..220f9209ea 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -300,7 +300,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL :param insecure_skip_verify: Wether to skip SSL certificate verification - :param console_endpoint: endpoint for console if differenet than Flyte backend + :param console_endpoint: endpoint for console if different than Flyte backend :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -311,7 +311,7 @@ class PlatformConfig(object): :param auth_mode: The OAuth mode to use. Defaults to pkce flow. """ - endpoint: str = "localhost:30081" + endpoint: str = "localhost:30080" insecure: bool = False insecure_skip_verify: bool = False console_endpoint: typing.Optional[str] = None @@ -463,7 +463,7 @@ class GCSConfig(object): gsutil_parallelism: bool = False @classmethod - def auto(self, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) @@ -529,7 +529,7 @@ def with_params( ) @classmethod - def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> Config: + def auto(cls, config_file: typing.Union[str, ConfigFile, None] = None) -> Config: """ Automatically constructs the Config Object. The order of precedence is as follows 1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format. @@ -558,9 +558,9 @@ def for_sandbox(cls) -> Config: :return: Config """ return Config( - platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True), + platform=PlatformConfig(endpoint="localhost:30080", auth_mode="Pkce", insecure=True), data_config=DataConfig( - s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage") + s3=S3Config(endpoint="http://localhost:30002", access_key_id="minio", secret_access_key="miniostorage") ), ) @@ -647,6 +647,7 @@ class SerializationSettings(object): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -719,6 +720,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, @@ -768,6 +770,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -783,6 +786,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, diff --git a/flytekit/configuration/default_images.py b/flytekit/configuration/default_images.py index 33520e544f..8c01041eed 100644 --- a/flytekit/configuration/default_images.py +++ b/flytekit/configuration/default_images.py @@ -16,10 +16,10 @@ class DefaultImages(object): """ _DEFAULT_IMAGE_PREFIXES = { - PythonVersion.PYTHON_3_7: "ghcr.io/flyteorg/flytekit:py3.7-", - PythonVersion.PYTHON_3_8: "ghcr.io/flyteorg/flytekit:py3.8-", - PythonVersion.PYTHON_3_9: "ghcr.io/flyteorg/flytekit:py3.9-", - PythonVersion.PYTHON_3_10: "ghcr.io/flyteorg/flytekit:py3.10-", + PythonVersion.PYTHON_3_7: "cr.flyte.org/flyteorg/flytekit:py3.7-", + PythonVersion.PYTHON_3_8: "cr.flyte.org/flyteorg/flytekit:py3.8-", + PythonVersion.PYTHON_3_9: "cr.flyte.org/flyteorg/flytekit:py3.9-", + PythonVersion.PYTHON_3_10: "cr.flyte.org/flyteorg/flytekit:py3.10-", } @classmethod diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index 467f660d42..23210e95f1 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -18,6 +18,11 @@ FLYTECTL_CONFIG_ENV_VAR = "FLYTECTL_CONFIG" +def _exists(val: typing.Any) -> bool: + """Check if a value is defined.""" + return isinstance(val, bool) or bool(val is not None and val) + + @dataclass class LegacyConfigEntry(object): """ @@ -63,7 +68,7 @@ def read_from_file( @dataclass class YamlConfigEntry(object): """ - Creates a record for the config entry. contains + Creates a record for the config entry. Args: switch: dot-delimited string that should match flytectl args. Leaving it as dot-delimited instead of a list of strings because it's easier to maintain alignment with flytectl. @@ -80,10 +85,11 @@ def read_from_file( return None try: v = cfg.get(self) - if v: + if _exists(v): return transform(v) if transform else v except Exception: ... + return None @@ -224,7 +230,7 @@ def legacy_config(self) -> _configparser.ConfigParser: return self._legacy_config @property - def yaml_config(self) -> typing.Dict[str, Any]: + def yaml_config(self) -> typing.Dict[str, typing.Any]: return self._yaml_config @@ -273,7 +279,7 @@ def set_if_exists(d: dict, k: str, v: typing.Any) -> dict: The input dictionary ``d`` will be mutated. """ - if v: + if _exists(v): d[k] = v return d diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index d2e4838ed8..30b73223a9 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Tuple, Type, TypeVar from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface @@ -22,11 +22,11 @@ def __init__( self, name: str, query_template: str, + task_config: Optional[T] = None, task_type="sql_task", - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - task_config: Optional[T] = None, - outputs: Dict[str, Type] = None, + outputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") + self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() @property def query_template(self) -> str: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dccbaec803..f163e891e1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,10 +21,16 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities +from flytekit.core.context_manager import ( + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, + FlyteEntities, +) from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( @@ -45,6 +51,7 @@ from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext @@ -84,6 +91,7 @@ class TaskMetadata(object): timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task should be executed for. The execution will be terminated if the runtime exceeds the given timeout (approximately) + pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task. """ cache: bool = False @@ -93,6 +101,7 @@ class TaskMetadata(object): deprecated: str = "" retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None + pod_template_name: Optional[str] = None def __post_init__(self): if self.timeout: @@ -126,6 +135,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: discovery_version=self.cache_version, deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, + pod_template_name=self.pod_template_name, ) @@ -152,10 +162,11 @@ def __init__( self, task_type: str, name: str, - interface: Optional[_interface_models.TypedInterface] = None, + interface: _interface_models.TypedInterface, metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, + docs: Optional[Documentation] = None, **kwargs, ): self._task_type = task_type @@ -164,11 +175,12 @@ def __init__( self._metadata = metadata if metadata else TaskMetadata() self._task_type_version = task_type_version self._security_ctx = security_ctx + self._docs = docs FlyteEntities.entities.append(self) @property - def interface(self) -> Optional[_interface_models.TypedInterface]: + def interface(self) -> _interface_models.TypedInterface: return self._interface @property @@ -195,6 +207,10 @@ def task_type_version(self) -> int: def security_context(self) -> SecurityContext: return self._security_ctx + @property + def docs(self) -> Documentation: + return self._docs + def get_type_for_input_var(self, k: str, v: Any) -> type: """ Returns the python native type for the given input variable @@ -232,8 +248,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr kwargs = translate_inputs_to_literals( ctx, incoming_values=kwargs, - flyte_interface_types=self.interface.inputs, # type: ignore - native_types=self.get_input_types(), + flyte_interface_types=self.interface.inputs, + native_types=self.get_input_types(), # type: ignore ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) @@ -248,7 +264,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) logger.info( @@ -258,10 +274,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() - ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + # This code should mirror the call to `sandbox_execute` in the above cache case. + # Code is simpler with duplication and less metaprogramming, but introduces regressions + # if one is changed and not the other. + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -279,8 +295,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs): - return flyte_entity_call_handler(self, *args, **kwargs) + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") @@ -316,6 +332,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def sandbox_execute( + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, + ) -> _literal_models.LiteralMap: + """ + Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. + """ + es = cast(ExecutionState, ctx.execution_state) + b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox() + ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() + return self.dispatch_execute(ctx, input_literal_map) + @abstractmethod def dispatch_execute( self, @@ -361,7 +390,7 @@ def __init__( self, task_type: str, name: str, - task_config: T, + task_config: Optional[T], interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, disable_deck: bool = True, @@ -390,6 +419,21 @@ def __init__( self._environment = environment if environment else {} self._task_config = task_config self._disable_deck = disable_deck + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + cast(Documentation, self._docs).long_description = Description( + value=self._python_interface.docstring.long_description + ) # TODO lets call this interface and the other as flyte_interface? @property @@ -400,25 +444,25 @@ def python_interface(self) -> Interface: return self._python_interface @property - def task_config(self) -> T: + def task_config(self) -> Optional[T]: """ Returns the user-specified task config which is used for plugin-specific handling of the task. """ return self._task_config - def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for an input variable by name. """ return self._python_interface.inputs[k] - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for the specified output variable by name. """ return self._python_interface.outputs[k] - def get_input_types(self) -> Optional[Dict[str, type]]: + def get_input_types(self) -> Dict[str, type]: """ Returns the names and python types as a dictionary for the inputs of this task. """ @@ -464,7 +508,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) # type: ignore ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now @@ -545,7 +591,7 @@ def dispatch_execute( # After the execute has been successfully completed return outputs_literal_map - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore """ This is the method that will be invoked directly before executing the task method and before all the inputs are converted. One particular case where this is useful is if the context is to be modified for the user process @@ -563,7 +609,7 @@ def execute(self, **kwargs) -> Any: """ pass - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any: """ Post execute is called after the execution has completed, with the user_params and can be used to clean-up, or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index c1eb933ec6..4b4cfd16f3 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -126,7 +126,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): fa.upload_directory(str(cp), self._checkpoint_dest) else: fname = cp.stem + cp.suffix - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), fname]) fa.upload(str(cp), rpath) return @@ -138,7 +138,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): with dest_cp.open("wb") as f: f.write(cp.read()) - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), self.TMP_DST_PATH]) fa.upload(str(dest_cp), rpath) def read(self) -> typing.Optional[bytes]: diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index d47820f811..49970d5623 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): def name(self) -> str: return "ClassStorageTaskResolver" - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore return self.mapping def add(self, t: PythonAutoContainerTask): @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: idx = int(loader_args[0]) return self.mapping[idx] - def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore """ This is responsible for turning an instance of a task into args that the load_task function can reconstitute. """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index b5cae86923..76553db702 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP return self._compute_outputs(n) return self._condition - def if_(self, expr: bool) -> Case: + def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: return self._condition._if(expr) def compute_output_vars(self) -> typing.Optional[typing.List[str]]: @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str: return f"{node_id}.{var}" -def merge_promises(*args: Promise) -> typing.List[Promise]: +def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]: node_vars: typing.Set[typing.Tuple[str, str]] = set() merged_promises: typing.List[Promise] = [] for p in args: @@ -414,7 +414,7 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: - expr, promises = transform_to_boolexpr(c.expr) + expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) n = c.output_promise.ref.node # type: ignore return _core_wf.IfBlock(condition=expr, then_node=n), promises diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 848c1d2524..677142736c 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata @@ -10,6 +10,7 @@ from flytekit.models.security import Secret, SecurityContext +# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -35,16 +36,16 @@ def __init__( name: str, image: str, command: List[str], - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - arguments: List[str] = None, - outputs: Dict[str, Type] = None, + arguments: Optional[List[str]] = None, + outputs: Optional[Dict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - input_data_dir: str = None, - output_data_dir: str = None, + input_data_dir: Optional[str] = None, + output_data_dir: Optional[str] = None, metadata_format: MetadataFormat = MetadataFormat.JSON, - io_strategy: IOStrategy = None, + io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 3a23bb3eea..f3ed8a6026 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -48,7 +48,7 @@ flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[]) if typing.TYPE_CHECKING: - from flytekit.core.base_task import TaskResolverMixin + from flytekit.core.base_task import Task, TaskResolverMixin # Identifier fields use placeholders for registration-time substitution. @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: def build(self) -> ExecutionParameters: if not isinstance(self.working_dir, utils.AutoDeletingTempDir): - pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True) return ExecutionParameters( execution_date=self.execution_date, stats=self.stats, @@ -123,12 +123,14 @@ def build(self) -> ExecutionParameters: ) @staticmethod - def new_builder(current: ExecutionParameters = None) -> Builder: + def new_builder(current: Optional[ExecutionParameters] = None) -> Builder: return ExecutionParameters.Builder(current=current) def with_task_sandbox(self) -> Builder: prefix = self.working_directory - task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) + if isinstance(self.working_directory, utils.AutoDeletingTempDir): + prefix = self.working_directory.name + task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore p = pathlib.Path(task_sandbox_dir) cp_dir = p.joinpath("__cp") cp_dir.mkdir(exist_ok=True) @@ -285,7 +287,7 @@ def get(self, key: str) -> typing.Any: """ Returns task specific context if present else raise an error. The returned context will match the key """ - return self.__getattr__(attr_name=key) + return self.__getattr__(attr_name=key) # type: ignore class SecretsManager(object): @@ -465,14 +467,14 @@ class Mode(Enum): LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] - working_dir: os.PathLike + working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] branch_eval_mode: Optional[BranchEvalMode] user_space_params: Optional[ExecutionParameters] def __init__( self, - working_dir: os.PathLike, + working_dir: Union[os.PathLike, str], mode: Optional[ExecutionState.Mode] = None, engine_dir: Optional[Union[os.PathLike, str]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, @@ -605,7 +607,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params) @staticmethod - def current_context() -> Optional[FlyteContext]: + def current_context() -> FlyteContext: """ This method exists only to maintain backwards compatibility. Please use ``FlyteContextManager.current_context()`` instead. @@ -637,7 +639,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig """ from flytekit.deck.deck import _get_deck - return _get_deck(self.execution_state.user_space_params) + return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params) @dataclass class Builder(object): @@ -850,7 +852,7 @@ class FlyteEntities(object): registration process """ - entities = [] + entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore FlyteContextManager.initialize() diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index e05e8c009d..1d7c359c8a 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -21,26 +21,22 @@ UnsupportedPersistenceOp """ - import os import pathlib -import re -import shutil -import sys import tempfile import typing -from abc import abstractmethod -from shutil import copyfile -from typing import Dict, Union +from typing import Union from uuid import UUID +import shutil import fsspec +from fsspec.core import strip_protocol from fsspec.utils import get_protocol from flytekit import configuration from flytekit.configuration import DataConfig from flytekit.core.utils import PerformanceTimer -from flytekit.exceptions.user import FlyteAssertion, FlyteValueException +from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random from flytekit.loggers import logger @@ -51,9 +47,10 @@ # for key and secret _FSSPEC_S3_KEY_ID = "key" _FSSPEC_S3_SECRET = "secret" +_ANON = "anon" -def s3_setup_args(s3_cfg: configuration.S3Config): +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): kwargs = {} if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: if s3_cfg.access_key_id: @@ -67,6 +64,9 @@ def s3_setup_args(s3_cfg: configuration.S3Config): if s3_cfg.endpoint is not None: kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} + if anonymous: + kwargs[_ANON] = True + return kwargs @@ -94,10 +94,14 @@ def __init__( self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) self._local = fsspec.filesystem(None) - self._raw_output_prefix = raw_output_prefix - self._default_protocol = self.get_protocol(self._raw_output_prefix) - self._default_remote = self.get_filesystem(self._default_protocol) self._data_config = data_config if data_config else DataConfig.auto() + self._default_protocol = get_protocol(raw_output_prefix) + self._default_remote = self.get_filesystem(self._default_protocol) + self._raw_output_prefix = ( + raw_output_prefix + if raw_output_prefix.endswith(self._default_remote.sep) + else raw_output_prefix + self._default_remote.sep + ) @property def raw_output_prefix(self) -> str: @@ -107,39 +111,36 @@ def raw_output_prefix(self) -> str: def data_config(self) -> DataConfig: return self._data_config - @staticmethod - def get_protocol(path: typing.Optional[str] = None): - if path: - return get_protocol(path) - logger.info("Setting protocol to file") - return "file" - - def get_filesystem(self, protocol: str = None) -> fsspec.AbstractFileSystem: + def get_filesystem( + self, protocol: str = None, anonymous: bool = False + ) -> typing.Optional[fsspec.AbstractFileSystem]: if not protocol: return self._default_remote kwargs = {} if protocol == "file": kwargs = {"auto_mkdir": True} elif protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) + kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + return fsspec.filesystem(protocol, **kwargs) # type: ignore + elif protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. + if anonymous: + return None + return fsspec.filesystem(protocol, **kwargs) # type: ignore - def get_filesystem_for_path(self, path: str) -> fsspec.AbstractFileSystem: - protocol = self.get_protocol(path) + def get_filesystem_for_path(self, path: str = "") -> fsspec.AbstractFileSystem: + protocol = get_protocol(path) return self.get_filesystem(protocol) - def get_anonymous_filesystem(self, path: str) -> typing.Optional[fsspec.AbstractFileSystem]: - protocol = self.get_protocol(path) - if protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore - return anonymous_fs - return None - @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: """ - Deprecated. Lets find a replacement + Deprecated. Let's find a replacement """ protocol = get_protocol(path) if protocol is None: @@ -157,67 +158,63 @@ def local_sandbox_dir(self) -> os.PathLike: def local_access(self) -> fsspec.AbstractFileSystem: return self._local + @staticmethod + def strip_file_header(path: str) -> str: + """ + Drops file:// if it exists from the file + """ + if path.startswith("file://"): + return path.replace("file://", "", 1) + return path + @staticmethod def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: - if not f.endswith("*"): - f = os.path.join(f, "*") - if not t.endswith("/"): - t += "/" + f = os.path.join(f, "") + t = os.path.join(t, "") return f, t def exists(self, path: str) -> bool: try: - fs = self.get_filesystem(path) - return fs.exists(path) + file_system = self.get_filesystem_for_path(path) + return file_system.exists(path) except OSError as oe: logger.debug(f"Error in exists checking {path} {oe}") - fs = self.get_anonymous_filesystem(path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 exists check") - return fs.exists(path) + anon_fs = self.get_filesystem(get_protocol(path), anonymous=True) + if anon_fs is not None: + logger.debug(f"Attempting anonymous exists with {anon_fs}") + return anon_fs.exists(path) raise oe def get(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(from_path) + file_system = self.get_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) try: - return fs.get(from_path, to_path, recursive=recursive) + # Special case system level behavior because of inconsistencies in local implementation copy + # Don't want to use ls to check empty-ness because it can be extremely expensive if not empty + # TODO: Fix after https://github.com/fsspec/filesystem_spec/issues/1198 + if file_system.protocol == "file" and recursive: + return shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True) + return file_system.get(from_path, to_path, recursive=recursive) except OSError as oe: logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") - fs = self.get_anonymous_filesystem(from_path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 access") - return fs.get(from_path, to_path, recursive=recursive) + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + if file_system is not None: + logger.debug(f"Attempting anonymous get with {file_system}") + return file_system.get(from_path, to_path, recursive=recursive) raise oe def put(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(to_path) + file_system = self.get_filesystem_for_path(to_path) + if from_path.startswith("file"): + # The localFs system doesn't know how to handle source files with file:// so remove it + from_path = from_path.replace("file://", "") if recursive: + # Only check this for the local filesystem + if file_system.protocol == "file" and not file_system.isdir(from_path): + raise FlyteAssertion(f"Source path {from_path} is not a directory") from_path, to_path = self.recursive_paths(from_path, to_path) - return fs.put(from_path, to_path, recursive=recursive) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - path_list = list(paths) # make type check happy - if add_prefix: - path_list.insert(0, self.default_prefix) # type: ignore - path = "/".join(path_list) - if add_protocol: - return f"{self._default_protocol}://{path}" - return typing.cast(str, path) - - def construct_random_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: - """ - Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name - """ - key = UUID(int=random.getrandbits(128)).hex - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return self.construct_path(False, True, key, tail) - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") - return self.construct_path(False, True, key) + return file_system.put(from_path, to_path, recursive=recursive) def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ @@ -226,7 +223,20 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._default_remote, file_path_or_file_name) + default_protocol = self._default_remote.protocol + if type(default_protocol) == list: + default_protocol = default_protocol[0] + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + sep = self._default_remote.sep + tail = sep + tail if tail else tail + if default_protocol == "file": + # Special case the local case, users will not expect to see a file:// prefix + return strip_protocol(self.raw_output_prefix) + sep + key + tail + + return self._default_remote.unstrip_protocol(self.raw_output_prefix + key + tail) def get_random_remote_directory(self): return self.get_random_remote_path(None) @@ -235,19 +245,19 @@ def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = N """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._local, file_path_or_file_name) + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + if tail: + return os.path.join(self._local_sandbox_dir, key, tail) + return os.path.join(self._local_sandbox_dir, key) def get_random_local_directory(self) -> str: _dir = self.get_random_local_path(None) pathlib.Path(_dir).mkdir(parents=True, exist_ok=True) return _dir - def exists(self, path: str) -> bool: - """ - checks if the given path exists - """ - return self.exists(path) - def download_directory(self, remote_path: str, local_path: str): """ Downloads directory from given remote to local path @@ -274,11 +284,11 @@ def upload_directory(self, local_path: str, remote_path: str): """ return self.put_data(local_path, remote_path, is_multipart=True) - def get_data(self, remote_path: str, local_path: str, is_multipart=False): + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False): """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: + :param remote_path: + :param local_path: + :param is_multipart: """ try: with PerformanceTimer(f"Copying ({remote_path} -> {local_path})"): @@ -290,14 +300,14 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): f"Original exception: {str(ex)}" ) - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): + def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: + :param local_path: + :param remote_path: + :param is_multipart: """ try: with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"): @@ -309,10 +319,6 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul ) from ex -fsspec.register_implementation("/", ) -DataPersistencePlugins.register_plugin("file://", DiskPersistence) -DataPersistencePlugins.register_plugin("/", DiskPersistence) - flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( local_sandbox_dir=os.path.join(flyte_tmp_dir, "sandbox"), diff --git a/flytekit/core/docstring.py b/flytekit/core/docstring.py index 420f26f8f5..fa9d9caec2 100644 --- a/flytekit/core/docstring.py +++ b/flytekit/core/docstring.py @@ -4,7 +4,7 @@ class Docstring(object): - def __init__(self, docstring: str = None, callable_: Callable = None): + def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None): if docstring is not None: self._parsed_docstring = parse(docstring) else: diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index f3d90ebef8..bc3ab1d3fd 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -53,7 +53,7 @@ def __init__( ) else: # We don't know how to find the python interface here, approve() sets it below, See the code. - self._python_interface = None + self._python_interface = None # type: ignore @property def name(self) -> str: @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return p # Assume this is an approval operation since that's the only remaining option. - msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?" + msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?" proceed = click.confirm(msg, default=True) if proceed: # We need to return a promise here, and a promise is what should've been passed in by the call in approve() @@ -118,7 +118,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type): - """ + """Create a Gate object that waits for user input of the specified type. + Create a Gate object. This object will function like a task. Note that unlike a task, each time this function is called, a new Python object is created. If a workflow calls a subworkflow twice, and the subworkflow has a signal, then two Gate @@ -136,7 +137,8 @@ def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing def sleep(duration: datetime.timedelta): - """ + """Create a sleep Gate object. + :param duration: How long to sleep for :return: """ @@ -146,7 +148,8 @@ def sleep(duration: datetime.timedelta): def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: str, timeout: datetime.timedelta): - """ + """Create a Gate object for binary approval. + Create a Gate object. This object will function like a task. Note that unlike a task, each time this function is called, a new Python object is created. If a workflow calls a subworkflow twice, and the subworkflow has a signal, then two Gate @@ -164,6 +167,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st raise ValueError("You can't use approval on a task that doesn't return anything.") ctx = FlyteContextManager.current_context() + upstream_item = typing.cast(Promise, upstream_item) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: if not upstream_item.ref.node.flyte_entity.python_interface: raise ValueError( diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 63d7c8106f..3c24e65db2 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -5,7 +5,7 @@ import inspect import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast from typing_extensions import Annotated, get_args, get_origin, get_type_hints @@ -28,8 +28,8 @@ class Interface(object): def __init__( self, - inputs: typing.Optional[typing.Dict[str, Union[Type, Tuple[Type, Any]], None]] = None, - outputs: typing.Optional[typing.Dict[str, Type]] = None, + inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, + outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -43,21 +43,21 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs = {} + self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore if inputs: for k, v in inputs.items(): - if isinstance(v, Tuple) and len(v) > 1: - self._inputs[k] = v + if type(v) is tuple and len(cast(Tuple, v)) > 1: + self._inputs[k] = v # type: ignore else: - self._inputs[k] = (v, None) - self._outputs = outputs if outputs else {} + self._inputs[k] = (v, None) # type: ignore + self._outputs = outputs if outputs else {} # type: ignore self._output_tuple_name = output_tuple_name if outputs: variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): + class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -90,7 +90,7 @@ def __rshift__(self, *args, **kwargs): self._docstring = docstring @property - def output_tuple(self) -> Optional[Type[collections.namedtuple]]: + def output_tuple(self) -> Type[collections.namedtuple]: # type: ignore return self._output_tuple_class @property @@ -98,7 +98,7 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> typing.Dict[str, Type]: + def inputs(self) -> Dict[str, type]: r = {} for k, v in self._inputs.items(): r[k] = v[0] @@ -111,8 +111,8 @@ def output_names(self) -> Optional[List[str]]: return None @property - def inputs_with_defaults(self) -> typing.Dict[str, Tuple[Type, Any]]: - return self._inputs + def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]: + return cast(Dict[str, Tuple[Type, Any]], self._inputs) @property def default_inputs_as_kwargs(self) -> Dict[str, Any]: @@ -120,13 +120,13 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]: @property def outputs(self) -> typing.Dict[str, type]: - return self._outputs + return self._outputs # type: ignore @property def docstring(self) -> Optional[Docstring]: return self._docstring - def remove_inputs(self, vars: List[str]) -> Interface: + def remove_inputs(self, vars: Optional[List[str]]) -> Interface: """ This method is useful in removing some variables from the Flyte backend inputs specification, as these are implicit local only inputs or will be supplied by the library at runtime. For example, spark-session etc @@ -151,7 +151,7 @@ def with_inputs(self, extra_inputs: Dict[str, Type]) -> Interface: for k, v in extra_inputs.items(): if k in new_inputs: raise ValueError(f"Input {k} cannot be added as it already exists in the interface") - new_inputs[k] = v + cast(Dict[str, Type], new_inputs)[k] = v return Interface(new_inputs, self._outputs, docstring=self.docstring) def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: @@ -207,7 +207,6 @@ def transform_interface_to_typed_interface( """ if interface is None: return None - if interface.docstring is None: input_descriptions = output_descriptions = {} else: @@ -241,7 +240,7 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] + om[k] = typing.List[v] # type: ignore return om # type: ignore @@ -256,18 +255,20 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) -def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T], Annotated]: +def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T]]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if get_origin(t) is list: - return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif get_origin(t) is dict and t.__args__[0] == str: - return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - elif get_origin(t) is typing.Union: - return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] - elif get_origin(t) is Annotated: + ot = get_origin(t) + args = getattr(t, "__args__") + if ot is list: + return typing.List[_change_unrecognized_type_to_pickle(args[0])] # type: ignore + elif ot is dict and args[0] == str: + return typing.Dict[str, _change_unrecognized_type_to_pickle(args[1])] # type: ignore + elif ot is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] # type: ignore + elif ot is Annotated: base_type, *config = get_args(t) - return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] # type: ignore TypeEngine.get_transformer(t) except ValueError: logger.warning( @@ -295,12 +296,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore - inputs = OrderedDict() + inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) + inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # type: ignore # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -326,23 +327,24 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) - sub_type: Type[T] = v + sub_type: type = v if hasattr(v, "__origin__") and hasattr(v, "__args__"): - if v.__origin__ is list: - sub_type = v.__args__[0] - elif v.__origin__ is dict: - sub_type = v.__args__[1] - if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - if hasattr(sub_type.python_type(), "__name__"): - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} - elif hasattr(sub_type.python_type(), "_name"): + if getattr(v, "__origin__") is list: + sub_type = getattr(v, "__args__")[0] + elif getattr(v, "__origin__") is dict: + sub_type = getattr(v, "__args__")[1] + if hasattr(sub_type, "__origin__") and getattr(sub_type, "__origin__") is FlytePickle: + original_type = cast(FlytePickle, sub_type).python_type() + if hasattr(original_type, "__name__"): + res[k].type.metadata = {"python_class_name": original_type.__name__} + elif hasattr(original_type, "_name"): # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. - res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} + res[k].type.metadata = {"python_class_name": original_type._name} return res -def transform_type(x: type, description: str = None) -> _interface_models.Variable: +def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable: return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description) @@ -394,13 +396,13 @@ def t(a: int, b: str) -> Dict[str, int]: ... # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python - if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): + if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore # isinstance / issubclass does not work for Namedtuple. # Options 1 and 2 bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(return_annotation, include_extras=True)) + return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -420,7 +422,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): return_annotation} + return {default_output_name(): cast(Type, return_annotation)} def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 0d143e5fe8..86011f1253 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -74,7 +74,7 @@ def wf(a: int, c: str) -> str: # The reason we cache is simply because users may get the default launch plan twice for a single Workflow. We # don't want to create two defaults, could be confusing. - CACHE = {} + CACHE: typing.Dict[str, LaunchPlan] = {} @staticmethod def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: @@ -107,16 +107,16 @@ def create( cls, name: str, workflow: _annotated_workflow.WorkflowBase, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} @@ -130,7 +130,7 @@ def create( temp_inputs = {} for k, v in default_inputs.items(): temp_inputs[k] = (workflow.python_interface.inputs[k], v) - temp_interface = Interface(inputs=temp_inputs, outputs={}) + temp_interface = Interface(inputs=temp_inputs, outputs={}) # type: ignore temp_signature = transform_inputs_to_parameters(ctx, temp_interface) wf_signature_parameters._parameters.update(temp_signature.parameters) @@ -185,16 +185,16 @@ def get_or_create( cls, workflow: _annotated_workflow.WorkflowBase, name: Optional[str] = None, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: """ This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not @@ -298,13 +298,13 @@ def __init__( workflow: _annotated_workflow.WorkflowBase, parameters: _interface_models.ParameterMap, fixed_inputs: _literal_models.LiteralMap, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: typing.Optional[int] = None, - security_context: typing.Optional[security.SecurityContext] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ): self._name = name self._workflow = workflow @@ -313,7 +313,7 @@ def __init__( self._parameters = _interface_models.ParameterMap(parameters=parameters) self._fixed_inputs = fixed_inputs # See create() for additional information - self._saved_inputs = {} + self._saved_inputs: Dict[str, Any] = {} self._schedule = schedule self._notifications = notifications or [] @@ -328,16 +328,15 @@ def __init__( def clone_with( self, name: str, - parameters: _interface_models.ParameterMap = None, - fixed_inputs: _literal_models.LiteralMap = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, + parameters: Optional[_interface_models.ParameterMap] = None, + fixed_inputs: Optional[_literal_models.LiteralMap] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ) -> LaunchPlan: return LaunchPlan( name=name, @@ -349,7 +348,6 @@ def clone_with( labels=labels or self.labels, annotations=annotations or self.annotations, raw_output_data_config=raw_output_data_config or self.raw_output_data_config, - auth_role=auth_role or self._auth_role, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, ) @@ -407,11 +405,11 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] return self._raw_output_data_config @property - def max_parallelism(self) -> typing.Optional[int]: + def max_parallelism(self) -> Optional[int]: return self._max_parallelism @property - def security_context(self) -> typing.Optional[security.SecurityContext]: + def security_context(self) -> Optional[security.SecurityContext]: return self._security_context def construct_node_metadata(self) -> _workflow_model.NodeMetadata: @@ -455,8 +453,15 @@ def reference_launch_plan( ) -> Callable[[Callable[..., Any]], ReferenceLaunchPlan]: """ A reference launch plan is a pointer to a launch plan that already exists on your Flyte installation. This - object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. + object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface + via the function definition. + If at registration time the interface provided causes an issue with compilation, an error will be returned. + + :param project: Flyte project name of the launch plan + :param domain: Flyte domain name of the launch plan + :param name: launch plan name + :param version: specific version of the launch plan to use """ def wrapper(fn) -> ReferenceLaunchPlan: diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 3b5c0a09ca..48d0f0b335 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -7,7 +7,7 @@ import typing from contextlib import contextmanager from itertools import count -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from flytekit.configuration import SerializationSettings from flytekit.core import tracker @@ -36,8 +36,8 @@ class MapPythonTask(PythonTask): def __init__( self, python_function_task: PythonFunctionTask, - concurrency: int = None, - min_success_ratio: float = None, + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, **kwargs, ): """ @@ -149,8 +149,8 @@ def _compute_array_job_index() -> int: environment variable and the offset (if one's set). The offset will be set and used when the user request that the job runs in a number of slots less than the size of the input. """ - return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", 0)) + int( - os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")) + return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int( + os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0") ) @property @@ -168,7 +168,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs return self._run_task.interface.outputs - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> type: """ We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this interface to construct outputs. Each instance of an container_array task will however produce outputs @@ -181,7 +181,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: return self._python_interface.outputs[k] return self._run_task._python_interface.outputs[k] - def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: + def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: """ This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d8b43f2728..73f951d721 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,8 +4,9 @@ import typing from typing import Any, List -from flytekit.core.resources import Resources +from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify +from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -92,9 +93,14 @@ def with_overrides(self, *args, **kwargs): for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) if "requests" in kwargs or "limits" in kwargs: - requests = _convert_resource_overrides(kwargs.get("requests"), "requests") - limits = _convert_resource_overrides(kwargs.get("limits"), "limits") - self._resources = _resources_model(requests=requests, limits=limits) + requests = kwargs.get("requests") + if requests and not isinstance(requests, Resources): + raise AssertionError("requests should be specified as flytekit.Resources") + limits = kwargs.get("limits") + if limits and not isinstance(limits, Resources): + raise AssertionError("limits should be specified as flytekit.Resources") + + self._resources = convert_resources_to_resource_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: @@ -114,16 +120,21 @@ def with_overrides(self, *args, **kwargs): self._metadata._interruptible = kwargs["interruptible"] if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: + logger.warning("This override is beta. We may want to revisit this in the future.") + new_task_config = kwargs["task_config"] + if not isinstance(new_task_config, type(self.flyte_entity._task_config)): + raise ValueError("can't change the type of the task config") + self.flyte_entity._task_config = new_task_config return self def _convert_resource_overrides( resources: typing.Optional[Resources], resource_name: str -) -> [_resources_model.ResourceEntry]: +) -> typing.List[_resources_model.ResourceEntry]: if resources is None: return [] - if not isinstance(resources, Resources): - raise AssertionError(f"{resource_name} should be specified as flytekit.Resources") + resource_entries = [] if resources.cpu is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu)) diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index de33393c13..62065f6869 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,6 @@ from __future__ import annotations -import collections -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -21,7 +20,7 @@ def create_node( entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs -) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: +) -> Union[Node, VoidPromise]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any @@ -173,9 +172,9 @@ def sub_wf(): if len(output_names) == 1: # See explanation above for why we still tupletize a single element. - return entity.python_interface.output_tuple(results) + return entity.python_interface.output_tuple(results) # type: ignore - return entity.python_interface.output_tuple(*results) + return entity.python_interface.output_tuple(*results) # type: ignore else: raise Exception(f"Cannot use explicit run to call Flyte entities {entity.name}") diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py new file mode 100644 index 0000000000..5e9c746911 --- /dev/null +++ b/flytekit/core/pod_template.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Dict, Optional + +from kubernetes.client.models import V1PodSpec + +from flytekit.exceptions import user as _user_exceptions + +PRIMARY_CONTAINER_DEFAULT_NAME = "primary" + + +@dataclass +class PodTemplate(object): + """Custom PodTemplate specification for a Task.""" + + pod_spec: V1PodSpec = V1PodSpec(containers=[]) + primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init__(self): + if not self.primary_container_name: + raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 53048cb03f..8be9d8ccae 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,9 +1,8 @@ from __future__ import annotations import collections -import typing from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from typing_extensions import Protocol, get_args @@ -11,7 +10,13 @@ from flytekit.core import context_manager as _flyte_context from flytekit.core import interface as flyte_interface from flytekit.core import type_engine -from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ( + BranchEvalMode, + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, +) from flytekit.core.interface import Interface from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine @@ -70,7 +75,6 @@ def extract_value( val_type: type, flyte_literal_type: _type_models.LiteralType, ) -> _literal_models.Literal: - if isinstance(input_val, list): lt = flyte_literal_type python_type = val_type @@ -83,7 +87,7 @@ def extract_value( if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(python_type) + sub_type: type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise @@ -143,17 +147,16 @@ def extract_value( def get_primitive_val(prim: Primitive) -> Any: - if prim.integer: - return prim.integer - if prim.datetime: - return prim.datetime - if prim.boolean: - return prim.boolean - if prim.duration: - return prim.duration - if prim.string_value: - return prim.string_value - return prim.float_value + for value in [ + prim.integer, + prim.float_value, + prim.string_value, + prim.boolean, + prim.datetime, + prim.duration, + ]: + if value is not None: + return value class ConjunctionOps(Enum): @@ -350,8 +353,8 @@ def __init__(self, var: str, val: Union[NodeOutput, _literal_models.Literal]): def __hash__(self): return hash(id(self)) - def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if not self.is_ready: + def __rshift__(self, other: Union[Promise, VoidPromise]): + if not self.is_ready and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -411,10 +414,10 @@ def is_false(self) -> ComparisonExpression: def is_true(self): return self.is_(True) - def __eq__(self, other) -> ComparisonExpression: + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) - def __ne__(self, other) -> ComparisonExpression: + def __ne__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.NE, other) def __gt__(self, other) -> ComparisonExpression: @@ -458,7 +461,7 @@ def __str__(self): def create_native_named_tuple( ctx: FlyteContext, - promises: Optional[Union[Promise, typing.List[Promise]]], + promises: Union[Tuple[Promise], Promise, VoidPromise, None], entity_interface: Interface, ) -> Optional[Tuple]: """ @@ -479,7 +482,7 @@ def create_native_named_tuple( except Exception as e: raise AssertionError(f"Failed to convert value of output {k}, expected type {v}.") from e - if len(promises) == 0: + if len(cast(Tuple[Promise], promises)) == 0: return None named_tuple_name = "DefaultNamedTupleOutput" @@ -487,7 +490,7 @@ def create_native_named_tuple( named_tuple_name = entity_interface.output_tuple_name outputs = {} - for p in promises: + for p in cast(Tuple[Promise], promises): if not isinstance(p, Promise): raise AssertionError( "Workflow outputs can only be promises that are returned by tasks. Found a value of" @@ -500,8 +503,8 @@ def create_native_named_tuple( raise AssertionError(f"Failed to convert value of output {p.var}, expected type {t}.") from e # Should this class be part of the Interface? - t = collections.namedtuple(named_tuple_name, list(outputs.keys())) - return t(**outputs) + nt = collections.namedtuple(named_tuple_name, list(outputs.keys())) # type: ignore + return nt(**outputs) # To create a class that is a named tuple, we might have to create namedtuplemeta and manipulate the tuple @@ -545,7 +548,7 @@ def create_task_output( named_tuple_name = entity_interface.output_tuple_name # Should this class be part of the Interface? - class Output(collections.namedtuple(named_tuple_name, variables)): + class Output(collections.namedtuple(named_tuple_name, variables)): # type: ignore def with_overrides(self, *args, **kwargs): val = self.__getattribute__(self._fields[0]) val.with_overrides(*args, **kwargs) @@ -578,7 +581,7 @@ def binding_from_flyte_std( ctx: _flyte_context.FlyteContext, var_name: str, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, ) -> _literals_models.Binding: binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type=None) return _literals_models.Binding(var=var_name, binding=binding_data) @@ -587,7 +590,7 @@ def binding_from_flyte_std( def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, t_value_type: Optional[type] = None, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task @@ -604,7 +607,7 @@ def binding_data_from_python_std( if expected_literal_type.collection_type is None: raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") - sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None + sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -654,7 +657,7 @@ def binding_from_python_std( ctx: _flyte_context.FlyteContext, var_name: str, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, t_value_type: type, ) -> _literals_models.Binding: binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type) @@ -671,7 +674,7 @@ class VoidPromise(object): VoidPromise cannot be interacted with and does not allow comparisons or any operations """ - def __init__(self, task_name: str, ref: typing.Optional[NodeOutput] = None): + def __init__(self, task_name: str, ref: Optional[NodeOutput] = None): self._task_name = task_name self._ref = ref @@ -682,11 +685,11 @@ def runs_before(self, *args, **kwargs): """ @property - def ref(self) -> typing.Optional[NodeOutput]: + def ref(self) -> Optional[NodeOutput]: return self._ref - def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if self.ref: + def __rshift__(self, other: Union[Promise, VoidPromise]): + if self.ref and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -811,10 +814,26 @@ def extract_obj_name(name: str) -> str: def create_and_link_node_from_remote( ctx: FlyteContext, entity: HasFlyteInterface, + _inputs_not_allowed: Optional[Set[str]] = None, + _ignorable_inputs: Optional[Set[str]] = None, **kwargs, -): +) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ - This method is used to generate a node with bindings. This is not used in the execution path. + This method is used to generate a node with bindings especially when using remote entities, like FlyteWorkflow, + FlyteTask and FlyteLaunchplan. + + This method is kept separate from the similar named method `create_and_link_node` as remote entities have to be + handled differently. The major difference arises from the fact that the remote entities do not have a python + interface, so all comparisons need to happen using the Literals. + + :param ctx: FlyteContext + :param entity: RemoteEntity + :param _inputs_not_allowed: Set of all variable names that should not be provided when using this entity. + Useful for Launchplans with `fixed` inputs + :param _ignorable_inputs: Set of all variable names that are optional, but if provided will be overriden. Useful + for launchplans with `default` inputs + :param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises. + :return: Optional[Union[Tuple[Promise], Promise, VoidPromise]] """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") @@ -824,9 +843,19 @@ def create_and_link_node_from_remote( typed_interface = entity.interface + if _inputs_not_allowed: + inputs_not_allowed_specified = _inputs_not_allowed.intersection(kwargs.keys()) + if inputs_not_allowed_specified: + raise _user_exceptions.FlyteAssertion( + f"Fixed inputs cannot be specified. Please remove the following inputs - {inputs_not_allowed_specified}" + ) + for k in sorted(typed_interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: + if _inputs_not_allowed and _ignorable_inputs: + if k in _ignorable_inputs or k in _inputs_not_allowed: + continue # TODO to improve the error message, should we show python equivalent types for var.type? raise _user_exceptions.FlyteAssertion("Missing input `{}` type `{}`".format(k, var.type)) v = kwargs[k] @@ -854,7 +883,8 @@ def create_and_link_node_from_remote( extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( - "Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs) + f"Too many inputs for [{entity.name}] Expected inputs: {typed_interface.inputs.keys()} " + f"- extra inputs: {extra_inputs}" ) # Detect upstream nodes @@ -895,7 +925,13 @@ def create_and_link_node( **kwargs, ) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ - This method is used to generate a node with bindings. This is not used in the execution path. + This method is used to generate a node with bindings within a flytekit workflow. this is useful to traverse the + workflow using regular python interpreter and generate nodes and promises whenever an execution is encountered + + :param ctx: FlyteContext + :param entity: RemoteEntity + :param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises. + :return: Optional[Union[Tuple[Promise], Promise, VoidPromise]] """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") @@ -989,11 +1025,13 @@ def create_and_link_node( class LocallyExecutable(Protocol): - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... -def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): +def flyte_entity_call_handler( + entity: SupportsNodeCreation, *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1023,7 +1061,6 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): ) ctx = FlyteContextManager.current_context() - if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return create_and_link_node(ctx, entity=entity, **kwargs) elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: @@ -1046,7 +1083,7 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: - cast(FlyteContext, child_ctx).user_space_params._decks = [] + cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) @@ -1056,7 +1093,9 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): else: raise Exception(f"Received an output when workflow local execution expected None. Received: {result}") - if (1 < expected_outputs == len(result)) or (result is not None and expected_outputs == 1): + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( + result is not None and expected_outputs == 1 + ): return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) raise ValueError( diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 06133d9784..113f94a998 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,12 +3,16 @@ import importlib import re from abc import ABC -from types import ModuleType -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast + +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.configuration import ImageConfig, SerializationSettings -from flytekit.core.base_task import PythonTask, TaskResolverMixin +from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module @@ -18,6 +22,11 @@ from flytekit.models.security import Secret, SecurityContext T = TypeVar("T") +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + + +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): @@ -40,6 +49,8 @@ def __init__( environment: Optional[Dict[str, str]] = None, task_resolver: Optional[TaskResolverMixin] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): """ @@ -64,6 +75,8 @@ def __init__( - `Confidant `__ - `Kube secrets `__ - `AWS Parameter store `__ + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ sec_ctx = None if secret_requests: @@ -71,6 +84,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metedata.pod_template_name + kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() + kwargs["metadata"].pod_template_name = pod_template_name + super().__init__( task_type=task_type, name=name, @@ -98,8 +116,10 @@ def __init__( self._task_resolver = task_resolver or default_task_resolver self._get_command_fn = self.get_default_command + self.pod_template = pod_template + @property - def task_resolver(self) -> Optional[TaskResolverMixin]: + def task_resolver(self) -> TaskResolverMixin: return self._task_resolver @property @@ -157,6 +177,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._get_command_fn(settings) def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + else: + return self._get_container(settings) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {} for elem in (settings.env, self.environment): if elem: @@ -179,6 +206,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe memory_limit=self.resources.limits.mem, ) + def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: + containers = cast(PodTemplate, self.pod_template).pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the default values used in the regular Python task. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + sdk_default_container = self._get_container(settings) + container.image = sdk_default_container.image + # clear existing commands + container.command = sdk_default_container.command + # also clear existing args + container.args = sdk_default_container.args + limits, requests = {}, {} + for resource in sdk_default_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in sdk_default_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=self._serialize_pod_spec(settings), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + ) + + # need to call super in all its children tasks + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} + class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ @@ -188,14 +273,14 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args - task_module = importlib.import_module(task_module) + task_module = importlib.import_module(name=task_module) # type: ignore task_def = getattr(task_module, task_name) return task_def - def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore from flytekit.core.python_function_task import PythonFunctionTask if isinstance(task, PythonFunctionTask): @@ -205,7 +290,7 @@ def loader_args(self, settings: SerializationSettings, task: PythonAutoContainer _, m, t, _ = extract_task_module(task) return ["task-module", m, "task-name", t] - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise Exception("should not be needed") diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index eee0dce9b8..07493886a2 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -21,7 +21,7 @@ TC = TypeVar("TC") -class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): +class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): # type: ignore """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor @@ -229,7 +229,7 @@ def name(self) -> str: # The return type of this function is different, it should be a Task, but it's not because it doesn't make # sense for ExecutableTemplateShimTask to inherit from Task. - def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: + def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: # type: ignore logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore @@ -240,7 +240,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class) - def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: # type: ignore return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"] def get_all_tasks(self) -> List[Task]: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 81f6739a39..90b10cbc36 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -17,7 +17,7 @@ from abc import ABC from collections import OrderedDict from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union, cast from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager @@ -43,7 +43,7 @@ T = TypeVar("T") -class PythonInstanceTask(PythonAutoContainerTask[T], ABC): +class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore """ This class should be used as the base class for all Tasks that do not have a user defined function body, but have a platform defined execute method. (Execute needs to be overridden). This base class ensures that the module loader @@ -72,7 +72,7 @@ def __init__( super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs) -class PythonFunctionTask(PythonAutoContainerTask[T]): +class PythonFunctionTask(PythonAutoContainerTask[T]): # type: ignore """ A Python Function task should be used as the base for all extensions that have a python function. It will automatically detect interface of the python function and when serialized on the hosted Flyte platform handles the @@ -193,10 +193,10 @@ def compile_into_workflow( from flytekit.tools.translator import get_serializable self._create_and_cache_dynamic_workflow() - self._wf.compile(**kwargs) + cast(PythonFunctionWorkflow, self._wf).compile(**kwargs) wf = self._wf - model_entities = OrderedDict() + model_entities: OrderedDict = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. @@ -263,12 +263,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = self._wf.execute(**kwargs) + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) - if len(self._wf.python_interface.outputs) == 0: + if len(cast(PythonFunctionWorkflow, self._wf).python_interface.outputs) == 0: raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") # TODO: This will need to be cleaned up when we revisit top-level tuple support. diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 77b96e6892..de386fa159 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -21,7 +21,7 @@ from flytekit.models.core import workflow as _workflow_model -@dataclass +@dataclass # type: ignore class Reference(ABC): project: str domain: str @@ -43,6 +43,8 @@ def resource_type(self) -> int: @dataclass class TaskReference(Reference): + """A reference object containing metadata that points to a remote task.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.TASK @@ -50,6 +52,8 @@ def resource_type(self) -> int: @dataclass class LaunchPlanReference(Reference): + """A reference object containing metadata that points to a remote launch plan.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.LAUNCH_PLAN @@ -57,6 +61,8 @@ def resource_type(self) -> int: @dataclass class WorkflowReference(Reference): + """A reference object containing metadata that points to a remote workflow.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.WORKFLOW @@ -66,7 +72,7 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]], + inputs: Dict[str, Type], outputs: Dict[str, Type], ): if ( diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 7b46cbe05c..4cf2523f6a 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional + +from flytekit.models import task as task_models @dataclass @@ -33,5 +35,44 @@ class Resources(object): @dataclass class ResourceSpec(object): - requests: Optional[Resources] = None - limits: Optional[Resources] = None + requests: Resources + limits: Resources + + +_ResourceName = task_models.Resources.ResourceName +_ResourceEntry = task_models.Resources.ResourceEntry + + +def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore + resource_entries = [] + if resources.cpu is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu)) + if resources.mem is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=resources.mem)) + if resources.gpu is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=resources.gpu)) + if resources.storage is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.STORAGE, value=resources.storage)) + if resources.ephemeral_storage is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) + return resource_entries + + +def convert_resources_to_resource_model( + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, +) -> task_models.Resources: + """ + Convert flytekit ``Resources`` objects to a Resources model + + :param requests: Resource requests. Optional, defaults to ``None`` + :param limits: Resource limits. Optional, defaults to ``None`` + :return: The given resources as requests and limits + """ + request_entries = [] + limit_entries = [] + if requests is not None: + request_entries = _convert_resources_to_resource_entries(requests) + if limits is not None: + limit_entries = _convert_resources_to_resource_entries(limits) + return task_models.Resources(requests=request_entries, limits=limit_entries) diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 7addc89197..93116d0720 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,6 +6,7 @@ import datetime import re as _re +from typing import Optional import croniter as _croniter @@ -52,7 +53,11 @@ class CronSchedule(_schedule_models.Schedule): _OFFSET_PATTERN = _re.compile("([-+]?)P([-+0-9YMWD]+)?(T([-+0-9HMS.,]+)?)?") def __init__( - self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None + self, + cron_expression: Optional[str] = None, + schedule: Optional[str] = None, + offset: Optional[str] = None, + kickoff_time_input_arg: Optional[str] = None, ): """ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. @@ -161,7 +166,7 @@ class FixedRate(_schedule_models.Schedule): See the :std:ref:`fixed rate intervals` chapter in the cookbook for additional usage examples. """ - def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: str = None): + def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: Optional[str] = None): """ :param datetime.timedelta duration: :param str kickoff_time_input_arg: diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index d8d18293c5..f96db3e49c 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Generic, Type, TypeVar, Union +from typing import Any, Generic, Optional, Type, TypeVar, Union, cast -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger @@ -47,7 +47,7 @@ def name(self) -> str: if self._task_template is not None: return self._task_template.id.name # if not access the subclass's name - return self._name + return self._name # type: ignore @property def task_template(self) -> _task_model.TaskTemplate: @@ -67,13 +67,13 @@ def execute(self, **kwargs) -> Any: """ return self.executor.execute_from_model(self.task_template, **kwargs) - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ return user_params - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, _: Optional[ExecutionParameters], rval: Any) -> Any: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ @@ -92,7 +92,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) ) as exec_ctx: # Added: Have to reverse the Python interface from the task template Flyte interface # See docstring for more details. diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6e5b0a6b6a..b107aafe12 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -4,9 +4,11 @@ from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -87,9 +89,12 @@ def task( requests: Optional[Resources] = None, limits: Optional[Resources] = None, secret_requests: Optional[List[Secret]] = None, - execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + docs: Optional[Documentation] = None, disable_deck: bool = True, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, ) -> Union[Callable, PythonFunctionTask]: """ This is the core decorator to use for any task type in flytekit. @@ -179,6 +184,9 @@ def foo2(): :param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically. :param task_resolver: Provide a custom task resolver. :param disable_deck: If true, this task will not output deck html file + :param docs: Documentation about this task + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ def wrapper(fn) -> PythonFunctionTask: @@ -204,6 +212,9 @@ def wrapper(fn) -> PythonFunctionTask: execution_mode=execution_mode, task_resolver=task_resolver, disable_deck=disable_deck, + docs=docs, + pod_template=pod_template, + pod_template_name=pod_template_name, ) update_wrapper(task_instance, fn) return task_instance @@ -214,7 +225,7 @@ def wrapper(fn) -> PythonFunctionTask: return wrapper -class ReferenceTask(ReferenceEntity, PythonFunctionTask): +class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore """ This is a reference task, the body of the function passed in through the constructor will never be used, only the signature of the function will be. The signature should also match the signature of the task you're referencing, @@ -222,7 +233,7 @@ class ReferenceTask(ReferenceEntity, PythonFunctionTask): """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, project: str, domain: str, name: str, version: str, inputs: Dict[str, type], outputs: Dict[str, Type] ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/testing.py b/flytekit/core/testing.py index 772a4b6df6..f1a0fec7de 100644 --- a/flytekit/core/testing.py +++ b/flytekit/core/testing.py @@ -1,3 +1,4 @@ +import typing from contextlib import contextmanager from typing import Union from unittest.mock import MagicMock @@ -9,7 +10,7 @@ @contextmanager -def task_mock(t: PythonTask) -> MagicMock: +def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. @@ -41,9 +42,9 @@ def _log(*args, **kwargs): return m(*args, **kwargs) _captured_fn = t.execute - t.execute = _log + t.execute = _log # type: ignore yield m - t.execute = _captured_fn + t.execute = _captured_fn # type: ignore def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py index bad4f8c555..3c39d3725c 100644 --- a/flytekit/core/tracked_abc.py +++ b/flytekit/core/tracked_abc.py @@ -3,7 +3,7 @@ from flytekit.core.tracker import TrackedInstance -class FlyteTrackedABC(type(TrackedInstance), type(ABC)): +class FlyteTrackedABC(type(TrackedInstance), type(ABC)): # type: ignore """ This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..23ff7c9222 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -179,7 +179,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: str) -> str: + def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] = None) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 6ddeb5c58c..3d9b64a2bf 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import copy import dataclasses import datetime as _datetime import enum @@ -117,7 +118,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented") @abstractmethod - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: """ Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised :param ctx: FlyteContext @@ -161,7 +162,7 @@ def __init__( self._to_literal_transformer = to_literal_transformer self._from_literal_transformer = from_literal_transformer - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -206,7 +207,7 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC): def __init__(self, name: str, t: Type[T]): super().__init__(name, t) - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -320,7 +321,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) schema = None try: - s = cast(DataClassJsonMixin, t).schema() + s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() for _, v in s.fields.items(): # marshmallow-jsonschema only supports enums loaded by name. # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 @@ -352,6 +353,46 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) ) + def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: + # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is + # not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin + # type from annotated. + if get_origin(python_type) is list: + return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] # type: ignore + elif get_origin(python_type) is dict: + return typing.Dict[ # type: ignore + self._get_origin_type_in_annotation(get_args(python_type)[0]), + self._get_origin_type_in_annotation(get_args(python_type)[1]), + ] + elif get_origin(python_type) is Annotated: + return get_args(python_type)[0] + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(copy.deepcopy(python_type)): + field.type = self._get_origin_type_in_annotation(field.type) + return python_type + + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: + # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, + # so here we convert it back to the Structured Dataset. + from flytekit import StructuredDataset + + if python_type == StructuredDataset and type(python_val) == dict: + return StructuredDataset(**python_val) + elif get_origin(python_type) is list: + return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] # type: ignore + elif get_origin(python_type) is dict: + return { # type: ignore + self._fix_structured_dataset_type(get_args(python_type)[0], k): self._fix_structured_dataset_type( + get_args(python_type)[1], v + ) + for k, v in python_val.items() + } + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(python_type): + val = python_val.__getattribute__(field.name) + python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) + return python_val + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. @@ -367,11 +408,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A return None return self._serialize_flyte_type(python_val, get_args(python_type)[0]) - if hasattr(python_type, "__origin__") and python_type.__origin__ is list: - return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + return [self._serialize_flyte_type(v, get_args(python_type)[0]) for v in cast(list, python_val)] - if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: - return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} + if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + return { + k: self._serialize_flyte_type(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() + } if not dataclasses.is_dataclass(python_type): return python_val @@ -431,7 +474,13 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> t = FlyteSchemaTransformer() return t.to_python_value( FlyteContext.current_context(), - Literal(scalar=Scalar(schema=Schema(python_val.remote_path, t._get_schema_type(expected_python_type)))), + Literal( + scalar=Scalar( + schema=Schema( + cast(FlyteSchema, python_val).remote_path, t._get_schema_type(expected_python_type) + ) + ) + ), expected_python_type, ) elif issubclass(expected_python_type, FlyteFile): @@ -445,7 +494,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=python_val.path, + uri=cast(FlyteFile, python_val).path, ) ) ), @@ -462,7 +511,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=python_val.path, + uri=cast(FlyteDirectory, python_val).path, ) ) ), @@ -475,9 +524,11 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> scalar=Scalar( structured_dataset=StructuredDataset( metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=python_val.file_format) + structured_dataset_type=StructuredDatasetType( + format=cast(StructuredDataset, python_val).file_format + ) ), - uri=python_val.uri, + uri=cast(StructuredDataset, python_val).uri, ) ) ), @@ -516,7 +567,9 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if isinstance(val, dict): ktype, vtype = DictTransformer.get_dict_types(t) # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) - return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + return { + self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items() + } if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -547,9 +600,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be " f"serialized correctly" ) - json_str = _json_format.MessageToJson(lv.scalar.generic) dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) + dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` @@ -557,7 +610,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # calls to guess_python_type would result in a logically equivalent (but new) dataclass, which # TypeEngine.assert_type would not be happy about. @lru_cache(typed=True) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] @@ -582,7 +635,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(python_val)) + struct.update(_MessageToDict(cast(Message, python_val))) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -593,7 +646,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: pb_obj = expected_python_type() dictionary = _MessageToDict(lv.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) + pb_obj = _ParseDict(dictionary, pb_obj) # type: ignore return pb_obj def guess_python_type(self, literal_type: LiteralType) -> Type[T]: @@ -616,7 +669,7 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] - _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() + _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore @classmethod def register( @@ -641,10 +694,10 @@ def register( def register_restricted_type( cls, name: str, - type: Type, + type: Type[T], ): cls._RESTRICTED_TYPES.append(type) - cls.register(RestrictedTypeTransformer(name, type)) + cls.register(RestrictedTypeTransformer(name, type)) # type: ignore @classmethod def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False): @@ -901,8 +954,8 @@ def get_sub_type(t: Type[T]) -> Type[T]: if get_origin(t) is Annotated: return ListTransformer.get_sub_type(get_args(t)[0]) - if t.__origin__ is list and hasattr(t, "__args__"): - return t.__args__[0] + if getattr(t, "__origin__") is list and hasattr(t, "__args__"): + return getattr(t, "__args__")[0] raise ValueError("Only generic univariate typing.List[T] type is supported.") @@ -924,7 +977,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[T]: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals except AttributeError: @@ -933,10 +986,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lits] - def guess_python_type(self, literal_type: LiteralType) -> Type[list]: + def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: - ct = TypeEngine.guess_python_type(literal_type.collection_type) - return typing.List[ct] + ct: Type = TypeEngine.guess_python_type(literal_type.collection_type) + return typing.List[ct] # type: ignore raise ValueError(f"List transformer cannot reverse {literal_type}") @@ -1049,7 +1102,9 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_args(t)[0] try: - trans = [(TypeEngine.get_transformer(x), x) for x in get_args(t)] + trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ + (TypeEngine.get_transformer(x), x) for x in get_args(t) + ] # must go through TypeEngine.to_literal_type instead of trans.get_literal_type # to handle Annotated variants = [_add_tag_to_type(TypeEngine.to_literal_type(x), t.name) for (t, x) in trans] @@ -1066,7 +1121,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp res_type = None for t in get_args(python_type): try: - trans = TypeEngine.get_transformer(t) + trans: TypeTransformer[T] = TypeEngine.get_transformer(t) res = trans.to_literal(ctx, python_val, t, expected) res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) @@ -1099,7 +1154,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: res_tag = None for v in get_args(expected_python_type): try: - trans = TypeEngine.get_transformer(v) + trans: TypeTransformer[T] = TypeEngine.get_transformer(v) if union_tag is not None: if trans.name != union_tag: continue @@ -1138,7 +1193,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] # type: ignore + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] # type: ignore raise ValueError(f"Union transformer cannot reverse {literal_type}") @@ -1185,7 +1240,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: if tp: if tp[0] == str: try: - sub_type = TypeEngine.to_literal_type(tp[1]) + sub_type = TypeEngine.to_literal_type(cast(type, tp[1])) return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") @@ -1206,7 +1261,7 @@ def to_literal( raise ValueError("Flyte MapType expects all keys to be strings") # TODO: log a warning for Annotated objects that contain HashMethod k_type, v_type = self.get_dict_types(python_type) - lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type) + lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: @@ -1222,7 +1277,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - py_map[k] = TypeEngine.to_python_value(ctx, v, tp[1]) + py_map[k] = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict @@ -1260,10 +1315,8 @@ def _blob_type(self) -> _core_types.BlobType: dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType( - blob=self._blob_type(), - ) + def get_literal_type(self, t: typing.TextIO) -> LiteralType: # type: ignore + return _type_models.LiteralType(blob=self._blob_type()) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -1334,7 +1387,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") return LiteralType(enum_type=_core_types.EnumType(values=values)) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal( + self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + ) -> Literal: if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: @@ -1343,11 +1398,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - return expected_python_type(lv.scalar.primitive.string_value) + return expected_python_type(lv.scalar.primitive.string_value) # type: ignore -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: - """Generate a model class based on the provided JSON Schema +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ @@ -1356,7 +1412,7 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore # Handle dataclass and dict elif property_type == "object": if property_val.get("$ref"): @@ -1364,13 +1420,13 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) + attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore # Handle int, float, bool or str else: - attribute_list.append([property_key, _get_element_type(property_val)]) + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) @@ -1544,8 +1600,8 @@ def __init__( raise ValueError("Cannot instantiate LiteralsResolver without a map of Literals.") self._literals = literals self._variable_map = variable_map - self._native_values = {} - self._type_hints = {} + self._native_values: Dict[str, type] = {} + self._type_hints: Dict[str, type] = {} self._ctx = ctx def __str__(self) -> str: @@ -1598,7 +1654,7 @@ def __getitem__(self, key: str): return self.get(key) - def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: + def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: # type: ignore """ This will get the ``attr`` value from the Literal map, and invoke the TypeEngine to convert it into a Python native value. A Python type can optionally be supplied. If successful, the native value will be cached and @@ -1625,7 +1681,9 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: raise e else: ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver") - val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type) + val = TypeEngine.to_python_value( + self._ctx or FlyteContext.current_context(), self._literals[attr], cast(Type, as_type) + ) self._native_values[attr] = val return val diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbb..ee2c841465 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional from flytekit.loggers import logger -from flytekit.models import task as _task_models +from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -51,8 +51,8 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, command: List[str], - args: List[str], - data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + args: Optional[List[str]] = None, + data_loading_config: Optional[task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -64,7 +64,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> _task_models.Container: +) -> task_models.Container: storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -76,50 +76,49 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_request) ) if ephemeral_storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_request)) limits = [] if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_limit)) if ephemeral_storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_limit)) if environment is None: environment = {} - return _task_models.Container( + return task_models.Container( image=image, command=command, args=args, - resources=_task_models.Resources(limits=limits, requests=requests), + resources=task_models.Resources(limits=limits, requests=requests), env=environment, config={}, data_loading_config=data_loading_config, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 1ced1f7b34..b716eb7114 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -39,6 +39,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -168,6 +169,7 @@ def __init__( workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, + docs: Optional[Documentation] = None, **kwargs, ): self._name = name @@ -175,10 +177,26 @@ def __init__( self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface self._interface = transform_interface_to_typed_interface(python_interface) - self._inputs = {} - self._unbound_inputs = set() - self._nodes = [] + self._inputs: Dict[str, Promise] = {} + self._unbound_inputs: set = set() + self._nodes: List[Node] = [] self._output_bindings: List[_literal_models.Binding] = [] + self._docs = docs + + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs = Description(value=self._python_interface.docstring.long_description) + FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -186,16 +204,20 @@ def __init__( def name(self) -> str: return self._name + @property + def docs(self): + return self._docs + @property def short_name(self) -> str: return extract_obj_name(self._name) @property - def workflow_metadata(self) -> Optional[WorkflowMetadata]: + def workflow_metadata(self) -> WorkflowMetadata: return self._workflow_metadata @property - def workflow_metadata_defaults(self): + def workflow_metadata_defaults(self) -> WorkflowMetadataDefaults: return self._workflow_metadata_defaults @property @@ -208,10 +230,12 @@ def interface(self) -> _interface_models.TypedInterface: @property def output_bindings(self) -> List[_literal_models.Binding]: + self.compile() return self._output_bindings @property def nodes(self) -> List[Node]: + self.compile() return self._nodes def __repr__(self): @@ -228,18 +252,22 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ # Get default arguments and override with kwargs passed in input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) + self.compile() return flyte_entity_call_handler(self, *args, **input_kwargs) def execute(self, **kwargs): raise Exception("Should not be called") + def compile(self, **kwargs): + pass + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. @@ -250,6 +278,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. + self.compile() function_outputs = self.execute(**kwargs) # First handle the empty return case. @@ -392,7 +421,7 @@ def execute(self, **kwargs): raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") # Create a map that holds the outputs of each node. - intermediate_node_outputs = {GLOBAL_START_NODE: {}} # type: Dict[Node, Dict[str, Promise]] + intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. # local_execute should've already ensured that all the values in kwargs are Promise objects @@ -489,7 +518,7 @@ def get_input_values(input_value): self._unbound_inputs.remove(input_value) return n # type: ignore - def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: + def add_workflow_input(self, input_name: str, python_type: Type) -> Promise: """ Adds an input to the workflow. """ @@ -516,7 +545,8 @@ def add_workflow_output( f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}" f" starting with the container type (e.g. List[int]" ) - python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var] + promise = cast(Promise, p) + python_type = promise.ref.node.flyte_entity.python_interface.outputs[promise.var] logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) @@ -569,9 +599,10 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], - docstring: Docstring = None, + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, + docstring: Optional[Docstring] = None, + docs: Optional[Documentation] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function @@ -586,13 +617,15 @@ def __init__( workflow_metadata=metadata, workflow_metadata_defaults=default_metadata, python_interface=native_interface, + docs=docs, ) + self.compiled = False @property def function(self): return self._workflow_function - def task_name(self, t: PythonAutoContainerTask) -> str: + def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" def compile(self, **kwargs): @@ -600,6 +633,9 @@ def compile(self, **kwargs): Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ + if self.compiled: + return + self.compiled = True ctx = FlyteContextManager.current_context() self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface) all_nodes = [] @@ -690,7 +726,8 @@ def workflow( _workflow_function=None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, -) -> PythonFunctionWorkflow: + docs: Optional[Documentation] = None, +) -> WorkflowBase: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -718,6 +755,7 @@ def workflow( :param _workflow_function: This argument is implicitly passed and represents the decorated function. :param failure_policy: Use the options in flytekit.WorkflowFailurePolicy :param interruptible: Whether or not tasks launched from this workflow are by default interruptible + :param docs: Description entity for the workflow """ def wrapper(fn): @@ -730,18 +768,18 @@ def wrapper(fn): metadata=workflow_metadata, default_metadata=workflow_metadata_defaults, docstring=Docstring(callable_=fn), + docs=docs, ) - workflow_instance.compile() update_wrapper(workflow_instance, fn) return workflow_instance if _workflow_function: return wrapper(_workflow_function) else: - return wrapper + return wrapper # type: ignore -class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): +class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index cec59e7318..45ee4efa51 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -10,6 +10,11 @@ OUTPUT_DIR_JUPYTER_PREFIX = "jupyter" DECK_FILE_NAME = "deck.html" +try: + from IPython.core.display import HTML +except ImportError: + ... + class Deck: """ @@ -100,8 +105,6 @@ def _get_deck( deck_map = {deck.name: deck.html for deck in new_user_params.decks} raw_html = template.render(metadata=deck_map) if not ignore_jupyter and _ipython_check(): - from IPython.core.display import HTML - return HTML(raw_html) return raw_html diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index 770fe11b73..a29d8e89e6 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -1,6 +1,4 @@ """ -Flytekit PyTorch -========================================= .. currentmodule:: flytekit.extras.pytorch .. autosummary:: @@ -8,6 +6,9 @@ :toctree: generated/ PyTorchCheckpoint + PyTorchCheckpointTransformer + PyTorchModuleTransformer + PyTorchTensorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py index 4cf37871fb..cbaa0c80f0 100644 --- a/flytekit/extras/pytorch/native.py +++ b/flytekit/extras/pytorch/native.py @@ -1,5 +1,5 @@ import pathlib -from typing import Generic, Type, TypeVar +from typing import Type, TypeVar import torch @@ -12,7 +12,7 @@ T = TypeVar("T") -class PyTorchTypeTransformer(TypeTransformer, Generic[T]): +class PyTorchTypeTransformer(TypeTransformer[T]): def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( @@ -63,30 +63,40 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # load pytorch tensor/module from a file return torch.load(local_path, map_location=map_location) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + +class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): + PYTORCH_FORMAT = "PyTorchTensor" + + def __init__(self): + super().__init__(name="PyTorch Tensor", t=torch.Tensor) + + def guess_python_type(self, literal_type: LiteralType) -> Type[torch.Tensor]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.PYTORCH_FORMAT ): - return T + return torch.Tensor raise ValueError(f"Transformer {self} cannot reverse {literal_type}") -class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): - PYTORCH_FORMAT = "PyTorchTensor" - - def __init__(self): - super().__init__(name="PyTorch Tensor", t=torch.Tensor) - - class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]): PYTORCH_FORMAT = "PyTorchModule" def __init__(self): super().__init__(name="PyTorch Module", t=torch.nn.Module) + def guess_python_type(self, literal_type: LiteralType) -> Type[torch.nn.Module]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_FORMAT + ): + return torch.nn.Module + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + TypeEngine.register(PyTorchTensorTransformer()) TypeEngine.register(PyTorchModuleTransformer()) diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py index 0a1bf2dda5..1d16f6080f 100644 --- a/flytekit/extras/sklearn/__init__.py +++ b/flytekit/extras/sklearn/__init__.py @@ -1,11 +1,11 @@ """ -Flytekit Sklearn -========================================= .. currentmodule:: flytekit.extras.sklearn .. autosummary:: :template: custom.rst :toctree: generated/ + + SklearnEstimatorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 0284440da3..8e7d8b3b29 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -92,14 +92,13 @@ def __init__( container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. query_template=query_template, inputs=inputs, outputs=outputs, **kwargs, ) - # Sanitize query by removing the newlines at the end of the query. Keep in mind - # that the query can be a multiline string. - self._query_template = query_template.replace("\n", " ") @property def output_columns(self) -> typing.Optional[typing.List[str]]: diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index 4db9b428ea..fe10c9024b 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -1,3 +1,14 @@ +""" +.. currentmodule:: flytekit.extras.tensorflow + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + TensorFlowRecordFileTransformer + TensorFlowRecordsDirTransformer +""" + from flytekit.loggers import logger # TODO: abstract this out so that there's an established pattern for registering plugins diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 0c8c2e035a..f047348de0 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -13,12 +13,6 @@ # By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning logger = logging.getLogger("flytekit") -# Root logger control -flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" -if os.getenv(flytekit_root_env_var) is not None: - logger.setLevel(int(os.getenv(flytekit_root_env_var))) -else: - logger.setLevel(logging.DEBUG) # Stop propagation so that configuration is isolated to this file (so that it doesn't matter what the # global Python root logger is set to). @@ -40,22 +34,33 @@ # create console handler ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +# Root logger control # Don't want to import the configuration library since that will cause all sorts of circular imports, let's # just use the environment variable if it's defined. Decide in the future when we implement better controls # if we should control with the channel or with the logger level. # The handler log level controls whether log statements will actually print to the screen +flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" level_from_env = os.getenv(LOGGING_ENV_VAR) -if level_from_env is not None: - ch.setLevel(int(level_from_env)) +root_level_from_env = os.getenv(flytekit_root_env_var) +if root_level_from_env is not None: + logger.setLevel(int(root_level_from_env)) +elif level_from_env is not None: + logger.setLevel(int(level_from_env)) else: - ch.setLevel(logging.WARNING) + logger.setLevel(logging.WARNING) for log_name, child_logger in child_loggers.items(): env_var = f"{LOGGING_ENV_VAR}_{log_name.upper()}" level_from_env = os.getenv(env_var) if level_from_env is not None: child_logger.setLevel(int(level_from_env)) + else: + if child_logger is user_space_logger: + child_logger.setLevel(logging.INFO) + else: + child_logger.setLevel(logging.WARNING) # create formatter formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index f34e692123..e40307b6ba 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -1,13 +1,21 @@ +import typing + from flyteidl.admin import workflow_pb2 as _admin_workflow from flytekit.models import common as _common from flytekit.models.core import compiler as _compiler_models from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _core_workflow +from flytekit.models.documentation import Documentation class WorkflowSpec(_common.FlyteIdlEntity): - def __init__(self, template, sub_workflows): + def __init__( + self, + template: _core_workflow.WorkflowTemplate, + sub_workflows: typing.List[_core_workflow.WorkflowTemplate], + docs: typing.Optional[Documentation] = None, + ): """ This object fully encapsulates the specification of a workflow :param flytekit.models.core.workflow.WorkflowTemplate template: @@ -15,6 +23,7 @@ def __init__(self, template, sub_workflows): """ self._template = template self._sub_workflows = sub_workflows + self._docs = docs @property def template(self): @@ -30,6 +39,13 @@ def sub_workflows(self): """ return self._sub_workflows + @property + def docs(self): + """ + :rtype: Description entity for the workflow + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec @@ -37,6 +53,7 @@ def to_flyte_idl(self): return _admin_workflow.WorkflowSpec( template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + description=self._docs.to_flyte_idl() if self._docs else None, ) @classmethod @@ -48,6 +65,7 @@ def from_flyte_idl(cls, pb2_object): return cls( _core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows], + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, ) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 7236dd15ce..62018c1eef 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -414,8 +414,10 @@ def from_flyte_idl(cls, pb): class AuthRole(FlyteIdlEntity): def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): - """ + """Auth configuration for IAM or K8s service account. + Either one or both of the assumable IAM role and/or the K8s service account can be set. + :param Text assumable_iam_role: IAM identity with set permissions policies. :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment administrators are responsible for handling permissions as they diff --git a/flytekit/models/documentation.py b/flytekit/models/documentation.py new file mode 100644 index 0000000000..e1bae8122e --- /dev/null +++ b/flytekit/models/documentation.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from flyteidl.admin import description_entity_pb2 + +from flytekit.models import common as _common_models + + +@dataclass +class Description(_common_models.FlyteIdlEntity): + """ + Full user description with formatting preserved. This can be rendered + by clients, such as the console or command line tools with in-tact + formatting. + """ + + class DescriptionFormat(Enum): + UNKNOWN = 0 + MARKDOWN = 1 + HTML = 2 + RST = 3 + + value: Optional[str] = None + uri: Optional[str] = None + icon_link: Optional[str] = None + format: DescriptionFormat = DescriptionFormat.RST + + def to_flyte_idl(self): + return description_entity_pb2.Description( + value=self.value if self.value else None, + uri=self.uri if self.uri else None, + format=self.format.value, + icon_link=self.icon_link, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.Description) -> "Description": + return cls( + value=pb2_object.value if pb2_object.value else None, + uri=pb2_object.uri if pb2_object.uri else None, + format=Description.DescriptionFormat(pb2_object.format), + icon_link=pb2_object.icon_link if pb2_object.icon_link else None, + ) + + +@dataclass +class SourceCode(_common_models.FlyteIdlEntity): + """ + Link to source code used to define this task or workflow. + """ + + link: Optional[str] = None + + def to_flyte_idl(self): + return description_entity_pb2.SourceCode(link=self.link) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.SourceCode) -> "SourceCode": + return cls(link=pb2_object.link) if pb2_object.link else None + + +@dataclass +class Documentation(_common_models.FlyteIdlEntity): + """ + DescriptionEntity contains detailed description for the task/workflow/launch plan. + Documentation could provide insight into the algorithms, business use case, etc. + Args: + short_description (str): One-liner overview of the entity. + long_description (Optional[Description]): Full user description with formatting preserved. + source_code (Optional[SourceCode]): link to source code used to define this entity + """ + + short_description: Optional[str] = None + long_description: Optional[Description] = None + source_code: Optional[SourceCode] = None + + def to_flyte_idl(self): + return description_entity_pb2.DescriptionEntity( + short_description=self.short_description, + long_description=self.long_description.to_flyte_idl() if self.long_description else None, + source_code=self.source_code.to_flyte_idl() if self.source_code else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.DescriptionEntity) -> "Documentation": + return cls( + short_description=pb2_object.short_description, + long_description=Description.from_flyte_idl(pb2_object.long_description) + if pb2_object.long_description + else None, + source_code=SourceCode.from_flyte_idl(pb2_object.source_code) if pb2_object.source_code else None, + ) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 4f06c3d3c6..e0a864e31e 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -628,7 +628,7 @@ def uri(self) -> str: return self._uri @property - def metadata(self) -> StructuredDatasetMetadata: + def metadata(self) -> Optional[StructuredDatasetMetadata]: return self._metadata def to_flyte_idl(self) -> _literals_pb2.StructuredDataset: diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f2ff5efd89..fc79c87a2d 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -13,6 +13,7 @@ from flytekit.models import literals as _literals from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier +from flytekit.models.documentation import Documentation class Resources(_common.FlyteIdlEntity): @@ -176,6 +177,7 @@ def __init__( discovery_version, deprecated_error_message, cache_serializable, + pod_template_name, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -195,6 +197,7 @@ def __init__( receive deprecation warnings. :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ self._discoverable = discoverable self._runtime = runtime @@ -204,6 +207,7 @@ def __init__( self._discovery_version = discovery_version self._deprecated_error_message = deprecated_error_message self._cache_serializable = cache_serializable + self._pod_template_name = pod_template_name @property def discoverable(self): @@ -273,6 +277,14 @@ def cache_serializable(self): """ return self._cache_serializable + @property + def pod_template_name(self): + """ + The name of the existing PodTemplate resource which will be used in this task. + :rtype: Text + """ + return self._pod_template_name + def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskMetadata @@ -285,6 +297,7 @@ def to_flyte_idl(self): discovery_version=self.discovery_version, deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, + pod_template_name=self.pod_template_name, ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) @@ -305,6 +318,7 @@ def from_flyte_idl(cls, pb2_object): discovery_version=pb2_object.discovery_version, deprecated_error_message=pb2_object.deprecated_error_message, cache_serializable=pb2_object.cache_serializable, + pod_template_name=pb2_object.pod_template_name, ) @@ -480,11 +494,13 @@ def from_flyte_idl(cls, pb2_object): class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): + def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None): """ :param TaskTemplate template: + :param Documentation docs: """ self._template = template + self._docs = docs @property def template(self): @@ -493,11 +509,20 @@ def template(self): """ return self._template + @property + def docs(self): + """ + :rtype: Description entity for the task + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) + return _admin_task.TaskSpec( + template=self.template.to_flyte_idl(), description=self.docs.to_flyte_idl() if self.docs else None + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -505,7 +530,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: :rtype: TaskSpec """ - return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) + return cls( + TaskTemplate.from_flyte_idl(pb2_object.template), + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, + ) class Task(_common.FlyteIdlEntity): diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 174928a5b4..4d6f172586 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -51,9 +51,9 @@ :toctree: generated/ :nosignatures: - ~task.FlyteTask - ~workflow.FlyteWorkflow - ~launch_plan.FlyteLaunchPlan + ~entities.FlyteTask + ~entities.FlyteWorkflow + ~entities.FlyteLaunchPlan .. _remote-flyte-entity-components: @@ -65,9 +65,9 @@ :toctree: generated/ :nosignatures: - ~nodes.FlyteNode - ~component_nodes.FlyteTaskNode - ~component_nodes.FlyteWorkflowNode + ~entities.FlyteNode + ~entities.FlyteTaskNode + ~entities.FlyteWorkflowNode .. _remote-flyte-execution-objects: diff --git a/flytekit/remote/backfill.py b/flytekit/remote/backfill.py new file mode 100644 index 0000000000..154bf4d1b4 --- /dev/null +++ b/flytekit/remote/backfill.py @@ -0,0 +1,97 @@ +import logging +import typing +from datetime import datetime, timedelta + +from croniter import croniter + +from flytekit import LaunchPlan +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase +from flytekit.remote.entities import FlyteLaunchPlan + + +def create_backfill_workflow( + start_date: datetime, + end_date: datetime, + for_lp: typing.Union[LaunchPlan, FlyteLaunchPlan], + parallel: bool = False, + per_node_timeout: timedelta = None, + per_node_retries: int = 0, +) -> typing.Tuple[WorkflowBase, datetime, datetime]: + """ + Generates a new imperative workflow for the launchplan that can be used to backfill the given launchplan. + This can only be used to generate backfilling workflow only for schedulable launchplans + + the Backfill plan is generated as (start_date - exclusive, end_date inclusive) + + .. code-block:: python + :caption: Correct usage for dates example + + lp = Launchplan.get_or_create(...) + start_date = datetime.datetime(2023, 1, 1) + end_date = start_date + datetime.timedelta(days=10) + wf = create_backfill_workflow(start_date, end_date, for_lp=lp) + + + .. code-block:: python + :caption: Incorrect date example + + wf = create_backfill_workflow(end_date, start_date, for_lp=lp) # end_date is before start_date + # OR + wf = create_backfill_workflow(start_date, start_date, for_lp=lp) # start and end date are same + + + :param start_date: datetime generate a backfill starting at this datetime (exclusive) + :param end_date: datetime generate a backfill ending at this datetime (inclusive) + :param for_lp: typing.Union[LaunchPlan, FlyteLaunchPlan] the backfill is generatd for this launchplan + :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially + :param per_node_timeout: timedelta Timeout to use per node + :param per_node_retries: int Retries to user per node + :return: WorkflowBase, datetime datetime -> New generated workflow, datetime for first instance of backfill, datetime for last instance of backfill + """ + if not for_lp: + raise ValueError("Launch plan is required!") + + if start_date >= end_date: + raise ValueError( + f"for a backfill start date should be earlier than end date. Received {start_date} -> {end_date}" + ) + + schedule = for_lp.entity_metadata.schedule if isinstance(for_lp, FlyteLaunchPlan) else for_lp.schedule + + if schedule is None: + raise ValueError("Backfill can only be created for scheduled launch plans") + + if schedule.cron_schedule is not None: + cron_schedule = schedule.cron_schedule + else: + raise NotImplementedError("Currently backfilling only supports cron schedules.") + + logging.info(f"Generating backfill from {start_date} -> {end_date}. Parallel?[{parallel}]") + wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}") + date_iter = croniter(cron_schedule.schedule, start_time=start_date, ret_type=datetime) + prev_node = None + actual_start = None + actual_end = None + while True: + next_start_date = date_iter.get_next() + if not actual_start: + actual_start = next_start_date + if next_start_date >= end_date: + break + actual_end = next_start_date + next_node = wf.add_launch_plan(for_lp, t=next_start_date) + next_node = next_node.with_overrides( + name=f"b-{next_start_date}", retries=per_node_retries, timeout=per_node_timeout + ) + if not parallel: + if prev_node: + prev_node.runs_before(next_node) + prev_node = next_node + + if actual_end is None: + raise StopIteration( + f"The time window is too small for any backfill instances, first instance after start" + f" date is {actual_start}" + ) + + return wf, actual_start, actual_end diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 0c745c11bb..624e01661d 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -4,9 +4,11 @@ from typing import Dict, List, Optional, Tuple, Union +from flytekit import FlyteContext from flytekit.core import constants as _constants from flytekit.core import hash as _hash_mixin from flytekit.core import hash as hash_mixin +from flytekit.core.promise import create_and_link_node_from_remote from flytekit.exceptions import system as _system_exceptions from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import remote_logger @@ -334,6 +336,12 @@ def promote_from_model( return cls(new_if_else_block), converted_sub_workflows +class FlyteGateNode(_workflow_model.GateNode): + @classmethod + def promote_from_model(cls, model: _workflow_model.GateNode): + return cls(model.signal, model.sleep, model.approve) + + class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): """A class encapsulating a remote Flyte node.""" @@ -343,22 +351,23 @@ def __init__( upstream_nodes, bindings, metadata, - task_node: FlyteTaskNode = None, - workflow_node: FlyteWorkflowNode = None, - branch_node: FlyteBranchNode = None, + task_node: Optional[FlyteTaskNode] = None, + workflow_node: Optional[FlyteWorkflowNode] = None, + branch_node: Optional[FlyteBranchNode] = None, + gate_node: Optional[FlyteGateNode] = None, ): - if not task_node and not workflow_node and not branch_node: + if not task_node and not workflow_node and not branch_node and not gate_node: raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one of task|workflow|branch entity specified at once" + "An Flyte node must have one of task|workflow|branch|gate entity specified at once" ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. + # TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead + # of a node? if task_node: self._flyte_entity = task_node.flyte_task elif workflow_node: self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan else: - self._flyte_entity = branch_node + self._flyte_entity = branch_node or gate_node super(FlyteNode, self).__init__( id=id, @@ -369,6 +378,7 @@ def __init__( task_node=task_node, workflow_node=workflow_node, branch_node=branch_node, + gate_node=gate_node, ) self._upstream = upstream_nodes @@ -412,7 +422,7 @@ def promote_from_model( remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") return None, converted_sub_workflows - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None + flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None if model.task_node is not None: if model.task_node.reference_id not in tasks: raise RuntimeError( @@ -435,6 +445,8 @@ def promote_from_model( tasks, converted_sub_workflows, ) + elif model.gate_node is not None: + flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) else: raise _system_exceptions.FlyteSystemException( f"Bad Node model, neither task nor workflow detected, node: {model}" @@ -459,6 +471,7 @@ def promote_from_model( task_node=flyte_task_node, workflow_node=flyte_workflow_node, branch_node=flyte_branch_node, + gate_node=flyte_gate_node, ), converted_sub_workflows, ) @@ -787,5 +800,16 @@ def resource_type(self) -> id_models.ResourceType: def entity_type_text(self) -> str: return "Launch Plan" + def compile(self, ctx: FlyteContext, *args, **kwargs): + fixed_input_lits = self.fixed_inputs.literals or {} + default_input_params = self.default_inputs.parameters or {} + return create_and_link_node_from_remote( + ctx, + entity=self, + _inputs_not_allowed=set(fixed_input_lits.keys()), + _ignorable_inputs=set(default_input_params.keys()), + **kwargs, + ) # noqa + def __repr__(self) -> str: return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py index b40c6e3ff7..4755aad99d 100644 --- a/flytekit/remote/lazy_entity.py +++ b/flytekit/remote/lazy_entity.py @@ -37,7 +37,12 @@ def entity(self) -> T: """ with self._mutex: if self._entity is None: - self._entity = self._getter() + try: + self._entity = self._getter() + except AttributeError as e: + raise RuntimeError( + f"Error downloading the entity {self._name}, (check original exception...)" + ) from e return self._entity def __getattr__(self, item: str) -> typing.Any: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 6473d46ec9..93badd5374 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -17,6 +17,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 from flytekit import Literal @@ -40,11 +41,12 @@ from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models from flytekit.models import task as task_models +from flytekit.models import types as type_models from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.common import Sort from flytekit.models.core import workflow as workflow_model -from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier +from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( ExecutionMetadata, @@ -53,7 +55,8 @@ NotificationList, WorkflowExecutionGetDataResponse, ) -from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteWorkflow +from flytekit.remote.backfill import create_backfill_workflow +from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface from flytekit.remote.lazy_entity import LazyEntity @@ -119,6 +122,22 @@ def _get_entity_identifier( ) +def _get_git_repo_url(source_path): + """ + Get git repo URL from remote.origin.url + """ + try: + from git import Repo + + return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1] + except ImportError: + remote_logger.warning("Could not import git. is the git executable installed?") + except Exception: + # If the file isn't in the git repo, we can't get the url from git config + remote_logger.debug(f"{source_path} is not a git repo.") + return "" + + class FlyteRemote(object): """Main entrypoint for programmatically accessing a Flyte remote backend. @@ -146,7 +165,8 @@ def __init__( if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") - self._client = SynchronousFlyteClient(config.platform, **kwargs) + self._kwargs = kwargs + self._client_initialized = False self._config = config # read config files, env vars, host, ssl options for admin client self._default_project = default_project @@ -168,6 +188,9 @@ def context(self) -> FlyteContext: @property def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" + if not self._client_initialized: + self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs) + self._client_initialized = True return self._client @property @@ -350,6 +373,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N # Listing Entities # ###################### + def list_signals( + self, + execution_name: str, + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + limit: int = 100, + filters: typing.Optional[typing.List[filter_models.Filter]] = None, + ) -> typing.List[Signal]: + """ + :param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param limit: The number of signals to fetch + :param filters: Optional list of filters + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters) + resp = self.client.list_signals(req) + s = resp.signals + return s + + def set_signal( + self, + signal_id: str, + execution_name: str, + value: typing.Union[literal_models.Literal, typing.Any], + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + python_type: typing.Optional[typing.Type] = None, + literal_type: typing.Optional[type_models.LiteralType] = None, + ): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to + convert into a Literal. This argument is only value for wait_for_input type signals. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param python_type: Provide a python type to help with conversion if the value you provided is not a Literal. + :param literal_type: Provide a Flyte literal type to help with conversion if the value you provided + is not a Literal + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + if isinstance(value, Literal): + remote_logger.debug(f"Using provided {value} as existing Literal value") + lit = value + else: + lt = literal_type or ( + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) + ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) + remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") + + req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + def recent_executions( self, project: typing.Optional[str] = None, @@ -725,11 +811,11 @@ def register_script( filename="scriptmode.tar.gz", ), ) - serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, + git_repo=_get_git_repo_url(source_path), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, @@ -1379,7 +1465,7 @@ def sync_execution( upstream_nodes=[], bindings=[], metadata=NodeMetadata(name=""), - flyte_task=flyte_entity, + task_node=FlyteTaskNode(flyte_entity), ) } if len(task_node_exec) >= 1 @@ -1636,6 +1722,88 @@ def generate_console_http_domain(self) -> str: return protocol + f"://{endpoint}" def generate_console_url( - self, execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution] + self, + entity: typing.Union[ + FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflow, FlyteTask, FlyteLaunchPlan + ], ): - return f"{self.generate_console_http_domain()}/console/projects/{execution.id.project}/domains/{execution.id.domain}/executions/{execution.id.name}" + """ + Generate a Flyteconsole URL for the given Flyte remote endpoint. + This will automatically determine if this is an execution or an entity and change the type automatically + """ + if isinstance(entity, (FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution)): + return f"{self.generate_console_http_domain()}/console/projects/{entity.id.project}/domains/{entity.id.domain}/executions/{entity.id.name}" # noqa + + if not isinstance(entity, (FlyteWorkflow, FlyteTask, FlyteLaunchPlan)): + raise ValueError(f"Only remote entities can be looked at in the console, got type {type(entity)}") + rt = "workflow" + if entity.id.resource_type == ResourceType.TASK: + rt = "task" + elif entity.id.resource_type == ResourceType.LAUNCH_PLAN: + rt = "launch_plan" + return f"{self.generate_console_http_domain()}/console/projects/{entity.id.project}/domains/{entity.id.domain}/{rt}/{entity.name}/version/{entity.id.version}" # noqa + + def launch_backfill( + self, + project: str, + domain: str, + from_date: datetime, + to_date: datetime, + launchplan: str, + launchplan_version: str = None, + execution_name: str = None, + version: str = None, + dry_run: bool = False, + execute: bool = True, + parallel: bool = False, + ) -> typing.Optional[FlyteWorkflowExecution, FlyteWorkflow, WorkflowBase]: + """ + Creates and launches a backfill workflow for the given launchplan. If launchplan version is not specified, + then the latest launchplan is retrieved. + The from_date is exclusive and end_date is inclusive and backfill run for all instances in between. + -> (start_date - exclusive, end_date inclusive) + If dry_run is specified, the workflow is created and returned + if execute==False is specified then the workflow is created and registered + in the last case, the workflow is created, registered and executed. + + The `parallel` flag can be used to generate a workflow where all launchplans can be run in parallel. Default + is that execute backfill is run sequentially + + :param project: str project name + :param domain: str domain name + :param from_date: datetime generate a backfill starting at this datetime (exclusive) + :param to_date: datetime generate a backfill ending at this datetime (inclusive) + :param launchplan: str launchplan name in the flyte backend + :param launchplan_version: str (optional) version for the launchplan. If not specified the most recent will be retrieved + :param execution_name: str (optional) the generated execution will be named so. this can help in ensuring idempotency + :param version: str (optional) version to be used for the newly created workflow. + :param dry_run: bool do not register or execute the workflow + :param execute: bool Register and execute the wwkflow. + :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially + :return: In case of dry-run, return WorkflowBase, else if no_execute return FlyteWorkflow else in the default + case return a FlyteWorkflowExecution + """ + lp = self.fetch_launch_plan(project=project, domain=domain, name=launchplan, version=launchplan_version) + wf, start, end = create_backfill_workflow(start_date=from_date, end_date=to_date, for_lp=lp, parallel=parallel) + if dry_run: + remote_logger.warning("Dry Run enabled. Workflow will not be registered and or executed.") + return wf + + unique_fingerprint = f"{start}-{end}-{launchplan}-{launchplan_version}" + h = hashlib.md5() + h.update(unique_fingerprint.encode("utf-8")) + unique_fingerprint_encoded = base64.urlsafe_b64encode(h.digest()).decode("ascii") + if not version: + version = unique_fingerprint_encoded + ss = SerializationSettings( + image_config=ImageConfig.auto(), + project=project, + domain=domain, + version=version, + ) + remote_wf = self.register_workflow(wf, serialization_settings=ss) + + if not execute: + return remote_wf + + return self.execute(remote_wf, inputs={}, project=project, domain=domain, execution_name=execution_name) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 50bac67844..3c9fe64068 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -12,7 +12,7 @@ from flytekit.models import launch_plan from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote -from flytekit.remote.remote import RegistrationSkipped +from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -162,7 +162,7 @@ def load_packages_and_modules( :param options: :return: The common detected root path, the output of _find_project_root """ - + ss.git_repo = _get_git_repo_url(project_root) pkgs_and_modules = [] for pm in pkgs_or_mods: p = Path(pm).resolve() diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index f0ad5e96c6..5ec249fa4b 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,9 +1,10 @@ +import sys import typing from collections import OrderedDict from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -23,6 +24,7 @@ from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model @@ -211,7 +213,8 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return TaskSpec(template=tt) + + return TaskSpec(template=tt, docs=entity.docs) def get_serializable_workflow( @@ -295,8 +298,9 @@ def get_serializable_workflow( nodes=serialized_nodes, outputs=entity.output_bindings, ) + return admin_workflow_models.WorkflowSpec( - template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()) + template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()), docs=entity.docs ) @@ -658,6 +662,11 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) + elif isinstance(entity, GateNode): + import ipdb + + ipdb.set_trace() + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): @@ -678,6 +687,16 @@ def get_serializable( else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") + if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): + # 1. Check if the size of long description exceeds 16KB + # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity + if entity.docs and entity.docs.long_description: + if entity.docs.long_description.value: + if sys.getsizeof(entity.docs.long_description.value) > 16 * 1024 * 1024: + raise ValueError( + "Long Description of the flyte entity exceeds the 16KB size limit. Please specify the uri in the long description instead." + ) + entity.docs.source_code = SourceCode(link=settings.git_repo) # This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it entity_mapping[entity] = cp_entity return cp_entity diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index c2ab8fd438..87b494d0ae 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -28,7 +28,7 @@ TensorBoard. """ -tfrecords_dir = typing.TypeVar("tfrecord") +tfrecords_dir = typing.TypeVar("tfrecords_dir") TFRecordsDirectory = FlyteDirectory[tfrecords_dir] """ This type can be used to denote that the output is a folder that contains tensorflow record files. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index afb59d58d0..7d576f9353 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -115,7 +115,12 @@ def t1(in1: FlyteDirectory["svg"]): field in the ``BlobType``. """ - def __init__(self, path: typing.Union[str, os.PathLike], downloader: typing.Callable = None, remote_directory=None): + def __init__( + self, + path: typing.Union[str, os.PathLike], + downloader: typing.Optional[typing.Callable] = None, + remote_directory: typing.Optional[str] = None, + ): """ :param path: The source path that users are expected to call open() on :param downloader: Optional function that can be passed that used to delay downloading of the actual fil diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9fc55f76ce..6537f85cae 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -346,13 +346,13 @@ def to_python_value( return FlyteFile(uri) # The rest of the logic is only for FlyteFile types. - if not issubclass(expected_python_type, FlyteFile): + if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. if not ctx.file_access.is_remote(uri): - return expected_python_type(uri) + return expected_python_type(uri) # type: ignore # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 38fedfacca..d766818bfd 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -77,7 +77,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return np.load( file=local_path, allow_pickle=metadata.get("allow_pickle", False), - mmap_mode=metadata.get("mmap_mode"), + mmap_mode=metadata.get("mmap_mode"), # type: ignore ) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index f486d1012e..c380bcc481 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -100,32 +100,38 @@ def write(self, *dfs, **kwargs): class LocalIOSchemaReader(SchemaReader[T]): - def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(from_path), cols, fmt) + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(from_path, cols, fmt) @abstractmethod def _read(self, *path: os.PathLike, **kwargs) -> T: pass def iter(self, **kwargs) -> typing.Generator[T, None, None]: - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - yield self._read(Path(entry.path), **kwargs) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + yield self._read(Path(typing.cast(os.DirEntry, entry).path), **kwargs) def all(self, **kwargs) -> T: files: typing.List[os.PathLike] = [] - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - files.append(Path(entry.path)) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + files.append(Path(typing.cast(os.DirEntry, entry).path)) return self._read(*files, **kwargs) class LocalIOSchemaWriter(SchemaWriter[T]): - def __init__(self, to_local_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(to_local_path), cols, fmt) + def __init__(self, to_local_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(to_local_path, cols, fmt) @abstractmethod def _write(self, df: T, path: os.PathLike, **kwargs): @@ -176,10 +182,11 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass_json @dataclass class FlyteSchema(object): - remote_path: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. """ + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -228,12 +235,11 @@ def format(cls) -> SchemaFormat: def __init__( self, - local_path: os.PathLike = None, - remote_path: os.PathLike = None, + local_path: typing.Optional[str] = None, + remote_path: typing.Optional[str] = None, supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, - downloader: typing.Callable[[str, os.PathLike], None] = None, + downloader: typing.Optional[typing.Callable] = None, ): - if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") if ( @@ -254,7 +260,7 @@ def __init__( self._downloader = downloader @property - def local_path(self) -> os.PathLike: + def local_path(self) -> str: return self._local_path @property @@ -262,7 +268,7 @@ def supported_mode(self) -> SchemaOpenMode: return self._supported_mode def open( - self, dataframe_fmt: type = pandas.DataFrame, override_mode: SchemaOpenMode = None + self, dataframe_fmt: type = pandas.DataFrame, override_mode: typing.Optional[SchemaOpenMode] = None ) -> typing.Union[SchemaReader, SchemaWriter]: """ Returns a reader or writer depending on the mode of the object when created. This mode can be @@ -290,13 +296,13 @@ def open( self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: - return h.writer(typing.cast(str, self.local_path), self.columns(), self.format()) - return h.reader(typing.cast(str, self.local_path), self.columns(), self.format()) + return h.writer(self.local_path, self.columns(), self.format()) + return h.reader(self.local_path, self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object if mode == SchemaOpenMode.WRITE: - return h.writer(self.remote_path, self.columns(), self.format()) - return h.reader(self.remote_path, self.columns(), self.format()) + return h.writer(typing.cast(str, self.remote_path), self.columns(), self.format()) + return h.reader(typing.cast(str, self.remote_path), self.columns(), self.format()) def as_readonly(self) -> FlyteSchema: if self._supported_mode == SchemaOpenMode.READ: @@ -304,7 +310,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path if self.remote_path else "", + remote_path=typing.cast(str, self.remote_path) if self.remote_path else "", supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index e4c6078e94..ca6cab8030 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -17,7 +17,9 @@ class ParquetIO(object): def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs) - def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame: + def read( + self, *files: os.PathLike, columns: typing.Optional[typing.List[str]] = None, **kwargs + ) -> pandas.DataFrame: frames = [self._read(chunk=f, columns=columns, **kwargs) for f in files if os.path.getsize(f) > 0] if len(frames) == 1: return frames[0] @@ -56,7 +58,7 @@ def write( class PandasSchemaReader(LocalIOSchemaReader[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() @@ -65,7 +67,7 @@ def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: class PandasSchemaWriter(LocalIOSchemaWriter[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index f0fd917340..1c89a908e6 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import os import types import typing from abc import ABC, abstractmethod @@ -13,11 +12,11 @@ import pandas as pd import pyarrow as pa from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.deck.renderer import Renderable from flytekit.loggers import logger @@ -35,6 +34,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" GENERIC_FORMAT: StructuredDatasetFormat = "" +GENERIC_PROTOCOL: str = "generic protocol" @dataclass_json @@ -45,7 +45,7 @@ class StructuredDataset(object): class (that is just a model, a Python class representation of the protobuf). """ - uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod @@ -59,7 +59,7 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: Optional[str, os.PathLike] = None, + uri: typing.Optional[str] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): @@ -74,10 +74,11 @@ def __init__( # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[Type[DF]] = None + self._dataframe_type: Optional[DF] = None # type: ignore + self._already_uploaded = False @property - def dataframe(self) -> Optional[Type[DF]]: + def dataframe(self) -> Optional[DF]: return self._dataframe @property @@ -92,7 +93,7 @@ def open(self, dataframe_type: Type[DF]): self._dataframe_type = dataframe_type return self - def all(self) -> DF: + def all(self) -> DF: # type: ignore if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() @@ -255,7 +256,7 @@ def decode( ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, - ) -> Union[DF, Generator[DF, None, None]]: + ) -> Union[DF, typing.Iterator[DF]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. @@ -271,11 +272,6 @@ def decode( raise NotImplementedError -def protocol_prefix(uri: str) -> str: - p = DataPersistencePlugins.get_protocol(uri) - return p - - def convert_schema_type_to_structured_dataset_type( column_type: int, ) -> int: @@ -437,11 +433,11 @@ def register( if h.protocol is None: if default_for_type: raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") - for persistence_protocol in DataPersistencePlugins.supported_protocols(): - # TODO: Clean this up when we get to replacing the persistence layer. + for persistence_protocol in ["s3", "gs", "file", "http", "https"]: + # TODO: Clean this up after replacing the persistence layer. # The behavior of the protocols given in the supported_protocols and is_supported_protocol # is not actually the same as the one returned in get_protocol. - stripped = DataPersistencePlugins.get_protocol(persistence_protocol) + stripped = persistence_protocol logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") try: cls.register_for_protocol( @@ -471,8 +467,7 @@ def register_for_protocol( See the main register function instead. """ if protocol == "/": - # TODO: Special fix again, because get_protocol returns file, instead of file:// - protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + protocol = "file" lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: raise DuplicateHandlerError( @@ -483,9 +478,10 @@ def register_for_protocol( if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: if h.python_type in cls.DEFAULT_FORMATS and not override: - logger.warning( - f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." - ) + if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format: + logger.info( + f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." + ) else: logger.debug( f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" @@ -493,7 +489,7 @@ def register_for_protocol( cls.DEFAULT_FORMATS[h.python_type] = h.supported_format if default_storage_for_type or default_for_type: if h.protocol in cls.DEFAULT_PROTOCOLS and not override: - logger.warning( + logger.debug( f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified." ) else: @@ -542,6 +538,8 @@ def to_literal( # def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]: # return dataset if python_val._literal_sd is not None: + if python_val._already_uploaded: + return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) if python_val.dataframe is not None: raise ValueError( f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}" @@ -593,7 +591,7 @@ def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: O if df_type in self.DEFAULT_PROTOCOLS: return self.DEFAULT_PROTOCOLS[df_type] else: - protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix) + protocol = get_protocol(uri or ctx.file_access.raw_output_prefix) logger.debug( f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}" ) @@ -616,13 +614,16 @@ def encode( # least as good as the type of the interface. if sd_model.metadata is None: sd_model._metadata = StructuredDatasetMetadata(structured_literal_type) - if sd_model.metadata.structured_dataset_type is None: + if sd_model.metadata and sd_model.metadata.structured_dataset_type is None: sd_model.metadata._structured_dataset_type = structured_literal_type # Always set the format here to the format of the handler. # Note that this will always be the same as the incoming format except for when the fallback handler # with a format of "" is used. sd_model.metadata._structured_dataset_type.format = handler.supported_format - return Literal(scalar=Scalar(structured_dataset=sd_model)) + lit = Literal(scalar=Scalar(structured_dataset=sd_model)) + sd._literal_sd = sd_model + sd._already_uploaded = True + return lit def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset @@ -746,7 +747,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) - return df.to_html() + return df.to_html() # type: ignore else: df = python_val @@ -769,7 +770,7 @@ def open_as( :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. :return: dataframe. It could be pandas dataframe or arrow table, etc. """ - protocol = protocol_prefix(sd.uri) + protocol = get_protocol(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): @@ -782,10 +783,10 @@ def iter_as( sd: literals.StructuredDataset, df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, - ) -> Generator[DF, None, None]: - protocol = protocol_prefix(sd.uri) + ) -> typing.Iterator[DF]: + protocol = get_protocol(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] - result = decoder.decode(ctx, sd, updated_metadata) + result: Union[DF, typing.Iterator[DF]] = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result @@ -800,7 +801,7 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: raise AssertionError(f"type {t} is currently not supported by StructuredDataset") def _convert_ordered_dict_of_columns_to_list( - self, column_map: typing.OrderedDict[str, Type] + self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: @@ -811,10 +812,13 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: - original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) + original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information - converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) + converted_cols: typing.List[ + StructuredDatasetType.DatasetColumn + ] = self._convert_ordered_dict_of_columns_to_list(column_map) + return StructuredDatasetType( columns=converted_cols, format=storage_format, diff --git a/plugins/README.md b/plugins/README.md index 447b91a37c..495ce91019 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -7,6 +7,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | |------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| | AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | +| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | | K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 8787e70011..0e67b2e50b 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_config(self, settings: SerializationSettings) -> Dict[str, str]: # Parameters in taskTemplate config will be used to create aws job definition. # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html - return {"platformCapabilities": self._task_config.platformCapabilities} + return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities} def get_command(self, settings: SerializationSettings) -> List[str]: container_args = [ diff --git a/plugins/flytekit-dask/README.md b/plugins/flytekit-dask/README.md new file mode 100644 index 0000000000..9d645bcd27 --- /dev/null +++ b/plugins/flytekit-dask/README.md @@ -0,0 +1,21 @@ +# Flytekit Dask Plugin + +Flyte can execute `dask` jobs natively on a Kubernetes Cluster, which manages the virtual `dask` cluster's lifecycle +(spin-up and tear down). It leverages the open-source Kubernetes Dask Operator and can be enabled without signing up +for any service. This is like running a transient (ephemeral) `dask` cluster - a type of cluster spun up for a specific +task and torn down after completion. This helps in making sure that the Python environment is the same on the job-runner +(driver), scheduler and the workers. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-dask +``` + +To configure Dask in the Flyte deployment's backed, follow +[step 1](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html#step-1-deploy-the-dask-plugin-in-the-flyte-backend) +and +[step 2](https://docs.flyte.org/projects/cookbook/en/latest/auto/auto/integrations/kubernetes/k8s_dask/index.html#step-2-environment-setup) + +An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html) +can be found in the documentation. diff --git a/plugins/flytekit-dask/flytekitplugins/dask/__init__.py b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py new file mode 100644 index 0000000000..ccadf385fc --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py @@ -0,0 +1,15 @@ +""" +.. currentmodule:: flytekitplugins.dask + +This package contains the Python related side of the Dask Plugin + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + Dask + Scheduler + WorkerGroup +""" + +from flytekitplugins.dask.task import Dask, Scheduler, WorkerGroup diff --git a/plugins/flytekit-dask/flytekitplugins/dask/models.py b/plugins/flytekit-dask/flytekitplugins/dask/models.py new file mode 100644 index 0000000000..b833ab660a --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/models.py @@ -0,0 +1,134 @@ +from typing import Optional + +from flyteidl.plugins import dask_pb2 as dask_task + +from flytekit.models import common as common +from flytekit.models import task as task + + +class Scheduler(common.FlyteIdlEntity): + """ + Configuration for the scheduler pod + + :param image: Optional image to use. + :param resources: Optional resources to use. + """ + + def __init__(self, image: Optional[str] = None, resources: Optional[task.Resources] = None): + self._image = image + self._resources = resources + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image for the scheduler pod + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources for the scheduler pod + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskScheduler: + """ + :return: The scheduler spec serialized to protobuf + """ + return dask_task.DaskScheduler( + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class WorkerGroup(common.FlyteIdlEntity): + """ + Configuration for a dask worker group + + :param number_of_workers:Number of workers in the group + :param image: Optional image to use for the pods of the worker group + :param resources: Optional resources to use for the pods of the worker group + """ + + def __init__( + self, + number_of_workers: int, + image: Optional[str] = None, + resources: Optional[task.Resources] = None, + ): + if number_of_workers < 1: + raise ValueError( + f"Each worker group needs to have at least one worker, but {number_of_workers} have been specified." + ) + + self._number_of_workers = number_of_workers + self._image = image + self._resources = resources + + @property + def number_of_workers(self) -> Optional[int]: + """ + :return: Optional number of workers for the worker group + """ + return self._number_of_workers + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image to use for the worker pods + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources to use for the worker pods + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskWorkerGroup: + """ + :return: The dask cluster serialized to protobuf + """ + return dask_task.DaskWorkerGroup( + number_of_workers=self.number_of_workers, + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class DaskJob(common.FlyteIdlEntity): + """ + Configuration for the custom dask job to run + + :param scheduler: Configuration for the scheduler + :param workers: Configuration of the default worker group + """ + + def __init__(self, scheduler: Scheduler, workers: WorkerGroup): + self._scheduler = scheduler + self._workers = workers + + @property + def scheduler(self) -> Scheduler: + """ + :return: Configuration for the scheduler pod + """ + return self._scheduler + + @property + def workers(self) -> WorkerGroup: + """ + :return: Configuration of the default worker group + """ + return self._workers + + def to_flyte_idl(self) -> dask_task.DaskJob: + """ + :return: The dask job serialized to protobuf + """ + return dask_task.DaskJob( + scheduler=self.scheduler.to_flyte_idl(), + workers=self.workers.to_flyte_idl(), + ) diff --git a/plugins/flytekit-dask/flytekitplugins/dask/task.py b/plugins/flytekit-dask/flytekitplugins/dask/task.py new file mode 100644 index 0000000000..830ede98ef --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/task.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from flytekitplugins.dask import models +from google.protobuf.json_format import MessageToDict + +from flytekit import PythonFunctionTask, Resources +from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model +from flytekit.core.task import TaskPlugins + + +@dataclass +class Scheduler: + """ + Configuration for the scheduler pod + + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed and should have the same Python + environment as the rest of the cluster (job runner pod + worker pods). + :param requests: Resources to request for the scheduler pod. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the scheduler pod. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class WorkerGroup: + """ + Configuration for a group of dask worker pods + + :param number_of_workers: Number of workers to use. Optional, defaults to 1. + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed. The provided image should have the + same Python environment as the job runner/driver as well as the scheduler. + :param requests: Resources to request for the worker pods. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the worker pods. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + number_of_workers: Optional[int] = 1 + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class Dask: + """ + Configuration for the dask task + + :param scheduler: Configuration for the scheduler pod. Optional, defaults to ``Scheduler()``. + :param workers: Configuration for the pods of the default worker group. Optional, defaults to ``WorkerGroup()``. + """ + + scheduler: Scheduler = Scheduler() + workers: WorkerGroup = WorkerGroup() + + +class DaskTask(PythonFunctionTask[Dask]): + """ + Actual Plugin that transforms the local python code for execution within a dask cluster + """ + + _DASK_TASK_TYPE = "dask" + + def __init__(self, task_config: Dask, task_function: Callable, **kwargs): + super(DaskTask, self).__init__( + task_config=task_config, + task_type=self._DASK_TASK_TYPE, + task_function=task_function, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + """ + Serialize the `dask` task config into a dict. + + :param settings: Current serialization settings + :return: Dictionary representation of the dask task config. + """ + scheduler = models.Scheduler( + image=self.task_config.scheduler.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.scheduler.requests, + limits=self.task_config.scheduler.limits, + ), + ) + workers = models.WorkerGroup( + number_of_workers=self.task_config.workers.number_of_workers, + image=self.task_config.workers.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.workers.requests, + limits=self.task_config.workers.limits, + ), + ) + job = models.DaskJob(scheduler=scheduler, workers=workers) + return MessageToDict(job.to_flyte_idl()) + + +# Inject the `dask` plugin into flytekits dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(Dask, DaskTask) diff --git a/plugins/flytekit-dask/requirements.in b/plugins/flytekit-dask/requirements.in new file mode 100644 index 0000000000..310ade8617 --- /dev/null +++ b/plugins/flytekit-dask/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-dask diff --git a/requirements-spark2.txt b/plugins/flytekit-dask/requirements.txt similarity index 67% rename from requirements-spark2.txt rename to plugins/flytekit-dask/requirements.txt index c6d0ff7fc0..2ec017e46d 100644 --- a/requirements-spark2.txt +++ b/plugins/flytekit-dask/requirements.txt @@ -1,43 +1,47 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# make requirements-spark2.txt +# pip-compile --output-file=requirements.txt requirements.in setup.py # --e file:.#egg=flytekit - # via - # -r requirements-spark2.in - # -r requirements.in +-e file:.#egg=flytekitplugins-dask + # via -r requirements.in arrow==1.2.3 # via jinja2-time -attrs==20.3.0 - # via - # -r requirements.in - # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.12.7 +certifi==2022.9.24 # via requests cffi==1.15.1 # via cryptography -chardet==5.1.0 +chardet==5.0.0 # via binaryornot charset-normalizer==2.1.1 # via requests click==8.1.3 # via # cookiecutter + # dask + # distributed # flytekit cloudpickle==2.2.0 - # via flytekit + # via + # dask + # distributed + # flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.7 # via flytekit -cryptography==38.0.4 +cryptography==38.0.3 # via # pyopenssl # secretstorage +dask[distributed]==2022.10.2 + # via + # distributed + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -46,15 +50,26 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit +distributed==2022.10.2 + # via dask docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.1 - # via flytekit -googleapis-common-protos==1.57.0 +flyteidl==1.3.2 + # via + # flytekit + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +flytekit==1.3.0b2 + # via + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +fsspec==2022.10.0 + # via dask +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -64,13 +79,13 @@ grpcio==1.51.1 # grpcio-status grpcio-status==1.51.1 # via flytekit +heapdict==1.0.1 + # via zict idna==3.4 # via requests -importlib-metadata==5.1.0 +importlib-metadata==5.0.0 # via - # click # flytekit - # jsonschema # keyring jaraco-classes==3.2.3 # via keyring @@ -81,18 +96,21 @@ jeepney==0.8.0 jinja2==3.1.2 # via # cookiecutter + # distributed # jinja2-time jinja2-time==0.2.0 # via cookiecutter joblib==1.2.0 # via flytekit -jsonschema==3.2.0 - # via -r requirements.in keyring==23.11.0 # via flytekit +locket==1.0.0 + # via + # distributed + # partd markupsafe==2.1.1 # via jinja2 -marshmallow==3.19.0 +marshmallow==3.18.0 # via # dataclasses-json # marshmallow-enum @@ -103,25 +121,27 @@ marshmallow-jsonschema==0.13.0 # via flytekit more-itertools==9.0.0 # via jaraco-classes +msgpack==1.0.4 + # via distributed mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.23.4 # via - # -r requirements.in - # flytekit # pandas # pyarrow packaging==21.3 # via + # dask + # distributed # docker # marshmallow -pandas==1.3.5 - # via - # -r requirements.in - # flytekit -protobuf==4.21.10 +pandas==1.5.1 + # via flytekit +partd==1.3.0 + # via dask +protobuf==4.21.11 # via # flyteidl # googleapis-common-protos @@ -129,9 +149,11 @@ protobuf==4.21.10 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl +psutil==5.9.3 + # via distributed py==1.11.0 # via retry -pyarrow==10.0.1 +pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi @@ -139,8 +161,6 @@ pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 - # via jsonschema python-dateutil==2.8.2 # via # arrow @@ -149,7 +169,7 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit @@ -157,10 +177,11 @@ pytz==2022.6 # via # flytekit # pandas -pyyaml==5.4.1 +pyyaml==6.0 # via - # -r requirements.in # cookiecutter + # dask + # distributed # flytekit regex==2022.10.31 # via docker-image-py @@ -176,50 +197,51 @@ retry==0.9.2 # via flytekit secretstorage==3.3.3 # via keyring -singledispatchmethod==1.0 - # via flytekit six==1.16.0 - # via - # jsonschema - # python-dateutil - # websocket-client + # via python-dateutil sortedcontainers==2.4.0 - # via flytekit + # via + # distributed + # flytekit statsd==3.3.0 # via flytekit +tblib==1.7.0 + # via distributed text-unidecode==1.3 # via python-slugify toml==0.10.2 # via responses -types-toml==0.10.8.1 +toolz==0.12.0 + # via + # dask + # distributed + # partd +tornado==6.1 + # via distributed +types-toml==0.10.8 # via responses typing-extensions==4.4.0 # via - # arrow # flytekit - # importlib-metadata - # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.13 +urllib3==1.26.12 # via + # distributed # docker # flytekit # requests # responses -websocket-client==0.59.0 - # via - # -r requirements.in - # docker -wheel==0.38.4 +websocket-client==1.4.2 + # via docker +wheel==0.38.2 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.11.0 +zict==2.2.0 + # via distributed +zipp==3.10.0 # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/plugins/flytekit-dask/setup.py b/plugins/flytekit-dask/setup.py new file mode 100644 index 0000000000..440d7b47db --- /dev/null +++ b/plugins/flytekit-dask/setup.py @@ -0,0 +1,42 @@ +from setuptools import setup + +PLUGIN_NAME = "dask" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flyteidl>=1.3.2", + "flytekit>=1.3.0b2,<2.0.0", + "dask[distributed]>=2022.10.2", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Dask plugin for flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-dask", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", # dask requires >= 3.8 + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tests/flytekit/unit/extras/persistence/__init__.py b/plugins/flytekit-dask/tests/__init__.py similarity index 100% rename from tests/flytekit/unit/extras/persistence/__init__.py rename to plugins/flytekit-dask/tests/__init__.py diff --git a/plugins/flytekit-dask/tests/test_models.py b/plugins/flytekit-dask/tests/test_models.py new file mode 100644 index 0000000000..801a110fb1 --- /dev/null +++ b/plugins/flytekit-dask/tests/test_models.py @@ -0,0 +1,96 @@ +import pytest +from flytekitplugins.dask import models + +from flytekit.models import task as _task + + +@pytest.fixture +def image() -> str: + return "foo:latest" + + +@pytest.fixture +def resources() -> _task.Resources: + return _task.Resources( + requests=[ + _task.Resources.ResourceEntry(name=_task.Resources.ResourceName.CPU, value="3"), + ], + limits=[], + ) + + +@pytest.fixture +def default_resources() -> _task.Resources: + return _task.Resources(requests=[], limits=[]) + + +@pytest.fixture +def scheduler(image: str, resources: _task.Resources) -> models.Scheduler: + return models.Scheduler(image=image, resources=resources) + + +@pytest.fixture +def workers(image: str, resources: _task.Resources) -> models.WorkerGroup: + return models.WorkerGroup(number_of_workers=123, image=image, resources=resources) + + +def test_create_scheduler_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_create_scheduler_to_flyte_idl_all_optional(default_resources: _task.Resources): + scheduler = models.Scheduler(image=None, resources=None) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_create_scheduler_spec_property_access(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + assert scheduler.image == image + assert scheduler.resources == resources + + +def test_worker_group_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + idl_object = worker_group.to_flyte_idl() + assert idl_object.number_of_workers == n_workers + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_worker_group_to_flyte_idl_all_optional(default_resources: _task.Resources): + worker_group = models.WorkerGroup(number_of_workers=1, image=None, resources=None) + idl_object = worker_group.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_worker_group_property_access(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + assert worker_group.image == image + assert worker_group.number_of_workers == n_workers + assert worker_group.resources == resources + + +def test_worker_group_fails_for_less_than_one_worker(): + with pytest.raises(ValueError, match=r"Each worker group needs to"): + models.WorkerGroup(number_of_workers=0, image=None, resources=None) + + +def test_dask_job_to_flyte_idl_no_optional(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + idl_object = job.to_flyte_idl() + assert idl_object.scheduler == scheduler.to_flyte_idl() + assert idl_object.workers == workers.to_flyte_idl() + + +def test_dask_job_property_access(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + assert job.scheduler == scheduler + assert job.workers == workers diff --git a/plugins/flytekit-dask/tests/test_task.py b/plugins/flytekit-dask/tests/test_task.py new file mode 100644 index 0000000000..76dbf9d048 --- /dev/null +++ b/plugins/flytekit-dask/tests/test_task.py @@ -0,0 +1,86 @@ +import pytest +from flytekitplugins.dask import Dask, Scheduler, WorkerGroup + +from flytekit import PythonFunctionTask, Resources, task +from flytekit.configuration import Image, ImageConfig, SerializationSettings + + +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_dask_task_with_default_config(serialization_settings: SerializationSettings): + task_config = Dask() + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + assert dask_task.task_config == task_config + assert dask_task.task_type == "dask" + + expected_dict = { + "scheduler": { + "resources": {}, + }, + "workers": { + "numberOfWorkers": 1, + "resources": {}, + }, + } + assert dask_task.get_custom(serialization_settings) == expected_dict + + +def test_dask_task_get_custom(serialization_settings: SerializationSettings): + task_config = Dask( + scheduler=Scheduler( + image="scheduler:latest", + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + workers=WorkerGroup( + number_of_workers=123, + image="dask_cluster:latest", + requests=Resources(cpu="3"), + limits=Resources(cpu="4"), + ), + ) + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + expected_custom_dict = { + "scheduler": { + "image": "scheduler:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workers": { + "numberOfWorkers": 123, + "image": "dask_cluster:latest", + "resources": { + "requests": [{"name": "CPU", "value": "3"}], + "limits": [{"name": "CPU", "value": "4"}], + }, + }, + } + custom_dict = dask_task.get_custom(serialization_settings) + assert custom_dict == expected_custom_dict diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 0690179bb1..579efd366c 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -1,3 +1,4 @@ +import os import typing import datasets @@ -59,12 +60,11 @@ def decode( ) -> datasets.Dataset: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" - + files = [item.path for item in os.scandir(local_dir)] if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return datasets.Dataset.from_parquet(path, columns=columns) - return datasets.Dataset.from_parquet(path) + return datasets.Dataset.from_parquet(files, columns=columns) + return datasets.Dataset.from_parquet(files) StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 7ce3ac2c1a..22cb096ba8 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -38,4 +38,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py index 170fdc3789..5b65b2511c 100644 --- a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -68,3 +68,15 @@ def test_datasets_renderer(): df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) dataset = datasets.Dataset.from_pandas(df) assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
") + + +def test_parquet_to_datasets(): + df = pd.DataFrame({"name": ["Alice"], "age": [10]}) + + @task + def create_sd() -> StructuredDataset: + return StructuredDataset(dataframe=df) + + sd = create_sd() + dataset = sd.open(datasets.Dataset).all() + assert dataset.data == datasets.Dataset.from_pandas(df).data diff --git a/plugins/flytekit-mlflow/README.md b/plugins/flytekit-mlflow/README.md new file mode 100644 index 0000000000..6cbee9cf59 --- /dev/null +++ b/plugins/flytekit-mlflow/README.md @@ -0,0 +1,22 @@ +# Flytekit MLflow Plugin + +MLflow enables us to log parameters, code, and results in machine learning experiments and compare them using an interactive UI. +This MLflow plugin enables seamless use of MLFlow within Flyte, and render the metrics and parameters on Flyte Deck. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-mlflow +``` + +Example +```python +from flytekit import task, workflow +from flytekitplugins.mlflow import mlflow_autolog +import mlflow + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(): + ... +``` diff --git a/plugins/flytekit-mlflow/dev-requirements.in b/plugins/flytekit-mlflow/dev-requirements.in new file mode 100644 index 0000000000..0f57144081 --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.in @@ -0,0 +1 @@ +tensorflow diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt new file mode 100644 index 0000000000..6ad9be49bb --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -0,0 +1,122 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile dev-requirements.in +# +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +astunparse==1.6.3 + # via tensorflow +cachetools==5.2.0 + # via google-auth +certifi==2022.9.24 + # via requests +charset-normalizer==2.1.1 + # via requests +flatbuffers==22.10.26 + # via tensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.14.1 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +grpcio==1.50.0 + # via + # tensorboard + # tensorflow +h5py==3.7.0 + # via tensorflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via markdown +keras==2.10.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard +markupsafe==2.1.1 + # via werkzeug +numpy==1.23.4 + # via + # h5py + # keras-preprocessing + # opt-einsum + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via tensorflow +protobuf==3.19.6 + # via + # tensorboard + # tensorflow +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pyparsing==3.0.9 + # via packaging +requests==2.28.1 + # via + # requests-oauthlib + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # tensorflow +tensorboard==2.10.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.10.0 + # via -r dev-requirements.in +tensorflow-estimator==2.10.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.27.0 + # via tensorflow +termcolor==2.1.0 + # via tensorflow +typing-extensions==4.4.0 + # via tensorflow +urllib3==1.26.12 + # via requests +werkzeug==2.2.2 + # via tensorboard +wheel==0.38.3 + # via + # astunparse + # tensorboard +wrapt==1.14.1 + # via tensorflow +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py new file mode 100644 index 0000000000..98e84547e0 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.mlflow + +This plugin enables seamless integration between Flyte and mlflow. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + mlflow_autolog +""" + +from .tracking import mlflow_autolog diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py new file mode 100644 index 0000000000..b58aa4a120 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -0,0 +1,140 @@ +import typing +from functools import partial, wraps + +import mlflow +import pandas +import pandas as pd +import plotly.graph_objects as go +from mlflow import MlflowClient +from mlflow.entities.metric import Metric +from plotly.subplots import make_subplots + +import flytekit +from flytekit import FlyteContextManager +from flytekit.bin.entrypoint import get_one_of +from flytekit.core.context_manager import ExecutionState +from flytekit.deck import TopFrameRenderer + + +def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: + """ + Converts mlflow Metric object to a dataframe of 2 columns ['timestamp', 'value'] + """ + t = [] + v = [] + for m in metrics: + t.append(m.timestamp) + v.append(m.value) + return pd.DataFrame(list(zip(t, v)), columns=["timestamp", "value"]) + + +def get_run_metrics(c: MlflowClient, run_id: str) -> typing.Dict[str, pandas.DataFrame]: + """ + Extracts all metrics and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + metrics = {} + for k in r.data.metrics.keys(): + metrics[k] = metric_to_df(metrics=c.get_metric_history(run_id=run_id, key=k)) + return metrics + + +def get_run_params(c: MlflowClient, run_id: str) -> typing.Optional[pd.DataFrame]: + """ + Extracts all parameters and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + name = [] + value = [] + if r.data.params == {}: + return None + for k, v in r.data.params.items(): + name.append(k) + value.append(v) + return pd.DataFrame(list(zip(name, value)), columns=["name", "value"]) + + +def plot_metrics(metrics: typing.Dict[str, pandas.DataFrame]) -> typing.Optional[go.Figure]: + v = len(metrics) + if v == 0: + return None + + # Initialize figure with subplots + fig = make_subplots(rows=v, cols=1, subplot_titles=list(metrics.keys())) + + # Add traces + row = 1 + for k, v in metrics.items(): + v["timestamp"] = (v["timestamp"] - v["timestamp"][0]) / 1000 + fig.add_trace(go.Scatter(x=v["timestamp"], y=v["value"], name=k), row=row, col=1) + row = row + 1 + + fig.update_xaxes(title_text="Time (s)") + fig.update_layout(height=700, width=900) + return fig + + +def mlflow_autolog(fn=None, *, framework=mlflow.sklearn, experiment_name: typing.Optional[str] = None): + """MLFlow decorator to enable autologging of training metrics. + + This decorator can be used as a nested decorator for a ``@task`` and it will automatically enable mlflow autologging, + for the given ``framework``. By default autologging is enabled for ``sklearn``. + + .. code-block:: python + + @task + @mlflow_autolog(framework=mlflow.tensorflow) + def my_tensorflow_trainer(): + ... + + One benefit of doing so is that the mlflow metrics are then rendered inline using FlyteDecks and can be viewed + in jupyter notebook, as well as in hosted Flyte environment: + + .. code-block:: python + + # jupyter notebook cell + with flytekit.new_context() as ctx: + my_tensorflow_trainer() + ctx.get_deck() # IPython.display + + When the task is called in a Flyte backend, the decorator starts a new MLFlow run using the Flyte execution name + by default, or a user-provided ``experiment_name`` in the decorator. + + :param fn: Function to generate autologs for. + :param framework: The mlflow module to use for autologging + :param experiment_name: The MLFlow experiment name. If not provided, uses the Flyte execution name. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + framework.autolog() + params = FlyteContextManager.current_context().user_space_params + ctx = FlyteContextManager.current_context() + + experiment = experiment_name or "local workflow" + run_name = None # MLflow will generate random name if value is None + + if ctx.execution_state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + experiment = experiment_name and f"{get_one_of('FLYTE_INTERNAL_EXECUTION_WORKFLOW', '_F_WF')}" + run_name = f"{params.execution_id.name}.{params.task_id.name.split('.')[-1]}" + + mlflow.set_experiment(experiment) + with mlflow.start_run(run_name=run_name): + out = fn(*args, **kwargs) + run = mlflow.active_run() + if run is not None: + client = MlflowClient() + run_id = run.info.run_id + metrics = get_run_metrics(client, run_id) + figure = plot_metrics(metrics) + if figure: + flytekit.Deck("mlflow metrics", figure.to_html()) + params = get_run_params(client, run_id) + if params is not None: + flytekit.Deck("mlflow params", TopFrameRenderer(max_rows=10).to_html(params)) + return out + + if fn is None: + return partial(mlflow_autolog, framework=framework, experiment_name=experiment_name) + + return wrapper diff --git a/plugins/flytekit-mlflow/requirements.in b/plugins/flytekit-mlflow/requirements.in new file mode 100644 index 0000000000..cbe58e3885 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-mlflow +grpcio-status<1.49.0 diff --git a/requirements.txt b/plugins/flytekit-mlflow/requirements.txt similarity index 59% rename from requirements.txt rename to plugins/flytekit-mlflow/requirements.txt index 5623078a25..03873c05f5 100644 --- a/requirements.txt +++ b/plugins/flytekit-mlflow/requirements.txt @@ -1,41 +1,44 @@ # -# This file is autogenerated by pip-compile with Python 3.7 -# by the following command: +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: # -# make requirements.txt +# pip-compile requirements.in # --e file:.#egg=flytekit +-e file:.#egg=flytekitplugins-mlflow # via -r requirements.in +alembic==1.8.1 + # via mlflow arrow==1.2.3 # via jinja2-time -attrs==20.3.0 - # via - # -r requirements.in - # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.12.7 +certifi==2022.9.24 # via requests cffi==1.15.1 # via cryptography -chardet==5.1.0 +chardet==5.0.0 # via binaryornot charset-normalizer==2.1.1 # via requests click==8.1.3 # via # cookiecutter + # databricks-cli + # flask # flytekit + # mlflow cloudpickle==2.2.0 - # via flytekit + # via + # flytekit + # mlflow cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.7 # via flytekit -cryptography==38.0.4 - # via - # pyopenssl - # secretstorage +cryptography==38.0.3 + # via pyopenssl +databricks-cli==0.17.3 + # via mlflow dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -45,52 +48,74 @@ deprecated==1.2.13 diskcache==5.4.0 # via flytekit docker==6.0.1 - # via flytekit + # via + # flytekit + # mlflow docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.1 +entrypoints==0.4 + # via mlflow +flask==2.2.2 + # via + # mlflow + # prometheus-flask-exporter +flyteidl==1.1.22 # via flytekit -googleapis-common-protos==1.57.0 +flytekit==1.2.3 + # via flytekitplugins-mlflow +gitdb==4.0.9 + # via gitpython +gitpython==3.1.29 + # via mlflow +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status -grpcio==1.51.1 +greenlet==2.0.1 + # via sqlalchemy +grpcio==1.50.0 # via # flytekit # grpcio-status -grpcio-status==1.51.1 - # via flytekit +grpcio-status==1.48.2 + # via + # -r requirements.in + # flytekit +gunicorn==20.1.0 + # via mlflow idna==3.4 # via requests -importlib-metadata==5.1.0 +importlib-metadata==5.0.0 # via - # click + # flask # flytekit - # jsonschema # keyring + # mlflow +itsdangerous==2.1.2 + # via flask jaraco-classes==3.2.3 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter + # flask # jinja2-time jinja2-time==0.2.0 # via cookiecutter joblib==1.2.0 # via flytekit -jsonschema==3.2.0 - # via -r requirements.in keyring==23.11.0 # via flytekit +mako==1.2.3 + # via alembic markupsafe==2.1.1 - # via jinja2 -marshmallow==3.19.0 + # via + # jinja2 + # mako + # werkzeug +marshmallow==3.18.0 # via # dataclasses-json # marshmallow-enum @@ -99,47 +124,59 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mlflow==1.30.0 + # via flytekitplugins-mlflow more-itertools==9.0.0 # via jaraco-classes mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.23.4 # via - # -r requirements.in - # flytekit + # mlflow # pandas # pyarrow + # scipy +oauthlib==3.2.2 + # via databricks-cli packaging==21.3 # via - # -r requirements.in # docker # marshmallow -pandas==1.3.5 + # mlflow +pandas==1.5.1 # via - # -r requirements.in # flytekit -protobuf==4.21.10 + # mlflow +plotly==5.11.0 + # via flytekitplugins-mlflow +prometheus-client==0.15.0 + # via prometheus-flask-exporter +prometheus-flask-exporter==0.20.3 + # via mlflow +protobuf==3.20.3 # via # flyteidl + # flytekit # googleapis-common-protos # grpcio-status + # mlflow # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==10.0.1 +pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi +pyjwt==2.6.0 + # via databricks-cli pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 - # via jsonschema python-dateutil==2.8.2 # via # arrow @@ -148,76 +185,89 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit pytz==2022.6 # via # flytekit + # mlflow # pandas -pyyaml==5.4.1 +pyyaml==6.0 # via - # -r requirements.in # cookiecutter # flytekit + # mlflow +querystring-parser==1.2.4 + # via mlflow regex==2022.10.31 # via docker-image-py requests==2.28.1 # via # cookiecutter + # databricks-cli # docker # flytekit + # mlflow # responses responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.3 - # via keyring -singledispatchmethod==1.0 - # via flytekit +scipy==1.9.3 + # via mlflow six==1.16.0 # via - # jsonschema + # databricks-cli + # grpcio # python-dateutil - # websocket-client + # querystring-parser +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit +sqlalchemy==1.4.43 + # via + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow statsd==3.3.0 # via flytekit +tabulate==0.9.0 + # via databricks-cli +tenacity==8.1.0 + # via plotly text-unidecode==1.3 # via python-slugify toml==0.10.2 # via responses -types-toml==0.10.8.1 +types-toml==0.10.8 # via responses typing-extensions==4.4.0 # via - # arrow # flytekit - # importlib-metadata - # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.13 +urllib3==1.26.12 # via # docker # flytekit # requests # responses -websocket-client==0.59.0 - # via - # -r requirements.in - # docker -wheel==0.38.4 +websocket-client==1.4.2 + # via docker +werkzeug==2.2.2 + # via flask +wheel==0.38.3 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.11.0 +zipp==3.10.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/plugins/flytekit-mlflow/setup.py b/plugins/flytekit-mlflow/setup.py new file mode 100644 index 0000000000..2033ce5d27 --- /dev/null +++ b/plugins/flytekit-mlflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "mlflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables seamless use of MLFlow within Flyte", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-mlflow/tests/__init__.py b/plugins/flytekit-mlflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py new file mode 100644 index 0000000000..b196327d8d --- /dev/null +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -0,0 +1,32 @@ +import mlflow +import tensorflow as tf +from flytekitplugins.mlflow import mlflow_autolog + +import flytekit +from flytekit import task + + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(epochs: int): + fashion_mnist = tf.keras.datasets.fashion_mnist + (train_images, train_labels), (_, _) = fashion_mnist.load_data() + train_images = train_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + model.fit(train_images, train_labels, epochs=epochs) + + +def test_local_exec(): + train_model(epochs=1) + assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index a16d80d781..cc9b26c4fa 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]: def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: return transform2(df=transform1(df=invalid_df)) + invalid_wf() + # raise error when executing workflow with invalid input @workflow def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 04f821ccf3..6c160c2690 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]: return self._config_task_instance.get_k8s_pod(settings) def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]: - return self._config_task_instance.get_config(settings) + return {**super().get_config(settings), **self._config_task_instance.get_config(settings)} def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0dfd0c6516..0b5bf8e577 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -64,11 +64,10 @@ def decode( ) -> pl.DataFrame: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(path, columns=columns) - return pl.read_parquet(path) + return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) + return pl.read_parquet(local_dir, use_pyarrow=True) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index b991cd5d13..15a195e5d5 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -66,3 +66,16 @@ def test_polars_renderer(): assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( df.describe().transpose(), columns=df.describe().columns ).to_html(index=False) + + +def test_parquet_to_polars(): + data = {"name": ["Alice"], "age": [5]} + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame(data=data) + return StructuredDataset(dataframe=df) + + sd = create_sd() + polars_df = sd.open(pl.DataFrame).all() + assert pl.DataFrame(data).frame_equal(polars_df) diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index a012e38d99..672f4a19ad 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -70,7 +70,7 @@ def test_local_exec(): ) assert len(snowflake_task.interface.inputs) == 1 - assert snowflake_task.query_template == "select 1\\n" + assert snowflake_task.query_template == "select 1" assert len(snowflake_task.interface.outputs) == 1 # will not run locally @@ -86,4 +86,4 @@ def test_sql_template(): custom where column = 1""", output_schema_type=FlyteSchema, ) - assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1" + assert snowflake_task.query_template == "select 1 from custom where column = 1" diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 46079f40dd..386570be5c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,6 +1,7 @@ import typing import pandas as pd +import pyspark from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -38,7 +39,10 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - df.write.mode("overwrite").parquet(path) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + # Avoid generating SUCCESS files + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index d344eaa2ba..67d47cf6b1 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -34,4 +34,5 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], scripts=["scripts/flytekit_install_spark3.sh"], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-spark/tests/test_remote_register.py b/plugins/flytekit-spark/tests/test_remote_register.py index 8eaf8a0794..3bb65d09bc 100644 --- a/plugins/flytekit-spark/tests/test_remote_register.py +++ b/plugins/flytekit-spark/tests/test_remote_register.py @@ -21,6 +21,7 @@ def my_python_task(a: str) -> int: mock_client = MagicMock() remote._client = mock_client + remote._client_initialized = True remote.register_task( my_spark, diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 6d20027b2a..7537a3a1de 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -70,7 +70,23 @@ def test_task_schema(sql_server): assert df is not None -def test_workflow(sql_server): +@pytest.mark.parametrize( + "query_template", + [ + "select * from tracks limit {{.inputs.limit}}", + """ + select * from tracks + limit {{.inputs.limit}} + """, + """select * from tracks + limit {{.inputs.limit}} + """, + """ + select * from tracks + limit {{.inputs.limit}}""", + ], +) +def test_workflow(sql_server, query_template): @task def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) @@ -84,7 +100,7 @@ def my_task(df: pandas.DataFrame) -> int: sql_task = SQLAlchemyTask( "test", - query_template="select * from tracks limit {{.inputs.limit}}", + query_template=query_template, inputs=kwtypes(limit=int), task_config=SQLAlchemyConfig(uri=sql_server), ) diff --git a/requirements-spark2.in b/requirements-spark2.in deleted file mode 100644 index 1e47380935..0000000000 --- a/requirements-spark2.in +++ /dev/null @@ -1,2 +0,0 @@ -.[all-spark2.4] --r requirements.in diff --git a/requirements.in b/requirements.in index 09828957fe..1f6d40aaf0 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,3 @@ -. -e file:.#egg=flytekit attrs<21 # We need to restrict constrain the versions of both jsonschema and pyyaml because of docker-compose (which is diff --git a/setup.py b/setup.py index 3d9710004f..329a4149c4 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,5 @@ -import sys - from setuptools import find_packages, setup # noqa -MIN_PYTHON_VERSION = (3, 7) -CURRENT_PYTHON = sys.version_info[:2] -if CURRENT_PYTHON < MIN_PYTHON_VERSION: - print( - f"Flytekit API is only supported for Python version is {MIN_PYTHON_VERSION}+. Detected you are on" - f" version {CURRENT_PYTHON}, installation will not proceed!" - ) - sys.exit(-1) - extras_require = {} __version__ = "0.0.0+develop" @@ -40,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.3.0,<1.4.0", + "flyteidl>=1.3.5,<1.4.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -54,6 +43,7 @@ "grpcio>=1.50.0,<2.0", "grpcio-status>=1.50.0,<2.0", "importlib-metadata", + "fsspec", "pyopenssl", "joblib", "python-json-logger>=2.0.0", @@ -82,6 +72,8 @@ # TODO: We should remove mentions to the deprecated numpy # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 "numpy<1.24.0", + "gitpython", + "kubernetes>=12.0.1", ], extras_require=extras_require, scripts=[ @@ -90,7 +82,7 @@ "flytekit/bin/entrypoint.py", ], license="apache2", - python_requires=">=3.7", + python_requires=">=3.7,<3.11", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 4b48fbcfb9..d5b07fe420 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -121,8 +121,9 @@ discovery_version, deprecated, cache_serializable, + pod_template_name, ) - for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable in product( + for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product( [True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], @@ -131,6 +132,7 @@ ["1.0"], ["deprecated"], [True, False], + ["A", "B"], ) ] diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index b8a781224a..7bd27f438b 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -14,19 +14,19 @@ cffi==1.15.1 # via cryptography chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.8 # via flytekit -cryptography==38.0.4 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -46,29 +46,36 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.5 +flyteidl==1.3.5 # via flytekit -flytekit==1.2.5 +flytekit==1.3.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in fonttools==4.38.0 # via matplotlib -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status -grpcio==1.48.2 +grpcio==1.51.1 # via # flytekit # grpcio-status -grpcio-status==1.48.2 +grpcio-status==1.51.1 # via flytekit idna==3.4 # via requests -importlib-metadata==5.1.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -85,11 +92,11 @@ joblib==1.2.0 # via # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in # flytekit -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -108,28 +115,27 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.22.0 +numpy==1.21.6 # via # flytekit # matplotlib # opencv-python # pandas # pyarrow -opencv-python==4.6.0.66 +opencv-python==4.7.0.68 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -packaging==21.3 +packaging==23.0 # via # docker # marshmallow # matplotlib pandas==1.3.5 # via flytekit -pillow==9.3.0 +pillow==9.4.0 # via matplotlib -protobuf==3.20.3 +protobuf==4.21.12 # via # flyteidl - # flytekit # googleapis-common-protos # grpcio-status # protoc-gen-swagger @@ -141,12 +147,10 @@ pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 - # via - # matplotlib - # packaging + # via matplotlib python-dateutil==2.8.2 # via # arrow @@ -156,11 +160,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -170,7 +174,7 @@ pyyaml==6.0 # flytekit regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker @@ -185,9 +189,9 @@ secretstorage==3.3.3 singledispatchmethod==1.0 # via flytekit six==1.16.0 - # via - # grpcio - # python-dateutil + # via python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -202,19 +206,20 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # kiwisolver # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.13 +urllib3==1.26.14 # via # docker # flytekit # requests # responses -websocket-client==1.4.2 +websocket-client==1.5.0 # via docker wheel==0.38.4 # via @@ -224,5 +229,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.11.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index dd021eb3be..78a828a507 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -8,9 +8,8 @@ import joblib import pytest -from flytekit import kwtypes +from flytekit import LaunchPlan, kwtypes from flytekit.configuration import Config -from flytekit.core.launch_plan import LaunchPlan from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote @@ -221,6 +220,7 @@ def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register): flyte_task = remote.fetch_task(name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}") d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, {"d": d}, wait=True) + remote.sync_execution(execution, sync_nodes=True) assert json.loads(execution.outputs["o0"]) == {"key1": "value1", "key2": "value2"} diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 479ad9e7bd..73a3c1996b 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,22 +2,21 @@ import typing from collections import OrderedDict +import fsspec import mock +import pytest from flyteidl.core.errors_pb2 import ErrorDocument +from fsspec.implementations.arrow import ArrowFSWrapper from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution -from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs -from flytekit.core.data_persistence import DiskPersistence from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.scopes import system_entry_point -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models @@ -110,6 +109,37 @@ def verify_output(*args, **kwargs): assert mock_write_to_file.call_count == 1 +@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"}) +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = Exception("random") + + empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() + mock_load_proto.return_value = empty_literal_map + + def verify_output(*args, **kwargs): + assert isinstance(args[0], ErrorDocument) + + mock_write_to_file.side_effect = verify_output + + with pytest.raises(SystemExit) as cm: + _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") + pytest.assertEqual(cm.value.code, 1) + + # This function collects outputs instead of writing them to a file. # See flytekit.core.utils.write_proto_to_file for the original def get_output_collector(results: OrderedDict): @@ -281,31 +311,16 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock def test_setup_disk_prefix(): with setup_execution("qwerty") as ctx: - assert isinstance(ctx.file_access._default_remote, DiskPersistence) + assert isinstance(ctx.file_access._default_remote, fsspec.AbstractFileSystem) + assert ctx.file_access._default_remote.protocol == "file" def test_setup_cloud_prefix(): with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, S3Persistence) + assert ctx.file_access._default_remote.protocol[0] == "s3" with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, GCSPersistence) - - -def test_persist_ss(): - default_img = Image(name="default", fqn="test", tag="tag") - ss = SerializationSettings( - project="proj1", - domain="dom", - version="version123", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - ss_txt = ss.serialized_context - os.environ["_F_SS_C"] = ss_txt - with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert ctx.serialization_settings.project == "proj1" - assert ctx.serialization_settings.domain == "dom" + assert "gs" in ctx.file_access._default_remote.protocol def test_normalize_inputs(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py new file mode 100644 index 0000000000..8389295af2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -0,0 +1,46 @@ +from datetime import datetime, timedelta + +import click +import pytest +from click.testing import CliRunner +from mock import mock + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.backfill import resolve_backfill_window +from flytekit.remote import FlyteRemote + + +def test_resolve_backfill_window(): + dt = datetime(2022, 12, 1, 8) + window = timedelta(days=10) + assert resolve_backfill_window(None, dt + window, window) == (dt, dt + window) + assert resolve_backfill_window(dt, None, window) == (dt, dt + window) + assert resolve_backfill_window(dt, dt + window) == (dt, dt + window) + with pytest.raises(click.BadParameter): + resolve_backfill_window() + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +def test_pyflyte_backfill(mock_remote): + mock_remote.generate_console_url.return_value = "ex" + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke( + pyflyte.main, + [ + "backfill", + "--parallel", + "-p", + "flytesnacks", + "-d", + "development", + "--from-date", + "now", + "--backfill-window", + "5 day", + "daily", + "--dry-run", + ], + ) + assert result.exit_code == 0 + assert "Execution launched" in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 4951d4be46..a6c0bb91d8 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,6 +8,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.configuration import Config from flytekit.core import context_manager from flytekit.remote.remote import FlyteRemote @@ -34,6 +35,7 @@ def test_saving_remote(mock_remote): mock_context.obj = {} get_and_save_remote_with_click_context(mock_context, "p", "d") assert mock_context.obj["flyte_remote"] is not None + mock_remote.assert_called_once_with(Config.for_sandbox(), default_project="p", default_domain="d") def test_register_with_no_package_or_module_argument(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index d5db7296b9..b211153f44 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -2,6 +2,7 @@ import os import pathlib import typing +from datetime import datetime, timedelta from enum import Enum import click @@ -16,6 +17,8 @@ from flytekit.clis.sdk_in_container.run import ( REMOTE_FLAG_KEY, RUN_LEVEL_PARAMS_KEY, + DateTimeType, + DurationParamType, FileParamType, FlyteLiteralConverter, get_entities_in_file, @@ -32,12 +35,21 @@ DIR_NAME = os.path.dirname(os.path.realpath(__file__)) -def test_pyflyte_run_wf(): - runner = CliRunner() - module_path = WORKFLOW_FILE - result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) +@pytest.fixture +def remote(): + with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client = mock_client + return flyte_remote - assert result.exit_code == 0 + +def test_pyflyte_run_wf(remote): + with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"): + runner = CliRunner() + module_path = WORKFLOW_FILE + result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) + + assert result.exit_code == 0 def test_imperative_wf(): @@ -330,3 +342,22 @@ def test_enum_converter(): assert union_lt.stored_type.simple is None assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] + + +def test_duration_type(): + t = DurationParamType() + assert t.convert(value="1 day", param=None, ctx=None) == timedelta(days=1) + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + +def test_datetime_type(): + t = DateTimeType() + + assert t.convert("2020-01-01", None, None) == datetime(2020, 1, 1) + + now = datetime.now() + v = t.convert("now", None, None) + assert v.day == now.day + assert v.month == now.month diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index b3f1807b96..10a7e09333 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock: return auth_stub_mock +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.get_basic_authorization_header") @mock.patch("flytekit.clients.raw.get_token") @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic( mock_get_token, mock_get_basic_header, mock_dataproxy, + mock_signal, ): mock_secure_channel.return_value = True mock_channel.return_value = True @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic( assert client._metadata[0][0] == "authorization" +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True diff --git a/tests/flytekit/unit/configuration/configs/good.config b/tests/flytekit/unit/configuration/configs/good.config index 56bb837b00..06c2579d42 100644 --- a/tests/flytekit/unit/configuration/configs/good.config +++ b/tests/flytekit/unit/configuration/configs/good.config @@ -7,8 +7,8 @@ assumable_iam_role=some_role [platform] - url=fakeflyte.com +insecure=false [madeup] diff --git a/tests/flytekit/unit/configuration/configs/nossl.yaml b/tests/flytekit/unit/configuration/configs/nossl.yaml new file mode 100644 index 0000000000..f7acdde5a5 --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/nossl.yaml @@ -0,0 +1,4 @@ +admin: + endpoint: dns:///flyte.mycorp.io + authType: Pkce + insecure: false diff --git a/tests/flytekit/unit/configuration/test_file.py b/tests/flytekit/unit/configuration/test_file.py index cb10bf42c0..3ce03f9c50 100644 --- a/tests/flytekit/unit/configuration/test_file.py +++ b/tests/flytekit/unit/configuration/test_file.py @@ -7,7 +7,8 @@ from pytimeparse.timeparse import timeparse from flytekit.configuration import ConfigEntry, get_config_file, set_if_exists -from flytekit.configuration.file import LegacyConfigEntry +from flytekit.configuration.file import LegacyConfigEntry, _exists +from flytekit.configuration.internal import Platform def test_set_if_exists(): @@ -21,6 +22,25 @@ def test_set_if_exists(): assert d["k"] == "x" +@pytest.mark.parametrize( + "data, expected", + [ + [1, True], + [1.0, True], + ["foo", True], + [True, True], + [False, True], + [[1], True], + [{"k": "v"}, True], + [None, False], + [[], False], + [{}, False], + ], +) +def test_exists(data, expected): + assert _exists(data) is expected + + def test_get_config_file(): c = get_config_file(None) assert c is None @@ -118,3 +138,9 @@ def test_env_var_bool_transformer(mock_file_read): # The last read should've triggered the file read since now the env var is no longer set. assert mock_file_read.call_count == 1 + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/configuration/test_image_config.py b/tests/flytekit/unit/configuration/test_image_config.py index be59b883af..84c767f8fb 100644 --- a/tests/flytekit/unit/configuration/test_image_config.py +++ b/tests/flytekit/unit/configuration/test_image_config.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize( "python_version_enum, expected_image_string", [ - (PythonVersion.PYTHON_3_7, "ghcr.io/flyteorg/flytekit:py3.7-latest"), - (PythonVersion.PYTHON_3_8, "ghcr.io/flyteorg/flytekit:py3.8-latest"), - (PythonVersion.PYTHON_3_9, "ghcr.io/flyteorg/flytekit:py3.9-latest"), - (PythonVersion.PYTHON_3_10, "ghcr.io/flyteorg/flytekit:py3.10-latest"), + (PythonVersion.PYTHON_3_7, "cr.flyte.org/flyteorg/flytekit:py3.7-latest"), + (PythonVersion.PYTHON_3_8, "cr.flyte.org/flyteorg/flytekit:py3.8-latest"), + (PythonVersion.PYTHON_3_9, "cr.flyte.org/flyteorg/flytekit:py3.9-latest"), + (PythonVersion.PYTHON_3_10, "cr.flyte.org/flyteorg/flytekit:py3.10-latest"), ], ) def test_defaults(python_version_enum, expected_image_string): @@ -24,8 +24,8 @@ def test_defaults(python_version_enum, expected_image_string): @pytest.mark.parametrize( "python_version_enum, flytekit_version, expected_image_string", [ - (PythonVersion.PYTHON_3_7, "v0.32.0", "ghcr.io/flyteorg/flytekit:py3.7-0.32.0"), - (PythonVersion.PYTHON_3_8, "1.31.3", "ghcr.io/flyteorg/flytekit:py3.8-1.31.3"), + (PythonVersion.PYTHON_3_7, "v0.32.0", "cr.flyte.org/flyteorg/flytekit:py3.7-0.32.0"), + (PythonVersion.PYTHON_3_8, "1.31.3", "cr.flyte.org/flyteorg/flytekit:py3.8-1.31.3"), ], ) def test_set_both(python_version_enum, flytekit_version, expected_image_string): @@ -36,7 +36,7 @@ def test_image_config_auto(): x = ImageConfig.auto_default_image() assert x.images[0].name == "default" version_str = f"{sys.version_info.major}.{sys.version_info.minor}" - assert x.images[0].full == f"ghcr.io/flyteorg/flytekit:py{version_str}-latest" + assert x.images[0].full == f"cr.flyte.org/flyteorg/flytekit:py{version_str}-latest" def test_image_from_flytectl_config(): @@ -56,7 +56,7 @@ def test_not_version(mock_sys): def test_image_create(): with pytest.raises(ValueError): - ImageConfig.create_from("ghcr.io/im/g:latest") + ImageConfig.create_from("cr.flyte.org/im/g:latest") - ic = ImageConfig.from_images("ghcr.io/im/g:latest") - assert ic.default_image.fqn == "ghcr.io/im/g" + ic = ImageConfig.from_images("cr.flyte.org/im/g:latest") + assert ic.default_image.fqn == "cr.flyte.org/im/g" diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 7f6be53a55..97e30b5612 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -77,3 +77,9 @@ def test_some_int(mocked): res = AWS.RETRIES.read(cfg) assert type(res) is int assert res == 5 + + +def test_default_platform_config_endpoint_insecure(): + platform_config = PlatformConfig() + assert platform_config.endpoint == "localhost:30080" + assert platform_config.insecure is False diff --git a/tests/flytekit/unit/configuration/test_yaml_file.py b/tests/flytekit/unit/configuration/test_yaml_file.py index 7e1c3eee98..ba2c61e158 100644 --- a/tests/flytekit/unit/configuration/test_yaml_file.py +++ b/tests/flytekit/unit/configuration/test_yaml_file.py @@ -14,6 +14,7 @@ def test_config_entry_file(): assert c.read() is None cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml")) + assert cfg.yaml_config is not None assert c.read(cfg) == "flyte.mycorp.io" c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist @@ -26,6 +27,7 @@ def test_config_entry_file_normal(): cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/no_images.yaml")) images_dict = Images.get_specified_images(cfg) assert images_dict == {} + assert cfg.yaml_config is not None @mock.patch("flytekit.configuration.file.getenv") @@ -43,6 +45,7 @@ def test_config_entry_file_2(mock_get): cfg = get_config_file(sample_yaml_file_name) assert c.read(cfg) == "flyte.mycorp.io" + assert cfg.yaml_config is not None c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist assert c.read(cfg) is None @@ -67,3 +70,9 @@ def test_real_config(): res = Credentials.SCOPES.read(config_file) assert res == ["all"] + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/nossl.yaml")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/core/flyte_functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py index 9c92364649..5790d5d358 100644 --- a/tests/flytekit/unit/core/flyte_functools/decorator_source.py +++ b/tests/flytekit/unit/core/flyte_functools/decorator_source.py @@ -1,10 +1,11 @@ """Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators""" - +import functools +import typing from functools import wraps from typing import List -def task_setup(function: callable = None, *, integration_requests: List = None) -> None: +def task_setup(function: typing.Callable, *, integration_requests: typing.Optional[List] = None) -> typing.Callable: integration_requests = integration_requests or [] @wraps(function) diff --git a/tests/flytekit/unit/core/flyte_functools/nested_function.py b/tests/flytekit/unit/core/flyte_functools/nested_function.py index 6a3ccfd9e1..98a39e497a 100644 --- a/tests/flytekit/unit/core/flyte_functools/nested_function.py +++ b/tests/flytekit/unit/core/flyte_functools/nested_function.py @@ -32,4 +32,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py index a51a283be5..3278af1bb0 100644 --- a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py @@ -38,4 +38,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py index 07c46cd46a..dd445a6fb3 100644 --- a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py @@ -48,4 +48,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py index 9f7e6599c6..6e22ca9840 100644 --- a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py @@ -26,4 +26,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index f1dbbbd5ef..b5fa46fe54 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -1,9 +1,11 @@ +import os from pathlib import Path import pytest import flytekit from flytekit.core.checkpointer import SyncCheckpoint +from flytekit.core.local_cache import LocalTaskCache def test_sync_checkpoint_write(tmpdir): @@ -35,11 +37,14 @@ def test_sync_checkpoint_save_file(tmpdir): def test_sync_checkpoint_save_filepath(tmpdir): - td_path = Path(tmpdir) - cp = SyncCheckpoint(checkpoint_dest=tmpdir) - dst_path = td_path.joinpath("test") + src_path = Path(os.path.join(tmpdir, "src")) + src_path.mkdir(parents=True, exist_ok=True) + chkpnt_path = Path(os.path.join(tmpdir, "dest")) + chkpnt_path.mkdir() + cp = SyncCheckpoint(checkpoint_dest=str(chkpnt_path)) + dst_path = chkpnt_path.joinpath("test") assert not dst_path.exists() - inp = td_path.joinpath("test") + inp = src_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") cp.save(inp) @@ -123,5 +128,23 @@ def t1(n: int) -> int: return n + 1 +@flytekit.task(cache=True, cache_version="v0") +def t2(n: int) -> int: + ctx = flytekit.current_context() + cp = ctx.checkpoint + cp.write(bytes(n + 1)) + return n + 1 + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + LocalTaskCache.initialize() + LocalTaskCache.clear() + + def test_checkpoint_task(): assert t1(n=5) == 6 + + +def test_checkpoint_cached_task(): + assert t2(n=5) == 6 diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 3963c77c8d..6fe2b01e61 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -35,14 +35,12 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", [("sub_int", int)]) @task def t1(a: int) -> nt: a = a + 2 - return nt( - a, - ) # returns a named tuple + return nt(a) @task def t2(a: int, b: int) -> nt: @@ -198,3 +196,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]: @workflow def wf(): return t3() + + wf() diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index be85918b74..7b0b292baa 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -71,6 +71,22 @@ def multiplier_2(my_input: float) -> float: multiplier_2(my_input=10.0) +def test_condition_else_int(): + @workflow + def multiplier_3(my_input: int) -> float: + return ( + conditional("fractions") + .if_((my_input >= 0) & (my_input < 1.0)) + .then(double(n=my_input)) + .elif_((my_input > 1.0) & (my_input < 10.0)) + .then(square(n=my_input)) + .else_() + .fail("The input must be between 0 and 10") + ) + + assert multiplier_3(my_input=0) == 0 + + def test_condition_sub_workflows(): @task def sum_div_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, div=int, sub=int): @@ -151,12 +167,16 @@ def decompose_unary() -> int: result = return_true() return conditional("test").if_(result).then(success()).else_().then(failed()) + decompose_unary() + with pytest.raises(AssertionError): @workflow def decompose_none() -> int: return conditional("test").if_(None).then(success()).else_().then(failed()) + decompose_none() + with pytest.raises(AssertionError): @workflow @@ -164,6 +184,8 @@ def decompose_is() -> int: result = return_true() return conditional("test").if_(result is True).then(success()).else_().then(failed()) + decompose_is() + @workflow def decompose() -> int: result = return_true() @@ -283,7 +305,7 @@ def branching(x: int): def test_subworkflow_condition_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int, c=str) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int), ("c", str)]) @task def t() -> nt: @@ -302,13 +324,11 @@ def branching(x: int) -> nt: def test_subworkflow_condition_single_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int)]) @task def t() -> nt: - return nt( - 5, - ) + return nt(5) @workflow def wf1() -> nt: diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 98af80638a..6e68c9d4be 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -207,7 +207,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 376 + assert len(tp) == 388 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py new file mode 100644 index 0000000000..874b99e363 --- /dev/null +++ b/tests/flytekit/unit/core/test_data.py @@ -0,0 +1,150 @@ +import os +import shutil +import tempfile + +import fsspec +import mock +import pytest +from fsspec.implementations.arrow import ArrowFSWrapper +from pyarrow import fs + +from flytekit.configuration import Config +from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider + +local = fsspec.filesystem("file") + +# def test_mlje(): +# # pyarrow stuff +# local = fs.LocalFileSystem() +# local_fsspec = ArrowFSWrapper(local) +# +# s3, path = fs.FileSystem.from_uri("s3://flyte-demo/datasets/sddemo/small.parquet") +# print(s3, path) +# f = s3.open_input_stream(path) +# f.readall() +# ws3 = ArrowFSWrapper(s3) +# +# ss3 = fs.S3FileSystem(region="us-east-2") +# +# # base fsspec stuff +# fs3 = fsspec.filesystem("s3") +# fs3.cat_file("s3://flyte-demo/datasets/sddemo/small.parquet") +# +# # Does doing this work with minio without the thing? +# s3, path = fs.FileSystem.from_uri( +# "s3://my-s3-bucket/metadata/flytesnacks/development/am9s9q2dfrkrfnc7x9nd/user_inputs" +# ) +# # If you don't have http, it will try to use SSL. +# # TODO: check the sandbox configuration to see what it uses. +# local_s3 = fs.S3FileSystem( +# access_key="minio", secret_key="miniostorage", endpoint_override="http://localhost:30002" +# ) +# wr_s3 = ArrowFSWrapper(local_s3) + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +@mock.patch("flytekit.core.data_persistence.UUID") +def test_path_getting(mock_uuid_class, mock_gcs): + mock_uuid_class.return_value.hex = "abcdef123" + + # Testing with raw output prefix pointing to a local path + local_raw_fp = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix="/tmp/unittestdata") + assert local_raw_fp.get_random_remote_path() == "/tmp/unittestdata/abcdef123" + assert local_raw_fp.get_random_remote_path("/fsa/blah.csv") == "/tmp/unittestdata/abcdef123/blah.csv" + assert local_raw_fp.get_random_remote_directory() == "/tmp/unittestdata/abcdef123" + + # Test local path and directory + assert local_raw_fp.get_random_local_path() == "/tmp/unittest/local_flytekit/abcdef123" + assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == "/tmp/unittest/local_flytekit/abcdef123/blah.txt" + assert local_raw_fp.get_random_local_directory() == "/tmp/unittest/local_flytekit/abcdef123" + + # Test with remote pointed to s3. + s3_fa = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + # trailing slash should make no difference + s3_fa = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + + # Testing with raw output prefix pointing to file:// + file_raw_fp = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix="file:///tmp/unittestdata") + assert file_raw_fp.get_random_remote_path() == "/tmp/unittestdata/abcdef123" + assert file_raw_fp.get_random_remote_path("/fsa/blah.csv") == "/tmp/unittestdata/abcdef123/blah.csv" + assert file_raw_fp.get_random_remote_directory() == "/tmp/unittestdata/abcdef123" + + g_fa = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix="gs://my-s3-bucket/") + assert g_fa.get_random_remote_path() == "gs://my-s3-bucket/abcdef123" + + +@mock.patch("flytekit.core.data_persistence.UUID") +def test_default_file_access_instance(mock_uuid_class): + mock_uuid_class.return_value.hex = "abcdef123" + + assert default_local_file_access_provider.get_random_local_path().endswith("/sandbox/local_flytekit/abcdef123") + assert default_local_file_access_provider.get_random_local_path("bob.txt").endswith("abcdef123/bob.txt") + + assert default_local_file_access_provider.get_random_local_directory().endswith("sandbox/local_flytekit/abcdef123") + + x = default_local_file_access_provider.get_random_remote_path() + assert x.endswith("raw/abcdef123") + x = default_local_file_access_provider.get_random_remote_path("eve.txt") + assert x.endswith("raw/abcdef123/eve.txt") + x = default_local_file_access_provider.get_random_remote_directory() + assert x.endswith("raw/abcdef123") + + +@pytest.fixture +def source_folder(): + # Set up source directory for testing + parent_temp = tempfile.mkdtemp() + src_dir = os.path.join(parent_temp, "source", "") + nested_dir = os.path.join(src_dir, "nested") + local.mkdir(nested_dir) + local.touch(os.path.join(src_dir, "original.txt")) + local.touch(os.path.join(nested_dir, "more.txt")) + yield src_dir + shutil.rmtree(parent_temp) + + +# Add some assertions +def test_local_raw_fsspec(source_folder): + with tempfile.TemporaryDirectory() as dest_tmpdir: + local.put(source_folder, dest_tmpdir, recursive=True) + + new_temp_dir_2 = tempfile.mkdtemp() + new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") + local.put(source_folder, new_temp_dir_2, recursive=True) + + +# Add some assertions +def test_local_provider(source_folder): + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as dest_tmpdir: + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + + exists = provider.get_random_remote_directory() + provider._default_remote.mkdir(exists) + provider.put_data(source_folder, exists, is_multipart=True) + + +@pytest.mark.needs_local_sandbox +def test_s3_provider(source_folder): + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + + +# Add some assertions +def test_local_provider_get_empty(): + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as empty_source: + with tempfile.TemporaryDirectory() as dest_folder: + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc) + provider.get_data(empty_source, dest_folder, is_multipart=True) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index af39e9e852..27b407c1ce 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,11 +1,11 @@ -from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider def test_get_random_remote_path(): fp = FileAccessProvider("/tmp", "s3://my-bucket") path = fp.get_random_remote_path() assert path.startswith("s3://my-bucket") - assert fp.raw_output_prefix == "s3://my-bucket" + assert fp.raw_output_prefix == "s3://my-bucket/" def test_is_remote(): @@ -14,10 +14,3 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True - - -def test_lister(): - x = DataPersistencePlugins.supported_protocols() - main_protocols = {"file", "/", "gs", "http", "https", "s3"} - all_protocols = set([y.replace("://", "") for y in x]) - assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index cccf406c71..b9b0ebd3fa 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -34,11 +34,16 @@ def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) + @workflow + def subwf(a: int): + t1(a=a) + @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) + subwf(a=a) return s @workflow @@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]: ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map) - assert len(dynamic_job_spec._nodes) == 5 + assert len(dynamic_job_spec._nodes) == 6 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 0cb4f524f9..bd20c39c53 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -49,7 +49,6 @@ def test_engine(): def test_transformer_to_literal_local(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() @@ -86,6 +85,15 @@ def test_transformer_to_literal_local(): with pytest.raises(TypeError, match="No automatic conversion from "): TypeEngine.to_literal(ctx, 3, FlyteDirectory, lt) + +def test_transformer_to_literal_localss(): + random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + + tf = FlyteDirToMultipartBlobTransformer() + lt = tf.get_literal_type(FlyteDirectory) # Can't use if it's not a directory with pytest.raises(FlyteAssertion): p = "/tmp/flyte/xyz" diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index a4689ed814..bb245ad594 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -13,7 +13,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.tools.translator import get_serializable +from flytekit.remote.entities import FlyteWorkflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( @@ -218,7 +219,7 @@ def wf_dyn(a: int) -> typing.Tuple[int, int]: def test_subwf(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def nt1(a: int) -> nt: @@ -290,3 +291,35 @@ def cond_wf(a: int) -> float: x = cond_wf(a=3) assert x == 6 assert stdin.read() == "" + + +def test_promote(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @workflow + def wf(a: int) -> typing.Tuple[int, int, int]: + zzz = sleep(timedelta(seconds=10)) + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + zzz >> x + x >> s1 + s1 >> z + + return y, z, q + + entries = OrderedDict() + wf_spec = get_serializable(entries, serialization_settings, wf) + tts, wf_specs, lp_specs = gather_dependent_entities(entries) + + fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts) + assert fwf.template.nodes[2].gate_node is not None diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index db4b32f6a9..ead5358316 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -67,15 +67,13 @@ def t2(): assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start - nt = typing.NamedTuple("wf_output", from_n0t1=str) + nt = typing.NamedTuple("wf_output", [("from_n0t1", str)]) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() - return nt( - x, - ) + return nt(x) # docs_equivalent_end diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 442851a8a2..26b43f2ef5 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,6 +4,7 @@ from typing_extensions import Annotated # type: ignore +from flytekit import task from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -101,7 +102,7 @@ def x(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): return ("hello world", 5) def y(a: int, b: str) -> nt1: - return nt1("hello world", 5) + return nt1("hello world", 5) # type: ignore result = transform_variable_map(extract_return_annotation(typing.get_type_hints(x).get("return", None))) assert result["x_str"].type.simple == 3 @@ -320,3 +321,20 @@ def z(a: Foo) -> Foo: assert params.parameters["a"].default is None assert our_interface.outputs["o0"].__origin__ == FlytePickle assert our_interface.inputs["a"].__origin__ == FlytePickle + + +def test_doc_string(): + @task + def t1(a: int) -> int: + """Set the temperature value. + + The value of the temp parameter is stored as a value in + the class variable temperature. + """ + return a + + assert t1.docs.short_description == "Set the temperature value." + assert ( + t1.docs.long_description.value + == "The value of the temp parameter is stored as a value in\nthe class variable temperature." + ) diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index ffaff8daad..3addd13e42 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -292,7 +292,7 @@ def wf(a: int, c: str) -> (int, str): def test_lp_all_parameters(): - nt = typing.NamedTuple("OutputsBC", t1_int_output=int, c=str) + nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]) @task def t1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 14c9620ae6..95927873d0 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -159,6 +159,8 @@ def wf1(a: int): def wf2(a: typing.List[int]): return map_task(wf1)(a=a) + wf2() + lp = LaunchPlan.create("test", wf1) with pytest.raises(ValueError): @@ -167,6 +169,8 @@ def wf2(a: typing.List[int]): def wf3(a: typing.List[int]): return map_task(lp)(a=a) + wf3() + def test_inputs_outputs_length(): @task diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 47c8af9830..b8d32e4c8d 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -1,6 +1,7 @@ import datetime import typing from collections import OrderedDict +from dataclasses import dataclass import pytest @@ -95,6 +96,8 @@ def empty_wf2(): def empty_wf2(): create_node(t2, "foo") + empty_wf2() + def test_more_normal_task(): nt = typing.NamedTuple("OneOutput", t1_str_output=str) @@ -102,14 +105,12 @@ def test_more_normal_task(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore @task def t1_nt(a: int) -> nt: # This one returns an instance of the named tuple. - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore @task def t2(a: typing.List[str]) -> str: @@ -132,9 +133,7 @@ def test_reserved_keyword(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore # Test that you can't name an output "outputs" with pytest.raises(FlyteAssertion): @@ -144,6 +143,8 @@ def my_wf(a: int) -> str: t1_node = create_node(t1, a=a) return t1_node.outputs + my_wf() + def test_runs_before(): @task @@ -333,6 +334,8 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(timeout="foo") + my_wf() + @pytest.mark.parametrize( "retries,expected", @@ -424,3 +427,27 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + + +def test_config_override(): + @dataclass + class DummyConfig: + name: str + + @task(task_config=DummyConfig(name="hello")) + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=DummyConfig("flyte")) + + assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte" + + with pytest.raises(ValueError): + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=None) + + my_wf() diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 23b3de4573..d8b043116e 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -4,7 +4,7 @@ import pytest from dataclasses_json import dataclass_json -from flytekit import task +from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager from flytekit.core.context_manager import CompilationState from flytekit.core.promise import ( @@ -64,6 +64,32 @@ def t2(a: int) -> int: assert len(p.ref.node.bindings) == 1 +def test_create_and_link_node_from_remote_ignore(): + @workflow + def wf(i: int, j: int): + ... + + lp = LaunchPlan.get_or_create(wf, name="promise-test", fixed_inputs={"i": 1}, default_inputs={"j": 10}) + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + + # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, + # which is incorrect + with pytest.raises(FlyteAssertion, match="Missing input `i` type `simple: INTEGER"): + create_and_link_node_from_remote(ctx, lp) + + # Even if j is not provided it will default + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}) + + # value of `i` cannot be overriden + with pytest.raises( + FlyteAssertion, match="ixed inputs cannot be specified. Please remove the following inputs - {'i'}" + ): + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}, i=15) + + # It is ok to override `j` which is a default input + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}, j=15) + + @pytest.mark.parametrize( "input", [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 108b42ebf7..24856432b4 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,9 +1,14 @@ from typing import Any import pytest +from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1Volume from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.base_task import TaskMetadata +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image +from flytekit.core.resources import Resources +from flytekit.tools.translator import get_serializable_task @pytest.fixture @@ -36,7 +41,6 @@ def execute(self, **kwargs) -> Any: task = DummyAutoContainerTask(name="x", task_config=None, task_type="t") -task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") def test_default_command(default_serialization_settings): @@ -68,14 +72,227 @@ def test_get_container(default_serialization_settings): assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar"} + ts = get_serializable_task(default_serialization_settings, task) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar"} + + +task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") + def test_get_container_with_task_envvars(default_serialization_settings): c = task_with_env_vars.get_container(default_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar", "HAM": "spam"} + ts = get_serializable_task(default_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} + def test_get_container_without_serialization_settings_envvars(minimal_serialization_settings): c = task_with_env_vars.get_container(minimal_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"HAM": "spam"} + + ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"HAM": "spam"} + + +task_with_pod_template = DummyAutoContainerTask( + name="x", + metadata=TaskMetadata( + pod_template_name="podTemplateB", # should be overwritten + retries=3, # ensure other fields still exists + ), + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + requests=Resources(cpu="3", gpu="1"), + limits=Resources(cpu="6", gpu="2"), + environment={"eKeyA": "eValA", "eKeyB": "vKeyB"}, + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="notPrimary", + ), + V1Container( + name="primary", + image="repo/placeholderImage:0.0.0", + command="placeholderCommand", + args="placeholderArgs", + resources=V1ResourceRequirements(limits={"cpu": "999", "gpu": "999"}), + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + volumes=[V1Volume(name="volume")], + ), + ), + pod_template_name="podTemplateA", +) + + +def test_pod_template(default_serialization_settings): + ################# + # Test get_k8s_pod + ################# + + container = task_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_pod_template.get_k8s_pod(default_serialization_settings) + + # labels/annotations should be passed + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert metadata.annotations == {"aKeyA": "aValA", "aKeyB": "aValB"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][1] + + # To test overwritten attributes + + # image + assert primary_container["image"] == "repo/image:0.0.0" + # command + assert primary_container["command"] == [] + # args + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_pod_template", + ] + # resource + assert primary_container["resources"]["requests"] == {"cpu": "3", "gpu": "1"} + assert primary_container["resources"]["limits"] == {"cpu": "6", "gpu": "2"} + + # To test union attributes + assert primary_container["env"] == [ + {"name": "FOO", "value": "bar"}, + {"name": "eKeyA", "value": "eValA"}, + {"name": "eKeyB", "value": "vKeyB"}, + {"name": "eKeyC", "value": "eValC"}, + {"name": "eKeyD", "value": "eValD"}, + ] + + # To test not overwritten attributes + assert pod_spec["volumes"][0] == {"name": "volume"} + + ################# + # Test pod_template_name + ################# + assert task_with_pod_template.metadata.pod_template_name == "podTemplateA" + assert task_with_pod_template.metadata.retries == 3 + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + + ################# + # Test config + ################# + assert config == {"primary_container_name": "primary"} + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + + assert ts.template.metadata.pod_template_name == "podTemplateA" + assert ts.template.metadata.retries.retries == 3 + assert ts.template.config is not None + + +task_with_minimum_pod_template = DummyAutoContainerTask( + name="x", + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + ), + pod_template_name="A", +) + + +def test_minimum_pod_template(default_serialization_settings): + + ################# + # Test get_k8s_pod + ################# + + container = task_with_minimum_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_minimum_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_minimum_pod_template", + ] + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + assert config == {"primary_container_name": "primary"} + + ################# + # Test pod_teamplte_name + ################# + assert task_with_minimum_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" + assert ts.template.config is not None diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 34aaefaeb3..7bbdd23a21 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -1,10 +1,13 @@ import pytest +from kubernetes.client.models import V1Container, V1PodSpec from flytekit import task from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import isnested, istestfunction +from flytekit.tools.translator import get_serializable_task from tests.flytekit.unit.core import tasks @@ -122,3 +125,83 @@ def foo_missing_cache_version(i: str): @task(cache_serialize=True) def foo_missing_cache(i: str): print(f"{i}") + + +def test_pod_template(): + @task( + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary", + ), + ] + ), + ), + pod_template_name="A", + ) + def func_with_pod_template(i: str): + print(i + "a") + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ################# + # Test get_k8s_pod + ################# + + container = func_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = func_with_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_function_task", + "task-name", + "func_with_pod_template", + ] + + ################# + # Test pod_teamplte_name + ################# + assert func_with_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, func_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" diff --git a/tests/flytekit/unit/core/test_realworld_examples.py b/tests/flytekit/unit/core/test_realworld_examples.py index 83e859c1da..779ba3334c 100644 --- a/tests/flytekit/unit/core/test_realworld_examples.py +++ b/tests/flytekit/unit/core/test_realworld_examples.py @@ -126,7 +126,7 @@ def fit(x: FlyteSchema[FEATURE_COLUMNS], y: FlyteSchema[CLASSES_COLUMNS], hyperp fname = "model.joblib.dat" with open(fname, "w") as f: f.write("Some binary data") - return nt(model=fname) + return nt(model=fname) # type: ignore @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def predict(x: FlyteSchema[FEATURE_COLUMNS], model_ser: FlyteFile[MODELSER_JOBLIB]) -> FlyteSchema[CLASSES_COLUMNS]: diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index df6e093b55..7486422fd9 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -160,7 +160,7 @@ def inner_test(ref_mock): @task def t1(a: int) -> nt1: a = a + 2 - return nt1(a, "world-" + str(a)) + return nt1(a, "world-" + str(a)) # type: ignore @workflow def wf2(a: int): diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py new file mode 100644 index 0000000000..1a3bf64dee --- /dev/null +++ b/tests/flytekit/unit/core/test_resources.py @@ -0,0 +1,68 @@ +from typing import Dict + +import pytest + +import flytekit.models.task as _task_models +from flytekit import Resources +from flytekit.core.resources import convert_resources_to_resource_model + +_ResourceName = _task_models.Resources.ResourceName + + +def test_convert_no_requests_no_limits(): + resource_model = convert_resources_to_resource_model(requests=None, limits=None) + assert isinstance(resource_model, _task_models.Resources) + assert resource_model.requests == [] + assert resource_model.limits == [] + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_requests(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(requests=requests) + + assert len(resources_model.requests) == 1 + request = resources_model.requests[0] + assert isinstance(request, _task_models.Resources.ResourceEntry) + assert request.name == expected_resource_name + assert request.value == expected_resource_value + assert len(resources_model.limits) == 0 + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(limits=requests) + + assert len(resources_model.limits) == 1 + limit = resources_model.limits[0] + assert isinstance(limit, _task_models.Resources.ResourceEntry) + assert limit.name == expected_resource_name + assert limit.value == expected_resource_value + assert len(resources_model.requests) == 0 diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index a96a94843b..d47d57969c 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -406,17 +406,17 @@ def wf() -> typing.NamedTuple("OP", a=str): def test_named_outputs_nested(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello().greet, say_hello().greet) x, y = my_wf() @@ -425,21 +425,23 @@ def my_wf() -> wf_outputs: def test_named_outputs_nested_fail(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) with pytest.raises(AssertionError): # this should fail because say_hello returns a tuple, but we do not de-reference it @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello(), say_hello()) + my_wf() + def test_serialized_docstrings(): @task diff --git a/tests/flytekit/unit/core/test_signal.py b/tests/flytekit/unit/core/test_signal.py new file mode 100644 index 0000000000..a3bee2e4c7 --- /dev/null +++ b/tests/flytekit/unit/core/test_signal.py @@ -0,0 +1,48 @@ +import pytest +from flyteidl.admin.signal_pb2 import Signal, SignalList +from mock import MagicMock + +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier +from flytekit.remote.remote import FlyteRemote + + +@pytest.fixture +def remote(): + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + return flyte_remote + + +def test_remote_list_signals(remote): + ctx = FlyteContextManager.current_context() + wfeid = WorkflowExecutionIdentifier("p", "d", "execid") + signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl() + lt = TypeEngine.to_literal_type(int) + signal = Signal( + id=signal_id, + type=lt.to_flyte_idl(), + value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(), + ) + + mock_client = MagicMock() + mock_client.list_signals.return_value = SignalList(signals=[signal], token="") + + remote._client = mock_client + res = remote.list_signals("execid", "p", "d", limit=10) + assert len(res) == 1 + + +def test_remote_set_signal(remote): + mock_client = MagicMock() + + def checker(request): + assert request.id.signal_id == "sigid" + assert request.value.scalar.primitive.integer == 3 + + mock_client.set_signal.side_effect = checker + + remote._client = mock_client + remote.set_signal("sigid", "execid", 3) diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index bfb41d0fef..a20a175dd3 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,9 +1,11 @@ +import os import tempfile import typing import pandas as pd import pyarrow as pa import pytest +from fsspec.utils import get_protocol from typing_extensions import Annotated import flytekit.configuration @@ -25,7 +27,6 @@ StructuredDatasetTransformerEngine, convert_schema_type_to_structured_dataset_type, extract_cols_and_format, - protocol_prefix, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) @@ -44,8 +45,8 @@ def test_protocol(): - assert protocol_prefix("s3://my-s3-bucket/file") == "s3" - assert protocol_prefix("/file") == "file" + assert get_protocol("s3://my-s3-bucket/file") == "s3" + assert get_protocol("/file") == "file" def generate_pandas() -> pd.DataFrame: @@ -74,7 +75,6 @@ def t1(a: pd.DataFrame) -> pd.DataFrame: def test_setting_of_unset_formats(): - custom = Annotated[StructuredDataset, "parquet"] example = custom(dataframe=df, uri="/path") # It's okay that the annotation is not used here yet. @@ -89,7 +89,9 @@ def t2(path: str) -> StructuredDataset: def wf(path: str) -> StructuredDataset: return t2(path=path) - res = wf(path="/tmp/somewhere") + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "somewhere") + res = wf(path=fname) # Now that it's passed through an encoder however, it should be set. assert res.file_format == "parquet" diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bbe46845fd..842ae7a98c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -45,7 +45,7 @@ from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Void -from flytekit.models.types import LiteralType, SimpleType, TypeStructure +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FileExt, JPEGImageFile @@ -749,18 +749,19 @@ class TestFileStruct(object): def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] @dataclass_json @dataclass class InnerDatasetStruct(object): a: StructuredDataset - b: typing.List[StructuredDataset] - c: typing.Dict[str, StructuredDataset] + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] @dataclass_json @dataclass class DatasetStruct(object): - a: StructuredDataset + a: People b: InnerDatasetStruct sd = StructuredDataset(dataframe=df, file_format="parquet") @@ -941,6 +942,18 @@ def test_union_transformer(): assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int +def test_union_guess_type(): + ut = UnionTransformer() + t = ut.guess_python_type( + LiteralType( + union_type=UnionType( + variants=[LiteralType(simple=SimpleType.STRING), LiteralType(simple=SimpleType.INTEGER)] + ) + ) + ) + assert t == typing.Union[str, int] + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})] @@ -1459,21 +1472,21 @@ def test_multiple_annotations(): TypeEngine.to_literal_type(t) -TestSchema = FlyteSchema[kwtypes(some_str=str)] +TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore @dataclass_json @dataclass class InnerResult: number: int - schema: TestSchema + schema: TestSchema # type: ignore @dataclass_json @dataclass class Result: result: InnerResult - schema: TestSchema + schema: TestSchema # type: ignore def test_schema_in_dataclass(): diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index b6d2d77ae5..532c500d26 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -176,21 +176,19 @@ def my_wf(a: int, b: str) -> (int, str): d = t2(a=y, b=b) return x, d - assert len(my_wf._nodes) == 2 + assert len(my_wf.nodes) == 2 assert my_wf._nodes[0].id == "n0" assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0] - assert len(my_wf._output_bindings) == 2 + assert len(my_wf.output_bindings) == 2 assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" - nt = typing.NamedTuple("SingleNT", t1_int_output=float) + nt = typing.NamedTuple("SingleNT", [("t1_int_output", float)]) @task def t3(a: int) -> nt: - return nt( - a + 2, - ) + return nt(a + 2) assert t3.python_interface.output_tuple_name == "SingleNT" assert t3.interface.outputs["t1_int_output"] is not None @@ -282,18 +280,24 @@ def test_wf_output_mismatch(): def my_wf(a: int, b: str) -> (int, str): return a + my_wf() + with pytest.raises(AssertionError): @workflow def my_wf2(a: int, b: str) -> int: return a, b # type: ignore + my_wf2() + with pytest.raises(AssertionError): @workflow def my_wf3(a: int, b: str) -> int: return (a,) # type: ignore + my_wf3() + assert context_manager.FlyteContextManager.size() == 1 @@ -486,11 +490,16 @@ def t1(path: str) -> DatasetStruct: def wf(path: str) -> DatasetStruct: return t1(path=path) - res = wf(path="/tmp/somewhere") - assert "parquet" == res.a.file_format - assert "parquet" == res.b.a.file_format - assert_frame_equal(df, res.a.open(pd.DataFrame).all()) - assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "df_file") + res = wf(path=fname) + assert "parquet" == res.a.file_format + assert "parquet" == res.b.a.file_format + print("--") + print(f"A: {res.a.open(pd.DataFrame).all()}") + print("==============") + assert_frame_equal(df, res.a.open(pd.DataFrame).all()) + assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) def test_wf1_with_map(): @@ -678,7 +687,7 @@ def lister() -> typing.List[str]: return s assert len(lister.interface.outputs) == 1 - binding_data = lister._output_bindings[0].binding # the property should be named binding_data + binding_data = lister.output_bindings[0].binding # the property should be named binding_data assert binding_data.collection is not None assert len(binding_data.collection.bindings) == 10 @@ -802,6 +811,8 @@ def my_wf(a: int, b: str) -> (int, str): conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah") return x, d + my_wf() + assert context_manager.FlyteContextManager.size() == 1 @@ -882,7 +893,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_subwf(a: int) -> (str, str): + def my_subwf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -1406,7 +1417,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_wf(a: int, b: str) -> (str, typing.List[str]): + def my_wf(a: int, b: str) -> typing.Tuple[str, typing.List[str]]: @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] @@ -1446,7 +1457,7 @@ def t1() -> str: return "Hello" @workflow - def wf() -> typing.NamedTuple("OP", a=str, b=str): + def wf() -> typing.NamedTuple("OP", [("a", str), ("b", str)]): # type: ignore return t1(), t1() assert wf() == ("Hello", "Hello") diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py index 9c2d09c145..2937d9f978 100644 --- a/tests/flytekit/unit/core/test_typing_annotation.py +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -18,7 +18,7 @@ env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) -entity_mapping = OrderedDict() +entity_mapping: OrderedDict = OrderedDict() @task diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 46389daed2..90a8c712e6 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -7,8 +7,9 @@ from typing_extensions import Annotated # type: ignore import flytekit.configuration -from flytekit import StructuredDataset, kwtypes +from flytekit import FlyteContextManager, StructuredDataset, kwtypes from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow @@ -44,12 +45,12 @@ def test_default_metadata_values(): def test_workflow_values(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) - def wf(a: int) -> (str, str): + def wf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -94,7 +95,7 @@ def list_output_wf() -> typing.List[int]: def test_sub_wf_single_named_tuple(): - nt = typing.NamedTuple("SingleNamedOutput", named1=int) + nt = typing.NamedTuple("SingleNamedOutput", [("named1", int)]) @task def t1(a: int) -> nt: @@ -115,7 +116,7 @@ def wf(b: int) -> nt: def test_sub_wf_multi_named_tuple(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def t1(a: int) -> nt: @@ -153,9 +154,11 @@ def no_outputs_wf(): with pytest.raises(AssertionError): @workflow - def one_output_wf() -> int: # noqa + def one_output_wf() -> int: # type: ignore t1(a=3) + one_output_wf() + def test_wf_no_output(): @task @@ -309,10 +312,10 @@ def sd_to_schema_wf() -> pd.DataFrame: @workflow -def schema_to_sd_wf() -> (pd.DataFrame, pd.DataFrame): +def schema_to_sd_wf() -> typing.Tuple[pd.DataFrame, pd.DataFrame]: # schema -> StructuredDataset df = t4() - return t2(df=df), t5(sd=df) + return t2(df=df), t5(sd=df) # type: ignore def test_structured_dataset_wf(): @@ -320,3 +323,18 @@ def test_structured_dataset_wf(): assert_frame_equal(sd_to_schema_wf(), superset_df) assert_frame_equal(schema_to_sd_wf()[0], subset_df) assert_frame_equal(schema_to_sd_wf()[1], subset_df) + + +def test_compile_wf_at_compile_time(): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ): + + @workflow + def wf(): + t4() + + assert ctx.compilation_state is None diff --git a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py b/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py deleted file mode 100644 index d2c50cc4a9..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py +++ /dev/null @@ -1,35 +0,0 @@ -import mock - -from flytekit import GCSPersistence - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1") - mock_exec.assert_called_with(["gsutil", "cp", "/test", "gs://my-bucket/k1"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "/test/*", "gs://my-bucket/k1/"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test") - mock_exec.assert_called_with(["gsutil", "cp", "gs://my-bucket/k1", "/test"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "gs://my-bucket/k1/*", "/test"]) diff --git a/tests/flytekit/unit/extras/persistence/test_http.py b/tests/flytekit/unit/extras/persistence/test_http.py deleted file mode 100644 index 893b43f364..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_http.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from flytekit import HttpPersistence - - -def test_put(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.put("", "", recursive=True) - - -def test_construct_path(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.construct_path(True, False, "", "") - - -def test_exists(): - proxy = HttpPersistence() - assert proxy.exists("https://flyte.org") diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py deleted file mode 100644 index a6f29f36d6..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ /dev/null @@ -1,80 +0,0 @@ -from datetime import timedelta - -import mock - -from flytekit import S3Persistence -from flytekit.configuration import DataConfig, S3Config -from flytekit.extras.persistence import s3_awscli - - -def test_property(): - aws = S3Persistence("s3://raw-output") - assert aws.default_prefix == "s3://raw-output" - - -def test_construct_path(): - aws = S3Persistence() - p = aws.construct_path(True, False, "xyz") - assert p == "s3://xyz" - - -@mock.patch("flytekit.extras.persistence.s3_awscli.S3Persistence._check_binary") -@mock.patch("flytekit.extras.persistence.s3_awscli.subprocess") -def test_retries(mock_subprocess, mock_check): - mock_subprocess.check_call.side_effect = Exception("test exception (404)") - mock_check.return_value = True - - proxy = S3Persistence(data_config=DataConfig(s3=S3Config(backoff=timedelta(seconds=0)))) - assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 8 - - -def test_extra_args(): - assert s3_awscli._extra_args({}) == [] - assert s3_awscli._extra_args({"ContentType": "ct"}) == ["--content-type", "ct"] - assert s3_awscli._extra_args({"ContentEncoding": "ec"}) == ["--content-encoding", "ec"] - assert s3_awscli._extra_args({"ACL": "acl"}) == ["--acl", "acl"] - assert s3_awscli._extra_args({"ContentType": "ct", "ContentEncoding": "ec", "ACL": "acl"}) == [ - "--content-type", - "ct", - "--content-encoding", - "ec", - "--acl", - "acl", - ] - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1") - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put_recursive(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test") - mock_exec.assert_called_with(cmd=["aws", "s3", "cp", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto()) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get_recursive(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto() - ) diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py index 9724a01182..a470b646d4 100644 --- a/tests/flytekit/unit/extras/pytorch/test_transformations.py +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -40,6 +40,7 @@ def test_get_literal_type(transformer, python_type, format): tf = transformer lt = tf.get_literal_type(python_type) assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + assert tf.guess_python_type(lt) == python_type @pytest.mark.parametrize( diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index ef7ea491e6..40fc94a3d2 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -119,14 +119,14 @@ def test_task_serialization(): select * from tracks limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ( """ \ select * \ from tracks \ limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ("select * from abc", "select * from abc"), ], diff --git a/tests/flytekit/unit/models/test_documentation.py b/tests/flytekit/unit/models/test_documentation.py new file mode 100644 index 0000000000..7702df0452 --- /dev/null +++ b/tests/flytekit/unit/models/test_documentation.py @@ -0,0 +1,29 @@ +from flytekit.models.documentation import Description, Documentation, SourceCode + + +def test_long_description(): + value = "long" + icon_link = "http://icon" + obj = Description(value=value, icon_link=icon_link) + assert Description.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.value == value + assert obj.icon_link == icon_link + assert obj.format == Description.DescriptionFormat.RST + + +def test_source_code(): + link = "https://github.com/flyteorg/flytekit" + obj = SourceCode(link=link) + assert SourceCode.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.link == link + + +def test_documentation(): + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + obj = Documentation(short_description=short_description, long_description=long_description, source_code=source_code) + assert Documentation.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.short_description == short_description + assert obj.long_description == long_description + assert obj.source_code == source_code diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index fcebf465f9..a979a39b66 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -7,6 +7,7 @@ import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models +from flytekit import Description, Documentation, SourceCode from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -70,6 +71,7 @@ def test_task_metadata(): "0.1.1b0", "This is deprecated!", True, + "A", ) assert obj.discoverable is True @@ -81,6 +83,7 @@ def test_task_metadata(): assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" + assert obj.pod_template_name == "A" assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl()) @@ -123,7 +126,62 @@ def test_task_template(in_tuple): assert obj.config == {"a": "b"} -def test_task_template__k8s_pod_target(): +def test_task_spec(): + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + {"a": interface_models.Variable(int_type, "description1")}, + { + "b": interface_models.Variable(int_type, "description2"), + "c": interface_models.Variable(int_type, "description3"), + }, + ) + + resource = [task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1")] + resources = task.Resources(resource, resource) + + template = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + task_metadata, + interfaces, + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, + ), + config={"a": "b"}, + ) + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + obj = task.TaskSpec(template, docs) + assert task.TaskSpec.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.docs == docs + assert obj.template == template + + +def test_task_template_k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), @@ -137,6 +195,7 @@ def test_task_template__k8s_pod_target(): "1.0", "deprecated", False, + "A", ), interface_models.TypedInterface( # inputs diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3a42f5af81..2b5b06696b 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -5,8 +5,10 @@ from flytekit.models import task as _task from flytekit.models import types as _types from flytekit.models import workflow_closure as _workflow_closure +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow +from flytekit.models.documentation import Description, Documentation, SourceCode def test_workflow_closure(): @@ -36,6 +38,7 @@ def test_workflow_closure(): "0.1.1b0", "This is deprecated!", True, + "A", ) cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1") @@ -81,3 +84,16 @@ def test_workflow_closure(): obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + workflow_spec = WorkflowSpec(template=template, sub_workflows=[], docs=docs) + assert WorkflowSpec.from_flyte_idl(workflow_spec.to_flyte_idl()) == workflow_spec + assert workflow_spec.docs.short_description == short_description + assert workflow_spec.docs.long_description == long_description + assert workflow_spec.docs.source_code == source_code diff --git a/tests/flytekit/unit/remote/test_backfill.py b/tests/flytekit/unit/remote/test_backfill.py new file mode 100644 index 0000000000..1d4884115d --- /dev/null +++ b/tests/flytekit/unit/remote/test_backfill.py @@ -0,0 +1,95 @@ +from datetime import datetime, timedelta + +import pytest + +from flytekit import CronSchedule, FixedRate, LaunchPlan, task, workflow +from flytekit.remote.backfill import create_backfill_workflow + + +@task +def tk(t: datetime, v: int): + print(f"Invoked at {t} with v {v}") + + +@workflow +def example_wf(t: datetime, v: int): + tk(t=t, v=v) + + +def test_create_backfiller_error(): + no_schedule = LaunchPlan.get_or_create( + workflow=example_wf, + name="nos", + fixed_inputs={"v": 10}, + ) + rate_schedule = LaunchPlan.get_or_create( + workflow=example_wf, + name="rate", + fixed_inputs={"v": 10}, + schedule=FixedRate(duration=timedelta(days=1)), + ) + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + with pytest.raises(ValueError): + create_backfill_workflow(start_date, end_date, no_schedule) + + with pytest.raises(ValueError): + create_backfill_workflow(end_date, start_date, no_schedule) + + with pytest.raises(ValueError): + create_backfill_workflow(end_date, start_date, None) + + with pytest.raises(NotImplementedError): + create_backfill_workflow(start_date, end_date, rate_schedule) + + +def test_create_backfiller(): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + wf, start, end = create_backfill_workflow(start_date, end_date, daily_lp) + assert isinstance(wf.nodes[0].flyte_entity, LaunchPlan) + b0, b1 = wf.nodes[0].bindings[0], wf.nodes[0].bindings[1] + assert b0.var == "t" + assert b0.binding.scalar.primitive.datetime.day == 2 + assert b1.var == "v" + assert b1.binding.scalar.primitive.integer == 10 + assert len(wf.nodes) == 9 + assert len(wf.nodes[0].upstream_nodes) == 0 + assert len(wf.nodes[1].upstream_nodes) == 1 + assert wf.nodes[1].upstream_nodes[0] == wf.nodes[0] + assert start + assert end + + +def test_create_backfiller_parallel(): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + wf, start, end = create_backfill_workflow(start_date, end_date, daily_lp, parallel=True) + assert isinstance(wf.nodes[0].flyte_entity, LaunchPlan) + b0, b1 = wf.nodes[0].bindings[0], wf.nodes[0].bindings[1] + assert b0.var == "t" + assert b0.binding.scalar.primitive.datetime.day == 2 + assert b1.var == "v" + assert b1.binding.scalar.primitive.integer == 10 + assert len(wf.nodes) == 9 + assert len(wf.nodes[0].upstream_nodes) == 0 + assert len(wf.nodes[1].upstream_nodes) == 0 + assert start + assert end diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 34e4f8e8b8..289fba37d7 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -75,6 +75,8 @@ def test_misnamed(): def wf(a: int) -> int: return ft(b=a) + wf() + def test_calling_lp(): sub_wf_lp = LaunchPlan.get_or_create(sub_wf) diff --git a/tests/flytekit/unit/remote/test_lazy_entity.py b/tests/flytekit/unit/remote/test_lazy_entity.py index 1ed191aea4..5328a2caf0 100644 --- a/tests/flytekit/unit/remote/test_lazy_entity.py +++ b/tests/flytekit/unit/remote/test_lazy_entity.py @@ -63,3 +63,16 @@ def _getter(): e.compile(ctx) assert e._entity is not None assert e.entity == dummy_task + + +def test_lazy_loading_exception(): + def _getter(): + raise AttributeError("Error") + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + with pytest.raises(RuntimeError) as exc: + assert e.blah + + assert isinstance(exc.value.__cause__, AttributeError) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 01688ea825..4b8f82fb7e 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -1,13 +1,17 @@ import os import pathlib import tempfile +from collections import OrderedDict +from datetime import datetime, timedelta import pytest from flyteidl.core import compiler_pb2 as _compiler_pb2 from mock import MagicMock, patch import flytekit.configuration +from flytekit import CronSchedule, LaunchPlan, task, workflow from flytekit.configuration import Config, DefaultImages, ImageConfig +from flytekit.core.base_task import PythonTask from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models import security @@ -18,7 +22,7 @@ from flytekit.models.task import Task from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote -from flytekit.tools.translator import Options +from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES CLIENT_METHODS = { @@ -40,29 +44,36 @@ } -@patch("flytekit.clients.friendly.SynchronousFlyteClient") -def test_remote_fetch_execution(mock_client_manager): +@pytest.fixture +def remote(): + with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + flyte_remote._client = mock_client + return flyte_remote + + +def test_remote_fetch_execution(remote): admin_workflow_execution = Execution( id=WorkflowExecutionIdentifier("p1", "d1", "n1"), spec=MagicMock(), closure=MagicMock(), ) - mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec): +@pytest.fixture +def mock_wf_exec(): + return patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") + + +def test_underscore_execute_uses_launch_plan_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client def local_assertions(*args, **kwargs): @@ -89,12 +100,9 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec): +def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client options = Options( @@ -120,14 +128,11 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_execute_with_wrong_input_key(mock_wf_exec): +def test_execute_with_wrong_input_key(remote, mock_wf_exec): # mock_url.get.return_value = "localhost" # mock_insecure.get.return_value = True mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client mock_entity = MagicMock() @@ -158,7 +163,7 @@ def test_passing_of_kwargs(mock_client): "root_certificates": 5, "certificate_chain": 6, } - FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) + FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args).client assert mock_client.called assert mock_client.call_args[1] == additional_args @@ -231,7 +236,7 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): remote = FlyteRemote( config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" ) - assert remote.generate_console_http_domain() == "http://localhost:30080" + assert remote.generate_console_http_domain() == "http://localhost:30081" with open(temp_filename, "w") as f: # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. @@ -269,8 +274,8 @@ def get_compiled_workflow_closure(): return CompiledWorkflowClosure.from_flyte_idl(cwc_pb) -@patch("flytekit.remote.remote.SynchronousFlyteClient") -def test_fetch_lazy(mock_client): +def test_fetch_lazy(remote): + mock_client = remote._client mock_client.get_task.return_value = Task( id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0] ) @@ -280,7 +285,6 @@ def test_fetch_lazy(mock_client): closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()), ) - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") lw = remote.fetch_workflow_lazy(name="wn", version="v") assert isinstance(lw, LazyEntity) assert lw._getter @@ -293,3 +297,52 @@ def test_fetch_lazy(mock_client): assert lt._entity is None tk = lt.entity assert tk.name == "n" + + +@task +def tk(t: datetime, v: int): + print(f"Invoked at {t} with v {v}") + + +@workflow +def example_wf(t: datetime, v: int): + tk(t=t, v=v) + + +def test_launch_backfill(remote): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily2", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + ser_lp = get_serializable_launch_plan(OrderedDict(), serialization_settings, daily_lp, recurse_downstream=False) + m = OrderedDict() + ser_wf = get_serializable(m, serialization_settings, example_wf) + tasks = [] + for k, v in m.items(): + if isinstance(k, PythonTask): + tasks.append(v) + mock_client = remote._client + mock_client.get_launch_plan.return_value = ser_lp + mock_client.get_workflow.return_value = Workflow( + id=Identifier(ResourceType.WORKFLOW, "p", "d", "daily2", "v"), + closure=WorkflowClosure( + compiled_workflow=CompiledWorkflowClosure(primary=ser_wf, sub_workflows=[], tasks=tasks) + ), + ) + + wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True) + assert wf