Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base classes for AWS Lambda Operators/Sensors #34890

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions airflow/providers/amazon/aws/operators/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@

import json
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context


class LambdaCreateFunctionOperator(BaseOperator):
class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
"""
Creates an AWS Lambda function.

Expand Down Expand Up @@ -62,13 +61,15 @@ class LambdaCreateFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"runtime",
"role",
"handler",
"code",
"config",
*AwsBaseOperator.template_fields,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding this to every operator class, can we do it the other way around where the base class has this field already with these defaults and we update it with anything new that the concrete class wants to add (if anything)? It would need to be a mutable type of course, I'm not sure if it being immutable is a requirement or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could update it only in init method, at least I don't know another way to change class attribute during creation, and I'm not sure is it works on each cases or not (especially for mapped tasks).

Another way it is make parameters as set, but this have side effect, set doesn't preserve the order, and it might generate new serialised DAG each time it parsed (maybe it is not so bad)

class Base:
    templates_fields: set[str] = {"a", "b", "c"}

class Child1(Base):
    ...

class Child2(Base):
    templates_fields = Base.templates_fields | {"d", "e"}

class Child3(Base):
    ...

print(" Step 1 ".center(72, "="))
print(f"Base.templates_fields={Base.templates_fields}")
print(f"Child1.templates_fields={Child1.templates_fields}")
print(f"Child2.templates_fields={Child2.templates_fields} - the order would change from run to run")
print(f"Child3.templates_fields={Child3.templates_fields}")

print(" Step 2: Update one of the child ".center(72, "="))
Child3.templates_fields.update({"f", "g"})

print(f"Base.templates_fields={Base.templates_fields}")
print(f"Child1.templates_fields={Child1.templates_fields}")
print(f"Child2.templates_fields={Child2.templates_fields} - the order would change from run to run")
print(f"Child3.templates_fields={Child3.templates_fields}")

print(" Step 3: Invalid operation ".center(72, "="))
class Child5(Base):
    templates_fields.add("h")  # We can't do that
================================ Step 1 ================================
Base.templates_fields={'b', 'c', 'a'}
Child1.templates_fields={'b', 'c', 'a'}
Child2.templates_fields={'e', 'b', 'd', 'c', 'a'} - the order would change from run to run
Child3.templates_fields={'b', 'c', 'a'}
=================== Step 2: Update one of the child ====================
Base.templates_fields={'b', 'f', 'g', 'c', 'a'}
Child1.templates_fields={'b', 'f', 'g', 'c', 'a'}
Child2.templates_fields={'e', 'b', 'd', 'c', 'a'} - the order would change from run to run
Child3.templates_fields={'b', 'f', 'g', 'c', 'a'}
====================== Step 3: Invalid operation =======================
Traceback (most recent call last):
...
NameError: name 'templates_fields' is not defined

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought about another option, create helper function for apply default parameters

from airflow.compat.functools import cache

@cache
def aws_template_fields(*templated_fields: str) -> tuple[str]:
    return tuple(sorted(list({"aws_conn_id", "region_name", "verify"} | set(templated_fields))))


class SomeOperator:
    template_fields: Sequence[str] = aws_template_fields(
        "function_name",
        "runtime",
        "role",
        "handler",
        "code",
        "config",
   )

WDYT?

Copy link
Contributor

@ferruzzi ferruzzi Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A variation on that might be to have each operator define "op_template_fields" and have BaseOperator's template_fields return base+op? Basically:

Class BaseOperator:
    base_template_fields = ("aws_conn_id", "region_name", "verify")

    @cache
    def template_fields:
        return self.base_template_fields + self.operator_template_fields
        
        
Class MyOperator:
    op_template_fields =  ("foo", ''"bar")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also would be nice to test within the mapped task, otherwise we could bump into the same problem as we have in BatchOperator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought about another option, create helper function for apply default parameters

This is certainly nicer! Less duplication and the defaults are complexity hidden from the users. It would be nice to have a fully backwards compatible approach, but I'm happy to settle on this one.

A variation on that might be to have each operator define "op_template_fields"

This would make the existing operators incompatible since the field name would need to change, but we need to make other changes to them to convert them to the base class approach, so maybe that's perfectly fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this helper 28aef3b

)
ui_color = "#ff7300"

Expand All @@ -82,12 +83,11 @@ def __init__(
code: dict,
description: str | None = None,
timeout: int | None = None,
config: dict = {},
config: dict | None = None,
wait_for_completion: bool = False,
waiter_max_attempts: int = 60,
waiter_delay: int = 15,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -98,16 +98,11 @@ def __init__(
self.code = code
self.description = description
self.timeout = timeout
self.config = config
self.config = config or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
self.log.info("Creating AWS Lambda function: %s", self.function_name)
Expand All @@ -131,6 +126,9 @@ def execute(self, context: Context):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
Expand All @@ -152,7 +150,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
return event["function_arn"]


class LambdaInvokeFunctionOperator(BaseOperator):
class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll admit this format for the inheritence is new to me and I am trying to read up on how it's doing what it does, but it appears that the value inside the square brackets here is always the same as the value being assigned to aws_hook_class within the Operator. I don't suppose there is a way to simplify that so you don't need both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In square brakets it is just typing.Generic, which help type annotate type of the hook property, all chain is

  1. Define TypeVar

AwsHookType = TypeVar("AwsHookType", bound=AwsGenericHook)

  1. Define Generic in class definition

class AwsBaseHookMixin(Generic[AwsHookType]):
"""Mixin class for AWS Operators, Sensors, etc.

  1. Use defined typevar as annotation

# Should be assigned in child class
aws_hook_class: type[AwsHookType]

@cached_property
@final
def hook(self) -> AwsHookType:
"""

In general class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]): for help IDE, static checkers to understand that hook will return LambdaHook object (in our case we use only in one place), and in additional validate that aws_hook_class should be a subclass of LambdaHook

However information about generics doesn't available in runtime and we can't extract it from class definition, so we should assign actual class to aws_hook_class for construct hook in hook property

"""
Invokes an AWS Lambda function.

Expand All @@ -176,7 +174,14 @@ class LambdaInvokeFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

template_fields: Sequence[str] = ("function_name", "payload", "qualifier", "invocation_type")
aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"payload",
"qualifier",
"invocation_type",
*AwsBaseOperator.template_fields,
)
ui_color = "#ff7300"

def __init__(
Expand All @@ -189,7 +194,6 @@ def __init__(
invocation_type: str | None = None,
client_context: str | None = None,
payload: bytes | str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -200,11 +204,6 @@ def __init__(
self.qualifier = qualifier
self.invocation_type = invocation_type
self.client_context = client_context
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
"""
Expand Down
16 changes: 5 additions & 11 deletions airflow/providers/amazon/aws/sensors/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator


class LambdaFunctionStateSensor(BaseSensorOperator):
class LambdaFunctionStateSensor(AwsBaseSensor[LambdaHook]):
"""
Poll the deployment state of the AWS Lambda function until it reaches a target state.

Expand All @@ -48,9 +46,11 @@ class LambdaFunctionStateSensor(BaseSensorOperator):

FAILURE_STATES = ("Failed",)

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"qualifier",
*AwsBaseSensor.template_fields,
)

def __init__(
Expand All @@ -59,11 +59,9 @@ def __init__(
function_name: str,
qualifier: str | None = None,
target_states: list = ["Active"],
aws_conn_id: str = "aws_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.function_name = function_name
self.qualifier = qualifier
self.target_states = target_states
Expand All @@ -83,7 +81,3 @@ def poke(self, context: Context) -> bool:
raise AirflowException(message)

return state in self.target_states

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
**kwargs,
) -> None:

super().__init__(
Expand All @@ -62,7 +63,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
return LambdaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
.. Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition

or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.


aws_conn_id
Reference to :ref:`Amazon Web Services Connection <howto/connection:aws>` ID.
If this parameter is set to ``None`` then the default boto3 behaviour is used without lookup connection.
Otherwise use credentials stored into the Connection. Default: ``aws_default``
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

region_name
AWS Region Name. If this parameter is set to ``None`` or omitted then **region_name** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of connection value. Default: ``None``
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

verify
Whether or not to verify SSL certificates.

* ``False`` - do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to uses. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

If this parameter is set to ``None`` or omitted then **verify** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of from connection value. Default: ``None``
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

botocore_config
Use provided dictionary to construct a `botocore.config.Config`_.
This configuration will able to use for :ref:`howto/connection:aws:avoid-throttling-exceptions`, configure timeouts and etc.
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python
:caption: Example, for more detail about parameters please have a look `botocore.config.Config`_

{
"signature_version": "unsigned",
"s3": {
"us_east_1_regional_endpoint": True,
},
"retries": {
"mode": "standard",
"max_attempts": 10,
},
"connect_timeout": 300,
"read_timeout": 300,
"tcp_keepalive": True,
}

If this parameter is set to ``None`` or omitted then **config_kwargs** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of connection value. Default: ``None``
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

.. note::
Empty dictionary ``{}`` uses for construct default botocore Config
and will overwrite connection configuration for `botocore.config.Config`_
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

.. _botocore.config.Config: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
6 changes: 6 additions & 0 deletions docs/apache-airflow-providers-amazon/connections/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Extra (optional)

* ``verify``: Whether or not to verify SSL certificates.

* ``False`` - do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to uses. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

Taragolis marked this conversation as resolved.
Show resolved Hide resolved
The following extra parameters used for specific AWS services:

* ``service_config``: json used to specify configuration/parameters per AWS service / Amazon provider hook,
Expand Down Expand Up @@ -396,6 +400,8 @@ provide selected options in the connection's extra field.
}


.. _howto/connection:aws:avoid-throttling-exceptions:

Avoid Throttling exceptions
---------------------------

Expand Down
23 changes: 20 additions & 3 deletions docs/apache-airflow-providers-amazon/operators/lambda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down Expand Up @@ -68,10 +73,22 @@ To invoke an AWS lambda function you can use

urllib3.exceptions.ReadTimeoutError: AWSHTTPSConnectionPool(host='lambda.us-east-1.amazonaws.com', port=443): Read timed out. (read timeout=60)

If you encounter this issue, configure :ref:`howto/connection:aws` to allow ``botocore`` / ``boto3`` to use long connections with
timeout or keep-alive settings in the Connection Extras field
If you encounter this issue, you need to provide `botocore.config.Config <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html>`__
for use long connections with timeout or keep-alive settings

By provide **botocore_config** as operator parameter
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

{
"connect_timeout": 900,
"read_timeout": 900,
"tcp_keepalive": True,
}

Or specify **config_kwargs** in associated :ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>`

.. code-block:: json
.. code-block:: json

{
"config_kwargs": {
Expand Down
Loading
Loading