[![PyPI version fury.io](https://badge.fury.io/py/flytekit.svg)](https://pypi.python.org/pypi/flytekit/)
[![PyPI download day](https://img.shields.io/pypi/dd/flytekit.svg)](https://pypi.python.org/pypi/flytekit/)
diff --git a/dev-requirements.in b/dev-requirements.in
index 0c29d2fe8a..b2e6cf74a1 100644
--- a/dev-requirements.in
+++ b/dev-requirements.in
@@ -10,3 +10,5 @@ pre-commit
codespell
google-cloud-bigquery
google-cloud-bigquery-storage
+IPython
+torch
diff --git a/dev-requirements.txt b/dev-requirements.txt
index a410225b32..3705a578c0 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,5 +1,5 @@
#
-# This file is autogenerated by pip-compile with python 3.9
+# This file is autogenerated by pip-compile with python 3.7
# To update, run:
#
# make dev-requirements.txt
@@ -18,56 +18,65 @@ attrs==20.3.0
# jsonschema
# pytest
# pytest-docker
-bcrypt==3.2.0
+backcall==0.2.0
+ # via ipython
+bcrypt==3.2.2
# via paramiko
binaryornot==0.4.4
# via
# -c requirements.txt
# cookiecutter
-cachetools==5.0.0
+cached-property==1.5.2
+ # via docker-compose
+cachetools==5.2.0
# via google-auth
-certifi==2021.10.8
+certifi==2022.6.15
# via
# -c requirements.txt
# requests
-cffi==1.15.0
+cffi==1.15.1
# via
+ # -c requirements.txt
# bcrypt
# cryptography
# pynacl
cfgv==3.3.1
# via pre-commit
-chardet==4.0.0
+chardet==5.0.0
# via
# -c requirements.txt
# binaryornot
-charset-normalizer==2.0.12
+charset-normalizer==2.1.0
# via
# -c requirements.txt
# requests
-click==8.1.2
+click==8.1.3
# via
# -c requirements.txt
# cookiecutter
# flytekit
-cloudpickle==2.0.0
+cloudpickle==2.1.0
# via
# -c requirements.txt
# flytekit
codespell==2.1.0
# via -r dev-requirements.in
-cookiecutter==1.7.3
+cookiecutter==2.1.1
# via
# -c requirements.txt
# flytekit
-coverage[toml]==6.3.2
+coverage[toml]==6.4.1
# via -r dev-requirements.in
-croniter==1.3.4
+croniter==1.3.5
# via
# -c requirements.txt
# flytekit
-cryptography==36.0.2
- # via paramiko
+cryptography==37.0.4
+ # via
+ # -c requirements.txt
+ # paramiko
+ # pyopenssl
+ # secretstorage
dataclasses-json==0.5.7
# via
# -c requirements.txt
@@ -75,6 +84,7 @@ dataclasses-json==0.5.7
decorator==5.1.1
# via
# -c requirements.txt
+ # ipython
# retry
deprecated==1.2.13
# via
@@ -105,68 +115,84 @@ dockerpty==0.4.1
# via docker-compose
docopt==0.6.2
# via docker-compose
-docstring-parser==0.14
+docstring-parser==0.14.1
# via
# -c requirements.txt
# flytekit
-filelock==3.6.0
+filelock==3.7.1
# via virtualenv
-flyteidl==1.0.0.post1
+flyteidl==1.1.8
# via
# -c requirements.txt
# flytekit
-google-api-core[grpc]==2.7.2
+google-api-core[grpc]==2.8.2
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
# google-cloud-core
-google-auth==2.6.6
+google-auth==2.9.0
# via
# google-api-core
# google-cloud-core
-google-cloud-bigquery==3.0.1
+google-cloud-bigquery==3.2.0
# via -r dev-requirements.in
-google-cloud-bigquery-storage==2.13.1
+google-cloud-bigquery-storage==2.13.2
# via
# -r dev-requirements.in
# google-cloud-bigquery
-google-cloud-core==2.3.0
+google-cloud-core==2.3.1
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
-google-resumable-media==2.3.2
+google-resumable-media==2.3.3
# via google-cloud-bigquery
-googleapis-common-protos==1.56.0
+googleapis-common-protos==1.56.3
# via
# -c requirements.txt
# flyteidl
# google-api-core
# grpcio-status
-grpcio==1.44.0
+grpcio==1.47.0
# via
# -c requirements.txt
# flytekit
# google-api-core
# google-cloud-bigquery
# grpcio-status
-grpcio-status==1.44.0
+grpcio-status==1.47.0
# via
# -c requirements.txt
# flytekit
# google-api-core
-identify==2.4.12
+identify==2.5.1
# via pre-commit
idna==3.3
# via
# -c requirements.txt
# requests
-importlib-metadata==4.11.3
+importlib-metadata==4.12.0
# via
# -c requirements.txt
+ # click
+ # flytekit
+ # jsonschema
# keyring
+ # pluggy
+ # pre-commit
+ # pytest
+ # virtualenv
iniconfig==1.1.1
# via pytest
-jinja2==3.1.1
+ipython==7.34.0
+ # via -r dev-requirements.in
+jedi==0.18.1
+ # via ipython
+jeepney==0.8.0
+ # via
+ # -c requirements.txt
+ # keyring
+ # secretstorage
+jinja2==3.1.2
# via
# -c requirements.txt
# cookiecutter
@@ -182,7 +208,7 @@ jsonschema==3.2.0
# via
# -c requirements.txt
# docker-compose
-keyring==23.5.0
+keyring==23.6.0
# via
# -c requirements.txt
# flytekit
@@ -190,7 +216,7 @@ markupsafe==2.1.1
# via
# -c requirements.txt
# jinja2
-marshmallow==3.15.0
+marshmallow==3.17.0
# via
# -c requirements.txt
# dataclasses-json
@@ -204,9 +230,11 @@ marshmallow-jsonschema==0.13.0
# via
# -c requirements.txt
# flytekit
+matplotlib-inline==0.1.3
+ # via ipython
mock==4.0.3
# via -r dev-requirements.in
-mypy==0.942
+mypy==0.961
# via -r dev-requirements.in
mypy-extensions==0.4.3
# via
@@ -217,11 +245,12 @@ natsort==8.1.0
# via
# -c requirements.txt
# flytekit
-nodeenv==1.6.0
+nodeenv==1.7.0
# via pre-commit
numpy==1.21.6
# via
# -c requirements.txt
+ # flytekit
# pandas
# pyarrow
packaging==21.3
@@ -234,19 +263,23 @@ pandas==1.3.5
# via
# -c requirements.txt
# flytekit
-paramiko==2.10.4
+paramiko==2.11.0
# via docker
+parso==0.8.3
+ # via jedi
+pexpect==4.8.0
+ # via ipython
+pickleshare==0.7.5
+ # via ipython
platformdirs==2.5.2
# via virtualenv
pluggy==1.0.0
# via pytest
-poyo==0.5.0
- # via
- # -c requirements.txt
- # cookiecutter
-pre-commit==2.18.1
+pre-commit==2.19.0
# via -r dev-requirements.in
-proto-plus==1.20.3
+prompt-toolkit==3.0.30
+ # via ipython
+proto-plus==1.20.6
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
@@ -257,6 +290,7 @@ protobuf==3.20.1
# flytekit
# google-api-core
# google-cloud-bigquery
+ # google-cloud-bigquery-storage
# googleapis-common-protos
# grpcio-status
# proto-plus
@@ -265,6 +299,8 @@ protoc-gen-swagger==0.1.0
# via
# -c requirements.txt
# flyteidl
+ptyprocess==0.7.0
+ # via pexpect
py==1.11.0
# via
# -c requirements.txt
@@ -282,10 +318,18 @@ pyasn1==0.4.8
pyasn1-modules==0.2.8
# via google-auth
pycparser==2.21
- # via cffi
+ # via
+ # -c requirements.txt
+ # cffi
+pygments==2.12.0
+ # via ipython
pynacl==1.5.0
# via paramiko
-pyparsing==3.0.8
+pyopenssl==22.0.0
+ # via
+ # -c requirements.txt
+ # flytekit
+pyparsing==3.0.9
# via
# -c requirements.txt
# packaging
@@ -316,7 +360,7 @@ python-json-logger==2.0.2
# via
# -c requirements.txt
# flytekit
-python-slugify==6.1.1
+python-slugify==6.1.2
# via
# -c requirements.txt
# cookiecutter
@@ -332,14 +376,15 @@ pytz==2022.1
pyyaml==5.4.1
# via
# -c requirements.txt
+ # cookiecutter
# docker-compose
# flytekit
# pre-commit
-regex==2022.4.24
+regex==2022.6.2
# via
# -c requirements.txt
# docker-image-py
-requests==2.27.1
+requests==2.28.1
# via
# -c requirements.txt
# cookiecutter
@@ -349,7 +394,7 @@ requests==2.27.1
# google-api-core
# google-cloud-bigquery
# responses
-responses==0.20.0
+responses==0.21.0
# via
# -c requirements.txt
# flytekit
@@ -359,11 +404,17 @@ retry==0.9.2
# flytekit
rsa==4.8
# via google-auth
+secretstorage==3.3.2
+ # via
+ # -c requirements.txt
+ # keyring
+singledispatchmethod==1.0
+ # via
+ # -c requirements.txt
+ # flytekit
six==1.16.0
# via
# -c requirements.txt
- # bcrypt
- # cookiecutter
# dockerpty
# google-auth
# grpcio
@@ -393,11 +444,23 @@ tomli==2.0.1
# coverage
# mypy
# pytest
-typing-extensions==4.2.0
+torch==1.11.0
+ # via -r dev-requirements.in
+traitlets==5.3.0
+ # via
+ # ipython
+ # matplotlib-inline
+typed-ast==1.5.4
+ # via mypy
+typing-extensions==4.3.0
# via
# -c requirements.txt
+ # arrow
# flytekit
+ # importlib-metadata
# mypy
+ # responses
+ # torch
# typing-inspect
typing-inspect==0.7.1
# via
@@ -409,8 +472,10 @@ urllib3==1.26.9
# flytekit
# requests
# responses
-virtualenv==20.14.1
+virtualenv==20.15.1
# via pre-commit
+wcwidth==0.2.5
+ # via prompt-toolkit
websocket-client==0.59.0
# via
# -c requirements.txt
@@ -420,7 +485,7 @@ wheel==0.37.1
# via
# -c requirements.txt
# flytekit
-wrapt==1.14.0
+wrapt==1.14.1
# via
# -c requirements.txt
# deprecated
diff --git a/doc-requirements.in b/doc-requirements.in
index 4d60a6919b..760d3903dc 100644
--- a/doc-requirements.in
+++ b/doc-requirements.in
@@ -33,3 +33,4 @@ papermill # papermill
jupyter # papermill
pyspark # spark
sqlalchemy # sqlalchemy
+torch # pytorch
diff --git a/doc-requirements.txt b/doc-requirements.txt
index 3eee9655cb..e9f16c89a0 100644
--- a/doc-requirements.txt
+++ b/doc-requirements.txt
@@ -1,5 +1,5 @@
#
-# This file is autogenerated by pip-compile with python 3.9
+# This file is autogenerated by pip-compile with python 3.7
# To update, run:
#
# make doc-requirements.txt
@@ -12,28 +12,26 @@ altair==4.2.0
# via great-expectations
ansiwrap==0.8.4
# via papermill
-appnope==0.1.3
- # via
- # ipykernel
- # ipython
argon2-cffi==21.3.0
# via notebook
argon2-cffi-bindings==21.2.0
# via argon2-cffi
arrow==1.2.2
# via jinja2-time
-astroid==2.11.3
+astroid==2.11.6
# via sphinx-autoapi
-asttokens==2.0.5
- # via stack-data
attrs==21.4.0
# via
# jsonschema
# visions
-babel==2.10.1
+babel==2.10.3
# via sphinx
backcall==0.2.0
# via ipython
+backports-zoneinfo==0.2.1
+ # via
+ # pytz-deprecation-shim
+ # tzlocal
beautifulsoup4==4.11.1
# via
# furo
@@ -42,42 +40,44 @@ beautifulsoup4==4.11.1
# sphinx-material
binaryornot==0.4.4
# via cookiecutter
-bleach==5.0.0
+bleach==5.0.1
# via nbconvert
-botocore==1.25.0
+botocore==1.27.22
# via -r doc-requirements.in
-cachetools==5.0.0
+cachetools==5.2.0
# via google-auth
-certifi==2021.10.8
+certifi==2022.6.15
# via
# kubernetes
# requests
-cffi==1.15.0
+cffi==1.15.1
# via
# argon2-cffi-bindings
# cryptography
-chardet==4.0.0
+chardet==5.0.0
# via binaryornot
-charset-normalizer==2.0.12
+charset-normalizer==2.1.0
# via requests
-click==8.1.2
+click==8.1.3
# via
# cookiecutter
# flytekit
# great-expectations
# papermill
-cloudpickle==2.0.0
+cloudpickle==2.1.0
# via flytekit
-colorama==0.4.4
+colorama==0.4.5
# via great-expectations
-cookiecutter==1.7.3
+cookiecutter==2.1.1
# via flytekit
-croniter==1.3.4
+croniter==1.3.5
# via flytekit
-cryptography==36.0.2
+cryptography==37.0.4
# via
# -r doc-requirements.in
# great-expectations
+ # pyopenssl
+ # secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
cycler==0.11.0
@@ -102,7 +102,7 @@ docker==5.0.3
# via flytekit
docker-image-py==0.1.12
# via flytekit
-docstring-parser==0.14
+docstring-parser==0.14.1
# via flytekit
docutils==0.17.1
# via
@@ -118,59 +118,57 @@ entrypoints==0.4
# jupyter-client
# nbconvert
# papermill
-executing==0.8.3
- # via stack-data
fastjsonschema==2.15.3
# via nbformat
-flyteidl==1.0.0.post1
+flyteidl==1.1.8
# via flytekit
-fonttools==4.33.2
+fonttools==4.33.3
# via matplotlib
-fsspec==2022.3.0
+fsspec==2022.5.0
# via
# -r doc-requirements.in
# modin
furo @ git+https://github.com/flyteorg/furo@main
# via -r doc-requirements.in
-google-api-core[grpc]==2.7.2
+google-api-core[grpc]==2.8.2
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
# google-cloud-core
-google-auth==2.6.6
+google-auth==2.9.0
# via
# google-api-core
# google-cloud-core
# kubernetes
google-cloud==0.34.0
# via -r doc-requirements.in
-google-cloud-bigquery==3.0.1
+google-cloud-bigquery==3.2.0
# via -r doc-requirements.in
-google-cloud-bigquery-storage==2.13.1
+google-cloud-bigquery-storage==2.13.2
# via google-cloud-bigquery
-google-cloud-core==2.3.0
+google-cloud-core==2.3.1
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
-google-resumable-media==2.3.2
+google-resumable-media==2.3.3
# via google-cloud-bigquery
-googleapis-common-protos==1.56.0
+googleapis-common-protos==1.56.3
# via
# flyteidl
# google-api-core
# grpcio-status
-great-expectations==0.15.2
+great-expectations==0.15.12
# via -r doc-requirements.in
greenlet==1.1.2
# via sqlalchemy
-grpcio==1.44.0
+grpcio==1.47.0
# via
# -r doc-requirements.in
# flytekit
# google-api-core
# google-cloud-bigquery
# grpcio-status
-grpcio-status==1.44.0
+grpcio-status==1.47.0
# via
# flytekit
# google-api-core
@@ -180,22 +178,28 @@ idna==3.3
# via requests
imagehash==4.2.1
# via visions
-imagesize==1.3.0
+imagesize==1.4.1
# via sphinx
-importlib-metadata==4.11.3
+importlib-metadata==4.12.0
# via
+ # click
+ # flytekit
# great-expectations
+ # jsonschema
# keyring
# markdown
# sphinx
-ipykernel==6.13.0
+ # sqlalchemy
+importlib-resources==5.8.0
+ # via jsonschema
+ipykernel==6.15.0
# via
# ipywidgets
# jupyter
# jupyter-console
# notebook
# qtconsole
-ipython==8.2.0
+ipython==7.34.0
# via
# great-expectations
# ipykernel
@@ -206,11 +210,15 @@ ipython-genutils==0.2.0
# ipywidgets
# notebook
# qtconsole
-ipywidgets==7.7.0
+ipywidgets==7.7.1
# via jupyter
jedi==0.18.1
# via ipython
-jinja2==3.0.3
+jeepney==0.8.0
+ # via
+ # keyring
+ # secretstorage
+jinja2==3.1.2
# via
# altair
# cookiecutter
@@ -223,9 +231,9 @@ jinja2==3.0.3
# sphinx-autoapi
jinja2-time==0.2.0
# via cookiecutter
-jmespath==1.0.0
+jmespath==1.0.1
# via botocore
-joblib==1.0.1
+joblib==1.1.0
# via
# pandas-profiling
# phik
@@ -233,21 +241,21 @@ jsonpatch==1.32
# via great-expectations
jsonpointer==2.3
# via jsonpatch
-jsonschema==4.4.0
+jsonschema==4.6.1
# via
# altair
# great-expectations
# nbformat
jupyter==1.0.0
# via -r doc-requirements.in
-jupyter-client==7.3.0
+jupyter-client==7.3.4
# via
# ipykernel
# jupyter-console
# nbclient
# notebook
# qtconsole
-jupyter-console==6.4.3
+jupyter-console==6.4.4
# via jupyter
jupyter-core==4.10.0
# via
@@ -258,26 +266,26 @@ jupyter-core==4.10.0
# qtconsole
jupyterlab-pygments==0.2.2
# via nbconvert
-jupyterlab-widgets==1.1.0
+jupyterlab-widgets==1.1.1
# via ipywidgets
-keyring==23.5.0
+keyring==23.6.0
# via flytekit
-kiwisolver==1.4.2
+kiwisolver==1.4.3
# via matplotlib
-kubernetes==23.3.0
+kubernetes==24.2.0
# via -r doc-requirements.in
lazy-object-proxy==1.7.1
# via astroid
-lxml==4.8.0
+lxml==4.9.1
# via sphinx-material
-markdown==3.3.6
+markdown==3.3.7
# via -r doc-requirements.in
-markupsafe==2.0.1
+markupsafe==2.1.1
# via
# jinja2
# nbconvert
# pandas-profiling
-marshmallow==3.15.0
+marshmallow==3.17.0
# via
# dataclasses-json
# marshmallow-enum
@@ -286,7 +294,7 @@ marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
-matplotlib==3.5.1
+matplotlib==3.5.2
# via
# missingno
# pandas-profiling
@@ -302,7 +310,7 @@ mistune==0.8.4
# via
# great-expectations
# nbconvert
-modin==0.14.0
+modin==0.12.1
# via -r doc-requirements.in
multimethod==1.8
# via
@@ -312,7 +320,7 @@ mypy-extensions==0.4.3
# via typing-inspect
natsort==8.1.0
# via flytekit
-nbclient==0.6.0
+nbclient==0.6.6
# via
# nbconvert
# papermill
@@ -320,10 +328,9 @@ nbconvert==6.5.0
# via
# jupyter
# notebook
-nbformat==5.3.0
+nbformat==5.4.0
# via
# great-expectations
- # ipywidgets
# nbclient
# nbconvert
# notebook
@@ -334,16 +341,17 @@ nest-asyncio==1.5.5
# jupyter-client
# nbclient
# notebook
-networkx==2.8
+networkx==2.6.3
# via visions
-notebook==6.4.11
+notebook==6.4.12
# via
# great-expectations
# jupyter
# widgetsnbextension
-numpy==1.22.3
+numpy==1.21.6
# via
# altair
+ # flytekit
# great-expectations
# imagehash
# matplotlib
@@ -372,7 +380,7 @@ packaging==21.3
# pandera
# qtpy
# sphinx
-pandas==1.4.1
+pandas==1.3.5
# via
# altair
# dolt-integrations
@@ -384,9 +392,9 @@ pandas==1.4.1
# phik
# seaborn
# visions
-pandas-profiling==3.1.0
+pandas-profiling==3.2.0
# via -r doc-requirements.in
-pandera==0.10.1
+pandera==0.9.0
# via -r doc-requirements.in
pandocfilters==1.5.0
# via nbconvert
@@ -400,22 +408,20 @@ phik==0.12.2
# via pandas-profiling
pickleshare==0.7.5
# via ipython
-pillow==9.1.0
+pillow==9.2.0
# via
# imagehash
# matplotlib
# visions
-plotly==5.7.0
+plotly==5.9.0
# via -r doc-requirements.in
-poyo==0.5.0
- # via cookiecutter
prometheus-client==0.14.1
# via notebook
-prompt-toolkit==3.0.29
+prompt-toolkit==3.0.30
# via
# ipython
# jupyter-console
-proto-plus==1.20.3
+proto-plus==1.20.6
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
@@ -425,23 +431,22 @@ protobuf==3.20.1
# flytekit
# google-api-core
# google-cloud-bigquery
+ # google-cloud-bigquery-storage
# googleapis-common-protos
# grpcio-status
# proto-plus
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via flyteidl
-psutil==5.9.0
+psutil==5.9.1
# via ipykernel
ptyprocess==0.7.0
# via
# pexpect
# terminado
-pure-eval==0.2.2
- # via stack-data
py==1.11.0
# via retry
-py4j==0.10.9.3
+py4j==0.10.9.5
# via pyspark
pyarrow==6.0.1
# via
@@ -456,18 +461,21 @@ pyasn1-modules==0.2.8
# via google-auth
pycparser==2.21
# via cffi
-pydantic==1.9.0
+pydantic==1.9.1
# via
# pandas-profiling
# pandera
pygments==2.12.0
# via
+ # furo
# ipython
# jupyter-console
# nbconvert
# qtconsole
# sphinx
# sphinx-prompt
+pyopenssl==22.0.0
+ # via flytekit
pyparsing==2.4.7
# via
# great-expectations
@@ -475,7 +483,7 @@ pyparsing==2.4.7
# packaging
pyrsistent==0.18.1
# via jsonschema
-pyspark==3.2.1
+pyspark==3.3.0
# via -r doc-requirements.in
python-dateutil==2.8.2
# via
@@ -491,7 +499,7 @@ python-dateutil==2.8.2
# pandas
python-json-logger==2.0.2
# via flytekit
-python-slugify[unidecode]==6.1.1
+python-slugify[unidecode]==6.1.2
# via
# cookiecutter
# sphinx-material
@@ -509,23 +517,25 @@ pywavelets==1.3.0
# via imagehash
pyyaml==6.0
# via
+ # cookiecutter
# flytekit
# kubernetes
# pandas-profiling
# papermill
# sphinx-autoapi
-pyzmq==22.3.0
+pyzmq==23.2.0
# via
+ # ipykernel
# jupyter-client
# notebook
# qtconsole
-qtconsole==5.3.0
+qtconsole==5.3.1
# via jupyter
-qtpy==2.0.1
+qtpy==2.1.0
# via qtconsole
-regex==2022.4.24
+regex==2022.6.2
# via docker-image-py
-requests==2.27.1
+requests==2.28.1
# via
# cookiecutter
# docker
@@ -541,7 +551,7 @@ requests==2.27.1
# sphinx
requests-oauthlib==1.3.1
# via kubernetes
-responses==0.20.0
+responses==0.21.0
# via flytekit
retry==0.9.2
# via flytekit
@@ -551,7 +561,7 @@ ruamel-yaml==0.17.17
# via great-expectations
ruamel-yaml-clib==0.2.6
# via ruamel-yaml
-scipy==1.8.0
+scipy==1.7.3
# via
# great-expectations
# imagehash
@@ -563,18 +573,19 @@ seaborn==0.11.2
# via
# missingno
# pandas-profiling
+secretstorage==3.3.2
+ # via keyring
send2trash==1.8.0
# via notebook
+singledispatchmethod==1.0
+ # via flytekit
six==1.16.0
# via
- # asttokens
# bleach
- # cookiecutter
# google-auth
# grpcio
# imagehash
# kubernetes
- # plotly
# python-dateutil
# sphinx-code-include
snowballstemmer==2.2.0
@@ -588,6 +599,7 @@ sphinx==4.5.0
# -r doc-requirements.in
# furo
# sphinx-autoapi
+ # sphinx-basic-ng
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
@@ -598,6 +610,8 @@ sphinx==4.5.0
# sphinxcontrib-yt
sphinx-autoapi==1.8.4
# via -r doc-requirements.in
+sphinx-basic-ng==0.0.1a12
+ # via furo
sphinx-code-include==1.1.1
# via -r doc-requirements.in
sphinx-copybutton==0.5.0
@@ -626,13 +640,11 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sphinxcontrib-yt==0.2.2
# via -r doc-requirements.in
-sqlalchemy==1.4.35
+sqlalchemy==1.4.39
# via -r doc-requirements.in
-stack-data==0.2.0
- # via ipython
statsd==3.3.0
# via flytekit
-tangled-up-in-unicode==0.1.0
+tangled-up-in-unicode==0.2.0
# via
# pandas-profiling
# visions
@@ -642,7 +654,7 @@ tenacity==8.0.1
# plotly
termcolor==1.1.0
# via great-expectations
-terminado==0.13.3
+terminado==0.15.0
# via notebook
text-unidecode==1.3
# via python-slugify
@@ -652,7 +664,9 @@ tinycss2==1.1.1
# via nbconvert
toolz==0.11.2
# via altair
-tornado==6.1
+torch==1.11.0
+ # via -r doc-requirements.in
+tornado==6.2
# via
# ipykernel
# jupyter-client
@@ -663,7 +677,7 @@ tqdm==4.64.0
# great-expectations
# pandas-profiling
# papermill
-traitlets==5.1.1
+traitlets==5.3.0
# via
# ipykernel
# ipython
@@ -676,12 +690,22 @@ traitlets==5.1.1
# nbformat
# notebook
# qtconsole
-typing-extensions==4.2.0
+typed-ast==1.5.4
+ # via astroid
+typing-extensions==4.3.0
# via
+ # argon2-cffi
+ # arrow
# astroid
# flytekit
# great-expectations
+ # importlib-metadata
+ # jsonschema
+ # kiwisolver
+ # pandera
# pydantic
+ # responses
+ # torch
# typing-inspect
typing-inspect==0.7.1
# via
@@ -711,22 +735,24 @@ webencodings==0.5.1
# via
# bleach
# tinycss2
-websocket-client==1.3.2
+websocket-client==1.3.3
# via
# docker
# kubernetes
wheel==0.37.1
# via flytekit
-widgetsnbextension==3.6.0
+widgetsnbextension==3.6.1
# via ipywidgets
-wrapt==1.14.0
+wrapt==1.14.1
# via
# astroid
# deprecated
# flytekit
# pandera
zipp==3.8.0
- # via importlib-metadata
+ # via
+ # importlib-metadata
+ # importlib-resources
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst
index 1ed84de1b4..3f06961022 100644
--- a/docs/source/data.extend.rst
+++ b/docs/source/data.extend.rst
@@ -1,8 +1,11 @@
##############################
Extend Data Persistence layer
##############################
-Flytekit provides a data persistence layer, which is used for recording metadata that is shared with backend Flyte. This persistence layer is also available for various types to store raw user data and is designed to be cross-cloud compatible.
-Moreover, it is design to be extensible and users can bring their own data persistence plugins by following the persistence interface. NOTE, this is bound to get more extensive for variety of use-cases, but the core set of apis are battle tested.
+Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible.
+Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface.
+
+.. note::
+ This will become extensive for a variety of use-cases, but the core set of APIs have been battle tested.
.. automodule:: flytekit.core.data_persistence
:no-members:
@@ -13,3 +16,22 @@ Moreover, it is design to be extensible and users can bring their own data persi
:no-members:
:no-inherited-members:
:no-special-members:
+
+The ``fsspec`` Data Plugin
+--------------------------
+
+Flytekit ships with a default storage driver that uses aws-cli on AWS and gsutil on GCP. By default, Flyte uploads the task outputs to S3 or GCS using these storage drivers.
+
+Why ``fsspec``?
+^^^^^^^^^^^^^^^
+
+You can use the fsspec plugin implementation to utilize all its available plugins with flytekit. The `fsspec `_ plugin provides an implementation of the data persistence layer in Flytekit. For example: HDFS, FTP are supported in fsspec, so you can use them with flytekit too.
+The data persistence layer helps store logs of metadata and raw user data.
+As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``.
+
+`Here `_ is a code snippet that shows protocols mapped to the class it implements.
+
+Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec.
+
+.. note::
+ This plugin installs fsspec core only. To install all the fsspec plugins, see `here `_.
diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst
new file mode 100644
index 0000000000..12fd3d62d9
--- /dev/null
+++ b/docs/source/extras.pytorch.rst
@@ -0,0 +1,7 @@
+############
+PyTorch Type
+############
+.. automodule:: flytekit.extras.pytorch
+ :no-members:
+ :no-inherited-members:
+ :no-special-members:
diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst
index f1b15455dd..f0cdff28dc 100644
--- a/docs/source/types.extend.rst
+++ b/docs/source/types.extend.rst
@@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types.
types.builtins.structured
types.builtins.file
types.builtins.directory
+ extras.pytorch
diff --git a/flytekit/__init__.py b/flytekit/__init__.py
index c0ea9ebe3d..c67a8a04b4 100644
--- a/flytekit/__init__.py
+++ b/flytekit/__init__.py
@@ -154,6 +154,7 @@
"""
import sys
+from typing import Generator
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
@@ -181,6 +182,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
+from flytekit.extras import pytorch
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
@@ -188,7 +190,7 @@
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
-from flytekit.types import directory, file, schema
+from flytekit.types import directory, file, numpy, schema
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
@@ -215,6 +217,10 @@ def current_context() -> ExecutionParameters:
return FlyteContextManager.current_context().execution_state.user_space_params
+def new_context() -> Generator[FlyteContext, None, None]:
+ return FlyteContextManager.with_context(FlyteContextManager.current_context().new_builder())
+
+
def load_implicit_plugins():
"""
This method allows loading all plugins that have the entrypoint specification. This uses the plugin loading
diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py
index 68bd2578e5..1f8dd78ef0 100644
--- a/flytekit/bin/entrypoint.py
+++ b/flytekit/bin/entrypoint.py
@@ -243,6 +243,7 @@ def setup_execution(
tmp_dir=user_workspace_dir,
raw_output_prefix=raw_output_data_prefix,
checkpoint=checkpointer,
+ task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
)
try:
diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py
index 8db7c93a98..d542af5f7e 100644
--- a/flytekit/clients/friendly.py
+++ b/flytekit/clients/friendly.py
@@ -1004,3 +1004,17 @@ def get_upload_signed_url(
expires_in=expires_in_pb,
)
)
+
+ def get_download_signed_url(
+ self, native_url: str, expires_in: datetime.timedelta = None
+ ) -> _data_proxy_pb2.CreateUploadLocationResponse:
+ expires_in_pb = None
+ if expires_in:
+ expires_in_pb = Duration()
+ expires_in_pb.FromTimedelta(expires_in)
+ return super(SynchronousFlyteClient, self).create_download_location(
+ _data_proxy_pb2.CreateDownloadLocationRequest(
+ native_url=native_url,
+ expires_in=expires_in_pb,
+ )
+ )
diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py
index 9aba78888a..9bdf85a178 100644
--- a/flytekit/clients/raw.py
+++ b/flytekit/clients/raw.py
@@ -1,12 +1,14 @@
from __future__ import annotations
import base64 as _base64
+import ssl
import subprocess
import time
import typing
from typing import Optional
import grpc
+import OpenSSL
import requests as _requests
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.service import admin_pb2_grpc as _admin_service
@@ -110,6 +112,26 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
self._cfg = cfg
if cfg.insecure:
self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs)
+ elif cfg.insecure_skip_verify:
+ # Get port from endpoint or use 443
+ endpoint_parts = cfg.endpoint.rsplit(":", 1)
+ if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
+ server_address = tuple(endpoint_parts)
+ else:
+ server_address = (cfg.endpoint, "443")
+
+ cert = ssl.get_server_certificate(server_address)
+ x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
+ cn = x509.get_subject().CN
+ credentials = grpc.ssl_channel_credentials(str.encode(cert))
+ options = kwargs.get("options", [])
+ options.append(("grpc.ssl_target_name_override", cn))
+ self._channel = grpc.secure_channel(
+ target=cfg.endpoint,
+ credentials=credentials,
+ options=options,
+ compression=kwargs.get("compression", None),
+ )
else:
if "credentials" not in kwargs:
credentials = grpc.ssl_channel_credentials(
@@ -251,7 +273,10 @@ def _refresh_credentials_from_command(self):
except subprocess.CalledProcessError as e:
cli_logger.error("Failed to generate token from command {}".format(command))
raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e))
- self.set_access_token(output.stdout.strip())
+ authorization_header_key = self.public_client_config.authorization_metadata_key or None
+ if not authorization_header_key:
+ self.set_access_token(output.stdout.strip())
+ self.set_access_token(output.stdout.strip(), authorization_header_key)
def _refresh_credentials_noop(self):
pass
@@ -833,6 +858,17 @@ def create_upload_location(
"""
return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata)
+ @_handle_rpc_error(retry=True)
+ def create_download_location(
+ self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest
+ ) -> _dataproxy_pb2.CreateDownloadLocationResponse:
+ """
+ Get a signed url to be used during fast registration
+ :param flyteidl.service.dataproxy_pb2.CreateDownloadLocationRequest create_download_location_request:
+ :rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse
+ """
+ return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata)
+
def get_token(token_endpoint, authorization_header, scope):
"""
diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py
index 3628af7beb..21aec1c4ad 100644
--- a/flytekit/clis/flyte_cli/main.py
+++ b/flytekit/clis/flyte_cli/main.py
@@ -4,6 +4,7 @@
import os as _os
import stat as _stat
import sys as _sys
+from dataclasses import replace
from typing import Callable, Dict, List, Tuple, Union
import click as _click
@@ -276,7 +277,7 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC
if parent_ctx.obj["cacert"]:
kwargs["root_certificates"] = parent_ctx.obj["cacert"]
cfg = parent_ctx.obj["config"]
- cfg = cfg.with_parameters(endpoint=host, insecure=insecure)
+ cfg = replace(cfg, endpoint=host, insecure=insecure)
return _friendly_client.SynchronousFlyteClient(cfg, **kwargs)
diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py
index 73274e972e..d922f5e3c1 100644
--- a/flytekit/clis/helpers.py
+++ b/flytekit/clis/helpers.py
@@ -1,5 +1,7 @@
+import sys
from typing import Tuple, Union
+import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan
from flyteidl.admin.task_pb2 import TaskSpec
from flyteidl.admin.workflow_pb2 import WorkflowSpec
@@ -125,3 +127,9 @@ def hydrate_registration_parameters(
del entity.sub_workflows[:]
entity.sub_workflows.extend(refreshed_sub_workflows)
return identifier, entity
+
+
+def display_help_with_error(ctx: click.Context, message: str):
+ click.echo(f"{ctx.get_help()}\n")
+ click.secho(message, fg="red")
+ sys.exit(1)
diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py
new file mode 100644
index 0000000000..a9a9c4900d
--- /dev/null
+++ b/flytekit/clis/sdk_in_container/helpers.py
@@ -0,0 +1,32 @@
+import click
+
+from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE
+from flytekit.configuration import Config
+from flytekit.loggers import cli_logger
+from flytekit.remote.remote import FlyteRemote
+
+FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote"
+
+
+def get_and_save_remote_with_click_context(
+ ctx: click.Context, project: str, domain: str, save: bool = True
+) -> FlyteRemote:
+ """
+ NB: This function will by default mutate the click Context.obj dictionary, adding a remote key with value
+ of the created FlyteRemote object.
+
+ :param ctx: the click context object
+ :param project: default project for the remote instance
+ :param domain: default domain
+ :param save: If false, will not mutate the context.obj dict
+ :return: FlyteRemote instance
+ """
+ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE)
+ cfg_obj = Config.auto(cfg_file_location)
+ cli_logger.info(
+ f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "")
+ )
+ r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain)
+ if save:
+ ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r
+ return r
diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py
index 71efeab576..2a884e29da 100644
--- a/flytekit/clis/sdk_in_container/package.py
+++ b/flytekit/clis/sdk_in_container/package.py
@@ -1,8 +1,8 @@
import os
-import sys
import click
+from flytekit.clis.helpers import display_help_with_error
from flytekit.clis.sdk_in_container import constants
from flytekit.configuration import (
DEFAULT_RUNTIME_PYTHON_INTERPRETER,
@@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_
pkgs = ctx.obj[constants.CTX_PACKAGES]
if not pkgs:
- click.secho("No packages to scan for flyte entities. Aborting!", fg="red")
- sys.exit(-1)
+ display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!")
try:
serialize_and_package(pkgs, serialization_settings, source, output, fast)
diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py
index c2b5f2b045..76777c5663 100644
--- a/flytekit/clis/sdk_in_container/pyflyte.py
+++ b/flytekit/clis/sdk_in_container/pyflyte.py
@@ -5,6 +5,7 @@
from flytekit.clis.sdk_in_container.init import init
from flytekit.clis.sdk_in_container.local_cache import local_cache
from flytekit.clis.sdk_in_container.package import package
+from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.run import run
from flytekit.clis.sdk_in_container.serialize import serialize
from flytekit.configuration.internal import LocalSDK
@@ -68,6 +69,7 @@ def main(ctx, pkgs=None, config=None):
main.add_command(local_cache)
main.add_command(init)
main.add_command(run)
+main.add_command(register)
if __name__ == "__main__":
main()
diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py
new file mode 100644
index 0000000000..03e00d7896
--- /dev/null
+++ b/flytekit/clis/sdk_in_container/register.py
@@ -0,0 +1,179 @@
+import os
+import pathlib
+import typing
+
+import click
+
+from flytekit.clis.helpers import display_help_with_error
+from flytekit.clis.sdk_in_container import constants
+from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
+from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
+from flytekit.configuration.default_images import DefaultImages
+from flytekit.loggers import cli_logger
+from flytekit.tools.fast_registration import fast_package
+from flytekit.tools.repo import find_common_root, load_packages_and_modules
+from flytekit.tools.repo import register as repo_register
+from flytekit.tools.translator import Options
+
+_register_help = """
+This command is similar to package but instead of producing a zip file, all your Flyte entities are compiled,
+and then sent to the backend specified by your config file. Think of this as combining the pyflyte package
+and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl
+like service account here.
+
+Note: This command runs "fast" register by default. Future work to come to add a non-fast version.
+This means that a zip is created from the detected root of the packages given, and uploaded. Just like with
+pyflyte run, tasks registered from this command will download and unzip that code package before running.
+
+Note: This command only works on regular Python packages, not namespace packages. When determining
+ the root of your project, it finds the first folder that does not have an __init__.py file.
+"""
+
+
+@click.command("register", help=_register_help)
+@click.option(
+ "-p",
+ "--project",
+ required=False,
+ type=str,
+ default="flytesnacks",
+ help="Project to register and run this workflow in",
+)
+@click.option(
+ "-d",
+ "--domain",
+ required=False,
+ type=str,
+ default="development",
+ help="Domain to register and run this workflow in",
+)
+@click.option(
+ "-i",
+ "--image",
+ "image_config",
+ required=False,
+ multiple=True,
+ type=click.UNPROCESSED,
+ callback=ImageConfig.validate_image,
+ default=[DefaultImages.default_image()],
+ help="A fully qualified tag for an docker image, e.g. somedocker.com/myimage:someversion123. This is a "
+ "multi-option and can be of the form --image xyz.io/docker:latest "
+ "--image my_image=xyz.io/docker2:latest. Note, the `name=image_uri`. The name is optional, if not "
+ "provided the image will be used as the default image. All the names have to be unique, and thus "
+ "there can only be one --image option with no name.",
+)
+@click.option(
+ "-o",
+ "--output",
+ required=False,
+ type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
+ default=None,
+ help="Directory to write the output zip file containing the protobuf definitions",
+)
+@click.option(
+ "-d",
+ "--destination-dir",
+ required=False,
+ type=str,
+ default="/root",
+ help="Directory inside the image where the tar file containing the code will be copied to",
+)
+@click.option(
+ "--service-account",
+ required=False,
+ type=str,
+ default="",
+ help="Service account used when creating launch plans",
+)
+@click.option(
+ "--raw-data-prefix",
+ required=False,
+ type=str,
+ default="",
+ help="Raw output data prefix when creating launch plans, where offloaded data will be stored",
+)
+@click.option(
+ "-v",
+ "--version",
+ required=False,
+ type=str,
+ help="Version the package or module is registered with",
+)
+@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
+@click.pass_context
+def register(
+ ctx: click.Context,
+ project: str,
+ domain: str,
+ image_config: ImageConfig,
+ output: str,
+ destination_dir: str,
+ service_account: str,
+ raw_data_prefix: str,
+ version: typing.Optional[str],
+ package_or_module: typing.Tuple[str],
+):
+ """
+ see help
+ """
+ pkgs = ctx.obj[constants.CTX_PACKAGES]
+ if not pkgs:
+ cli_logger.debug("No pkgs")
+ if pkgs:
+ raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command")
+
+ if len(package_or_module) == 0:
+ display_help_with_error(
+ ctx,
+ "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
+ )
+
+ cli_logger.debug(
+ f"Running pyflyte register from {os.getcwd()} "
+ f"with images {image_config} "
+ f"and image destinationfolder {destination_dir} "
+ f"on {len(package_or_module)} package(s) {package_or_module}"
+ )
+
+ # Create and save FlyteRemote,
+ remote = get_and_save_remote_with_click_context(ctx, project, domain)
+
+ # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings
+ # Create a zip file containing all the entries.
+ detected_root = find_common_root(package_or_module)
+ cli_logger.debug(f"Using {detected_root} as root folder for project")
+ zip_file = fast_package(detected_root, output)
+
+ # Upload zip file to Admin using FlyteRemote.
+ md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file))
+ cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}")
+
+ # Create serialization settings
+ # Todo: Rely on default Python interpreter for now, this will break custom Spark containers
+ serialization_settings = SerializationSettings(
+ project=project,
+ domain=domain,
+ image_config=image_config,
+ fast_serialization_settings=FastSerializationSettings(
+ enabled=True,
+ destination_dir=destination_dir,
+ distribution_location=native_url,
+ ),
+ )
+
+ options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix)
+
+ # Load all the entities
+ registerable_entities = load_packages_and_modules(
+ serialization_settings, detected_root, list(package_or_module), options
+ )
+ if len(registerable_entities) == 0:
+ display_help_with_error(ctx, "No Flyte entities were detected. Aborting!")
+ cli_logger.info(f"Found and serialized {len(registerable_entities)} entities")
+
+ if not version:
+ version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa
+ cli_logger.info(f"Computed version is {version}")
+
+ # Register using repo code
+ repo_register(registerable_entities, project, domain, version, remote.client)
diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py
index b6f723fe0c..bd7b8ee20b 100644
--- a/flytekit/clis/sdk_in_container/run.py
+++ b/flytekit/clis/sdk_in_container/run.py
@@ -2,6 +2,7 @@
import functools
import importlib
import json
+import logging
import os
import pathlib
import typing
@@ -11,10 +12,12 @@
import click
from dataclasses_json import DataClassJsonMixin
from pytimeparse import parse
+from typing_extensions import get_args
from flytekit import BlobType, Literal, Scalar
-from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT
-from flytekit.configuration import Config, ImageConfig, SerializationSettings
+from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT
+from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context
+from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager, tracker
from flytekit.core.base_task import PythonTask
@@ -22,19 +25,17 @@
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
-from flytekit.loggers import cli_logger
from flytekit.models import literals
from flytekit.models.interface import Variable
from flytekit.models.literals import Blob, BlobMetadata, Primitive
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote.executions import FlyteWorkflowExecution
-from flytekit.remote.remote import FlyteRemote
from flytekit.tools import module_loader, script_mode
+from flytekit.tools.script_mode import _find_project_root
from flytekit.tools.translator import Options
REMOTE_FLAG_KEY = "remote"
RUN_LEVEL_PARAMS_KEY = "run_level_params"
-FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote"
DATA_PROXY_CALLBACK_KEY = "data_proxy"
@@ -244,6 +245,31 @@ def convert_to_blob(
return lit
+ def convert_to_union(
+ self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
+ ) -> Literal:
+ lt = self._literal_type
+ for i in range(len(self._literal_type.union_type.variants)):
+ variant = self._literal_type.union_type.variants[i]
+ python_type = get_args(self._python_type)[i]
+ converter = FlyteLiteralConverter(
+ ctx,
+ self._flyte_ctx,
+ variant,
+ python_type,
+ self._create_upload_fn,
+ )
+ try:
+ # Here we use click converter to convert the input in command line to native python type,
+ # and then use flyte converter to convert it to literal.
+ python_val = converter._click_type.convert(value, param, ctx)
+ literal = converter.convert_to_literal(ctx, param, python_val)
+ self._python_type = python_type
+ return literal
+ except (Exception or AttributeError) as e:
+ logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e)
+ raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}")
+
def convert_to_literal(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
) -> Literal:
@@ -255,7 +281,7 @@ def convert_to_literal(
if self._literal_type.collection_type or self._literal_type.map_value_type:
# TODO Does not support nested flytefile, flyteschema types
- v = json.loads(value)
+ v = json.loads(value) if isinstance(value, str) else value
if self._literal_type.collection_type and not isinstance(v, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}")
if self._literal_type.map_value_type and not isinstance(v, dict):
@@ -263,11 +289,14 @@ def convert_to_literal(
return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type)
if self._literal_type.union_type:
- raise NotImplementedError("Union type is not yet implemented for pyflyte run")
+ return self.convert_to_union(ctx, param, value)
if self._literal_type.simple or self._literal_type.enum_type:
if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT:
- o = cast(DataClassJsonMixin, self._python_type).from_json(value)
+ if type(value) != self._python_type:
+ o = cast(DataClassJsonMixin, self._python_type).from_json(value)
+ else:
+ o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)
return Literal(scalar=self._converter.convert(value, self._python_type))
@@ -396,16 +425,14 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
]
-def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]:
+def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]:
"""
Load the workflow of a the script file.
N.B.: it assumes that the file is self-contained, in other words, there are no relative imports.
"""
- flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
- SerializationSettings(None)
- )
- with context_manager.FlyteContextManager.with_context(flyte_ctx):
- with module_loader.add_sys_path(os.getcwd()):
+ flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder()
+ with context_manager.FlyteContextManager.with_context(flyte_ctx_builder):
+ with module_loader.add_sys_path(project_root):
importlib.import_module(module_name)
return module_loader.load_object_from_module(f"{module_name}.{entity_name}")
@@ -444,9 +471,7 @@ def get_entities_in_file(filename: str) -> Entities:
"""
Returns a list of flyte workflow names and list of Flyte tasks in a file.
"""
- flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
- SerializationSettings(None)
- )
+ flyte_ctx = context_manager.FlyteContextManager.current_context().new_builder()
module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".")
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(os.getcwd()):
@@ -473,6 +498,8 @@ def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow,
"""
def _run(*args, **kwargs):
+ # By the time we get to this function, all the loading has already happened
+
run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY]
project, domain = run_level_params.get("project"), run_level_params.get("domain")
inputs = {}
@@ -486,10 +513,6 @@ def _run(*args, **kwargs):
remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
- # StructuredDatasetTransformerEngine.register(
- # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True
- # )
-
remote_entity = remote.register_script(
entity,
project=project,
@@ -532,32 +555,41 @@ class WorkflowCommand(click.MultiCommand):
def __init__(self, filename: str, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._filename = filename
+ self._filename = pathlib.Path(filename).resolve()
def list_commands(self, ctx):
entities = get_entities_in_file(self._filename)
return entities.all()
def get_command(self, ctx, exe_entity):
+ """
+ This command uses the filename with which this command was created, and the string name of the entity passed
+ after the Python filename on the command line, to load the Python object, and then return the Command that
+ click should run.
+ :param ctx: The click Context object.
+ :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task
+ function.
+ :return:
+ """
+
rel_path = os.path.relpath(self._filename)
if rel_path.startswith(".."):
raise ValueError(
f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}"
)
+ project_root = _find_project_root(self._filename)
+ # Find the relative path for the filename relative to the root of the project.
+ # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as
+ # a parameter.
+ rel_path = self._filename.relative_to(project_root)
module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".")
- entity = load_naive_entity(module, exe_entity)
+ entity = load_naive_entity(module, exe_entity, project_root)
# If this is a remote execution, which we should know at this point, then create the remote object
p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
d = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
- cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE)
- cfg_obj = Config.auto(cfg_file_location)
- cli_logger.info(
- f"Run is using config object {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "")
- )
- r = FlyteRemote(cfg_obj, default_project=p, default_domain=d)
- ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r
+ r = get_and_save_remote_with_click_context(ctx, p, d)
get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=p, domain=d)
flyte_ctx = context_manager.FlyteContextManager.current_context()
@@ -596,8 +628,16 @@ def get_command(self, ctx, filename):
return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode")
+_run_help = """
+This command can execute either a workflow or a task from the commandline, for fully self-contained scripts.
+Tasks and workflows cannot be imported from other files currently. Please use `pyflyte package` or
+`pyflyte register` to handle those and then launch from the Flyte UI or `flytectl`
+
+Note: This command only works on regular Python packages, not namespace packages. When determining
+ the root of your project, it finds the first folder that does not have an __init__.py file.
+"""
+
run = RunCommand(
name="run",
- help="Run command: This command can execute either a workflow or a task from the commandline, for "
- "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.",
+ help=_run_help,
)
diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py
index 188f6ebeee..5354abde4f 100644
--- a/flytekit/configuration/__init__.py
+++ b/flytekit/configuration/__init__.py
@@ -297,6 +297,7 @@ class PlatformConfig(object):
:param endpoint: DNS for Flyte backend
:param insecure: Whether or not to use SSL
+ :param insecure_skip_verify: Wether to skip SSL certificate verification
:param command: This command is executed to return a token using an external process.
:param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment.
More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/.
@@ -309,32 +310,13 @@ class PlatformConfig(object):
endpoint: str = "localhost:30081"
insecure: bool = False
+ insecure_skip_verify: bool = False
command: typing.Optional[typing.List[str]] = None
client_id: typing.Optional[str] = None
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
auth_mode: AuthType = AuthType.STANDARD
- def with_parameters(
- self,
- endpoint: str = "localhost:30081",
- insecure: bool = False,
- command: typing.Optional[typing.List[str]] = None,
- client_id: typing.Optional[str] = None,
- client_credentials_secret: typing.Optional[str] = None,
- scopes: List[str] = None,
- auth_mode: AuthType = AuthType.STANDARD,
- ) -> PlatformConfig:
- return PlatformConfig(
- endpoint=endpoint,
- insecure=insecure,
- command=command,
- client_id=client_id,
- client_credentials_secret=client_credentials_secret,
- scopes=scopes if scopes else [],
- auth_mode=auth_mode,
- )
-
@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
"""
@@ -345,6 +327,9 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "insecure", _internal.Platform.INSECURE.read(config_file))
+ kwargs = set_if_exists(
+ kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file)
+ )
kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file))
kwargs = set_if_exists(
@@ -562,7 +547,7 @@ def for_sandbox(cls) -> Config:
:return: Config
"""
return Config(
- platform=PlatformConfig(insecure=True),
+ platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True),
data_config=DataConfig(
s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage")
),
diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py
index b329a16112..467f660d42 100644
--- a/flytekit/configuration/file.py
+++ b/flytekit/configuration/file.py
@@ -32,13 +32,16 @@ class LegacyConfigEntry(object):
option: str
type_: typing.Type = str
+ def get_env_name(self):
+ return f"FLYTE_{self.section.upper()}_{self.option.upper()}"
+
def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]:
"""
Reads the config entry from environment variable, the structure of the env var is current
``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future.
:return:
"""
- env = f"FLYTE_{self.section.upper()}_{self.option.upper()}"
+ env = self.get_env_name()
v = os.environ.get(env, None)
if v is None:
return None
@@ -159,7 +162,7 @@ def __init__(self, location: str):
Load the config from this location
"""
self._location = location
- if location.endswith("yaml"):
+ if location.endswith("yaml") or location.endswith("yml"):
self._legacy_config = None
self._yaml_config = self._read_yaml_config(location)
else:
diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py
index 55018c1caf..fb09015b6f 100644
--- a/flytekit/configuration/internal.py
+++ b/flytekit/configuration/internal.py
@@ -19,7 +19,7 @@ def get_specified_images(cfg: ConfigFile) -> typing.Dict[str, str]:
:returns a dictionary of name: image Version is optional
"""
images: typing.Dict[str, str] = {}
- if cfg is None:
+ if cfg is None or not cfg.legacy_config:
return images
try:
image_names = cfg.legacy_config.options("images")
@@ -105,6 +105,9 @@ class Platform(object):
LegacyConfigEntry(SECTION, "url"), YamlConfigEntry("admin.endpoint"), lambda x: x.replace("dns:///", "")
)
INSECURE = ConfigEntry(LegacyConfigEntry(SECTION, "insecure", bool), YamlConfigEntry("admin.insecure", bool))
+ INSECURE_SKIP_VERIFY = ConfigEntry(
+ LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool)
+ )
class LocalSDK(object):
diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py
index cc3319e359..6aef36305e 100644
--- a/flytekit/core/base_task.py
+++ b/flytekit/core/base_task.py
@@ -463,7 +463,6 @@ def dispatch_execute(
new_user_params = self.pre_execute(ctx.user_space_params)
from flytekit.deck.deck import _output_deck
- new_user_params._decks = [ctx.user_space_params.default_deck]
# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params))
@@ -539,11 +538,9 @@ def dispatch_execute(
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))
- new_user_params.decks.append(input_deck)
- new_user_params.decks.append(output_deck)
-
if _internal.Deck.DISABLE_DECK.read() is not True and self.disable_deck is False:
_output_deck(self.name.split(".")[-1], new_user_params)
+
outputs_literal_map = _literal_models.LiteralMap(literals=literals)
# After the execute has been successfully completed
return outputs_literal_map
diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py
index 2d16b87aab..a17a9148b6 100644
--- a/flytekit/core/context_manager.py
+++ b/flytekit/core/context_manager.py
@@ -21,12 +21,12 @@
import traceback
import typing
from contextlib import contextmanager
+from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Generator, List, Optional, Union
-import flytekit
from flytekit.clients import friendly as friendly_client # noqa
from flytekit.configuration import Config, SecretsConfig, SerializationSettings
from flytekit.core import mock_stats, utils
@@ -38,10 +38,14 @@
from flytekit.loggers import logger, user_space_logger
from flytekit.models.core import identifier as _identifier
+if typing.TYPE_CHECKING:
+ from flytekit.deck.deck import Deck
+
# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin
# Enables static type checking https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING
+flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[])
if typing.TYPE_CHECKING:
from flytekit.core.base_task import TaskResolverMixin
@@ -78,12 +82,13 @@ class Builder(object):
stats: taggable.TaggableStats
execution_date: datetime
logging: _logging.Logger
- execution_id: str
+ execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier]
attrs: typing.Dict[str, typing.Any]
working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir]
checkpoint: typing.Optional[Checkpoint]
- decks: List[flytekit.Deck]
+ decks: List[Deck]
raw_output_prefix: str
+ task_id: typing.Optional[_identifier.Identifier]
def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.stats = current.stats if current else None
@@ -95,6 +100,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.decks = current._decks if current else []
self.attrs = current._attrs if current else {}
self.raw_output_prefix = current.raw_output_prefix if current else None
+ self.task_id = current.task_id if current else None
def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:
self.attrs[key] = v
@@ -112,6 +118,7 @@ def build(self) -> ExecutionParameters:
checkpoint=self.checkpoint,
decks=self.decks,
raw_output_prefix=self.raw_output_prefix,
+ task_id=self.task_id,
**self.attrs,
)
@@ -141,11 +148,12 @@ def __init__(
execution_date,
tmp_dir,
stats,
- execution_id,
+ execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier],
logging,
raw_output_prefix,
checkpoint=None,
decks=None,
+ task_id: typing.Optional[_identifier.Identifier] = None,
**kwargs,
):
"""
@@ -171,6 +179,7 @@ def __init__(
self._secrets_manager = SecretsManager()
self._checkpoint = checkpoint
self._decks = decks
+ self._task_id = task_id
@property
def stats(self) -> taggable.TaggableStats:
@@ -228,6 +237,14 @@ def execution_id(self) -> _identifier.WorkflowExecutionIdentifier:
"""
return self._execution_id
+ @property
+ def task_id(self) -> typing.Optional[_identifier.Identifier]:
+ """
+ At production run-time, this will be generated by reading environment variables that are set
+ by the backend.
+ """
+ return self._task_id
+
@property
def secrets(self) -> SecretsManager:
return self._secrets_manager
@@ -246,8 +263,8 @@ def decks(self) -> typing.List:
return self._decks
@property
- def default_deck(self) -> "Deck":
- from flytekit import Deck
+ def default_deck(self) -> Deck:
+ from flytekit.deck.deck import Deck
return Deck("default")
@@ -600,6 +617,11 @@ def current_context() -> Optional[FlyteContext]:
"""
return FlyteContextManager.current_context()
+ def get_deck(self) -> str:
+ from flytekit.deck.deck import _get_deck
+
+ return _get_deck(self.execution_state.user_space_params)
+
@dataclass
class Builder(object):
file_access: FileAccessProvider
@@ -701,8 +723,6 @@ class FlyteContextManager(object):
FlyteContextManager.pop_context()
"""
- _OBJS: typing.List[FlyteContext] = []
-
@staticmethod
def get_origin_stackframe(limit=2) -> traceback.FrameSummary:
ss = traceback.extract_stack(limit=limit + 1)
@@ -711,31 +731,36 @@ def get_origin_stackframe(limit=2) -> traceback.FrameSummary:
return ss[0]
@staticmethod
- def current_context() -> Optional[FlyteContext]:
- if FlyteContextManager._OBJS:
- return FlyteContextManager._OBJS[-1]
- return None
+ def current_context() -> FlyteContext:
+ if not flyte_context_Var.get():
+ # we will lost the default flyte context in the new thread. Therefore, reinitialize the context when running in the new thread.
+ FlyteContextManager.initialize()
+ return flyte_context_Var.get()[-1]
@staticmethod
def push_context(ctx: FlyteContext, f: Optional[traceback.FrameSummary] = None) -> FlyteContext:
if not f:
f = FlyteContextManager.get_origin_stackframe(limit=2)
ctx.set_stackframe(f)
- FlyteContextManager._OBJS.append(ctx)
+ context_list = flyte_context_Var.get()
+ context_list.append(ctx)
+ flyte_context_Var.set(context_list)
t = "\t"
logger.debug(
- f"{t * ctx.level}[{len(FlyteContextManager._OBJS)}] Pushing context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
+ f"{t * ctx.level}[{len(flyte_context_Var.get())}] Pushing context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
)
return ctx
@staticmethod
def pop_context() -> FlyteContext:
- ctx = FlyteContextManager._OBJS.pop()
+ context_list = flyte_context_Var.get()
+ ctx = context_list.pop()
+ flyte_context_Var.set(context_list)
t = "\t"
logger.debug(
- f"{t * ctx.level}[{len(FlyteContextManager._OBJS) + 1}] Popping context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
+ f"{t * ctx.level}[{len(flyte_context_Var.get()) + 1}] Popping context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
)
- if len(FlyteContextManager._OBJS) == 0:
+ if len(flyte_context_Var.get()) == 0:
raise AssertionError(f"Illegal Context state! Popped, {ctx}")
return ctx
@@ -766,7 +791,7 @@ def with_context(b: FlyteContext.Builder) -> Generator[FlyteContext, None, None]
@staticmethod
def size() -> int:
- return len(FlyteContextManager._OBJS)
+ return len(flyte_context_Var.get())
@staticmethod
def initialize():
@@ -786,6 +811,7 @@ def initialize():
default_context = FlyteContext(file_access=default_local_file_access_provider)
default_user_space_params = ExecutionParameters(
execution_id=WorkflowExecutionIdentifier.promote_from_model(default_execution_id),
+ task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"),
execution_date=_datetime.datetime.utcnow(),
stats=mock_stats.MockStats(),
logging=user_space_logger,
@@ -798,7 +824,7 @@ def initialize():
default_context.new_execution_state().with_params(user_space_params=default_user_space_params)
).build()
default_context.set_stackframe(s=FlyteContextManager.get_origin_stackframe())
- FlyteContextManager._OBJS = [default_context]
+ flyte_context_Var.set([default_context])
class FlyteEntities(object):
diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py
index fdd47d1741..0f651410bf 100644
--- a/flytekit/core/interface.py
+++ b/flytekit/core/interface.py
@@ -7,12 +7,15 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
+from typing_extensions import get_args, get_origin, get_type_hints
+
from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
+from flytekit.models.literals import Void
from flytekit.types.pickle import FlytePickle
T = typing.TypeVar("T")
@@ -182,11 +185,17 @@ def transform_inputs_to_parameters(
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
- required = _default is None
- default_lv = None
- if _default is not None:
- default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
- params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
+ if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val):
+ from flytekit import Literal, Scalar
+
+ literal = Literal(scalar=Scalar(none_type=Void()))
+ params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
+ else:
+ required = _default is None
+ default_lv = None
+ if _default is not None:
+ default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
+ params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
return _interface_models.ParameterMap(params)
@@ -274,11 +283,8 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
For now the fancy object, maybe in the future a dumb object.
"""
- try:
- # include_extras can only be used in python >= 3.9
- type_hints = typing.get_type_hints(fn, include_extras=True)
- except TypeError:
- type_hints = typing.get_type_hints(fn)
+
+ type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
@@ -386,7 +392,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
bases = return_annotation.__bases__ # type: ignore
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
logger.debug(f"Task returns named tuple {return_annotation}")
- return dict(typing.get_type_hints(return_annotation))
+ return dict(get_type_hints(return_annotation, include_extras=True))
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
# Handle option 3
diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py
index e1f53dc5d2..e8f5bd4aa3 100644
--- a/flytekit/core/launch_plan.py
+++ b/flytekit/core/launch_plan.py
@@ -292,7 +292,6 @@ def get_or_create(
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
- # TODO: Add QoS after it's done
def __init__(
self,
name: str,
@@ -359,6 +358,10 @@ def clone_with(
def python_interface(self) -> Interface:
return self.workflow.python_interface
+ @property
+ def interface(self) -> _interface_models.TypedInterface:
+ return self.workflow.interface
+
@property
def name(self) -> str:
return self._name
diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py
index 6eb4f51d81..1694af6d70 100644
--- a/flytekit/core/node_creation.py
+++ b/flytekit/core/node_creation.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import collections
-from typing import Type, Union
+from typing import TYPE_CHECKING, Type, Union
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
@@ -12,11 +12,15 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
+if TYPE_CHECKING:
+ from flytekit.remote.remote_callable import RemoteEntity
+
+
# This file exists instead of moving to node.py because it needs Task/Workflow/LaunchPlan and those depend on Node
def create_node(
- entity: Union[PythonTask, LaunchPlan, WorkflowBase], *args, **kwargs
+ entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs
) -> Union[Node, VoidPromise, Type[collections.namedtuple]]:
"""
This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or
@@ -65,6 +69,8 @@ def sub_wf():
t2(t1_node.o0)
"""
+ from flytekit.remote.remote_callable import RemoteEntity
+
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"Only keyword args are supported to pass inputs to workflows and tasks."
@@ -75,8 +81,9 @@ def sub_wf():
not isinstance(entity, PythonTask)
and not isinstance(entity, WorkflowBase)
and not isinstance(entity, LaunchPlan)
+ and not isinstance(entity, RemoteEntity)
):
- raise AssertionError("Should be but it's not")
+ raise AssertionError(f"Should be a callable Flyte entity (either local or fetched) but is {type(entity)}")
# This function is only called from inside workflows and dynamic tasks.
# That means there are two scenarios we need to take care of, compilation and local workflow execution.
@@ -84,7 +91,6 @@ def sub_wf():
# When compiling, calling the entity will create a node.
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
-
outputs = entity(**kwargs)
# This is always the output of create_and_link_node which returns create_task_output, which can be
# VoidPromise, Promise, or our custom namedtuple of Promises.
@@ -105,9 +111,11 @@ def sub_wf():
return node
# If a Promise or custom namedtuple of Promises, we need to attach each output as an attribute to the node.
- if entity.python_interface.outputs:
+ # todo: fix the noqas below somehow... can't add abstract property to RemoteEntity because it has to come
+ # before the model Template classes in FlyteTask/Workflow/LaunchPlan
+ if entity.interface.outputs: # noqa
if isinstance(outputs, tuple):
- for output_name in entity.python_interface.output_names:
+ for output_name in entity.interface.outputs.keys(): # noqa
attr = getattr(outputs, output_name)
if attr is None:
raise _user_exceptions.FlyteAssertion(
@@ -120,7 +128,7 @@ def sub_wf():
setattr(node, output_name, attr)
node.outputs[output_name] = attr
else:
- output_names = entity.python_interface.output_names
+ output_names = [k for k in entity.interface.outputs.keys()] # noqa
if len(output_names) != 1:
raise _user_exceptions.FlyteAssertion(f"Output of length 1 expected but {len(output_names)} found")
@@ -136,6 +144,9 @@ def sub_wf():
# Handling local execution
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
+ if isinstance(entity, RemoteEntity):
+ raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}")
+
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
logger.warning(f"Manual node creation cannot be used in branch logic {entity.name}")
raise Exception("Being more restrictive for now and disallowing manual node creation in branch logic")
diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py
index 99ead16b2c..4c9150881d 100644
--- a/flytekit/core/promise.py
+++ b/flytekit/core/promise.py
@@ -5,7 +5,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union, cast
-from typing_extensions import Protocol
+from typing_extensions import Protocol, get_args
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager as _flyte_context
@@ -23,6 +23,7 @@
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
+from flytekit.models.types import SimpleType
def translate_inputs_to_literals(
@@ -68,29 +69,43 @@ def extract_value(
) -> _literal_models.Literal:
if isinstance(input_val, list):
- if flyte_literal_type.collection_type is None:
+ lt = flyte_literal_type
+ python_type = val_type
+ if flyte_literal_type.union_type:
+ for i in range(len(flyte_literal_type.union_type.variants)):
+ variant = flyte_literal_type.union_type.variants[i]
+ if variant.collection_type:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if lt.collection_type is None:
raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}")
try:
- sub_type = ListTransformer.get_sub_type(val_type)
+ sub_type = ListTransformer.get_sub_type(python_type)
except ValueError:
if len(input_val) == 0:
raise
sub_type = type(input_val[0])
- literal_list = [extract_value(ctx, v, sub_type, flyte_literal_type.collection_type) for v in input_val]
+ literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list))
elif isinstance(input_val, dict):
- if (
- flyte_literal_type.map_value_type is None
- and flyte_literal_type.simple != _type_models.SimpleType.STRUCT
- ):
- raise TypeError(f"Not a map type {flyte_literal_type} but got a map {input_val}")
- k_type, sub_type = DictTransformer.get_dict_types(val_type) # type: ignore
- if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
- return TypeEngine.to_literal(ctx, input_val, type(input_val), flyte_literal_type)
+ lt = flyte_literal_type
+ python_type = val_type
+ if flyte_literal_type.union_type:
+ for i in range(len(flyte_literal_type.union_type.variants)):
+ variant = flyte_literal_type.union_type.variants[i]
+ if variant.map_value_type:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if variant.simple == _type_models.SimpleType.STRUCT:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if lt.map_value_type is None and lt.simple != _type_models.SimpleType.STRUCT:
+ raise TypeError(f"Not a map type {lt} but got a map {input_val}")
+ if lt.simple == _type_models.SimpleType.STRUCT:
+ return TypeEngine.to_literal(ctx, input_val, type(input_val), lt)
else:
- literal_map = {
- k: extract_value(ctx, v, sub_type, flyte_literal_type.map_value_type) for k, v in input_val.items()
- }
+ k_type, sub_type = DictTransformer.get_dict_types(python_type) # type: ignore
+ literal_map = {k: extract_value(ctx, v, sub_type, lt.map_value_type) for k, v in input_val.items()}
return _literal_models.Literal(map=_literal_models.LiteralMap(literals=literal_map))
elif isinstance(input_val, Promise):
# In the example above, this handles the "in2=a" type of argument
@@ -863,7 +878,7 @@ def create_and_link_node(
ctx: FlyteContext,
entity: SupportsNodeCreation,
**kwargs,
-):
+) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
"""
This method is used to generate a node with bindings. This is not used in the execution path.
"""
@@ -881,7 +896,20 @@ def create_and_link_node(
for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if k not in kwargs:
- raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
+ is_optional = False
+ if var.type.union_type:
+ for variant in var.type.union_type.variants:
+ if variant.simple == SimpleType.NONE:
+ val, _default = interface.inputs_with_defaults[k]
+ if _default is not None:
+ raise ValueError(
+ f"The default value for the optional type must be None, but got {_default}"
+ )
+ is_optional = True
+ if not is_optional:
+ raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
+ else:
+ continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
@@ -999,6 +1027,7 @@ def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwar
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)
)
) as child_ctx:
+ cast(FlyteContext, child_ctx).user_space_params._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)
expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs)
diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py
index 0fff171f36..4b6743afd3 100644
--- a/flytekit/core/python_auto_container.py
+++ b/flytekit/core/python_auto_container.py
@@ -156,7 +156,10 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return self._get_command_fn(settings)
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
- env = {**settings.env, **self.environment} if self.environment else settings.env
+ env = {}
+ for elem in (settings.env, self.environment):
+ if elem:
+ env.update(elem)
return _get_container_definition(
image=get_registerable_container_image(self.container_image, settings.image_config),
command=[],
@@ -253,4 +256,4 @@ def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> st
# fqn will access the fully qualified name of the image (e.g. registry/imagename:version -> registry/imagename)
# version will access the version part of the image (e.g. registry/imagename:version -> version)
# With empty attribute, it'll access the full image path (e.g. registry/imagename:version -> registry/imagename:version)
-_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z]+))(?:\.([a-zA-Z]+))?\s*}})", re.IGNORECASE)
+_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z0-9_]+))(?:\.([a-zA-Z0-9_]+))?\s*}})", re.IGNORECASE)
diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py
index 0c5fe786ae..7addc89197 100644
--- a/flytekit/core/schedule.py
+++ b/flytekit/core/schedule.py
@@ -15,13 +15,14 @@
# Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass.
class CronSchedule(_schedule_models.Schedule):
"""
- Use this when you have a launch plan that you want to run on a cron expression. The syntax currently used for this
- follows the `AWS convention `__
+ Use this when you have a launch plan that you want to run on a cron expression.
+ This uses standard `cron format `__
+ in case where you are using default native scheduler using the schedule attribute.
.. code-block::
CronSchedule(
- cron_expression="0 10 * * ? *",
+ schedule="*/1 * * * *", # Following schedule runs every min
)
See the :std:ref:`User Guide ` for further examples.
@@ -54,9 +55,10 @@ def __init__(
self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None
):
"""
- :param str cron_expression: This should be a cron expression in AWS style.
+ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler.
:param str schedule: This takes a cron alias (see ``_VALID_CRON_ALIASES``) or a croniter parseable schedule.
- Only one of this or ``cron_expression`` can be set, not both.
+ Only one of this or ``cron_expression`` can be set, not both. This uses standard `cron format `_
+ and is supported by native scheduler
:param str offset:
:param str kickoff_time_input_arg: This is a convenient argument to use when your code needs to know what time
a run was kicked off. Supply the name of the input argument of your workflow to this argument here. Note
@@ -67,7 +69,7 @@ def __init__(
def my_wf(kickoff_time: datetime): ...
schedule = CronSchedule(
- cron_expression="0 10 * * ? *",
+ schedule="*/1 * * * *"
kickoff_time_input_arg="kickoff_time")
"""
diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py
index 60c42ac339..b3851e77ce 100644
--- a/flytekit/core/type_engine.py
+++ b/flytekit/core/type_engine.py
@@ -319,7 +319,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)
- def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
+ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any:
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
"""
@@ -328,36 +328,42 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset
- for f in dataclasses.fields(python_type):
- v = python_val.__getattribute__(f.name)
- field_type = f.type
- if inspect.isclass(field_type) and (
- issubclass(field_type, FlyteSchema)
- or issubclass(field_type, FlyteFile)
- or issubclass(field_type, FlyteDirectory)
- or issubclass(field_type, StructuredDataset)
- ):
- lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
- # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
- # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
- # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
- # so that dataclass_json can always get a remote path.
- # In other words, the file transformer has special code that handles the fact that if remote_source is
- # set, then the real uri in the literal should be the remote source, not the path (which may be an
- # auto-generated random local path). To be sure we're writing the right path to the json, use the uri
- # as determined by the transformer.
- if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
- python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))
- elif issubclass(field_type, StructuredDataset):
- python_val.__setattr__(
- f.name,
- field_type(
- uri=lv.scalar.structured_dataset.uri,
- ),
- )
+ if hasattr(python_type, "__origin__") and python_type.__origin__ is list:
+ return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val]
+
+ if hasattr(python_type, "__origin__") and python_type.__origin__ is dict:
+ return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()}
- elif dataclasses.is_dataclass(field_type):
- self._serialize_flyte_type(v, field_type)
+ if not dataclasses.is_dataclass(python_type):
+ return python_val
+
+ if inspect.isclass(python_type) and (
+ issubclass(python_type, FlyteSchema)
+ or issubclass(python_type, FlyteFile)
+ or issubclass(python_type, FlyteDirectory)
+ or issubclass(python_type, StructuredDataset)
+ ):
+ lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None)
+ # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
+ # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
+ # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
+ # so that dataclass_json can always get a remote path.
+ # In other words, the file transformer has special code that handles the fact that if remote_source is
+ # set, then the real uri in the literal should be the remote source, not the path (which may be an
+ # auto-generated random local path). To be sure we're writing the right path to the json, use the uri
+ # as determined by the transformer.
+ if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory):
+ return python_type(path=lv.scalar.blob.uri)
+ elif issubclass(python_type, StructuredDataset):
+ return python_type(uri=lv.scalar.structured_dataset.uri)
+ else:
+ return python_val
+ else:
+ for v in dataclasses.fields(python_type):
+ val = python_val.__getattribute__(v.name)
+ field_type = v.type
+ python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type))
+ return python_val
def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
@@ -365,6 +371,12 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine
+ if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list:
+ return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val]
+
+ if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is dict:
+ return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()}
+
if not dataclasses.is_dataclass(expected_python_type):
return python_val
@@ -737,11 +749,11 @@ def literal_map_to_kwargs(
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
"""
- if len(lm.literals) != len(python_types):
+ if len(lm.literals) > len(python_types):
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
)
- return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()}
+ return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()}
@classmethod
def dict_to_literal_map(
@@ -993,7 +1005,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
# Should really never happen, sanity check
raise TypeError("Ambiguous choice of variant for union type")
found_res = True
- except TypeTransformerFailedError as e:
+ except (TypeTransformerFailedError, AttributeError) as e:
logger.debug(f"Failed to convert from {python_val} to {t}", e)
continue
@@ -1047,7 +1059,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
)
res_tag = trans.name
found_res = True
- except TypeTransformerFailedError as e:
+ except (TypeTransformerFailedError, AttributeError) as e:
logger.debug(f"Failed to convert from {lv} to {v}", e)
if found_res:
diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py
index 435416e0ff..599b886ab0 100644
--- a/flytekit/deck/deck.py
+++ b/flytekit/deck/deck.py
@@ -1,9 +1,13 @@
import os
-from typing import Dict, Optional
+from typing import Optional
from jinja2 import Environment, FileSystemLoader
-from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
+from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager
+from flytekit.loggers import logger
+
+OUTPUT_DIR_JUPYTER_PREFIX = "jupyter"
+DECK_FILE_NAME = "deck.html"
class Deck:
@@ -69,30 +73,43 @@ def html(self) -> str:
return self._html
-def _output_deck(task_name: str, new_user_params: ExecutionParameters):
- deck_map: Dict[str, str] = {}
- decks = new_user_params.decks
- ctx = FlyteContext.current_context()
+def _ipython_check() -> bool:
+ """
+ Check if interface is launching from iPython (not colab)
+ :return is_ipython (bool): True or False
+ """
+ is_ipython = False
+ try: # Check if running interactively using ipython.
+ from IPython import get_ipython
- # TODO: upload deck file to remote filesystems (s3, gcs)
- output_dir = ctx.file_access.get_random_local_directory()
+ if get_ipython() is not None:
+ is_ipython = True
+ except (ImportError, NameError):
+ pass
+ return is_ipython
- for deck in decks:
- _deck_to_html_file(deck, deck_map, output_dir)
- root = os.path.dirname(os.path.abspath(__file__))
- templates_dir = os.path.join(root, "html")
- env = Environment(loader=FileSystemLoader(templates_dir))
- template = env.get_template("template.html")
+def _get_deck(new_user_params: ExecutionParameters) -> str:
+ """
+ Get flyte deck html string
+ """
+ deck_map = {deck.name: deck.html for deck in new_user_params.decks}
+ return template.render(metadata=deck_map)
+
- deck_path = os.path.join(output_dir, "deck.html")
+def _output_deck(task_name: str, new_user_params: ExecutionParameters):
+ ctx = FlyteContext.current_context()
+ if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
+ output_dir = ctx.execution_state.engine_dir
+ else:
+ output_dir = ctx.file_access.get_random_local_directory()
+ deck_path = os.path.join(output_dir, DECK_FILE_NAME)
with open(deck_path, "w") as f:
- f.write(template.render(metadata=deck_map))
+ f.write(_get_deck(new_user_params))
+ logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}")
-def _deck_to_html_file(deck: Deck, deck_map: Dict[str, str], output_dir: str):
- file_name = deck.name + ".html"
- path = os.path.join(output_dir, file_name)
- with open(path, "w") as output:
- deck_map[deck.name] = file_name
- output.write(deck.html)
+root = os.path.dirname(os.path.abspath(__file__))
+templates_dir = os.path.join(root, "html")
+env = Environment(loader=FileSystemLoader(templates_dir))
+template = env.get_template("template.html")
diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html
index 1a429a2ab6..3992ab9c0f 100644
--- a/flytekit/deck/html/template.html
+++ b/flytekit/deck/html/template.html
@@ -4,56 +4,12 @@
User Content
+
-
-
-