Skip to content

Commit

Permalink
Rename databricks task type (#2574)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jul 12, 2024
1 parent eb0fc61 commit 993ff59
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
20 changes: 13 additions & 7 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import utils
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, dataclass_from_dict
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.exceptions.user import FlyteUserException
Expand Down Expand Up @@ -319,14 +320,19 @@ async def _create(
self: PythonTask, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None
) -> ResourceMeta:
ctx = FlyteContext.current_context()

literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
if isinstance(self, PythonFunctionTask):
# Write the inputs to a remote file, so that the remote task can read the inputs from this file.
path = ctx.file_access.get_random_local_path()
utils.write_proto_to_file(literal_map.to_flyte_idl(), path)
ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb")
task_template = render_task_template(task_template, output_prefix)
es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)
cb = ctx.new_builder().with_execution_state(es)

with FlyteContextManager.with_context(cb) as ctx:
# Write the inputs to a remote file, so that the remote task can read the inputs from this file.
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
path = ctx.file_access.get_random_local_path()
utils.write_proto_to_file(literal_map.to_flyte_idl(), path)
ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb")
task_template = render_task_template(task_template, output_prefix)
else:
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())

resource_meta = await mirror_async_methods(
self._agent.create,
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from .pyspark_transformers import PySparkPipelineModelTransformer
from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa
from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler
from .task import Databricks, Spark, new_spark_session # noqa
from .task import Databricks, DatabricksV2, Spark, new_spark_session # noqa
33 changes: 25 additions & 8 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ async def create(
if databricks_job.get("existing_cluster_id") is None:
new_cluster = databricks_job.get("new_cluster")
if new_cluster is None:
raise Exception("Either existing_cluster_id or new_cluster must be specified")
raise ValueError("Either existing_cluster_id or new_cluster must be specified")
if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom["sparkConf"]
# https://docs.databricks.com/api/workspace/jobs/submit
databricks_job["spark_python_task"] = {
"python_file": "flytekitplugins/spark/entrypoint.py",
"python_file": "flytekitplugins/databricks/entrypoint.py",
"source": "GIT",
"parameters": container.args,
}
databricks_job["git_source"] = {
"git_url": "https://github.com/flyteorg/flytetools",
"git_provider": "gitHub",
# https://github.com/flyteorg/flytetools/commit/aff8a9f2adbf5deda81d36d59a0b8fa3b1fc3679
"git_commit": "aff8a9f2adbf5deda81d36d59a0b8fa3b1fc3679",
# https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96
"git_commit": "572298df1f971fb58c258398bd70a6372f811c96",
}

databricks_instance = custom["databricksInstance"]
Expand All @@ -65,7 +65,7 @@ async def create(
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
response = await resp.json()
if resp.status != http.HTTPStatus.OK:
raise Exception(f"Failed to create databricks job with error: {response}")
raise RuntimeError(f"Failed to create databricks job with error: {response}")

return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"]))

Expand All @@ -78,14 +78,15 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource:
async with aiohttp.ClientSession() as session:
async with session.get(databricks_url, headers=get_header()) as resp:
if resp.status != http.HTTPStatus.OK:
raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
response = await resp.json()

cur_phase = TaskExecution.UNDEFINED
message = ""
state = response.get("state")

# The databricks job's state is determined by life_cycle_state and result_state. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
# The databricks job's state is determined by life_cycle_state and result_state.
# https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state:
life_cycle_state = state.get("life_cycle_state")
if result_state_is_available(life_cycle_state):
Expand All @@ -109,10 +110,25 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs):
async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
if resp.status != http.HTTPStatus.OK:
raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}")
raise RuntimeError(
f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}"
)
await resp.json()


class DatabricksAgentV2(DatabricksAgent):
"""
Add DatabricksAgentV2 to support running the k8s spark and databricks spark together in the same workflow.
This is necessary because one task type can only be handled by a single backend plugin.
spark -> k8s spark plugin
databricks -> databricks agent
"""

def __init__(self):
super(AsyncAgentBase, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata)


def get_header() -> typing.Dict[str, str]:
token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN")
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}
Expand All @@ -123,3 +139,4 @@ def result_state_is_available(life_cycle_state: str) -> bool:


AgentRegistry.register(DatabricksAgent())
AgentRegistry.register(DatabricksAgentV2())
29 changes: 24 additions & 5 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union, cast

import click
from google.protobuf.json_format import MessageToDict

from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
Expand Down Expand Up @@ -46,6 +47,22 @@ def __post_init__(self):

@dataclass
class Databricks(Spark):
"""
Deprecated. Use DatabricksV2 instead.
"""

databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
databricks_instance: Optional[str] = None

def __post_init__(self):
logger.warn(
"Databricks is deprecated. Use 'from flytekitplugins.spark import Databricks' instead,"
"and make sure to upgrade the version of flyteagent deployment to >v1.13.0.",
)


@dataclass
class DatabricksV2(Spark):
"""
Use this to configure a Databricks task. Task's marked with this will automatically execute
natively onto databricks platform as a distributed execution of spark
Expand Down Expand Up @@ -127,6 +144,7 @@ def __init__(
self._default_applications_path = (
self._default_applications_path or "local:///usr/local/bin/entrypoint.py"
)

super(PysparkFunctionTask, self).__init__(
task_config=task_config,
task_type=self._SPARK_TASK_TYPE,
Expand All @@ -151,8 +169,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
main_class="",
spark_type=SparkType.PYTHON,
)
if isinstance(self.task_config, Databricks):
cfg = cast(Databricks, self.task_config)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
job._databricks_conf = cfg.databricks_conf
job._databricks_instance = cfg.databricks_instance

Expand Down Expand Up @@ -181,7 +199,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()

def execute(self, **kwargs) -> Any:
if isinstance(self.task_config, Databricks):
if isinstance(self.task_config, (Databricks, DatabricksV2)):
# Use the Databricks agent to run it by default.
try:
ctx = FlyteContextManager.current_context()
Expand All @@ -193,11 +211,12 @@ def execute(self, **kwargs) -> Any:
if ctx.execution_state and ctx.execution_state.is_local_execution():
return AsyncAgentExecutorMixin.execute(self, **kwargs)
except Exception as e:
logger.error(f"Agent failed to run the task with error: {e}")
logger.info("Falling back to local execution")
click.secho(f"Agent failed to run the task with error: {e}", fg="red")
click.secho("Falling back to local execution", fg="red")
return PythonFunctionTask.execute(self, **kwargs)


# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(DatabricksV2, PysparkFunctionTask)

0 comments on commit 993ff59

Please sign in to comment.