Skip to content

Commit

Permalink
Support DAGS folder being in different location on scheduler and runners
Browse files Browse the repository at this point in the history
There has been some vestigial support for this concept in Airflow for a
while (all the CLI command already turn the literal `DAGS_FOLDER` in to
the real value of the DAGS folder when loading dags), but sometime
around 1.10.1-1.10.3 it got fully broken and the scheduler only ever
passed full paths to DAG files.

This PR brings back this behaviour
  • Loading branch information
ashb committed Aug 4, 2021
1 parent ed99eaa commit 0da2b8c
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 75 deletions.
12 changes: 1 addition & 11 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,18 +496,8 @@ def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) ->
"""
# actually enqueue them
for ti in task_instances:
command = TI.generate_command(
ti.dag_id,
ti.task_id,
ti.execution_date,
command = ti.command_as_list(
local=True,
mark_success=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=ti.pool,
file_path=ti.dag_model.fileloc,
pickle_id=ti.dag_model.pickle_id,
)

Expand Down
78 changes: 65 additions & 13 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools
import logging
import os
import pathlib
import pickle
import re
import sys
Expand Down Expand Up @@ -236,13 +237,21 @@ class DAG(LoggingMixin):
'parent_dag',
'start_date',
'schedule_interval',
'full_filepath',
'fileloc',
'template_searchpath',
'last_loaded',
}

__serialized_fields: Optional[FrozenSet[str]] = None

fileloc: str
"""
File path that needs to be imported to load this DAG or subdag.
This may not be an actual file on disk in the case when this DAG is loaded
from a ZIP file or other DAG distribution format.
"""

def __init__(
self,
dag_id: str,
Expand Down Expand Up @@ -286,10 +295,16 @@ def __init__(
self.params.update(self.default_args['params'])
del self.default_args['params']

if full_filepath:
warnings.warn(
"Passing full_filepath to DAG() is deprecated and has no effect",
DeprecationWarning,
stacklevel=2,
)

validate_key(dag_id)

self._dag_id = dag_id
self._full_filepath = full_filepath if full_filepath else ''
if concurrency and not max_active_tasks:
# TODO: Remove in Airflow 3.0
warnings.warn(
Expand Down Expand Up @@ -655,11 +670,22 @@ def dag_id(self, value: str) -> None:

@property
def full_filepath(self) -> str:
return self._full_filepath
""":meta private:"""
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
DeprecationWarning,
stacklevel=2,
)
return self.fileloc

@full_filepath.setter
def full_filepath(self, value) -> None:
self._full_filepath = value
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
DeprecationWarning,
stacklevel=2,
)
self.fileloc = value

@property
def concurrency(self) -> int:
Expand Down Expand Up @@ -735,15 +761,26 @@ def task_group(self) -> "TaskGroup":

@property
def filepath(self) -> str:
"""File location of where the dag object is instantiated"""
fn = self.full_filepath.replace(settings.DAGS_FOLDER + '/', '')
fn = fn.replace(os.path.dirname(__file__) + '/', '')
return fn
""":meta private:"""
warnings.warn(
"filepath is deprecated, use relative_fileloc instead", DeprecationWarning, stacklevel=2
)
return str(self.relative_fileloc)

@property
def relative_fileloc(self) -> pathlib.Path:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
path = pathlib.Path(self.fileloc)
try:
return path.relative_to(settings.DAGS_FOLDER)
except ValueError:
# Not relative to DAGS_FOLDER.
return path

@property
def folder(self) -> str:
"""Folder location of where the DAG object is instantiated."""
return os.path.dirname(self.full_filepath)
return os.path.dirname(self.fileloc)

@property
def owner(self) -> str:
Expand Down Expand Up @@ -2118,9 +2155,11 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
.group_by(DagRun.dag_id)
.all()
)
filelocs = []

for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
dag = dag_by_ids[orm_dag.dag_id]
filelocs.append(dag.fileloc)
if dag.is_subdag:
orm_dag.is_subdag = True
orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore
Expand Down Expand Up @@ -2157,7 +2196,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
session.add(dag_tag_orm)

if settings.STORE_DAG_CODE:
DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags])
DagCode.bulk_sync_to_db(filelocs)

# Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
# decide when to commit
Expand Down Expand Up @@ -2274,7 +2313,6 @@ def get_serialized_fields(cls):
'_old_context_manager_dags',
'safe_dag_id',
'last_loaded',
'_full_filepath',
'user_defined_filters',
'user_defined_macros',
'partial',
Expand Down Expand Up @@ -2332,7 +2370,7 @@ class DagModel(Base):
These items are stored in the database for state related information
"""
dag_id = Column(String(ID_LEN), primary_key=True)
root_dag_id = Column(String(ID_LEN))
root_dag_id = Column(String(ID_LEN), ForeignKey("dag.dag_id"))
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
Expand Down Expand Up @@ -2382,6 +2420,8 @@ class DagModel(Base):
Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False),
)

parent_dag = relationship("DagModel", remote_side=[dag_id])

NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10)

def __init__(self, concurrency=None, **kwargs):
Expand Down Expand Up @@ -2410,7 +2450,7 @@ def timezone(self):
@staticmethod
@provide_session
def get_dagmodel(dag_id, session=None):
return session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id)

@classmethod
@provide_session
Expand Down Expand Up @@ -2455,6 +2495,18 @@ def get_default_view(self) -> str:
def safe_dag_id(self):
return self.dag_id.replace('.', '__dot__')

@property
def relative_fileloc(self) -> Optional[pathlib.Path]:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
if self.fileloc is None:
return None
path = pathlib.Path(self.fileloc)
try:
return path.relative_to(settings.DAGS_FOLDER)
except ValueError:
# Not relative to DAGS_FOLDER.
return path

@provide_session
def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None:
"""
Expand Down
30 changes: 13 additions & 17 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,34 +381,30 @@ def _load_modules_from_zip(self, filepath, safe_mode):
def _process_modules(self, filepath, mods, file_last_changed_on_disk):
from airflow.models.dag import DAG # Avoid circular import

is_zipfile = zipfile.is_zipfile(filepath)
top_level_dags = [o for m in mods for o in list(m.__dict__.values()) if isinstance(o, DAG)]
top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG))

found_dags = []

for dag in top_level_dags:
if not dag.full_filepath:
dag.full_filepath = filepath
if dag.fileloc != filepath and not is_zipfile:
dag.fileloc = filepath
for (dag, mod) in top_level_dags:
dag.fileloc = mod.__file__
try:
dag.is_subdag = False
dag.timetable.validate()
self.bag_dag(dag=dag, root_dag=dag)
found_dags.append(dag)
found_dags += dag.subdags
except AirflowTimetableInvalid as exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = f"Invalid timetable expression: {exception}"
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = f"Invalid timetable expression: {exception}"
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
except (
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
AirflowClusterPolicyViolation,
) as exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = str(exception)
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = str(exception)
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
return found_dags

def bag_dag(self, dag, root_dag):
Expand Down Expand Up @@ -444,17 +440,17 @@ def _bag_dag(self, *, dag, root_dag, recursive):
# into further _bag_dag() calls.
if recursive:
for subdag in subdags:
subdag.full_filepath = dag.full_filepath
subdag.fileloc = dag.fileloc
subdag.parent_dag = dag
subdag.is_subdag = True
self._bag_dag(dag=subdag, root_dag=root_dag, recursive=False)

prev_dag = self.dags.get(dag.dag_id)
if prev_dag and prev_dag.full_filepath != dag.full_filepath:
if prev_dag and prev_dag.fileloc != dag.fileloc:
raise AirflowDagDuplicatedIdException(
dag_id=dag.dag_id,
incoming=dag.full_filepath,
existing=self.dags[dag.dag_id].full_filepath,
incoming=dag.fileloc,
existing=self.dags[dag.dag_id].fileloc,
)
self.dags[dag.dag_id] = dag
self.log.debug('Loaded DAG %s', dag)
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class SerializedDagModel(Base):

def __init__(self, dag: DAG):
self.dag_id = dag.dag_id
self.fileloc = dag.full_filepath
self.fileloc = dag.fileloc
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.data = SerializedDAG.to_dict(dag)
self.last_updated = timezone.utcnow()
Expand Down
30 changes: 22 additions & 8 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections import defaultdict
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import IO, TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -91,6 +91,10 @@
log = logging.getLogger(__name__)


if TYPE_CHECKING:
from airflow.models.dag import DAG, DagModel


@contextlib.contextmanager
def set_current_context(context: Context):
"""
Expand Down Expand Up @@ -436,15 +440,25 @@ def command_as_list(
installed. This command is part of the message sent to executors by
the orchestrator.
"""
dag = self.task.dag
dag: Union["DAG", "DagModel"]
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(self, 'task') and hasattr(self.task, 'dag'):
dag = self.task.dag
else:
dag = self.dag_model

should_pass_filepath = not pickle_id and dag
if should_pass_filepath and dag.full_filepath != dag.filepath:
path = f"DAGS_FOLDER/{dag.filepath}"
elif should_pass_filepath and dag.full_filepath:
path = dag.full_filepath
else:
path = None
path = None
if should_pass_filepath:
if dag.is_subdag:
path = dag.parent_dag.relative_fileloc
else:
path = dag.relative_fileloc

if path:
if not path.is_absolute():
path = 'DAGS_FOLDER' / path
path = str(path)

return TaskInstance.generate_command(
self.dag_id,
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
for k in keys_to_set_none:
setattr(dag, k, None)

setattr(dag, 'full_filepath', dag.fileloc)
for task in dag.task_dict.values():
task.dag = dag
serializable_task: BaseOperator = task
Expand Down
4 changes: 2 additions & 2 deletions airflow/www/templates/airflow/dag_details.html
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ <h3>{{ title }}</h3>
<td>{{ dag.task_ids }}</td>
</tr>
<tr>
<th>Filepath</th>
<td>{{ dag.filepath }}</td>
<th>Relative file location</th>
<td>{{ dag.relative_fileloc }}</td>
</tr>
<tr>
<th>Owner</th>
Expand Down
3 changes: 1 addition & 2 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,9 +1388,8 @@ def xcom(self, session=None):
dttm = timezone.parse(execution_date)
form = DateTimeForm(data={'execution_date': dttm})
root = request.args.get('root', '')
dm_db = models.DagModel
ti_db = models.TaskInstance
dag = session.query(dm_db).filter(dm_db.dag_id == dag_id).first()
dag = DagModel.get_dagmodel(dag_id)
ti = session.query(ti_db).filter(and_(ti_db.dag_id == dag_id, ti_db.task_id == task_id)).first()

if not ti:
Expand Down
4 changes: 2 additions & 2 deletions tests/cluster_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _check_task_rules(current_task: BaseOperator):
if notices:
notices_list = " * " + "\n * ".join(notices)
raise AirflowClusterPolicyViolation(
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n"
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.fileloc}):\n"
f"Notices:\n"
f"{notices_list}"
)
Expand All @@ -70,7 +70,7 @@ def dag_policy(dag: DAG):
"""Ensure that DAG has at least one tag"""
if not dag.tags:
raise AirflowClusterPolicyViolation(
f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.filepath}"
f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.fileloc}"
)


Expand Down
6 changes: 3 additions & 3 deletions tests/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def test_find_zombies(self):
seconds=manager._zombie_threshold_secs + 1
)
manager._find_zombies()
requests = manager._callback_to_execute[dag.full_filepath]
requests = manager._callback_to_execute[dag.fileloc]
assert 1 == len(requests)
assert requests[0].full_filepath == dag.full_filepath
assert requests[0].full_filepath == dag.fileloc
assert requests[0].msg == "Detected as zombie"
assert requests[0].is_failure_callback is True
assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p

expected_failure_callback_requests = [
TaskCallbackRequest(
full_filepath=dag.full_filepath,
full_filepath=dag.fileloc,
simple_task_instance=SimpleTaskInstance(ti),
msg="Message",
)
Expand Down
Loading

0 comments on commit 0da2b8c

Please sign in to comment.