diff --git a/airflow/models.py b/airflow/models.py index 40c466c9d967a..80715853239ea 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -4358,6 +4358,7 @@ def deactivate_unknown_dags(active_dag_ids, session=None): DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): dag.is_active = False session.merge(dag) + session.commit() @staticmethod @provide_session diff --git a/tests/models.py b/tests/models.py index f2d36a263b3f7..e7ca343ca150b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1691,6 +1691,31 @@ def test_kill_zombies(self, mock_ti): configuration.getboolean('core', 'unit_test_mode'), ANY) + def test_deactivate_unknown_dags(self): + """ + Test that dag_ids not passed into deactivate_unknown_dags + are deactivated when function is invoked + """ + dagbag = models.DagBag(include_examples=True) + expected_active_dags = dagbag.dags.keys() + + session = settings.Session + session.add(DagModel(dag_id='test_deactivate_unknown_dags', is_active=True)) + session.commit() + + models.DAG.deactivate_unknown_dags(expected_active_dags) + + for dag in session.query(DagModel).all(): + if dag.dag_id in expected_active_dags: + self.assertTrue(dag.is_active) + else: + self.assertEquals(dag.dag_id, 'test_deactivate_unknown_dags') + self.assertFalse(dag.is_active) + + # clean up + session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete() + session.commit() + class TaskInstanceTest(unittest.TestCase):