Skip to content

Commit

Permalink
[AIRFLOW-1104] Update jobs.py so Airflow does not over schedule tasks (
Browse files Browse the repository at this point in the history
…apache#3568)

This change will prevent tasks from getting scheduled and queued over
the concurrency limits set for the dag
  • Loading branch information
dan-sf authored and Alice Berard committed Jan 3, 2019
1 parent c85f2a1 commit 20ad14f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
5 changes: 1 addition & 4 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,9 +1075,6 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
:type states: Tuple[State]
:return: List[TaskInstance]
"""
# TODO(saguziel): Change this to include QUEUED, for concurrency
# purposes we may want to count queued tasks
states_to_count_as_running = [State.RUNNING]
executable_tis = []

# Get all the queued task instances from associated with scheduled
Expand Down Expand Up @@ -1123,6 +1120,7 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
for task_instance in task_instances_to_examine:
pool_to_task_instances[task_instance.pool].append(task_instance)

states_to_count_as_running = [State.RUNNING, State.QUEUED]
task_concurrency_map = self.__get_task_concurrency_map(
states=states_to_count_as_running, session=session)

Expand Down Expand Up @@ -1173,7 +1171,6 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
simple_dag = simple_dag_bag.get_dag(dag_id)

if dag_id not in dag_id_to_possibly_running_task_count:
# TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104
dag_id_to_possibly_running_task_count[dag_id] = \
DAG.get_num_task_instances(
dag_id,
Expand Down
33 changes: 33 additions & 0 deletions tests/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,39 @@ def test_find_executable_task_instances_concurrency(self):

self.assertEqual(0, len(res))

def test_find_executable_task_instances_concurrency_queued(self):
dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency_queued'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
task1 = DummyOperator(dag=dag, task_id='dummy1')
task2 = DummyOperator(dag=dag, task_id='dummy2')
task3 = DummyOperator(dag=dag, task_id='dummy3')
dagbag = self._make_simple_dag_bag([dag])

scheduler = SchedulerJob()
session = settings.Session()
dag_run = scheduler.create_dag_run(dag)

ti1 = TI(task1, dag_run.execution_date)
ti2 = TI(task2, dag_run.execution_date)
ti3 = TI(task3, dag_run.execution_date)
ti1.state = State.RUNNING
ti2.state = State.QUEUED
ti3.state = State.SCHEDULED

session.merge(ti1)
session.merge(ti2)
session.merge(ti3)

session.commit()

res = scheduler._find_executable_task_instances(
dagbag,
states=[State.SCHEDULED],
session=session)

self.assertEqual(1, len(res))
self.assertEqual(res[0].key, ti3.key)

def test_find_executable_task_instances_task_concurrency(self):
dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency'
task_id_1 = 'dummy'
Expand Down

0 comments on commit 20ad14f

Please sign in to comment.