Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix airflow sensor #2169

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions plugins/flytekit-airflow/flytekitplugins/airflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class AirflowContainerTask(PythonAutoContainerTask[AirflowObj]):
The airflow task module, name and parameters are stored in the task config.
Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator.
These tasks don't have async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container.
These tasks don't have an async method to get the job status, so cannot be used in the Flyte agent. We run these tasks in a container.
"""

def __init__(
Expand Down Expand Up @@ -164,7 +164,7 @@ def _is_deferrable(cls: Type) -> bool:
# Only Airflow operators are deferrable.
if not issubclass(cls, airflow_models.BaseOperator):
return False
# Airflow sensors are not deferrable. Sensor is a subclass of BaseOperator.
# Airflow sensors are not deferrable. The Sensor is a subclass of BaseOperator.
if issubclass(cls, airflow_sensors.BaseSensorOperator):
return False
try:
Expand All @@ -186,7 +186,7 @@ def _flyte_operator(*args, **kwargs):
cls = args[0]
try:
if FlyteContextManager.current_context().user_space_params.get_original_task:
# Return original task when running in the agent.
# Return an original task when running in the agent.
return object.__new__(cls)
except AssertionError:
# This happens when the task is created in the dynamic workflow.
Expand All @@ -197,7 +197,7 @@ def _flyte_operator(*args, **kwargs):
task_id = kwargs["task_id"] or cls.__name__
config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs)

if not _is_deferrable(cls):
if not issubclass(cls, airflow_sensors.BaseSensorOperator) and not _is_deferrable(cls):
# Dataflow operators are not deferrable, so we run them in a container.
return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)()
return AirflowTask(name=task_id, task_config=config)()
Expand Down
35 changes: 34 additions & 1 deletion plugins/flytekit-airflow/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import jsonpickle
from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator
from airflow.operators.bash import BashOperator
from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator
from airflow.providers.google.cloud.operators.dataproc import DataprocCreateClusterOperator
from airflow.sensors.bash import BashSensor
from airflow.sensors.time_sensor import TimeSensor
from airflow.utils.context import Context
from flytekitplugins.airflow.task import (
AirflowContainerTask,
AirflowObj,
AirflowTask,
_flyte_operator,
_is_deferrable,
airflow_task_resolver,
)
from mock import mock

from flytekit import FlyteContextManager
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import context_manager


def test_xcom_push():
Expand Down Expand Up @@ -87,3 +92,31 @@ def test_airflow_container_task():
assert isinstance(
airflow_task_resolver.load_task(t.task_resolver.loader_args(serialization_settings, t)), AirflowContainerTask
)


@mock.patch("flytekitplugins.airflow.task.AirflowContainerTask")
@mock.patch("flytekitplugins.airflow.task.AirflowTask")
def test_flyte_operator(airflow_task, airflow_container_task):
ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(ctx.new_builder()):
params = FlyteContextManager.current_context().user_space_params
params.builder().add_attr("GET_ORIGINAL_TASK", False).add_attr("XCOM_DATA", {}).build()
_flyte_operator(BashOperator, task_id="BashOperator")
airflow_task.assert_called_once()
_flyte_operator(BeamRunJavaPipelineOperator, task_id="BeamRunJavaPipelineOperator")
airflow_container_task.assert_called_once()

airflow_task.reset_mock()
airflow_container_task.reset_mock()

_flyte_operator(TimeSensor, task_id="TimeSensor")
airflow_task.assert_called_once()

_flyte_operator(BeamRunPythonPipelineOperator, task_id="BeamRunPythonPipelineOperator")
airflow_container_task.assert_called_once()

airflow_task.reset_mock()
airflow_container_task.reset_mock()

_flyte_operator(DataprocCreateClusterOperator, task_id="DataprocCreateClusterOperator")
airflow_task.assert_called_once()
Loading