-
Notifications
You must be signed in to change notification settings - Fork 14.5k
/
mixins.py
165 lines (138 loc) · 5.78 KB
/
mixins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains different mixin classes for internal use within the Amazon provider.
.. warning::
Only for internal usage, this module and all classes might be changed, renamed or removed in the future
without any further notice.
:meta: private
"""
from __future__ import annotations
import warnings
from functools import cached_property
from typing import Any, Generic, NamedTuple, TypeVar
from typing_extensions import final
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
AwsHookType = TypeVar("AwsHookType", bound=AwsGenericHook)
REGION_MSG = "`region` is deprecated and will be removed in the future. Please use `region_name` instead."
class AwsHookParams(NamedTuple):
"""
Default Aws Hook Parameters storage class.
:meta private:
"""
aws_conn_id: str | None
region_name: str | None
verify: bool | str | None
botocore_config: dict[str, Any] | None
@classmethod
def from_constructor(
cls,
aws_conn_id: str | None,
region_name: str | None,
verify: bool | str | None,
botocore_config: dict[str, Any] | None,
additional_params: dict,
):
"""
Resolve generic AWS Hooks parameters in class constructor.
Examples:
.. code-block:: python
class AwsFooBarOperator(BaseOperator):
def __init__(
self,
*,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
foo: str = "bar",
**kwargs,
):
params = AwsHookParams.from_constructor(
aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs
)
super().__init__(**kwargs)
self.aws_conn_id = params.aws_conn_id
self.region_name = params.region_name
self.verify = params.verify
self.botocore_config = params.botocore_config
self.foo = foo
"""
if region := additional_params.pop("region", None):
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
if region_name and region_name != region:
raise ValueError(
f"Conflicting `region_name` provided, region_name={region_name!r}, region={region!r}."
)
region_name = region
return cls(aws_conn_id, region_name, verify, botocore_config)
class AwsBaseHookMixin(Generic[AwsHookType]):
"""Mixin class for AWS Operators, Sensors, etc.
.. warning::
Only for internal usage, this class might be changed, renamed or removed in the future
without any further notice.
:meta private:
"""
# Should be assigned in child class
aws_hook_class: type[AwsHookType]
aws_conn_id: str | None
region_name: str | None
verify: bool | str | None
botocore_config: dict[str, Any] | None
def validate_attributes(self):
"""Validate class attributes."""
if hasattr(self, "aws_hook_class"): # Validate if ``aws_hook_class`` is properly set.
try:
if not issubclass(self.aws_hook_class, AwsGenericHook):
raise TypeError
except TypeError:
# Raise if ``aws_hook_class`` is not a class or not a subclass of Generic/Base AWS Hook
raise AttributeError(
f"Class attribute '{type(self).__name__}.aws_hook_class' "
f"is not a subclass of AwsGenericHook."
) from None
else:
raise AttributeError(f"Class attribute '{type(self).__name__}.aws_hook_class' should be set.")
@property
def _hook_parameters(self) -> dict[str, Any]:
"""
Mapping parameters to build boto3-related hooks.
Only required to be overwritten for thick-wrapped Hooks.
"""
return {
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"verify": self.verify,
"config": self.botocore_config,
}
@cached_property
@final
def hook(self) -> AwsHookType:
"""
Return AWS Provider's hook based on ``aws_hook_class``.
This method implementation should be taken as a final for
thin-wrapped Hooks around boto3. For thick-wrapped Hooks developer
should consider to overwrite ``_hook_parameters`` method instead.
"""
return self.aws_hook_class(**self._hook_parameters)
@property
@final
def region(self) -> str | None:
"""Alias for ``region_name``, used for compatibility (deprecated)."""
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
return self.region_name