diff --git a/docker-compose.yml b/docker-compose.yml index 0c8fea87e0..6af0f3ae74 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -353,6 +353,7 @@ services: - SOURCES_KAFKA_HOST=${SOURCES_KAFKA_HOST-kafka} - SOURCES_KAFKA_PORT=${SOURCES_KAFKA_PORT-29092} - KOKU_SOURCES_CLIENT_PORT=${KOKU_SOURCES_CLIENT_PORT-9000} + - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} - prometheus_multiproc_dir=/tmp privileged: true ports: diff --git a/koku/providers/gcp/provider.py b/koku/providers/gcp/provider.py index 30cafa1263..d34b23f11d 100644 --- a/koku/providers/gcp/provider.py +++ b/koku/providers/gcp/provider.py @@ -8,6 +8,7 @@ from googleapiclient.errors import HttpError from rest_framework import serializers +from ..provider_errors import ProviderErrors from ..provider_errors import SkipStatusPush from ..provider_interface import ProviderInterface from api.common import error_obj @@ -70,9 +71,11 @@ def cost_usage_source_is_reachable(self, credentials, data_source): for required_permission in REQUIRED_IAM_PERMISSIONS: if required_permission not in permissions: - key = "authentication.project_id" - err_msg = f"Improper IAM permissions: {permissions}." - raise serializers.ValidationError(error_obj(key, err_msg)) + key = ProviderErrors.GCP_INCORRECT_IAM_PERMISSIONS + internal_message = f"Improper IAM permissions: {permissions}." + LOG.warning(internal_message) + message = f"Incorrect IAM permissions for project {project}" + raise serializers.ValidationError(error_obj(key, message)) except GoogleCloudError as e: key = "authentication.project_id" diff --git a/koku/providers/provider_errors.py b/koku/providers/provider_errors.py index 286925be35..62a68f6418 100644 --- a/koku/providers/provider_errors.py +++ b/koku/providers/provider_errors.py @@ -43,6 +43,8 @@ class ProviderErrors: AZURE_CREDENTAL_UNREACHABLE = "authentication.credentials.unreachable" AZURE_CLIENT_ERROR = "azure.exception" + GCP_INCORRECT_IAM_PERMISSIONS = "gcp.iam.permissions" + # MESSAGES INVALID_SOURCE_TYPE_MESSAGE = "The given source type is not supported." diff --git a/koku/sources/api/serializers.py b/koku/sources/api/serializers.py index dbc8a6c6c6..73d4d1819e 100644 --- a/koku/sources/api/serializers.py +++ b/koku/sources/api/serializers.py @@ -15,7 +15,6 @@ # along with this program. If not, see . # """Sources Model Serializers.""" -import copy import logging from socket import gaierror from uuid import uuid4 @@ -35,9 +34,12 @@ from providers.provider_errors import SkipStatusPush from sources.api import get_account_from_header from sources.api import get_auth_header +from sources.storage import _update_authentication +from sources.storage import _update_billing_source from sources.storage import get_source_instance from sources.storage import SourcesStorageError + LOG = logging.getLogger(__name__) ALLOWED_BILLING_SOURCE_PROVIDERS = ( @@ -90,54 +92,6 @@ def get_source_uuid(self, obj): """Get the source_uuid.""" return obj.source_uuid - def _validate_billing_source(self, provider_type, billing_source): # noqa: C901 - """Validate billing source parameters.""" - if provider_type == Provider.PROVIDER_AWS: - # TODO: Remove `and not billing_source.get("bucket")` if UI is updated to send "data_source" field - if not billing_source.get("data_source", {}).get("bucket") and not billing_source.get("bucket"): - raise SourcesStorageError("Missing AWS bucket.") - elif provider_type == Provider.PROVIDER_AZURE: - data_source = billing_source.get("data_source") - if not data_source: - raise SourcesStorageError("Missing AZURE data_source.") - if not data_source.get("resource_group"): - raise SourcesStorageError("Missing AZURE resource_group") - if not data_source.get("storage_account"): - raise SourcesStorageError("Missing AZURE storage_account") - elif provider_type == Provider.PROVIDER_GCP: - data_source = billing_source.get("data_source") - if not data_source: - raise SourcesStorageError("Missing GCP data_source.") - if not data_source.get("dataset"): - raise SourcesStorageError("Missing GCP dataset") - - def _update_billing_source(self, instance, billing_source): - if instance.source_type not in ALLOWED_BILLING_SOURCE_PROVIDERS: - raise SourcesStorageError(f"Option not supported by source type {instance.source_type}.") - if instance.billing_source.get("data_source"): - billing_copy = copy.deepcopy(instance.billing_source.get("data_source")) - data_source = billing_source.get("data_source", {}) - if data_source.get("resource_group") or data_source.get("storage_account"): - billing_copy.update(billing_source.get("data_source")) - billing_source["data_source"] = billing_copy - self._validate_billing_source(instance.source_type, billing_source) - # This if statement can also be removed if UI is updated to send "data_source" field - if instance.source_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL) and not billing_source.get( - "data_source" - ): - billing_source = {"data_source": billing_source} - return billing_source - - def _update_authentication(self, instance, authentication): - if instance.source_type not in ALLOWED_AUTHENTICATION_PROVIDERS: - raise SourcesStorageError(f"Option not supported by source type {instance.source_type}.") - auth_dict = instance.authentication - if not auth_dict.get("credentials"): - auth_dict["credentials"] = {"subscription_id": None} - subscription_id = authentication.get("credentials", {}).get("subscription_id") - auth_dict["credentials"]["subscription_id"] = subscription_id - return auth_dict - def update(self, instance, validated_data): """Update a Provider instance from validated data.""" billing_source = validated_data.get("billing_source") @@ -146,10 +100,10 @@ def update(self, instance, validated_data): try: with ServerProxy(SOURCES_CLIENT_BASE_URL) as sources_client: if billing_source: - billing_source = self._update_billing_source(instance, billing_source) + billing_source = _update_billing_source(instance, billing_source) sources_client.update_billing_source(instance.source_id, billing_source) if authentication: - authentication = self._update_authentication(instance, authentication) + authentication = _update_authentication(instance, authentication) sources_client.update_authentication(instance.source_id, authentication) except Fault as error: LOG.error(f"Sources update error: {error}") diff --git a/koku/sources/kafka_listener.py b/koku/sources/kafka_listener.py index 4626747197..6d76868a79 100644 --- a/koku/sources/kafka_listener.py +++ b/koku/sources/kafka_listener.py @@ -60,6 +60,7 @@ PROCESS_QUEUE = queue.PriorityQueue() COUNT = itertools.count() # next(COUNT) returns next sequential number KAFKA_APPLICATION_CREATE = "Application.create" +KAFKA_APPLICATION_UPDATE = "Application.update" KAFKA_APPLICATION_DESTROY = "Application.destroy" KAFKA_AUTHENTICATION_CREATE = "Authentication.create" KAFKA_AUTHENTICATION_UPDATE = "Authentication.update" @@ -106,6 +107,7 @@ def __init__(self, auth_header, source_id): self.source_uuid = details.get("uid") self.source_type_name = sources_network.get_source_type_name(self.source_type_id) self.source_type = SOURCE_PROVIDER_MAP.get(self.source_type_name) + self.app_settings = sources_network.get_application_settings(self.source_type) def _extract_from_header(headers, header_type): @@ -214,7 +216,7 @@ def get_sources_msg_data(msg, app_type_id): LOG.debug(f"msg value: {str(value)}") event_type = _extract_from_header(msg.headers(), KAFKA_HDR_EVENT_TYPE) LOG.debug(f"event_type: {str(event_type)}") - if event_type in (KAFKA_APPLICATION_CREATE, KAFKA_APPLICATION_DESTROY): + if event_type in (KAFKA_APPLICATION_CREATE, KAFKA_APPLICATION_UPDATE, KAFKA_APPLICATION_DESTROY): if int(value.get("application_type_id")) == app_type_id: LOG.debug("Application Message: %s", str(msg)) msg_data["event_type"] = event_type @@ -352,6 +354,13 @@ def sources_network_info(source_id, auth_header): storage.add_provider_sources_network_info(src_details, source_id) save_auth_info(auth_header, source_id) + app_settings = src_details.app_settings + if app_settings: + try: + storage.update_application_settings(source_id, app_settings) + except storage.SourcesStorageError as error: + LOG.error(f"Unable to apply application settings. error: {str(error)}") + return def cost_mgmt_msg_filter(msg_data): @@ -421,7 +430,7 @@ def process_message(app_type_id, msg): # noqa: C901 save_auth_info(msg_data.get("auth_header"), msg_data.get("source_id")) - elif msg_data.get("event_type") in (KAFKA_SOURCE_UPDATE,): + elif msg_data.get("event_type") in (KAFKA_SOURCE_UPDATE, KAFKA_APPLICATION_UPDATE): if storage.is_known_source(msg_data.get("source_id")) is False: LOG.info("Update event for unknown source id, skipping...") return diff --git a/koku/sources/sources_http_client.py b/koku/sources/sources_http_client.py index d82598019b..49e919d7c7 100644 --- a/koku/sources/sources_http_client.py +++ b/koku/sources/sources_http_client.py @@ -25,6 +25,8 @@ import requests from requests.exceptions import RequestException +from api.provider.models import Provider +from sources import storage from sources.config import Config from sources.sources_error_message import SourcesErrorMessage @@ -144,6 +146,77 @@ def get_source_type_name(self, type_id): source_name = endpoint_response.get("data")[0].get("name") return source_name + def _build_app_settings_for_gcp(self, app_settings): + """Build settings structure for gcp.""" + billing_source = {} + dataset = app_settings.get("dataset") + if dataset: + billing_source = {"data_source": {}} + billing_source["data_source"]["dataset"] = dataset + return billing_source + + def _build_app_settings_for_aws(self, app_settings): + """Build settings structure for aws.""" + billing_source = {} + bucket = app_settings.get("bucket") + if bucket: + billing_source = {"data_source": {}} + billing_source["data_source"]["bucket"] = bucket + return billing_source + + def _build_app_settings_for_azure(self, app_settings): + """Build settings structure for azure.""" + billing_source = {} + authentication = {} + resource_group = app_settings.get("resource_group") + storage_account = app_settings.get("storage_account") + subscription_id = app_settings.get("subscription_id") + + if resource_group or storage_account: + billing_source = {"data_source": {}} + if resource_group: + billing_source["data_source"]["resource_group"] = resource_group + if storage_account: + billing_source["data_source"]["storage_account"] = storage_account + if subscription_id: + authentication = {"credentials": {}} + authentication["credentials"]["subscription_id"] = subscription_id + return billing_source, authentication + + def _update_app_settings_for_source_type(self, source_type, app_settings): + """Update application settings.""" + settings = {} + billing_source = {} + authentication = {} + + if source_type in (Provider.PROVIDER_GCP, Provider.PROVIDER_GCP_LOCAL): + billing_source = self._build_app_settings_for_gcp(app_settings) + elif source_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL): + billing_source = self._build_app_settings_for_aws(app_settings) + elif source_type in (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL): + billing_source, authentication = self._build_app_settings_for_azure(app_settings) + + if billing_source: + settings["billing_source"] = billing_source + if authentication: + settings["authentication"] = authentication + + return settings + + def get_application_settings(self, source_type): + """Get the application settings from Sources.""" + application_url = "{}/applications?filter[source_id]={}".format(self._base_url, str(self._source_id)) + r = self._get_network_response(application_url, self._identity_header, "Unable to application settings") + applications_response = r.json() + if not applications_response.get("data"): + raise SourcesHTTPClientError(f"No application data for source: {self._source_id}") + app_settings = applications_response.get("data")[0].get("extra") + + updated_settings = None + if app_settings: + updated_settings = self._update_app_settings_for_source_type(source_type, app_settings) + return updated_settings + def get_aws_credentials(self): """Get the roleARN from Sources Authentication service.""" url = "{}/applications?filter[source_id]={}".format(self._base_url, str(self._source_id)) @@ -186,24 +259,15 @@ def get_gcp_credentials(self): else: raise SourcesHTTPClientError(f"Unable to get GCP credentials for Source: {self._source_id}") - authentications_str = "{}/authentications?[authtype]=project_id&[resource_id]={}" + authentications_str = "{}/authentications?[authtype]=project_id_service_account_json&[resource_id]={}" authentications_url = authentications_str.format(self._base_url, str(resource_id)) r = self._get_network_response(authentications_url, self._identity_header, "Unable to GCP credentials") authentications_response = r.json() if not authentications_response.get("data"): raise SourcesHTTPClientError(f"Unable to get GCP credentials for Source: {self._source_id}") - authentications_id = authentications_response.get("data")[0].get("id") - - authentications_internal_url = "{}/authentications/{}?expose_encrypted_attribute[]=password".format( - self._internal_url, str(authentications_id) - ) - r = self._get_network_response( - authentications_internal_url, self._identity_header, "Unable to GCP Credentials" - ) - authentications_internal_response = r.json() - password = authentications_internal_response.get("password") - if password: - return {"project_id": password} + project_id = authentications_response.get("data")[0].get("username") + if project_id: + return {"project_id": project_id} raise SourcesHTTPClientError(f"Unable to get GCP credentials for Source: {self._source_id}") @@ -318,12 +382,12 @@ def set_source_status(self, error_msg, cost_management_type_id=None): application_url = f"{self._base_url}/applications/{str(application_id)}" json_data = self.build_source_status(error_msg) - - application_response = requests.patch(application_url, json=json_data, headers=status_header) - if application_response.status_code != 204: - raise SourcesHTTPClientError( - f"Unable to set status for Source {self._source_id}. Reason: " - f"Status code: {application_response.status_code}. Response: {application_response.text}." - ) - return True + if storage.save_status(self._source_id, json_data): + application_response = requests.patch(application_url, json=json_data, headers=status_header) + if application_response.status_code != 204: + raise SourcesHTTPClientError( + f"Unable to set status for Source {self._source_id}. Reason: " + f"Status code: {application_response.status_code}. Response: {application_response.text}." + ) + return True return False diff --git a/koku/sources/storage.py b/koku/sources/storage.py index f91dada116..311c6dab34 100644 --- a/koku/sources/storage.py +++ b/koku/sources/storage.py @@ -16,6 +16,7 @@ # """Database accessors for Sources database table.""" import binascii +import copy import logging from base64 import b64decode from json import loads as json_loads @@ -31,6 +32,15 @@ LOG = logging.getLogger(__name__) REQUIRED_AZURE_AUTH_KEYS = {"client_id", "tenant_id", "client_secret", "subscription_id"} REQUIRED_AZURE_BILLING_KEYS = {"resource_group", "storage_account"} +ALLOWED_BILLING_SOURCE_PROVIDERS = ( + Provider.PROVIDER_AWS, + Provider.PROVIDER_AWS_LOCAL, + Provider.PROVIDER_AZURE, + Provider.PROVIDER_AZURE_LOCAL, + Provider.PROVIDER_GCP, + Provider.PROVIDER_GCP_LOCAL, +) +ALLOWED_AUTHENTICATION_PROVIDERS = (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL) class SourcesStorageError(Exception): @@ -90,8 +100,8 @@ def _gcp_provider_ready_for_create(provider): provider.source_id and provider.name and provider.auth_header - and provider.billing_source - and provider.authentication + and provider.billing_source.get("data_source") + and provider.authentication.get("credentials") and not provider.koku_uuid ): return True @@ -415,6 +425,26 @@ def add_provider_koku_uuid(source_id, koku_uuid): source.save() +def save_status(source_id, status): + """ + Save source status. + + Args: + source_id (Integer) - Platform-Sources identifier + + Returns: + status (dict) - source status json + + """ + source = get_source(source_id, f"Source ID {source_id} does not exist.", LOG.error) + if source and source.status != status: + source.status = status + source.save() + return True + + return False + + def is_known_source(source_id): """ Check if source exists in database. @@ -435,3 +465,76 @@ def is_known_source(source_id): LOG.error(f"Accessing Sources resulting in {type(error).__name__}: {error}") raise error return source_exists + + +def _validate_billing_source(provider_type, billing_source): # noqa: C901 + """Validate billing source parameters.""" + if provider_type == Provider.PROVIDER_AWS: + # TODO: Remove `and not billing_source.get("bucket")` if UI is updated to send "data_source" field + if not billing_source.get("data_source", {}).get("bucket") and not billing_source.get("bucket"): + raise SourcesStorageError("Missing AWS bucket.") + elif provider_type == Provider.PROVIDER_AZURE: + data_source = billing_source.get("data_source") + if not data_source: + raise SourcesStorageError("Missing AZURE data_source.") + if not data_source.get("resource_group"): + raise SourcesStorageError("Missing AZURE resource_group") + if not data_source.get("storage_account"): + raise SourcesStorageError("Missing AZURE storage_account") + elif provider_type == Provider.PROVIDER_GCP: + data_source = billing_source.get("data_source") + if not data_source: + raise SourcesStorageError("Missing GCP data_source.") + if not data_source.get("dataset"): + raise SourcesStorageError("Missing GCP dataset") + + +def _update_billing_source(instance, billing_source): + if instance.source_type not in ALLOWED_BILLING_SOURCE_PROVIDERS: + raise SourcesStorageError(f"Option not supported by source type {instance.source_type}.") + if instance.billing_source.get("data_source"): + billing_copy = copy.deepcopy(instance.billing_source.get("data_source")) + data_source = billing_source.get("data_source", {}) + if data_source.get("resource_group") or data_source.get("storage_account"): + billing_copy.update(billing_source.get("data_source")) + billing_source["data_source"] = billing_copy + _validate_billing_source(instance.source_type, billing_source) + # This if statement can also be removed if UI is updated to send "data_source" field + if instance.source_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL) and not billing_source.get( + "data_source" + ): + billing_source = {"data_source": billing_source} + return billing_source + + +def _update_authentication(instance, authentication): + if instance.source_type not in ALLOWED_AUTHENTICATION_PROVIDERS: + raise SourcesStorageError(f"Option not supported by source type {instance.source_type}.") + auth_dict = instance.authentication + if not auth_dict.get("credentials"): + auth_dict["credentials"] = {"subscription_id": None} + subscription_id = authentication.get("credentials", {}).get("subscription_id") + auth_dict["credentials"]["subscription_id"] = subscription_id + return auth_dict + + +def update_application_settings(source_id, settings): + """Store billing source update.""" + LOG.info(f"Found settings: {str(settings)}") + billing_source = settings.get("billing_source") + authentication = settings.get("authentication") + if billing_source: + instance = get_source(source_id, "Unable to add billing source", LOG.error) + if instance.billing_source: + billing_source = _update_billing_source(instance, billing_source) + instance.billing_source = billing_source + instance.pending_update = True + instance.save() + + if authentication: + instance = get_source(source_id, "Unable to add authentication", LOG.error) + if instance.authentication: + authentication = _update_authentication(instance, authentication) + instance.authentication = authentication + instance.pending_update = True + instance.save() diff --git a/koku/sources/test/api/test_serializers.py b/koku/sources/test/api/test_serializers.py index ae9fd8281e..2c9ce8bcdc 100644 --- a/koku/sources/test/api/test_serializers.py +++ b/koku/sources/test/api/test_serializers.py @@ -28,7 +28,6 @@ from api.provider.models import Provider from api.provider.models import Sources from api.provider.provider_builder import ProviderBuilder -from api.provider.test import PROVIDERS from providers.provider_access import ProviderAccessor from providers.provider_errors import SkipStatusPush from sources.api import get_account_from_header @@ -390,148 +389,6 @@ def test_provider_create(self, mock_header, mock_request_info, _): self.assertEqual(instance2.billing_source.get("data_source", {}).get("bucket"), "second-bucket") - def test_validate_billing_source(self, _): - """Test to validate that the billing source dictionary is valid.""" - test_matrix = [ - {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"bucket": "test-bucket"}, "exception": False}, - { - "provider_type": Provider.PROVIDER_AWS, - "billing_source": {"data_source": {"bucket": "test-bucket"}}, - "exception": False, - }, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, - "exception": False, - }, - { - "provider_type": Provider.PROVIDER_AWS, - "billing_source": {"data_source": {"nobucket": "test-bucket"}}, - "exception": True, - }, - {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"nobucket": "test-bucket"}, "exception": True}, - {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"data_source": {}}, "exception": True}, - {"provider_type": Provider.PROVIDER_AWS, "billing_source": {}, "exception": True}, - {"provider_type": Provider.PROVIDER_AZURE, "billing_source": {}, "exception": True}, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"nodata_source": {"resource_group": "foo", "storage_account": "bar"}}, - "exception": True, - }, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"data_source": {"noresource_group": "foo", "storage_account": "bar"}}, - "exception": True, - }, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"data_source": {"resource_group": "foo", "nostorage_account": "bar"}}, - "exception": True, - }, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"data_source": {"resource_group": "foo"}}, - "exception": True, - }, - { - "provider_type": Provider.PROVIDER_AZURE, - "billing_source": {"data_source": {"storage_account": "bar"}}, - "exception": True, - }, - { - "provider_type": Provider.PROVIDER_GCP, - "billing_source": {"data_source": {"dataset": "test_dataset", "table_id": "test_table_id"}}, - "exception": False, - }, - { - "provider_type": Provider.PROVIDER_GCP, - "billing_source": {"data_source": {"dataset": "test_dataset"}}, - "exception": False, - }, - { - "provider_type": Provider.PROVIDER_GCP, - "billing_source": {"data_source": {"table_id": "test_table_id"}}, - "exception": True, - }, - {"provider_type": Provider.PROVIDER_GCP, "billing_source": {}, "exception": True}, - ] - - for test in test_matrix: - with self.subTest(test=test): - if test.get("exception"): - with self.assertRaises(SourcesStorageError): - SourcesSerializer()._validate_billing_source( - test.get("provider_type"), test.get("billing_source") - ) - else: - try: - SourcesSerializer()._validate_billing_source( - test.get("provider_type"), test.get("billing_source") - ) - except Exception as error: - self.fail(str(error)) - - def test_update_aws_billing_source(self, _): - """Test to validate that the billing source dictionary is updated.""" - aws_instance = self.aws_obj - aws_instance.billing_source = PROVIDERS[Provider.PROVIDER_AWS].get("billing_source") - aws_instance.save() - test_matrix = [ - { - "instance": aws_instance, - "billing_source": {"bucket": "test-bucket"}, - "expected": {"data_source": {"bucket": "test-bucket"}}, - }, - { - "instance": aws_instance, - "billing_source": {"data_source": {"bucket": "test-bucket"}}, - "expected": {"data_source": {"bucket": "test-bucket"}}, - }, - ] - - for test in test_matrix: - with self.subTest(test=test): - try: - new_billing = SourcesSerializer()._update_billing_source(aws_instance, test.get("billing_source")) - self.assertEqual(new_billing, test.get("expected")) - except Exception as error: - self.fail(str(error)) - - def test_update_azure_billing_source(self, _): - """Test to validate that the billing source dictionary is updated.""" - azure_instance = self.azure_obj - azure_instance.billing_source = { - "data_source": {"resource_group": "original-1", "storage_account": "original-2"} - } - azure_instance.save() - test_matrix = [ - { - "instance": azure_instance, - "billing_source": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, - "expected": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, - }, - { - "instance": azure_instance, - "billing_source": {"data_source": {"resource_group": "foo"}}, - "expected": {"data_source": {"resource_group": "foo", "storage_account": "original-2"}}, - }, - { - "instance": azure_instance, - "billing_source": {"data_source": {"storage_account": "bar"}}, - "expected": {"data_source": {"resource_group": "original-1", "storage_account": "bar"}}, - }, - ] - - for test in test_matrix: - with self.subTest(test=test): - try: - new_billing = SourcesSerializer()._update_billing_source( - azure_instance, test.get("billing_source") - ) - self.assertEqual(new_billing, test.get("expected")) - except Exception as error: - self.fail(str(error)) - @patch("api.provider.serializers.ProviderSerializer.get_request_info") @patch("sources.api.serializers.get_auth_header", return_value=Config.SOURCES_FAKE_HEADER) def test_gcp_admin_add_table_not_ready(self, mock_header, mock_request_info, _): diff --git a/koku/sources/test/test_kafka_listener.py b/koku/sources/test/test_kafka_listener.py index 927f218c5d..efd20d62c4 100644 --- a/koku/sources/test/test_kafka_listener.py +++ b/koku/sources/test/test_kafka_listener.py @@ -612,7 +612,6 @@ def test_sources_network_info_sync_gcp(self): source_type_id = 1 mock_source_name = "google" resource_id = 2 - authentication_id = 3 with requests_mock.mock() as m: m.get( f"http://www.sources.com/api/v1.0/sources/{test_source_id}", @@ -624,31 +623,25 @@ def test_sources_network_info_sync_gcp(self): status_code=200, json={"data": [{"name": mock_source_name}]}, ) - m.get( - f"http://www.sources.com/api/v1.0/endpoints?filter[source_id]={test_source_id}", - status_code=200, - json={"data": [{"id": resource_id}]}, - ) m.get( f"http://www.sources.com/api/v1.0/applications?filter[source_id]={test_source_id}", status_code=200, - json={"data": [{"id": resource_id}]}, + json={ + "data": [ + { + "id": resource_id, + "extra": {"billing_source": {"data_source": {"dataset": "billing_datset"}}}, + } + ] + }, ) m.get( ( f"http://www.sources.com/api/v1.0/authentications?" - f"[authtype]=project_id&[resource_id]={resource_id}" - ), - status_code=200, - json={"data": [{"id": authentication_id}]}, - ) - m.get( - ( - f"http://www.sources.com/internal/v1.0/authentications/{authentication_id}" - f"?expose_encrypted_attribute[]=password" + f"[authtype]=project_id_service_account_json&[resource_id]={resource_id}" ), status_code=200, - json={"password": authentication}, + json={"data": [{"username": authentication}]}, ) source_integration.sources_network_info(test_source_id, test_auth_header) @@ -672,7 +665,7 @@ def test_sources_network_info_sync_aws_local(self): aws_source.save() source_type_id = 1 mock_source_name = "amazon-local" - resource_id = 2 + resource_id = 1 authentication_id = 3 with requests_mock.mock() as m: m.get( @@ -691,10 +684,15 @@ def test_sources_network_info_sync_aws_local(self): json={"data": [{"id": resource_id}]}, ) m.get( - (f"http://www.sources.com/api/v1.0/authentications?" f"[authtype]=arn&[resource_id]={resource_id}"), + (f"http://www.sources.com/api/v1.0/authentications?[authtype]=arn&[resource_id]={resource_id}"), status_code=200, json={"data": [{"id": authentication_id}]}, ) + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={test_source_id}", + status_code=200, + json={"data": [{"id": resource_id, "extra": {"foo": "bar"}}]}, + ) m.get( ( f"http://www.sources.com/internal/v1.0/authentications/{authentication_id}" @@ -746,7 +744,7 @@ def test_sources_network_info_sync_ocp(self): json={"data": [{"name": mock_source_name}]}, ) m.get( - f"http://www.sources.com/api/v1.0/endpoints?filter[source_id]={test_source_id}", + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={test_source_id}", status_code=200, json={"data": [{"id": resource_id}]}, ) @@ -755,7 +753,7 @@ def test_sources_network_info_sync_ocp(self): status_code=200, json={"data": [{"id": authentication_id}]}, ) - m.patch(f"http://www.sources.com/api/v1.0/applications/{app_id}", status_code=204) + m.patch(f"http://www.sources.com/api/v1.0/applications/{resource_id}", status_code=204) source_integration.sources_network_info(test_source_id, test_auth_header) source_obj = Sources.objects.get(source_id=test_source_id) @@ -915,6 +913,7 @@ def test_sources_network_info_no_endpoint(self): application_type = 2 mock_source_name = "amazon" source_type_id = 1 + resource_id = 3 source_uid = faker.uuid4() test_auth_header = Config.SOURCES_FAKE_HEADER ocp_source = Sources(source_id=test_source_id, auth_header=test_auth_header, offset=1) @@ -939,6 +938,11 @@ def test_sources_network_info_no_endpoint(self): m.get( f"http://www.sources.com/api/v1.0/applications?filter[source_id]={test_source_id}", status_code=200, + json={"data": [{"id": resource_id, "extra": {}}]}, + ) + m.get( + f"http://www.sources.com/api/v1.0/authentications?[authtype]=arn&[resource_id]={resource_id}", + status_code=200, json={"data": []}, ) m.get( @@ -946,6 +950,8 @@ def test_sources_network_info_no_endpoint(self): status_code=200, json={"data": [{"id": application_type}]}, ) + m.patch(f"http://www.sources.com/api/v1.0/applications/{resource_id}", status_code=204) + source_integration.sources_network_info(test_source_id, test_auth_header) source_obj = Sources.objects.get(source_id=test_source_id) @@ -1010,8 +1016,9 @@ def test_process_message_application_unsupported_source_type(self): with patch.object( SourcesHTTPClient, "get_source_details", return_value={"name": "my ansible", "source_type_id": 2} ): - with patch.object(SourcesHTTPClient, "get_source_type_name", return_value="ansible-tower"): - self.assertIsNone(process_message(test_application_id, msg_data)) + with patch.object(SourcesHTTPClient, "get_application_settings", return_value={}): + with patch.object(SourcesHTTPClient, "get_source_type_name", return_value="ansible-tower"): + self.assertIsNone(process_message(test_application_id, msg_data)) @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") @patch("sources.kafka_listener.sources_network_info", returns=None) diff --git a/koku/sources/test/test_sources_http_client.py b/koku/sources/test/test_sources_http_client.py index 12e1a2f7f5..479ccb9af6 100644 --- a/koku/sources/test/test_sources_http_client.py +++ b/koku/sources/test/test_sources_http_client.py @@ -256,7 +256,6 @@ def test_get_aws_credentials_no_auth(self): def test_get_gcp_credentials_from_app_auth(self): """Test to get project id from authentication service for Application authentication.""" resource_id = 2 - authentication_id = 3 client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) with requests_mock.mock() as m: m.get( @@ -267,18 +266,10 @@ def test_get_gcp_credentials_from_app_auth(self): m.get( ( f"http://www.sources.com/api/v1.0/authentications?" - f"[authtype]=project_id&[resource_id]={resource_id}" + f"[authtype]=project_id_service_account_json&[resource_id]={resource_id}" ), status_code=200, - json={"data": [{"id": authentication_id}]}, - ) - m.get( - ( - f"http://www.sources.com/internal/v1.0/authentications/{authentication_id}" - f"?expose_encrypted_attribute[]=password" - ), - status_code=200, - json={"password": self.authentication}, + json={"data": [{"username": self.authentication}]}, ) response = client.get_gcp_credentials() self.assertEqual(response, {"project_id": self.authentication}) @@ -296,7 +287,10 @@ def test_get_gcp_credentials_no_auth(self): json={"data": []}, ) m.get( - (f"http://www.sources.com/api/v1.0/authentications?" f"[authtype]=arn&[resource_id]={resource_id}"), + ( + f"http://www.sources.com/api/v1.0/authentications?" + f"[authtype]=project_id_service_account_json&[resource_id]={resource_id}" + ), status_code=200, json={"data": []}, ) @@ -315,7 +309,6 @@ def test_get_gcp_credentials_no_auth(self): def test_get_gcp_credentials_no_password(self): """Test to get GCP project id from authentication service with auth not containing password.""" resource_id = 2 - authentication_id = 3 client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) with requests_mock.mock() as m: m.get( @@ -326,18 +319,10 @@ def test_get_gcp_credentials_no_password(self): m.get( ( f"http://www.sources.com/api/v1.0/authentications?" - f"[authtype]=project_id&[resource_id]={resource_id}" + f"[authtype]=project_id_service_account_json&[resource_id]={resource_id}" ), status_code=200, - json={"data": [{"id": authentication_id}]}, - ) - m.get( - ( - f"http://www.sources.com/internal/v1.0/authentications/{authentication_id}" - f"?expose_encrypted_attribute[]=password" - ), - status_code=200, - json={"other": self.authentication}, + json={"data": [{"other": self.authentication}]}, ) with self.assertRaises(SourcesHTTPClientError): client.get_gcp_credentials() @@ -696,6 +681,126 @@ def test_set_source_status_unexpected_header(self): response = client.set_source_status(error_msg, application_type_id) self.assertFalse(response) + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_aws(self): + """Test to get application settings for aws.""" + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"data": [{"extra": {"bucket": "testbucket"}}]}, + ) + response = client.get_application_settings("AWS") + expected_settings = {"billing_source": {"data_source": {"bucket": "testbucket"}}} + self.assertEqual(response, expected_settings) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_azure(self): + """Test to get application settings for azure.""" + subscription_id = "subscription-uuid" + resource_group = "testrg" + storage_account = "testsa" + + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={ + "data": [ + { + "extra": { + "subscription_id": subscription_id, + "resource_group": resource_group, + "storage_account": storage_account, + } + } + ] + }, + ) + response = client.get_application_settings("Azure") + + self.assertEqual(response.get("billing_source").get("data_source").get("resource_group"), resource_group) + self.assertEqual(response.get("billing_source").get("data_source").get("storage_account"), storage_account) + self.assertEqual(response.get("authentication").get("credentials").get("subscription_id"), subscription_id) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_azure_only_billing(self): + """Test to get application settings for azure only billing_source.""" + resource_group = "testrg" + storage_account = "testsa" + + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"data": [{"extra": {"resource_group": resource_group, "storage_account": storage_account}}]}, + ) + response = client.get_application_settings("Azure") + + self.assertEqual(response.get("billing_source").get("data_source").get("resource_group"), resource_group) + self.assertEqual(response.get("billing_source").get("data_source").get("storage_account"), storage_account) + self.assertIsNone(response.get("authentication")) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_azure_authentication(self): + """Test to get application settings for azure for authentications.""" + subscription_id = "subscription-uuid" + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"data": [{"extra": {"subscription_id": subscription_id}}]}, + ) + response = client.get_application_settings("Azure") + + self.assertIsNone(response.get("billing_source")) + self.assertEqual(response.get("authentication").get("credentials").get("subscription_id"), subscription_id) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_gcp(self): + """Test to get application settings for gcp.""" + dataset = "testdataset" + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"data": [{"extra": {"dataset": dataset}}]}, + ) + response = client.get_application_settings("GCP") + expected_settings = {"billing_source": {"data_source": {"dataset": dataset}}} + self.assertEqual(response, expected_settings) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_ocp(self): + """Test to get application settings for ocp.""" + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"data": [{"extra": {}}]}, + ) + response = client.get_application_settings("OCP") + self.assertIsNone(response) + + @patch.object(Config, "SOURCES_API_URL", "http://www.sources.com") + def test_get_application_settings_malformed_response(self): + """Test to get application settings for a malformed repsonse.""" + client = SourcesHTTPClient(auth_header=Config.SOURCES_FAKE_HEADER, source_id=self.source_id) + with requests_mock.mock() as m: + m.get( + f"http://www.sources.com/api/v1.0/applications?filter[source_id]={self.source_id}", + status_code=200, + json={"foo": [{"extra": {"bucket": "testbucket"}}]}, + ) + with self.assertRaises(SourcesHTTPClientError): + _ = client.get_application_settings("AWS") + class SourcesHTTPClientCheckAppTypeTest(TestCase): def setUp(self): diff --git a/koku/sources/test/test_storage.py b/koku/sources/test/test_storage.py index df85abea8d..fcf7510c53 100644 --- a/koku/sources/test/test_storage.py +++ b/koku/sources/test/test_storage.py @@ -308,7 +308,7 @@ def test_screen_and_build_provider_sync_create_event(self): "GCP Provider", Provider.PROVIDER_GCP, {"project_id": "test-project"}, - None, + {"data_source": {}}, "authheader", 1, False, @@ -557,3 +557,242 @@ def test_load_providers_to_update(self): self.assertEquals(len(response), test.get("expected_list_length")) test_source_id += 1 aws_obj.delete() + + def test_validate_billing_source(self): + """Test to validate that the billing source dictionary is valid.""" + test_matrix = [ + {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"bucket": "test-bucket"}, "exception": False}, + { + "provider_type": Provider.PROVIDER_AWS, + "billing_source": {"data_source": {"bucket": "test-bucket"}}, + "exception": False, + }, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, + "exception": False, + }, + { + "provider_type": Provider.PROVIDER_AWS, + "billing_source": {"data_source": {"nobucket": "test-bucket"}}, + "exception": True, + }, + {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"nobucket": "test-bucket"}, "exception": True}, + {"provider_type": Provider.PROVIDER_AWS, "billing_source": {"data_source": {}}, "exception": True}, + {"provider_type": Provider.PROVIDER_AWS, "billing_source": {}, "exception": True}, + {"provider_type": Provider.PROVIDER_AZURE, "billing_source": {}, "exception": True}, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"nodata_source": {"resource_group": "foo", "storage_account": "bar"}}, + "exception": True, + }, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"data_source": {"noresource_group": "foo", "storage_account": "bar"}}, + "exception": True, + }, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"data_source": {"resource_group": "foo", "nostorage_account": "bar"}}, + "exception": True, + }, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"data_source": {"resource_group": "foo"}}, + "exception": True, + }, + { + "provider_type": Provider.PROVIDER_AZURE, + "billing_source": {"data_source": {"storage_account": "bar"}}, + "exception": True, + }, + { + "provider_type": Provider.PROVIDER_GCP, + "billing_source": {"data_source": {"dataset": "test_dataset", "table_id": "test_table_id"}}, + "exception": False, + }, + { + "provider_type": Provider.PROVIDER_GCP, + "billing_source": {"data_source": {"dataset": "test_dataset"}}, + "exception": False, + }, + { + "provider_type": Provider.PROVIDER_GCP, + "billing_source": {"data_source": {"table_id": "test_table_id"}}, + "exception": True, + }, + {"provider_type": Provider.PROVIDER_GCP, "billing_source": {}, "exception": True}, + ] + + for test in test_matrix: + with self.subTest(test=test): + if test.get("exception"): + with self.assertRaises(storage.SourcesStorageError): + storage._validate_billing_source(test.get("provider_type"), test.get("billing_source")) + else: + try: + storage._validate_billing_source(test.get("provider_type"), test.get("billing_source")) + except Exception as error: + self.fail(str(error)) + + def test_update_aws_billing_source(self): + """Test to validate that the billing source dictionary is updated.""" + aws_instance = Sources( + source_id=3, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AWS, + name="Test AWS Source", + billing_source={"data_source": {"bucket": "my_s3_bucket"}}, + ) + aws_instance.save() + test_matrix = [ + { + "instance": aws_instance, + "billing_source": {"bucket": "test-bucket"}, + "expected": {"data_source": {"bucket": "test-bucket"}}, + }, + { + "instance": aws_instance, + "billing_source": {"data_source": {"bucket": "test-bucket"}}, + "expected": {"data_source": {"bucket": "test-bucket"}}, + }, + ] + + for test in test_matrix: + with self.subTest(test=test): + try: + new_billing = storage._update_billing_source(aws_instance, test.get("billing_source")) + self.assertEqual(new_billing, test.get("expected")) + except Exception as error: + self.fail(str(error)) + aws_instance.delete() + + def test_update_azure_billing_source(self): + """Test to validate that the billing source dictionary is updated.""" + azure_instance = Sources( + source_id=4, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AZURE, + name="Test Azure Source", + billing_source={"data_source": {"resource_group": "original-1", "storage_account": "original-2"}}, + ) + + azure_instance.save() + test_matrix = [ + { + "instance": azure_instance, + "billing_source": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, + "expected": {"data_source": {"resource_group": "foo", "storage_account": "bar"}}, + }, + { + "instance": azure_instance, + "billing_source": {"data_source": {"resource_group": "foo"}}, + "expected": {"data_source": {"resource_group": "foo", "storage_account": "original-2"}}, + }, + { + "instance": azure_instance, + "billing_source": {"data_source": {"storage_account": "bar"}}, + "expected": {"data_source": {"resource_group": "original-1", "storage_account": "bar"}}, + }, + ] + + for test in test_matrix: + with self.subTest(test=test): + try: + new_billing = storage._update_billing_source(azure_instance, test.get("billing_source")) + self.assertEqual(new_billing, test.get("expected")) + except Exception as error: + self.fail(str(error)) + azure_instance.delete() + + def test_update_application_settings(self): + """Test to update application settings.""" + test_source_id = 3 + resource_group = "testrg" + subscription_id = "testsubid" + settings = { + "billing_source": {"data_source": {"resource_group": resource_group}}, + "authentication": {"subscription_id": subscription_id}, + } + azure_obj = Sources( + source_id=test_source_id, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AZURE, + name="Test AZURE Source", + ) + azure_obj.save() + + storage.update_application_settings(test_source_id, settings) + db_obj = Sources.objects.get(source_id=test_source_id) + + self.assertEqual(db_obj.authentication.get("subscription_id"), subscription_id) + self.assertEqual(db_obj.billing_source.get("data_source").get("resource_group"), resource_group) + + def test_update_application_settings_only_billing_source(self): + """Test to update application settings (only billing_source).""" + test_source_id = 3 + resource_group = "testrg" + settings = {"billing_source": {"data_source": {"resource_group": resource_group}}} + azure_obj = Sources( + source_id=test_source_id, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AZURE, + name="Test AZURE Source", + ) + azure_obj.save() + + storage.update_application_settings(test_source_id, settings) + db_obj = Sources.objects.get(source_id=test_source_id) + + self.assertIsNone(db_obj.authentication.get("subscription_id")) + self.assertEqual(db_obj.billing_source.get("data_source").get("resource_group"), resource_group) + + def test_update_application_settings_only_authentication(self): + """Test to update application settings (only authentication).""" + test_source_id = 3 + subscription_id = "testsubid" + settings = {"authentication": {"subscription_id": subscription_id}} + azure_obj = Sources( + source_id=test_source_id, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AZURE, + name="Test AZURE Source", + ) + azure_obj.save() + + storage.update_application_settings(test_source_id, settings) + db_obj = Sources.objects.get(source_id=test_source_id) + + self.assertEqual(db_obj.authentication.get("subscription_id"), subscription_id) + self.assertIsNone(db_obj.billing_source.get("data_source")) + + def test_save_status(self): + """Test to verify source status is saved.""" + test_source_id = 3 + status = "unavailable" + user_facing_string = "Missing credential and billing source" + mock_status = {"availability_status": status, "availability_status_error": user_facing_string} + azure_obj = Sources( + source_id=test_source_id, + auth_header=self.test_header, + offset=3, + source_type=Provider.PROVIDER_AZURE, + name="Test AZURE Source", + ) + azure_obj.save() + + return_code = storage.save_status(test_source_id, mock_status) + db_obj = Sources.objects.get(source_id=test_source_id) + self.assertEqual(db_obj.status, mock_status) + self.assertTrue(return_code) + + # Save again and verify return_code is False + return_code = storage.save_status(test_source_id, mock_status) + db_obj = Sources.objects.get(source_id=test_source_id) + self.assertEqual(db_obj.status, mock_status) + self.assertFalse(return_code)