diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index b49bc061f5..4f009f91c1 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -85,7 +85,7 @@ class AirflowContainerTask(PythonAutoContainerTask[AirflowObj]): The airflow task module, name and parameters are stored in the task config. Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator. - These tasks don't have async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container. + These tasks don't have an async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container. """ def __init__( @@ -164,7 +164,7 @@ def _is_deferrable(cls: Type) -> bool: # Only Airflow operators are deferrable. if not issubclass(cls, airflow_models.BaseOperator): return False - # Airflow sensors are not deferrable. Sensor is a subclass of BaseOperator. + # Airflow sensors are not deferrable. The Sensor is a subclass of BaseOperator. if issubclass(cls, airflow_sensors.BaseSensorOperator): return False try: @@ -186,7 +186,7 @@ def _flyte_operator(*args, **kwargs): cls = args[0] try: if FlyteContextManager.current_context().user_space_params.get_original_task: - # Return original task when running in the agent. + # Return an original task when running in the agent. return object.__new__(cls) except AssertionError: # This happens when the task is created in the dynamic workflow. @@ -197,7 +197,7 @@ def _flyte_operator(*args, **kwargs): task_id = kwargs["task_id"] or cls.__name__ config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) - if not _is_deferrable(cls): + if not issubclass(cls, airflow_sensors.BaseSensorOperator) and not _is_deferrable(cls): # Dataflow operators are not deferrable, so we run them in a container. return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)() return AirflowTask(name=task_id, task_config=config)() diff --git a/plugins/flytekit-airflow/tests/test_task.py b/plugins/flytekit-airflow/tests/test_task.py index f55bcef5dd..81399e63a9 100644 --- a/plugins/flytekit-airflow/tests/test_task.py +++ b/plugins/flytekit-airflow/tests/test_task.py @@ -1,18 +1,23 @@ import jsonpickle -from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator +from airflow.operators.bash import BashOperator +from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator from airflow.providers.google.cloud.operators.dataproc import DataprocCreateClusterOperator from airflow.sensors.bash import BashSensor +from airflow.sensors.time_sensor import TimeSensor from airflow.utils.context import Context from flytekitplugins.airflow.task import ( AirflowContainerTask, AirflowObj, AirflowTask, + _flyte_operator, _is_deferrable, airflow_task_resolver, ) +from mock import mock from flytekit import FlyteContextManager from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core import context_manager def test_xcom_push(): @@ -87,3 +92,31 @@ def test_airflow_container_task(): assert isinstance( airflow_task_resolver.load_task(t.task_resolver.loader_args(serialization_settings, t)), AirflowContainerTask ) + + +@mock.patch("flytekitplugins.airflow.task.AirflowContainerTask") +@mock.patch("flytekitplugins.airflow.task.AirflowTask") +def test_flyte_operator(airflow_task, airflow_container_task): + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context(ctx.new_builder()): + params = FlyteContextManager.current_context().user_space_params + params.builder().add_attr("GET_ORIGINAL_TASK", False).add_attr("XCOM_DATA", {}).build() + _flyte_operator(BashOperator, task_id="BashOperator") + airflow_task.assert_called_once() + _flyte_operator(BeamRunJavaPipelineOperator, task_id="BeamRunJavaPipelineOperator") + airflow_container_task.assert_called_once() + + airflow_task.reset_mock() + airflow_container_task.reset_mock() + + _flyte_operator(TimeSensor, task_id="TimeSensor") + airflow_task.assert_called_once() + + _flyte_operator(BeamRunPythonPipelineOperator, task_id="BeamRunPythonPipelineOperator") + airflow_container_task.assert_called_once() + + airflow_task.reset_mock() + airflow_container_task.reset_mock() + + _flyte_operator(DataprocCreateClusterOperator, task_id="DataprocCreateClusterOperator") + airflow_task.assert_called_once()