Skip to content

Commit

Permalink
Use base classes for AWS Lambda Operators/Sensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 12, 2023
1 parent 545e4d5 commit 7db910a
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 35 deletions.
37 changes: 18 additions & 19 deletions airflow/providers/amazon/aws/operators/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@

import json
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context


class LambdaCreateFunctionOperator(BaseOperator):
class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
"""
Creates an AWS Lambda function.
Expand Down Expand Up @@ -62,13 +61,15 @@ class LambdaCreateFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"runtime",
"role",
"handler",
"code",
"config",
*AwsBaseOperator.template_fields,
)
ui_color = "#ff7300"

Expand All @@ -82,12 +83,11 @@ def __init__(
code: dict,
description: str | None = None,
timeout: int | None = None,
config: dict = {},
config: dict | None = None,
wait_for_completion: bool = False,
waiter_max_attempts: int = 60,
waiter_delay: int = 15,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -98,16 +98,11 @@ def __init__(
self.code = code
self.description = description
self.timeout = timeout
self.config = config
self.config = config or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
self.log.info("Creating AWS Lambda function: %s", self.function_name)
Expand All @@ -131,6 +126,9 @@ def execute(self, context: Context):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
Expand All @@ -152,7 +150,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
return event["function_arn"]


class LambdaInvokeFunctionOperator(BaseOperator):
class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
"""
Invokes an AWS Lambda function.
Expand All @@ -176,7 +174,14 @@ class LambdaInvokeFunctionOperator(BaseOperator):
:param aws_conn_id: The AWS connection ID to use
"""

template_fields: Sequence[str] = ("function_name", "payload", "qualifier", "invocation_type")
aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"payload",
"qualifier",
"invocation_type",
*AwsBaseOperator.template_fields,
)
ui_color = "#ff7300"

def __init__(
Expand All @@ -189,7 +194,6 @@ def __init__(
invocation_type: str | None = None,
client_context: str | None = None,
payload: bytes | str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -200,11 +204,6 @@ def __init__(
self.qualifier = qualifier
self.invocation_type = invocation_type
self.client_context = client_context
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
"""
Expand Down
16 changes: 5 additions & 11 deletions airflow/providers/amazon/aws/sensors/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator


class LambdaFunctionStateSensor(BaseSensorOperator):
class LambdaFunctionStateSensor(AwsBaseSensor[LambdaHook]):
"""
Poll the deployment state of the AWS Lambda function until it reaches a target state.
Expand All @@ -48,9 +46,11 @@ class LambdaFunctionStateSensor(BaseSensorOperator):

FAILURE_STATES = ("Failed",)

aws_hook_class = LambdaHook
template_fields: Sequence[str] = (
"function_name",
"qualifier",
*AwsBaseSensor.template_fields,
)

def __init__(
Expand All @@ -59,11 +59,9 @@ def __init__(
function_name: str,
qualifier: str | None = None,
target_states: list = ["Active"],
aws_conn_id: str = "aws_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.function_name = function_name
self.qualifier = qualifier
self.target_states = target_states
Expand All @@ -83,7 +81,3 @@ def poke(self, context: Context) -> bool:
raise AirflowException(message)

return state in self.target_states

@cached_property
def hook(self) -> LambdaHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
waiter_delay: int = 60,
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
**kwargs,
) -> None:

super().__init__(
Expand All @@ -62,7 +63,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return LambdaHook(aws_conn_id=self.aws_conn_id)
return LambdaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
aws_conn_id
Reference to :ref:`Amazon Web Services Connection <howto/connection:aws>` ID.
If this parameter is set to ``None`` then the default boto3 behaviour is used without lookup connection.
Otherwise use credentials stored into the Connection. Default: ``aws_default``

region_name
AWS Region Name. If this parameter is set to ``None`` or omitted then **region_name** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of connection value. Default: ``None``

verify
Whether or not to verify SSL certificates.

* ``False`` - do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to uses. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

If this parameter is set to ``None`` or omitted then **verify** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of from connection value. Default: ``None``

botocore_config
Use provided dictionary to construct a `botocore.config.Config`_.
This configuration will able to use for :ref:`howto/connection:aws:avoid-throttling-exceptions`, configure timeouts and etc.

.. code-block:: python
:caption: Example, for more detail about parameters please have a look `botocore.config.Config`_
{
"signature_version": "unsigned",
"s3": {
"us_east_1_regional_endpoint": True,
},
"retries": {
"mode": "standard",
"max_attempts": 10,
},
"connect_timeout": 300,
"read_timeout": 300,
"tcp_keepalive": True,
}
If this parameter is set to ``None`` or omitted then **config_kwargs** from
:ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>` will use.
Otherwise use specified value instead of connection value. Default: ``None``
.. note::
Empty dictionary ``{}`` uses for construct default botocore Config
and will overwrite connection configuration for `botocore.config.Config`_
.. _botocore.config.Config: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
6 changes: 6 additions & 0 deletions docs/apache-airflow-providers-amazon/connections/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Extra (optional)

* ``verify``: Whether or not to verify SSL certificates.

* ``False`` - do not validate SSL certificates.
* **path/to/cert/bundle.pem** - A filename of the CA cert bundle to uses. You can specify this argument
if you want to use a different CA cert bundle than the one used by botocore.

The following extra parameters used for specific AWS services:

* ``service_config``: json used to specify configuration/parameters per AWS service / Amazon provider hook,
Expand Down Expand Up @@ -396,6 +400,8 @@ provide selected options in the connection's extra field.
}
.. _howto/connection:aws:avoid-throttling-exceptions:

Avoid Throttling exceptions
---------------------------

Expand Down
23 changes: 20 additions & 3 deletions docs/apache-airflow-providers-amazon/operators/lambda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down Expand Up @@ -68,10 +73,22 @@ To invoke an AWS lambda function you can use
urllib3.exceptions.ReadTimeoutError: AWSHTTPSConnectionPool(host='lambda.us-east-1.amazonaws.com', port=443): Read timed out. (read timeout=60)
If you encounter this issue, configure :ref:`howto/connection:aws` to allow ``botocore`` / ``boto3`` to use long connections with
timeout or keep-alive settings in the Connection Extras field
If you encounter this issue, you need to provide `botocore.config.Config <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html>`__
for use long connections with timeout or keep-alive settings

By provide **botocore_config** as operator parameter

.. code-block:: python
{
"connect_timeout": 900,
"read_timeout": 900,
"tcp_keepalive": True,
}
Or specify **config_kwargs** in associated :ref:`AWS Connection Extra Parameter <howto/connection:aws:configuring-the-connection>`

.. code-block:: json
.. code-block:: json
{
"config_kwargs": {
Expand Down
Loading

0 comments on commit 7db910a

Please sign in to comment.