diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 8a062570..6536a37c 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -173,6 +173,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: bucket=bucket, key=None, version_id=None, + delimiter=None, ) self.dircache[bucket] = file else: @@ -206,6 +207,7 @@ def _head_object( bucket=bucket, key=key, version_id=version_id, + delimiter=None, ) self.dircache[path] = file else: @@ -230,6 +232,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]: bucket=b["Name"], key=None, version_id=None, + delimiter=None, ) for b in response["Buckets"] ] @@ -250,55 +253,60 @@ def _ls_dirs( bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" - if path not in self.dircache or refresh: - files: List[S3Object] = [] - while True: - request: Dict[Any, Any] = { - "Bucket": bucket, - "Prefix": prefix, - "Delimiter": delimiter, - } - if next_token: - request.update({"ContinuationToken": next_token}) - if max_keys: - request.update({"MaxKeys": max_keys}) - response = self._call( - self._client.list_objects_v2, - **request, - ) - files.extend( - S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=c["Prefix"][:-1].rstrip("/"), - version_id=version_id, - ) - for c in response.get("CommonPrefixes", []) + + if path in self.dircache and not refresh: + cache = self.dircache[path] + caches = cache if isinstance(cache, list) else [cache] + if all(f.delimiter == delimiter for f in caches): + return caches + + files: List[S3Object] = [] + while True: + request: Dict[Any, Any] = { + "Bucket": bucket, + "Prefix": prefix, + "Delimiter": delimiter, + } + if next_token: + request.update({"ContinuationToken": next_token}) + if max_keys: + request.update({"MaxKeys": max_keys}) + response = self._call( + self._client.list_objects_v2, + **request, + ) + files.extend( + S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=c["Prefix"][:-1].rstrip("/"), + version_id=version_id, + delimiter=delimiter, ) - files.extend( - S3Object( - init=c, - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - bucket=bucket, - key=c["Key"], - ) - for c in response.get("Contents", []) + for c in response.get("CommonPrefixes", []) + ) + files.extend( + S3Object( + init=c, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, + bucket=bucket, + key=c["Key"], + delimiter=delimiter, ) - next_token = response.get("NextContinuationToken") - if not next_token: - break - if files: - self.dircache[path] = files - else: - cache = self.dircache[path] - files = cache if isinstance(cache, list) else [cache] + for c in response.get("Contents", []) + ) + next_token = response.get("NextContinuationToken") + if not next_token: + break + if files: + self.dircache[path] = files return files def ls( @@ -333,6 +341,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=None, version_id=None, + delimiter=None, ) if not refresh: caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path) @@ -358,6 +367,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=key.rstrip("/") if key else None, version_id=version_id, + delimiter=None, ) if key: object_info = self._head_object(path, refresh=refresh, version_id=version_id) @@ -393,23 +403,29 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=key.rstrip("/") if key else None, version_id=version_id, + delimiter=None, ) raise FileNotFoundError(path) - def find( + def _find( self, path: str, maxdepth: Optional[int] = None, withdirs: Optional[bool] = None, - detail: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: - # TODO: Support maxdepth and withdirs + ) -> List[S3Object]: path = self._strip_protocol(path) if path in ["", "/"]: raise ValueError("Cannot traverse all files in S3.") bucket, key, _ = self.parse_path(path) prefix = kwargs.pop("prefix", "") + if maxdepth: + return cast( + List[S3Object], + super() + .find(path=path, maxdepth=maxdepth, withdirs=withdirs, detail=True, **kwargs) + .values(), + ) files = self._ls_dirs(path, prefix=prefix, delimiter="") if not files and key: @@ -417,6 +433,18 @@ def find( files = [self.info(path)] except FileNotFoundError: files = [] + return files + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: Optional[bool] = None, + detail: bool = False, + **kwargs, + ) -> Union[Dict[str, S3Object], List[str]]: + # TODO: Support withdirs + files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs) if detail: return {f.name: f for f in files} return [f.name for f in files]