Skip to content

Commit

Permalink
Use tmp_path in amazon s3 test (#33705)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Aug 24, 2023
1 parent e8ba579 commit 9b8a093
Showing 1 changed file with 98 additions and 116 deletions.
214 changes: 98 additions & 116 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import inspect
import os
import re
import tempfile
import unittest
from pathlib import Path
from unittest import mock, mock as async_mock
from unittest.mock import MagicMock, Mock, patch

Expand Down Expand Up @@ -827,63 +825,52 @@ def test_load_bytes_acl(self, s3_bucket):
response["Grants"][0]["Permission"] == "FULL_CONTROL"
)

def test_load_fileobj(self, s3_bucket):
def test_load_fileobj(self, s3_bucket, tmp_path):
hook = S3Hook()
with tempfile.TemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket)
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert resource.get()["Body"].read() == b"Content"

def test_load_fileobj_acl(self, s3_bucket):
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file_obj(path.open("rb"), "my_key", s3_bucket)
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert resource.get()["Body"].read() == b"Content"

def test_load_fileobj_acl(self, s3_bucket, tmp_path):
hook = S3Hook()
with tempfile.TemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy="public-read")
response = boto3.client("s3").get_object_acl(
Bucket=s3_bucket, Key="my_key", RequestPayer="requester"
)
assert (response["Grants"][1]["Permission"] == "READ") and (
response["Grants"][0]["Permission"] == "FULL_CONTROL"
)
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, acl_policy="public-read")
response = boto3.client("s3").get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer="requester")
assert (response["Grants"][1]["Permission"] == "READ") and (
response["Grants"][0]["Permission"] == "FULL_CONTROL"
)

def test_load_file_gzip(self, s3_bucket):
def test_load_file_gzip(self, s3_bucket, tmp_path):
hook = S3Hook()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True)
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert gz.decompress(resource.get()["Body"].read()) == b"Content"
os.unlink(temp_file.name)

def test_load_file_acl(self, s3_bucket):
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file(path, "my_key", s3_bucket, gzip=True)
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert gz.decompress(resource.get()["Body"].read()) == b"Content"

def test_load_file_acl(self, s3_bucket, tmp_path):
hook = S3Hook()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True, acl_policy="public-read")
response = boto3.client("s3").get_object_acl(
Bucket=s3_bucket, Key="my_key", RequestPayer="requester"
)
assert (response["Grants"][1]["Permission"] == "READ") and (
response["Grants"][0]["Permission"] == "FULL_CONTROL"
)
os.unlink(temp_file.name)
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file(path, "my_key", s3_bucket, gzip=True, acl_policy="public-read")
response = boto3.client("s3").get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer="requester")
assert (response["Grants"][1]["Permission"] == "READ") and (
response["Grants"][0]["Permission"] == "FULL_CONTROL"
)

def test_copy_object_acl(self, s3_bucket):
def test_copy_object_acl(self, s3_bucket, tmp_path):
hook = S3Hook()
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket)
hook.copy_object("my_key", "my_key2", s3_bucket, s3_bucket)
response = boto3.client("s3").get_object_acl(
Bucket=s3_bucket, Key="my_key2", RequestPayer="requester"
)
assert (response["Grants"][0]["Permission"] == "FULL_CONTROL") and (len(response["Grants"]) == 1)
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file_obj(path.open("rb"), "my_key", s3_bucket)
hook.copy_object("my_key", "my_key2", s3_bucket, s3_bucket)
response = boto3.client("s3").get_object_acl(
Bucket=s3_bucket, Key="my_key2", RequestPayer="requester"
)
assert (response["Grants"][0]["Permission"] == "FULL_CONTROL") and (len(response["Grants"]) == 1)

@mock_s3
def test_delete_bucket_if_bucket_exist(self, s3_bucket):
Expand Down Expand Up @@ -974,34 +961,33 @@ def test_function_with_test_key(self, test_key, bucket_name=None):
assert isinstance(ctx.value, ValueError)

@mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
def test_download_file(self, mock_temp_file):
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file:
mock_temp_file.return_value = temp_file
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
key = "test_key"
bucket = "test_bucket"

output_file = s3_hook.download_file(key=key, bucket_name=bucket)

s3_hook.get_key.assert_called_once_with(key, bucket)
s3_obj.download_fileobj.assert_called_once_with(
temp_file,
Config=s3_hook.transfer_config,
ExtraArgs=s3_hook.extra_args,
)
def test_download_file(self, mock_temp_file, tmp_path):
path = tmp_path / "airflow_tmp_test_s3_hook"
mock_temp_file.return_value = path
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
key = "test_key"
bucket = "test_bucket"

output_file = s3_hook.download_file(key=key, bucket_name=bucket)

assert temp_file.name == output_file
s3_hook.get_key.assert_called_once_with(key, bucket)
s3_obj.download_fileobj.assert_called_once_with(
path,
Config=s3_hook.transfer_config,
ExtraArgs=s3_hook.extra_args,
)

assert path.name == output_file

@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
def test_download_file_with_preserve_name(self, mock_open):
file_name = "test.log"
def test_download_file_with_preserve_name(self, mock_open, tmp_path):
path = tmp_path / "test.log"
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"
key = f"test_key/{path.name}"

s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
Expand All @@ -1012,19 +998,18 @@ def test_download_file_with_preserve_name(self, mock_open):
s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
local_path=os.fspath(path.parent),
preserve_file_name=True,
use_autogenerated_subdir=False,
)

mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")
mock_open.assert_called_once_with(path, "wb")

@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open):
file_name = "test.log"
def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open, tmp_path):
path = tmp_path / "test.log"
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"
key = f"test_key/{path.name}"

s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
Expand All @@ -1035,33 +1020,32 @@ def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_o
result_file = s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
local_path=os.fspath(path.parent),
preserve_file_name=True,
use_autogenerated_subdir=True,
)

assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")

def test_download_file_with_preserve_name_file_already_exists(self):
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file:
file_name = file.name.rsplit("/", 1)[-1]
bucket = "test_bucket"
key = f"test_key/{file_name}"
local_folder = "/tmp"
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
with pytest.raises(FileExistsError):
s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=local_folder,
preserve_file_name=True,
use_autogenerated_subdir=False,
)
def test_download_file_with_preserve_name_file_already_exists(self, tmp_path):
path = tmp_path / "airflow_tmp_test_s3_hook"
path.write_text("")
bucket = "test_bucket"
key = f"test_key/{path.name}"
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
with pytest.raises(FileExistsError):
s3_hook.download_file(
key=key,
bucket_name=bucket,
local_path=os.fspath(path.parent),
preserve_file_name=True,
use_autogenerated_subdir=False,
)

def test_generate_presigned_url(self, s3_bucket):
hook = S3Hook()
Expand All @@ -1078,22 +1062,20 @@ def test_should_throw_error_if_extra_args_is_not_dict(self):
with pytest.raises(TypeError, match="extra_args expected dict, got .*"):
S3Hook(extra_args=1)

def test_should_throw_error_if_extra_args_contains_unknown_arg(self, s3_bucket):
def test_should_throw_error_if_extra_args_contains_unknown_arg(self, s3_bucket, tmp_path):
hook = S3Hook(extra_args={"unknown_s3_args": "value"})
with tempfile.TemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
with pytest.raises(ValueError):
hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy="public-read")
path = tmp_path / "testfile"
path.write_text("Content")
with pytest.raises(ValueError):
hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, acl_policy="public-read")

def test_should_pass_extra_args(self, s3_bucket):
def test_should_pass_extra_args(self, s3_bucket, tmp_path):
hook = S3Hook(extra_args={"ContentLanguage": "value"})
with tempfile.TemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy="public-read")
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert resource.get()["ContentLanguage"] == "value"
path = tmp_path / "testfile"
path.write_text("Content")
hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, acl_policy="public-read")
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert resource.get()["ContentLanguage"] == "value"

def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
original = {
Expand Down

0 comments on commit 9b8a093

Please sign in to comment.