Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Metadata to emit runtime extra #38650

Merged
merged 12 commits into from
Apr 8, 2024
24 changes: 20 additions & 4 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@


def normalize_noop(parts: SplitResult) -> SplitResult:
"""Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.

:meta private:
"""
return parts


Expand All @@ -42,13 +46,11 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)


def sanitize_uri(uri: str) -> str:
def _sanitize_uri(uri: str) -> str:
"""Sanitize a dataset URI.

This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.

:meta private:
"""
if not uri:
raise ValueError("Dataset URI cannot be empty")
Expand Down Expand Up @@ -89,6 +91,20 @@ def sanitize_uri(uri: str) -> str:
return urllib.parse.urlunsplit(parsed)


def coerce_to_uri(value: str | Dataset) -> str:
"""Coerce a user input into a sanitized URI.

If the input value is a string, it is treated as a URI and sanitized. If the
input is a :class:`Dataset`, the URI it contains is considered sanitized and
returned directly.

:meta private:
"""
if isinstance(value, Dataset):
return value.uri
return _sanitize_uri(str(value))


class BaseDatasetEventInput:
"""Protocol for all dataset triggers to use in ``DAG(schedule=...)``.

Expand Down Expand Up @@ -117,7 +133,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(
converter=sanitize_uri,
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None
Expand Down
39 changes: 39 additions & 0 deletions airflow/datasets/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import attrs

from airflow.datasets import coerce_to_uri

if TYPE_CHECKING:
from airflow.datasets import Dataset


@attrs.define(init=False)
class Metadata:
"""Metadata to attach to a DatasetEvent."""

uri: str
extra: dict[str, Any]

def __init__(self, target: str | Dataset, extra: dict[str, Any]) -> None:
self.uri = coerce_to_uri(target)
self.extra = extra
25 changes: 18 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import contextlib
import copy
import functools
import inspect
import logging
import sys
import warnings
from datetime import datetime, timedelta
from functools import total_ordering, wraps
from inspect import signature
from types import FunctionType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -91,10 +91,11 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.context import Context, context_get_dataset_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
Expand Down Expand Up @@ -423,7 +424,7 @@ def _apply_defaults(cls, func: T) -> T:
# at every decorated invocation. This is separate sig_cache created
# per decoration, i.e. each function decorated using apply_defaults will
# have a different sig_cache.
sig_cache = signature(func)
sig_cache = inspect.signature(func)
non_variadic_params = {
name: param
for (name, param) in sig_cache.parameters.items()
Expand Down Expand Up @@ -1269,8 +1270,13 @@ def set_xcomargs_dependencies(self) -> None:
@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is not None:
self._pre_execute_hook(context)
if self._pre_execute_hook is None:
return
ExecutionCallableRunner(
self._pre_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -1289,8 +1295,13 @@ def post_execute(self, context: Any, result: Any = None):

It is passed the execution context and any results returned by the operator.
"""
if self._post_execute_hook is not None:
self._post_execute_hook(context, result)
if self._post_execute_hook is None:
return
ExecutionCallableRunner(
self._post_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context, result)

def on_kill(self) -> None:
"""
Expand Down
40 changes: 23 additions & 17 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@
Context,
DatasetEventAccessors,
VariableAccessor,
context_get_dataset_events,
context_merge,
)
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import qualname
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -432,12 +433,16 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
if execute_callable.__name__ == "execute":
execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel

def _execute_callable(context, **execute_callable_kwargs):
def _execute_callable(context: Context, **execute_callable_kwargs):
try:
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")

return execute_callable(context=context, **execute_callable_kwargs)
return ExecutionCallableRunner(
execute_callable,
context_get_dataset_events(context),
logger=log,
).run(context=context, **execute_callable_kwargs)
except SystemExit as e:
# Handle only successful cases here. Failure cases will be handled upper
# in the exception chain.
Expand Down Expand Up @@ -2678,6 +2683,10 @@ def signal_handler(signum, frame):
jinja_env = None
task_orig = self.render_templates(context=context, jinja_env=jinja_env)

# The task is never MappedOperator at this point.
if TYPE_CHECKING:
assert isinstance(self.task, BaseOperator)

if not test_mode:
rendered_fields = get_serialized_template_fields(task=self.task)
_update_rtif(ti=self, rendered_fields=rendered_fields)
Expand All @@ -2695,8 +2704,7 @@ def signal_handler(signum, frame):
)

# Run pre_execute callback
# Is never MappedOperator at this point
self.task.pre_execute(context=context) # type: ignore[union-attr]
self.task.pre_execute(context=context)

# Run on_execute callback
self._run_execute_callback(context, self.task)
Expand All @@ -2711,8 +2719,7 @@ def signal_handler(signum, frame):
result = self._execute_task(context, task_orig)

# Run post_execute callback
# Is never MappedOperator at this point
self.task.post_execute(context=context, result=result) # type: ignore[union-attr]
self.task.post_execute(context=context, result=result)

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
Expand All @@ -2724,7 +2731,7 @@ def signal_handler(signum, frame):
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Stats.incr("ti_successes", tags=self.stats_tags)

def _execute_task(self, context, task_orig):
def _execute_task(self, context: Context, task_orig: Operator):
"""
Execute Task (optionally with a Timeout) and push Xcom results.

Expand Down Expand Up @@ -2775,16 +2782,15 @@ def defer_task(self, session: Session, defer: TaskDeferred) -> None:
else:
self.trigger_timeout = self.start_date + execution_timeout

def _run_execute_callback(self, context: Context, task: Operator) -> None:
def _run_execute_callback(self, context: Context, task: BaseOperator) -> None:
"""Functions that need to be run before a Task is executed."""
callbacks = task.on_execute_callback
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
for callback in callbacks:
try:
callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")
if not (callbacks := task.on_execute_callback):
return
for callback in callbacks if isinstance(callbacks, list) else [callbacks]:
try:
callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")

@provide_session
def run(
Expand Down
12 changes: 8 additions & 4 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script

Expand Down Expand Up @@ -231,6 +231,7 @@ def __init__(
def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
self._dataset_events = context_get_dataset_events(context)

return_value = self.execute_callable()
if self.show_return_value_in_logs:
Expand All @@ -249,7 +250,8 @@ def execute_callable(self) -> Any:

:return: the return value of the call.
"""
return self.python_callable(*self.op_args, **self.op_kwargs)
runner = ExecutionCallableRunner(self.python_callable, self._dataset_events, logger=self.log)
return runner.run(*self.op_args, **self.op_kwargs)


class BranchPythonOperator(PythonOperator, BranchMixIn):
Expand Down Expand Up @@ -406,7 +408,9 @@ def __init__(
or isinstance(python_callable, types.LambdaType)
and python_callable.__name__ == "<lambda>"
):
raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg")
raise ValueError(f"{type(self).__name__} only supports functions for python_callable arg")
if inspect.isgeneratorfunction(python_callable):
raise ValueError(f"{type(self).__name__} does not support using 'yield' in python_callable")
super().__init__(
python_callable=python_callable,
op_args=op_args,
Expand Down
17 changes: 9 additions & 8 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import attrs
import lazy_object_proxy

from airflow.datasets import Dataset, sanitize_uri
from airflow.datasets import Dataset, coerce_to_uri
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils.types import NOTSET

Expand Down Expand Up @@ -169,13 +169,7 @@ def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
if isinstance(key, str):
uri = sanitize_uri(key)
elif isinstance(key, Dataset):
uri = key.uri
else:
return NotImplemented
if uri not in self._dict:
if (uri := coerce_to_uri(key)) not in self._dict:
self._dict[uri] = DatasetEventAccessor({})
return self._dict[uri]

Expand Down Expand Up @@ -361,3 +355,10 @@ def _create_value(k: str, v: Any) -> Any:
return lazy_object_proxy.Proxy(factory)

return {k: _create_value(k, v) for k, v in source._context.items()}


def context_get_dataset_events(context: Context) -> DatasetEventAccessors:
try:
return context["dataset_events"]
except KeyError:
return DatasetEventAccessors()
1 change: 1 addition & 0 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ def context_merge(context: Context, **kwargs: Any) -> None: ...
def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: ...
def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
def context_get_dataset_events(context: Context) -> DatasetEventAccessors: ...
Loading