Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mypy): a few return types, type signatures, and untyped areas #718

Merged
merged 6 commits into from
Oct 1, 2021
8 changes: 4 additions & 4 deletions aws_lambda_powertools/logging/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
def __init__(
self,
json_serializer: Optional[Callable[[Dict], str]] = None,
json_deserializer: Optional[Callable[[Dict], str]] = None,
json_deserializer: Optional[Callable[[Union[Dict, str, bool, int, float]], str]] = None,
json_default: Optional[Callable[[Any], Any]] = None,
datefmt: Optional[str] = None,
log_record_order: Optional[List[str]] = None,
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self.update_formatter = self.append_keys # alias to old method

if self.utc:
self.converter = time.gmtime
self.converter = time.gmtime # type: ignore

super(LambdaPowertoolsFormatter, self).__init__(datefmt=self.datefmt)

Expand All @@ -128,7 +128,7 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003
return self.serialize(log=formatted_log)

def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str:
record_ts = self.converter(record.created)
record_ts = self.converter(record.created) # type: ignore
if datefmt:
return time.strftime(datefmt, record_ts)

Expand Down Expand Up @@ -201,7 +201,7 @@ def _extract_log_exception(self, log_record: logging.LogRecord) -> Union[Tuple[s
Log record with constant traceback info and exception name
"""
if log_record.exc_info:
return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__
return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__ # type: ignore

return None, None

Expand Down
6 changes: 4 additions & 2 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def registered_handler(self) -> logging.Handler:
return handlers[0]

@property
def registered_formatter(self) -> Optional[PowertoolsFormatter]:
def registered_formatter(self) -> PowertoolsFormatter:
"""Convenience property to access logger formatter"""
return self.registered_handler.formatter # type: ignore

Expand Down Expand Up @@ -405,7 +405,9 @@ def get_correlation_id(self) -> Optional[str]:
str, optional
Value for the correlation id
"""
return self.registered_formatter.log_format.get("correlation_id")
if isinstance(self.registered_formatter, LambdaPowertoolsFormatter):
return self.registered_formatter.log_format.get("correlation_id")
return None

@staticmethod
def _get_log_level(level: Union[str, int, None]) -> Union[str, int]:
Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self._metric_unit_options = list(MetricUnit.__members__)
self.metadata_set = metadata_set if metadata_set is not None else {}

def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float):
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None:
"""Adds given metric

Example
Expand Down Expand Up @@ -215,7 +215,7 @@ def serialize_metric_set(
**metric_names_and_values, # "single_metric": 1.0
}

def add_dimension(self, name: str, value: str):
def add_dimension(self, name: str, value: str) -> None:
"""Adds given dimension to all metrics

Example
Expand All @@ -241,7 +241,7 @@ def add_dimension(self, name: str, value: str):
# checking before casting improves performance in most cases
self.dimension_set[name] = value if isinstance(value, str) else str(value)

def add_metadata(self, key: str, value: Any):
def add_metadata(self, key: str, value: Any) -> None:
"""Adds high cardinal metadata for metrics object

This will not be available during metrics visualization.
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SingleMetric(MetricManager):
Inherits from `aws_lambda_powertools.metrics.base.MetricManager`
"""

def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float):
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None:
"""Method to prevent more than one metric being created

Parameters
Expand Down
30 changes: 17 additions & 13 deletions aws_lambda_powertools/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json
import logging
import warnings
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union, cast

from ..shared.types import AnyCallableT
from .base import MetricManager, MetricUnit
from .metric import single_metric

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(self, service: Optional[str] = None, namespace: Optional[str] = Non
service=self.service,
)

def set_default_dimensions(self, **dimensions):
def set_default_dimensions(self, **dimensions) -> None:
"""Persist dimensions across Lambda invocations

Parameters
Expand All @@ -113,10 +114,10 @@ def lambda_handler():

self.default_dimensions.update(**dimensions)

def clear_default_dimensions(self):
def clear_default_dimensions(self) -> None:
self.default_dimensions.clear()

def clear_metrics(self):
def clear_metrics(self) -> None:
logger.debug("Clearing out existing metric set from memory")
self.metric_set.clear()
self.dimension_set.clear()
Expand All @@ -125,11 +126,11 @@ def clear_metrics(self):

def log_metrics(
self,
lambda_handler: Optional[Callable[[Any, Any], Any]] = None,
lambda_handler: Union[Callable[[Dict, Any], Any], Optional[Callable[[Dict, Any, Optional[Dict]], Any]]] = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
default_dimensions: Optional[Dict[str, str]] = None,
):
) -> AnyCallableT:
"""Decorator to serialize and publish metrics at the end of a function execution.

Be aware that the log_metrics **does call* the decorated function (e.g. lambda_handler).
Expand Down Expand Up @@ -169,11 +170,14 @@ def handler(event, context):
# Return a partial function with args filled
if lambda_handler is None:
logger.debug("Decorator called with parameters")
return functools.partial(
self.log_metrics,
capture_cold_start_metric=capture_cold_start_metric,
raise_on_empty_metrics=raise_on_empty_metrics,
default_dimensions=default_dimensions,
return cast(
AnyCallableT,
functools.partial(
self.log_metrics,
capture_cold_start_metric=capture_cold_start_metric,
raise_on_empty_metrics=raise_on_empty_metrics,
default_dimensions=default_dimensions,
),
)

@functools.wraps(lambda_handler)
Expand All @@ -194,9 +198,9 @@ def decorate(event, context):

return response

return decorate
return cast(AnyCallableT, decorate)

def __add_cold_start_metric(self, context: Any):
def __add_cold_start_metric(self, context: Any) -> None:
"""Add cold start metric and function_name dimension

Parameters
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/middleware_factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def final_decorator(func: Optional[Callable] = None, **kwargs):
if not inspect.isfunction(func):
# @custom_middleware(True) vs @custom_middleware(log_event=True)
raise MiddlewareInvalidArgumentError(
f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}"
f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501
)

@functools.wraps(func)
Expand Down
9 changes: 5 additions & 4 deletions aws_lambda_powertools/shared/jmespath_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@

import jmespath
from jmespath.exceptions import LexerError
from jmespath.functions import Functions, signature

from aws_lambda_powertools.exceptions import InvalidEnvelopeExpressionError

logger = logging.getLogger(__name__)


class PowertoolsFunctions(jmespath.functions.Functions):
@jmespath.functions.signature({"types": ["string"]})
class PowertoolsFunctions(Functions):
@signature({"types": ["string"]})
def _func_powertools_json(self, value):
return json.loads(value)

@jmespath.functions.signature({"types": ["string"]})
@signature({"types": ["string"]})
def _func_powertools_base64(self, value):
return base64.b64decode(value).decode()

@jmespath.functions.signature({"types": ["string"]})
@signature({"types": ["string"]})
def _func_powertools_base64_gzip(self, value):
encoded = base64.b64decode(value)
uncompressed = gzip.decompress(encoded)
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)

aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE)
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE)
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) # type: ignore # noqa: E501


class Tracer:
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/sqs_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def data_type(self) -> str:


class SQSMessageAttributes(Dict[str, SQSMessageAttribute]):
def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]:
def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignore
item = super(SQSMessageAttributes, self).get(key)
return None if item is None else SQSMessageAttribute(item)
return None if item is None else SQSMessageAttribute(item) # type: ignore


class SQSRecord(DictWrapper):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _update_record(self, data_record: DataRecord):
"ExpressionAttributeNames": expression_attr_names,
}

self.table.update_item(**kwargs) # type: ignore
self.table.update_item(**kwargs)

def _delete_record(self, data_record: DataRecord) -> None:
logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}")
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/validation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def validate_data_against_schema(data: Union[Dict, str], schema: Dict, formats:
except (TypeError, AttributeError, fastjsonschema.JsonSchemaDefinitionException) as e:
raise InvalidSchemaFormatError(f"Schema received: {schema}, Formats: {formats}. Error: {e}")
except fastjsonschema.JsonSchemaValueException as e:
message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}"
message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}" # noqa: B306
raise SchemaValidationError(
message,
validation_message=e.message,
validation_message=e.message, # noqa: B306
name=e.name,
path=e.path,
value=e.value,
Expand Down
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ show_error_context = True
[mypy-jmespath]
ignore_missing_imports=True

[mypy-jmespath.exceptions]
ignore_missing_imports=True

[mypy-jmespath.functions]
ignore_missing_imports=True

[mypy-boto3]
ignore_missing_imports = True

Expand Down