Skip to content

Commit

Permalink
Fix airflow sensor (#2169)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Feb 9, 2024
1 parent a1e7153 commit 54ab9cf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
8 changes: 4 additions & 4 deletions plugins/flytekit-airflow/flytekitplugins/airflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class AirflowContainerTask(PythonAutoContainerTask[AirflowObj]):
The airflow task module, name and parameters are stored in the task config.
Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator.
These tasks don't have async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container.
These tasks don't have an async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container.
"""

def __init__(
Expand Down Expand Up @@ -164,7 +164,7 @@ def _is_deferrable(cls: Type) -> bool:
# Only Airflow operators are deferrable.
if not issubclass(cls, airflow_models.BaseOperator):
return False
# Airflow sensors are not deferrable. Sensor is a subclass of BaseOperator.
# Airflow sensors are not deferrable. The Sensor is a subclass of BaseOperator.
if issubclass(cls, airflow_sensors.BaseSensorOperator):
return False
try:
Expand All @@ -186,7 +186,7 @@ def _flyte_operator(*args, **kwargs):
cls = args[0]
try:
if FlyteContextManager.current_context().user_space_params.get_original_task:
# Return original task when running in the agent.
# Return an original task when running in the agent.
return object.__new__(cls)
except AssertionError:
# This happens when the task is created in the dynamic workflow.
Expand All @@ -197,7 +197,7 @@ def _flyte_operator(*args, **kwargs):
task_id = kwargs["task_id"] or cls.__name__
config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs)

if not _is_deferrable(cls):
if not issubclass(cls, airflow_sensors.BaseSensorOperator) and not _is_deferrable(cls):
# Dataflow operators are not deferrable, so we run them in a container.
return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)()
return AirflowTask(name=task_id, task_config=config)()
Expand Down
35 changes: 34 additions & 1 deletion plugins/flytekit-airflow/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import jsonpickle
from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator
from airflow.operators.bash import BashOperator
from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator
from airflow.providers.google.cloud.operators.dataproc import DataprocCreateClusterOperator
from airflow.sensors.bash import BashSensor
from airflow.sensors.time_sensor import TimeSensor
from airflow.utils.context import Context
from flytekitplugins.airflow.task import (
AirflowContainerTask,
AirflowObj,
AirflowTask,
_flyte_operator,
_is_deferrable,
airflow_task_resolver,
)
from mock import mock

from flytekit import FlyteContextManager
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import context_manager


def test_xcom_push():
Expand Down Expand Up @@ -87,3 +92,31 @@ def test_airflow_container_task():
assert isinstance(
airflow_task_resolver.load_task(t.task_resolver.loader_args(serialization_settings, t)), AirflowContainerTask
)


@mock.patch("flytekitplugins.airflow.task.AirflowContainerTask")
@mock.patch("flytekitplugins.airflow.task.AirflowTask")
def test_flyte_operator(airflow_task, airflow_container_task):
ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(ctx.new_builder()):
params = FlyteContextManager.current_context().user_space_params
params.builder().add_attr("GET_ORIGINAL_TASK", False).add_attr("XCOM_DATA", {}).build()
_flyte_operator(BashOperator, task_id="BashOperator")
airflow_task.assert_called_once()
_flyte_operator(BeamRunJavaPipelineOperator, task_id="BeamRunJavaPipelineOperator")
airflow_container_task.assert_called_once()

airflow_task.reset_mock()
airflow_container_task.reset_mock()

_flyte_operator(TimeSensor, task_id="TimeSensor")
airflow_task.assert_called_once()

_flyte_operator(BeamRunPythonPipelineOperator, task_id="BeamRunPythonPipelineOperator")
airflow_container_task.assert_called_once()

airflow_task.reset_mock()
airflow_container_task.reset_mock()

_flyte_operator(DataprocCreateClusterOperator, task_id="DataprocCreateClusterOperator")
airflow_task.assert_called_once()

0 comments on commit 54ab9cf

Please sign in to comment.