diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index d9ab7a97c2..6d5b8be8e9 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -271,9 +271,9 @@ jobs: run: | make setup cd plugins/${{ matrix.plugin-names }} - pip install . + pip install --pre . if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi - pip install -U $GITHUB_WORKSPACE + pip install --pre -U $GITHUB_WORKSPACE pip freeze - name: Test with coverage run: | @@ -307,7 +307,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r dev-requirements.in + make setup + pip freeze - name: Lint run: | make lint @@ -329,6 +330,8 @@ jobs: run: | python -m pip install --upgrade pip==21.2.4 setuptools wheel pip install -r doc-requirements.txt + make setup + pip freeze - name: Build the documentation run: | # TODO: Remove after buf migration is done and packages updated diff --git a/Makefile b/Makefile index 95e6d538ca..7d28849741 100644 --- a/Makefile +++ b/Makefile @@ -24,8 +24,9 @@ update_boilerplate: .PHONY: setup setup: install-piptools ## Install requirements - pip install flyteidl --pre - pip install -r dev-requirements.in + pip install --pre -r dev-requirements.in + pip install git+https://github.com/flyteorg/flyte.git@7711df2cebaaa6a2dc8d7de2149859eed5ba0cc2#subdirectory=flyteidl + .PHONY: fmt fmt: diff --git a/doc-requirements.in b/doc-requirements.in index 3b602e7e52..4fe6bd570d 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -1,5 +1,3 @@ --e file:.#egg=flytekit - grpcio git+https://github.com/flyteorg/furo@main sphinx diff --git a/doc-requirements.txt b/doc-requirements.txt index 5c5c660b38..95be4948ee 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1,36 +1,17 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# make doc-requirements.txt +# pip-compile doc-requirements.in # --e file:.#egg=flytekit - # via -r doc-requirements.in -adlfs==2023.9.0 - # via flytekit -aiobotocore==2.5.4 - # via s3fs -aiohttp==3.9.2 - # via - # adlfs - # aiobotocore - # gcsfs - # s3fs -aioitertools==0.11.0 - # via aiobotocore -aiosignal==1.3.1 - # via aiohttp alabaster==0.7.16 # via sphinx anyio==4.2.0 # via - # azure-core # starlette # watchfiles aplus==0.11.0 # via vaex-core -arrow==1.3.0 - # via cookiecutter astroid==3.0.2 # via sphinx-autoapi astropy==6.0.0 @@ -39,21 +20,6 @@ astropy-iers-data==0.2024.1.8.0.30.55 # via astropy asttokens==2.4.1 # via stack-data -async-timeout==4.0.3 - # via aiohttp -attrs==23.2.0 - # via aiohttp -azure-core==1.29.6 - # via - # adlfs - # azure-identity - # azure-storage-blob -azure-datalake-store==0.0.53 - # via adlfs -azure-identity==1.15.0 - # via adlfs -azure-storage-blob==12.19.0 - # via adlfs babel==2.14.0 # via sphinx beautifulsoup4==4.12.2 @@ -61,12 +27,8 @@ beautifulsoup4==4.12.2 # furo # sphinx-code-include # sphinx-material -binaryornot==0.4.4 - # via cookiecutter blake3==0.3.4 # via vaex-core -botocore==1.31.17 - # via aiobotocore bqplot==0.12.42 # via # ipyvolume @@ -78,69 +40,40 @@ cachetools==5.3.2 # google-auth # vaex-server certifi==2023.11.17 - # via - # kubernetes - # requests + # via requests cffi==1.16.0 - # via - # azure-datalake-store - # cryptography + # via cryptography cfgv==3.4.0 # via pre-commit -chardet==5.2.0 - # via binaryornot charset-normalizer==3.3.2 # via requests click==8.1.7 # via - # cookiecutter # dask - # flytekit - # rich-click # sphinx-click # uvicorn cloudpickle==3.0.0 # via # dask - # flytekit # vaex-core comm==0.2.1 # via ipywidgets contourpy==1.2.0 # via matplotlib -cookiecutter==2.5.0 - # via flytekit -croniter==2.0.1 - # via flytekit cryptography==41.0.7 - # via - # -r doc-requirements.in - # azure-identity - # azure-storage-blob - # msal - # pyjwt - # secretstorage + # via -r doc-requirements.in css-html-js-minify==2.5.5 # via sphinx-material cycler==0.12.1 # via matplotlib dask==2023.12.1 # via vaex-core -dataclasses-json==0.5.9 - # via flytekit decorator==5.1.1 # via - # gcsfs # ipython # retry -diskcache==5.6.3 - # via flytekit distlib==0.3.8 # via virtualenv -docker==6.1.3 - # via flytekit -docstring-parser==0.15 - # via flytekit docutils==0.17.1 # via # sphinx @@ -158,70 +91,31 @@ filelock==3.13.1 # via # vaex-core # virtualenv -flyteidl==1.10.6 - # via flytekit fonttools==4.47.0 # via matplotlib frozendict==2.4.0 # via vaex-core -frozenlist==1.4.1 - # via - # aiohttp - # aiosignal fsspec==2023.9.2 - # via - # adlfs - # dask - # flytekit - # gcsfs - # s3fs + # via dask furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in future==0.18.3 # via vaex-core -gcsfs==2023.9.2 - # via flytekit google-api-core[grpc]==2.15.0 - # via - # -r doc-requirements.in - # google-cloud-core - # google-cloud-storage + # via -r doc-requirements.in google-auth==2.26.1 - # via - # gcsfs - # google-api-core - # google-auth-oauthlib - # google-cloud-core - # google-cloud-storage - # kubernetes -google-auth-oauthlib==1.2.0 - # via gcsfs -google-cloud-core==2.4.1 - # via google-cloud-storage -google-cloud-storage==2.14.0 - # via gcsfs -google-crc32c==1.5.0 - # via - # google-cloud-storage - # google-resumable-media -google-resumable-media==2.7.0 - # via google-cloud-storage + # via google-api-core googleapis-common-protos==1.62.0 # via - # flyteidl - # flytekit # google-api-core # grpcio-status grpcio==1.60.0 # via # -r doc-requirements.in - # flytekit # google-api-core # grpcio-status grpcio-status==1.60.0 - # via - # flytekit - # google-api-core + # via google-api-core h11==0.14.0 # via uvicorn h5py==3.10.0 @@ -234,14 +128,10 @@ idna==3.6 # via # anyio # requests - # yarl imagesize==1.4.1 # via sphinx importlib-metadata==7.0.1 - # via - # dask - # flytekit - # keyring + # via dask ipydatawidgets==4.3.5 # via pythreejs ipyleaflet==0.18.1 @@ -275,39 +165,20 @@ ipywidgets==8.1.1 # ipyvolume # ipyvue # pythreejs -isodate==0.6.1 - # via azure-storage-blob -jaraco-classes==3.3.0 - # via keyring jedi==0.19.1 # via ipython -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.3 # via # branca - # cookiecutter # sphinx # sphinx-autoapi # vaex-ml -jmespath==1.0.1 - # via botocore joblib==1.3.2 - # via - # flytekit - # scikit-learn -jsonpickle==3.0.2 - # via flytekit + # via scikit-learn jupyterlab-widgets==3.0.9 # via ipywidgets -keyring==24.3.0 - # via flytekit kiwisolver==1.4.5 # via matplotlib -kubernetes==29.0.0 - # via flytekit llvmlite==0.41.1 # via numba locket==1.0.0 @@ -318,21 +189,8 @@ markdown-it-py==3.0.0 # via rich markupsafe==2.1.3 # via jinja2 -marshmallow==3.20.2 - # via - # dataclasses-json - # marshmallow-enum - # marshmallow-jsonschema -marshmallow-enum==1.5.1 - # via - # dataclasses-json - # flytekit -marshmallow-jsonschema==0.13.0 - # via flytekit mashumaro==3.11 - # via - # -r doc-requirements.in - # flytekit + # via -r doc-requirements.in matplotlib==3.8.2 # via # ipympl @@ -342,21 +200,6 @@ matplotlib-inline==0.1.6 # via ipython mdurl==0.1.2 # via markdown-it-py -more-itertools==10.2.0 - # via jaraco-classes -msal==1.26.0 - # via - # azure-datalake-store - # azure-identity - # msal-extensions -msal-extensions==1.1.0 - # via azure-identity -multidict==6.0.4 - # via - # aiohttp - # yarl -mypy-extensions==1.0.0 - # via typing-inspect nest-asyncio==1.5.8 # via vaex-core nodeenv==1.8.0 @@ -382,18 +225,11 @@ numpy==1.26.3 # scipy # vaex-core # xarray -oauthlib==3.2.2 - # via - # kubernetes - # requests-oauthlib packaging==23.2 # via # astropy # dask - # docker - # marshmallow # matplotlib - # msal-extensions # sphinx # xarray pandas==2.1.4 @@ -416,8 +252,6 @@ pillow==10.2.0 # vaex-viz platformdirs==4.1.0 # via virtualenv -portalocker==2.8.2 - # via msal-extensions pre-commit==3.6.0 # via sphinx-tags progressbar2==4.3.2 @@ -426,14 +260,9 @@ prompt-toolkit==3.0.43 # via ipython protobuf==4.24.4 # via - # flyteidl - # flytekit # google-api-core # googleapis-common-protos # grpcio-status - # protoc-gen-swagger -protoc-gen-swagger==0.1.0 - # via flyteidl ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 @@ -441,9 +270,7 @@ pure-eval==0.2.2 py==1.11.0 # via retry pyarrow==14.0.2 - # via - # flytekit - # vaex-core + # via vaex-core pyasn1==0.5.1 # via # pyasn1-modules @@ -466,96 +293,50 @@ pygments==2.17.2 # rich # sphinx # sphinx-prompt -pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt pyparsing==3.1.1 # via matplotlib python-dateutil==2.8.2 # via - # arrow - # botocore - # croniter - # kubernetes # matplotlib # pandas python-dotenv==1.0.0 # via uvicorn -python-json-logger==2.0.7 - # via flytekit python-slugify[unidecode]==8.0.1 - # via - # cookiecutter - # sphinx-material + # via sphinx-material python-utils==3.8.1 # via progressbar2 pythreejs==2.4.2 # via ipyvolume -pytimeparse==1.1.8 - # via flytekit pytz==2023.3.post1 - # via - # croniter - # pandas + # via pandas pyyaml==6.0.1 # via # astropy - # cookiecutter # dask - # flytekit - # kubernetes # pre-commit # sphinx-autoapi # uvicorn # vaex-core requests==2.31.0 # via - # azure-core - # azure-datalake-store - # cookiecutter - # docker - # flytekit - # gcsfs # google-api-core - # google-cloud-storage # ipyvolume - # kubernetes - # msal - # requests-oauthlib # sphinx # sphinxcontrib-youtube # vaex-core -requests-oauthlib==1.3.1 - # via - # google-auth-oauthlib - # kubernetes retry==0.9.2 # via -r doc-requirements.in rich==13.7.0 - # via - # cookiecutter - # flytekit - # rich-click - # vaex-core -rich-click==1.7.3 - # via flytekit + # via vaex-core rsa==4.9 # via google-auth -s3fs==2023.9.2 - # via flytekit scikit-learn==1.3.2 # via -r doc-requirements.in scipy==1.11.4 # via scikit-learn -secretstorage==3.3.3 - # via keyring six==1.16.0 # via # asttokens - # azure-core - # isodate - # kubernetes # python-dateutil # sphinx-code-include # vaex-core @@ -624,8 +405,6 @@ stack-data==0.6.3 # via ipython starlette==0.32.0.post1 # via fastapi -statsd==3.3.0 - # via flytekit tabulate==0.9.0 # via vaex-core text-unidecode==1.3 @@ -656,24 +435,15 @@ traittypes==0.2.1 # ipydatawidgets # ipyleaflet # ipyvolume -types-python-dateutil==2.8.19.20240106 - # via arrow typing-extensions==4.9.0 # via # anyio # astroid - # azure-core - # azure-storage-blob # fastapi - # flytekit # mashumaro # pydantic # python-utils - # rich-click - # typing-inspect # uvicorn -typing-inspect==0.9.0 - # via dataclasses-json tzdata==2023.4 # via pandas unidecode==1.3.7 @@ -681,12 +451,7 @@ unidecode==1.3.7 # python-slugify # sphinx-autoapi urllib3==1.26.18 - # via - # botocore - # docker - # flytekit - # kubernetes - # requests + # via requests uvicorn[standard]==0.25.0 # via vaex-server uvloop==0.19.0 @@ -722,22 +487,14 @@ watchfiles==0.21.0 # via uvicorn wcwidth==0.2.13 # via prompt-toolkit -websocket-client==1.7.0 - # via - # docker - # kubernetes websockets==12.0 # via uvicorn widgetsnbextension==4.0.9 # via ipywidgets -wrapt==1.16.0 - # via aiobotocore xarray==2023.12.0 # via vaex-jupyter xyzservices==2023.10.1 # via ipyleaflet -yarl==1.9.4 - # via aiohttp zipp==3.17.0 # via importlib-metadata diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 28bc18ddd2..3bbdb8c49c 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -10,17 +10,13 @@ import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, - RETRYABLE_FAILURE, - RUNNING, - SUCCEEDED, Agent, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, - State, ) from flyteidl.core import literals_pb2 +from flyteidl.core.execution_pb2 import TaskExecution from flyteidl.core.tasks_pb2 import TaskTemplate from rich.progress import Progress @@ -152,26 +148,26 @@ def get_agent_metadata(name: str) -> Agent: return AgentRegistry._METADATA[name] -def convert_to_flyte_state(state: str) -> State: +def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: """ Convert the state from the agent to the state in flyte. """ state = state.lower() # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate if state in ["failed", "timeout", "timedout", "canceled"]: - return RETRYABLE_FAILURE + return TaskExecution.FAILED elif state in ["done", "succeeded", "success"]: - return SUCCEEDED + return TaskExecution.SUCCEEDED elif state in ["running"]: - return RUNNING + return TaskExecution.RUNNING raise ValueError(f"Unrecognized state: {state}") -def is_terminal_state(state: State) -> bool: +def is_terminal_phase(phase: TaskExecution.Phase) -> bool: """ - Return true if the state is terminal. + Return true if the phase is terminal. """ - return state in [SUCCEEDED, RETRYABLE_FAILURE, PERMANENT_FAILURE] + return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] def get_agent_secret(secret_key: str) -> str: @@ -215,13 +211,13 @@ def execute(self, **kwargs) -> typing.Any: # If the task is synchronous, the agent will return the output from the resource literals. if res.HasField("resource"): - if res.resource.state != SUCCEEDED: + if res.resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self._entity.name}") return LiteralMap.from_flyte_idl(res.resource.outputs) res = asyncio.run(self._get(resource_meta=res.resource_meta)) - if res.resource.state != SUCCEEDED: + if res.resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self._entity.name}") # Read the literals from a remote file, if agent doesn't return the output literals. @@ -260,13 +256,13 @@ async def _create( return res async def _get(self, resource_meta: bytes) -> GetTaskResponse: - state = RUNNING + phase = TaskExecution.RUNNING grpc_ctx = _get_grpc_context() progress = Progress(transient=True) task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) with progress: - while not is_terminal_state(state): + while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) if self._agent.asynchronous: @@ -276,11 +272,12 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: sys.exit(1) else: res = self._agent.get(grpc_ctx, resource_meta) - state = res.resource.state - progress.print(f"Task state: {State.Name(state)}, State message: {res.resource.message}") - if hasattr(res.resource, "log_links"): - for link in res.resource.log_links: - progress.print(f"{link.name}: {link.uri}") + phase = res.resource.phase + + progress.print(f"Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {res.resource.message}") + for link in res.resource.log_links: + progress.print(f"{link.name}: {link.uri}") + return res def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 5fa16162a0..3007fd9df2 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -6,13 +6,12 @@ import grpc import jsonpickle from flyteidl.admin.agent_pb2 import ( - RUNNING, - SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource, ) +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine @@ -54,8 +53,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None inputs = meta.get(INPUTS, {}) - cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING - return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) + cur_phase = ( + TaskExecution.SUCCEEDED + if await sensor_def("sensor", config=sensor_config).poke(**inputs) + else TaskExecution.RUNNING + ) + return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=None)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py index b84deb009d..9ed3e9706e 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py @@ -7,14 +7,12 @@ import grpc import jsonpickle from flyteidl.admin.agent_pb2 import ( - RETRYABLE_FAILURE, - RUNNING, - SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource, ) +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance from airflow.exceptions import AirflowException, TaskDeferred @@ -101,11 +99,11 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None airflow_ctx = Context() message = None - cur_state = RUNNING + cur_phase = TaskExecution.RUNNING if isinstance(airflow_operator_instance, BaseSensorOperator): ok = airflow_operator_instance.poke(context=airflow_ctx) - cur_state = SUCCEEDED if ok else RUNNING + cur_phase = TaskExecution.SUCCEEDED if ok else TaskExecution.RUNNING elif isinstance(airflow_operator_instance, BaseOperator): if airflow_trigger_instance: try: @@ -120,26 +118,26 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - # Trigger callback will check the status of the task in the payload, and raise AirflowException if failed. trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback) trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload) - cur_state = SUCCEEDED + cur_phase = TaskExecution.SUCCEEDED except AirflowException as e: - cur_state = RETRYABLE_FAILURE + cur_phase = TaskExecution.FAILED message = e.__str__() except asyncio.TimeoutError: logger.debug("No event received from airflow trigger") except AirflowException as e: - cur_state = RETRYABLE_FAILURE + cur_phase = TaskExecution.FAILED message = e.__str__() else: # If there is no trigger, it means the operator is not deferrable. In this case, this operator will be # executed in the creation step. Therefore, we can directly return SUCCEEDED here. # For instance, SlackWebhookOperator is not deferrable. It sends a message to Slack in the creation step. # If the message is sent successfully, agent will return SUCCEEDED here. Otherwise, it will raise an exception at creation step. - cur_state = SUCCEEDED + cur_phase = TaskExecution.SUCCEEDED else: raise FlyteUserException("Only sensor and operator are supported.") - return GetTaskResponse(resource=Resource(state=cur_state, message=message)) + return GetTaskResponse(resource=Resource(phase=cur_phase, message=message)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index 7ab0c31729..af34cae44e 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -7,7 +7,8 @@ from airflow.operators.python import PythonOperator from airflow.sensors.bash import BashSensor from airflow.sensors.time_sensor import TimeSensor -from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flyteidl.admin.agent_pb2 import DeleteTaskResponse +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow import AirflowObj from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata @@ -94,7 +95,7 @@ async def test_airflow_agent(): res = await agent.async_create(grpc_ctx, "/tmp", dummy_template, None) metadata = res.resource_meta res = await agent.async_get(grpc_ctx, metadata) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert res.resource.message == "" res = await agent.async_delete(grpc_ctx, metadata) assert res == DeleteTaskResponse() diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 54b70ea8e0..da397eb54f 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -5,18 +5,17 @@ import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, - SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource, ) +from flyteidl.core.execution_pb2 import TaskExecution from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase from flytekit.models import literals from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap @@ -94,12 +93,12 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes logger.error(job.errors.__str__()) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(job.errors.__str__()) - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE), log_links=log_links) + return GetTaskResponse(resource=Resource(phase=TaskExecution.FAILED), log_links=log_links) - cur_state = convert_to_flyte_state(str(job.state)) + cur_phase = convert_to_flyte_phase(str(job.state)) res = None - if cur_state == SUCCEEDED: + if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() if job.destination: output_location = ( @@ -116,7 +115,7 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(state=cur_state, outputs=res), log_links=log_links) + return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links) def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: client = bigquery.Client() diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index ae8091a3c9..075a13e905 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import grpc -from flyteidl.admin.agent_pb2 import SUCCEEDED +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.bigquery.agent import Metadata import flytekit.models.interface as interface_models @@ -94,7 +94,7 @@ def __init__(self): ).encode("utf-8") assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes res = agent.get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert ( res.resource.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py index 5c6ab06831..a927a1d021 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py @@ -7,7 +7,7 @@ import grpc from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource -from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_state +from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_phase from flytekit import current_context from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry @@ -173,12 +173,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - logger.exception(f"Failed to obtain status for MMCloud job: {job_id}") raise - task_state = mmcloud_status_to_flyte_state(job_status) + task_phase = mmcloud_status_to_flyte_phase(job_status) logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}") logger.debug(f"OpCenter response: {show_response}") - return GetTaskResponse(resource=Resource(state=task_state)) + return GetTaskResponse(resource=Resource(phase=task_phase)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: """ diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py index 03696d6c45..7a081ba753 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py @@ -5,39 +5,39 @@ from decimal import ROUND_CEILING, Decimal from typing import Optional, Tuple -from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RETRYABLE_FAILURE, RUNNING, SUCCEEDED, State +from flyteidl.core.execution_pb2 import TaskExecution from kubernetes.utils.quantity import parse_quantity from flytekit.core.resources import Resources -MMCLOUD_STATUS_TO_FLYTE_STATE = { - "Submitted": RUNNING, - "Initializing": RUNNING, - "Starting": RUNNING, - "Executing": RUNNING, - "Capturing": RUNNING, - "Floating": RUNNING, - "Suspended": RUNNING, - "Suspending": RUNNING, - "Resuming": RUNNING, - "Completed": SUCCEEDED, - "Cancelled": PERMANENT_FAILURE, - "Cancelling": PERMANENT_FAILURE, - "FailToComplete": RETRYABLE_FAILURE, - "FailToExecute": RETRYABLE_FAILURE, - "CheckpointFailed": RETRYABLE_FAILURE, - "Timedout": RETRYABLE_FAILURE, - "NoAvailableHost": RETRYABLE_FAILURE, - "Unknown": RETRYABLE_FAILURE, - "WaitingForLicense": PERMANENT_FAILURE, +MMCLOUD_STATUS_TO_FLYTE_PHASE = { + "Submitted": TaskExecution.RUNNING, + "Initializing": TaskExecution.RUNNING, + "Starting": TaskExecution.RUNNING, + "Executing": TaskExecution.RUNNING, + "Capturing": TaskExecution.RUNNING, + "Floating": TaskExecution.RUNNING, + "Suspended": TaskExecution.RUNNING, + "Suspending": TaskExecution.RUNNING, + "Resuming": TaskExecution.RUNNING, + "Completed": TaskExecution.SUCCEEDED, + "Cancelled": TaskExecution.FAILED, + "Cancelling": TaskExecution.FAILED, + "FailToComplete": TaskExecution.FAILED, + "FailToExecute": TaskExecution.FAILED, + "CheckpointFailed": TaskExecution.FAILED, + "Timedout": TaskExecution.FAILED, + "NoAvailableHost": TaskExecution.FAILED, + "Unknown": TaskExecution.FAILED, + "WaitingForLicense": TaskExecution.FAILED, } -def mmcloud_status_to_flyte_state(status: str) -> State: +def mmcloud_status_to_flyte_phase(status: str) -> TaskExecution.Phase: """ - Map MMCloud status to Flyte state. + Map MMCloud status to Flyte phase. """ - return MMCLOUD_STATUS_TO_FLYTE_STATE[status] + return MMCLOUD_STATUS_TO_FLYTE_PHASE[status] def flyte_to_mmcloud_resources( diff --git a/plugins/flytekit-mmcloud/tests/test_mmcloud.py b/plugins/flytekit-mmcloud/tests/test_mmcloud.py index e7f3fde7a3..eff4c4e63c 100644 --- a/plugins/flytekit-mmcloud/tests/test_mmcloud.py +++ b/plugins/flytekit-mmcloud/tests/test_mmcloud.py @@ -6,7 +6,7 @@ import grpc import pytest -from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RUNNING, SUCCEEDED +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.mmcloud import MMCloudAgent, MMCloudConfig, MMCloudTask from flytekitplugins.mmcloud.utils import async_check_output, flyte_to_mmcloud_resources @@ -125,14 +125,14 @@ def say_hello0(name: str) -> str: resource_meta = create_task_response.resource_meta get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) - state = get_task_response.resource.state - assert state in (RUNNING, SUCCEEDED) + phase = get_task_response.resource.phase + assert phase in (TaskExecution.RUNNING, TaskExecution.SUCCEEDED) asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) - state = get_task_response.resource.state - assert state == PERMANENT_FAILURE + phase = get_task_response.resource.phase + assert phase == TaskExecution.FAILED @task( task_config=MMCloudConfig(submit_extra="--nonexistent"), diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index e23a550153..7195e0bf8a 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -4,17 +4,16 @@ import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, - SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource, ) +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -120,11 +119,11 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - logger.error(err.msg) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(err.msg) - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) - cur_state = convert_to_flyte_state(str(query_status.name)) + return GetTaskResponse(resource=Resource(phase=TaskExecution.FAILED)) + cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None - if cur_state == SUCCEEDED: + if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" res = literals.LiteralMap( @@ -138,7 +137,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 50dc689bc1..a9928eb817 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -6,7 +6,8 @@ import grpc import pytest -from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flyteidl.admin.agent_pb2 import DeleteTaskResponse +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.snowflake.agent import Metadata import flytekit.models.interface as interface_models @@ -99,7 +100,7 @@ async def test_snowflake_agent(mock_get_private_key): assert res.resource_meta == metadata_bytes res = await agent.async_get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert ( res.resource.outputs.literals["results"].scalar.structured_dataset.uri == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index ff907d2a41..d178a4a893 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -6,10 +6,11 @@ from typing import Optional import grpc -from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource +from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state, get_agent_secret +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase, get_agent_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -89,12 +90,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") response = await resp.json() - cur_state = PENDING + cur_phase = TaskExecution.RUNNING message = "" state = response.get("state") if state: if state.get("result_state"): - cur_state = convert_to_flyte_state(state["result_state"]) + cur_phase = convert_to_flyte_phase(state["result_state"]) if state.get("state_message"): message = state["state_message"] @@ -102,7 +103,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] - return GetTaskResponse(resource=Resource(state=cur_state, message=message), log_links=log_links) + return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = pickle.loads(resource_meta) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 8f3bf94756..fd62bc978e 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -7,7 +7,7 @@ import grpc import pytest from aioresponses import aioresponses -from flyteidl.admin.agent_pb2 import SUCCEEDED +from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, Metadata, get_header from flytekit.extend.backend.base_agent import AgentRegistry @@ -126,7 +126,7 @@ async def test_databricks_agent(): mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) res = await agent.async_get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() assert res.resource.message == "OK" assert res.log_links[0].name == "Databricks Console" diff --git a/pyproject.toml b/pyproject.toml index 008bd46ae5..a4e5372341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.10.0", + "flyteidl>1.10.6", "fsspec>=2023.3.0,<=2023.9.2", "gcsfs>=2023.3.0,<=2023.9.2", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 3e1f096698..0e3b152794 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -8,10 +8,6 @@ import grpc import pytest from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, - RETRYABLE_FAILURE, - RUNNING, - SUCCEEDED, CreateTaskRequest, CreateTaskResponse, DeleteTaskRequest, @@ -22,6 +18,7 @@ ListAgentsResponse, Resource, ) +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings @@ -30,9 +27,9 @@ AgentBase, AgentRegistry, AsyncAgentExecutorMixin, - convert_to_flyte_state, + convert_to_flyte_phase, get_agent_secret, - is_terminal_state, + is_terminal_phase, render_task_template, ) from flytekit.models import literals @@ -67,7 +64,8 @@ def create( def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: return GetTaskResponse( - resource=Resource(state=SUCCEEDED), log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()] + resource=Resource(phase=TaskExecution.SUCCEEDED), + log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()], ) def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: @@ -90,7 +88,7 @@ async def async_create( return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + return GetTaskResponse(resource=Resource(phase=TaskExecution.SUCCEEDED)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() @@ -109,7 +107,9 @@ async def async_create( task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> CreateTaskResponse: - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl())) + return CreateTaskResponse( + resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl()) + ) def get_task_template(task_type: str) -> TaskTemplate: @@ -150,7 +150,7 @@ def test_dummy_agent(): metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes res = agent.get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert res.log_links[0].name == "console" assert res.log_links[0].uri == "localhost:3000" assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() @@ -183,7 +183,7 @@ async def test_async_dummy_agent(): res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) assert res.resource_meta == metadata_bytes res = await agent.async_get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED res = await agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse() @@ -198,7 +198,7 @@ async def test_sync_dummy_agent(): ctx = MagicMock(spec=grpc.ServicerContext) agent = AgentRegistry.get_agent("sync_dummy") res = await agent.async_create(ctx, "/tmp", sync_dummy_template, task_inputs) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert res.resource.outputs == LiteralMap({}).to_flyte_idl() agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent") @@ -225,19 +225,19 @@ async def run_agent_server(): res = await service.CreateTask(request, ctx) assert res.resource_meta == metadata_bytes res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) assert isinstance(res, DeleteTaskResponse) res = await service.CreateTask(async_request, ctx) assert res.resource_meta == metadata_bytes res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert isinstance(res, DeleteTaskResponse) res = await service.CreateTask(sync_request, ctx) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED assert res.resource.outputs == LiteralMap({}).to_flyte_idl() res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) @@ -252,27 +252,28 @@ def test_agent_server(): loop.run_in_executor(None, run_agent_server) -def test_is_terminal_state(): - assert is_terminal_state(SUCCEEDED) - assert is_terminal_state(RETRYABLE_FAILURE) - assert is_terminal_state(PERMANENT_FAILURE) - assert not is_terminal_state(RUNNING) +def test_is_terminal_phase(): + assert is_terminal_phase(TaskExecution.SUCCEEDED) + assert is_terminal_phase(TaskExecution.ABORTED) + assert is_terminal_phase(TaskExecution.FAILED) + assert not is_terminal_phase(TaskExecution.RUNNING) -def test_convert_to_flyte_state(): - assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE - assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE - assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE +def test_convert_to_flyte_phase(): + assert convert_to_flyte_phase("FAILED") == TaskExecution.FAILED + assert convert_to_flyte_phase("TIMEOUT") == TaskExecution.FAILED + assert convert_to_flyte_phase("TIMEDOUT") == TaskExecution.FAILED + assert convert_to_flyte_phase("CANCELED") == TaskExecution.FAILED - assert convert_to_flyte_state("DONE") == SUCCEEDED - assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED - assert convert_to_flyte_state("SUCCESS") == SUCCEEDED + assert convert_to_flyte_phase("DONE") == TaskExecution.SUCCEEDED + assert convert_to_flyte_phase("SUCCEEDED") == TaskExecution.SUCCEEDED + assert convert_to_flyte_phase("SUCCESS") == TaskExecution.SUCCEEDED - assert convert_to_flyte_state("RUNNING") == RUNNING + assert convert_to_flyte_phase("RUNNING") == TaskExecution.RUNNING invalid_state = "INVALID_STATE" with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"): - convert_to_flyte_state(invalid_state) + convert_to_flyte_phase(invalid_state) @patch("flytekit.current_context") diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index dbb81c3f47..e654f522a6 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -4,7 +4,8 @@ import cloudpickle import grpc import pytest -from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flyteidl.admin.agent_pb2 import DeleteTaskResponse +from flyteidl.core.execution_pb2 import TaskExecution import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry @@ -44,6 +45,6 @@ async def test_sensor_engine(): metadata_bytes = cloudpickle.dumps(tmp.custom) assert res.resource_meta == metadata_bytes res = await agent.async_get(ctx, metadata_bytes) - assert res.resource.state == SUCCEEDED + assert res.resource.phase == TaskExecution.SUCCEEDED res = await agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse()