Skip to content

Commit

Permalink
Use S3CopyObjectOperator in example_comprehend_document_classifier (
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 0613867 commit ad95a82
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
5 changes: 5 additions & 0 deletions providers/src/airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ def copy_object(
dest_bucket_name: str | None = None,
source_version_id: str | None = None,
acl_policy: str | None = None,
meta_data_directive: str | None = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -1326,10 +1327,14 @@ def copy_object(
:param source_version_id: Version ID of the source object (OPTIONAL)
:param acl_policy: The string to specify the canned ACL policy for the
object to be copied which is private by default.
:param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it
with metadata that's provided in the request.
"""
acl_policy = acl_policy or "private"
if acl_policy != NO_ACL:
kwargs["ACL"] = acl_policy
if meta_data_directive:
kwargs["MetadataDirective"] = meta_data_directive

dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
Expand Down
5 changes: 5 additions & 0 deletions providers/src/airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ class S3CopyObjectOperator(BaseOperator):
CA cert bundle than the one used by botocore.
:param acl_policy: String specifying the canned ACL policy for the file being
uploaded to the S3 bucket.
:param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with
metadata that's provided in the request.
"""

template_fields: Sequence[str] = (
Expand All @@ -302,6 +304,7 @@ def __init__(
aws_conn_id: str | None = "aws_default",
verify: str | bool | None = None,
acl_policy: str | None = None,
meta_data_directive: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -314,6 +317,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
self.acl_policy = acl_policy
self.meta_data_directive = meta_data_directive

def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
Expand All @@ -324,6 +328,7 @@ def execute(self, context: Context):
self.dest_bucket_name,
self.source_version_id,
self.acl_policy,
self.meta_data_directive,
)

def get_openlineage_facets_on_start(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ComprehendCreateDocumentClassifierOperator,
)
from airflow.providers.amazon.aws.operators.s3 import (
S3CopyObjectOperator,
S3CreateBucketOperator,
S3CreateObjectOperator,
S3DeleteBucketOperator,
Expand Down Expand Up @@ -140,7 +141,14 @@ def copy_data_to_s3(bucket: str, sources: list[dict], prefix: str, number_of_cop
http_to_s3_configs = [
{
"endpoint": source["endpoint"],
"s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}",
"s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
}
for source in sources
]
copy_to_s3_configs = [
{
"source_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
"dest_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}",
}
for counter in range(number_of_copies)
for source in sources
Expand Down Expand Up @@ -170,7 +178,14 @@ def delete_connection(conn_id):
s3_bucket=bucket,
).expand_kwargs(http_to_s3_configs)

chain(create_connection(http_conn_id), http_to_s3_task, delete_connection(http_conn_id))
s3_copy_task = S3CopyObjectOperator.partial(
task_id="s3_copy_task",
source_bucket_name=bucket,
dest_bucket_name=bucket,
meta_data_directive="REPLACE",
).expand_kwargs(copy_to_s3_configs)

chain(create_connection(http_conn_id), http_to_s3_task, s3_copy_task, delete_connection(http_conn_id))


with DAG(
Expand Down

0 comments on commit ad95a82

Please sign in to comment.