Skip to content

Commit

Permalink
add simple test to check that AysncAgent subclass can consume output_…
Browse files Browse the repository at this point in the history
…prefix

Signed-off-by: noahjax <[email protected]>
  • Loading branch information
noahjax committed Mar 18, 2024
1 parent 059cdb7 commit cdf11c8
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
@dataclass
class DummyMetadata(ResourceMeta):
job_id: str
output_path: typing.Optional[str] = None


class DummyAgent(AsyncAgentBase):
Expand All @@ -72,9 +73,14 @@ def __init__(self):
super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata)

async def create(
self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs
self,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
output_prefix: typing.Optional[str] = None,
**kwargs,
) -> DummyMetadata:
return DummyMetadata(job_id=dummy_id)
output_path = f"{output_prefix}/{dummy_id}" if output_prefix else None
return DummyMetadata(job_id=dummy_id, output_path=output_path)

async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
Expand Down Expand Up @@ -164,7 +170,7 @@ async def test_async_agent_service(agent):

inputs_proto = task_inputs.to_flyte_idl()
output_prefix = "/tmp"
metadata_bytes = DummyMetadata(job_id=dummy_id).encode()
metadata_bytes = DummyMetadata(job_id=dummy_id, output_path=f"{output_prefix/{dummy_id}}").encode()

tmp = get_task_template(agent.task_category.name).to_flyte_idl()
task_category = TaskCategory(name=agent.task_category.name, version=0)
Expand Down

0 comments on commit cdf11c8

Please sign in to comment.