Skip to content

Commit

Permalink
add output_prefix back as an arg to agent create
Browse files Browse the repository at this point in the history
Signed-off-by: noahjax <[email protected]>
  • Loading branch information
noahjax committed Mar 15, 2024
1 parent 4208da2 commit 8c048ac
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 8 deletions.
7 changes: 6 additions & 1 deletion flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
agent = AgentRegistry.get_agent(template.type, template.task_type_version)

logger.info(f"{agent.name} start creating the job")
resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs)
resource_mata = await mirror_async_methods(
agent.create,
task_template=template,
output_prefix=request.output_prefix,
inputs=inputs,
)
return CreateTaskResponse(resource_meta=resource_mata.encode())

@record_agent_metrics
Expand Down
5 changes: 4 additions & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def metadata_type(self) -> ResourceMeta:
return self._metadata_type

@abstractmethod
def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> ResourceMeta:
def create(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap], **kwargs
) -> ResourceMeta:
"""
Return a resource meta that can be used to get the status of the task.
"""
Expand Down Expand Up @@ -311,6 +313,7 @@ async def _create(
resource_meta = await mirror_async_methods(
self._agent.create,
task_template=task_template,
output_prefix=output_prefix,
inputs=literal_map,
)

Expand Down
4 changes: 3 additions & 1 deletion flytekit/sensor/sensor_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class SensorEngine(AsyncAgentBase):
def __init__(self):
super().__init__(task_type_name="sensor", metadata_type=SensorMetadata)

async def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwarg) -> SensorMetadata:
async def create(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap], **kwarg
) -> SensorMetadata:
sensor_metadata = SensorMetadata(**task_template.custom)

if inputs:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-airflow/flytekitplugins/airflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self):
super().__init__(task_type_name="airflow", metadata_type=AirflowMetadata)

async def create(
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, output_prefix: str = "", **kwargs
) -> AirflowMetadata:
airflow_obj = jsonpickle.decode(task_template.custom["task_config_pkl"])
airflow_instance = _get_airflow_instance(airflow_obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self):
def create(
self,
task_template: TaskTemplate,
output_prefix: str,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> BigQueryMetadata:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def async_login(self):
logger.info("Logged in to OpCenter")

async def create(
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> MMCloudMetadata:
"""
Submit a Flyte task as MMCloud job to the OpCenter, and return the job UID for the task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self):
super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata)

async def create(
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
self, task_template: TaskTemplate, output_prefix: str = "", inputs: Optional[LiteralMap] = None, **kwargs
) -> SnowflakeJobMetadata:
ctx = FlyteContextManager.current_context()
literal_types = task_template.interface.inputs
Expand Down
6 changes: 4 additions & 2 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class DummyAgent(AsyncAgentBase):
def __init__(self):
super().__init__(task_type_name="dummy", metadata_type=DummyMetadata)

def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap], **kwargs) -> DummyMetadata:
def create(
self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Optional[LiteralMap], **kwargs
) -> DummyMetadata:
return DummyMetadata(job_id=dummy_id)

def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
Expand All @@ -72,7 +74,7 @@ def __init__(self):
super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata)

async def create(
self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs
self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Optional[LiteralMap] = None, **kwargs
) -> DummyMetadata:
return DummyMetadata(job_id=dummy_id)

Expand Down

0 comments on commit 8c048ac

Please sign in to comment.