Skip to content

Commit

Permalink
Handle executors in the providers for Airflow <3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Nov 4, 2024
1 parent d55524f commit 4464668
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
4 changes: 2 additions & 2 deletions dev/perf/scheduler_dag_execution_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids):

executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids, num_runs=num_runs)
scheduler_job = Job(executor=executor)
job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids, do_pickle=False)
job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids)
executor.job_runner = job_runner

total_tasks = sum(len(dag.tasks) for dag in dags)
Expand All @@ -301,7 +301,7 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids):
reset_dag(dag, session)
executor.reset(dag_ids)
scheduler_job = Job(executor=executor)
job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids, do_pickle=False)
job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids)
executor.scheduler_job = scheduler_job

gc.disable()
Expand Down
2 changes: 1 addition & 1 deletion dev/perf/sql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run_scheduler_job(with_db_reset=False) -> None:

if with_db_reset:
reset_db()
job_runner = SchedulerJobRunner(job=Job(), subdir=DAG_FOLDER, do_pickle=False, num_runs=3)
job_runner = SchedulerJobRunner(job=Job(), subdir=DAG_FOLDER, num_runs=3)
run_job(job=job_runner.job, execute_callable=job_runner._execute)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
"""

supports_ad_hoc_ti_run: bool = True
# TODO: Remove this flag once providers depend on Airflow 3.0
supports_pickling: bool = True
supports_sentry: bool = False

Expand Down Expand Up @@ -159,14 +160,14 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
pickle_id: int | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
pool: str | None = None,
cfg_path: str | None = None,
**kwargs,
) -> None:
"""Queues task instance via celery or kubernetes executor."""
from airflow.models.taskinstance import SimpleTaskInstance
Expand All @@ -175,17 +176,22 @@ def queue_task_instance(
self.log.debug(
"Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key
)

# TODO: Remove this once providers depend on Airflow 3.0
if not hasattr(task_instance, "pickle_id"):
del kwargs["pickle_id"]

executor.queue_task_instance(
task_instance=task_instance,
mark_success=mark_success,
pickle_id=pickle_id,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
cfg_path=cfg_path,
**kwargs,
)

def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LocalKubernetesExecutor(BaseExecutor):
"""

supports_ad_hoc_ti_run: bool = True
# TODO: Remove this attribute once providers rely on Airflow >=3.0.0
supports_pickling: bool = False
supports_sentry: bool = False

Expand Down Expand Up @@ -146,14 +147,14 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
pickle_id: int | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
pool: str | None = None,
cfg_path: str | None = None,
**kwargs,
) -> None:
"""Queues task instance via local or kubernetes executor."""
from airflow.models.taskinstance import SimpleTaskInstance
Expand All @@ -162,17 +163,21 @@ def queue_task_instance(
self.log.debug(
"Using executor: %s to queue_task_instance for %s", executor.__class__.__name__, task_instance.key
)

if not hasattr(task_instance, "pickle_id"):
del kwargs["pickle_id"]

executor.queue_task_instance(
task_instance=task_instance,
mark_success=mark_success,
pickle_id=pickle_id,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
cfg_path=cfg_path,
**kwargs,
)

def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
Expand Down
3 changes: 0 additions & 3 deletions providers/tests/celery/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def teardown_method(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()

def test_supports_pickling(self):
assert CeleryExecutor.supports_pickling

def test_supports_sentry(self):
assert CeleryExecutor.supports_sentry

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,6 @@ def test_get_task_log(self, mock_get_kube_client, create_task_instance_of_operat
"Reading from k8s pod logs failed: error_fetching_pod_log",
]

def test_supports_pickling(self):
assert KubernetesExecutor.supports_pickling

def test_supports_sentry(self):
assert not KubernetesExecutor.supports_sentry

Expand Down

0 comments on commit 4464668

Please sign in to comment.