Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Remove usage of boto3 resources #2525

Merged
merged 16 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions awswrangler/dynamodb/_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import Any, Dict, List, Optional

import boto3
from boto3.dynamodb.types import TypeSerializer

from awswrangler import _utils
from awswrangler._config import apply_configs

from ._utils import _validate_items, get_table
from ._utils import _TableBatchWriter, _validate_items

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,9 +48,16 @@ def delete_items(
"""
_logger.debug("Deleting items from DynamoDB table %s", table_name)

dynamodb_table = get_table(table_name=table_name, boto3_session=boto3_session)
_validate_items(items=items, dynamodb_table=dynamodb_table)
table_keys = [schema["AttributeName"] for schema in dynamodb_table.key_schema]
with dynamodb_table.batch_writer() as writer:
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
serializer = TypeSerializer()

key_schema = dynamodb_client.describe_table(TableName=table_name)["Table"]["KeySchema"]
_validate_items(items=items, key_schema=key_schema)

table_keys = [schema["AttributeName"] for schema in key_schema]

with _TableBatchWriter(table_name, dynamodb_client) as writer:
for item in items:
writer.delete_item(Key={key: item[key] for key in table_keys})
writer.delete_item(
key={key: serializer.serialize(item[key]) for key in table_keys},
)
128 changes: 86 additions & 42 deletions awswrangler/dynamodb/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Dict,
Iterator,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Expand All @@ -20,8 +21,8 @@

import boto3
import pyarrow as pa
from boto3.dynamodb.conditions import ConditionBase
from boto3.dynamodb.types import Binary
from boto3.dynamodb.conditions import ConditionBase, ConditionExpressionBuilder
from boto3.dynamodb.types import Binary, TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from typing_extensions import Literal

Expand All @@ -30,7 +31,7 @@
from awswrangler._distributed import engine
from awswrangler._executor import _BaseExecutor, _get_executor
from awswrangler.distributed.ray import ray_get
from awswrangler.dynamodb._utils import _serialize_kwargs, execute_statement, get_table
from awswrangler.dynamodb._utils import _deserialize_item, _serialize_item, execute_statement

if TYPE_CHECKING:
from mypy_boto3_dynamodb.client import DynamoDBClient
Expand Down Expand Up @@ -195,8 +196,8 @@ def _read_scan_chunked(
# SEE: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan
client_dynamodb = dynamodb_client if dynamodb_client else _utils.client(service_name="dynamodb")

deserializer = boto3.dynamodb.types.TypeDeserializer()
next_token = "init_token" # Dummy token
deserializer = TypeDeserializer()
next_token: Optional[str] = "init_token" # Dummy token
total_items = 0

kwargs = dict(kwargs)
Expand All @@ -218,7 +219,7 @@ def _read_scan_chunked(
if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
break

next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
next_token = response.get("LastEvaluatedKey", None)
if next_token:
kwargs["ExclusiveStartKey"] = next_token

Expand All @@ -242,33 +243,30 @@ def _read_scan(
return _utils.list_to_arrow_table(mapping=items, schema=schema) if as_dataframe else items


def _read_query_chunked(
table_name: str, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
) -> Iterator[_ItemsListType]:
table = get_table(table_name=table_name, boto3_session=boto3_session)
next_token = "init_token" # Dummy token
def _read_query_chunked(table_name: str, dynamodb_client: "DynamoDBClient", **kwargs: Any) -> Iterator[_ItemsListType]:
next_token: Optional[str] = "init_token" # Dummy token
total_items = 0

# Handle pagination
while next_token:
response = table.query(**kwargs)
response = dynamodb_client.query(TableName=table_name, **kwargs)
items = response.get("Items", [])
total_items += len(items)
yield items

if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
break

next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
next_token = response.get("LastEvaluatedKey", None)
if next_token:
kwargs["ExclusiveStartKey"] = next_token


@_handle_reserved_keyword_error
def _read_query(
table_name: str, chunked: bool, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
table_name: str, dynamodb_client: "DynamoDBClient", chunked: bool, **kwargs: Any
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
items_iterator = _read_query_chunked(table_name, boto3_session, **kwargs)
items_iterator = _read_query_chunked(table_name, dynamodb_client, **kwargs)

if chunked:
return items_iterator
Expand All @@ -277,12 +275,13 @@ def _read_query(


def _read_batch_items_chunked(
table_name: str, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
table_name: str, dynamodb_client: Optional["DynamoDBClient"], **kwargs: Any
) -> Iterator[_ItemsListType]:
resource = _utils.resource(service_name="dynamodb", session=boto3_session)
dynamodb_client = dynamodb_client if dynamodb_client else _utils.client("dynamodb")
deserializer = TypeDeserializer()

response = resource.batch_get_item(RequestItems={table_name: kwargs}) # type: ignore[dict-item]
yield response.get("Responses", {table_name: []}).get(table_name, []) # type: ignore[arg-type]
response = dynamodb_client.batch_get_item(RequestItems={table_name: kwargs})
yield [_deserialize_item(d, deserializer) for d in response.get("Responses", {table_name: []}).get(table_name, [])]

# SEE: handle possible unprocessed keys. As suggested in Boto3 docs,
# this approach should involve exponential backoff, but this should be
Expand All @@ -291,15 +290,17 @@ def _read_batch_items_chunked(
while response["UnprocessedKeys"]:
kwargs["Keys"] = response["UnprocessedKeys"][table_name]["Keys"]

response = resource.batch_get_item(RequestItems={table_name: kwargs}) # type: ignore[dict-item]
yield response.get("Responses", {table_name: []}).get(table_name, []) # type: ignore[arg-type]
response = dynamodb_client.batch_get_item(RequestItems={table_name: kwargs})
yield [
_deserialize_item(d, deserializer) for d in response.get("Responses", {table_name: []}).get(table_name, [])
]


@_handle_reserved_keyword_error
def _read_batch_items(
table_name: str, chunked: bool, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
table_name: str, dynamodb_client: Optional["DynamoDBClient"], chunked: bool, **kwargs: Any
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
items_iterator = _read_batch_items_chunked(table_name, boto3_session, **kwargs)
items_iterator = _read_batch_items_chunked(table_name, dynamodb_client, **kwargs)

if chunked:
return items_iterator
Expand All @@ -309,10 +310,13 @@ def _read_batch_items(

@_handle_reserved_keyword_error
def _read_item(
table_name: str, chunked: bool = False, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
table_name: str,
dynamodb_client: "DynamoDBClient",
chunked: bool = False,
**kwargs: Any,
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
table = get_table(table_name=table_name, boto3_session=boto3_session)
item_list: _ItemsListType = [table.get_item(**kwargs).get("Item", {})]
item = dynamodb_client.get_item(TableName=table_name, **kwargs).get("Item", {})
item_list: _ItemsListType = [_deserialize_item(item)]

return [item_list] if chunked else item_list

Expand All @@ -322,13 +326,10 @@ def _read_items_scan(
as_dataframe: bool,
arrow_kwargs: Dict[str, Any],
use_threads: Union[bool, int],
dynamodb_client: "DynamoDBClient",
chunked: bool,
boto3_session: Optional[boto3.Session] = None,
**kwargs: Any,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame], _ItemsListType, Iterator[_ItemsListType]]:
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)

kwargs = _serialize_kwargs(kwargs)
kwargs["TableName"] = table_name
schema = arrow_kwargs.pop("schema", None)

Expand Down Expand Up @@ -368,7 +369,7 @@ def _read_items(
arrow_kwargs: Dict[str, Any],
use_threads: Union[bool, int],
chunked: bool,
boto3_session: Optional[boto3.Session] = None,
dynamodb_client: "DynamoDBClient",
**kwargs: Any,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame], _ItemsListType, Iterator[_ItemsListType]]:
# Extract 'Keys', 'IndexName' and 'Limit' from provided kwargs: if needed, will be reinserted later on
Expand All @@ -384,12 +385,12 @@ def _read_items(
# Single Item
if use_get_item:
kwargs["Key"] = keys[0]
items = _read_item(table_name, chunked, boto3_session, **kwargs)
items = _read_item(table_name, dynamodb_client, chunked, **kwargs)

# Batch of Items
elif use_batch_get_item:
kwargs["Keys"] = keys
items = _read_batch_items(table_name, chunked, boto3_session, **kwargs)
items = _read_batch_items(table_name, dynamodb_client, chunked, **kwargs)

else:
if limit:
Expand All @@ -403,7 +404,7 @@ def _read_items(
if use_query:
# Query
_logger.debug("Query DynamoDB table %s", table_name)
items = _read_query(table_name, chunked, boto3_session, **kwargs)
items = _read_query(table_name, dynamodb_client, chunked, **kwargs)
else:
# Last resort use Scan
warnings.warn(
Expand All @@ -415,8 +416,8 @@ def _read_items(
as_dataframe=as_dataframe,
arrow_kwargs=arrow_kwargs,
use_threads=use_threads,
dynamodb_client=dynamodb_client,
chunked=chunked,
boto3_session=boto3_session,
**kwargs,
)

Expand All @@ -428,6 +429,25 @@ def _read_items(
return _convert_items(items=cast(_ItemsListType, items), as_dataframe=as_dataframe, arrow_kwargs=arrow_kwargs)


class _ExpressionTuple(NamedTuple):
condition_expression: str
attribute_name_placeholders: Dict[str, str]
attribute_value_placeholders: Dict[str, Any]


def _convert_condition_base_to_expression(
key_condition_expression: ConditionBase, is_key_condition: bool, serializer: TypeSerializer
) -> Dict[str, Any]:
builder = ConditionExpressionBuilder()
expression = builder.build_expression(key_condition_expression, is_key_condition=is_key_condition)

return _ExpressionTuple(
condition_expression=expression.condition_expression,
attribute_name_placeholders=expression.attribute_name_placeholders,
attribute_value_placeholders=_serialize_item(expression.attribute_value_placeholders, serializer=serializer),
)


@_utils.validate_distributed_kwargs(
unsupported_kwargs=["boto3_session", "dtype_backend"],
)
Expand Down Expand Up @@ -630,7 +650,9 @@ def read_items( # pylint: disable=too-many-branches
)

# Extract key schema
table_key_schema = get_table(table_name=table_name, boto3_session=boto3_session).key_schema
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
serializer = TypeSerializer()
table_key_schema = dynamodb_client.describe_table(TableName=table_name)["Table"]["KeySchema"]

# Detect sort key, if any
if len(table_key_schema) == 1:
Expand All @@ -645,28 +667,50 @@ def read_items( # pylint: disable=too-many-branches
kwargs: Dict[str, Any] = {"ConsistentRead": consistent}
if partition_values:
if sort_key is None:
keys = [{partition_key: pv} for pv in partition_values]
keys = [{partition_key: serializer.serialize(pv)} for pv in partition_values]
else:
if not sort_values:
raise exceptions.InvalidArgumentType(
f"Kwarg sort_values must be specified: table {table_name} has {sort_key} as sort key."
)
if len(sort_values) != len(partition_values):
raise exceptions.InvalidArgumentCombination("Partition and sort values must have the same length.")
keys = [{partition_key: pv, sort_key: sv} for pv, sv in zip(partition_values, sort_values)]
keys = [
{partition_key: serializer.serialize(pv), sort_key: serializer.serialize(sv)}
for pv, sv in zip(partition_values, sort_values)
]
kwargs["Keys"] = keys
if index_name:
kwargs["IndexName"] = index_name

if key_condition_expression:
kwargs["KeyConditionExpression"] = key_condition_expression
if isinstance(key_condition_expression, str):
kwargs["KeyConditionExpression"] = key_condition_expression
else:
expression_tuple = _convert_condition_base_to_expression(
key_condition_expression, is_key_condition=True, serializer=serializer
)
kwargs["KeyConditionExpression"] = expression_tuple.condition_expression
kwargs["ExpressionAttributeNames"] = expression_tuple.attribute_name_placeholders
kwargs["ExpressionAttributeValues"] = expression_tuple.attribute_value_placeholders

if filter_expression:
kwargs["FilterExpression"] = filter_expression
if isinstance(filter_expression, str):
kwargs["FilterExpression"] = filter_expression
else:
expression_tuple = _convert_condition_base_to_expression(
filter_expression, is_key_condition=False, serializer=serializer
)
kwargs["FilterExpression"] = expression_tuple.condition_expression
kwargs["ExpressionAttributeNames"] = expression_tuple.attribute_name_placeholders
kwargs["ExpressionAttributeValues"] = expression_tuple.attribute_value_placeholders

if columns:
kwargs["ProjectionExpression"] = ", ".join(columns)
if expression_attribute_names:
kwargs["ExpressionAttributeNames"] = expression_attribute_names
if expression_attribute_values:
kwargs["ExpressionAttributeValues"] = expression_attribute_values
kwargs["ExpressionAttributeValues"] = _serialize_item(expression_attribute_values, serializer)
if max_items_evaluated:
kwargs["Limit"] = max_items_evaluated

Expand All @@ -678,8 +722,8 @@ def read_items( # pylint: disable=too-many-branches
as_dataframe=as_dataframe,
arrow_kwargs=arrow_kwargs,
use_threads=use_threads,
boto3_session=boto3_session,
chunked=chunked,
dynamodb_client=dynamodb_client,
**kwargs,
)
# Raise otherwise
Expand Down
1 change: 1 addition & 0 deletions awswrangler/dynamodb/_read.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _read_scan(
dynamodb_client: Optional["DynamoDBClient"],
as_dataframe: bool,
kwargs: Dict[str, Any],
schema: Optional[pa.Schema],
segment: int,
) -> Union[pa.Table, _ItemsListType]: ...
@overload
Expand Down
Loading
Loading