Skip to content

Commit

Permalink
feat(task_sdk): make airflow.sdk.definitions.decoratos a package
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Nov 15, 2024
1 parent 757f6a6 commit 324efc0
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 30 deletions.
2 changes: 1 addition & 1 deletion airflow/example_dags/example_asset_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from airflow.decorators import dag, task
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.decorators import asset
from airflow.sdk.definitions.decorators.asset import asset


@asset(uri="s3://bucket/asset1_producer", schedule=None)
Expand Down
42 changes: 42 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import sys
from types import FunctionType


class _autostacklevel_warn:
def __init__(self):
self.warnings = __import__("warnings")

def __getattr__(self, name: str):
return getattr(self.warnings, name)

def __dir__(self):
return dir(self.warnings)

def warn(self, message, category=None, stacklevel=1, source=None):
self.warnings.warn(message, category, stacklevel + 2, source)


def fixup_decorator_warning_stack(func: FunctionType):
if func.__globals__.get("warnings") is sys.modules["warnings"]:
# Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to
# `warnings.warn` to ignore the decorator.
func.__globals__["warnings"] = _autostacklevel_warn()
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,15 @@
import attrs

from airflow.models.asset import _fetch_active_assets_by_name
from airflow.models.dag import DAG, ScheduleArg
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.sdk.definitions.dag import DAG, ScheduleArg
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath


import sys
from types import FunctionType


class _autostacklevel_warn:
def __init__(self):
self.warnings = __import__("warnings")

def __getattr__(self, name: str):
return getattr(self.warnings, name)

def __dir__(self):
return dir(self.warnings)

def warn(self, message, category=None, stacklevel=1, source=None):
self.warnings.warn(message, category, stacklevel + 2, source)


def fixup_decorator_warning_stack(func: FunctionType):
if func.__globals__.get("warnings") is sys.modules["warnings"]:
# Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to
# `warnings.warn` to ignore the decorator.
func.__globals__["warnings"] = _autostacklevel_warn()


class _AssetMainOperator(PythonOperator):
def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/tests/defintions/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from airflow.models.asset import AssetActive, AssetModel
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset
from airflow.sdk.definitions.decorators.asset import AssetRef, _AssetMainOperator, asset

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -119,8 +119,8 @@ def test_serialzie(self, example_asset_definition):
"uri": "s3://bucket/object",
}

@mock.patch("airflow.sdk.definitions.decorators._AssetMainOperator")
@mock.patch("airflow.sdk.definitions.decorators.DAG")
@mock.patch("airflow.sdk.definitions.decorators.asset._AssetMainOperator")
@mock.patch("airflow.sdk.definitions.decorators.asset.DAG")
def test__attrs_post_init__(
self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset
):
Expand Down

0 comments on commit 324efc0

Please sign in to comment.