Skip to content

Commit

Permalink
[COST-4915] Add unattributed distribution to cost model form (#5072)
Browse files Browse the repository at this point in the history
  • Loading branch information
myersCody authored May 10, 2024
1 parent aa0379c commit eb4ca82
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 69 deletions.
14 changes: 14 additions & 0 deletions koku/api/metrics/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,17 @@
"default_cost_type": "Infrastructure",
},
]

PLATFORM_COST = "platform_cost"
WORKER_UNALLOCATED = "worker_cost"
NETWORK_UNATTRIBUTED = "network_unattributed"
STORAGE_UNATTRIBUTED = "storage_unattributed"
DISTRIBUTION_TYPE = "distribution_type"

DEFAULT_DISTRIBUTION_INFO = {
DISTRIBUTION_TYPE: CPU_DISTRIBUTION,
PLATFORM_COST: True,
WORKER_UNALLOCATED: True,
NETWORK_UNATTRIBUTED: False,
STORAGE_UNATTRIBUTED: False,
}
30 changes: 14 additions & 16 deletions koku/cost_models/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,21 @@ class MarkupSerializer(serializers.Serializer):
class DistributionSerializer(BaseSerializer):
"""Serializer for distribution options"""

DISTRIBUTION_OPTIONS = {"distribution_type", "worker_cost", "platform_cost"}

distribution_type = serializers.ChoiceField(choices=metric_constants.DISTRIBUTION_CHOICES, required=False)
platform_cost = serializers.BooleanField(required=False)
worker_cost = serializers.BooleanField(required=False)
network_unattributed = serializers.BooleanField(required=False)
storage_unattributed = serializers.BooleanField(required=False)

def validate(self, data):
"""Run validation for distribution options."""

diff = self.DISTRIBUTION_OPTIONS.difference(data)
if diff == self.DISTRIBUTION_OPTIONS:
return {"distribution_type": metric_constants.CPU_DISTRIBUTION, "platform_cost": True, "worker_cost": True}
if diff:
distribution_info_str = ", ".join(diff)
error_msg = f"Missing distribution information: one of {distribution_info_str}"
raise serializers.ValidationError(error_msg)
default_to_true = [metric_constants.PLATFORM_COST, metric_constants.WORKER_UNALLOCATED]
distribution_keys = metric_constants.DEFAULT_DISTRIBUTION_INFO.keys()
diff = set(distribution_keys).difference(data)
if diff == distribution_keys:
return metric_constants.DEFAULT_DISTRIBUTION_INFO
for element in diff:
data[element] = element in default_to_true
return data


Expand Down Expand Up @@ -477,12 +476,11 @@ def validate(self, data):
data["currency"] = get_currency(self.context.get("request"))

if not data.get("distribution_info"):
data["distribution_info"] = {
"distribution_type": data.get("distribution", metric_constants.CPU_DISTRIBUTION),
"platform_cost": True,
"worker_cost": True,
}

# TODO: Have this return just the default distribution info after
# QE updates tests.
distribution_info = metric_constants.DEFAULT_DISTRIBUTION_INFO
distribution_info["distribution_type"] = data.get("distribution", metric_constants.CPU_DISTRIBUTION)
data["distribution_info"] = distribution_info
if (
data.get("markup")
and not data.get("rates")
Expand Down
35 changes: 14 additions & 21 deletions koku/cost_models/test/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

from api.iam.test.iam_test_case import IamTestCase
from api.metrics import constants as metric_constants
from api.metrics.constants import DEFAULT_DISTRIBUTION_INFO
from api.metrics.constants import SOURCE_TYPE_MAP
from api.provider.models import Provider
from api.utils import get_currency
from cost_models.models import CostModel
from cost_models.models import CostModelMap
from cost_models.serializers import CostModelSerializer
from cost_models.serializers import DistributionSerializer
from cost_models.serializers import RateSerializer
from cost_models.serializers import UUIDKeyRelatedField

Expand Down Expand Up @@ -844,47 +846,38 @@ def test_valid_distribution_info_keys(self):
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertIsNotNone(instance)
# Add in default options
valid_distrib_obj[metric_constants.NETWORK_UNATTRIBUTED] = False
valid_distrib_obj[metric_constants.STORAGE_UNATTRIBUTED] = False
self.assertEqual(instance.distribution_info, valid_distrib_obj)

def test_invalid_distribution_info_keys(self):
"""Test that source distribution_info object has invalid keys."""

invalid_distrib_info_keys = {"bad_key": "", "badder_key": True, "worker_cost": False}
self.ocp_data["distribution_info"] = invalid_distrib_info_keys
self.assertEqual(self.ocp_data["distribution_info"], invalid_distrib_info_keys)
bad_key1 = "bad_key"
bad_key2 = "worst_key"
invalid_distrib_info_keys = {bad_key1: "", bad_key2: True, "worker_cost": False}
with tenant_context(self.tenant):
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
with self.assertRaises(serializers.ValidationError):
serializer.is_valid(raise_exception=True)
serializer = DistributionSerializer(data=invalid_distrib_info_keys)
self.assertTrue(serializer.is_valid(raise_exception=True))
self.assertNotIn(bad_key1, serializer.data)
self.assertNotIn(bad_key2, serializer.data)

def test_none_distribution_info_returns_defaults(self):
"""Test that a none distribution_info object uses default options."""

default_distrib_info_obj = {
"distribution_type": metric_constants.CPU_DISTRIBUTION,
"platform_cost": True,
"worker_cost": True,
}
with tenant_context(self.tenant):
instance = None
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertIsNotNone(instance)
self.assertEqual(instance.distribution_info, default_distrib_info_obj)
self.assertEqual(instance.distribution_info, DEFAULT_DISTRIBUTION_INFO)

def test_empty_distribution_info_returns_defaults(self):
"""Test that an empty distribution_info object returns default options."""

default_distrib_info_obj = {
"distribution_type": metric_constants.CPU_DISTRIBUTION,
"platform_cost": True,
"worker_cost": True,
}
self.ocp_data["distribution_info"] = {}
with tenant_context(self.tenant):
instance = None
serializer = CostModelSerializer(data=self.ocp_data, context=self.request_context)
if serializer.is_valid(raise_exception=True):
instance = serializer.save()
self.assertEqual(instance.distribution_info, default_distrib_info_obj)
self.assertEqual(instance.distribution_info, DEFAULT_DISTRIBUTION_INFO)
44 changes: 19 additions & 25 deletions koku/masu/database/ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from trino.exceptions import TrinoExternalError

from api.common import log_json
from api.metrics import constants as metric_constants
from api.metrics.constants import DEFAULT_DISTRIBUTION_TYPE
from api.provider.models import Provider
from koku.database import SQLScriptAtomicExecutorMixin
Expand Down Expand Up @@ -392,19 +393,24 @@ def populate_markup_cost(self, markup, start_date, end_date, cluster_id):
),
)

def populate_platform_and_worker_distributed_cost_sql(
self, start_date, end_date, provider_uuid, distribution_info
):
def populate_distributed_cost_sql(self, start_date, end_date, provider_uuid, distribution_info):
"""
Populate the platform cost distribution of a customer.
Populate the distribution cost model options.
args:
start_date (datetime, str): The start_date to calculate monthly_cost.
end_date (datetime, str): The end_date to calculate monthly_cost.
distribution: Choice of monthly distribution ex. memory
provider_uuid (str): The str of the provider UUID
"""
distribute_mapping = {}

key_to_file_mapping = {
metric_constants.PLATFORM_COST: "distribute_platform_cost.sql",
metric_constants.WORKER_UNALLOCATED: "distribute_worker_cost.sql",
# metric_constants.STORAGE_UNATTRIBUTED: "distribute_unattributed_storage_cost.sql",
# metric_constants.NETWORK_UNATTRIBUTED: "distribute_unattributed_network_cost.sql",
}

distribution = distribution_info.get("distribution_type", DEFAULT_DISTRIBUTION_TYPE)
table_name = self._table_map["line_item_daily_summary"]
report_period = self.report_periods_for_provider_uuid(provider_uuid, start_date)
Expand All @@ -415,26 +421,14 @@ def populate_platform_and_worker_distributed_cost_sql(
return

report_period_id = report_period.id
distribute_mapping = {
"platform_cost": {
"sql_file": "distribute_platform_cost.sql",
"log_msg": {
True: "distributing platform cost",
False: "removing platform_distributed cost model rate type",
},
},
"worker_cost": {
"sql_file": "distribute_worker_cost.sql",
"log_msg": {
True: "distributing worker unallocated cost",
False: "removing worker_distributed cost model rate type",
},
},
}

for cost_model_key, metadata in distribute_mapping.items():
for cost_model_key, sql_file in key_to_file_mapping.items():
populate = distribution_info.get(cost_model_key, False)
# if populate is false we only execute the delete sql.
if populate:
log_msg = f"distributing {cost_model_key}"
else:
# if populate is false we only execute the delete sql.
log_msg = f"removing {cost_model_key} distribution"
sql_params = {
"start_date": start_date,
"end_date": end_date,
Expand All @@ -445,9 +439,9 @@ def populate_platform_and_worker_distributed_cost_sql(
"populate": populate,
}

sql = pkgutil.get_data("masu.database", f"sql/openshift/cost_model/{metadata['sql_file']}")
sql = pkgutil.get_data("masu.database", f"sql/openshift/cost_model/distribute_cost/{sql_file}")
sql = sql.decode("utf-8")
LOG.info(log_json(msg=metadata["log_msg"][populate], context=sql_params))
LOG.info(log_json(msg=log_msg, context=sql_params))
self._prepare_and_execute_raw_sql_query(table_name, sql, sql_params, operation="INSERT")

def populate_monthly_cost_sql(self, cost_type, rate_type, rate, start_date, end_date, distribution, provider_uuid):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ user_defined_project_sum as (
AND report_period_id = {{report_period_id}}
AND lids.namespace != 'Worker unallocated'
AND lids.namespace != 'Platform unallocated'
AND lids.namespace != 'Storage unattributed'
AND lids.namespace != 'Network unattributed'
AND (cost_category_id IS NULL OR cat.name != 'Platform')
GROUP BY usage_start, cluster_id, source_uuid
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ user_defined_project_sum as (
AND report_period_id = {{report_period_id}}
AND lids.namespace != 'Worker unallocated'
AND lids.namespace != 'Platform unallocated'
AND lids.namespace != 'Storage unattributed'
AND lids.namespace != 'Network unattributed'
AND (cost_category_id IS NULL OR cat.name != 'Platform')
GROUP BY usage_start, cluster_id, source_uuid
),
Expand Down
4 changes: 1 addition & 3 deletions koku/masu/processor/ocp/ocp_cost_model_cost_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def update_summary_cost_model_costs(self, start_date, end_date):

with OCPReportDBAccessor(self._schema) as accessor:

accessor.populate_platform_and_worker_distributed_cost_sql(
start_date, end_date, self._provider_uuid, self._distribution_info
)
accessor.populate_distributed_cost_sql(start_date, end_date, self._provider_uuid, self._distribution_info)
accessor.populate_ui_summary_tables(start_date, end_date, self._provider.uuid)
report_period = accessor.report_periods_for_provider_uuid(self._provider_uuid, start_date)
if report_period:
Expand Down
8 changes: 4 additions & 4 deletions koku/masu/test/database/test_ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,19 +930,19 @@ def test_populate_usage_costs_new_columns_no_report_period(self):
acc.populate_usage_costs("", "", start_date, end_date, self.provider_uuid)
self.assertIn("no report period for OCP provider", logger.output[0])

def test_populate_platform_and_worker_distributed_cost_sql_no_report_period(self):
def test_populate_distributed_cost_sql_no_report_period(self):
"""Test that updating monthly costs without a matching report period no longer throws an error"""
start_date = "2000-01-01"
end_date = "2000-02-01"
with self.accessor as acc:
result = acc.populate_platform_and_worker_distributed_cost_sql(
result = acc.populate_distributed_cost_sql(
start_date, end_date, self.provider_uuid, {"platform_cost": True}
)
self.assertIsNone(result)

@patch("masu.database.ocp_report_db_accessor.pkgutil.get_data")
@patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_raw_sql_query")
def test_populate_platform_and_worker_distributed_cost_sql_called(self, mock_sql_execute, mock_data_get):
def test_populate_distributed_cost_sql_called(self, mock_sql_execute, mock_data_get):
"""Test that the platform distribution is called."""

def get_pkgutil_values(file):
Expand Down Expand Up @@ -972,7 +972,7 @@ def get_pkgutil_values(file):

with self.accessor as acc:
acc.prepare_query = mock_jinja
acc.populate_platform_and_worker_distributed_cost_sql(
acc.populate_distributed_cost_sql(
start_date, end_date, self.ocp_test_provider_uuid, {"worker_cost": True, "platform_cost": True}
)
expected_calls = [
Expand Down

0 comments on commit eb4ca82

Please sign in to comment.