Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DAGS folder being in different location on scheduler and runners #16860

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,18 +496,8 @@ def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) ->
"""
# actually enqueue them
for ti in task_instances:
command = TI.generate_command(
ti.dag_id,
ti.task_id,
ti.execution_date,
command = ti.command_as_list(
local=True,
mark_success=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
Comment on lines -504 to -508
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the defaults, we don't need to pass them again

pool=ti.pool,
file_path=ti.dag_model.fileloc,
pickle_id=ti.dag_model.pickle_id,
)

Expand Down
78 changes: 66 additions & 12 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools
import logging
import os
import pathlib
import pickle
import re
import sys
Expand Down Expand Up @@ -236,13 +237,21 @@ class DAG(LoggingMixin):
'parent_dag',
'start_date',
'schedule_interval',
'full_filepath',
'fileloc',
'template_searchpath',
'last_loaded',
}

__serialized_fields: Optional[FrozenSet[str]] = None

fileloc: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me I was wondering whether fileloc is guaranteed to be absolute or not and had to trace a lot of code. Maybe it’s worthwhile to add this to the docstring (and the one on DagModel).

"""
File path that needs to be imported to load this DAG or subdag.

This may not be an actual file on disk in the case when this DAG is loaded
from a ZIP file or other DAG distribution format.
"""

def __init__(
self,
dag_id: str,
Expand Down Expand Up @@ -286,10 +295,16 @@ def __init__(
self.params.update(self.default_args['params'])
del self.default_args['params']

if full_filepath:
warnings.warn(
"Passing full_filepath to DAG() is deprecated and has no effect",
DeprecationWarning,
stacklevel=2,
)

validate_key(dag_id)

self._dag_id = dag_id
self._full_filepath = full_filepath if full_filepath else ''
if concurrency and not max_active_tasks:
# TODO: Remove in Airflow 3.0
warnings.warn(
Expand Down Expand Up @@ -655,11 +670,22 @@ def dag_id(self, value: str) -> None:

@property
def full_filepath(self) -> str:
return self._full_filepath
""":meta private:"""
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
DeprecationWarning,
stacklevel=2,
)
return self.fileloc

@full_filepath.setter
def full_filepath(self, value) -> None:
self._full_filepath = value
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
DeprecationWarning,
stacklevel=2,
)
self.fileloc = value

@property
def concurrency(self) -> int:
Expand Down Expand Up @@ -735,15 +761,26 @@ def task_group(self) -> "TaskGroup":

@property
def filepath(self) -> str:
"""File location of where the dag object is instantiated"""
fn = self.full_filepath.replace(settings.DAGS_FOLDER + '/', '')
fn = fn.replace(os.path.dirname(__file__) + '/', '')
return fn
""":meta private:"""
warnings.warn(
"filepath is deprecated, use relative_fileloc instead", DeprecationWarning, stacklevel=2
)
return str(self.relative_fileloc)

@property
def relative_fileloc(self) -> pathlib.Path:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will this field behave for example DAGs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it would be /usr/lib/python3.7/.../airflow/example_dags/example_bash_operator.py -- so not relative at all.

I'll update/expand the docstring to say how it deals with DAGs outside of the dags folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we are also able to handle examples DAGs by adding a suffix? This will allow us to install workers on shared machines without any problems
For samples:
[exaample_dags_folder]/example_bash_operator.py => /usr/lib/python3.7/.../airflow/example_dags/example_bash_operator.py
[dags_folder]/example_bash_operator.py => ~/home/airflow/dags

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change would be better done along with something like https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-20+DAG+manifest, so I've not done this for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we shouldn’t call this relative fileloc, but just fileloc…? Not sure about this.

path = pathlib.Path(self.fileloc)
try:
return path.relative_to(settings.DAGS_FOLDER)
except ValueError:
# Not relative to DAGS_FOLDER.
return path

@property
def folder(self) -> str:
"""Folder location of where the DAG object is instantiated."""
return os.path.dirname(self.full_filepath)
return os.path.dirname(self.fileloc)

@property
def owner(self) -> str:
Expand Down Expand Up @@ -2118,9 +2155,11 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
.group_by(DagRun.dag_id)
.all()
)
filelocs = []

for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
dag = dag_by_ids[orm_dag.dag_id]
filelocs.append(dag.fileloc)
if dag.is_subdag:
orm_dag.is_subdag = True
orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore
Expand Down Expand Up @@ -2157,7 +2196,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
session.add(dag_tag_orm)

if settings.STORE_DAG_CODE:
DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags])
DagCode.bulk_sync_to_db(filelocs)

# Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
# decide when to commit
Expand Down Expand Up @@ -2274,7 +2313,6 @@ def get_serialized_fields(cls):
'_old_context_manager_dags',
'safe_dag_id',
'last_loaded',
'_full_filepath',
'user_defined_filters',
'user_defined_macros',
'partial',
Expand Down Expand Up @@ -2382,6 +2420,10 @@ class DagModel(Base):
Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False),
)

parent_dag = relationship(
"DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id]
)

NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10)

def __init__(self, concurrency=None, **kwargs):
Expand Down Expand Up @@ -2410,7 +2452,7 @@ def timezone(self):
@staticmethod
@provide_session
def get_dagmodel(dag_id, session=None):
return session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id)

@classmethod
@provide_session
Expand Down Expand Up @@ -2455,6 +2497,18 @@ def get_default_view(self) -> str:
def safe_dag_id(self):
return self.dag_id.replace('.', '__dot__')

@property
def relative_fileloc(self) -> Optional[pathlib.Path]:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
if self.fileloc is None:
return None
path = pathlib.Path(self.fileloc)
try:
return path.relative_to(settings.DAGS_FOLDER)
except ValueError:
# Not relative to DAGS_FOLDER.
return path

@provide_session
def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None:
"""
Expand Down
30 changes: 13 additions & 17 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,34 +381,30 @@ def _load_modules_from_zip(self, filepath, safe_mode):
def _process_modules(self, filepath, mods, file_last_changed_on_disk):
from airflow.models.dag import DAG # Avoid circular import

is_zipfile = zipfile.is_zipfile(filepath)
top_level_dags = [o for m in mods for o in list(m.__dict__.values()) if isinstance(o, DAG)]
top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG))

found_dags = []

for dag in top_level_dags:
if not dag.full_filepath:
dag.full_filepath = filepath
if dag.fileloc != filepath and not is_zipfile:
dag.fileloc = filepath
for (dag, mod) in top_level_dags:
dag.fileloc = mod.__file__
try:
dag.is_subdag = False
dag.timetable.validate()
self.bag_dag(dag=dag, root_dag=dag)
found_dags.append(dag)
found_dags += dag.subdags
except AirflowTimetableInvalid as exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = f"Invalid timetable expression: {exception}"
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = f"Invalid timetable expression: {exception}"
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
except (
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
AirflowClusterPolicyViolation,
) as exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = str(exception)
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = str(exception)
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
return found_dags

def bag_dag(self, dag, root_dag):
Expand Down Expand Up @@ -444,17 +440,17 @@ def _bag_dag(self, *, dag, root_dag, recursive):
# into further _bag_dag() calls.
if recursive:
for subdag in subdags:
subdag.full_filepath = dag.full_filepath
subdag.fileloc = dag.fileloc
subdag.parent_dag = dag
subdag.is_subdag = True
self._bag_dag(dag=subdag, root_dag=root_dag, recursive=False)

prev_dag = self.dags.get(dag.dag_id)
if prev_dag and prev_dag.full_filepath != dag.full_filepath:
if prev_dag and prev_dag.fileloc != dag.fileloc:
raise AirflowDagDuplicatedIdException(
dag_id=dag.dag_id,
incoming=dag.full_filepath,
existing=self.dags[dag.dag_id].full_filepath,
incoming=dag.fileloc,
existing=self.dags[dag.dag_id].fileloc,
)
self.dags[dag.dag_id] = dag
self.log.debug('Loaded DAG %s', dag)
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class SerializedDagModel(Base):

def __init__(self, dag: DAG):
self.dag_id = dag.dag_id
self.fileloc = dag.full_filepath
self.fileloc = dag.fileloc
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.data = SerializedDAG.to_dict(dag)
self.last_updated = timezone.utcnow()
Expand Down
30 changes: 22 additions & 8 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections import defaultdict
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import IO, TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -91,6 +91,10 @@
log = logging.getLogger(__name__)


if TYPE_CHECKING:
from airflow.models.dag import DAG, DagModel


@contextlib.contextmanager
def set_current_context(context: Context):
"""
Expand Down Expand Up @@ -436,15 +440,25 @@ def command_as_list(
installed. This command is part of the message sent to executors by
the orchestrator.
"""
dag = self.task.dag
dag: Union["DAG", "DagModel"]
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(self, 'task') and hasattr(self.task, 'dag'):
dag = self.task.dag
else:
dag = self.dag_model

should_pass_filepath = not pickle_id and dag
if should_pass_filepath and dag.full_filepath != dag.filepath:
path = f"DAGS_FOLDER/{dag.filepath}"
elif should_pass_filepath and dag.full_filepath:
path = dag.full_filepath
else:
path = None
path = None
if should_pass_filepath:
if dag.is_subdag:
path = dag.parent_dag.relative_fileloc
else:
path = dag.relative_fileloc

if path:
if not path.is_absolute():
path = 'DAGS_FOLDER' / path
path = str(path)

return TaskInstance.generate_command(
self.dag_id,
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
for k in keys_to_set_none:
setattr(dag, k, None)

setattr(dag, 'full_filepath', dag.fileloc)
for task in dag.task_dict.values():
task.dag = dag
serializable_task: BaseOperator = task
Expand Down
4 changes: 2 additions & 2 deletions airflow/www/templates/airflow/dag_details.html
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ <h3>{{ title }}</h3>
<td>{{ dag.task_ids }}</td>
</tr>
<tr>
<th>Filepath</th>
<td>{{ dag.filepath }}</td>
<th>Relative file location</th>
<td>{{ dag.relative_fileloc }}</td>
</tr>
<tr>
<th>Owner</th>
Expand Down
3 changes: 1 addition & 2 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,9 +1388,8 @@ def xcom(self, session=None):
dttm = timezone.parse(execution_date)
form = DateTimeForm(data={'execution_date': dttm})
root = request.args.get('root', '')
dm_db = models.DagModel
ti_db = models.TaskInstance
dag = session.query(dm_db).filter(dm_db.dag_id == dag_id).first()
dag = DagModel.get_dagmodel(dag_id)
ti = session.query(ti_db).filter(and_(ti_db.dag_id == dag_id, ti_db.task_id == task_id)).first()

if not ti:
Expand Down
4 changes: 2 additions & 2 deletions tests/cluster_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _check_task_rules(current_task: BaseOperator):
if notices:
notices_list = " * " + "\n * ".join(notices)
raise AirflowClusterPolicyViolation(
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n"
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.fileloc}):\n"
f"Notices:\n"
f"{notices_list}"
)
Expand All @@ -70,7 +70,7 @@ def dag_policy(dag: DAG):
"""Ensure that DAG has at least one tag"""
if not dag.tags:
raise AirflowClusterPolicyViolation(
f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.filepath}"
f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.fileloc}"
)


Expand Down
6 changes: 3 additions & 3 deletions tests/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def test_find_zombies(self):
seconds=manager._zombie_threshold_secs + 1
)
manager._find_zombies()
requests = manager._callback_to_execute[dag.full_filepath]
requests = manager._callback_to_execute[dag.fileloc]
assert 1 == len(requests)
assert requests[0].full_filepath == dag.full_filepath
assert requests[0].full_filepath == dag.fileloc
assert requests[0].msg == "Detected as zombie"
assert requests[0].is_failure_callback is True
assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p

expected_failure_callback_requests = [
TaskCallbackRequest(
full_filepath=dag.full_filepath,
full_filepath=dag.fileloc,
simple_task_instance=SimpleTaskInstance(ti),
msg="Message",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def test_process_file_should_failure_callback(self):

requests = [
TaskCallbackRequest(
full_filepath=dag.full_filepath,
full_filepath=dag.fileloc,
simple_task_instance=SimpleTaskInstance(ti),
msg="Message",
)
Expand Down
Loading