From 6dbb113cc593722a38f61c567a9034f504a8aaf7 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 23 Sep 2024 03:33:41 +0800 Subject: [PATCH] [flyteagent] All agents return dict instead of literal map (#2762) Signed-off-by: Future-Outlier --- .../awssagemaker_inference/boto3_agent.py | 30 +++++++++---------- .../flytekitplugins/bigquery/agent.py | 3 +- plugins/flytekit-bigquery/tests/test_agent.py | 2 +- .../flytekitplugins/openai/batch/agent.py | 4 +-- .../flytekitplugins/snowflake/agent.py | 19 +++++------- 5 files changed, 25 insertions(+), 33 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index 5e34557e40..d254ec5960 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -122,22 +122,20 @@ async def do( ) ) with context_manager.FlyteContextManager.with_context(builder) as new_ctx: - outputs = LiteralMap( - literals={ - "result": TypeEngine.to_literal( - new_ctx, - truncated_result if truncated_result else result, - Annotated[dict, kwtypes(allow_pickle=True)], - TypeEngine.to_literal_type(dict), - ), - "idempotence_token": TypeEngine.to_literal( - new_ctx, - idempotence_token, - str, - TypeEngine.to_literal_type(str), - ), - } - ) + outputs = { + "result": TypeEngine.to_literal( + new_ctx, + truncated_result if truncated_result else result, + Annotated[dict, kwtypes(allow_pickle=True)], + TypeEngine.to_literal_type(dict), + ), + "idempotence_token": TypeEngine.to_literal( + new_ctx, + idempotence_token, + str, + TypeEngine.to_literal_type(str), + ), + } return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 813cc1794a..ff34f7a580 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -84,9 +84,8 @@ def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: if cur_phase == TaskExecution.SUCCEEDED: dst = job.destination if dst: - ctx = FlyteContextManager.current_context() output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" - res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)}) + res = {"results": StructuredDataset(uri=output_location)} return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res) diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index 57d4b747cd..e376d18216 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -90,7 +90,7 @@ def __init__(self): resource = agent.get(metadata) assert resource.phase == TaskExecution.SUCCEEDED assert ( - resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs["results"].uri == "bq://dummy_project:dummy_dataset.dummy_table" ) assert resource.log_links[0].name == "BigQuery Console" diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py index fa01383ca0..8daf236828 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py @@ -105,9 +105,7 @@ async def get( result = retrieved_result.to_dict() ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} - ) + outputs = {"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} return Resource(phase=flyte_phase, outputs=outputs, message=message) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 831b431afa..e4318f8cfb 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -7,7 +7,6 @@ from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret -from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.models.types import LiteralType, StructuredDatasetType @@ -114,16 +113,14 @@ async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output: ctx = FlyteContextManager.current_context() uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}" - res = literals.LiteralMap( - { - "results": TypeEngine.to_literal( - ctx, - StructuredDataset(uri=uri), - StructuredDataset, - LiteralType(structured_dataset_type=StructuredDatasetType(format="")), - ) - } - ) + res = { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=uri), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } return Resource(phase=cur_phase, outputs=res, log_links=[log_link])