Skip to content

Commit

Permalink
Merge upstream
Browse files Browse the repository at this point in the history
Also a chapter was added to recommend taking a backup before
the migration.

Based on discussions and user input from apache#25866, apache#24526

Closes: apache#24526

Improve cleanup of temporary files in CI (apache#25957)

After recent change in Paralell execution, we start to have
infrequent "no space left on device" message - likely caused by
the /tmp/ generated files clogging the filesystem from multiple
runs. We could fix it by simply running cleanup after parallel
job always, but this is not good due to diagnostics needed
when debugging parallel runs locally so we need to have
a way to skip /tmp files deletion.

This PR fixes the problem twofold:

* cleanup breeze instructions which is run at the beginning of
  every job cleans also /tmp file
* the parallel jobs cleans after themselvs unless skipped.

Properly check the existence of missing mapped TIs (apache#25788)

The previous implementation of missing indexes was not correct. Missing indexes
were being checked every time that `task_instance_scheduling_decision` was called.
The missing tasks should only be revised after expanding of last resort for mapped tasks have been done. If we find that a task is in schedulable state and has already been expanded, we revise its indexes and ensure they are complete. Missing indexes are marked as removed.
This implementation allows the revision to be done in one place

Co-authored-by: Tzu-ping Chung <[email protected]>

Fix dataset_event_manager resolution (apache#25943)

Appears `__init__` is not invoked as part of `_run_raw_task` due to the way TI is refreshed from db.  Centralize dataset manager instantiation instead.

Fix unhashable issue with secrets.backend_kwargs and caching (apache#25970)

Resolves apache#25968

Fix response schema for list-mapped-task-instance (apache#25965)

update areActiveRuns, fix states (apache#25962)
  • Loading branch information
pankajastro authored and anja-istenic committed Aug 29, 2022
1 parent dffa7be commit 5d5fe14
Show file tree
Hide file tree
Showing 35 changed files with 1,534 additions and 698 deletions.
2 changes: 1 addition & 1 deletion airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstance'
$ref: '#/components/schemas/TaskInstanceCollection'
'401':
$ref: '#/components/responses/Unauthenticated'
'403':
Expand Down
12 changes: 12 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,18 @@
type: string
default: "0o077"
example: ~
- name: dataset_event_manager_class
description: Class to use as dataset event manager.
version_added: 2.4.0
type: string
default: ~
example: 'airflow.datasets.manager.DatasetEventManager'
- name: dataset_event_manager_kwargs
description: Kwargs to supply to dataset event manager.
version_added: 2.4.0
type: string
default: ~
example: '{"some_param": "some_value"}'

- name: database
description: ~
Expand Down
8 changes: 8 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ max_map_length = 1024
# This value is treated as an octal-integer.
daemon_umask = 0o077

# Class to use as dataset event manager.
# Example: dataset_event_manager_class = airflow.datasets.manager.DatasetEventManager
# dataset_event_manager_class =

# Kwargs to supply to dataset event manager.
# Example: dataset_event_manager_kwargs = {{"some_param": "some_value"}}
# dataset_event_manager_kwargs =

[database]
# The SqlAlchemy connection string to the metadata database.
# SqlAlchemy supports many different database engines.
Expand Down
15 changes: 7 additions & 8 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,19 +1545,18 @@ def get_custom_secret_backend() -> Optional[BaseSecretsBackend]:
"""Get Secret Backend if defined in airflow.cfg"""
secrets_backend_cls = conf.getimport(section='secrets', key='backend')
if secrets_backend_cls:
try:
backends: Any = conf.get(section='secrets', key='backend_kwargs', fallback='{}')
alternative_secrets_config_dict = json.loads(backends)
except JSONDecodeError:
alternative_secrets_config_dict = {}

return _custom_secrets_backend(secrets_backend_cls, **alternative_secrets_config_dict)
backends: Any = conf.get(section='secrets', key='backend_kwargs', fallback='{}')
return _custom_secrets_backend(secrets_backend_cls, backends)
return None


@functools.lru_cache(maxsize=2)
def _custom_secrets_backend(secrets_backend_cls, **alternative_secrets_config_dict):
def _custom_secrets_backend(secrets_backend_cls, backend_kwargs):
"""Separate function to create secrets backend instance to allow caching"""
try:
alternative_secrets_config_dict = json.loads(backend_kwargs)
except JSONDecodeError:
alternative_secrets_config_dict = {}
return secrets_backend_cls(**alternative_secrets_config_dict)


Expand Down
29 changes: 27 additions & 2 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING

from sqlalchemy.orm.session import Session

from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.taskinstance import TaskInstance
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance


class DatasetEventManager(LoggingMixin):
"""
Expand All @@ -31,8 +36,11 @@ class DatasetEventManager(LoggingMixin):
Airflow deployments can use plugins that broadcast dataset events to each other.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def register_dataset_change(
self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs
self, *, task_instance: "TaskInstance", dataset: Dataset, extra=None, session: Session, **kwargs
) -> None:
"""
For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
Expand All @@ -59,3 +67,20 @@ def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
self.log.debug("consuming dag ids %s", consuming_dag_ids)
for dag_id in consuming_dag_ids:
session.merge(DatasetDagRunQueue(dataset_id=dataset.id, target_dag_id=dag_id))


def resolve_dataset_event_manager():
_dataset_event_manager_class = conf.getimport(
section='core',
key='dataset_event_manager_class',
fallback='airflow.datasets.manager.DatasetEventManager',
)
_dataset_event_manager_kwargs = conf.getjson(
section='core',
key='dataset_event_manager_kwargs',
fallback={},
)
return _dataset_event_manager_class(**_dataset_event_manager_kwargs)


dataset_event_manager = resolve_dataset_event_manager()
131 changes: 54 additions & 77 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,6 @@ def _filter_tis_and_exclude_removed(dag: "DAG", tis: List[TI]) -> Iterable[TI]:
yield ti

tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
missing_indexes = self._revise_mapped_task_indexes(tis, session=session)
if missing_indexes:
self.verify_integrity(missing_indexes=missing_indexes, session=session)

unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
Expand Down Expand Up @@ -730,6 +727,11 @@ def _get_ready_tis(
additional_tis.extend(expanded_tis[1:])
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
task = schedulable.task
if isinstance(schedulable.task, MappedOperator):
# Ensure the task indexes are complete
created = self._revise_mapped_task_indexes(task, session=session)
ready_tis.extend(created)
ready_tis.append(schedulable)

# Check if any ti changed state
Expand Down Expand Up @@ -825,7 +827,6 @@ def _emit_duration_stats_for_finished_state(self):
def verify_integrity(
self,
*,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
session: Session = NEW_SESSION,
):
"""
Expand All @@ -842,15 +843,10 @@ def verify_integrity(

dag = self.get_dag()
task_ids: Set[str] = set()
if missing_indexes:
tis = self.get_task_instances(session=session)
for ti in tis:
task_instance_mutation_hook(ti)
task_ids.add(ti.task_id)
else:
task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
Expand All @@ -865,29 +861,27 @@ def task_filter(task: "Operator") -> bool:
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)

# Create the missing tasks, including mapped tasks
tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)
tasks = self._create_tasks(dag, task_creator, task_filter, session=session)

self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)

def _check_for_removed_or_restored_tasks(
self, dag: "DAG", ti_mutation_hook, *, session: Session
) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
) -> Set[str]:
"""
Check for removed tasks/restored/missing tasks.
:param dag: DAG object corresponding to the dagrun
:param ti_mutation_hook: task_instance_mutation_hook function
:param session: Sqlalchemy ORM Session
:return: List of task_ids in the dagrun and missing task indexes
:return: Task IDs in the DAG run
"""
tis = self.get_task_instances(session=session)

# check for removed or restored tasks
task_ids = set()
existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for ti in tis:
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
Expand Down Expand Up @@ -925,13 +919,9 @@ def _check_for_removed_or_restored_tasks(
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
ti.state = State.REMOVED
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(num_mapped_tis)
else:
# What if it is _now_ dynamically mapped, but wasn't before?
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)

if total_length is None:
Expand All @@ -950,16 +940,8 @@ def _check_for_removed_or_restored_tasks(
total_length,
)
ti.state = State.REMOVED
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(total_length)
# Check if we have some missing indexes to create ti for
missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
return task_ids, missing_indexes

return task_ids

def _get_task_creator(
self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
Expand Down Expand Up @@ -995,12 +977,11 @@ def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
creator = create_ti
return creator

def _create_missing_tasks(
def _create_tasks(
self,
dag: "DAG",
task_creator: Callable,
task_filter: Callable,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
*,
session: Session,
) -> Iterable["Operator"]:
Expand Down Expand Up @@ -1031,12 +1012,7 @@ def expand_mapped_literals(
tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))

tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
if missing_indexes:
# If there are missing indexes, override the tasks to create
new_tasks_and_map_idxs = itertools.starmap(
expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
)
tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))

return tasks

def _create_task_instances(
Expand Down Expand Up @@ -1082,44 +1058,45 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_mapped_task_indexes(
self,
tis: Iterable[TI],
*,
session: Session,
) -> Dict["MappedOperator", Sequence[int]]:
"""Check if the length of the mapped task instances changed at runtime and find the missing indexes.
def _revise_mapped_task_indexes(self, task, session: Session):
"""Check if task increased or reduced in length and handle appropriately"""
from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook

:param tis: Task instances to check
:param session: The session to use
"""
from airflow.models.mappedoperator import MappedOperator
task.run_time_mapped_ti_count.cache_clear()
total_length = (
task.parse_time_mapped_ti_count
or task.run_time_mapped_ti_count(self.run_id, session=session)
or 0
)
query = session.query(TaskInstance.map_index).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
)
existing_indexes = {i for (i,) in query}
missing_indexes = set(range(total_length)).difference(existing_indexes)
removed_indexes = existing_indexes.difference(range(total_length))
created_tis = []

existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list)
new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
for ti in tis:
task = ti.task
if not isinstance(task, MappedOperator):
continue
# skip unexpanded tasks and also tasks that expands with literal arguments
if ti.map_index < 0 or task.parse_time_mapped_ti_count:
continue
existing_indexes[task].append(ti.map_index)
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0

if ti.map_index >= new_length:
self.log.debug(
"Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
ti,
new_length,
)
ti.state = State.REMOVED
new_indexes[task] = range(new_length)
missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
return missing_indexes
if missing_indexes:
for index in missing_indexes:
ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
ti.refresh_from_task(task)
session.flush()
created_tis.append(ti)
elif removed_indexes:
session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index.in_(removed_indexes),
).update({TaskInstance.state: TaskInstanceState.REMOVED})
session.flush()
return created_tis

@staticmethod
def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
Expand Down
7 changes: 2 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.datasets.manager import dataset_event_manager
from airflow.exceptions import (
AirflowException,
AirflowFailException,
Expand Down Expand Up @@ -585,10 +586,6 @@ def __init__(
# can be changed when calling 'run'
self.test_mode = False

self.dataset_event_manager = conf.getimport(
'core', 'dataset_event_manager_class', fallback='airflow.datasets.manager.DatasetEventManager'
)()

@staticmethod
def insert_mapping(run_id: str, task: "Operator", map_index: int) -> dict:
""":meta private:"""
Expand Down Expand Up @@ -1538,7 +1535,7 @@ def _register_dataset_changes(self, *, session: Session) -> None:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
if isinstance(obj, Dataset):
self.dataset_event_manager.register_dataset_change(
dataset_event_manager.register_dataset_change(
task_instance=self,
dataset=obj,
session=session,
Expand Down
Loading

0 comments on commit 5d5fe14

Please sign in to comment.