Skip to content

Commit

Permalink
Handle executors in the providers for Airflow <3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Nov 4, 2024
1 parent d55524f commit f998628
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
"""

supports_ad_hoc_ti_run: bool = True
# TODO: Remove this flag once providers depend on Airflow 3.0
supports_pickling: bool = True
supports_sentry: bool = False

Expand Down Expand Up @@ -159,14 +160,14 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
pickle_id: int | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
pool: str | None = None,
cfg_path: str | None = None,
**kwargs,
) -> None:
"""Queues task instance via celery or kubernetes executor."""
from airflow.models.taskinstance import SimpleTaskInstance
Expand All @@ -175,17 +176,22 @@ def queue_task_instance(
self.log.debug(
"Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key
)

# TODO: Remove this once providers depend on Airflow 3.0
if not hasattr(task_instance, "pickle_id"):
del kwargs["pickle_id"]

executor.queue_task_instance(
task_instance=task_instance,
mark_success=mark_success,
pickle_id=pickle_id,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
cfg_path=cfg_path,
**kwargs,
)

def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LocalKubernetesExecutor(BaseExecutor):
"""

supports_ad_hoc_ti_run: bool = True
# TODO: Remove this attribute once providers rely on Airflow >=3.0.0
supports_pickling: bool = False
supports_sentry: bool = False

Expand Down Expand Up @@ -146,14 +147,14 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
pickle_id: int | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
pool: str | None = None,
cfg_path: str | None = None,
**kwargs,
) -> None:
"""Queues task instance via local or kubernetes executor."""
from airflow.models.taskinstance import SimpleTaskInstance
Expand All @@ -162,17 +163,21 @@ def queue_task_instance(
self.log.debug(
"Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key
)

if not hasattr(task_instance, "pickle_id"):
del kwargs["pickle_id"]

executor.queue_task_instance(
task_instance=task_instance,
mark_success=mark_success,
pickle_id=pickle_id,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
cfg_path=cfg_path,
**kwargs,
)

def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
Expand Down
3 changes: 0 additions & 3 deletions providers/tests/celery/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def teardown_method(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()

def test_supports_pickling(self):
assert CeleryExecutor.supports_pickling

def test_supports_sentry(self):
assert CeleryExecutor.supports_sentry

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,6 @@ def test_get_task_log(self, mock_get_kube_client, create_task_instance_of_operat
"Reading from k8s pod logs failed: error_fetching_pod_log",
]

def test_supports_pickling(self):
assert KubernetesExecutor.supports_pickling

def test_supports_sentry(self):
assert not KubernetesExecutor.supports_sentry

Expand Down

0 comments on commit f998628

Please sign in to comment.