Skip to content

Commit

Permalink
Add arguments to filter list: start_after_key, from_datetime, to_date…
Browse files Browse the repository at this point in the history
…time, object_filter callable (#22231)

Implemented as discussed in [closed PR](apache/airflow#19018).

Add more filter options to list_keys of S3Hook
- `start_after_key`: should return only keys greater than this key
- `from_datetime`: should return only keys with LastModified attr greater than this equal `from_datetime`.
- `to_datetime`: should return only keys with LastModified attr less than this `to_datetime`.
- `object_filter`: Function callable that receives the list of the S3 objects, `from_datetime` and `to_datetime` and returns the List of the matched key.

Add test for the added argument to `list_keys`.

closes: #16627
GitOrigin-RevId: 926f6d1894ce9d097ef2256d14a99968638da9c0
  • Loading branch information
sunank200 authored and Cloud Composer Team committed Oct 7, 2022
1 parent d5c14e4 commit 11421dd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
59 changes: 56 additions & 3 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io
import re
import shutil
from datetime import datetime
from functools import wraps
from inspect import signature
from io import BytesIO
Expand Down Expand Up @@ -255,6 +256,18 @@ def list_prefixes(

return prefixes

def _list_key_object_filter(
self, keys: list, from_datetime: Optional[datetime] = None, to_datetime: Optional[datetime] = None
) -> list:
def _is_in_period(input_date: datetime) -> bool:
if from_datetime is not None and input_date <= from_datetime:
return False
if to_datetime is not None and input_date > to_datetime:
return False
return True

return [k['Key'] for k in keys if _is_in_period(k['LastModified'])]

@provide_bucket_name
def list_keys(
self,
Expand All @@ -263,6 +276,10 @@ def list_keys(
delimiter: Optional[str] = None,
page_size: Optional[int] = None,
max_items: Optional[int] = None,
start_after_key: Optional[str] = None,
from_datetime: Optional[datetime] = None,
to_datetime: Optional[datetime] = None,
object_filter: Optional[Callable[..., list]] = None,
) -> list:
"""
Lists keys in a bucket under prefix and not containing delimiter
Expand All @@ -272,28 +289,64 @@ def list_keys(
:param delimiter: the delimiter marks key hierarchy.
:param page_size: pagination size
:param max_items: maximum items to return
:param start_after_key: should return only keys greater than this key
:param from_datetime: should return only keys with LastModified attr greater than this equal
from_datetime
:param to_datetime: should return only keys with LastModified attr less than this to_datetime
:param object_filter: Function that receives the list of the S3 objects, from_datetime and
to_datetime and returns the List of matched key.
**Example**: Returns the list of S3 object with LastModified attr greater than from_datetime
and less than to_datetime:
.. code-block:: python
def object_filter(
keys: list,
from_datetime: Optional[datetime] = None,
to_datetime: Optional[datetime] = None,
) -> list:
def _is_in_period(input_date: datetime) -> bool:
if from_datetime is not None and input_date < from_datetime:
return False
if to_datetime is not None and input_date > to_datetime:
return False
return True
return [k["Key"] for k in keys if _is_in_period(k["LastModified"])]
:return: a list of matched keys
:rtype: list
"""
prefix = prefix or ''
delimiter = delimiter or ''
start_after_key = start_after_key or ''
self.object_filter_usr = object_filter
config = {
'PageSize': page_size,
'MaxItems': max_items,
}

paginator = self.get_conn().get_paginator('list_objects_v2')
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config,
StartAfter=start_after_key,
)

keys = []
for page in response:
if 'Contents' in page:
for k in page['Contents']:
keys.append(k['Key'])
keys.append(k)

if self.object_filter_usr is not None:
return self.object_filter_usr(keys, from_datetime, to_datetime)

return keys
return self._list_key_object_filter(keys, from_datetime, to_datetime)

@provide_bucket_name
@unify_bucket_name_and_key
Expand Down
12 changes: 12 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name, unify_bucket_name_and_key
from airflow.utils.timezone import datetime

try:
from moto import mock_s3
Expand Down Expand Up @@ -151,10 +152,21 @@ def test_list_keys(self, s3_bucket):
bucket.put_object(Key='a', Body=b'a')
bucket.put_object(Key='dir/b', Body=b'b')

from_datetime = datetime(1992, 3, 8, 18, 52, 51)
to_datetime = datetime(1993, 3, 14, 21, 52, 42)

def dummy_object_filter(keys, from_datetime=None, to_datetime=None):
return []

assert [] == hook.list_keys(s3_bucket, prefix='non-existent/')
assert ['a', 'dir/b'] == hook.list_keys(s3_bucket)
assert ['a'] == hook.list_keys(s3_bucket, delimiter='/')
assert ['dir/b'] == hook.list_keys(s3_bucket, prefix='dir/')
assert ['dir/b'] == hook.list_keys(s3_bucket, start_after_key='a')
assert [] == hook.list_keys(s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime)
assert [] == hook.list_keys(
s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime, object_filter=dummy_object_filter
)

def test_list_keys_paged(self, s3_bucket):
hook = S3Hook()
Expand Down

0 comments on commit 11421dd

Please sign in to comment.