Skip to content

Commit

Permalink
[flyteagent] All agents return dict instead of literal map (#2762)
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored and kumare3 committed Nov 8, 2024
1 parent 9423f8b commit 6dbb113
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,20 @@ async def do(
)
)
with context_manager.FlyteContextManager.with_context(builder) as new_ctx:
outputs = LiteralMap(
literals={
"result": TypeEngine.to_literal(
new_ctx,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
"idempotence_token": TypeEngine.to_literal(
new_ctx,
idempotence_token,
str,
TypeEngine.to_literal_type(str),
),
}
)
outputs = {
"result": TypeEngine.to_literal(
new_ctx,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
"idempotence_token": TypeEngine.to_literal(
new_ctx,
idempotence_token,
str,
TypeEngine.to_literal_type(str),
),
}

return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs)

Expand Down
3 changes: 1 addition & 2 deletions plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource:
if cur_phase == TaskExecution.SUCCEEDED:
dst = job.destination
if dst:
ctx = FlyteContextManager.current_context()
output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}"
res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)})
res = {"results": StructuredDataset(uri=output_location)}

return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res)

Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self):
resource = agent.get(metadata)
assert resource.phase == TaskExecution.SUCCEEDED
assert (
resource.outputs.literals["results"].scalar.structured_dataset.uri
resource.outputs["results"].uri
== "bq://dummy_project:dummy_dataset.dummy_table"
)
assert resource.log_links[0].name == "BigQuery Console"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ async def get(
result = retrieved_result.to_dict()

ctx = FlyteContextManager.current_context()
outputs = LiteralMap(
literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}
)
outputs = {"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}

return Resource(phase=flyte_phase, outputs=outputs, message=message)

Expand Down
19 changes: 8 additions & 11 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType
Expand Down Expand Up @@ -114,16 +113,14 @@ async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource:
if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output:
ctx = FlyteContextManager.current_context()
uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}"
res = literals.LiteralMap(
{
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=uri),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}
)
res = {
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=uri),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}

return Resource(phase=cur_phase, outputs=res, log_links=[log_link])

Expand Down

0 comments on commit 6dbb113

Please sign in to comment.