diff --git a/airflow/jobs.py b/airflow/jobs.py index 69471fd64f890..f2caeda137721 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1073,9 +1073,6 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): :type states: Tuple[State] :return: List[TaskInstance] """ - # TODO(saguziel): Change this to include QUEUED, for concurrency - # purposes we may want to count queued tasks - states_to_count_as_running = [State.RUNNING] executable_tis = [] # Get all the queued task instances from associated with scheduled @@ -1121,6 +1118,7 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) + states_to_count_as_running = [State.RUNNING, State.QUEUED] task_concurrency_map = self.__get_task_concurrency_map( states=states_to_count_as_running, session=session) @@ -1171,7 +1169,6 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): simple_dag = simple_dag_bag.get_dag(dag_id) if dag_id not in dag_id_to_possibly_running_task_count: - # TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104 dag_id_to_possibly_running_task_count[dag_id] = \ DAG.get_num_task_instances( dag_id, diff --git a/tests/jobs.py b/tests/jobs.py index 67717acc2b654..c6297af2459f8 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -1575,6 +1575,39 @@ def test_find_executable_task_instances_concurrency(self): self.assertEqual(0, len(res)) + def test_find_executable_task_instances_concurrency_queued(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency_queued' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3) + task1 = DummyOperator(dag=dag, task_id='dummy1') + task2 = DummyOperator(dag=dag, task_id='dummy2') + task3 = DummyOperator(dag=dag, task_id='dummy3') + dagbag = self._make_simple_dag_bag([dag]) + + scheduler = SchedulerJob() + session = settings.Session() + dag_run = scheduler.create_dag_run(dag) + + ti1 = TI(task1, dag_run.execution_date) + ti2 = TI(task2, dag_run.execution_date) + ti3 = TI(task3, dag_run.execution_date) + ti1.state = State.RUNNING + ti2.state = State.QUEUED + ti3.state = State.SCHEDULED + + session.merge(ti1) + session.merge(ti2) + session.merge(ti3) + + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + self.assertEqual(res[0].key, ti3.key) + def test_find_executable_task_instances_task_concurrency(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency' task_id_1 = 'dummy'