Skip to content

Commit

Permalink
[COST-3082] Add bucket_region as an additional parameter for AWS so…
Browse files Browse the repository at this point in the history
…urces (#4183)

* Add region_name to relevant calls
* Add test cases for AWS bucket region
* Set default region in method signature
  For consistency, it’s better to be implicit. The S3 API will default to using
  us-east-1 if region_name is not supplied, so this is not a change in behavior.

* Only pass in region_name if it exists in the data source
  Rely on default value in function definition rather than specifying a default
  value in calling code.
  • Loading branch information
samdoran authored Mar 16, 2023
1 parent 4924314 commit 5063a5d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 11 deletions.
13 changes: 10 additions & 3 deletions dev/scripts/create_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def create_parser():
parser.add_argument(
"--s3_bucket", dest="s3_bucket", required=False, help="AWS S3 bucket with cost and usage report"
)
parser.add_argument("--s3_region", required=False, help="AWS S3 region")
parser.add_argument("--resource_group", dest="resource_group", required=False, help="AZURE Storage Resource Group")
parser.add_argument("--storage_account", dest="storage_account", required=False, help="AZURE Storage Account")
parser.add_argument("--scope", dest="scope", required=False, help="AZURE Cost Export Scope")
Expand Down Expand Up @@ -70,8 +71,13 @@ def __init__(self, auth_header):
header = {"x-rh-identity": auth_header}
self._identity_header = header

def create_s3_bucket(self, parameters, billing_source):
json_data = {"billing_source": {"bucket": billing_source}}
def create_s3_bucket(self, parameters, billing_source, s3_region=None):
json_data = {
"billing_source": {
"bucket": billing_source,
"bucket_region": s3_region,
}
}

url = "{}/{}/".format(self._base_url, parameters.get("source_id"))
response = requests.patch(url, headers=self._identity_header, json=json_data)
Expand Down Expand Up @@ -220,11 +226,12 @@ def main(args): # noqa
if parameters.get("aws"):
role_arn = parameters.get("role_arn")
s3_bucket = parameters.get("s3_bucket")
s3_region = parameters.get("s3_region")
source_id_param = parameters.get("source_id")

if s3_bucket and source_id_param:
sources_client = SourcesClientDataGenerator(identity_header)
billing_source_response = sources_client.create_s3_bucket(parameters, s3_bucket)
billing_source_response = sources_client.create_s3_bucket(parameters, s3_bucket, s3_region)
print(f"Associating S3 bucket: {billing_source_response.content}")
return

Expand Down
16 changes: 10 additions & 6 deletions koku/providers/aws/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _get_sts_access(role_arn):
)


def _check_s3_access(bucket, credentials):
def _check_s3_access(bucket, credentials, region_name="us-east-1"):
"""Check for access to s3 bucket."""
s3_exists = True
s3_resource = boto3.resource("s3", **credentials)
s3_resource = boto3.resource("s3", region_name=region_name, **credentials)
try:
s3_resource.meta.client.head_bucket(Bucket=bucket)
except (ClientError, BotoConnectionError) as boto_error:
Expand All @@ -63,9 +63,9 @@ def _check_s3_access(bucket, credentials):
return s3_exists


def _check_cost_report_access(credential_name, credentials, region="us-east-1", bucket=None):
def _check_cost_report_access(credential_name, credentials, region_name="us-east-1", bucket=None):
"""Check for provider cost and usage report access."""
cur_client = boto3.client("cur", region_name=region, **credentials)
cur_client = boto3.client("cur", region_name=region_name, **credentials)
reports = None

try:
Expand Down Expand Up @@ -136,13 +136,17 @@ def cost_usage_source_is_reachable(self, credentials, data_source):
internal_message = f"Unable to access account resources with ARN {credential_name}."
raise serializers.ValidationError(error_obj(key, internal_message))

s3_exists = _check_s3_access(storage_resource_name, creds)
region_kwargs = {}
if region_name := data_source.get("bucket_region"):
region_kwargs["region_name"] = region_name

s3_exists = _check_s3_access(storage_resource_name, creds, **region_kwargs)
if not s3_exists:
key = ProviderErrors.AWS_BILLING_SOURCE_NOT_FOUND
internal_message = f"Bucket {storage_resource_name} could not be found with {credential_name}."
raise serializers.ValidationError(error_obj(key, internal_message))

_check_cost_report_access(credential_name, creds, bucket=storage_resource_name)
_check_cost_report_access(credential_name, creds, bucket=storage_resource_name, **region_kwargs)

return True

Expand Down
61 changes: 61 additions & 0 deletions koku/providers/test/aws/test_aws_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,26 @@ def test_check_s3_access_fail(self, mock_boto3_resource):
s3_exists = _check_s3_access("bucket", {})
self.assertFalse(s3_exists)

@patch("providers.aws.provider.boto3.resource", side_effect=AttributeError("Raised intentionally"))
def test_check_s3_access_default_region(self, mock_boto3_resource):
"""Test that the default region value is used"""
expected_region_name = "us-east-1"
with self.assertRaisesRegex(AttributeError, "Raised intentionally"):
_check_s3_access("bucket", {})

self.assertIn("region_name", mock_boto3_resource.call_args.kwargs)
self.assertEqual(expected_region_name, mock_boto3_resource.call_args.kwargs.get("region_name"))

@patch("providers.aws.provider.boto3.resource", side_effect=AttributeError("Raised intentionally"))
def test_check_s3_access_with_region(self, mock_boto3_resource):
"""Test that the provided region value is used"""
expected_region_name = "eu-south-2"
with self.assertRaisesRegex(AttributeError, "Raised intentionally"):
_check_s3_access("bucket", {}, region_name=expected_region_name)

self.assertIn("region_name", mock_boto3_resource.call_args.kwargs)
self.assertEqual(expected_region_name, mock_boto3_resource.call_args.kwargs.get("region_name"))

@patch("providers.aws.provider.boto3.client")
def test_check_cost_report_access(self, mock_boto3_client):
"""Test _check_cost_report_access success."""
Expand Down Expand Up @@ -266,6 +286,47 @@ def test_cost_usage_source_is_reachable(
except Exception:
self.fail("Unexpected Error")

@patch(
"providers.aws.provider._get_sts_access",
return_value=dict(
aws_access_key_id=FAKE.md5(), aws_secret_access_key=FAKE.md5(), aws_session_token=FAKE.md5()
),
)
@patch("providers.aws.provider._check_s3_access", return_value=True)
@patch("providers.aws.provider._check_cost_report_access")
def test_cost_usage_source_is_reachable_with_region(
self, mock_check_cost_report_access, mock_check_s3_access, mock_get_sts_access
):
"""Verify that the bucket region is used when available"""
provider_interface = AWSProvider()
credentials = {"role_arn": "arn:aws:s3:::my_s3_bucket"}
data_source = {"bucket": "bucket_name", "bucket_region": "me-south-1"}
provider_interface.cost_usage_source_is_reachable(credentials, data_source)

self.assertIn("region_name", mock_check_s3_access.call_args.kwargs)
self.assertIn("region_name", mock_check_cost_report_access.call_args.kwargs)

@patch(
"providers.aws.provider._get_sts_access",
return_value=dict(
aws_access_key_id=FAKE.md5(), aws_secret_access_key=FAKE.md5(), aws_session_token=FAKE.md5()
),
)
@patch("providers.aws.provider._check_s3_access", return_value=True)
@patch("providers.aws.provider._check_cost_report_access")
def test_cost_usage_source_is_call_no_region(
self, mock_check_cost_report_access, mock_check_s3_access, mock_get_sts_access
):
"""Verify that the bucket region is not passed in when not availeble in the data source
so that the default value in the function definiton is used."""
provider_interface = AWSProvider()
credentials = {"role_arn": "arn:aws:s3:::my_s3_bucket"}
data_source = {"bucket": "bucket_name"}
provider_interface.cost_usage_source_is_reachable(credentials, data_source)

self.assertNotIn("region_name", mock_check_s3_access.call_args.kwargs)
self.assertNotIn("region_name", mock_check_cost_report_access.call_args.kwargs)

def test_cost_usage_source_is_reachable_no_arn(self):
"""Verify that the cost usage source is authenticated and created."""
provider_interface = AWSProvider()
Expand Down
4 changes: 2 additions & 2 deletions koku/sources/sources_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
ENDPOINT_SOURCE_TYPES = "source_types"
APP_OPT_EXTRA_FEILD_MAP = {
Provider.PROVIDER_OCP: [],
Provider.PROVIDER_AWS: ["storage_only"],
Provider.PROVIDER_AWS_LOCAL: ["storage_only"],
Provider.PROVIDER_AWS: ["storage_only", "bucket_region"],
Provider.PROVIDER_AWS_LOCAL: ["storage_only", "bucket_region"],
Provider.PROVIDER_AZURE: ["scope", "export_name", "storage_only"],
Provider.PROVIDER_AZURE_LOCAL: ["scope", "export_name", "storage_only"],
Provider.PROVIDER_GCP: ["dataset", "bucket", "storage_only"],
Expand Down
5 changes: 5 additions & 0 deletions koku/sources/test/test_sources_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def test_get_data_source(self):
"json": {"extra": {"bucket": bucket, "storage_only": True}},
"expected": {"bucket": bucket, "storage_only": True},
},
{
"source-type": Provider.PROVIDER_AWS,
"json": {"extra": {"bucket": bucket, "bucket_region": "me-south-1"}},
"expected": {"bucket": bucket, "bucket_region": "me-south-1"},
},
{
"source-type": Provider.PROVIDER_AZURE,
"json": {
Expand Down

0 comments on commit 5063a5d

Please sign in to comment.