Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Agent State with Agent Phase #2123

Merged
merged 15 commits into from
Feb 5, 2024
39 changes: 18 additions & 21 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@

import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
State,
)
from flyteidl.core import literals_pb2
from flyteidl.core.execution_pb2 import TaskExecution
from flyteidl.core.tasks_pb2 import TaskTemplate
from rich.progress import Progress

Expand Down Expand Up @@ -133,26 +129,26 @@
return AgentRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
"""
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
# timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state in ["failed", "timeout", "timedout", "canceled"]:
return RETRYABLE_FAILURE
return TaskExecution.FAILED

Check warning on line 139 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L139

Added line #L139 was not covered by tests
elif state in ["done", "succeeded", "success"]:
return SUCCEEDED
return TaskExecution.SUCCEEDED

Check warning on line 141 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L141

Added line #L141 was not covered by tests
elif state in ["running"]:
return RUNNING
return TaskExecution.RUNNING

Check warning on line 143 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L143

Added line #L143 was not covered by tests
raise ValueError(f"Unrecognized state: {state}")


def is_terminal_state(state: State) -> bool:
def is_terminal_phase(phase: TaskExecution.Phase) -> bool:
"""
Return true if the state is terminal.
Return true if the phase is terminal.
"""
return state in [SUCCEEDED, RETRYABLE_FAILURE, PERMANENT_FAILURE]
return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED]

Check warning on line 151 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L151

Added line #L151 was not covered by tests


def get_agent_secret(secret_key: str) -> str:
Expand Down Expand Up @@ -196,13 +192,13 @@

# If the task is synchronous, the agent will return the output from the resource literals.
if res.HasField("resource"):
if res.resource.state != SUCCEEDED:
if res.resource.phase != TaskExecution.SUCCEEDED:
raise FlyteUserException(f"Failed to run the task {self._entity.name}")
return LiteralMap.from_flyte_idl(res.resource.outputs)

res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.state != SUCCEEDED:
if res.resource.phase != TaskExecution.SUCCEEDED:
raise FlyteUserException(f"Failed to run the task {self._entity.name}")

# Read the literals from a remote file, if agent doesn't return the output literals.
Expand Down Expand Up @@ -241,13 +237,13 @@
return res

async def _get(self, resource_meta: bytes) -> GetTaskResponse:
state = RUNNING
phase = TaskExecution.RUNNING

Check warning on line 240 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L240

Added line #L240 was not covered by tests
grpc_ctx = _get_grpc_context()

progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None)
with progress:
while not is_terminal_state(state):
while not is_terminal_phase(phase):
progress.start_task(task)
time.sleep(1)
if self._agent.asynchronous:
Expand All @@ -257,11 +253,12 @@
sys.exit(1)
else:
res = self._agent.get(grpc_ctx, resource_meta)
state = res.resource.state
progress.print(f"Task state: {State.Name(state)}, State message: {res.resource.message}")
if hasattr(res.resource, "log_links"):
for link in res.resource.log_links:
progress.print(f"{link.name}: {link.uri}")
phase = res.resource.phase

Check warning on line 256 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L256

Added line #L256 was not covered by tests

progress.print(f"Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {res.resource.message}")

Check warning on line 258 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L258

Added line #L258 was not covered by tests
if hasattr(res.resource, "log_links"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really new hasattr here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really new hasattr here?

will remove it, thank you

for link in res.resource.log_links:
progress.print(f"{link.name}: {link.uri}")

Check warning on line 261 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L261

Added line #L261 was not covered by tests
return res

def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any:
Expand Down
11 changes: 7 additions & 4 deletions flytekit/sensor/sensor_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import grpc
import jsonpickle
from flyteidl.admin.agent_pb2 import (
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -52,8 +51,12 @@
sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None

inputs = meta.get(INPUTS, {})
cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING
return GetTaskResponse(resource=Resource(state=cur_state, outputs=None))
cur_phase = (

Check warning on line 54 in flytekit/sensor/sensor_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/sensor/sensor_engine.py#L54

Added line #L54 was not covered by tests
TaskExecution.SUCCEEDED
if await sensor_def("sensor", config=sensor_config).poke(**inputs)
else TaskExecution.RUNNING
)
return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=None))

Check warning on line 59 in flytekit/sensor/sensor_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/sensor/sensor_engine.py#L59

Added line #L59 was not covered by tests

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()
Expand Down
18 changes: 8 additions & 10 deletions plugins/flytekit-airflow/flytekitplugins/airflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import grpc
import jsonpickle
from flyteidl.admin.agent_pb2 import (
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance

from airflow.exceptions import AirflowException, TaskDeferred
Expand Down Expand Up @@ -99,11 +97,11 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -
airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None
airflow_ctx = Context()
message = None
cur_state = RUNNING
cur_phase = TaskExecution.RUNNING

if isinstance(airflow_operator_instance, BaseSensorOperator):
ok = airflow_operator_instance.poke(context=airflow_ctx)
cur_state = SUCCEEDED if ok else RUNNING
cur_phase = TaskExecution.SUCCEEDED if ok else TaskExecution.RUNNING
elif isinstance(airflow_operator_instance, BaseOperator):
if airflow_trigger_instance:
try:
Expand All @@ -118,26 +116,26 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -
# Trigger callback will check the status of the task in the payload, and raise AirflowException if failed.
trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback)
trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload)
cur_state = SUCCEEDED
cur_phase = TaskExecution.SUCCEEDED
except AirflowException as e:
cur_state = RETRYABLE_FAILURE
cur_phase = TaskExecution.FAILED
message = e.__str__()
except asyncio.TimeoutError:
logger.debug("No event received from airflow trigger")
except AirflowException as e:
cur_state = RETRYABLE_FAILURE
cur_phase = TaskExecution.FAILED
message = e.__str__()
else:
# If there is no trigger, it means the operator is not deferrable. In this case, this operator will be
# executed in the creation step. Therefore, we can directly return SUCCEEDED here.
# For instance, SlackWebhookOperator is not deferrable. It sends a message to Slack in the creation step.
# If the message is sent successfully, agent will return SUCCEEDED here. Otherwise, it will raise an exception at creation step.
cur_state = SUCCEEDED
cur_phase = TaskExecution.SUCCEEDED

else:
raise FlyteUserException("Only sensor and operator are supported.")

return GetTaskResponse(resource=Resource(state=cur_state, message=message))
return GetTaskResponse(resource=Resource(phase=cur_phase, message=message))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()
Expand Down
5 changes: 3 additions & 2 deletions plugins/flytekit-airflow/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from airflow.operators.python import PythonOperator
from airflow.sensors.bash import BashSensor
from airflow.sensors.time_sensor import TimeSensor
from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse
from flyteidl.admin.agent_pb2 import DeleteTaskResponse
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.airflow import AirflowObj
from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata

Expand Down Expand Up @@ -94,7 +95,7 @@ async def test_airflow_agent():
res = await agent.async_create(grpc_ctx, "/tmp", dummy_template, None)
metadata = res.resource_meta
res = await agent.async_get(grpc_ctx, metadata)
assert res.resource.state == SUCCEEDED
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.message == ""
res = await agent.async_delete(grpc_ctx, metadata)
assert res == DeleteTaskResponse()
13 changes: 6 additions & 7 deletions plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@

import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution
from google.cloud import bigquery

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase
from flytekit.models import literals
from flytekit.models.core.execution import TaskLog
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -92,12 +91,12 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes
logger.error(job.errors.__str__())
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(job.errors.__str__())
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=TaskExecution.FAILED), log_links=log_links)

cur_state = convert_to_flyte_state(str(job.state))
cur_phase = convert_to_flyte_phase(str(job.state))
res = None

if cur_state == SUCCEEDED:
if cur_phase == TaskExecution.SUCCEEDED:
ctx = FlyteContextManager.current_context()
if job.destination:
output_location = (
Expand All @@ -114,7 +113,7 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes
}
).to_flyte_idl()

return GetTaskResponse(resource=Resource(state=cur_state, outputs=res), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links)

def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
client = bigquery.Client()
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest.mock import MagicMock

import grpc
from flyteidl.admin.agent_pb2 import SUCCEEDED
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.bigquery.agent import Metadata

import flytekit.models.interface as interface_models
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self):
).encode("utf-8")
assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes
res = agent.get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert res.resource.phase == TaskExecution.SUCCEEDED
assert (
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "bq://dummy_project:dummy_dataset.dummy_table"
Expand Down
6 changes: 3 additions & 3 deletions plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import grpc
from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource
from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_state
from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_phase

from flytekit import current_context
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry
Expand Down Expand Up @@ -171,12 +171,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -
logger.exception(f"Failed to obtain status for MMCloud job: {job_id}")
raise

task_state = mmcloud_status_to_flyte_state(job_status)
task_phase = mmcloud_status_to_flyte_phase(job_status)

logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}")
logger.debug(f"OpCenter response: {show_response}")

return GetTaskResponse(resource=Resource(state=task_state))
return GetTaskResponse(resource=Resource(phase=task_phase))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
"""
Expand Down
48 changes: 24 additions & 24 deletions plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,39 @@
from decimal import ROUND_CEILING, Decimal
from typing import Optional, Tuple

from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RETRYABLE_FAILURE, RUNNING, SUCCEEDED, State
from flyteidl.core.execution_pb2 import TaskExecution
from kubernetes.utils.quantity import parse_quantity

from flytekit.core.resources import Resources

MMCLOUD_STATUS_TO_FLYTE_STATE = {
"Submitted": RUNNING,
"Initializing": RUNNING,
"Starting": RUNNING,
"Executing": RUNNING,
"Capturing": RUNNING,
"Floating": RUNNING,
"Suspended": RUNNING,
"Suspending": RUNNING,
"Resuming": RUNNING,
"Completed": SUCCEEDED,
"Cancelled": PERMANENT_FAILURE,
"Cancelling": PERMANENT_FAILURE,
"FailToComplete": RETRYABLE_FAILURE,
"FailToExecute": RETRYABLE_FAILURE,
"CheckpointFailed": RETRYABLE_FAILURE,
"Timedout": RETRYABLE_FAILURE,
"NoAvailableHost": RETRYABLE_FAILURE,
"Unknown": RETRYABLE_FAILURE,
"WaitingForLicense": PERMANENT_FAILURE,
MMCLOUD_STATUS_TO_FLYTE_PHASE = {
"Submitted": TaskExecution.RUNNING,
"Initializing": TaskExecution.RUNNING,
"Starting": TaskExecution.RUNNING,
"Executing": TaskExecution.RUNNING,
"Capturing": TaskExecution.RUNNING,
"Floating": TaskExecution.RUNNING,
"Suspended": TaskExecution.RUNNING,
"Suspending": TaskExecution.RUNNING,
"Resuming": TaskExecution.RUNNING,
"Completed": TaskExecution.SUCCEEDED,
"Cancelled": TaskExecution.FAILED,
"Cancelling": TaskExecution.FAILED,
"FailToComplete": TaskExecution.FAILED,
"FailToExecute": TaskExecution.FAILED,
"CheckpointFailed": TaskExecution.FAILED,
"Timedout": TaskExecution.FAILED,
"NoAvailableHost": TaskExecution.FAILED,
"Unknown": TaskExecution.FAILED,
"WaitingForLicense": TaskExecution.FAILED,
}


def mmcloud_status_to_flyte_state(status: str) -> State:
def mmcloud_status_to_flyte_phase(status: str) -> TaskExecution.Phase:
"""
Map MMCloud status to Flyte state.
Map MMCloud status to Flyte phase.
"""
return MMCLOUD_STATUS_TO_FLYTE_STATE[status]
return MMCLOUD_STATUS_TO_FLYTE_PHASE[status]


def flyte_to_mmcloud_resources(
Expand Down
10 changes: 5 additions & 5 deletions plugins/flytekit-mmcloud/tests/test_mmcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import grpc
import pytest
from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RUNNING, SUCCEEDED
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.mmcloud import MMCloudAgent, MMCloudConfig, MMCloudTask
from flytekitplugins.mmcloud.utils import async_check_output, flyte_to_mmcloud_resources

Expand Down Expand Up @@ -125,14 +125,14 @@ def say_hello0(name: str) -> str:
resource_meta = create_task_response.resource_meta

get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta))
state = get_task_response.resource.state
assert state in (RUNNING, SUCCEEDED)
phase = get_task_response.resource.phase
assert phase in (TaskExecution.RUNNING, TaskExecution.SUCCEEDED)

asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta))

get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta))
state = get_task_response.resource.state
assert state == PERMANENT_FAILURE
phase = get_task_response.resource.phase
assert phase == TaskExecution.FAILED

@task(
task_config=MMCloudConfig(submit_extra="--nonexistent"),
Expand Down
Loading
Loading