diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py index 45b6686d059dc..1cb1ea4e31696 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -18,27 +18,43 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import attrs from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetRef +from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset if TYPE_CHECKING: - from collections.abc import Collection, Iterator, Mapping + from collections.abc import Callable, Collection, Iterator, Mapping from airflow.io.path import ObjectStoragePath - from airflow.models.dag import DagStateChangeCallback, ScheduleArg from airflow.models.param import ParamsDict + from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey + from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg + from airflow.serialization.dag_dependency import DagDependency from airflow.triggers.base import BaseTrigger + from airflow.typing_compat import Self class _AssetMainOperator(PythonOperator): def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: super().__init__(**kwargs) self._definition_name = definition_name - self._uri = uri + + @classmethod + def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> Self: + return cls( + task_id="__main__", + inlets=[ + AssetRef(name=inlet_asset_name) + for inlet_asset_name in inspect.signature(definition._function).parameters + if inlet_asset_name not in ("self", "context") + ], + outlets=[v for _, v in definition.iter_assets()], + python_callable=definition._function, + definition_name=definition._function.__name__, + ) def _iter_kwargs( self, context: Mapping[str, Any], active_assets: dict[str, Asset] @@ -81,41 +97,53 @@ class AssetDefinition(Asset): _source: asset def __attrs_post_init__(self) -> None: - from airflow.models.dag import DAG - - with DAG( - dag_id=self.name, - schedule=self._source.schedule, - is_paused_upon_creation=self._source.is_paused_upon_creation, - dag_display_name=self._source.display_name or self.name, - description=self._source.description, - params=self._source.params, - on_success_callback=self._source.on_success_callback, - on_failure_callback=self._source.on_failure_callback, - auto_register=True, - ): - _AssetMainOperator( - task_id="__main__", - inlets=[ - AssetRef(name=inlet_asset_name) - for inlet_asset_name in inspect.signature(self._function).parameters - if inlet_asset_name not in ("self", "context") - ], - outlets=[self], - python_callable=self._function, - definition_name=self.name, - uri=self.uri, - ) + with self._source.create_dag(dag_id=self.name): + _AssetMainOperator.from_definition(self) @attrs.define(kw_only=True) -class asset: - """Create an asset by decorating a materialization function.""" +class MultiAssetDefinition(BaseAsset): + """ + Representation from decorating a function with ``@asset.multi``. + + This is implemented as an "asset-like" object that can be used in all places + that accept asset-ish things (e.g. normal assets, aliases, AssetAll, + AssetAny). + + :meta private: + """ + + _function: Callable + _source: asset.multi + + def __attrs_post_init__(self) -> None: + with self._source.create_dag(dag_id=self._function.__name__): + _AssetMainOperator.from_definition(self) + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return all(o.evaluate(statuses=statuses) for o in self._source.outlets) + + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: + for o in self._source.outlets: + yield from o.iter_assets() + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + for o in self._source.outlets: + yield from o.iter_asset_aliases() + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + for obj in self._source.outlets: + yield from obj.iter_dag_dependencies(source=source, target=target) - uri: str | ObjectStoragePath | None = None - group: str = Asset.asset_type - extra: dict[str, Any] = attrs.field(factory=dict) - watchers: list[BaseTrigger] = attrs.field(factory=list) + +@attrs.define(kw_only=True) +class _DAGFactory: + """ + Common class for things that take DAG-like arguments. + + This exists so we don't need to define these arguments separately for + ``@asset`` and ``@asset.multi``. + """ schedule: ScheduleArg is_paused_upon_creation: bool | None = None @@ -130,6 +158,46 @@ class asset: access_control: dict[str, dict[str, Collection[str]]] | None = None owner_links: dict[str, str] | None = None + def create_dag(self, *, dag_id: str) -> DAG: + from airflow.models.dag import DAG # TODO: Use the SDK DAG when it works. + + return DAG( + dag_id=dag_id, + schedule=self.schedule, + is_paused_upon_creation=self.is_paused_upon_creation, + dag_display_name=self.display_name or dag_id, + description=self.description, + params=self.params, + on_success_callback=self.on_success_callback, + on_failure_callback=self.on_failure_callback, + auto_register=True, + ) + + +@attrs.define(kw_only=True) +class asset(_DAGFactory): + """Create an asset by decorating a materialization function.""" + + uri: str | ObjectStoragePath | None = None + group: str = Asset.asset_type + extra: dict[str, Any] = attrs.field(factory=dict) + watchers: list[BaseTrigger] = attrs.field(factory=list) + + @attrs.define(kw_only=True) + class multi(_DAGFactory): + """Create a one-task DAG that emits multiple assets.""" + + outlets: Collection[BaseAsset] # TODO: Support non-asset outlets? + + def __call__(self, f: Callable) -> MultiAssetDefinition: + if self.schedule is not None: + raise NotImplementedError("asset scheduling not implemented yet") + if f.__name__ != f.__qualname__: + raise ValueError("nested function not supported") + if not self.outlets: + raise ValueError("no outlets provided") + return MultiAssetDefinition(function=f, source=self) + def __call__(self, f: Callable) -> AssetDefinition: if self.schedule is not None: raise NotImplementedError("asset scheduling not implemented yet") diff --git a/task_sdk/tests/defintions/test_asset_decorators.py b/task_sdk/tests/defintions/test_asset_decorators.py index aeaa8632901c4..2e714237b4193 100644 --- a/task_sdk/tests/defintions/test_asset_decorators.py +++ b/task_sdk/tests/defintions/test_asset_decorators.py @@ -17,7 +17,6 @@ from __future__ import annotations from unittest import mock -from unittest.mock import ANY import pytest @@ -108,12 +107,22 @@ def test_with_invalid_asset_name(self, example_asset_func): assert err.value.args[0].startswith("prohibited name for asset: ") +class TestAssetMultiDecorator: + def test_multi_asset(self, example_asset_func): + definition = asset.multi( + schedule=None, + outlets=[Asset(name="a"), Asset(name="b")], + )(example_asset_func) + + assert definition._function == example_asset_func + assert definition._source.schedule is None + assert definition._source.outlets == [Asset(name="a"), Asset(name="b")] + + class TestAssetDefinition: - @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator") + @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition") @mock.patch("airflow.models.dag.DAG") - def test__attrs_post_init__( - self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset - ): + def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_valid_arg_as_inlet_asset): asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( example_asset_func_with_valid_arg_as_inlet_asset ) @@ -129,23 +138,56 @@ def test__attrs_post_init__( params=None, auto_register=True, ) - _AssetMainOperator.assert_called_once_with( - task_id="__main__", - inlets=[ - AssetRef(name="inlet_asset_1"), - AssetRef(name="inlet_asset_2"), - ], - outlets=[asset_definition], - python_callable=ANY, - definition_name="example_asset_func", - uri="s3://bucket/object", - ) + from_definition.assert_called_once_with(asset_definition) - python_callable = _AssetMainOperator.call_args.kwargs["python_callable"] - assert python_callable == example_asset_func_with_valid_arg_as_inlet_asset + +class TestMultiAssetDefinition: + @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition") + @mock.patch("airflow.models.dag.DAG") + def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_valid_arg_as_inlet_asset): + definition = asset.multi( + schedule=None, + outlets=[Asset(name="a"), Asset(name="b")], + )(example_asset_func_with_valid_arg_as_inlet_asset) + + DAG.assert_called_once_with( + dag_id="example_asset_func", + dag_display_name="example_asset_func", + description=None, + schedule=None, + is_paused_upon_creation=None, + on_failure_callback=None, + on_success_callback=None, + params=None, + auto_register=True, + ) + from_definition.assert_called_once_with(definition) class Test_AssetMainOperator: + def test_from_definition(self, example_asset_func_with_valid_arg_as_inlet_asset): + definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func_with_valid_arg_as_inlet_asset + ) + op = _AssetMainOperator.from_definition(definition) + assert op.task_id == "__main__" + assert op.inlets == [AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")] + assert op.outlets == [definition] + assert op.python_callable == example_asset_func_with_valid_arg_as_inlet_asset + assert op._definition_name == "example_asset_func" + + def test_from_definition_multi(self, example_asset_func_with_valid_arg_as_inlet_asset): + definition = asset.multi( + schedule=None, + outlets=[Asset(name="a"), Asset(name="b")], + )(example_asset_func_with_valid_arg_as_inlet_asset) + op = _AssetMainOperator.from_definition(definition) + assert op.task_id == "__main__" + assert op.inlets == [AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")] + assert op.outlets == [Asset(name="a"), Asset(name="b")] + assert op.python_callable == example_asset_func_with_valid_arg_as_inlet_asset + assert op._definition_name == "example_asset_func" + @mock.patch("airflow.models.asset.fetch_active_assets_by_name") @mock.patch("airflow.utils.session.create_session") def test_determine_kwargs(