Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update openai batch test and workflow #2440

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,10 @@ def _version_from_hash(
h.update(bytes(s, "utf-8"))

if default_inputs:
h.update(cloudpickle.dumps(default_inputs))
try:
h.update(cloudpickle.dumps(default_inputs))
except TypeError: # cannot pickle errors
logger.info("Skip pickling default inputs.")

# Omit the character '=' from the version as that's essentially padding used by the base64 encoding
# and does not increase entropy of the hash while making it very inconvenient to copy-and-paste.
Expand Down
3 changes: 0 additions & 3 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,10 @@ def _downloader():
return ff

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
from flytekit.types.iterator.json_iterator import JSONIteratorTransformer

if (
literal_type.blob is not None
and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT
and literal_type.blob.format != JSONIteratorTransformer.JSON_ITERATOR_FORMAT
):
return FlyteFile.__class_getitem__(literal_type.blob.format)

Expand Down
5 changes: 4 additions & 1 deletion flytekit/types/iterator/json_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class JSONIteratorTransformer(TypeTransformer[Iterator[JSON]]):
"""

JSON_ITERATOR_FORMAT = "jsonl"
JSON_ITERATOR_METADATA = "json iterator"

def __init__(self):
super().__init__("JSON Iterator", Iterator[JSON])
Expand All @@ -49,7 +50,8 @@ def get_literal_type(self, t: Type[Iterator[JSON]]) -> LiteralType:
blob=_core_types.BlobType(
format=self.JSON_ITERATOR_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
),
metadata={"format": self.JSON_ITERATOR_METADATA},
)

def to_literal(
Expand Down Expand Up @@ -103,6 +105,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[Iterator[JSON]]:
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.JSON_ITERATOR_FORMAT
and literal_type.metadata == {"format": self.JSON_ITERATOR_METADATA}
):
return Iterator[JSON] # type: ignore

Expand Down
14 changes: 6 additions & 8 deletions plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,12 @@ async def get(
if data and data[0].message:
message = data[0].message

outputs = {"result": {"result": None}}
if current_state in State.Success.value:
result = retrieved_result.to_dict()

ctx = FlyteContextManager.current_context()
outputs = LiteralMap(
literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}
)
result = retrieved_result.to_dict()

ctx = FlyteContextManager.current_context()
outputs = LiteralMap(
literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}
)

return Resource(phase=flyte_phase, outputs=outputs, message=message)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Iterator

from flytekit import Workflow
from flytekit import Resources, Workflow
from flytekit.models.security import Secret
from flytekit.types.file import JSONLFile
from flytekit.types.iterator import JSON
Expand All @@ -20,6 +20,8 @@ def create_batch(
secret: Secret,
config: Dict[str, Any] = {},
is_json_iterator: bool = True,
file_upload_mem: str = "700Mi",
file_download_mem: str = "700Mi",
) -> Workflow:
"""
Uploads JSON data to a JSONL file, creates a batch, waits for it to complete, and downloads the output/error JSON files.
Expand All @@ -29,6 +31,8 @@ def create_batch(
:param secret: Secret comprising the OpenAI API key.
:param config: Additional config for batch creation.
:param is_json_iterator: Set to True if you're sending an iterator/generator; if a JSONL file, set to False.
:param file_upload_mem: Memory to allocate to the upload file task.
:param file_download_mem: Memory to allocate to the download file task.
"""
wf = Workflow(name=f"openai-batch-{name.replace('.', '')}")

Expand Down Expand Up @@ -64,6 +68,9 @@ def create_batch(
batch_endpoint_result=node_2.outputs["result"],
)

node_1.with_overrides(requests=Resources(mem=file_upload_mem), limits=Resources(mem=file_upload_mem))
node_3.with_overrides(requests=Resources(mem=file_download_mem), limits=Resources(mem=file_download_mem))

wf.add_workflow_output("batch_output", node_3.outputs["result"], BatchResult)

return wf
1 change: 0 additions & 1 deletion plugins/flytekit-openai/tests/openai_batch/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context):
mock_retrieve.return_value = batch_retrieve_result_failure
resource = await agent.get(metadata)
assert resource.phase == TaskExecution.FAILED
assert resource.outputs == {"result": {"result": None}}
assert resource.message == "This line is not parseable as valid JSON."

# CREATE
Expand Down
Loading