Skip to content

Commit

Permalink
Implement asset.multi (#44711)
Browse files Browse the repository at this point in the history
This allows a function to emit multiple assets. In this case, you are on
your own providing proper names to each asset, but it would work.

Also includes refactoring to the existing decorator mechanism so we
don't need to repeat code (especially arguments).
  • Loading branch information
uranusjr authored Dec 6, 2024
1 parent 3d421f7 commit da3bdbf
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 54 deletions.
140 changes: 104 additions & 36 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,43 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

import attrs

from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset

if TYPE_CHECKING:
from collections.abc import Collection, Iterator, Mapping
from collections.abc import Callable, Collection, Iterator, Mapping

from airflow.io.path import ObjectStoragePath
from airflow.models.dag import DagStateChangeCallback, ScheduleArg
from airflow.models.param import ParamsDict
from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg
from airflow.serialization.dag_dependency import DagDependency
from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self


class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
self._definition_name = definition_name
self._uri = uri

@classmethod
def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> Self:
return cls(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in inspect.signature(definition._function).parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[v for _, v in definition.iter_assets()],
python_callable=definition._function,
definition_name=definition._function.__name__,
)

def _iter_kwargs(
self, context: Mapping[str, Any], active_assets: dict[str, Asset]
Expand Down Expand Up @@ -81,41 +97,53 @@ class AssetDefinition(Asset):
_source: asset

def __attrs_post_init__(self) -> None:
from airflow.models.dag import DAG

with DAG(
dag_id=self.name,
schedule=self._source.schedule,
is_paused_upon_creation=self._source.is_paused_upon_creation,
dag_display_name=self._source.display_name or self.name,
description=self._source.description,
params=self._source.params,
on_success_callback=self._source.on_success_callback,
on_failure_callback=self._source.on_failure_callback,
auto_register=True,
):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in inspect.signature(self._function).parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[self],
python_callable=self._function,
definition_name=self.name,
uri=self.uri,
)
with self._source.create_dag(dag_id=self.name):
_AssetMainOperator.from_definition(self)


@attrs.define(kw_only=True)
class asset:
"""Create an asset by decorating a materialization function."""
class MultiAssetDefinition(BaseAsset):
"""
Representation from decorating a function with ``@asset.multi``.
This is implemented as an "asset-like" object that can be used in all places
that accept asset-ish things (e.g. normal assets, aliases, AssetAll,
AssetAny).
:meta private:
"""

_function: Callable
_source: asset.multi

def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self._function.__name__):
_AssetMainOperator.from_definition(self)

def evaluate(self, statuses: dict[str, bool]) -> bool:
return all(o.evaluate(statuses=statuses) for o in self._source.outlets)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
for o in self._source.outlets:
yield from o.iter_assets()

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
for o in self._source.outlets:
yield from o.iter_asset_aliases()

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
for obj in self._source.outlets:
yield from obj.iter_dag_dependencies(source=source, target=target)

uri: str | ObjectStoragePath | None = None
group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

@attrs.define(kw_only=True)
class _DAGFactory:
"""
Common class for things that take DAG-like arguments.
This exists so we don't need to define these arguments separately for
``@asset`` and ``@asset.multi``.
"""

schedule: ScheduleArg
is_paused_upon_creation: bool | None = None
Expand All @@ -130,6 +158,46 @@ class asset:
access_control: dict[str, dict[str, Collection[str]]] | None = None
owner_links: dict[str, str] | None = None

def create_dag(self, *, dag_id: str) -> DAG:
from airflow.models.dag import DAG # TODO: Use the SDK DAG when it works.

return DAG(
dag_id=dag_id,
schedule=self.schedule,
is_paused_upon_creation=self.is_paused_upon_creation,
dag_display_name=self.display_name or dag_id,
description=self.description,
params=self.params,
on_success_callback=self.on_success_callback,
on_failure_callback=self.on_failure_callback,
auto_register=True,
)


@attrs.define(kw_only=True)
class asset(_DAGFactory):
"""Create an asset by decorating a materialization function."""

uri: str | ObjectStoragePath | None = None
group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

@attrs.define(kw_only=True)
class multi(_DAGFactory):
"""Create a one-task DAG that emits multiple assets."""

outlets: Collection[BaseAsset] # TODO: Support non-asset outlets?

def __call__(self, f: Callable) -> MultiAssetDefinition:
if self.schedule is not None:
raise NotImplementedError("asset scheduling not implemented yet")
if f.__name__ != f.__qualname__:
raise ValueError("nested function not supported")
if not self.outlets:
raise ValueError("no outlets provided")
return MultiAssetDefinition(function=f, source=self)

def __call__(self, f: Callable) -> AssetDefinition:
if self.schedule is not None:
raise NotImplementedError("asset scheduling not implemented yet")
Expand Down
78 changes: 60 additions & 18 deletions task_sdk/tests/defintions/test_asset_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import ANY

import pytest

Expand Down Expand Up @@ -108,12 +107,22 @@ def test_with_invalid_asset_name(self, example_asset_func):
assert err.value.args[0].startswith("prohibited name for asset: ")


class TestAssetMultiDecorator:
def test_multi_asset(self, example_asset_func):
definition = asset.multi(
schedule=None,
outlets=[Asset(name="a"), Asset(name="b")],
)(example_asset_func)

assert definition._function == example_asset_func
assert definition._source.schedule is None
assert definition._source.outlets == [Asset(name="a"), Asset(name="b")]


class TestAssetDefinition:
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator")
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
@mock.patch("airflow.models.dag.DAG")
def test__attrs_post_init__(
self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset
):
def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_valid_arg_as_inlet_asset):
asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)
Expand All @@ -129,23 +138,56 @@ def test__attrs_post_init__(
params=None,
auto_register=True,
)
_AssetMainOperator.assert_called_once_with(
task_id="__main__",
inlets=[
AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2"),
],
outlets=[asset_definition],
python_callable=ANY,
definition_name="example_asset_func",
uri="s3://bucket/object",
)
from_definition.assert_called_once_with(asset_definition)

python_callable = _AssetMainOperator.call_args.kwargs["python_callable"]
assert python_callable == example_asset_func_with_valid_arg_as_inlet_asset

class TestMultiAssetDefinition:
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
@mock.patch("airflow.models.dag.DAG")
def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_valid_arg_as_inlet_asset):
definition = asset.multi(
schedule=None,
outlets=[Asset(name="a"), Asset(name="b")],
)(example_asset_func_with_valid_arg_as_inlet_asset)

DAG.assert_called_once_with(
dag_id="example_asset_func",
dag_display_name="example_asset_func",
description=None,
schedule=None,
is_paused_upon_creation=None,
on_failure_callback=None,
on_success_callback=None,
params=None,
auto_register=True,
)
from_definition.assert_called_once_with(definition)


class Test_AssetMainOperator:
def test_from_definition(self, example_asset_func_with_valid_arg_as_inlet_asset):
definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)
op = _AssetMainOperator.from_definition(definition)
assert op.task_id == "__main__"
assert op.inlets == [AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")]
assert op.outlets == [definition]
assert op.python_callable == example_asset_func_with_valid_arg_as_inlet_asset
assert op._definition_name == "example_asset_func"

def test_from_definition_multi(self, example_asset_func_with_valid_arg_as_inlet_asset):
definition = asset.multi(
schedule=None,
outlets=[Asset(name="a"), Asset(name="b")],
)(example_asset_func_with_valid_arg_as_inlet_asset)
op = _AssetMainOperator.from_definition(definition)
assert op.task_id == "__main__"
assert op.inlets == [AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")]
assert op.outlets == [Asset(name="a"), Asset(name="b")]
assert op.python_callable == example_asset_func_with_valid_arg_as_inlet_asset
assert op._definition_name == "example_asset_func"

@mock.patch("airflow.models.asset.fetch_active_assets_by_name")
@mock.patch("airflow.utils.session.create_session")
def test_determine_kwargs(
Expand Down

0 comments on commit da3bdbf

Please sign in to comment.