Skip to content

Commit

Permalink
add stronger type hints for execute_statement
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger committed Nov 20, 2023
1 parent f6df257 commit a3ef0df
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions awswrangler/dynamodb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional, Type, TypedDict, Union

import boto3
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from typing_extensions import NotRequired, Required

from awswrangler import _utils, exceptions
from awswrangler._config import apply_configs
Expand Down Expand Up @@ -53,12 +54,33 @@ def get_table(
return dynamodb_table


def _serialize_item(
item: Mapping[str, Any], serializer: Optional[TypeSerializer] = None
) -> Dict[str, "AttributeValueTypeDef"]:
serializer = serializer if serializer else TypeSerializer()
return {k: serializer.serialize(v) for k, v in item.items()}


def _deserialize_item(
item: Mapping[str, "AttributeValueTypeDef"], deserializer: Optional[TypeDeserializer] = None
) -> Dict[str, Any]:
deserializer = deserializer if deserializer else TypeDeserializer()
return {k: deserializer.deserialize(v) for k, v in item.items()}


class _ReadExecuteStatementKwargs(TypedDict):
Statement: Required[str]
ConsistentRead: Required[bool]
Parameters: NotRequired[List["AttributeValueTypeDef"]]
NextToken: NotRequired[str]


def _execute_statement(
kwargs: Dict[str, Union[str, bool, List[Any]]],
kwargs: _ReadExecuteStatementKwargs,
dynamodb_client: "DynamoDBClient",
) -> "ExecuteStatementOutputTypeDef":
try:
response = dynamodb_client.execute_statement(**kwargs) # type: ignore[arg-type]
response = dynamodb_client.execute_statement(**kwargs)
except ClientError as err:
if err.response["Error"]["Code"] == "ResourceNotFoundException":
_logger.error("Couldn't execute PartiQL: '%s' because the table does not exist.", kwargs["Statement"])
Expand All @@ -73,33 +95,21 @@ def _execute_statement(
return response


def _serialize_item(
item: Mapping[str, Any], serializer: Optional[TypeSerializer] = None
) -> Dict[str, "AttributeValueTypeDef"]:
serializer = serializer if serializer else TypeSerializer()
return {k: serializer.serialize(v) for k, v in item.items()}


def _deserialize_item(
item: Mapping[str, "AttributeValueTypeDef"], deserializer: Optional[TypeDeserializer] = None
) -> Dict[str, Any]:
deserializer = deserializer if deserializer else TypeDeserializer()
return {k: deserializer.deserialize(v) for k, v in item.items()}


def _read_execute_statement(
kwargs: Dict[str, Union[str, bool, List[Any]]],
kwargs: _ReadExecuteStatementKwargs,
dynamodb_client: "DynamoDBClient",
) -> Iterator[List[Dict[str, Any]]]:
next_token: Optional[str] = "init_token" # Dummy token
deserializer = TypeDeserializer()

while next_token:
response = _execute_statement(kwargs=kwargs, dynamodb_client=dynamodb_client)
next_token = response.get("NextToken", None)
kwargs["NextToken"] = next_token # type: ignore[assignment]
yield [_deserialize_item(item, deserializer) for item in response["Items"]]

next_token = response.get("NextToken", None)
if next_token:
kwargs["NextToken"] = next_token


def execute_statement(
statement: str,
Expand Down Expand Up @@ -156,7 +166,7 @@ def execute_statement(
... parameters=[title, year],
... )
"""
kwargs: Dict[str, Union[str, bool, List[Any]]] = {"Statement": statement, "ConsistentRead": consistent_read}
kwargs: _ReadExecuteStatementKwargs = {"Statement": statement, "ConsistentRead": consistent_read}
if parameters:
serializer = TypeSerializer()
kwargs["Parameters"] = [serializer.serialize(p) for p in parameters]
Expand Down

0 comments on commit a3ef0df

Please sign in to comment.