From fe0467c66949bb86c602c930485bca9f1a109c7b Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 22 May 2024 18:11:39 +0530 Subject: [PATCH] add openai batch API agent (#2353) * add openai batch endpoint agent Signed-off-by: Samhita Alla * add jsonl type transformer, modify agent Signed-off-by: Samhita Alla * revert format changes Signed-off-by: Samhita Alla * revert iterator edits Signed-off-by: Samhita Alla * typealias Signed-off-by: Samhita Alla * add jsonlines Signed-off-by: Samhita Alla * remove typealias Signed-off-by: Samhita Alla * update readthedocs python version Signed-off-by: Samhita Alla * update docs python version Signed-off-by: Samhita Alla * replace dict with enum Signed-off-by: Samhita Alla * modify json type; add a check to validate if iterator's empty Signed-off-by: Samhita Alla * ignore mypy check Signed-off-by: Samhita Alla * modify JSON type Signed-off-by: Samhita Alla * move to openai folder Signed-off-by: Samhita Alla * update readme Signed-off-by: Samhita Alla * update setup.py Signed-off-by: Samhita Alla * batch_api to batch, add plugin to setup.py Signed-off-by: Samhita Alla * modify return type Signed-off-by: Samhita Alla * modify return type Signed-off-by: Samhita Alla * fix lint error Signed-off-by: Samhita Alla * remove guess_python_type Signed-off-by: Samhita Alla * modify json type Signed-off-by: Samhita Alla * modify tests Signed-off-by: Samhita Alla * json to iterator[json] Signed-off-by: Samhita Alla * update plugin readme Signed-off-by: Samhita Alla * replace flytefile with jsonlfile Signed-off-by: Samhita Alla * modify output type of batch; add typealias Signed-off-by: Samhita Alla * nit Signed-off-by: Samhita Alla * nit Signed-off-by: Samhita Alla * replace openai-batch-endpoint with openai-batch Signed-off-by: Samhita Alla * update readme Signed-off-by: Samhita Alla * revert to secrets Signed-off-by: Samhita Alla * fix openai batch code; add json iterator click type Signed-off-by: Samhita Alla * update readme Signed-off-by: Samhita Alla * fix types Signed-off-by: Samhita Alla * add shim tasks Signed-off-by: Samhita Alla * dict to dataclass; add container image; add guess_python_type to json iterator Signed-off-by: Samhita Alla * json iterator check in type engine and flytefile Signed-off-by: Samhita Alla * update image version Signed-off-by: Samhita Alla * add secret; update json iterator Signed-off-by: Samhita Alla * add secret to shim task init method Signed-off-by: Samhita Alla * fix secret Signed-off-by: Samhita Alla * fix secret Signed-off-by: Samhita Alla * add secret to dict Signed-off-by: Samhita Alla * fix logging error; remove iterator copy; remove in flyte entity names Signed-off-by: Samhita Alla * openai tests Signed-off-by: Samhita Alla * lint and remove auto spec in openai tests Signed-off-by: Samhita Alla * fix test Signed-off-by: Samhita Alla * json key check Signed-off-by: Samhita Alla * modify input type of upload json data to jsonlfile only Signed-off-by: Samhita Alla * add jsonl to mime type Signed-off-by: Samhita Alla * change mime type Signed-off-by: Samhita Alla * change mime type and fix tests Signed-off-by: Samhita Alla * add json mime type Signed-off-by: Samhita Alla * lint Signed-off-by: Samhita Alla * lint Signed-off-by: Samhita Alla * incorporate kevin's suggestion Signed-off-by: Samhita Alla * requests 2.32.2 doesn't work either Signed-off-by: Samhita Alla --------- Signed-off-by: Samhita Alla --- .github/workflows/pythonpublish.yml | 23 ++ dev-requirements.txt | 102 ++------- docs/source/plugins/index.rst | 2 + docs/source/plugins/openai.rst | 12 ++ docs/source/types.builtins.iterator.rst | 4 + docs/source/types.extend.rst | 1 + flytekit/core/type_engine.py | 25 ++- flytekit/interaction/click_types.py | 12 ++ flytekit/types/file/__init__.py | 6 + flytekit/types/file/file.py | 8 +- flytekit/types/iterator/__init__.py | 5 +- flytekit/types/iterator/json_iterator.py | 112 ++++++++++ plugins/README.md | 4 +- plugins/flytekit-openai/Dockerfile.batch | 16 ++ plugins/flytekit-openai/README.md | 84 +++++++- plugins/flytekit-openai/dev-requirements.txt | 1 + .../flytekitplugins/openai/__init__.py | 13 +- .../flytekitplugins/openai/batch/__init__.py | 0 .../flytekitplugins/openai/batch/agent.py | 129 ++++++++++++ .../flytekitplugins/openai/batch/task.py | 198 ++++++++++++++++++ .../flytekitplugins/openai/batch/workflow.py | 69 ++++++ plugins/flytekit-openai/setup.py | 6 +- .../flytekit-openai/tests/chatgpt/__init__.py | 0 .../tests/{ => chatgpt}/test_agent.py | 0 .../tests/{ => chatgpt}/test_chatgpt.py | 0 .../tests/openai_batch/__init__.py | 0 .../tests/openai_batch/data.jsonl | 2 + .../tests/openai_batch/test_agent.py | 181 ++++++++++++++++ .../tests/openai_batch/test_task.py | 141 +++++++++++++ .../tests/openai_batch/test_workflow.py | 15 ++ plugins/setup.py | 1 + pyproject.toml | 3 +- tests/flytekit/unit/types/iterator/data.jsonl | 3 + .../unit/types/iterator/test_json_iterator.py | 137 ++++++++++++ 34 files changed, 1209 insertions(+), 106 deletions(-) create mode 100644 docs/source/plugins/openai.rst create mode 100644 docs/source/types.builtins.iterator.rst create mode 100644 flytekit/types/iterator/json_iterator.py create mode 100644 plugins/flytekit-openai/Dockerfile.batch create mode 100644 plugins/flytekit-openai/dev-requirements.txt create mode 100644 plugins/flytekit-openai/flytekitplugins/openai/batch/__init__.py create mode 100644 plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py create mode 100644 plugins/flytekit-openai/flytekitplugins/openai/batch/task.py create mode 100644 plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py create mode 100644 plugins/flytekit-openai/tests/chatgpt/__init__.py rename plugins/flytekit-openai/tests/{ => chatgpt}/test_agent.py (100%) rename plugins/flytekit-openai/tests/{ => chatgpt}/test_chatgpt.py (100%) create mode 100644 plugins/flytekit-openai/tests/openai_batch/__init__.py create mode 100644 plugins/flytekit-openai/tests/openai_batch/data.jsonl create mode 100644 plugins/flytekit-openai/tests/openai_batch/test_agent.py create mode 100644 plugins/flytekit-openai/tests/openai_batch/test_task.py create mode 100644 plugins/flytekit-openai/tests/openai_batch/test_workflow.py create mode 100644 tests/flytekit/unit/types/iterator/data.jsonl create mode 100644 tests/flytekit/unit/types/iterator/test_json_iterator.py diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 35b098524a..2b4ba6c0d1 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -147,6 +147,29 @@ jobs: file: ./plugins/flytekit-sqlalchemy/Dockerfile cache-from: type=gha cache-to: type=gha,mode=max + - name: Prepare OpenAI Batch Image Names + id: openai-batch-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + py${{ matrix.python-version }}-openai-batch-latest + py${{ matrix.python-version }}-openai-batch-${{ github.sha }} + py${{ matrix.python-version }}-openai-batch-${{ needs.deploy.outputs.version }} + - name: Push OpenAI Batch Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "./plugins/flytekit-openai/" + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.openai-batch-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + PYTHON_VERSION=${{ matrix.python-version }} + file: ./plugins/flytekit-openai/Dockerfile.batch + cache-from: type=gha + cache-to: type=gha,mode=max build-and-push-flyteagent-images: runs-on: ubuntu-latest diff --git a/dev-requirements.txt b/dev-requirements.txt index 301b6eae27..28fab54d4e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,15 +1,11 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile dev-requirements.in # -e file:.#egg=flytekit # via -r dev-requirements.in -absl-py==2.1.0 - # via - # tensorboard - # tensorflow-macos adlfs==2023.9.0 # via flytekit aiobotocore==2.5.4 @@ -26,12 +22,11 @@ aiosignal==1.3.1 # via aiohttp asttokens==2.4.1 # via stack-data -astunparse==1.6.3 - # via tensorflow-macos attrs==23.2.0 # via # aiohttp # hypothesis + # jsonlines autoflake==2.2.1 # via -r dev-requirements.in azure-core==1.30.0 @@ -100,11 +95,7 @@ execnet==2.0.2 executing==2.0.1 # via stack-data filelock==3.13.1 - # via - # torch - # virtualenv -flatbuffers==23.5.26 - # via tensorflow-macos + # via virtualenv flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl # via # -r dev-requirements.in @@ -119,9 +110,6 @@ fsspec==2023.9.2 # flytekit # gcsfs # s3fs - # torch -gast==0.5.4 - # via tensorflow-macos gcsfs==2023.9.2 # via flytekit google-api-core[grpc]==2.16.2 @@ -138,11 +126,8 @@ google-auth==2.27.0 # google-cloud-core # google-cloud-storage # kubernetes - # tensorboard google-auth-oauthlib==1.2.0 - # via - # gcsfs - # tensorboard + # via gcsfs google-cloud-bigquery==3.17.1 # via -r dev-requirements.in google-cloud-bigquery-storage==2.24.0 @@ -157,8 +142,6 @@ google-crc32c==1.5.0 # via # google-cloud-storage # google-resumable-media -google-pasta==0.2.0 - # via tensorflow-macos google-resumable-media==2.7.0 # via # google-cloud-bigquery @@ -175,14 +158,10 @@ grpcio==1.60.1 # flytekit # google-api-core # grpcio-status - # tensorboard - # tensorflow-macos grpcio-status==1.60.1 # via # flytekit # google-api-core -h5py==3.10.0 - # via tensorflow-macos hypothesis==6.98.2 # via -r dev-requirements.in icdiff==2.0.7 @@ -194,9 +173,7 @@ idna==3.6 # requests # yarl importlib-metadata==7.0.1 - # via - # flytekit - # keyring + # via flytekit iniconfig==2.0.0 # via pytest ipython==8.21.0 @@ -212,9 +189,7 @@ jaraco-classes==3.3.0 jedi==0.19.1 # via ipython jinja2==3.1.3 - # via - # -r dev-requirements.in - # torch + # via -r dev-requirements.in jmespath==1.0.1 # via botocore joblib==1.3.2 @@ -222,28 +197,22 @@ joblib==1.3.2 # -r dev-requirements.in # flytekit # scikit-learn +jsonlines==4.0.0 + # via flytekit jsonpickle==3.0.2 # via flytekit -keras==2.15.0 - # via tensorflow-macos keyring==24.3.0 # via flytekit keyrings-alt==5.0.0 # via -r dev-requirements.in kubernetes==29.0.0 # via -r dev-requirements.in -libclang==16.0.6 - # via tensorflow-macos -markdown==3.5.2 - # via tensorboard markdown-it-py==3.0.0 # via # flytekit # rich markupsafe==2.1.5 - # via - # jinja2 - # werkzeug + # via jinja2 marshmallow==3.20.2 # via # dataclasses-json @@ -261,14 +230,10 @@ matplotlib-inline==0.1.6 # via ipython mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.2.0 - # via tensorflow-macos mock==5.1.0 # via -r dev-requirements.in more-itertools==10.2.0 # via jaraco-classes -mpmath==1.3.0 - # via sympy msal==1.26.0 # via # azure-datalake-store @@ -286,28 +251,19 @@ mypy-extensions==1.0.0 # via # mypy # typing-inspect -networkx==3.2.1 - # via torch nodeenv==1.8.0 # via pre-commit numpy==1.26.4 # via # -r dev-requirements.in - # h5py - # ml-dtypes - # opt-einsum # pandas # pyarrow # scikit-learn # scipy - # tensorboard - # tensorflow-macos oauthlib==3.2.2 # via # kubernetes # requests-oauthlib -opt-einsum==3.3.0 - # via tensorflow-macos orjson==3.9.12 # via -r dev-requirements.in packaging==23.2 @@ -318,7 +274,6 @@ packaging==23.2 # msal-extensions # pytest # setuptools-scm - # tensorflow-macos pandas==2.2.0 # via -r dev-requirements.in parso==0.8.3 @@ -353,8 +308,6 @@ protobuf==4.23.4 # grpcio-status # proto-plus # protoc-gen-openapiv2 - # tensorboard - # tensorflow-macos protoc-gen-openapiv2==0.0.1 # via flyteidl ptyprocess==0.7.0 @@ -379,7 +332,9 @@ pygments==2.17.2 # ipython # rich pyjwt[crypto]==2.8.0 - # via msal + # via + # msal + # pyjwt pytest==7.4.4 # via # -r dev-requirements.in @@ -436,7 +391,6 @@ requests==2.31.0 # kubernetes # msal # requests-oauthlib - # tensorboard requests-oauthlib==1.3.1 # via # google-auth-oauthlib @@ -460,40 +414,18 @@ setuptools-scm==8.0.4 six==1.16.0 # via # asttokens - # astunparse # azure-core - # google-pasta # isodate # kubernetes # python-dateutil - # tensorboard - # tensorflow-macos sortedcontainers==2.4.0 # via hypothesis stack-data==0.6.3 # via ipython statsd==3.3.0 # via flytekit -sympy==1.12 - # via torch -tensorboard==2.15.1 - # via tensorflow-macos -tensorboard-data-server==0.7.2 - # via tensorboard -tensorflow==2.15.0 ; python_version < "3.12" - # via -r dev-requirements.in -tensorflow-estimator==2.15.0 - # via tensorflow-macos -tensorflow-io-gcs-filesystem==0.34.0 - # via tensorflow-macos -tensorflow-macos==2.15.0 - # via tensorflow -termcolor==2.4.0 - # via tensorflow-macos threadpoolctl==3.2.0 # via scikit-learn -torch==2.2.0 ; python_version < "3.12" - # via -r dev-requirements.in traitlets==5.14.1 # via # ipython @@ -519,8 +451,6 @@ typing-extensions==4.9.0 # mypy # rich-click # setuptools-scm - # tensorflow-macos - # torch # typing-inspect typing-inspect==0.9.0 # via dataclasses-json @@ -541,14 +471,8 @@ websocket-client==1.7.0 # via # docker # kubernetes -werkzeug==3.0.1 - # via tensorboard -wheel==0.42.0 - # via astunparse wrapt==1.14.1 - # via - # aiobotocore - # tensorflow-macos + # via aiobotocore yarl==1.9.4 # via aiohttp zipp==3.17.0 diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index c2f6599e03..40e5d00ff9 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -31,6 +31,7 @@ Plugin API reference * :ref:`MLflow ` - MLflow API reference * :ref:`DuckDB ` - DuckDB API reference * :ref:`SageMaker Inference ` - SageMaker Inference API reference +* :ref:`OpenAI ` - OpenAI API reference .. toctree:: :maxdepth: 2 @@ -63,3 +64,4 @@ Plugin API reference MLflow DuckDB SageMaker Inference + OpenAI diff --git a/docs/source/plugins/openai.rst b/docs/source/plugins/openai.rst new file mode 100644 index 0000000000..169529c922 --- /dev/null +++ b/docs/source/plugins/openai.rst @@ -0,0 +1,12 @@ +.. _openai: + +################ +OpenAI reference +################ + +.. tags:: Integration, OpenAI + +.. automodule:: flytekitplugins.openai + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.builtins.iterator.rst b/docs/source/types.builtins.iterator.rst new file mode 100644 index 0000000000..560c13dd5f --- /dev/null +++ b/docs/source/types.builtins.iterator.rst @@ -0,0 +1,4 @@ +.. automodule:: flytekit.types.iterator + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index db1cb8dfff..9848c3d4e2 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -14,6 +14,7 @@ Refer to the :ref:`extensibility contribution guide TypeTransformer[T]: # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. return cls._ENUM_TRANSFORMER + from flytekit.types.iterator.json_iterator import JSONIterator + for base_type in cls._REGISTRY.keys(): if base_type is None: continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it try: - if isinstance(python_type, base_type) or ( - inspect.isclass(python_type) and issubclass(python_type, base_type) + origin_type: Optional[typing.Any] = base_type + if hasattr(base_type, "__args__"): + origin_base_type = get_origin(base_type) + if isinstance(origin_base_type, type) and issubclass( + origin_base_type, typing.Iterator + ): # Iterator[JSON] + origin_type = origin_base_type + + if isinstance(python_type, origin_type) or ( # type: ignore[arg-type] + inspect.isclass(python_type) and issubclass(python_type, origin_type) # type: ignore[arg-type] ): + # Consider Iterator[JSON] but not vanilla Iterator when the value is a JSON iterator. + if ( + isinstance(python_type, type) + and issubclass(python_type, JSONIterator) + and not get_args(base_type) + ): + continue return cls._REGISTRY[base_type] except TypeError: # As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which @@ -1579,8 +1596,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if found_res: is_ambiguous = True found_res = True - except Exception as e: - logger.debug(f"Failed to convert from {python_val} to {t}", e) + except Exception: + logger.debug(f"Failed to convert from {python_val} to {t}", exc_info=True) continue if is_ambiguous: diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 85ce1c7d93..2c26ed0cbc 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -21,6 +21,7 @@ from flytekit.remote.remote_fs import FlytePathResolver from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile +from flytekit.types.iterator.json_iterator import JSONIteratorTransformer from flytekit.types.pickle.pickle import FlytePickleTransformer @@ -129,6 +130,15 @@ def convert( return FlyteFile(path=str(pathlib.Path(uri).resolve()), remote_path=remote_path) +class JSONIteratorParamType(click.ParamType): + name = "json iterator" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + return value + + class DateTimeType(click.DateTime): _NOW_FMT = "now" _ADDITONAL_FORMATS = [_NOW_FMT] @@ -332,6 +342,8 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: return PickleParamType() + elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT: + return JSONIteratorParamType() return FileParamType() return DirParamType() diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 8a2fe50b6c..838516f33d 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -20,6 +20,7 @@ PythonNotebook SVGImageFile """ + import typing from typing_extensions import Annotated, get_args, get_origin @@ -114,3 +115,8 @@ def check_and_convert_to_str(item: typing.Union[typing.Type, str]) -> str: #: Can be used to receive or return an TFRecordFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. TFRecordFile = FlyteFile[tfrecords_file] + +jsonl_file = Annotated[str, FileExt("jsonl")] +#: Can be used to receive or return a JSONLFile. The underlying type is a FlyteFile type. This is just a +#: decoration and useful for attaching content type information with the file and automatically documenting code. +JSONLFile = FlyteFile[jsonl_file] diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 5c47bda998..2995bd82f7 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -341,7 +341,7 @@ def assert_type( def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType: return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t))) - def get_mime_type_from_extension(self, extension: str) -> str: + def get_mime_type_from_extension(self, extension: str) -> typing.Union[str, typing.Sequence[str]]: extension_to_mime_type = { "hdf5": "text/plain", "joblib": "application/octet-stream", @@ -349,6 +349,7 @@ def get_mime_type_from_extension(self, extension: str) -> str: "ipynb": "application/json", "onnx": "application/json", "tfrecord": "application/octet-stream", + "jsonl": ["application/json", "application/x-ndjson"], } for ext, mimetype in mimetypes.types_map.items(): @@ -389,7 +390,7 @@ def validate_file_type( if FlyteFilePathTransformer.get_format(python_type): real_type = magic.from_file(source_path, mime=True) expected_type = self.get_mime_type_from_extension(FlyteFilePathTransformer.get_format(python_type)) - if real_type != expected_type: + if real_type not in expected_type: raise ValueError(f"Incorrect file type, expected {expected_type}, got {real_type}") def to_literal( @@ -525,10 +526,13 @@ def _downloader(): return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: + from flytekit.types.iterator.json_iterator import JSONIteratorTransformer + if ( literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT + and literal_type.blob.format != JSONIteratorTransformer.JSON_ITERATOR_FORMAT ): return FlyteFile.__class_getitem__(literal_type.blob.format) diff --git a/flytekit/types/iterator/__init__.py b/flytekit/types/iterator/__init__.py index 998a84b422..3c1911394f 100644 --- a/flytekit/types/iterator/__init__.py +++ b/flytekit/types/iterator/__init__.py @@ -1,6 +1,7 @@ """ Flytekit Iterator Type -========================================================== +====================== + .. currentmodule:: flytekit.types.iterator .. autosummary:: @@ -8,6 +9,8 @@ :toctree: generated/ FlyteIterator + JSON """ from .iterator import FlyteIterator +from .json_iterator import JSON diff --git a/flytekit/types/iterator/json_iterator.py b/flytekit/types/iterator/json_iterator.py new file mode 100644 index 0000000000..043c050f43 --- /dev/null +++ b/flytekit/types/iterator/json_iterator.py @@ -0,0 +1,112 @@ +from pathlib import Path +from typing import Any, Dict, Iterator, List, Type, Union + +import jsonlines +from typing_extensions import TypeAlias + +from flytekit import FlyteContext, Literal, LiteralType +from flytekit.core.type_engine import ( + TypeEngine, + TypeTransformer, + TypeTransformerFailedError, +) +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Scalar + +JSONCollection: TypeAlias = Union[Dict[str, Any], List[Any]] +JSONScalar: TypeAlias = Union[bool, float, int, str] +JSON: TypeAlias = Union[JSONCollection, JSONScalar] + + +class JSONIterator: + def __init__(self, reader: jsonlines.Reader): + self._reader = reader + self._reader_iter = reader.iter() + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self._reader_iter) + except StopIteration: + self._reader.close() + raise StopIteration("File handler is exhausted") + + +class JSONIteratorTransformer(TypeTransformer[Iterator[JSON]]): + """ + A JSON iterator that handles conversion between an iterator/generator and a JSONL file. + """ + + JSON_ITERATOR_FORMAT = "JSONL" + + def __init__(self): + super().__init__("JSON Iterator", Iterator[JSON]) + + def get_literal_type(self, t: Type[Iterator[JSON]]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.JSON_ITERATOR_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: Iterator[JSON], + python_type: Type[Iterator[JSON]], + expected: LiteralType, + ) -> Literal: + local_dir = Path(ctx.file_access.get_random_local_directory()) + local_dir.mkdir(exist_ok=True) + local_path = ctx.file_access.get_random_local_path() + uri = str(Path(local_dir) / local_path) + + empty = True + with open(uri, "w") as fp: + with jsonlines.Writer(fp) as writer: + for json_val in python_val: + writer.write(json_val) + empty = False + + if empty: + raise ValueError("The iterator is empty.") + + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.JSON_ITERATOR_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=ctx.file_access.put_raw_data(uri)))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[Iterator[JSON]] + ) -> JSONIterator: + try: + uri = lv.scalar.blob.uri + except AttributeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + fs = ctx.file_access.get_filesystem_for_path(uri) + + fp = fs.open(uri, "r") + reader = jsonlines.Reader(fp) + + return JSONIterator(reader) + + def guess_python_type(self, literal_type: LiteralType) -> Type[Iterator[JSON]]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.JSON_ITERATOR_FORMAT + ): + return Iterator[JSON] # type: ignore + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}.") + + +TypeEngine.register(JSONIteratorTransformer()) diff --git a/plugins/README.md b/plugins/README.md index 81d3ad9530..3eb4fae30c 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -6,7 +6,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | | ---------------------------- | ----------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------- | -| AWS SageMaker | `bash pip install flytekitplugins-awssagemaker` | Deploy SageMaker models and manage inference endpoints with ease. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Python | +| AWS SageMaker | `bash pip install flytekitplugins-awssagemaker` | Deploy SageMaker models and manage inference endpoints with ease. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Flytekit-only | | dask | `bash pip install flytekitplugins-dask ` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | `bash pip install flytekitplugins-hive ` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | `bash pip install flytekitplugins-kfpytorch ` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | @@ -24,6 +24,8 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | dbt | `bash pip install flytekitplugins-dbt` | Run dbt within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dbt.svg)](https://pypi.python.org/pypi/flytekitplugins-dbt/) | Flytekit-only | | Huggingface | `bash pip install flytekitplugins-huggingface` | Read & write Hugginface Datasets as Flyte StructuredDatasets | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-huggingface.svg)](https://pypi.python.org/pypi/flytekitplugins-huggingface/) | Flytekit-only | | DuckDB | `bash pip install flytekitplugins-duckdb` | Run analytical workloads with ease using DuckDB | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-duckdb.svg)](https://pypi.python.org/pypi/flytekitplugins-duckdb/) | Flytekit-only | +| ChatGPT | `bash pip install flytekitplugins-openai` | Interact with OpenAI's ChatGPT. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-openai.svg)](https://pypi.python.org/pypi/flytekitplugins-openai/) | Flytekit-only | +| OpenAI Batch | `bash pip install flytekitplugins-openai` | Submit requests to OpenAI for asynchronous batch processing. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-openai.svg)](https://pypi.python.org/pypi/flytekitplugins-openai/) | Flytekit-only | ## Have a Plugin Idea? 💡 diff --git a/plugins/flytekit-openai/Dockerfile.batch b/plugins/flytekit-openai/Dockerfile.batch new file mode 100644 index 0000000000..2174a82543 --- /dev/null +++ b/plugins/flytekit-openai/Dockerfile.batch @@ -0,0 +1,16 @@ +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-bookworm + +WORKDIR /root +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +ARG VERSION + +RUN pip install flytekitplugins-openai==$VERSION \ + flytekit==$VERSION + +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit diff --git a/plugins/flytekit-openai/README.md b/plugins/flytekit-openai/README.md index f93b634735..48ca3c10ef 100644 --- a/plugins/flytekit-openai/README.md +++ b/plugins/flytekit-openai/README.md @@ -1,7 +1,17 @@ -# Flytekit ChatGPT Plugin -ChatGPT plugin allows you to run ChatGPT tasks in the Flyte workflow without changing any code. +# OpenAI Plugins + +The plugin currently features ChatGPT and Batch API agents. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-openai +``` + +## ChatGPT + +The ChatGPT plugin allows you to run ChatGPT tasks within the Flyte workflow without requiring any code changes. -## Example ```python from flytekit import task, workflow from flytekitplugins.openai import ChatGPTTask, ChatGPTConfig @@ -36,9 +46,71 @@ if __name__ == "__main__": print(wf(message="hi")) ``` +## Batch API -To install the plugin, run the following command: +The Batch API agent allows you to submit requests for asynchronous batch processing on OpenAI. +You can provide either a JSONL file or a JSON iterator, and the agent handles the upload to OpenAI, +creation of the batch, and downloading of the output and error files. -```bash -pip install flytekitplugins-openai +```python +from typing import Iterator + +from flytekit import workflow, Secret +from flytekit.types.file import JSONLFile +from flytekit.types.iterator import JSON +from flytekitplugins.openai import create_batch, BatchResult + + +def jsons(): + for x in [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + ], + }, + }, + { + "custom_id": "request-2", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ], + }, + }, + ]: + yield x + + +it_batch = create_batch( + name="gpt-3.5-turbo", + openai_organization="your-org", + secret=Secret(group="openai-secret", key="api-key"), +) + +file_batch = create_batch( + name="gpt-3.5-turbo", + openai_organization="your-org", + secret=Secret(group="openai-secret", key="api-key"), + is_json_iterator=False, +) + + +@workflow +def json_iterator_wf(json_vals: Iterator[JSON] = jsons()) -> BatchResult: + return it_batch(jsonl_in=json_vals) + + +@workflow +def jsonl_wf(jsonl_file: JSONLFile = "data.jsonl") -> BatchResult: + return file_batch(jsonl_in=jsonl_file) ``` diff --git a/plugins/flytekit-openai/dev-requirements.txt b/plugins/flytekit-openai/dev-requirements.txt new file mode 100644 index 0000000000..2d73dba5b4 --- /dev/null +++ b/plugins/flytekit-openai/dev-requirements.txt @@ -0,0 +1 @@ +pytest-asyncio diff --git a/plugins/flytekit-openai/flytekitplugins/openai/__init__.py b/plugins/flytekit-openai/flytekitplugins/openai/__init__.py index 58e99f747e..263e3fe675 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/__init__.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/__init__.py @@ -1,12 +1,23 @@ """ .. currentmodule:: flytekitplugins.openai -This package contains things that are useful when extending Flytekit. + .. autosummary:: :template: custom.rst :toctree: generated/ + + BatchEndpointAgent + BatchEndpointTask + BatchResult + DownloadJSONFilesTask + UploadJSONLFileTask + OpenAIFileConfig + create_batch ChatGPTAgent ChatGPTTask """ +from .batch.agent import BatchEndpointAgent +from .batch.task import BatchEndpointTask, BatchResult, DownloadJSONFilesTask, OpenAIFileConfig, UploadJSONLFileTask +from .batch.workflow import create_batch from .chatgpt.agent import ChatGPTAgent from .chatgpt.task import ChatGPTTask diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/__init__.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py new file mode 100644 index 0000000000..2c9821b204 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional + +import cloudpickle + +from flytekit import FlyteContextManager, lazy_module +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import ( + AgentRegistry, + AsyncAgentBase, + Resource, + ResourceMeta, +) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +openai = lazy_module("openai") +OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY" + + +class State(Enum): + Running = ["in_progress", "finalizing", "validating"] + Success = ["completed"] + Failed = ["failed", "cancelled", "cancelling", "expired"] + + @classmethod + def key_by_value(cls, value) -> str: + for member in cls: + if value in member.value: + return member.name + + +@dataclass +class BatchEndpointMetadata(ResourceMeta): + openai_org: str + batch_id: str + + def encode(self) -> bytes: + return cloudpickle.dumps(self) + + @classmethod + def decode(cls, data: bytes) -> "BatchEndpointMetadata": + return cloudpickle.loads(data) + + +class BatchEndpointAgent(AsyncAgentBase): + name = "OpenAI Batch Endpoint Agent" + + def __init__(self): + super().__init__(task_type_name="openai-batch", metadata_type=BatchEndpointMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> BatchEndpointMetadata: + ctx = FlyteContextManager.current_context() + input_values = TypeEngine.literal_map_to_kwargs( + ctx, + inputs, + {"input_file_id": str}, + ) + custom = task_template.custom + + async_client = openai.AsyncOpenAI( + organization=custom.get("openai_organization"), + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + custom["config"].setdefault("completion_window", "24h") + custom["config"].setdefault("endpoint", "/v1/chat/completions") + + result = await async_client.batches.create( + **custom["config"], + input_file_id=input_values["input_file_id"], + ) + batch_id = result.id + + return BatchEndpointMetadata(batch_id=batch_id, openai_org=custom["openai_organization"]) + + async def get( + self, + resource_meta: BatchEndpointMetadata, + **kwargs, + ) -> Resource: + async_client = openai.AsyncOpenAI( + organization=resource_meta.openai_org, + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + retrieved_result = await async_client.batches.retrieve(resource_meta.batch_id) + current_state = retrieved_result.status + + flyte_phase = convert_to_flyte_phase(State.key_by_value(current_state)) + + message = None + if current_state in State.Failed.value and retrieved_result.errors: + data = retrieved_result.errors.data + if data and data[0].message: + message = data[0].message + + outputs = {"result": {"result": None}} + if current_state in State.Success.value: + result = retrieved_result.to_dict() + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} + ) + + return Resource(phase=flyte_phase, outputs=outputs, message=message) + + async def delete( + self, + resource_meta: BatchEndpointMetadata, + **kwargs, + ): + async_client = openai.AsyncOpenAI( + organization=resource_meta.openai_org, + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + await async_client.batches.cancel(resource_meta.batch_id) + + +AgentRegistry.register(BatchEndpointAgent()) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py new file mode 100644 index 0000000000..d5ff3af9a3 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +from mashumaro.mixins.json import DataClassJSONMixin + +import flytekit +from flytekit import Resources, kwtypes, lazy_module +from flytekit.configuration import SerializationSettings +from flytekit.configuration.default_images import DefaultImages, PythonVersion +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask +from flytekit.core.shim_task import ShimTaskExecutor +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.models.security import Secret +from flytekit.models.task import TaskTemplate +from flytekit.types.file import JSONLFile + +openai = lazy_module("openai") + + +@dataclass +class BatchResult(DataClassJSONMixin): + output_file: Optional[JSONLFile] = None + error_file: Optional[JSONLFile] = None + + +class BatchEndpointTask(AsyncAgentExecutorMixin, PythonTask): + _TASK_TYPE = "openai-batch" + + def __init__( + self, + name: str, + openai_organization: str, + config: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__( + name=name, + task_type=self._TASK_TYPE, + interface=Interface( + inputs=kwtypes(input_file_id=str), + outputs=kwtypes(result=Dict), + ), + **kwargs, + ) + + self._openai_organization = openai_organization + self._config = config + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "openai_organization": self._openai_organization, + "config": self._config, + } + + +class OpenAIFileDefaultImages(DefaultImages): + """Default images for the openai batch plugin.""" + + _DEFAULT_IMAGE_PREFIXES = { + PythonVersion.PYTHON_3_8: "cr.flyte.org/flyteorg/flytekit:py3.8-openai-batch-", + PythonVersion.PYTHON_3_9: "cr.flyte.org/flyteorg/flytekit:py3.9-openai-batch-", + PythonVersion.PYTHON_3_10: "cr.flyte.org/flyteorg/flytekit:py3.10-openai-batch-", + PythonVersion.PYTHON_3_11: "cr.flyte.org/flyteorg/flytekit:py3.11-openai-batch-", + PythonVersion.PYTHON_3_12: "cr.flyte.org/flyteorg/flytekit:py3.12-openai-batch-", + } + + +@dataclass +class OpenAIFileConfig: + openai_organization: str + secret: Secret + + def _secret_to_dict(self) -> Dict[str, Optional[str]]: + return { + "group": self.secret.group, + "key": self.secret.key, + "group_version": self.secret.group_version, + "mount_requirement": self.secret.mount_requirement.value, + } + + +class UploadJSONLFileTask(PythonCustomizedContainerTask[OpenAIFileConfig]): + _UPLOAD_JSONL_FILE_TASK_TYPE = "openai-batch-upload-file" + + def __init__( + self, + name: str, + task_config: OpenAIFileConfig, + container_image: str = OpenAIFileDefaultImages.default_image(), + **kwargs, + ): + super().__init__( + name=name, + task_config=task_config, + task_type=self._UPLOAD_JSONL_FILE_TASK_TYPE, + executor_type=UploadJSONLFileExecutor, + container_image=container_image, + requests=Resources(mem="700Mi"), + interface=Interface( + inputs=kwtypes( + jsonl_in=JSONLFile, + ), + outputs=kwtypes(result=str), + ), + secret_requests=[task_config.secret], + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "openai_organization": self.task_config.openai_organization, + "secret_arg": self.task_config._secret_to_dict(), + } + + +class UploadJSONLFileExecutor(ShimTaskExecutor[UploadJSONLFileTask]): + def execute_from_model(self, tt: TaskTemplate, **kwargs) -> Any: + secret = tt.custom["secret_arg"] + client = openai.OpenAI( + organization=tt.custom["openai_organization"], + api_key=flytekit.current_context().secrets.get( + group=secret["group"], + key=secret["key"], + group_version=secret["group_version"], + ), + ) + + local_jsonl_file = kwargs["jsonl_in"].download() + uploaded_file_obj = client.files.create(file=open(local_jsonl_file, "rb"), purpose="batch") + return uploaded_file_obj.id + + +class DownloadJSONFilesTask(PythonCustomizedContainerTask[OpenAIFileConfig]): + _DOWNLOAD_JSON_FILES_TASK_TYPE = "openai-batch-download-files" + + def __init__( + self, + name: str, + task_config: OpenAIFileConfig, + container_image: str = OpenAIFileDefaultImages.default_image(), + **kwargs, + ): + super().__init__( + name=name, + task_config=task_config, + task_type=self._DOWNLOAD_JSON_FILES_TASK_TYPE, + executor_type=DownloadJSONFilesExecutor, + container_image=container_image, + requests=Resources(mem="700Mi"), + interface=Interface( + inputs=kwtypes(batch_endpoint_result=Dict), + outputs=kwtypes(result=BatchResult), + ), + secret_requests=[task_config.secret], + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "openai_organization": self.task_config.openai_organization, + "secret_arg": self.task_config._secret_to_dict(), + } + + +class DownloadJSONFilesExecutor(ShimTaskExecutor[DownloadJSONFilesTask]): + def execute_from_model(self, tt: TaskTemplate, **kwargs) -> Any: + secret = tt.custom["secret_arg"] + client = openai.OpenAI( + organization=tt.custom["openai_organization"], + api_key=flytekit.current_context().secrets.get( + group=secret["group"], + key=secret["key"], + group_version=secret["group_version"], + ), + ) + + batch_result = BatchResult() + working_dir = flytekit.current_context().working_directory + + for file_name, file_id in zip( + ("output_file", "error_file"), + ( + kwargs["batch_endpoint_result"]["output_file_id"], + kwargs["batch_endpoint_result"]["error_file_id"], + ), + ): + if file_id: + file_path = str(Path(working_dir, file_name).with_suffix(".jsonl")) + + with client.files.with_streaming_response.content(file_id) as response: + response.stream_to_file(file_path) + + setattr(batch_result, file_name, JSONLFile(file_path)) + + return batch_result diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py new file mode 100644 index 0000000000..027f006b59 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, Iterator + +from flytekit import Workflow +from flytekit.models.security import Secret +from flytekit.types.file import JSONLFile +from flytekit.types.iterator import JSON + +from .task import ( + BatchEndpointTask, + BatchResult, + DownloadJSONFilesTask, + OpenAIFileConfig, + UploadJSONLFileTask, +) + + +def create_batch( + name: str, + openai_organization: str, + secret: Secret, + config: Dict[str, Any] = {}, + is_json_iterator: bool = True, +) -> Workflow: + """ + Uploads JSON data to a JSONL file, creates a batch, waits for it to complete, and downloads the output/error JSON files. + + :param name: The suffix to be added to workflow and task names. + :param openai_organization: Name of the OpenAI organization. + :param secret: Secret comprising the OpenAI API key. + :param config: Additional config for batch creation. + :param is_json_iterator: Set to True if you're sending an iterator/generator; if a JSONL file, set to False. + """ + wf = Workflow(name=f"openai-batch-{name.replace('.', '')}") + + if is_json_iterator: + wf.add_workflow_input("jsonl_in", Iterator[JSON]) + else: + wf.add_workflow_input("jsonl_in", JSONLFile) + + upload_jsonl_file_task_obj = UploadJSONLFileTask( + name=f"openai-file-upload-{name.replace('.', '')}", + task_config=OpenAIFileConfig(openai_organization=openai_organization, secret=secret), + ) + batch_endpoint_task_obj = BatchEndpointTask( + name=f"openai-batch-{name.replace('.', '')}", + openai_organization=openai_organization, + config=config, + ) + download_json_files_task_obj = DownloadJSONFilesTask( + name=f"openai-download-files-{name.replace('.', '')}", + task_config=OpenAIFileConfig(openai_organization=openai_organization, secret=secret), + ) + + node_1 = wf.add_entity( + upload_jsonl_file_task_obj, + jsonl_in=wf.inputs["jsonl_in"], + ) + node_2 = wf.add_entity( + batch_endpoint_task_obj, + input_file_id=node_1.outputs["result"], + ) + node_3 = wf.add_entity( + download_json_files_task_obj, + batch_endpoint_result=node_2.outputs["result"], + ) + + wf.add_workflow_output("batch_output", node_3.outputs["result"], BatchResult) + + return wf diff --git a/plugins/flytekit-openai/setup.py b/plugins/flytekit-openai/setup.py index 9a7fff284a..07db38c212 100644 --- a/plugins/flytekit-openai/setup.py +++ b/plugins/flytekit-openai/setup.py @@ -15,7 +15,11 @@ author_email="admin@flyte.org", description="This package holds the openai plugins for flytekit", namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}", f"flytekitplugins.{PLUGIN_NAME}.chatgpt"], + packages=[ + f"flytekitplugins.{PLUGIN_NAME}", + f"flytekitplugins.{PLUGIN_NAME}.chatgpt", + f"flytekitplugins.{PLUGIN_NAME}.batch", + ], install_requires=plugin_requires, license="apache2", python_requires=">=3.8", diff --git a/plugins/flytekit-openai/tests/chatgpt/__init__.py b/plugins/flytekit-openai/tests/chatgpt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai/tests/test_agent.py b/plugins/flytekit-openai/tests/chatgpt/test_agent.py similarity index 100% rename from plugins/flytekit-openai/tests/test_agent.py rename to plugins/flytekit-openai/tests/chatgpt/test_agent.py diff --git a/plugins/flytekit-openai/tests/test_chatgpt.py b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py similarity index 100% rename from plugins/flytekit-openai/tests/test_chatgpt.py rename to plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py diff --git a/plugins/flytekit-openai/tests/openai_batch/__init__.py b/plugins/flytekit-openai/tests/openai_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai/tests/openai_batch/data.jsonl b/plugins/flytekit-openai/tests/openai_batch/data.jsonl new file mode 100644 index 0000000000..9701cc3a6a --- /dev/null +++ b/plugins/flytekit-openai/tests/openai_batch/data.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}]}} diff --git a/plugins/flytekit-openai/tests/openai_batch/test_agent.py b/plugins/flytekit-openai/tests/openai_batch/test_agent.py new file mode 100644 index 0000000000..3dde953741 --- /dev/null +++ b/plugins/flytekit-openai/tests/openai_batch/test_agent.py @@ -0,0 +1,181 @@ +from datetime import timedelta +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.openai.batch.agent import BatchEndpointMetadata +from openai.types import Batch, BatchError, BatchRequestCounts +from openai.types.batch import Errors + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate + +batch_create_result = Batch( + id="batch_abc123", + object="batch", + endpoint="/v1/completions", + errors=None, + input_file_id="file-abc123", + completion_window="24h", + status="completed", + output_file_id="file-cvaTdG", + error_file_id="file-HOWS94", + created_at=1711471533, + in_progress_at=1711471538, + expires_at=1711557933, + finalizing_at=1711493133, + completed_at=1711493163, + failed_at=None, + expired_at=None, + cancelling_at=None, + cancelled_at=None, + request_counts=BatchRequestCounts(completed=95, failed=5, total=100), + metadata={ + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + }, +) + +batch_retrieve_result = Batch( + id="batch_abc123", + object="batch", + endpoint="/v1/completions", + errors=None, + input_file_id="file-abc123", + completion_window="24h", + status="completed", + output_file_id="file-cvaTdG", + error_file_id="file-HOWS94", + created_at=1711471533, + in_progress_at=1711471538, + expires_at=1711557933, + finalizing_at=1711493133, + completed_at=1711493163, + failed_at=None, + expired_at=None, + cancelling_at=None, + cancelled_at=None, + request_counts=BatchRequestCounts(completed=95, failed=5, total=100), + metadata={ + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + }, +) + +batch_retrieve_result_failure = Batch( + id="batch_JneJt99rNcZZncptC5Ec58hw", + object="batch", + endpoint="/v1/chat/completions", + errors=Errors( + data=[ + BatchError( + code="invalid_json_line", + line=1, + message="This line is not parseable as valid JSON.", + param=None, + ), + BatchError( + code="invalid_json_line", + line=10, + message="This line is not parseable as valid JSON.", + param=None, + ), + ], + object="list", + ), + input_file_id="file-3QV5EKbuUJjpACw0xPaVH6cV", + completion_window="24h", + status="failed", + output_file_id=None, + error_file_id=None, + created_at=1713779467, + in_progress_at=None, + expires_at=1713865867, + finalizing_at=None, + completed_at=None, + failed_at=1713779467, + expired_at=None, + cancelling_at=None, + cancelled_at=None, + request_counts=BatchRequestCounts(completed=0, failed=0, total=0), + metadata=None, +) + + +@pytest.mark.asyncio +@mock.patch("flytekit.current_context") +@mock.patch("openai.resources.batches.AsyncBatches.create", new_callable=AsyncMock) +@mock.patch("openai.resources.batches.AsyncBatches.retrieve", new_callable=AsyncMock) +async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): + agent = AgentRegistry.get_agent("openai-batch") + task_id = Identifier( + resource_type=ResourceType.TASK, + project="project", + domain="domain", + name="name", + version="version", + ) + task_config = { + "openai_organization": "test-openai-orgnization-id", + "config": {"metadata": {"batch_description": "Nightly eval job"}}, + } + task_metadata = TaskMetadata( + discoverable=True, + runtime=RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timeout=timedelta(days=1), + retries=literals.RetryStrategy(3), + interruptible=True, + discovery_version="0.1.1b0", + deprecated_error_message="This is deprecated!", + cache_serializable=True, + pod_template_name="A", + cache_ignore_input_vars=(), + ) + + task_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=None, + type="openai-batch", + ) + + mocked_token = "mocked_openai_api_key" + mock_context.return_value.secrets.get.return_value = mocked_token + + metadata = BatchEndpointMetadata(openai_org="test-openai-orgnization-id", batch_id="batch_abc123") + + # GET + # Status: Completed + mock_retrieve.return_value = batch_retrieve_result + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + + outputs = literal_map_string_repr(resource.outputs) + result = outputs["result"] + + assert result == batch_retrieve_result.to_dict() + + # Status: Failed + mock_retrieve.return_value = batch_retrieve_result_failure + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.FAILED + assert resource.outputs == {"result": {"result": None}} + assert resource.message == "This line is not parseable as valid JSON." + + # CREATE + mock_create.return_value = batch_create_result + task_inputs = literals.LiteralMap( + { + "input_file_id": literals.Literal( + scalar=literals.Scalar(primitive=literals.Primitive(string_value="file-xuefauew")) + ) + }, + ) + response = await agent.create(task_template, task_inputs) + assert response == metadata diff --git a/plugins/flytekit-openai/tests/openai_batch/test_task.py b/plugins/flytekit-openai/tests/openai_batch/test_task.py new file mode 100644 index 0000000000..b2564da6fc --- /dev/null +++ b/plugins/flytekit-openai/tests/openai_batch/test_task.py @@ -0,0 +1,141 @@ +import dataclasses +import os +import tempfile +from collections import OrderedDict +from unittest import mock + +import jsonlines +from flytekitplugins.openai import ( + BatchEndpointTask, + DownloadJSONFilesTask, + OpenAIFileConfig, + UploadJSONLFileTask, +) +from openai.types import FileObject + +from flytekit import Secret +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable +from flytekit.models.types import SimpleType +from flytekit.types.file import JSONLFile + +JSONL_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data.jsonl") + + +def test_openai_batch_endpoint_task(): + batch_endpoint_task = BatchEndpointTask( + name="gpt-3.5-turbo", + openai_organization="testorg", + config={"completion_window": "24h"}, + ) + + assert len(batch_endpoint_task.interface.inputs) == 1 + assert len(batch_endpoint_task.interface.outputs) == 1 + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + batch_endpoint_task_spec = get_serializable(OrderedDict(), serialization_settings, batch_endpoint_task) + custom = batch_endpoint_task_spec.template.custom + + assert custom["openai_organization"] == "testorg" + assert custom["config"] == {"completion_window": "24h"} + + assert batch_endpoint_task_spec.template.interface.inputs["input_file_id"].type.simple == SimpleType.STRING + assert batch_endpoint_task_spec.template.interface.outputs["result"].type.simple == SimpleType.STRUCT + + +@mock.patch( + "openai.resources.files.Files.create", + return_value=FileObject( + id="file-abc123", + object="file", + bytes=120000, + created_at=1677610602, + filename="mydata.jsonl", + purpose="fine-tune", + status="uploaded", + ), +) +@mock.patch("flytekit.current_context") +def test_upload_jsonl_files_task(mock_context, mock_file_creation): + mocked_token = "mocked_openai_api_key" + mock_context.return_value.secrets.get.return_value = mocked_token + mock_context.return_value.working_directory = "/tmp" + + upload_jsonl_files_task_obj = UploadJSONLFileTask( + name="upload-jsonl-1", + task_config=OpenAIFileConfig( + openai_organization="testorg", + secret=Secret(group="test-openai", key="test-key"), + ), + ) + + jsonl_file_output = upload_jsonl_files_task_obj(jsonl_in=JSONLFile(JSONL_FILE)) + assert jsonl_file_output == "file-abc123" + + +@mock.patch("openai.resources.files.FilesWithStreamingResponse") +@mock.patch("flytekit.current_context") +@mock.patch("flytekitplugins.openai.batch.task.Path") +def test_download_files_task(mock_path, mock_context, mock_streaming): + mocked_token = "mocked_openai_api_key" + mock_context.return_value.secrets.get.return_value = mocked_token + + download_json_files_task_obj = DownloadJSONFilesTask( + name="download-json-files", + task_config=OpenAIFileConfig( + openai_organization="testorg", + secret=Secret(group="test-openai", key="test-key"), + ), + ) + + temp_dir = tempfile.TemporaryDirectory() + temp_file_path = os.path.join(temp_dir.name, "output.jsonl") + + with open(temp_file_path, "w") as f: + with jsonlines.Writer(f) as writer: + writer.write_all([{"id": ""}, {"id": ""}]) # dummy outputs + + mock_path.return_value.with_suffix.return_value = temp_file_path + + response_mock = mock.MagicMock() + mock_streaming.return_value.content.return_value.__enter__.return_value = response_mock + response_mock.stream_to_file.return_value = None + + output = download_json_files_task_obj( + batch_endpoint_result={ + "id": "batch_abc123", + "completion_window": "24h", + "created_at": 1711471533, + "endpoint": "/v1/completions", + "input_file_id": "file-abc123", + "object": "batch", + "status": "completed", + "cancelled_at": None, + "cancelling_at": None, + "completed_at": 1711493163, + "error_file_id": "file-HOWS94", + "errors": None, + "expired_at": None, + "expires_at": 1711557933, + "failed_at": None, + "finalizing_at": 1711493133, + "in_progress_at": 1711471538, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + }, + "output_file_id": "file-cvaTdG", + "request_counts": {"completed": 95, "failed": 5, "total": 100}, + } + ) + assert dataclasses.is_dataclass(output) + assert output.output_file is not None + assert output.error_file is not None diff --git a/plugins/flytekit-openai/tests/openai_batch/test_workflow.py b/plugins/flytekit-openai/tests/openai_batch/test_workflow.py new file mode 100644 index 0000000000..f7e56f4ce8 --- /dev/null +++ b/plugins/flytekit-openai/tests/openai_batch/test_workflow.py @@ -0,0 +1,15 @@ +from flytekitplugins.openai import create_batch + +from flytekit import Secret + + +def test_openai_batch_wf(): + openai_batch_wf = create_batch( + name="gpt-3.5-turbo", + openai_organization="testorg", + secret=Secret(group="test-group"), + ) + + assert len(openai_batch_wf.interface.inputs) == 1 + assert len(openai_batch_wf.interface.outputs) == 1 + assert len(openai_batch_wf.nodes) == 3 diff --git a/plugins/setup.py b/plugins/setup.py index 002514f400..ea35649ed7 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -35,6 +35,7 @@ "flytekitplugins-onnxscikitlearn": "flytekit-onnx-scikitlearn", "flytekitplugins-onnxtensorflow": "flytekit-onnx-tensorflow", "flytekitplugins-onnxpytorch": "flytekit-onnx-pytorch", + "flytekitplugins-openai": "flytekit-openai", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", "flytekitplugins-polars": "flytekit-polars", diff --git a/pyproject.toml b/pyproject.toml index 5d78177084..d3bcf722ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "isodate", "jinja2", "joblib", + "jsonlines", "jsonpickle", "keyring>=18.0.1", "markdown-it-py", @@ -42,7 +43,7 @@ dependencies = [ "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", "pyyaml!=6.0.0,!=5.4.0,!=5.4.1", # pyyaml is broken with cython 3: https://github.com/yaml/pyyaml/issues/601 - "requests>=2.18.4,<3.0.0,!=2.32.0,!=2.32.1", + "requests>=2.18.4,<3.0.0,!=2.32.0,!=2.32.1,!=2.32.2", "rich", "rich_click", "s3fs>=2023.3.0,!=2024.3.1", diff --git a/tests/flytekit/unit/types/iterator/data.jsonl b/tests/flytekit/unit/types/iterator/data.jsonl new file mode 100644 index 0000000000..ff6546b0b8 --- /dev/null +++ b/tests/flytekit/unit/types/iterator/data.jsonl @@ -0,0 +1,3 @@ +{"file_name": "0000.png", "text": "One chinhuahua"} +{"file_name": "0001.png", "text": "A german shepherd"} +{"file_name": "0002.png", "text": "This is a golden retriever playing with a ball"} diff --git a/tests/flytekit/unit/types/iterator/test_json_iterator.py b/tests/flytekit/unit/types/iterator/test_json_iterator.py new file mode 100644 index 0000000000..fbef86d791 --- /dev/null +++ b/tests/flytekit/unit/types/iterator/test_json_iterator.py @@ -0,0 +1,137 @@ +import os +from typing import Iterator + +import jsonlines +import pytest + +from flytekit import task, workflow +from flytekit.types.iterator import JSON + +JSONL_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data.jsonl") + + +def jsons(): + for index, text in enumerate(["One chinhuahua", "A german shepherd"]): + yield { + "file_name": f"000{index}.png", + "text": text, + } + + +def jsons_iter(): + return iter( + [ + x + for x in ( + {"file_name": "0000.png", "text": "One chinhuahua"}, + {"file_name": "0001.png", "text": "A german shepherd"}, + ) + ] + ) + + +@task +def jsons_task(x: Iterator[JSON]) -> Iterator[JSON]: + return x + + +@task +def jsons_loop_task(x: Iterator[JSON]) -> Iterator[JSON]: + for val in x: + print(val) + + return x + + +@task +def jsons_iter_task(x: Iterator[JSON]) -> Iterator[JSON]: + return x + + +@task +def jsonl_input(x: Iterator[JSON]): + for val in x: + print(val) + + +@task +def jsons_return_iter() -> Iterator[JSON]: + reader = jsonlines.Reader(open(JSONL_FILE)) + for obj in reader: + yield obj + + +def test_jsons_tasks(): + # 1 + iterator = jsons_task(x=jsons()) + assert isinstance(iterator, Iterator) + + x, y = next(iterator), next(iterator) + assert x == {"file_name": "0000.png", "text": "One chinhuahua"} + assert y == {"file_name": "0001.png", "text": "A german shepherd"} + + with pytest.raises(StopIteration): + next(iterator) + + # 2 + with pytest.raises(TypeError, match="The iterator is empty."): + jsons_loop_task(x=jsons()) + + # 3 + iter_iterator = jsons_iter_task(x=jsons_iter()) + assert isinstance(iter_iterator, Iterator) + + # 4 + jsonl_input(x=jsonlines.Reader(open(JSONL_FILE)).iter()) + + # 5 + return_iter_iterator = jsons_return_iter() + assert isinstance(return_iter_iterator, Iterator) + + x, y, z = ( + next(return_iter_iterator), + next(return_iter_iterator), + next(return_iter_iterator), + ) + assert x == {"file_name": "0000.png", "text": "One chinhuahua"} + assert y == {"file_name": "0001.png", "text": "A german shepherd"} + assert z == { + "file_name": "0002.png", + "text": "This is a golden retriever playing with a ball", + } + + with pytest.raises(StopIteration): + next(return_iter_iterator) + + +@workflow +def jsons_wf(x: Iterator[JSON] = jsons()) -> Iterator[JSON]: + return jsons_task(x=x) + + +@workflow +def jsons_iter_wf(x: Iterator[JSON] = jsons_iter()) -> Iterator[JSON]: + return jsons_iter_task(x=x) + + +@workflow +def jsons_multiple_tasks_wf() -> Iterator[JSON]: + return jsons_task(x=jsons_return_iter()) + + +def test_jsons_wf(): + # 1 + iterator = jsons_wf() + assert isinstance(iterator, Iterator) + + x, y = next(iterator), next(iterator) + assert x == {"file_name": "0000.png", "text": "One chinhuahua"} + assert y == {"file_name": "0001.png", "text": "A german shepherd"} + + # 2 + iter_iterator = jsons_iter_wf() + assert isinstance(iter_iterator, Iterator) + + # 3 + multiple_tasks = jsons_multiple_tasks_wf() + assert isinstance(multiple_tasks, Iterator)