From 11421dd07d1ca7014ddbcfebb3088141197c61a3 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 16 Mar 2022 00:05:47 +0545 Subject: [PATCH] Add arguments to filter list: start_after_key, from_datetime, to_datetime, object_filter callable (#22231) Implemented as discussed in [closed PR](https://github.com/apache/airflow/pull/19018). Add more filter options to list_keys of S3Hook - `start_after_key`: should return only keys greater than this key - `from_datetime`: should return only keys with LastModified attr greater than this equal `from_datetime`. - `to_datetime`: should return only keys with LastModified attr less than this `to_datetime`. - `object_filter`: Function callable that receives the list of the S3 objects, `from_datetime` and `to_datetime` and returns the List of the matched key. Add test for the added argument to `list_keys`. closes: #16627 GitOrigin-RevId: 926f6d1894ce9d097ef2256d14a99968638da9c0 --- airflow/providers/amazon/aws/hooks/s3.py | 59 +++++++++++++++++++-- tests/providers/amazon/aws/hooks/test_s3.py | 12 +++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 976f8e9a4c2..fc23255a78d 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -23,6 +23,7 @@ import io import re import shutil +from datetime import datetime from functools import wraps from inspect import signature from io import BytesIO @@ -255,6 +256,18 @@ def list_prefixes( return prefixes + def _list_key_object_filter( + self, keys: list, from_datetime: Optional[datetime] = None, to_datetime: Optional[datetime] = None + ) -> list: + def _is_in_period(input_date: datetime) -> bool: + if from_datetime is not None and input_date <= from_datetime: + return False + if to_datetime is not None and input_date > to_datetime: + return False + return True + + return [k['Key'] for k in keys if _is_in_period(k['LastModified'])] + @provide_bucket_name def list_keys( self, @@ -263,6 +276,10 @@ def list_keys( delimiter: Optional[str] = None, page_size: Optional[int] = None, max_items: Optional[int] = None, + start_after_key: Optional[str] = None, + from_datetime: Optional[datetime] = None, + to_datetime: Optional[datetime] = None, + object_filter: Optional[Callable[..., list]] = None, ) -> list: """ Lists keys in a bucket under prefix and not containing delimiter @@ -272,11 +289,40 @@ def list_keys( :param delimiter: the delimiter marks key hierarchy. :param page_size: pagination size :param max_items: maximum items to return + :param start_after_key: should return only keys greater than this key + :param from_datetime: should return only keys with LastModified attr greater than this equal + from_datetime + :param to_datetime: should return only keys with LastModified attr less than this to_datetime + :param object_filter: Function that receives the list of the S3 objects, from_datetime and + to_datetime and returns the List of matched key. + + **Example**: Returns the list of S3 object with LastModified attr greater than from_datetime + and less than to_datetime: + + .. code-block:: python + + def object_filter( + keys: list, + from_datetime: Optional[datetime] = None, + to_datetime: Optional[datetime] = None, + ) -> list: + def _is_in_period(input_date: datetime) -> bool: + if from_datetime is not None and input_date < from_datetime: + return False + + if to_datetime is not None and input_date > to_datetime: + return False + return True + + return [k["Key"] for k in keys if _is_in_period(k["LastModified"])] + :return: a list of matched keys :rtype: list """ prefix = prefix or '' delimiter = delimiter or '' + start_after_key = start_after_key or '' + self.object_filter_usr = object_filter config = { 'PageSize': page_size, 'MaxItems': max_items, @@ -284,16 +330,23 @@ def list_keys( paginator = self.get_conn().get_paginator('list_objects_v2') response = paginator.paginate( - Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config + Bucket=bucket_name, + Prefix=prefix, + Delimiter=delimiter, + PaginationConfig=config, + StartAfter=start_after_key, ) keys = [] for page in response: if 'Contents' in page: for k in page['Contents']: - keys.append(k['Key']) + keys.append(k) + + if self.object_filter_usr is not None: + return self.object_filter_usr(keys, from_datetime, to_datetime) - return keys + return self._list_key_object_filter(keys, from_datetime, to_datetime) @provide_bucket_name @unify_bucket_name_and_key diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 934993c0405..a9cbbe09f7c 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -28,6 +28,7 @@ from airflow.models import Connection from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name, unify_bucket_name_and_key +from airflow.utils.timezone import datetime try: from moto import mock_s3 @@ -151,10 +152,21 @@ def test_list_keys(self, s3_bucket): bucket.put_object(Key='a', Body=b'a') bucket.put_object(Key='dir/b', Body=b'b') + from_datetime = datetime(1992, 3, 8, 18, 52, 51) + to_datetime = datetime(1993, 3, 14, 21, 52, 42) + + def dummy_object_filter(keys, from_datetime=None, to_datetime=None): + return [] + assert [] == hook.list_keys(s3_bucket, prefix='non-existent/') assert ['a', 'dir/b'] == hook.list_keys(s3_bucket) assert ['a'] == hook.list_keys(s3_bucket, delimiter='/') assert ['dir/b'] == hook.list_keys(s3_bucket, prefix='dir/') + assert ['dir/b'] == hook.list_keys(s3_bucket, start_after_key='a') + assert [] == hook.list_keys(s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime) + assert [] == hook.list_keys( + s3_bucket, from_datetime=from_datetime, to_datetime=to_datetime, object_filter=dummy_object_filter + ) def test_list_keys_paged(self, s3_bucket): hook = S3Hook()