Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S3 hooks filter options #19018

Closed
wants to merge 12 commits into from
80 changes: 74 additions & 6 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import fnmatch
import gzip as gz
import io
import operator
import re
import shutil
from collections import namedtuple
from functools import wraps
from inspect import signature
from io import BytesIO
Expand All @@ -41,6 +43,60 @@
T = TypeVar("T", bound=Callable)


class ResponseFilter:
"""
ResponseFilter class filters the S3 boto3 response based on various operations
defined by user. This class can be extended as per required for parsing and filtering the
response based on different object filter
"""

allowed_operations = {
"lte": operator.le,
"gte": operator.ge,
"gt": operator.gt,
"lt": operator.lt,
"eq": operator.eq,
}

def __init__(self, data):
self.data = data

def filter(self, object_filter: Optional[dict] = None) -> list:
# if object_filter is None return all the Keys.
if object_filter is None:
result = []
if "Contents" in self.data:
contents = self.data["Contents"]
for content in contents:
result.append(content.get("Key"))
return result

# object_filter is expected to be list of tuple with exactly two elements.
object_filter = [(k, v) for k, v in object_filter.items()]
operation = namedtuple("Q", "op key value")

def parse_filter(item) -> operation:
key, *op = item[0].split("__")
# no value after __ means exact value query, e.g. key='Sample'
op = "".join(op).strip() or "eq"
assert op in self.allowed_operations, f"{repr(op)} operation is not allowed"
return operation(self.allowed_operations[op], key, item[1])

if "Contents" in self.data:
contents = self.data["Contents"]
results = {i["Key"] for i in contents}
for item in map(parse_filter, object_filter):
if "Contents" in self.data:
for content in contents:
if item.op == operator.contains and all(
item.op(content[item.key], v) for v in item.value
):
results.add(content.get("Key"))
elif not item.op(content[item.key], item.value):
results.discard(content.get("Key"))
return results


def provide_bucket_name(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
Expand Down Expand Up @@ -268,6 +324,8 @@ def list_keys(
delimiter: Optional[str] = None,
page_size: Optional[int] = None,
max_items: Optional[int] = None,
start_after_key: Optional[str] = None,
object_filter: Optional[dict] = None,
) -> list:
"""
Lists keys in a bucket under prefix and not containing delimiter
Expand All @@ -282,6 +340,10 @@ def list_keys(
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:param start_after_key: returns keys after this specified key in the bucket.
:type start_after_key: str
:param object_filter: returns keys based on object filter dict
:type object_filter: dict
:return: a list of matched keys
:rtype: list
"""
Expand All @@ -293,15 +355,21 @@ def list_keys(
}

paginator = self.get_conn().get_paginator('list_objects_v2')
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)
operation_parameters = {
"Bucket": bucket_name,
"Prefix": prefix,
"Delimiter": delimiter,
"PaginationConfig": config,
}

if start_after_key:
operation_parameters["StartAfter"] = start_after_key

response = paginator.paginate(**operation_parameters)
keys = []
for page in response:
if 'Contents' in page:
for k in page['Contents']:
keys.append(k['Key'])
page_response = ResponseFilter(page)
keys.extend(page_response.filter(object_filter=object_filter))

return keys

Expand Down
10 changes: 10 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import gzip as gz
import os
import tempfile
from datetime import datetime
from unittest import mock
from unittest.mock import Mock

Expand Down Expand Up @@ -151,10 +152,19 @@ def test_list_keys(self, s3_bucket):
bucket.put_object(Key='a', Body=b'a')
bucket.put_object(Key='dir/b', Body=b'b')

LastModified__gte = datetime.strptime('1900-08-19T09:55:48+0000', '%Y-%m-%dT%H:%M:%S%z')
LastModified__lt = datetime.strptime('1901-08-19T09:55:48+0000', '%Y-%m-%dT%H:%M:%S%z')
object_filter = {
"LastModified__gte": LastModified__gte,
"LastModified__lt": LastModified__lt,
}

assert [] == hook.list_keys(s3_bucket, prefix='non-existent/')
assert ['a', 'dir/b'] == hook.list_keys(s3_bucket)
assert ['a'] == hook.list_keys(s3_bucket, delimiter='/')
assert ['dir/b'] == hook.list_keys(s3_bucket, prefix='dir/')
assert ['dir/b'] == hook.list_keys(s3_bucket, start_after_key='a')
assert [] == hook.list_keys(s3_bucket, object_filter=object_filter)

def test_list_keys_paged(self, s3_bucket):
hook = S3Hook()
Expand Down