Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base classes for AWS Lambda Operators/Sensors #34890

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions airflow/providers/amazon/aws/operators/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType
from airflow.providers.amazon.aws.utils.mixins import (
AwsBaseHookMixin,
AwsHookParams,
AwsHookType,
aws_template_fields,
)


class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]):
Expand Down Expand Up @@ -71,11 +76,7 @@ def execute(self, context):
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)
template_fields: Sequence[str] = aws_template_fields()

def __init__(
self,
Expand Down
38 changes: 18 additions & 20 deletions airflow/providers/amazon/aws/operators/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@

import json
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class LambdaCreateFunctionOperator(BaseOperator):
class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
"""
Creates an AWS Lambda function.

Expand Down Expand Up @@ -62,7 +62,8 @@ class LambdaCreateFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

template_fields: Sequence[str] = (
aws_hook_class = LambdaHook
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"runtime",
"role",
Expand All @@ -82,12 +83,11 @@ def __init__(
code: dict,
description: str | None = None,
timeout: int | None = None,
config: dict = {},
config: dict | None = None,
wait_for_completion: bool = False,
waiter_max_attempts: int = 60,
waiter_delay: int = 15,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -98,16 +98,11 @@ def __init__(
self.code = code
self.description = description
self.timeout = timeout
self.config = config
self.config = config or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
self.log.info("Creating AWS Lambda function: %s", self.function_name)
Expand All @@ -131,6 +126,9 @@ def execute(self, context: Context):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
Expand All @@ -152,7 +150,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
return event["function_arn"]


class LambdaInvokeFunctionOperator(BaseOperator):
class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll admit this format for the inheritence is new to me and I am trying to read up on how it's doing what it does, but it appears that the value inside the square brackets here is always the same as the value being assigned to aws_hook_class within the Operator. I don't suppose there is a way to simplify that so you don't need both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In square brakets it is just typing.Generic, which help type annotate type of the hook property, all chain is

  1. Define TypeVar

AwsHookType = TypeVar("AwsHookType", bound=AwsGenericHook)

  1. Define Generic in class definition

class AwsBaseHookMixin(Generic[AwsHookType]):
"""Mixin class for AWS Operators, Sensors, etc.

  1. Use defined typevar as annotation

# Should be assigned in child class
aws_hook_class: type[AwsHookType]

@cached_property
@final
def hook(self) -> AwsHookType:
"""

In general class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]): for help IDE, static checkers to understand that hook will return LambdaHook object (in our case we use only in one place), and in additional validate that aws_hook_class should be a subclass of LambdaHook

However information about generics doesn't available in runtime and we can't extract it from class definition, so we should assign actual class to aws_hook_class for construct hook in hook property

"""
Invokes an AWS Lambda function.

Expand All @@ -176,7 +174,13 @@ class LambdaInvokeFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

template_fields: Sequence[str] = ("function_name", "payload", "qualifier", "invocation_type")
aws_hook_class = LambdaHook
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"payload",
"qualifier",
"invocation_type",
)
ui_color = "#ff7300"

def __init__(
Expand All @@ -189,7 +193,6 @@ def __init__(
invocation_type: str | None = None,
client_context: str | None = None,
payload: bytes | str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -200,11 +203,6 @@ def __init__(
self.qualifier = qualifier
self.invocation_type = invocation_type
self.client_context = client_context
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
"""
Expand Down
13 changes: 7 additions & 6 deletions airflow/providers/amazon/aws/sensors/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

from typing import Sequence

from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType
from airflow.providers.amazon.aws.utils.mixins import (
AwsBaseHookMixin,
AwsHookParams,
AwsHookType,
aws_template_fields,
)
from airflow.sensors.base import BaseSensorOperator


Expand Down Expand Up @@ -70,11 +75,7 @@ def poke(self, context):
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)
template_fields: Sequence[str] = aws_template_fields()

def __init__(
self,
Expand Down
18 changes: 6 additions & 12 deletions airflow/providers/amazon/aws/sensors/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator


class LambdaFunctionStateSensor(BaseSensorOperator):
class LambdaFunctionStateSensor(AwsBaseSensor[LambdaHook]):
"""
Poll the deployment state of the AWS Lambda function until it reaches a target state.

Expand All @@ -48,7 +47,8 @@ class LambdaFunctionStateSensor(BaseSensorOperator):

FAILURE_STATES = ("Failed",)

template_fields: Sequence[str] = (
aws_hook_class = LambdaHook
template_fields: Sequence[str] = aws_template_fields(
"function_name",
"qualifier",
)
Expand All @@ -59,11 +59,9 @@ def __init__(
function_name: str,
qualifier: str | None = None,
target_states: list = ["Active"],
aws_conn_id: str = "aws_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.function_name = function_name
self.qualifier = qualifier
self.target_states = target_states
Expand All @@ -83,7 +81,3 @@ def poke(self, context: Context) -> bool:
raise AirflowException(message)

return state in self.target_states

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
**kwargs,
) -> None:

super().__init__(
Expand All @@ -62,7 +63,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
return LambdaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
13 changes: 13 additions & 0 deletions airflow/providers/amazon/aws/utils/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from typing_extensions import final

from airflow.compat.functools import cache
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

Expand Down Expand Up @@ -163,3 +164,15 @@ def region(self) -> str | None:
"""Alias for ``region_name``, used for compatibility (deprecated)."""
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
return self.region_name


@cache
def aws_template_fields(*template_fields: str) -> tuple[str, ...]:
"""Merge provided template_fields with generic one and return in alphabetical order."""
if not all(isinstance(tf, str) for tf in template_fields):
msg = (
"Expected that all provided arguments are strings, but got "
f"{', '.join(map(repr, template_fields))}."
)
raise TypeError(msg)
return tuple(sorted(list({"aws_conn_id", "region_name", "verify"} | set(template_fields))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
.. Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition

or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.


aws_conn_id
Reference to :ref:`Amazon Web Services Connection <howto/connection:aws>` ID.
If this parameter is set to ``None`` then the default boto3 behaviour is used without a connection lookup.
Otherwise use the credentials stored in the Connection. Default: ``aws_default``

region_name
AWS Region Name. If this parameter is set to ``None`` or omitted then **region_name** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will be used.
Otherwise use the specified value instead of the connection value. Default: ``None``

verify
Whether or not to verify SSL certificates.

* ``False`` - Do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to use. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

If this parameter is set to ``None`` or is omitted then **verify** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will be used.
Otherwise use the specified value instead of the connection value. Default: ``None``

botocore_config
The provided dictionary is used to construct a `botocore.config.Config`_.
This configuration can be used to configure :ref:`howto/connection:aws:avoid-throttling-exceptions`, timeouts, etc.

.. code-block:: python
:caption: Example, for more detail about parameters please have a look `botocore.config.Config`_

{
"signature_version": "unsigned",
"s3": {
"us_east_1_regional_endpoint": True,
},
"retries": {
"mode": "standard",
"max_attempts": 10,
},
"connect_timeout": 300,
"read_timeout": 300,
"tcp_keepalive": True,
}

If this parameter is set to ``None`` or omitted then **config_kwargs** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will be used.
Otherwise use the specified value instead of the connection value. Default: ``None``

.. note::
Specifying an empty dictionary, ``{}``, will overwrite the connection configuration for `botocore.config.Config`_

.. _botocore.config.Config: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
6 changes: 6 additions & 0 deletions docs/apache-airflow-providers-amazon/connections/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Extra (optional)

* ``verify``: Whether or not to verify SSL certificates.

* ``False`` - Do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to use. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

The following extra parameters used for specific AWS services:

* ``service_config``: json used to specify configuration/parameters per AWS service / Amazon provider hook,
Expand Down Expand Up @@ -396,6 +400,8 @@ provide selected options in the connection's extra field.
}


.. _howto/connection:aws:avoid-throttling-exceptions:

Avoid Throttling exceptions
---------------------------

Expand Down
Loading
Loading