Skip to content

Commit

Permalink
Add aws_template_field helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 17, 2023
1 parent 0a1598f commit 28aef3b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 18 deletions.
13 changes: 7 additions & 6 deletions airflow/providers/amazon/aws/operators/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType
from airflow.providers.amazon.aws.utils.mixins import (
AwsBaseHookMixin,
AwsHookParams,
AwsHookType,
aws_template_fields,
)


class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]):
Expand Down Expand Up @@ -71,11 +76,7 @@ def execute(self, context):
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)
template_fields: Sequence[str] = aws_template_fields()

def __init__(
self,
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/amazon/aws/operators/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -62,14 +63,13 @@ class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
"""

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"runtime",
"role",
"handler",
"code",
"config",
*AwsBaseOperator.template_fields,
)
ui_color = "#ff7300"

Expand Down Expand Up @@ -175,12 +175,11 @@ class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
"""

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"payload",
"qualifier",
"invocation_type",
*AwsBaseOperator.template_fields,
)
ui_color = "#ff7300"

Expand Down
13 changes: 7 additions & 6 deletions airflow/providers/amazon/aws/sensors/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

from typing import Sequence

from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType
from airflow.providers.amazon.aws.utils.mixins import (
AwsBaseHookMixin,
AwsHookParams,
AwsHookType,
aws_template_fields,
)
from airflow.sensors.base import BaseSensorOperator


Expand Down Expand Up @@ -70,11 +75,7 @@ def poke(self, context):
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)
template_fields: Sequence[str] = aws_template_fields()

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -47,10 +48,9 @@ class LambdaFunctionStateSensor(AwsBaseSensor[LambdaHook]):
FAILURE_STATES = ("Failed",)

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"qualifier",
*AwsBaseSensor.template_fields,
)

def __init__(
Expand Down
13 changes: 13 additions & 0 deletions airflow/providers/amazon/aws/utils/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from typing_extensions import final

from airflow.compat.functools import cache
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

Expand Down Expand Up @@ -163,3 +164,15 @@ def region(self) -> str | None:
"""Alias for ``region_name``, used for compatibility (deprecated)."""
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
return self.region_name


@cache
def aws_template_fields(*template_fields: str) -> tuple[str, ...]:
"""Merge provided template_fields with generic one and return in alphabetical order."""
if not all(isinstance(tf, str) for tf in template_fields):
msg = (
"Expected that all provided arguments are strings, but got "
f"{', '.join(map(repr, template_fields))}."
)
raise TypeError(msg)
return tuple(sorted(list({"aws_conn_id", "region_name", "verify"} | set(template_fields))))
40 changes: 40 additions & 0 deletions tests/providers/amazon/aws/utils/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.providers.amazon.aws.utils.mixins import aws_template_fields


def test_aws_template_fields():
result = aws_template_fields()
assert result == ("aws_conn_id", "region_name", "verify")
assert aws_template_fields() is result, "aws_template_fields expect to return cached result"

assert aws_template_fields("foo", "aws_conn_id", "bar") == (
"aws_conn_id",
"bar",
"foo",
"region_name",
"verify",
)


def test_aws_template_fields_incorrect_types():
with pytest.raises(TypeError, match="Expected that all provided arguments are strings"):
aws_template_fields(1, None, "test")

0 comments on commit 28aef3b

Please sign in to comment.