From 993ff59df409922b5f3f1bd41bd4eb0e71b54d32 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jul 2024 15:11:27 +0800 Subject: [PATCH] Rename databricks task type (#2574) Signed-off-by: Kevin Su --- flytekit/extend/backend/base_agent.py | 20 +++++++---- .../flytekitplugins/spark/__init__.py | 2 +- .../flytekitplugins/spark/agent.py | 33 ++++++++++++++----- .../flytekitplugins/spark/task.py | 29 +++++++++++++--- 4 files changed, 63 insertions(+), 21 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 9d42910070..214feed892 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -22,6 +22,7 @@ from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils from flytekit.core.base_task import PythonTask +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.type_engine import TypeEngine, dataclass_from_dict from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.exceptions.user import FlyteUserException @@ -319,14 +320,19 @@ async def _create( self: PythonTask, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None ) -> ResourceMeta: ctx = FlyteContext.current_context() - - literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) if isinstance(self, PythonFunctionTask): - # Write the inputs to a remote file, so that the remote task can read the inputs from this file. - path = ctx.file_access.get_random_local_path() - utils.write_proto_to_file(literal_map.to_flyte_idl(), path) - ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") - task_template = render_task_template(task_template, output_prefix) + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION) + cb = ctx.new_builder().with_execution_state(es) + + with FlyteContextManager.with_context(cb) as ctx: + # Write the inputs to a remote file, so that the remote task can read the inputs from this file. + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + path = ctx.file_access.get_random_local_path() + utils.write_proto_to_file(literal_map.to_flyte_idl(), path) + ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") + task_template = render_task_template(task_template, output_prefix) + else: + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) resource_meta = await mirror_async_methods( self._agent.create, diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index 72c9f37c9f..1deeceec6b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -21,4 +21,4 @@ from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Databricks, Spark, new_spark_session # noqa +from .task import Databricks, DatabricksV2, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index d367f3f04a..0f1788288a 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -39,22 +39,22 @@ async def create( if databricks_job.get("existing_cluster_id") is None: new_cluster = databricks_job.get("new_cluster") if new_cluster is None: - raise Exception("Either existing_cluster_id or new_cluster must be specified") + raise ValueError("Either existing_cluster_id or new_cluster must be specified") if not new_cluster.get("docker_image"): new_cluster["docker_image"] = {"url": container.image} if not new_cluster.get("spark_conf"): new_cluster["spark_conf"] = custom["sparkConf"] # https://docs.databricks.com/api/workspace/jobs/submit databricks_job["spark_python_task"] = { - "python_file": "flytekitplugins/spark/entrypoint.py", + "python_file": "flytekitplugins/databricks/entrypoint.py", "source": "GIT", "parameters": container.args, } databricks_job["git_source"] = { "git_url": "https://github.com/flyteorg/flytetools", "git_provider": "gitHub", - # https://github.com/flyteorg/flytetools/commit/aff8a9f2adbf5deda81d36d59a0b8fa3b1fc3679 - "git_commit": "aff8a9f2adbf5deda81d36d59a0b8fa3b1fc3679", + # https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96 + "git_commit": "572298df1f971fb58c258398bd70a6372f811c96", } databricks_instance = custom["databricksInstance"] @@ -65,7 +65,7 @@ async def create( async with session.post(databricks_url, headers=get_header(), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to create databricks job with error: {response}") + raise RuntimeError(f"Failed to create databricks job with error: {response}") return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) @@ -78,14 +78,15 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: async with aiohttp.ClientSession() as session: async with session.get(databricks_url, headers=get_header()) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") + raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() cur_phase = TaskExecution.UNDEFINED message = "" state = response.get("state") - # The databricks job's state is determined by life_cycle_state and result_state. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + # The databricks job's state is determined by life_cycle_state and result_state. + # https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate if state: life_cycle_state = state.get("life_cycle_state") if result_state_is_available(life_cycle_state): @@ -109,10 +110,25 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") + raise RuntimeError( + f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}" + ) await resp.json() +class DatabricksAgentV2(DatabricksAgent): + """ + Add DatabricksAgentV2 to support running the k8s spark and databricks spark together in the same workflow. + This is necessary because one task type can only be handled by a single backend plugin. + + spark -> k8s spark plugin + databricks -> databricks agent + """ + + def __init__(self): + super(AsyncAgentBase, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) + + def get_header() -> typing.Dict[str, str]: token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") return {"Authorization": f"Bearer {token}", "content-type": "application/json"} @@ -123,3 +139,4 @@ def result_state_is_available(life_cycle_state: str) -> bool: AgentRegistry.register(DatabricksAgent()) +AgentRegistry.register(DatabricksAgentV2()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 8a8c3b2b5b..2ed01c2cd7 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union, cast +import click from google.protobuf.json_format import MessageToDict from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger @@ -46,6 +47,22 @@ def __post_init__(self): @dataclass class Databricks(Spark): + """ + Deprecated. Use DatabricksV2 instead. + """ + + databricks_conf: Optional[Dict[str, Union[str, dict]]] = None + databricks_instance: Optional[str] = None + + def __post_init__(self): + logger.warn( + "Databricks is deprecated. Use 'from flytekitplugins.spark import Databricks' instead," + "and make sure to upgrade the version of flyteagent deployment to >v1.13.0.", + ) + + +@dataclass +class DatabricksV2(Spark): """ Use this to configure a Databricks task. Task's marked with this will automatically execute natively onto databricks platform as a distributed execution of spark @@ -127,6 +144,7 @@ def __init__( self._default_applications_path = ( self._default_applications_path or "local:///usr/local/bin/entrypoint.py" ) + super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, @@ -151,8 +169,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: main_class="", spark_type=SparkType.PYTHON, ) - if isinstance(self.task_config, Databricks): - cfg = cast(Databricks, self.task_config) + if isinstance(self.task_config, (Databricks, DatabricksV2)): + cfg = cast(DatabricksV2, self.task_config) job._databricks_conf = cfg.databricks_conf job._databricks_instance = cfg.databricks_instance @@ -181,7 +199,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() def execute(self, **kwargs) -> Any: - if isinstance(self.task_config, Databricks): + if isinstance(self.task_config, (Databricks, DatabricksV2)): # Use the Databricks agent to run it by default. try: ctx = FlyteContextManager.current_context() @@ -193,11 +211,12 @@ def execute(self, **kwargs) -> Any: if ctx.execution_state and ctx.execution_state.is_local_execution(): return AsyncAgentExecutorMixin.execute(self, **kwargs) except Exception as e: - logger.error(f"Agent failed to run the task with error: {e}") - logger.info("Falling back to local execution") + click.secho(f"❌ Agent failed to run the task with error: {e}", fg="red") + click.secho("Falling back to local execution", fg="red") return PythonFunctionTask.execute(self, **kwargs) # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask) +TaskPlugins.register_pythontask_plugin(DatabricksV2, PysparkFunctionTask)