Skip to content

Commit

Permalink
Fix airflow tasks run --local when dags_folder differs from that of p…
Browse files Browse the repository at this point in the history
…rocessor (#26509)

Previously the code used the dags_folder of the "current" process (e.g. the celery worker, or k8s executor worker pod) to calculate the relative fileloc based on the full fileloc stored in the serialized dag.  But if the worker dags_folder folder is different from the dags folder configured on the dag processor, then airflow can't calculate the relative path, so it will just use the full path, which in this case will be a bad path.  We can fix this by keeping track of the dags_folder from the dag processor that serialized the dag, and using this for figuring out the relative path.

(cherry picked from commit c94f978)
  • Loading branch information
dstandish authored and ephraimbuddy committed Oct 18, 2022
1 parent a643184 commit 804abc7
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 11 deletions.
11 changes: 10 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,11 @@ def __init__(
f"Bad formatted links are: {wrong_links}"
)

# this will only be set at serialization time
# it's only use is for determining the relative
# fileloc based only on the serialize dag
self._processor_dags_folder = None

def get_doc_md(self, doc_md: str | None) -> str | None:
if doc_md is None:
return doc_md
Expand Down Expand Up @@ -1189,7 +1194,11 @@ def relative_fileloc(self) -> pathlib.Path:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
path = pathlib.Path(self.fileloc)
try:
return path.relative_to(settings.DAGS_FOLDER)
rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER)
if rel_path == pathlib.Path('.'):
return path
else:
return rel_path
except ValueError:
# Not relative to DAGS_FOLDER.
return path
Expand Down
8 changes: 7 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@
"catchup": { "type": "boolean" },
"is_subdag": { "type": "boolean" },
"fileloc": { "type" : "string"},
"orientation": { "type" : "string"},
"_processor_dags_folder": {
"anyOf": [
{ "type": "null" },
{"type": "string"}
]
},
"orientation": { "type" : "string"},
"_description": { "type" : "string"},
"_concurrency": { "type" : "number"},
"_max_active_tasks": { "type" : "number"},
Expand Down
4 changes: 3 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import Validator, load_dag_schema
from airflow.settings import json
from airflow.settings import DAGS_FOLDER, json
from airflow.timetables.base import Timetable
from airflow.utils.code_utils import get_python_source
from airflow.utils.docs import get_docs_url
Expand Down Expand Up @@ -1120,6 +1120,8 @@ def serialize_dag(cls, dag: DAG) -> dict:
try:
serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)

serialized_dag['_processor_dags_folder'] = DAGS_FOLDER

# If schedule_interval is backed by timetable, serialize only
# timetable; vice versa for a timetable backed by schedule_interval.
if dag.timetable.summary == dag.schedule_interval:
Expand Down
45 changes: 40 additions & 5 deletions airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import warnings
from argparse import Namespace
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Callable, TypeVar, cast

from airflow import settings
Expand All @@ -43,6 +44,8 @@
if TYPE_CHECKING:
from airflow.models.dag import DAG

logger = logging.getLogger(__name__)


def _check_cli_args(args):
if not args:
Expand Down Expand Up @@ -181,15 +184,47 @@ def get_dag_by_file_location(dag_id: str):
return dagbag.dags[dag_id]


def _search_for_dag_file(val: str | None) -> str | None:
"""
Search for the file referenced at fileloc.
By the time we get to this function, we've already run this `val` through `process_subdir`
and loaded the DagBag there and came up empty. So here, if `val` is a file path, we make
a last ditch effort to try and find a dag file with the same name in our dags folder. (This
avoids the unnecessary dag parsing that would occur if we just parsed the dags folder).
If `val` is a path to a file, this likely means that the serializing process had a dags_folder
equal to only the dag file in question. This prevents us from determining the relative location.
And if the paths are different between worker and dag processor / scheduler, then we won't find
the dag at the given location.
"""
if val and Path(val).suffix in ('.zip', '.py'):
matches = list(Path(settings.DAGS_FOLDER).rglob(Path(val).name))
if len(matches) == 1:
return matches[0].as_posix()
return None


def get_dag(subdir: str | None, dag_id: str) -> DAG:
"""Returns DAG of a given dag_id"""
"""
Returns DAG of a given dag_id
First it we'll try to use the given subdir. If that doesn't work, we'll try to
find the correct path (assuming it's a file) and failing that, use the configured
dags folder.
"""
from airflow.models import DagBag

dagbag = DagBag(process_subdir(subdir))
first_path = process_subdir(subdir)
dagbag = DagBag(first_path)
if dag_id not in dagbag.dags:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
)
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
dagbag = DagBag(dag_folder=fallback_path)
if dag_id not in dagbag.dags:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
)
return dagbag.dags[dag_id]


Expand Down
84 changes: 83 additions & 1 deletion tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import logging
import os
import re
import tempfile
import unittest
from argparse import ArgumentParser
from contextlib import redirect_stdout
from contextlib import contextmanager, redirect_stdout
from pathlib import Path
from unittest import mock

Expand Down Expand Up @@ -60,6 +61,13 @@ def reset(dag_id):
runs.delete()


@contextmanager
def move_back(old_path, new_path):
os.rename(old_path, new_path)
yield
os.rename(new_path, old_path)


# TODO: Check if tests needs side effects - locally there's missing DAG
class TestCliTasks:
run_id = 'TEST_RUN_ID'
Expand Down Expand Up @@ -183,6 +191,80 @@ def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserializ
)
mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)

def test_cli_test_different_path(self, session):
"""
When thedag processor has a different dags folder
from the worker, ``airflow tasks run --local`` should still work.
"""
repo_root = Path(__file__).parent.parent.parent.parent
orig_file_path = repo_root / 'tests/dags/test_dags_folder.py'
orig_dags_folder = orig_file_path.parent

# parse dag in original path
with conf_vars({('core', 'dags_folder'): orig_dags_folder.as_posix()}):
dagbag = DagBag(include_examples=False)
dag = dagbag.get_dag('test_dags_folder')
dagbag.sync_to_db(session=session)

dag.create_dagrun(
state=State.NONE,
run_id='abc123',
run_type=DagRunType.MANUAL,
execution_date=pendulum.now('UTC'),
session=session,
)
session.commit()

# now let's move the file
# additionally let's update the dags folder to be the new path
# ideally since dags_folder points correctly to the file, airflow
# should be able to find the dag.
with tempfile.TemporaryDirectory() as td:
new_file_path = Path(td) / Path(orig_file_path).name
new_dags_folder = new_file_path.parent
with move_back(orig_file_path, new_file_path), conf_vars(
{('core', 'dags_folder'): new_dags_folder.as_posix()}
):
ser_dag = (
session.query(SerializedDagModel)
.filter(SerializedDagModel.dag_id == 'test_dags_folder')
.one()
)
# confirm that the serialized dag location has not been updated
assert ser_dag.fileloc == orig_file_path.as_posix()
assert ser_dag.data['dag']['_processor_dags_folder'] == orig_dags_folder.as_posix()
assert ser_dag.data['dag']['fileloc'] == orig_file_path.as_posix()
assert ser_dag.dag._processor_dags_folder == orig_dags_folder.as_posix()
from airflow.settings import DAGS_FOLDER

assert DAGS_FOLDER == new_dags_folder.as_posix() != orig_dags_folder.as_posix()
task_command.task_run(
self.parser.parse_args(
[
'tasks',
'run',
'--ignore-all-dependencies',
'--local',
'test_dags_folder',
'task',
'abc123',
]
)
)
ti = (
session.query(TaskInstance)
.filter(
TaskInstance.task_id == 'task',
TaskInstance.dag_id == 'test_dags_folder',
TaskInstance.run_id == 'abc123',
TaskInstance.map_index == -1,
)
.one()
)
assert ti.state == 'success'
# verify that the file was in different location when run
assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
Expand Down
38 changes: 38 additions & 0 deletions tests/dags/test_dags_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pendulum

from airflow import DAG
from airflow.decorators import task

with DAG(
dag_id='test_dags_folder',
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
) as dag:

@task(task_id="task")
def return_file_path():
"""Print the Airflow context and ds variable from the context."""
print(f"dag file location: {__file__}")
return __file__

return_file_path()
57 changes: 56 additions & 1 deletion tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from freezegun import freeze_time
from sqlalchemy import inspect

import airflow
from airflow import models, settings
from airflow.configuration import conf
from airflow.datasets import Dataset
Expand All @@ -47,6 +48,7 @@
from airflow.models.dag import DagOwnerAttributes, dag as dag_decorator, get_dataset_triggered_next_run_info
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel, TaskOutletDatasetReference
from airflow.models.param import DagParam, Param, ParamsDict
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
Expand All @@ -65,12 +67,24 @@
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags
from tests.test_utils.mapping import expand_mapped_task
from tests.test_utils.timetables import cron_timetable, delta_timetable

TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)

repo_root = Path(airflow.__file__).parent.parent


@pytest.fixture
def clear_dags():
clear_db_dags()
clear_db_serialized_dags()
yield
clear_db_dags()
clear_db_serialized_dags()


class TestDag:
def setup_method(self) -> None:
Expand Down Expand Up @@ -2273,6 +2287,47 @@ def test_relative_fileloc(self, fileloc, expected_relative):

assert dag.relative_fileloc == expected_relative

@pytest.mark.parametrize(
'reader_dags_folder', [settings.DAGS_FOLDER, str(repo_root / 'airflow/example_dags')]
)
@pytest.mark.parametrize(
('fileloc', 'expected_relative'),
[
(str(Path(settings.DAGS_FOLDER, 'a.py')), Path('a.py')),
('/tmp/foo.py', Path('/tmp/foo.py')),
],
)
def test_relative_fileloc_serialized(
self, fileloc, expected_relative, session, clear_dags, reader_dags_folder
):
"""
The serialized dag model includes the dags folder as configured on the thing serializing
the dag. On the thing deserializing the dag, when determining relative fileloc,
we should use the dags folder of the processor. So even if the dags folder of
the deserializer is different (meaning that the full path is no longer relative to
the dags folder) then we should still get the relative fileloc as it existed on the
serializer process. When the full path is not relative to the configured dags folder,
then relative fileloc should just be the full path.
"""
dag = DAG(dag_id='test')
dag.fileloc = fileloc
sdm = SerializedDagModel(dag)
session.add(sdm)
session.commit()
session.expunge_all()
sdm = SerializedDagModel.get(dag.dag_id, session)
dag = sdm.dag
with conf_vars({('core', 'dags_folder'): reader_dags_folder}):
assert dag.relative_fileloc == expected_relative

def test__processor_dags_folder(self, session):
"""Only populated after deserializtion"""
dag = DAG(dag_id='test')
dag.fileloc = '/abc/test.py'
assert dag._processor_dags_folder is None
sdm = SerializedDagModel(dag)
assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER

@pytest.mark.need_serialized_dag
def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker):
dataset1 = Dataset(uri="ds1")
Expand Down
Loading

0 comments on commit 804abc7

Please sign in to comment.