diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 60660ce0fbe54..1a6b9906b5225 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -584,6 +584,11 @@ def __init__( f"Bad formatted links are: {wrong_links}" ) + # this will only be set at serialization time + # it's only use is for determining the relative + # fileloc based only on the serialize dag + self._processor_dags_folder = None + def get_doc_md(self, doc_md: str | None) -> str | None: if doc_md is None: return doc_md @@ -1189,7 +1194,11 @@ def relative_fileloc(self) -> pathlib.Path: """File location of the importable dag 'file' relative to the configured DAGs folder.""" path = pathlib.Path(self.fileloc) try: - return path.relative_to(settings.DAGS_FOLDER) + rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) + if rel_path == pathlib.Path('.'): + return path + else: + return rel_path except ValueError: # Not relative to DAGS_FOLDER. return path diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index ddbedad42ccd9..13e91b33d6ee7 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -133,7 +133,13 @@ "catchup": { "type": "boolean" }, "is_subdag": { "type": "boolean" }, "fileloc": { "type" : "string"}, - "orientation": { "type" : "string"}, + "_processor_dags_folder": { + "anyOf": [ + { "type": "null" }, + {"type": "string"} + ] + }, + "orientation": { "type" : "string"}, "_description": { "type" : "string"}, "_concurrency": { "type" : "number"}, "_max_active_tasks": { "type" : "number"}, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 969b6014db5f6..542573fbcc738 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -50,7 +50,7 @@ from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field from airflow.serialization.json_schema import Validator, load_dag_schema -from airflow.settings import json +from airflow.settings import DAGS_FOLDER, json from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source from airflow.utils.docs import get_docs_url @@ -1120,6 +1120,8 @@ def serialize_dag(cls, dag: DAG) -> dict: try: serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields) + serialized_dag['_processor_dags_folder'] = DAGS_FOLDER + # If schedule_interval is backed by timetable, serialize only # timetable; vice versa for a timetable backed by schedule_interval. if dag.timetable.summary == dag.schedule_interval: diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 37fe0d270278e..87313f46f5561 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -29,6 +29,7 @@ import warnings from argparse import Namespace from datetime import datetime +from pathlib import Path from typing import TYPE_CHECKING, Callable, TypeVar, cast from airflow import settings @@ -43,6 +44,8 @@ if TYPE_CHECKING: from airflow.models.dag import DAG +logger = logging.getLogger(__name__) + def _check_cli_args(args): if not args: @@ -181,15 +184,47 @@ def get_dag_by_file_location(dag_id: str): return dagbag.dags[dag_id] +def _search_for_dag_file(val: str | None) -> str | None: + """ + Search for the file referenced at fileloc. + + By the time we get to this function, we've already run this `val` through `process_subdir` + and loaded the DagBag there and came up empty. So here, if `val` is a file path, we make + a last ditch effort to try and find a dag file with the same name in our dags folder. (This + avoids the unnecessary dag parsing that would occur if we just parsed the dags folder). + + If `val` is a path to a file, this likely means that the serializing process had a dags_folder + equal to only the dag file in question. This prevents us from determining the relative location. + And if the paths are different between worker and dag processor / scheduler, then we won't find + the dag at the given location. + """ + if val and Path(val).suffix in ('.zip', '.py'): + matches = list(Path(settings.DAGS_FOLDER).rglob(Path(val).name)) + if len(matches) == 1: + return matches[0].as_posix() + return None + + def get_dag(subdir: str | None, dag_id: str) -> DAG: - """Returns DAG of a given dag_id""" + """ + Returns DAG of a given dag_id + + First it we'll try to use the given subdir. If that doesn't work, we'll try to + find the correct path (assuming it's a file) and failing that, use the configured + dags folder. + """ from airflow.models import DagBag - dagbag = DagBag(process_subdir(subdir)) + first_path = process_subdir(subdir) + dagbag = DagBag(first_path) if dag_id not in dagbag.dags: - raise AirflowException( - f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." - ) + fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER + logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path) + dagbag = DagBag(dag_folder=fallback_path) + if dag_id not in dagbag.dags: + raise AirflowException( + f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." + ) return dagbag.dags[dag_id] diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 56a955b0ab452..03b9259f8db28 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -22,9 +22,10 @@ import logging import os import re +import tempfile import unittest from argparse import ArgumentParser -from contextlib import redirect_stdout +from contextlib import contextmanager, redirect_stdout from pathlib import Path from unittest import mock @@ -60,6 +61,13 @@ def reset(dag_id): runs.delete() +@contextmanager +def move_back(old_path, new_path): + os.rename(old_path, new_path) + yield + os.rename(new_path, old_path) + + # TODO: Check if tests needs side effects - locally there's missing DAG class TestCliTasks: run_id = 'TEST_RUN_ID' @@ -183,6 +191,80 @@ def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserializ ) mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) + def test_cli_test_different_path(self, session): + """ + When thedag processor has a different dags folder + from the worker, ``airflow tasks run --local`` should still work. + """ + repo_root = Path(__file__).parent.parent.parent.parent + orig_file_path = repo_root / 'tests/dags/test_dags_folder.py' + orig_dags_folder = orig_file_path.parent + + # parse dag in original path + with conf_vars({('core', 'dags_folder'): orig_dags_folder.as_posix()}): + dagbag = DagBag(include_examples=False) + dag = dagbag.get_dag('test_dags_folder') + dagbag.sync_to_db(session=session) + + dag.create_dagrun( + state=State.NONE, + run_id='abc123', + run_type=DagRunType.MANUAL, + execution_date=pendulum.now('UTC'), + session=session, + ) + session.commit() + + # now let's move the file + # additionally let's update the dags folder to be the new path + # ideally since dags_folder points correctly to the file, airflow + # should be able to find the dag. + with tempfile.TemporaryDirectory() as td: + new_file_path = Path(td) / Path(orig_file_path).name + new_dags_folder = new_file_path.parent + with move_back(orig_file_path, new_file_path), conf_vars( + {('core', 'dags_folder'): new_dags_folder.as_posix()} + ): + ser_dag = ( + session.query(SerializedDagModel) + .filter(SerializedDagModel.dag_id == 'test_dags_folder') + .one() + ) + # confirm that the serialized dag location has not been updated + assert ser_dag.fileloc == orig_file_path.as_posix() + assert ser_dag.data['dag']['_processor_dags_folder'] == orig_dags_folder.as_posix() + assert ser_dag.data['dag']['fileloc'] == orig_file_path.as_posix() + assert ser_dag.dag._processor_dags_folder == orig_dags_folder.as_posix() + from airflow.settings import DAGS_FOLDER + + assert DAGS_FOLDER == new_dags_folder.as_posix() != orig_dags_folder.as_posix() + task_command.task_run( + self.parser.parse_args( + [ + 'tasks', + 'run', + '--ignore-all-dependencies', + '--local', + 'test_dags_folder', + 'task', + 'abc123', + ] + ) + ) + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.task_id == 'task', + TaskInstance.dag_id == 'test_dags_folder', + TaskInstance.run_id == 'abc123', + TaskInstance.map_index == -1, + ) + .one() + ) + assert ti.state == 'success' + # verify that the file was in different location when run + assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix() + @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization") @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization): diff --git a/tests/dags/test_dags_folder.py b/tests/dags/test_dags_folder.py new file mode 100644 index 0000000000000..e4b15a0857640 --- /dev/null +++ b/tests/dags/test_dags_folder.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pendulum + +from airflow import DAG +from airflow.decorators import task + +with DAG( + dag_id='test_dags_folder', + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, +) as dag: + + @task(task_id="task") + def return_file_path(): + """Print the Airflow context and ds variable from the context.""" + print(f"dag file location: {__file__}") + return __file__ + + return_file_path() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 62f057c376b5c..54634e463f96b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -37,6 +37,7 @@ from freezegun import freeze_time from sqlalchemy import inspect +import airflow from airflow import models, settings from airflow.configuration import conf from airflow.datasets import Dataset @@ -47,6 +48,7 @@ from airflow.models.dag import DagOwnerAttributes, dag as dag_decorator, get_dataset_triggered_next_run_info from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel, TaskOutletDatasetReference from airflow.models.param import DagParam, Param, ParamsDict +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator @@ -65,12 +67,24 @@ from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.timetables import cron_timetable, delta_timetable TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) +repo_root = Path(airflow.__file__).parent.parent + + +@pytest.fixture +def clear_dags(): + clear_db_dags() + clear_db_serialized_dags() + yield + clear_db_dags() + clear_db_serialized_dags() + class TestDag: def setup_method(self) -> None: @@ -2273,6 +2287,47 @@ def test_relative_fileloc(self, fileloc, expected_relative): assert dag.relative_fileloc == expected_relative + @pytest.mark.parametrize( + 'reader_dags_folder', [settings.DAGS_FOLDER, str(repo_root / 'airflow/example_dags')] + ) + @pytest.mark.parametrize( + ('fileloc', 'expected_relative'), + [ + (str(Path(settings.DAGS_FOLDER, 'a.py')), Path('a.py')), + ('/tmp/foo.py', Path('/tmp/foo.py')), + ], + ) + def test_relative_fileloc_serialized( + self, fileloc, expected_relative, session, clear_dags, reader_dags_folder + ): + """ + The serialized dag model includes the dags folder as configured on the thing serializing + the dag. On the thing deserializing the dag, when determining relative fileloc, + we should use the dags folder of the processor. So even if the dags folder of + the deserializer is different (meaning that the full path is no longer relative to + the dags folder) then we should still get the relative fileloc as it existed on the + serializer process. When the full path is not relative to the configured dags folder, + then relative fileloc should just be the full path. + """ + dag = DAG(dag_id='test') + dag.fileloc = fileloc + sdm = SerializedDagModel(dag) + session.add(sdm) + session.commit() + session.expunge_all() + sdm = SerializedDagModel.get(dag.dag_id, session) + dag = sdm.dag + with conf_vars({('core', 'dags_folder'): reader_dags_folder}): + assert dag.relative_fileloc == expected_relative + + def test__processor_dags_folder(self, session): + """Only populated after deserializtion""" + dag = DAG(dag_id='test') + dag.fileloc = '/abc/test.py' + assert dag._processor_dags_folder is None + sdm = SerializedDagModel(dag) + assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER + @pytest.mark.need_serialized_dag def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker): dataset1 = Dataset(uri="ds1") diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 44a218290ea53..6b485f74654cd 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -27,6 +27,7 @@ import pickle from datetime import datetime, timedelta from glob import glob +from pathlib import Path from unittest import mock import pendulum @@ -34,6 +35,7 @@ from dateutil.relativedelta import FR, relativedelta from kubernetes.client import models as k8s +import airflow from airflow.datasets import Dataset from airflow.exceptions import SerializationError from airflow.hooks.base import BaseHook @@ -63,6 +65,8 @@ from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable +repo_root = Path(airflow.__file__).parent.parent + class CustomDepOperator(BashOperator): """ @@ -133,6 +137,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_dag_id": "simple_dag", "doc_md": "### DAG Tutorial Documentation", "fileloc": None, + "_processor_dags_folder": f"{repo_root}/tests/dags", "tasks": [ { "task_id": "bash_task", @@ -494,13 +499,17 @@ def validate_deserialized_dag(self, serialized_dag, dag): 'default_args', "_task_group", 'params', + '_processor_dags_folder', } fields_to_check = dag.get_serialized_fields() - exclusion_list for field in fields_to_check: assert getattr(serialized_dag, field) == getattr( dag, field ), f'{dag.dag_id}.{field} does not match' - + # _processor_dags_folder is only populated at serialization time + # it's only used when relying on serialized dag to determine a dag's relative path + assert dag._processor_dags_folder is None + assert serialized_dag._processor_dags_folder == str(repo_root / 'tests/dags') if dag.default_args: for k, v in dag.default_args.items(): if callable(v): diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py index 126f25d90de14..e814d6bfd225e 100644 --- a/tests/utils/test_cli_util.py +++ b/tests/utils/test_cli_util.py @@ -24,14 +24,19 @@ from argparse import Namespace from contextlib import contextmanager from datetime import datetime +from pathlib import Path from unittest import mock import pytest +import airflow from airflow import settings from airflow.exceptions import AirflowException from airflow.models.log import Log from airflow.utils import cli, cli_action_loggers, timezone +from airflow.utils.cli import _search_for_dag_file + +repo_root = Path(airflow.__file__).parent.parent class TestCliUtil: @@ -189,3 +194,19 @@ def fail_func(_): @cli.action_cli(check_db=False) def success_func(_): pass + + +def test__search_for_dags_file(): + dags_folder = settings.DAGS_FOLDER + assert _search_for_dag_file('') is None + assert _search_for_dag_file(None) is None + # if it's a file, and one can be find in subdir, should return full path + assert _search_for_dag_file('any/hi/test_dags_folder.py') == str( + Path(dags_folder) / 'test_dags_folder.py' + ) + # if a folder, even if exists, should return dags folder + existing_folder = Path(settings.DAGS_FOLDER, 'subdir1') + assert existing_folder.exists() + assert _search_for_dag_file(existing_folder.as_posix()) is None + # when multiple files found, default to the dags folder + assert _search_for_dag_file('any/hi/__init__.py') is None