From 0da2b8c88960f1506ee1361abb6b5011c5aa8d0d Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 7 Jul 2021 16:58:05 +0100 Subject: [PATCH] Support DAGS folder being in different location on scheduler and runners There has been some vestigial support for this concept in Airflow for a while (all the CLI command already turn the literal `DAGS_FOLDER` in to the real value of the DAGS folder when loading dags), but sometime around 1.10.1-1.10.3 it got fully broken and the scheduler only ever passed full paths to DAG files. This PR brings back this behaviour --- airflow/jobs/scheduler_job.py | 12 +-- airflow/models/dag.py | 78 +++++++++++++++---- airflow/models/dagbag.py | 30 ++++--- airflow/models/serialized_dag.py | 2 +- airflow/models/taskinstance.py | 30 +++++-- airflow/serialization/serialized_objects.py | 1 - .../www/templates/airflow/dag_details.html | 4 +- airflow/www/views.py | 3 +- tests/cluster_policies/__init__.py | 4 +- tests/dag_processing/test_manager.py | 6 +- tests/dag_processing/test_processor.py | 2 +- tests/jobs/test_scheduler_job.py | 4 +- tests/models/__init__.py | 2 +- tests/models/test_dag.py | 14 ++++ tests/models/test_dagbag.py | 7 +- tests/models/test_renderedtifields.py | 2 + tests/models/test_serialized_dag.py | 6 +- tests/models/test_taskinstance.py | 27 ++++++- tests/serialization/test_dag_serialization.py | 4 - 19 files changed, 163 insertions(+), 75 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 6c2a9fce0d6b6..d2a19cae93f3c 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -496,18 +496,8 @@ def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) -> """ # actually enqueue them for ti in task_instances: - command = TI.generate_command( - ti.dag_id, - ti.task_id, - ti.execution_date, + command = ti.command_as_list( local=True, - mark_success=False, - ignore_all_deps=False, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pool=ti.pool, - file_path=ti.dag_model.fileloc, pickle_id=ti.dag_model.pickle_id, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 22e781c23f14e..b609b226cd42f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -20,6 +20,7 @@ import functools import logging import os +import pathlib import pickle import re import sys @@ -236,13 +237,21 @@ class DAG(LoggingMixin): 'parent_dag', 'start_date', 'schedule_interval', - 'full_filepath', + 'fileloc', 'template_searchpath', 'last_loaded', } __serialized_fields: Optional[FrozenSet[str]] = None + fileloc: str + """ + File path that needs to be imported to load this DAG or subdag. + + This may not be an actual file on disk in the case when this DAG is loaded + from a ZIP file or other DAG distribution format. + """ + def __init__( self, dag_id: str, @@ -286,10 +295,16 @@ def __init__( self.params.update(self.default_args['params']) del self.default_args['params'] + if full_filepath: + warnings.warn( + "Passing full_filepath to DAG() is deprecated and has no effect", + DeprecationWarning, + stacklevel=2, + ) + validate_key(dag_id) self._dag_id = dag_id - self._full_filepath = full_filepath if full_filepath else '' if concurrency and not max_active_tasks: # TODO: Remove in Airflow 3.0 warnings.warn( @@ -655,11 +670,22 @@ def dag_id(self, value: str) -> None: @property def full_filepath(self) -> str: - return self._full_filepath + """:meta private:""" + warnings.warn( + "DAG.full_filepath is deprecated in favour of fileloc", + DeprecationWarning, + stacklevel=2, + ) + return self.fileloc @full_filepath.setter def full_filepath(self, value) -> None: - self._full_filepath = value + warnings.warn( + "DAG.full_filepath is deprecated in favour of fileloc", + DeprecationWarning, + stacklevel=2, + ) + self.fileloc = value @property def concurrency(self) -> int: @@ -735,15 +761,26 @@ def task_group(self) -> "TaskGroup": @property def filepath(self) -> str: - """File location of where the dag object is instantiated""" - fn = self.full_filepath.replace(settings.DAGS_FOLDER + '/', '') - fn = fn.replace(os.path.dirname(__file__) + '/', '') - return fn + """:meta private:""" + warnings.warn( + "filepath is deprecated, use relative_fileloc instead", DeprecationWarning, stacklevel=2 + ) + return str(self.relative_fileloc) + + @property + def relative_fileloc(self) -> pathlib.Path: + """File location of the importable dag 'file' relative to the configured DAGs folder.""" + path = pathlib.Path(self.fileloc) + try: + return path.relative_to(settings.DAGS_FOLDER) + except ValueError: + # Not relative to DAGS_FOLDER. + return path @property def folder(self) -> str: """Folder location of where the DAG object is instantiated.""" - return os.path.dirname(self.full_filepath) + return os.path.dirname(self.fileloc) @property def owner(self) -> str: @@ -2118,9 +2155,11 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): .group_by(DagRun.dag_id) .all() ) + filelocs = [] for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): dag = dag_by_ids[orm_dag.dag_id] + filelocs.append(dag.fileloc) if dag.is_subdag: orm_dag.is_subdag = True orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore @@ -2157,7 +2196,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): session.add(dag_tag_orm) if settings.STORE_DAG_CODE: - DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags]) + DagCode.bulk_sync_to_db(filelocs) # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller # decide when to commit @@ -2274,7 +2313,6 @@ def get_serialized_fields(cls): '_old_context_manager_dags', 'safe_dag_id', 'last_loaded', - '_full_filepath', 'user_defined_filters', 'user_defined_macros', 'partial', @@ -2332,7 +2370,7 @@ class DagModel(Base): These items are stored in the database for state related information """ dag_id = Column(String(ID_LEN), primary_key=True) - root_dag_id = Column(String(ID_LEN)) + root_dag_id = Column(String(ID_LEN), ForeignKey("dag.dag_id")) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation') @@ -2382,6 +2420,8 @@ class DagModel(Base): Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), ) + parent_dag = relationship("DagModel", remote_side=[dag_id]) + NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10) def __init__(self, concurrency=None, **kwargs): @@ -2410,7 +2450,7 @@ def timezone(self): @staticmethod @provide_session def get_dagmodel(dag_id, session=None): - return session.query(DagModel).filter(DagModel.dag_id == dag_id).first() + return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id) @classmethod @provide_session @@ -2455,6 +2495,18 @@ def get_default_view(self) -> str: def safe_dag_id(self): return self.dag_id.replace('.', '__dot__') + @property + def relative_fileloc(self) -> Optional[pathlib.Path]: + """File location of the importable dag 'file' relative to the configured DAGs folder.""" + if self.fileloc is None: + return None + path = pathlib.Path(self.fileloc) + try: + return path.relative_to(settings.DAGS_FOLDER) + except ValueError: + # Not relative to DAGS_FOLDER. + return path + @provide_session def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None: """ diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index e553b21a4075c..b1249b518b9bd 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -381,16 +381,12 @@ def _load_modules_from_zip(self, filepath, safe_mode): def _process_modules(self, filepath, mods, file_last_changed_on_disk): from airflow.models.dag import DAG # Avoid circular import - is_zipfile = zipfile.is_zipfile(filepath) - top_level_dags = [o for m in mods for o in list(m.__dict__.values()) if isinstance(o, DAG)] + top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)) found_dags = [] - for dag in top_level_dags: - if not dag.full_filepath: - dag.full_filepath = filepath - if dag.fileloc != filepath and not is_zipfile: - dag.fileloc = filepath + for (dag, mod) in top_level_dags: + dag.fileloc = mod.__file__ try: dag.is_subdag = False dag.timetable.validate() @@ -398,17 +394,17 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): found_dags.append(dag) found_dags += dag.subdags except AirflowTimetableInvalid as exception: - self.log.exception("Failed to bag_dag: %s", dag.full_filepath) - self.import_errors[dag.full_filepath] = f"Invalid timetable expression: {exception}" - self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk + self.log.exception("Failed to bag_dag: %s", dag.fileloc) + self.import_errors[dag.fileloc] = f"Invalid timetable expression: {exception}" + self.file_last_changed[dag.fileloc] = file_last_changed_on_disk except ( AirflowDagCycleException, AirflowDagDuplicatedIdException, AirflowClusterPolicyViolation, ) as exception: - self.log.exception("Failed to bag_dag: %s", dag.full_filepath) - self.import_errors[dag.full_filepath] = str(exception) - self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk + self.log.exception("Failed to bag_dag: %s", dag.fileloc) + self.import_errors[dag.fileloc] = str(exception) + self.file_last_changed[dag.fileloc] = file_last_changed_on_disk return found_dags def bag_dag(self, dag, root_dag): @@ -444,17 +440,17 @@ def _bag_dag(self, *, dag, root_dag, recursive): # into further _bag_dag() calls. if recursive: for subdag in subdags: - subdag.full_filepath = dag.full_filepath + subdag.fileloc = dag.fileloc subdag.parent_dag = dag subdag.is_subdag = True self._bag_dag(dag=subdag, root_dag=root_dag, recursive=False) prev_dag = self.dags.get(dag.dag_id) - if prev_dag and prev_dag.full_filepath != dag.full_filepath: + if prev_dag and prev_dag.fileloc != dag.fileloc: raise AirflowDagDuplicatedIdException( dag_id=dag.dag_id, - incoming=dag.full_filepath, - existing=self.dags[dag.dag_id].full_filepath, + incoming=dag.fileloc, + existing=self.dags[dag.dag_id].fileloc, ) self.dags[dag.dag_id] = dag self.log.debug('Loaded DAG %s', dag) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 1dabc0ed53d6b..98d933f29425f 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -90,7 +90,7 @@ class SerializedDagModel(Base): def __init__(self, dag: DAG): self.dag_id = dag.dag_id - self.fileloc = dag.full_filepath + self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.data = SerializedDAG.to_dict(dag) self.last_updated = timezone.utcnow() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 25d5c4a750b4c..1b76145a469e2 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -26,7 +26,7 @@ from collections import defaultdict from datetime import datetime, timedelta from tempfile import NamedTemporaryFile -from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import IO, TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union from urllib.parse import quote import dill @@ -91,6 +91,10 @@ log = logging.getLogger(__name__) +if TYPE_CHECKING: + from airflow.models.dag import DAG, DagModel + + @contextlib.contextmanager def set_current_context(context: Context): """ @@ -436,15 +440,25 @@ def command_as_list( installed. This command is part of the message sent to executors by the orchestrator. """ - dag = self.task.dag + dag: Union["DAG", "DagModel"] + # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded + if hasattr(self, 'task') and hasattr(self.task, 'dag'): + dag = self.task.dag + else: + dag = self.dag_model should_pass_filepath = not pickle_id and dag - if should_pass_filepath and dag.full_filepath != dag.filepath: - path = f"DAGS_FOLDER/{dag.filepath}" - elif should_pass_filepath and dag.full_filepath: - path = dag.full_filepath - else: - path = None + path = None + if should_pass_filepath: + if dag.is_subdag: + path = dag.parent_dag.relative_fileloc + else: + path = dag.relative_fileloc + + if path: + if not path.is_absolute(): + path = 'DAGS_FOLDER' / path + path = str(path) return TaskInstance.generate_command( self.dag_id, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index a62a66d930fc8..e1fd91341802d 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -752,7 +752,6 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': for k in keys_to_set_none: setattr(dag, k, None) - setattr(dag, 'full_filepath', dag.fileloc) for task in dag.task_dict.values(): task.dag = dag serializable_task: BaseOperator = task diff --git a/airflow/www/templates/airflow/dag_details.html b/airflow/www/templates/airflow/dag_details.html index af03e3dae8574..4d8fe0eb4b264 100644 --- a/airflow/www/templates/airflow/dag_details.html +++ b/airflow/www/templates/airflow/dag_details.html @@ -88,8 +88,8 @@

{{ title }}

{{ dag.task_ids }} - Filepath - {{ dag.filepath }} + Relative file location + {{ dag.relative_fileloc }} Owner diff --git a/airflow/www/views.py b/airflow/www/views.py index 8c9bf8f0f71c5..9d1d4ad4f1f82 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1388,9 +1388,8 @@ def xcom(self, session=None): dttm = timezone.parse(execution_date) form = DateTimeForm(data={'execution_date': dttm}) root = request.args.get('root', '') - dm_db = models.DagModel ti_db = models.TaskInstance - dag = session.query(dm_db).filter(dm_db.dag_id == dag_id).first() + dag = DagModel.get_dagmodel(dag_id) ti = session.query(ti_db).filter(and_(ti_db.dag_id == dag_id, ti_db.task_id == task_id)).first() if not ti: diff --git a/tests/cluster_policies/__init__.py b/tests/cluster_policies/__init__.py index d395ec0982c04..521a8522908b8 100644 --- a/tests/cluster_policies/__init__.py +++ b/tests/cluster_policies/__init__.py @@ -52,7 +52,7 @@ def _check_task_rules(current_task: BaseOperator): if notices: notices_list = " * " + "\n * ".join(notices) raise AirflowClusterPolicyViolation( - f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n" + f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.fileloc}):\n" f"Notices:\n" f"{notices_list}" ) @@ -70,7 +70,7 @@ def dag_policy(dag: DAG): """Ensure that DAG has at least one tag""" if not dag.tags: raise AirflowClusterPolicyViolation( - f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.filepath}" + f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.fileloc}" ) diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index d34eca23585b3..52d7487cd9f57 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -407,9 +407,9 @@ def test_find_zombies(self): seconds=manager._zombie_threshold_secs + 1 ) manager._find_zombies() - requests = manager._callback_to_execute[dag.full_filepath] + requests = manager._callback_to_execute[dag.fileloc] assert 1 == len(requests) - assert requests[0].full_filepath == dag.full_filepath + assert requests[0].full_filepath == dag.fileloc assert requests[0].msg == "Detected as zombie" assert requests[0].is_failure_callback is True assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance) @@ -451,7 +451,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p expected_failure_callback_requests = [ TaskCallbackRequest( - full_filepath=dag.full_filepath, + full_filepath=dag.fileloc, simple_task_instance=SimpleTaskInstance(ti), msg="Message", ) diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index b6b95893acccc..3de9952d155bf 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -685,7 +685,7 @@ def test_process_file_should_failure_callback(self): requests = [ TaskCallbackRequest( - full_filepath=dag.full_filepath, + full_filepath=dag.fileloc, simple_task_instance=SimpleTaskInstance(ti), msg="Message", ) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 87a624015aef3..7b1f1fbbcb028 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -200,8 +200,8 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback): dag_id2 = "test_process_executor_events_2" task_id_1 = 'dummy_task' - dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, full_filepath="/test_path1/") - dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, full_filepath="/test_path1/") + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) + dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) DummyOperator(dag=dag2, task_id=task_id_1) dag.fileloc = "/test_path1/" diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 7fba62e42e300..2d4a0d9a430de 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -21,4 +21,4 @@ from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2016, 1, 1) -TEST_DAGS_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../dags') +TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags')) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 726c07bfa1d9d..0b673b9d95c6c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -25,6 +25,7 @@ import unittest from contextlib import redirect_stdout from datetime import timedelta +from pathlib import Path from tempfile import NamedTemporaryFile from typing import Optional from unittest import mock @@ -1770,6 +1771,19 @@ def test_dags_needing_dagruns_only_unpaused(self): session.rollback() session.close() + @pytest.mark.parametrize( + ('fileloc', 'expected_relative'), + [ + (os.path.join(settings.DAGS_FOLDER, 'a.py'), Path('a.py')), + ('/tmp/foo.py', Path('/tmp/foo.py')), + ], + ) + def test_relative_fileloc(self, fileloc, expected_relative): + dag = DAG(dag_id='test') + dag.fileloc = fileloc + + assert dag.relative_fileloc == expected_relative + class TestQueries(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index a065318152e28..37bb0106da5a3 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -246,7 +246,7 @@ def test_get_dag_fileloc(self): expected = { 'example_bash_operator': 'airflow/example_dags/example_bash_operator.py', 'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py', - 'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py', + 'example_subdag_operator.section-1': 'airflow/example_dags/example_subdag_operator.py', 'test_zip_dag': 'dags/test_zip.zip/test_zip.py', } @@ -507,12 +507,15 @@ def subdag_1(): assert len(test_dag.subdags) == 6 # Perform processing dag - dagbag, found_dags, _ = self.process_dag(nested_subdags) + dagbag, found_dags, filename = self.process_dag(nested_subdags) # Validate correctness # all dags from test_dag should be listed self.validate_dags(test_dag, found_dags, dagbag) + for dag in dagbag.dags.values(): + assert dag.fileloc == filename + def test_skip_cycle_dags(self): """ Don't crash when loading an invalid (contains a cycle) DAG file. diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index f76078c008cb5..02c72be994cb0 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -26,6 +26,7 @@ from parameterized import parameterized from airflow import settings +from airflow.configuration import TEST_DAGS_FOLDER from airflow.models import Variable from airflow.models.dag import DAG from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF @@ -244,6 +245,7 @@ def test_get_k8s_pod_yaml(self, redact): dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo hi") + dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py' ti = TI(task=task, execution_date=EXECUTION_DATE) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 3e68ddc2cc413..aea143fc6a24a 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -71,7 +71,7 @@ def test_write_dag(self): assert SDM.has_dag(dag.dag_id) result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one() - assert result.fileloc == dag.full_filepath + assert result.fileloc == dag.fileloc # Verifies JSON schema. SerializedDAG.validate_schema(result.data) @@ -138,8 +138,8 @@ def test_remove_dags_by_filepath(self): # Tests removing by file path. dag_removed_by_file = filtered_example_dags_list[0] # remove repeated files for those DAGs that define multiple dags in the same file (set comprehension) - example_dag_files = list({dag.full_filepath for dag in filtered_example_dags_list}) - example_dag_files.remove(dag_removed_by_file.full_filepath) + example_dag_files = list({dag.fileloc for dag in filtered_example_dags_list}) + example_dag_files.remove(dag_removed_by_file.fileloc) SDM.remove_deleted_dags(example_dag_files) assert not SDM.has_dag(dag_removed_by_file.dag_id) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 5c86a221788ce..7d64c76c5a92f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -68,7 +68,7 @@ from airflow.utils.state import State from airflow.utils.types import DagRunType from airflow.version import version -from tests.models import DEFAULT_DATE +from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars @@ -1899,6 +1899,26 @@ def test_task_stats(self, stats_mock): assert call(f'ti.start.{dag.dag_id}.{op.task_id}') in stats_mock.mock_calls assert stats_mock.call_count == 5 + def test_command_as_list(self): + dag = DAG( + 'test_dag', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + dag.fileloc = os.path.join(TEST_DAGS_FOLDER, 'x.py') + op = DummyOperator(task_id='dummy_op', dag=dag) + ti = TI(task=op, execution_date=DEFAULT_DATE) + assert ti.command_as_list() == [ + 'airflow', + 'tasks', + 'run', + dag.dag_id, + op.task_id, + DEFAULT_DATE.isoformat(), + '--subdir', + 'DAGS_FOLDER/x.py', + ] + def test_generate_command_default_param(self): dag_id = 'test_generate_command_default_param' task_id = 'task' @@ -1925,8 +1945,9 @@ def test_generate_command_specific_param(self): def test_get_rendered_template_fields(self): - with DAG('test-dag', start_date=DEFAULT_DATE): + with DAG('test-dag', start_date=DEFAULT_DATE) as dag: task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}") + dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py' ti = TI(task=task, execution_date=DEFAULT_DATE) @@ -1984,6 +2005,8 @@ def test_render_k8s_pod_yaml(self, pod_mutation_hook): 'test_get_rendered_k8s_spec', 'op1', '2016-01-01T00:00:00+00:00', + '--subdir', + __file__, ], 'image': ':', 'name': 'base', diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 40d9bc077c441..7e2798c4cbda4 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -380,10 +380,6 @@ def validate_deserialized_dag(self, serialized_dag, dag): for task_id in dag.task_ids: self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id)) - # Verify that the DAG object has 'full_filepath' attribute - # and is equal to fileloc - assert serialized_dag.full_filepath == dag.fileloc - def validate_deserialized_task( self, serialized_task,