From 644b071e9ea293852b67d428b2b0173fdf78ffd3 Mon Sep 17 00:00:00 2001 From: myersCody Date: Tue, 12 Dec 2023 17:36:32 -0500 Subject: [PATCH 01/30] [COST-4389] An internal masu endpoint to fix parquet files. --- koku/masu/api/upgrade_trino/__init__.py | 2 + koku/masu/api/upgrade_trino/task_handler.py | 45 +++ koku/masu/api/upgrade_trino/util/__init__.py | 0 .../api/upgrade_trino/util/state_tracker.py | 104 ++++++ .../util/verify_parquet_files.py | 300 ++++++++++++++++++ koku/masu/api/upgrade_trino/view.py | 65 ++++ koku/masu/api/urls.py | 2 + koku/masu/api/views.py | 1 + koku/masu/celery/tasks.py | 8 + 9 files changed, 527 insertions(+) create mode 100644 koku/masu/api/upgrade_trino/__init__.py create mode 100644 koku/masu/api/upgrade_trino/task_handler.py create mode 100644 koku/masu/api/upgrade_trino/util/__init__.py create mode 100644 koku/masu/api/upgrade_trino/util/state_tracker.py create mode 100644 koku/masu/api/upgrade_trino/util/verify_parquet_files.py create mode 100644 koku/masu/api/upgrade_trino/view.py diff --git a/koku/masu/api/upgrade_trino/__init__.py b/koku/masu/api/upgrade_trino/__init__.py new file mode 100644 index 0000000000..f7b39ea693 --- /dev/null +++ b/koku/masu/api/upgrade_trino/__init__.py @@ -0,0 +1,2 @@ +# Everything in this directory will become +# dead code after the trino upgrade. diff --git a/koku/masu/api/upgrade_trino/task_handler.py b/koku/masu/api/upgrade_trino/task_handler.py new file mode 100644 index 0000000000..ca685d60c6 --- /dev/null +++ b/koku/masu/api/upgrade_trino/task_handler.py @@ -0,0 +1,45 @@ +import copy +import logging + +from api.common import log_json +from api.provider.models import Provider +from masu.celery.tasks import fix_parquet_data_types +from masu.processor.orchestrator import get_billing_month_start + +LOG = logging.getLogger(__name__) + + +def fix_parquet_data_types_task_builder(bill_date, provider_uuid=None, provider_type=None, simulate=False): + """ + Fixes the parquet file data type for each account. + Args: + simulate (Boolean) simulate the parquet file fixing. + Returns: + (celery.result.AsyncResult) Async result for deletion request. + """ + async_results = [] + if provider_type: + providers = Provider.objects.filter(active=True, paused=False, type=provider_type) + else: + providers = Provider.objects.filter(uuid=provider_uuid) + for provider in providers: + account = copy.deepcopy(provider.account) + report_month = get_billing_month_start(bill_date) + async_result = fix_parquet_data_types.delay( + schema_name=account.get("schema_name"), + provider_type=account.get("provider_type"), + provider_uuid=account.get("provider_uuid"), + simulate=simulate, + bill_date=report_month, + ) + LOG.info( + log_json( + provider.uuid, + msg="Calling fix_parquet_data_types", + schema=account.get("schema_name"), + provider_uuid=provider.uuid, + task_id=str(async_result), + ) + ) + async_results.append(str(async_result)) + return async_results diff --git a/koku/masu/api/upgrade_trino/util/__init__.py b/koku/masu/api/upgrade_trino/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py new file mode 100644 index 0000000000..123909a173 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -0,0 +1,104 @@ +import logging +import os + +from api.common import log_json +from api.provider.provider_manager import ProviderManager +from api.provider.provider_manager import ProviderManagerError + + +LOG = logging.getLogger(__name__) + + +class StateTracker: + FOUND_S3_FILE = "found_s3_file" + DOWNLOADED_LOCALLY = "downloaded_locally" + NO_CHANGES_NEEDED = "no_changes_needed" + COERCE_REQUIRED = "coerce_required" + SENT_TO_S3_COMPLETE = "sent_to_s3_complete" + SENT_TO_S3_FAILED = "sent_to_s3_failed" + FAILED_DTYPE_CONVERSION = "failed_data_type_conversion" + + def __init__(self, provider_uuid): + self.files = [] + self.tracker = {} + self.local_files = {} + self.provider_uuid = provider_uuid + self.context_key = "dtype_conversion" + self.failed_files_key = "dtype_failed_files" + + def set_state(self, s3_obj_key, state): + self.tracker[s3_obj_key] = state + + def add_local_file(self, s3_obj_key, local_path): + self.local_files[s3_obj_key] = local_path + self.tracker[s3_obj_key] = self.DOWNLOADED_LOCALLY + + def get_files_that_need_updated(self): + """Returns a mapping of files in the s3 needs + updating state. + + {s3_object_key: local_file_path} for + """ + mapping = {} + for s3_obj_key, state in self.tracker.items(): + if state == self.COERCE_REQUIRED: + mapping[s3_obj_key] = self.local_files.get(s3_obj_key) + return mapping + + def generate_simulate_messages(self): + """ + Generates the simulate messages. + """ + files_count = 0 + files_failed = [] + files_need_updated = [] + files_correct = [] + for s3_obj_key, state in self.tracker.items(): + files_count += 1 + if state == self.COERCE_REQUIRED: + files_need_updated.append(s3_obj_key) + elif state == self.NO_CHANGES_NEEDED: + files_correct.append(s3_obj_key) + else: + files_failed.append(s3_obj_key) + simulate_info = { + "already correct.": files_correct, + "need updated.": files_need_updated, + "failed to convert.": files_failed, + } + for substring, files_list in simulate_info.items(): + LOG.info(f"{len(files_list)} out of {files_count} {substring}") + if files_list: + LOG.info(f"File list: {files_list}") + + def _clean_local_files(self): + for file_path in self.local_files.values(): + os.remove(file_path) + + def _check_for_incomplete_files(self): + incomplete_files = [] + for file_prefix, state in self.tracker.items(): + if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: + file_metadata = {"key": file_prefix, "state": state} + incomplete_files.append(file_metadata) + return incomplete_files + + def _check_if_complete(self): + incomplete_files = self._check_for_incomplete_files() + try: + manager = ProviderManager(self.provider_uuid) + context = manager.get_additional_context() + context[self.context_key] = True + if incomplete_files: + context[self.context_key] = False + context[self.failed_files_key] = incomplete_files + manager.model.set_additional_context(context) + LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context)) + except ProviderManagerError: + pass + + def finalize_and_clean_up(self): + self._check_if_complete() + self._clean_local_files() + # We can decide if we want to record + # failed parquet conversion diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py new file mode 100644 index 0000000000..7c63d4862f --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -0,0 +1,300 @@ +import logging +import os +import uuid +from pathlib import Path + +import ciso8601 +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from botocore.exceptions import ClientError +from django.conf import settings + +from api.common import log_json +from api.provider.models import Provider +from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.config import Config +from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE +from masu.util.aws.common import get_s3_resource +from masu.util.common import get_path_prefix +from masu.util.common import strip_characters_from_column_name +from reporting.provider.aws.models import TRINO_REQUIRED_COLUMNS as AWS_TRINO_REQUIRED_COLUMNS +from reporting.provider.azure.models import TRINO_REQUIRED_COLUMNS as AZURE_TRINO_REQUIRED_COLUMNS +from reporting.provider.oci.models import TRINO_REQUIRED_COLUMNS as OCI_TRINO_REQUIRED_COLUMNS + +# TODO: Move the Trino required columns up to the task handler. + + +# Node role is the only column we add manually for OCP +# Therefore, it is the only column that can be incorrect +OCP_TRINO_REQUIRED_COLUMNS = {"node_role": ""} + +LOG = logging.getLogger(__name__) + + +class VerifyParquetFiles: + CONVERTER_VERSION = 1.0 + + def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_date): + self.schema_name = schema_name + self.provider_uuid = uuid.UUID(provider_uuid) + self.provider_type = provider_type.replace("-local", "") + self.simulate = True + # self.simulate = simulate + self.bill_date = bill_date + self.file_tracker = StateTracker(provider_uuid) + # Provider specific vars + self.openshift_data = False + self.report_types = [None] + self.required_columns = self.set_required_columns() + self.logging_context = { + "provider_type": self.provider_type, + "provider_uuid": self.provider_uuid, + "schema": self.schema_name, + "simulate": self.simulate, + "bill_date": self.bill_date, + } + + def _get_bill_dates(self): + # However far back we want to fix. + return [ciso8601.parse_datetime(self.bill_date)] + + def _find_pyarrow_value(self, default_value): + """Our mapping contains a default value, but + we need the pyarrow value for that default value. + """ + if pd.isnull(default_value): + # TODO: Azure saves datetime as pa.timestamp("ms") + # TODO: AWS saves datetime as timestamp[ms, tz=UTC] + # Should we be storing in a standard type here? + return pa.timestamp("ms") + if default_value == "": + return pa.string() + if default_value == 0.0: + return pa.float64() + + def _clean_mapping(self, mapping): + """ + Our required mapping stores the raw column name; however, + the parquet files will contain the cleaned column name. + """ + scrubbed_mapping = {} + for raw_column_name, default_value in mapping.items(): + scrubbed_column_name = strip_characters_from_column_name(raw_column_name) + scrubbed_mapping[scrubbed_column_name] = self._find_pyarrow_value(default_value) + return scrubbed_mapping + + def set_required_columns(self): + """Grabs the mapping of column_name to data type.""" + if self.provider_type == Provider.PROVIDER_OCI: + self.report_types = ["cost", "usage"] + return self._clean_mapping(OCI_TRINO_REQUIRED_COLUMNS) + if self.provider_type == Provider.PROVIDER_OCP: + self.report_types = ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] + return self._clean_mapping(OCP_TRINO_REQUIRED_COLUMNS) + if self.provider_type == Provider.PROVIDER_AWS: + return self._clean_mapping(AWS_TRINO_REQUIRED_COLUMNS) + if self.provider_type == Provider.PROVIDER_AZURE: + return self._clean_mapping(AZURE_TRINO_REQUIRED_COLUMNS) + + # Stolen from parquet_report_processor + def _parquet_path_s3(self, bill_date, report_type): + """The path in the S3 bucket where Parquet files are loaded.""" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=report_type, + ) + + # Stolen from parquet_report_processor + def _parquet_daily_path_s3(self, bill_date, report_type): + """The path in the S3 bucket where Parquet files are loaded.""" + if report_type is None: + report_type = "raw" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=report_type, + daily=True, + ) + + # Stolen from parquet_report_processor + def _parquet_ocp_on_cloud_path_s3(self, bill_date): + """The path in the S3 bucket where Parquet files are loaded.""" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=OPENSHIFT_REPORT_TYPE, + daily=True, + ) + + # Stolen from parquet_report_processor + def _generate_s3_path_prefixes(self, bill_date): + """ + generates the s3 path prefixes. + """ + path_prefixes = set() + for report_type in self.report_types: + path_prefixes.add(self._parquet_path_s3(bill_date, report_type)) + path_prefixes.add(self._parquet_daily_path_s3(bill_date, report_type)) + if self.openshift_data: + path_prefixes.add(self._parquet_ocp_on_cloud_path_s3(bill_date)) + return path_prefixes + + # Stolen from parquet_report_processor + @property + def local_path(self): + local_path = Path(Config.TMP_DIR, self.schema_name, str(self.provider_uuid)) + local_path.mkdir(parents=True, exist_ok=True) + return local_path + + # New logic to download the parquet files locally, coerce them, + # then upload the files that need updated back to s3 + def retrieve_verify_reload_S3_parquet(self): + """Retrieves the s3 files from s3""" + s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) + s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) + bill_dates = self._get_bill_dates() + for bill_date in bill_dates: + for prefix in self._generate_s3_path_prefixes(bill_date): + LOG.info( + log_json( + self.provider_uuid, + msg="Retrieving files from S3.", + context=self.logging_context, + prefix=prefix, + ) + ) + for s3_object in s3_bucket.objects.filter(Prefix=prefix): + s3_object_key = s3_object.key + self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE) + local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) + LOG.info( + log_json( + self.provider_uuid, + msg="Downloading file locally", + context=self.logging_context, + prefix=prefix, + local_file_path=local_file_path, + ) + ) + s3_bucket.download_file(s3_object_key, local_file_path) + self.file_tracker.add_local_file(s3_object_key, local_file_path) + self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) + if self.simulate: + self.file_tracker.generate_simulate_messages() + return False + else: + files_need_updated = self.file_tracker.get_files_that_need_updated() + for s3_obj_key, converted_local_file_path in files_need_updated.items(): + try: + s3_bucket.Object(s3_obj_key).delete() + LOG.info(f"Deleted current parquet file: {s3_obj_key}") + except ClientError as e: + LOG.info(f"Failed to delete {s3_object_key}: {str(e)}") + self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) + continue + + # An error here would cause a data gap. + with open(converted_local_file_path, "rb") as new_file: + s3_bucket.upload_fileobj(new_file, s3_obj_key) + LOG.info(f"Uploaded revised parquet: {s3_object_key}") + self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) + self.file_tracker.finalize_and_clean_up() + + # Same logic as last time, but combined into one method & added state tracking + def _coerce_parquet_data_type(self, parquet_file_path): + """If a parquet file has an incorrect dtype we can attempt to coerce + it to the correct type it. + + Returns a boolean indicating if the update parquet file should be sent + to s3. + """ + LOG.info( + log_json( + self.provider_uuid, + msg="Checking local parquet_file", + context=self.logging_context, + local_file_path=parquet_file_path, + ) + ) + corrected_fields = {} + try: + table = pq.read_table(parquet_file_path) + schema = table.schema + fields = [] + for field in schema: + if correct_data_type := self.required_columns.get(field.name): + # Check if the field's type matches the desired type + if field.type != correct_data_type: + # State update: Needs to be replaced. + LOG.info( + log_json( + self.provider_uuid, + msg="Incorrect data type.", + context=self.logging_context, + column_name=field.name, + current_dtype=field.type, + expected_data_type=correct_data_type, + ) + ) + LOG.info( + log_json( + self.provider_uuid, + msg="Building new parquet schema.", + context=self.logging_context, + column_name=field.name, + expected_data_type=correct_data_type, + ) + ) + field = pa.field(field.name, correct_data_type) + corrected_fields[field.name] = correct_data_type + fields.append(field) + + if not corrected_fields: + # Final State: No changes needed. + LOG.info( + log_json( + self.provider_uuid, + msg="All data types correct", + context=self.logging_context, + local_file_path=parquet_file_path, + ) + ) + return self.file_tracker.NO_CHANGES_NEEDED + + new_schema = pa.schema(fields) + LOG.info( + log_json( + self.provider_uuid, + msg="Applying new parquet schema to local parquet file.", + context=self.logging_context, + local_file_path=parquet_file_path, + updated_columns=corrected_fields, + ) + ) + table = table.cast(new_schema) + LOG.info( + log_json( + self.provider_uuid, + msg="Saving updated schema to the local parquet_file", + local_file_path=parquet_file_path, + ) + ) + # Write the table back to the Parquet file + pa.parquet.write_table(table, parquet_file_path) + # Signal that we need to send this update to S3. + return self.file_tracker.COERCE_REQUIRED + + except Exception as e: + LOG.info(log_json(self.provider_uuid, msg="Failed to coerce data.", context=self.logging_context, error=e)) + return self.file_tracker.FAILED_DTYPE_CONVERSION diff --git a/koku/masu/api/upgrade_trino/view.py b/koku/masu/api/upgrade_trino/view.py new file mode 100644 index 0000000000..9279562e7d --- /dev/null +++ b/koku/masu/api/upgrade_trino/view.py @@ -0,0 +1,65 @@ +# +# Copyright 2023 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""View for fixing parquet files endpoint.""" +import logging + +from django.views.decorators.cache import never_cache +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.decorators import permission_classes +from rest_framework.decorators import renderer_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework.settings import api_settings + +from api.provider.models import Provider +from masu.api.upgrade_trino.task_handler import fix_parquet_data_types_task_builder + +LOG = logging.getLogger(__name__) + + +class RequiredParametersError(Exception): + """Handle require parameters error.""" + + +def build_task_handler_kwargs(query_params): + """Validates the expected query parameters.""" + uuid_or_type_provided = False # used to check if provider uuid or type supplied + reprocess_kwargs = {} + if start_date := query_params.get("start_date"): + reprocess_kwargs["bill_date"] = start_date + else: + raise RequiredParametersError("start_date must be supplied as a parameter.") + if provider_uuid := query_params.get("provider_uuid"): + uuid_or_type_provided = True + provider = Provider.objects.filter(uuid=provider_uuid).first() + if not provider: + raise RequiredParametersError(f"The provider_uuid {provider_uuid} does not exist.") + reprocess_kwargs["provider_uuid"] = provider_uuid + if provider_type := query_params.get("provider_type"): + uuid_or_type_provided = True + reprocess_kwargs["provider_type"] = provider_type + if not uuid_or_type_provided: + raise RequiredParametersError("provider_uuid or provider_type must be supplied") + return reprocess_kwargs + + +@never_cache +@api_view(http_method_names=["GET", "DELETE"]) +@permission_classes((AllowAny,)) +@renderer_classes(tuple(api_settings.DEFAULT_RENDERER_CLASSES)) +def fix_parquet(request): + """Return expired data.""" + simulate = False + params = request.query_params + try: + task_handler_kwargs = build_task_handler_kwargs(params) + async_fix_results = fix_parquet_data_types_task_builder(**task_handler_kwargs) + except RequiredParametersError as errmsg: + return Response({"Error": str(errmsg)}, status=status.HTTP_400_BAD_REQUEST) + response_key = "Async jobs for fix parquet files" + if simulate: + response_key = response_key + " (simulated)" + return Response({response_key: str(async_fix_results)}) diff --git a/koku/masu/api/urls.py b/koku/masu/api/urls.py index 54513f4575..e9d7ec1510 100644 --- a/koku/masu/api/urls.py +++ b/koku/masu/api/urls.py @@ -23,6 +23,7 @@ from masu.api.views import EnabledTagView from masu.api.views import expired_data from masu.api.views import explain_query +from masu.api.views import fix_parquet from masu.api.views import get_status from masu.api.views import hcs_report_data from masu.api.views import hcs_report_finalization @@ -48,6 +49,7 @@ urlpatterns = [ + path("fix_parquet/", fix_parquet, name="fix_parquet"), path("status/", get_status, name="server-status"), path("download/", download_report, name="report_download"), path("ingress_reports/", ingress_reports, name="ingress_reports"), diff --git a/koku/masu/api/views.py b/koku/masu/api/views.py index 124c26e00f..aa97e2b284 100644 --- a/koku/masu/api/views.py +++ b/koku/masu/api/views.py @@ -39,3 +39,4 @@ from masu.api.update_cost_model_costs import update_cost_model_costs from masu.api.update_exchange_rates import update_exchange_rates from masu.api.update_openshift_on_cloud import update_openshift_on_cloud +from masu.api.upgrade_trino.view import fix_parquet diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index 3e0159c4d2..bc61d1c39a 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -30,6 +30,7 @@ from api.utils import DateHelper from koku import celery_app from koku.notifications import NotificationService +from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles from masu.config import Config from masu.database.cost_model_db_accessor import CostModelDBAccessor from masu.database.ocp_report_db_accessor import OCPReportDBAccessor @@ -57,6 +58,13 @@ } +# TODO: Change the queue from the default queue +@celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT) +def fix_parquet_data_types(*args, **kwargs): + verify_parquet = VerifyParquetFiles(*args, **kwargs) + verify_parquet.retrieve_verify_reload_S3_parquet() + + @celery_app.task(name="masu.celery.tasks.check_report_updates", queue=DEFAULT) def check_report_updates(*args, **kwargs): """Scheduled task to initiate scanning process on a regular interval.""" From 1ddf095472c95face5193ecaa04f44efbdc8bb42 Mon Sep 17 00:00:00 2001 From: myersCody Date: Wed, 13 Dec 2023 13:10:33 -0500 Subject: [PATCH 02/30] Update the task handler logic. --- koku/masu/api/upgrade_trino/task_handler.py | 45 ------- .../api/upgrade_trino/util/state_tracker.py | 11 +- .../api/upgrade_trino/util/task_handler.py | 116 +++++++++++++++++ .../util/verify_parquet_files.py | 117 +++++++----------- koku/masu/api/upgrade_trino/view.py | 40 +----- 5 files changed, 170 insertions(+), 159 deletions(-) delete mode 100644 koku/masu/api/upgrade_trino/task_handler.py create mode 100644 koku/masu/api/upgrade_trino/util/task_handler.py diff --git a/koku/masu/api/upgrade_trino/task_handler.py b/koku/masu/api/upgrade_trino/task_handler.py deleted file mode 100644 index ca685d60c6..0000000000 --- a/koku/masu/api/upgrade_trino/task_handler.py +++ /dev/null @@ -1,45 +0,0 @@ -import copy -import logging - -from api.common import log_json -from api.provider.models import Provider -from masu.celery.tasks import fix_parquet_data_types -from masu.processor.orchestrator import get_billing_month_start - -LOG = logging.getLogger(__name__) - - -def fix_parquet_data_types_task_builder(bill_date, provider_uuid=None, provider_type=None, simulate=False): - """ - Fixes the parquet file data type for each account. - Args: - simulate (Boolean) simulate the parquet file fixing. - Returns: - (celery.result.AsyncResult) Async result for deletion request. - """ - async_results = [] - if provider_type: - providers = Provider.objects.filter(active=True, paused=False, type=provider_type) - else: - providers = Provider.objects.filter(uuid=provider_uuid) - for provider in providers: - account = copy.deepcopy(provider.account) - report_month = get_billing_month_start(bill_date) - async_result = fix_parquet_data_types.delay( - schema_name=account.get("schema_name"), - provider_type=account.get("provider_type"), - provider_uuid=account.get("provider_uuid"), - simulate=simulate, - bill_date=report_month, - ) - LOG.info( - log_json( - provider.uuid, - msg="Calling fix_parquet_data_types", - schema=account.get("schema_name"), - provider_uuid=provider.uuid, - task_id=str(async_result), - ) - ) - async_results.append(str(async_result)) - return async_results diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 123909a173..85ef12d4d9 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -15,6 +15,7 @@ class StateTracker: NO_CHANGES_NEEDED = "no_changes_needed" COERCE_REQUIRED = "coerce_required" SENT_TO_S3_COMPLETE = "sent_to_s3_complete" + S3_FILE_DELETED = "s3_file_deleted" SENT_TO_S3_FAILED = "sent_to_s3_failed" FAILED_DTYPE_CONVERSION = "failed_data_type_conversion" @@ -62,14 +63,12 @@ def generate_simulate_messages(self): else: files_failed.append(s3_obj_key) simulate_info = { - "already correct.": files_correct, - "need updated.": files_need_updated, - "failed to convert.": files_failed, + "Files that have all correct data_types.": files_correct, + "Files that need to be updated.": files_need_updated, + "Files that failed to convert.": files_failed, } for substring, files_list in simulate_info.items(): - LOG.info(f"{len(files_list)} out of {files_count} {substring}") - if files_list: - LOG.info(f"File list: {files_list}") + LOG.info(log_json(self.provider_uuid, msg=substring, file_count=len(files_list), total_count=files_count)) def _clean_local_files(self): for file_path in self.local_files.values(): diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py new file mode 100644 index 0000000000..cf092d6ac8 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -0,0 +1,116 @@ +import copy +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Optional + +from django.http import QueryDict + +from api.common import log_json +from api.provider.models import Provider +from masu.celery.tasks import fix_parquet_data_types +from masu.processor.orchestrator import get_billing_month_start +from masu.util.common import strip_characters_from_column_name +from reporting.provider.aws.models import TRINO_REQUIRED_COLUMNS as AWS_TRINO_REQUIRED_COLUMNS +from reporting.provider.azure.models import TRINO_REQUIRED_COLUMNS as AZURE_TRINO_REQUIRED_COLUMNS +from reporting.provider.oci.models import TRINO_REQUIRED_COLUMNS as OCI_TRINO_REQUIRED_COLUMNS + +LOG = logging.getLogger(__name__) + + +class RequiredParametersError(Exception): + """Handle require parameters error.""" + + +@dataclass +class FixParquetTaskHandler: + bill_date: Optional[str] = field(default=None) + provider_uuid: Optional[str] = field(default=None) + provider_type: Optional[str] = field(default=None) + simulate: Optional[bool] = field(default=False) + cleaned_column_mapping: Optional[dict] = field(default=None) + + # Node role is the only column we add manually for OCP + # Therefore, it is the only column that can be incorrect + REQUIRED_COLUMNS_MAPPING = { + Provider.PROVIDER_OCI: OCI_TRINO_REQUIRED_COLUMNS, + Provider.PROVIDER_OCP: {"node_role": ""}, + Provider.PROVIDER_AWS: AWS_TRINO_REQUIRED_COLUMNS, + Provider.PROVIDER_AZURE: AZURE_TRINO_REQUIRED_COLUMNS, + } + + @classmethod + def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": + """Create an instance from query parameters.""" + reprocess_kwargs = cls() + if start_date := query_params.get("start_date"): + reprocess_kwargs.bill_date = start_date + + if provider_uuid := query_params.get("provider_uuid"): + provider = Provider.objects.filter(uuid=provider_uuid).first() + if not provider: + raise RequiredParametersError(f"The provider_uuid {provider_uuid} does not exist.") + reprocess_kwargs.provider_uuid = provider_uuid + reprocess_kwargs.provider_type = provider.type + + if provider_type := query_params.get("provider_type"): + reprocess_kwargs.provider_type = provider_type + + if simulate := query_params.get("simulate"): + reprocess_kwargs.simulate = simulate + + if not reprocess_kwargs.provider_type and not reprocess_kwargs.provider_uuid: + raise RequiredParametersError("provider_uuid or provider_type must be supplied") + if not reprocess_kwargs.bill_date: + raise RequiredParametersError("start_date must be supplied as a parameter.") + + reprocess_kwargs.cleaned_column_mapping = reprocess_kwargs.clean_column_names() + return reprocess_kwargs + + def clean_column_names(self): + """Creates a mapping of columns to expected pyarrow values.""" + clean_column_names = {} + # provider_type_key = copy.deepcopy() + provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(self.provider_type.replace("local", "")) + # Our required mapping stores the raw column name; however, + # the parquet files will contain the cleaned column name. + for raw_col, default_val in provider_mapping.items(): + clean_column_names[strip_characters_from_column_name(raw_col)] = default_val + return clean_column_names + + def build_celery_tasks(self): + """ + Fixes the parquet file data type for each account. + Args: + simulate (Boolean) simulate the parquet file fixing. + Returns: + (celery.result.AsyncResult) Async result for deletion request. + """ + async_results = [] + if self.provider_uuid: + providers = Provider.objects.filter(uuid=self.provider_uuid) + else: + providers = Provider.objects.filter(active=True, paused=False, type=self.provider_type) + + for provider in providers: + account = copy.deepcopy(provider.account) + report_month = get_billing_month_start(self.bill_date) + async_result = fix_parquet_data_types.delay( + schema_name=account.get("schema_name"), + provider_type=account.get("provider_type"), + provider_uuid=account.get("provider_uuid"), + simulate=self.simulate, + bill_date=report_month, + cleaned_column_mapping=self.cleaned_column_mapping, + ) + LOG.info( + log_json( + provider.uuid, + msg="Calling fix_parquet_data_types", + schema=account.get("schema_name"), + provider_uuid=provider.uuid, + task_id=str(async_result), + ) + ) + async_results.append(str(async_result)) + return async_results diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 7c63d4862f..5ef88e2bb2 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -17,36 +17,26 @@ from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE from masu.util.aws.common import get_s3_resource from masu.util.common import get_path_prefix -from masu.util.common import strip_characters_from_column_name -from reporting.provider.aws.models import TRINO_REQUIRED_COLUMNS as AWS_TRINO_REQUIRED_COLUMNS -from reporting.provider.azure.models import TRINO_REQUIRED_COLUMNS as AZURE_TRINO_REQUIRED_COLUMNS -from reporting.provider.oci.models import TRINO_REQUIRED_COLUMNS as OCI_TRINO_REQUIRED_COLUMNS -# TODO: Move the Trino required columns up to the task handler. - - -# Node role is the only column we add manually for OCP -# Therefore, it is the only column that can be incorrect -OCP_TRINO_REQUIRED_COLUMNS = {"node_role": ""} LOG = logging.getLogger(__name__) class VerifyParquetFiles: CONVERTER_VERSION = 1.0 + S3_OBJ_LOG_KEY = "s3_object_key" + S3_PREFIX_LOG_KEY = "s3_prefix" - def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_date): + def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_date, cleaned_column_mapping): self.schema_name = schema_name self.provider_uuid = uuid.UUID(provider_uuid) self.provider_type = provider_type.replace("-local", "") - self.simulate = True - # self.simulate = simulate + self.simulate = simulate self.bill_date = bill_date self.file_tracker = StateTracker(provider_uuid) - # Provider specific vars - self.openshift_data = False - self.report_types = [None] - self.required_columns = self.set_required_columns() + self.openshift_data = False # Not sure if we need this + self.report_types = self._set_report_types() + self.required_columns = cleaned_column_mapping self.logging_context = { "provider_type": self.provider_type, "provider_uuid": self.provider_uuid, @@ -55,47 +45,28 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat "bill_date": self.bill_date, } - def _get_bill_dates(self): - # However far back we want to fix. - return [ciso8601.parse_datetime(self.bill_date)] + def _set_report_types(self): + if self.provider_type == Provider.PROVIDER_OCI: + return ["cost", "usage"] + if self.provider_type == Provider.PROVIDER_OCP: + return ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] + return [None] - def _find_pyarrow_value(self, default_value): - """Our mapping contains a default value, but - we need the pyarrow value for that default value. - """ - if pd.isnull(default_value): + def _find_pyarrow_value(self, default_val): + """Converts our default value to a pyarrow dtype.""" + if pd.isnull(default_val): # TODO: Azure saves datetime as pa.timestamp("ms") # TODO: AWS saves datetime as timestamp[ms, tz=UTC] # Should we be storing in a standard type here? return pa.timestamp("ms") - if default_value == "": + elif default_val == "": return pa.string() - if default_value == 0.0: + elif default_val == 0.0: return pa.float64() - def _clean_mapping(self, mapping): - """ - Our required mapping stores the raw column name; however, - the parquet files will contain the cleaned column name. - """ - scrubbed_mapping = {} - for raw_column_name, default_value in mapping.items(): - scrubbed_column_name = strip_characters_from_column_name(raw_column_name) - scrubbed_mapping[scrubbed_column_name] = self._find_pyarrow_value(default_value) - return scrubbed_mapping - - def set_required_columns(self): - """Grabs the mapping of column_name to data type.""" - if self.provider_type == Provider.PROVIDER_OCI: - self.report_types = ["cost", "usage"] - return self._clean_mapping(OCI_TRINO_REQUIRED_COLUMNS) - if self.provider_type == Provider.PROVIDER_OCP: - self.report_types = ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] - return self._clean_mapping(OCP_TRINO_REQUIRED_COLUMNS) - if self.provider_type == Provider.PROVIDER_AWS: - return self._clean_mapping(AWS_TRINO_REQUIRED_COLUMNS) - if self.provider_type == Provider.PROVIDER_AZURE: - return self._clean_mapping(AZURE_TRINO_REQUIRED_COLUMNS) + def _get_bill_dates(self): + # However far back we want to fix. + return [ciso8601.parse_datetime(self.bill_date)] # Stolen from parquet_report_processor def _parquet_path_s3(self, bill_date, report_type): @@ -157,8 +128,6 @@ def local_path(self): local_path.mkdir(parents=True, exist_ok=True) return local_path - # New logic to download the parquet files locally, coerce them, - # then upload the files that need updated back to s3 def retrieve_verify_reload_S3_parquet(self): """Retrieves the s3 files from s3""" s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) @@ -166,6 +135,7 @@ def retrieve_verify_reload_S3_parquet(self): bill_dates = self._get_bill_dates() for bill_date in bill_dates: for prefix in self._generate_s3_path_prefixes(bill_date): + self.logging_context[self.S3_PREFIX_LOG_KEY] = prefix LOG.info( log_json( self.provider_uuid, @@ -176,6 +146,7 @@ def retrieve_verify_reload_S3_parquet(self): ) for s3_object in s3_bucket.objects.filter(Prefix=prefix): s3_object_key = s3_object.key + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE) local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) LOG.info( @@ -183,22 +154,28 @@ def retrieve_verify_reload_S3_parquet(self): self.provider_uuid, msg="Downloading file locally", context=self.logging_context, - prefix=prefix, - local_file_path=local_file_path, ) ) s3_bucket.download_file(s3_object_key, local_file_path) self.file_tracker.add_local_file(s3_object_key, local_file_path) self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) + del self.logging_context[self.S3_OBJ_LOG_KEY] + del self.logging_context[self.S3_PREFIX_LOG_KEY] + if self.simulate: self.file_tracker.generate_simulate_messages() return False else: files_need_updated = self.file_tracker.get_files_that_need_updated() for s3_obj_key, converted_local_file_path in files_need_updated.items(): + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key try: + LOG.info(log_json(self.provider_uuid, "Deleting s3 parquet file.", context=self.logging_context)) s3_bucket.Object(s3_obj_key).delete() - LOG.info(f"Deleted current parquet file: {s3_obj_key}") + self.file_tracker.set_state(s3_object_key, self.file_tracker.S3_FILE_DELETED) + LOG.info( + log_json(self.provider_uuid, "Deletion of s3 parquet file.", context=self.logging_context) + ) except ClientError as e: LOG.info(f"Failed to delete {s3_object_key}: {str(e)}") self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) @@ -206,8 +183,15 @@ def retrieve_verify_reload_S3_parquet(self): # An error here would cause a data gap. with open(converted_local_file_path, "rb") as new_file: + LOG.info( + log_json( + self.provider_uuid, + "Uploading revised parquet file.", + context=self.logging_context, + local_file_path=converted_local_file_path, + ) + ) s3_bucket.upload_fileobj(new_file, s3_obj_key) - LOG.info(f"Uploaded revised parquet: {s3_object_key}") self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) self.file_tracker.finalize_and_clean_up() @@ -233,29 +217,21 @@ def _coerce_parquet_data_type(self, parquet_file_path): schema = table.schema fields = [] for field in schema: - if correct_data_type := self.required_columns.get(field.name): + if default_value := self.required_columns.get(field.name): + correct_data_type = self._find_pyarrow_value(default_value) # Check if the field's type matches the desired type if field.type != correct_data_type: # State update: Needs to be replaced. LOG.info( log_json( self.provider_uuid, - msg="Incorrect data type.", + msg="Incorrect data type, building new schema.", context=self.logging_context, column_name=field.name, current_dtype=field.type, expected_data_type=correct_data_type, ) ) - LOG.info( - log_json( - self.provider_uuid, - msg="Building new parquet schema.", - context=self.logging_context, - column_name=field.name, - expected_data_type=correct_data_type, - ) - ) field = pa.field(field.name, correct_data_type) corrected_fields[field.name] = correct_data_type fields.append(field) @@ -283,13 +259,6 @@ def _coerce_parquet_data_type(self, parquet_file_path): ) ) table = table.cast(new_schema) - LOG.info( - log_json( - self.provider_uuid, - msg="Saving updated schema to the local parquet_file", - local_file_path=parquet_file_path, - ) - ) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) # Signal that we need to send this update to S3. diff --git a/koku/masu/api/upgrade_trino/view.py b/koku/masu/api/upgrade_trino/view.py index 9279562e7d..d124cbf338 100644 --- a/koku/masu/api/upgrade_trino/view.py +++ b/koku/masu/api/upgrade_trino/view.py @@ -14,52 +14,24 @@ from rest_framework.response import Response from rest_framework.settings import api_settings -from api.provider.models import Provider -from masu.api.upgrade_trino.task_handler import fix_parquet_data_types_task_builder +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.api.upgrade_trino.util.task_handler import RequiredParametersError LOG = logging.getLogger(__name__) -class RequiredParametersError(Exception): - """Handle require parameters error.""" - - -def build_task_handler_kwargs(query_params): - """Validates the expected query parameters.""" - uuid_or_type_provided = False # used to check if provider uuid or type supplied - reprocess_kwargs = {} - if start_date := query_params.get("start_date"): - reprocess_kwargs["bill_date"] = start_date - else: - raise RequiredParametersError("start_date must be supplied as a parameter.") - if provider_uuid := query_params.get("provider_uuid"): - uuid_or_type_provided = True - provider = Provider.objects.filter(uuid=provider_uuid).first() - if not provider: - raise RequiredParametersError(f"The provider_uuid {provider_uuid} does not exist.") - reprocess_kwargs["provider_uuid"] = provider_uuid - if provider_type := query_params.get("provider_type"): - uuid_or_type_provided = True - reprocess_kwargs["provider_type"] = provider_type - if not uuid_or_type_provided: - raise RequiredParametersError("provider_uuid or provider_type must be supplied") - return reprocess_kwargs - - @never_cache @api_view(http_method_names=["GET", "DELETE"]) @permission_classes((AllowAny,)) @renderer_classes(tuple(api_settings.DEFAULT_RENDERER_CLASSES)) def fix_parquet(request): - """Return expired data.""" - simulate = False - params = request.query_params + """Fix parquet files so that we can upgrade Trino.""" try: - task_handler_kwargs = build_task_handler_kwargs(params) - async_fix_results = fix_parquet_data_types_task_builder(**task_handler_kwargs) + task_handler = FixParquetTaskHandler.from_query_params(request.query_params) + async_fix_results = task_handler.build_celery_tasks() except RequiredParametersError as errmsg: return Response({"Error": str(errmsg)}, status=status.HTTP_400_BAD_REQUEST) response_key = "Async jobs for fix parquet files" - if simulate: + if task_handler.simulate: response_key = response_key + " (simulated)" return Response({response_key: str(async_fix_results)}) From 3c316e518b5912d53e4037cbb7b92cbbb15c12fe Mon Sep 17 00:00:00 2001 From: myersCody Date: Wed, 13 Dec 2023 13:18:47 -0500 Subject: [PATCH 03/30] Clean up local tmp files when simulate is set to True. --- koku/masu/api/upgrade_trino/util/state_tracker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 85ef12d4d9..81be50f0ca 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -69,6 +69,7 @@ def generate_simulate_messages(self): } for substring, files_list in simulate_info.items(): LOG.info(log_json(self.provider_uuid, msg=substring, file_count=len(files_list), total_count=files_count)) + self._clean_local_files() def _clean_local_files(self): for file_path in self.local_files.values(): From e27c432b4d413741a308f2d0233b16705a1aa462 Mon Sep 17 00:00:00 2001 From: myersCody Date: Thu, 14 Dec 2023 10:24:41 -0500 Subject: [PATCH 04/30] Add unittests and fix edge case. --- koku/masu/api/upgrade_trino/test/__init__.py | 0 .../test/test_verify_parquet_files.py | 89 +++++++++++++++++++ koku/masu/api/upgrade_trino/test/test_view.py | 63 +++++++++++++ .../api/upgrade_trino/util/task_handler.py | 11 +-- .../util/verify_parquet_files.py | 75 +++++++++++----- 5 files changed, 213 insertions(+), 25 deletions(-) create mode 100644 koku/masu/api/upgrade_trino/test/__init__.py create mode 100644 koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py create mode 100644 koku/masu/api/upgrade_trino/test/test_view.py diff --git a/koku/masu/api/upgrade_trino/test/__init__.py b/koku/masu/api/upgrade_trino/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py new file mode 100644 index 0000000000..dccf0a04ff --- /dev/null +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -0,0 +1,89 @@ +# +# Copyright 2023 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""Test the hcs_report_data endpoint view.""" +import os +import shutil +import tempfile +from datetime import datetime + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + +from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles +from masu.test import MasuTestCase + + +class TestVerifyParquetFiles(MasuTestCase): + def setUp(self): + super().setUp() + # Experienced issues with pyarrow not + # playing nice with tempfiles. Therefore + # I opted for writing files to a tmp dir + self.temp_dir = tempfile.mkdtemp() + self.required_columns = {"float": 0.0, "string": "", "datetime": pd.NaT} + self.expected_pyarrow_dtypes = {"float": pa.float64(), "string": pa.string(), "datetime": pa.timestamp("ms")} + self.panda_kwargs = {"allow_truncated_timestamps": True, "coerce_timestamps": "ms", "index": False} + self.suffix = ".parquet" + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def create_default_verify_handler(self): + verify_parquet_file_kwargs = { + "schema_name": self.schema_name, + "provider_uuid": self.aws_provider_uuid, + "provider_type": self.aws_provider.type, + "simulate": True, + "bill_date": datetime(2023, 1, 1), + "cleaned_column_mapping": self.required_columns, + } + return VerifyParquetFiles(**verify_parquet_file_kwargs) + + def test_coerce_parquet_data_type_no_changes_needed(self): + """Test a parquet file with correct dtypes.""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + "unrequired_column": ["a", "b", "c"], + } + with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + self.assertEqual(return_state, StateTracker.NO_CHANGES_NEEDED) + + def test_coerce_parquet_data_type_coerce_needed(self): + """Test that files created through reindex are fixed correctly.""" + data_frame = pd.DataFrame() + data_frame = data_frame.reindex(columns=self.required_columns.keys()) + temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(temp_file, temp_file) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + self.assertEqual(return_state, StateTracker.COERCE_REQUIRED) + table = pq.read_table(temp_file) + schema = table.schema + for field in schema: + self.assertEqual(field.type, self.expected_pyarrow_dtypes.get(field.name)) + os.remove(temp_file) + + def test_coerce_parquet_data_type_failed_to_coerce(self): + """Test a parquet file with correct dtypes.""" + file_data = { + "float": [datetime(2023, 1, 1), datetime(2023, 1, 1), datetime(2023, 1, 1)], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + } + with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) diff --git a/koku/masu/api/upgrade_trino/test/test_view.py b/koku/masu/api/upgrade_trino/test/test_view.py new file mode 100644 index 0000000000..88339ea86a --- /dev/null +++ b/koku/masu/api/upgrade_trino/test/test_view.py @@ -0,0 +1,63 @@ +# +# Copyright 2023 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""Test the hcs_report_data endpoint view.""" +from unittest.mock import patch +from uuid import uuid4 + +from django.test.utils import override_settings +from django.urls import reverse + +from api.models import Provider +from api.utils import DateHelper +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.test import MasuTestCase + + +@override_settings(ROOT_URLCONF="masu.urls") +class TestUpgradeTrinoView(MasuTestCase): + ENDPOINT = "fix_parquet" + bill_date = DateHelper().month_start("2023-12-01") + + @patch("koku.middleware.MASU", return_value=True) + def test_required_parameters_failure(self, _): + """Test the hcs_report_finalization endpoint.""" + parameter_options = [{}, {"start_date": self.bill_date}, {"provider_uuid": self.aws_provider_uuid}] + for parameters in parameter_options: + with self.subTest(parameters=parameters): + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 400) + + @patch("koku.middleware.MASU", return_value=True) + def test_provider_uuid_does_not_exist(self, _): + """Test the hcs_report_finalization endpoint.""" + parameters = {"start_date": self.bill_date, "provider_uuid": str(uuid4())} + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 400) + + @patch("koku.middleware.MASU", return_value=True) + def test_acceptable_parameters(self, _): + """Test that the endpoint accepts""" + acceptable_parameters = [ + {"start_date": self.bill_date, "provider_type": self.aws_provider.type}, + {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": True}, + {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": "bad_value"}, + ] + cleaned_column_mapping = FixParquetTaskHandler.clean_column_names(self.aws_provider.type) + for parameters in acceptable_parameters: + with self.subTest(parameters=parameters): + with patch("masu.celery.tasks.fix_parquet_data_types.delay") as patch_celery: + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 200) + simulate = parameters.get("simulate", False) + if simulate == "bad_value": + simulate = False + patch_celery.assert_called_once_with( + schema_name=self.schema_name, + provider_type=Provider.PROVIDER_AWS_LOCAL, + provider_uuid=self.aws_provider.uuid, + simulate=simulate, + bill_date=self.bill_date, + cleaned_column_mapping=cleaned_column_mapping, + ) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index cf092d6ac8..c9d8ad14bf 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -57,21 +57,22 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": reprocess_kwargs.provider_type = provider_type if simulate := query_params.get("simulate"): - reprocess_kwargs.simulate = simulate + if simulate.lower() == "true": + reprocess_kwargs.simulate = True if not reprocess_kwargs.provider_type and not reprocess_kwargs.provider_uuid: raise RequiredParametersError("provider_uuid or provider_type must be supplied") if not reprocess_kwargs.bill_date: raise RequiredParametersError("start_date must be supplied as a parameter.") - reprocess_kwargs.cleaned_column_mapping = reprocess_kwargs.clean_column_names() + reprocess_kwargs.cleaned_column_mapping = reprocess_kwargs.clean_column_names(reprocess_kwargs.provider_type) return reprocess_kwargs - def clean_column_names(self): + @classmethod + def clean_column_names(self, provider_type): """Creates a mapping of columns to expected pyarrow values.""" clean_column_names = {} - # provider_type_key = copy.deepcopy() - provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(self.provider_type.replace("local", "")) + provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) # Our required mapping stores the raw column name; however, # the parquet files will contain the cleaned column name. for raw_col, default_val in provider_mapping.items(): diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 5ef88e2bb2..d4eca91230 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -36,7 +36,7 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat self.file_tracker = StateTracker(provider_uuid) self.openshift_data = False # Not sure if we need this self.report_types = self._set_report_types() - self.required_columns = cleaned_column_mapping + self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) self.logging_context = { "provider_type": self.provider_type, "provider_uuid": self.provider_uuid, @@ -45,6 +45,20 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat "bill_date": self.bill_date, } + def _set_pyarrow_types(self, cleaned_column_mapping): + mapping = {} + for key, default_val in cleaned_column_mapping.items(): + if pd.isnull(default_val): + # TODO: Azure saves datetime as pa.timestamp("ms") + # TODO: AWS saves datetime as timestamp[ms, tz=UTC] + # Should we be storing in a standard type here? + mapping[key] = pa.timestamp("ms") + elif default_val == "": + mapping[key] = pa.string() + elif default_val == 0.0: + mapping[key] = pa.float64() + return mapping + def _set_report_types(self): if self.provider_type == Provider.PROVIDER_OCI: return ["cost", "usage"] @@ -52,18 +66,6 @@ def _set_report_types(self): return ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] return [None] - def _find_pyarrow_value(self, default_val): - """Converts our default value to a pyarrow dtype.""" - if pd.isnull(default_val): - # TODO: Azure saves datetime as pa.timestamp("ms") - # TODO: AWS saves datetime as timestamp[ms, tz=UTC] - # Should we be storing in a standard type here? - return pa.timestamp("ms") - elif default_val == "": - return pa.string() - elif default_val == 0.0: - return pa.float64() - def _get_bill_dates(self): # However far back we want to fix. return [ciso8601.parse_datetime(self.bill_date)] @@ -195,8 +197,37 @@ def retrieve_verify_reload_S3_parquet(self): self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) self.file_tracker.finalize_and_clean_up() + def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): + """Performs a transformation to change a double to a timestamp.""" + table = pq.read_table(parquet_file_path) + schema = table.schema + fields = [] + for field in schema: + if field.name in field_names: + # if len is 0 here we get an empty list, if it does + # have a value for the field, overwrite it with bill_date + replaced_values = [self.bill_date] * len(table[field.name]) + corrected_column = pa.array(replaced_values, type=pa.timestamp("ms")) + field = pa.field(field.name, corrected_column.type) + fields.append(field) + # Create a new schema + new_schema = pa.schema(fields) + # Create a DataFrame from the original PyArrow Table + original_df = table.to_pandas() + + # Update the DataFrame with corrected values + for field_name in field_names: + if field_name in original_df.columns: + original_df[field_name] = corrected_column.to_pandas() + + # Create a new PyArrow Table from the updated DataFrame + new_table = pa.Table.from_pandas(original_df, schema=new_schema) + + # Write the new table back to the Parquet file + pq.write_table(new_table, parquet_file_path) + # Same logic as last time, but combined into one method & added state tracking - def _coerce_parquet_data_type(self, parquet_file_path): + def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=True): """If a parquet file has an incorrect dtype we can attempt to coerce it to the correct type it. @@ -212,16 +243,15 @@ def _coerce_parquet_data_type(self, parquet_file_path): ) ) corrected_fields = {} + double_to_timestamp_fields = [] try: table = pq.read_table(parquet_file_path) schema = table.schema fields = [] for field in schema: - if default_value := self.required_columns.get(field.name): - correct_data_type = self._find_pyarrow_value(default_value) + if correct_data_type := self.required_columns.get(field.name): # Check if the field's type matches the desired type if field.type != correct_data_type: - # State update: Needs to be replaced. LOG.info( log_json( self.provider_uuid, @@ -232,11 +262,14 @@ def _coerce_parquet_data_type(self, parquet_file_path): expected_data_type=correct_data_type, ) ) - field = pa.field(field.name, correct_data_type) - corrected_fields[field.name] = correct_data_type + if field.type == pa.float64() and correct_data_type == pa.timestamp("ms"): + double_to_timestamp_fields.append(field.name) + else: + field = pa.field(field.name, correct_data_type) + corrected_fields[field.name] = correct_data_type fields.append(field) - if not corrected_fields: + if not corrected_fields and not double_to_timestamp_fields: # Final State: No changes needed. LOG.info( log_json( @@ -261,6 +294,8 @@ def _coerce_parquet_data_type(self, parquet_file_path): table = table.cast(new_schema) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) + if double_to_timestamp_fields: + self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. return self.file_tracker.COERCE_REQUIRED From b91827e67f083249c16bb89b505290c3b8c3f01f Mon Sep 17 00:00:00 2001 From: myersCody Date: Thu, 14 Dec 2023 14:30:08 -0500 Subject: [PATCH 05/30] Improve test coverage. --- .../test/test_verify_parquet_files.py | 222 +++++++++++++++++- .../api/upgrade_trino/util/state_tracker.py | 3 +- .../util/verify_parquet_files.py | 10 +- 3 files changed, 218 insertions(+), 17 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index dccf0a04ff..64884514c4 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -6,15 +6,24 @@ import os import shutil import tempfile +from collections import namedtuple from datetime import datetime +from unittest.mock import patch import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from api.utils import DateHelper from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles +from masu.celery.tasks import PROVIDER_REPORT_TYPE_MAP +from masu.config import Config from masu.test import MasuTestCase +from masu.util.common import get_path_prefix + +DummyS3Object = namedtuple("DummyS3Object", "key") class TestVerifyParquetFiles(MasuTestCase): @@ -25,23 +34,74 @@ def setUp(self): # I opted for writing files to a tmp dir self.temp_dir = tempfile.mkdtemp() self.required_columns = {"float": 0.0, "string": "", "datetime": pd.NaT} - self.expected_pyarrow_dtypes = {"float": pa.float64(), "string": pa.string(), "datetime": pa.timestamp("ms")} - self.panda_kwargs = {"allow_truncated_timestamps": True, "coerce_timestamps": "ms", "index": False} + self.expected_pyarrow_dtypes = { + "float": pa.float64(), + "string": pa.string(), + "datetime": pa.timestamp("ms"), + } + self.panda_kwargs = { + "allow_truncated_timestamps": True, + "coerce_timestamps": "ms", + "index": False, + } self.suffix = ".parquet" def tearDown(self): shutil.rmtree(self.temp_dir) def create_default_verify_handler(self): - verify_parquet_file_kwargs = { - "schema_name": self.schema_name, - "provider_uuid": self.aws_provider_uuid, - "provider_type": self.aws_provider.type, - "simulate": True, - "bill_date": datetime(2023, 1, 1), - "cleaned_column_mapping": self.required_columns, - } - return VerifyParquetFiles(**verify_parquet_file_kwargs) + return VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.aws_provider_uuid, + provider_type=self.aws_provider.type, + simulate=True, + bill_date=datetime(2023, 1, 1), + cleaned_column_mapping=self.required_columns, + ) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") + def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _): + """Test fixes for reindexes on all required columns.""" + # build a parquet file where reindex is used for all required columns + test_metadata = [ + {"uuid": self.aws_provider_uuid, "type": self.aws_provider.type}, + {"uuid": self.azure_provider_uuid, "type": self.azure_provider.type}, + {"uuid": self.ocp_provider_uuid, "type": self.ocp_provider.type}, + {"uuid": self.oci_provider_uuid, "type": self.oci_provider.type}, + ] + for metadata in test_metadata: + with self.subTest(metadata=metadata): + bill_date = str(DateHelper().this_month_start) + required_columns = FixParquetTaskHandler.clean_column_names(metadata["type"]) + data_frame = pd.DataFrame() + data_frame = data_frame.reindex(columns=required_columns.keys()) + filename = f"test_{metadata['uuid']}{self.suffix}" + temp_file = os.path.join(self.temp_dir, filename) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + mock_bucket = mock_s3_resource.return_value.Bucket.return_value + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=metadata["uuid"], + provider_type=metadata["type"], + simulate=False, + bill_date=bill_date, + cleaned_column_mapping=required_columns, + ) + prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) + filter_side_effect = [[DummyS3Object(key=temp_file)]] + for _ in range(len(prefixes) - 1): + filter_side_effect.append([]) + mock_bucket.objects.filter.side_effect = filter_side_effect + mock_bucket.download_file.return_value = temp_file + VerifyParquetFiles.local_path = self.temp_dir + verify_handler.retrieve_verify_reload_S3_parquet() + mock_bucket.upload_fileobj.assert_called() + table = pq.read_table(temp_file) + schema = table.schema + for field in schema: + self.assertEqual(field.type, verify_handler.required_columns.get(field.name)) + os.remove(temp_file) def test_coerce_parquet_data_type_no_changes_needed(self): """Test a parquet file with correct dtypes.""" @@ -56,18 +116,24 @@ def test_coerce_parquet_data_type_no_changes_needed(self): verify_handler = self.create_default_verify_handler() verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) + verify_handler.file_tracker.set_state(temp_file.name, return_state) self.assertEqual(return_state, StateTracker.NO_CHANGES_NEEDED) + self.assertEqual(verify_handler.file_tracker._check_for_incomplete_files(), []) def test_coerce_parquet_data_type_coerce_needed(self): """Test that files created through reindex are fixed correctly.""" data_frame = pd.DataFrame() data_frame = data_frame.reindex(columns=self.required_columns.keys()) + filename = f"test{self.suffix}" temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") data_frame.to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(temp_file, temp_file) + verify_handler.file_tracker.add_local_file(filename, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) self.assertEqual(return_state, StateTracker.COERCE_REQUIRED) + verify_handler.file_tracker.set_state(filename, return_state) + files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() + self.assertTrue(files_need_updating.get(filename)) table = pq.read_table(temp_file) schema = table.schema for field in schema: @@ -86,4 +152,136 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): verify_handler = self.create_default_verify_handler() verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) + verify_handler.file_tracker.set_state(temp_file.name, return_state) self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) + self.assertNotEqual(verify_handler.file_tracker._check_for_incomplete_files(), []) + + def test_oci_s3_paths(self): + """test path generation for oci sources.""" + bill_date = DateHelper().this_month_start + expected_s3_paths = [] + for oci_report_type in PROVIDER_REPORT_TYPE_MAP.get(self.oci_provider.type): + path_kwargs = { + "account": self.schema_name, + "provider_type": self.oci_provider.type.replace("-local", ""), + "provider_uuid": self.oci_provider_uuid, + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + "report_type": oci_report_type, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.oci_provider_uuid, + provider_type=self.oci_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + def test_ocp_s3_paths(self): + """test path generation for ocp sources.""" + bill_date = DateHelper().this_month_start + expected_s3_paths = [] + for ocp_report_type in PROVIDER_REPORT_TYPE_MAP.get(self.ocp_provider.type).keys(): + path_kwargs = { + "account": self.schema_name, + "provider_type": self.ocp_provider.type.replace("-local", ""), + "provider_uuid": self.ocp_provider_uuid, + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + "report_type": ocp_report_type, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.ocp_provider_uuid, + provider_type=self.ocp_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + def test_other_providers_s3_paths(self): + bill_date = DateHelper().this_month_start + test_metadata = [ + {"uuid": self.aws_provider_uuid, "type": self.aws_provider.type.replace("-local", "")}, + {"uuid": self.azure_provider_uuid, "type": self.azure_provider.type.replace("-local", "")}, + ] + for metadata in test_metadata: + with self.subTest(metadata=metadata): + expected_s3_paths = [] + path_kwargs = { + "account": self.schema_name, + "provider_type": metadata["type"], + "provider_uuid": metadata["uuid"], + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + path_kwargs["report_type"] = "raw" + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=metadata["uuid"], + provider_type=metadata["type"], + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") + def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _): + """Test fixes for reindexes on all required columns.""" + # build a parquet file where reindex is used for all required columns + file_data = { + "float": [datetime(2023, 1, 1), datetime(2023, 1, 1), datetime(2023, 1, 1)], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + } + + bill_date = str(DateHelper().this_month_start) + temp_file = os.path.join(self.temp_dir, f"fail{self.suffix}") + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + mock_bucket = mock_s3_resource.return_value.Bucket.return_value + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.aws_provider_uuid, + provider_type=self.aws_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) + filter_side_effect = [[DummyS3Object(key=temp_file)]] + for _ in range(len(prefixes) - 1): + filter_side_effect.append([]) + mock_bucket.objects.filter.side_effect = filter_side_effect + mock_bucket.download_file.return_value = temp_file + VerifyParquetFiles.local_path = self.temp_dir + verify_handler.retrieve_verify_reload_S3_parquet() + mock_bucket.upload_fileobj.assert_not_called() + os.remove(temp_file) + + def test_local_path(self): + """Test local path.""" + verify_handler = self.create_default_verify_handler() + self.assertTrue(verify_handler.local_path) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 81be50f0ca..1f26a87215 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -73,7 +73,8 @@ def generate_simulate_messages(self): def _clean_local_files(self): for file_path in self.local_files.values(): - os.remove(file_path) + if os.path.exists(file_path): + os.remove(file_path) def _check_for_incomplete_files(self): incomplete_files = [] diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index d4eca91230..13d7ce74c9 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -161,7 +161,7 @@ def retrieve_verify_reload_S3_parquet(self): s3_bucket.download_file(s3_object_key, local_file_path) self.file_tracker.add_local_file(s3_object_key, local_file_path) self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) - del self.logging_context[self.S3_OBJ_LOG_KEY] + del self.logging_context[self.S3_OBJ_LOG_KEY] del self.logging_context[self.S3_PREFIX_LOG_KEY] if self.simulate: @@ -172,11 +172,13 @@ def retrieve_verify_reload_S3_parquet(self): for s3_obj_key, converted_local_file_path in files_need_updated.items(): self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key try: - LOG.info(log_json(self.provider_uuid, "Deleting s3 parquet file.", context=self.logging_context)) + LOG.info( + log_json(self.provider_uuid, msg="Deleting s3 parquet file.", context=self.logging_context) + ) s3_bucket.Object(s3_obj_key).delete() self.file_tracker.set_state(s3_object_key, self.file_tracker.S3_FILE_DELETED) LOG.info( - log_json(self.provider_uuid, "Deletion of s3 parquet file.", context=self.logging_context) + log_json(self.provider_uuid, msg="Deletion of s3 parquet file.", context=self.logging_context) ) except ClientError as e: LOG.info(f"Failed to delete {s3_object_key}: {str(e)}") @@ -188,7 +190,7 @@ def retrieve_verify_reload_S3_parquet(self): LOG.info( log_json( self.provider_uuid, - "Uploading revised parquet file.", + msg="Uploading revised parquet file.", context=self.logging_context, local_file_path=converted_local_file_path, ) From 971000d851a66aedc3fc31a756b131cb19f3e002 Mon Sep 17 00:00:00 2001 From: myersCody Date: Fri, 15 Dec 2023 07:46:43 -0500 Subject: [PATCH 06/30] Address code smells. --- .../upgrade_trino/test/test_verify_parquet_files.py | 8 ++++---- koku/masu/api/upgrade_trino/util/task_handler.py | 4 ++-- .../api/upgrade_trino/util/verify_parquet_files.py | 13 +++++++------ koku/masu/celery/tasks.py | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 64884514c4..e79497dce2 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -61,7 +61,7 @@ def create_default_verify_handler(self): @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") - def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _): + def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): """Test fixes for reindexes on all required columns.""" # build a parquet file where reindex is used for all required columns test_metadata = [ @@ -95,7 +95,7 @@ def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _): mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file VerifyParquetFiles.local_path = self.temp_dir - verify_handler.retrieve_verify_reload_S3_parquet() + verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_called() table = pq.read_table(temp_file) schema = table.schema @@ -249,7 +249,7 @@ def test_other_providers_s3_paths(self): @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") - def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _): + def test_retrieve_verify_reload_s3_parquet_failure(self, mock_s3_resource, _): """Test fixes for reindexes on all required columns.""" # build a parquet file where reindex is used for all required columns file_data = { @@ -277,7 +277,7 @@ def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _): mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file VerifyParquetFiles.local_path = self.temp_dir - verify_handler.retrieve_verify_reload_S3_parquet() + verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_not_called() os.remove(temp_file) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index c9d8ad14bf..15aac5180a 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -69,10 +69,10 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": return reprocess_kwargs @classmethod - def clean_column_names(self, provider_type): + def clean_column_names(cls, provider_type): """Creates a mapping of columns to expected pyarrow values.""" clean_column_names = {} - provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) + provider_mapping = cls.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) # Our required mapping stores the raw column name; however, # the parquet files will contain the cleaned column name. for raw_col, default_val in provider_mapping.items(): diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 13d7ce74c9..7ad50198e2 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -53,9 +53,9 @@ def _set_pyarrow_types(self, cleaned_column_mapping): # TODO: AWS saves datetime as timestamp[ms, tz=UTC] # Should we be storing in a standard type here? mapping[key] = pa.timestamp("ms") - elif default_val == "": + elif isinstance(default_val, str): mapping[key] = pa.string() - elif default_val == 0.0: + elif isinstance(default_val, float): mapping[key] = pa.float64() return mapping @@ -130,7 +130,7 @@ def local_path(self): local_path.mkdir(parents=True, exist_ok=True) return local_path - def retrieve_verify_reload_S3_parquet(self): + def retrieve_verify_reload_s3_parquet(self): """Retrieves the s3 files from s3""" s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) @@ -201,6 +201,8 @@ def retrieve_verify_reload_S3_parquet(self): def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): """Performs a transformation to change a double to a timestamp.""" + if not field_names: + return table = pq.read_table(parquet_file_path) schema = table.schema fields = [] @@ -229,7 +231,7 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n pq.write_table(new_table, parquet_file_path) # Same logic as last time, but combined into one method & added state tracking - def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=True): + def _coerce_parquet_data_type(self, parquet_file_path): """If a parquet file has an incorrect dtype we can attempt to coerce it to the correct type it. @@ -296,8 +298,7 @@ def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=Tr table = table.cast(new_schema) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) - if double_to_timestamp_fields: - self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) + self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. return self.file_tracker.COERCE_REQUIRED diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index bc61d1c39a..4f5a2668d5 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -62,7 +62,7 @@ @celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT) def fix_parquet_data_types(*args, **kwargs): verify_parquet = VerifyParquetFiles(*args, **kwargs) - verify_parquet.retrieve_verify_reload_S3_parquet() + verify_parquet.retrieve_verify_reload_s3_parquet() @celery_app.task(name="masu.celery.tasks.check_report_updates", queue=DEFAULT) From 101a4721032dca0aca91b58f156876511afbe025 Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Wed, 3 Jan 2024 14:28:20 +0000 Subject: [PATCH 07/30] use download or dl_XL queue and overwrite instead of delete/upload --- koku/masu/api/upgrade_trino/util/task_handler.py | 11 +++++++++-- .../upgrade_trino/util/verify_parquet_files.py | 16 +--------------- koku/masu/celery/tasks.py | 1 - 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index 15aac5180a..4839ad6ad9 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -10,6 +10,9 @@ from api.provider.models import Provider from masu.celery.tasks import fix_parquet_data_types from masu.processor.orchestrator import get_billing_month_start +from masu.processor import is_customer_large +from masu.processor.tasks import GET_REPORT_FILES_QUEUE +from masu.processor.tasks import GET_REPORT_FILES_QUEUE_XL from masu.util.common import strip_characters_from_column_name from reporting.provider.aws.models import TRINO_REQUIRED_COLUMNS as AWS_TRINO_REQUIRED_COLUMNS from reporting.provider.azure.models import TRINO_REQUIRED_COLUMNS as AZURE_TRINO_REQUIRED_COLUMNS @@ -94,16 +97,20 @@ def build_celery_tasks(self): providers = Provider.objects.filter(active=True, paused=False, type=self.provider_type) for provider in providers: + queue_name = GET_REPORT_FILES_QUEUE + if is_customer_large(provider.account["schema_name"]): + queue_name = GET_REPORT_FILES_QUEUE_XL + account = copy.deepcopy(provider.account) report_month = get_billing_month_start(self.bill_date) - async_result = fix_parquet_data_types.delay( + async_result = fix_parquet_data_types.s( schema_name=account.get("schema_name"), provider_type=account.get("provider_type"), provider_uuid=account.get("provider_uuid"), simulate=self.simulate, bill_date=report_month, cleaned_column_mapping=self.cleaned_column_mapping, - ) + ).apply_async(queue=queue_name) LOG.info( log_json( provider.uuid, diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 7ad50198e2..d3410939b9 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -171,21 +171,7 @@ def retrieve_verify_reload_s3_parquet(self): files_need_updated = self.file_tracker.get_files_that_need_updated() for s3_obj_key, converted_local_file_path in files_need_updated.items(): self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key - try: - LOG.info( - log_json(self.provider_uuid, msg="Deleting s3 parquet file.", context=self.logging_context) - ) - s3_bucket.Object(s3_obj_key).delete() - self.file_tracker.set_state(s3_object_key, self.file_tracker.S3_FILE_DELETED) - LOG.info( - log_json(self.provider_uuid, msg="Deletion of s3 parquet file.", context=self.logging_context) - ) - except ClientError as e: - LOG.info(f"Failed to delete {s3_object_key}: {str(e)}") - self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) - continue - - # An error here would cause a data gap. + # Overwrite s3 object with updated file data with open(converted_local_file_path, "rb") as new_file: LOG.info( log_json( diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index 4f5a2668d5..1ca5f09c4e 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -58,7 +58,6 @@ } -# TODO: Change the queue from the default queue @celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT) def fix_parquet_data_types(*args, **kwargs): verify_parquet = VerifyParquetFiles(*args, **kwargs) From 8acc01e02727ffaf2a96f9b42b87cc11b2bed151 Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Wed, 3 Jan 2024 14:57:09 +0000 Subject: [PATCH 08/30] lint --- koku/masu/api/upgrade_trino/util/task_handler.py | 2 +- koku/masu/api/upgrade_trino/util/verify_parquet_files.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index 4839ad6ad9..8c0845ddb3 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -9,8 +9,8 @@ from api.common import log_json from api.provider.models import Provider from masu.celery.tasks import fix_parquet_data_types -from masu.processor.orchestrator import get_billing_month_start from masu.processor import is_customer_large +from masu.processor.orchestrator import get_billing_month_start from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.processor.tasks import GET_REPORT_FILES_QUEUE_XL from masu.util.common import strip_characters_from_column_name diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index d3410939b9..fd341c1990 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -7,7 +7,6 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -from botocore.exceptions import ClientError from django.conf import settings from api.common import log_json From 75849f868dd38929bdb53f46e36f8442357ba9db Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Thu, 4 Jan 2024 10:46:29 +0000 Subject: [PATCH 09/30] fix tracking --- koku/masu/api/upgrade_trino/util/state_tracker.py | 1 - .../api/upgrade_trino/util/verify_parquet_files.py | 10 ++++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 1f26a87215..eb242da55c 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -15,7 +15,6 @@ class StateTracker: NO_CHANGES_NEEDED = "no_changes_needed" COERCE_REQUIRED = "coerce_required" SENT_TO_S3_COMPLETE = "sent_to_s3_complete" - S3_FILE_DELETED = "s3_file_deleted" SENT_TO_S3_FAILED = "sent_to_s3_failed" FAILED_DTYPE_CONVERSION = "failed_data_type_conversion" diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index fd341c1990..f04a68d329 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -7,6 +7,7 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from botocore.exceptions import ClientError from django.conf import settings from api.common import log_json @@ -180,8 +181,13 @@ def retrieve_verify_reload_s3_parquet(self): local_file_path=converted_local_file_path, ) ) - s3_bucket.upload_fileobj(new_file, s3_obj_key) - self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) + try: + s3_bucket.upload_fileobj(new_file, s3_obj_key) + self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) + except ClientError as e: + LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") + self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) + continue self.file_tracker.finalize_and_clean_up() def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): From 6f4ebdac748754b425d6267fad52148a495ded1a Mon Sep 17 00:00:00 2001 From: myersCody Date: Thu, 4 Jan 2024 15:55:43 -0500 Subject: [PATCH 10/30] Add ocp on cloud check and fix unittests from async switchover. --- .../test/test_verify_parquet_files.py | 30 +++++++++++-------- koku/masu/api/upgrade_trino/test/test_view.py | 20 +++++++------ .../util/verify_parquet_files.py | 10 +++++-- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index e79497dce2..a990d40788 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -215,6 +215,23 @@ def test_ocp_s3_paths(self): self.assertIn(expected_path, s3_prefixes) def test_other_providers_s3_paths(self): + def _build_expected_s3_paths(metadata): + expected_s3_paths = [] + path_kwargs = { + "account": self.schema_name, + "provider_type": metadata["type"], + "provider_uuid": metadata["uuid"], + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + path_kwargs["report_type"] = "raw" + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["report_type"] = "openshift" + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + return expected_s3_paths + bill_date = DateHelper().this_month_start test_metadata = [ {"uuid": self.aws_provider_uuid, "type": self.aws_provider.type.replace("-local", "")}, @@ -222,18 +239,7 @@ def test_other_providers_s3_paths(self): ] for metadata in test_metadata: with self.subTest(metadata=metadata): - expected_s3_paths = [] - path_kwargs = { - "account": self.schema_name, - "provider_type": metadata["type"], - "provider_uuid": metadata["uuid"], - "start_date": bill_date, - "data_type": Config.PARQUET_DATA_TYPE, - } - expected_s3_paths.append(get_path_prefix(**path_kwargs)) - path_kwargs["daily"] = True - path_kwargs["report_type"] = "raw" - expected_s3_paths.append(get_path_prefix(**path_kwargs)) + expected_s3_paths = _build_expected_s3_paths(metadata) verify_handler = VerifyParquetFiles( schema_name=self.schema_name, provider_uuid=metadata["uuid"], diff --git a/koku/masu/api/upgrade_trino/test/test_view.py b/koku/masu/api/upgrade_trino/test/test_view.py index 88339ea86a..da708f692e 100644 --- a/koku/masu/api/upgrade_trino/test/test_view.py +++ b/koku/masu/api/upgrade_trino/test/test_view.py @@ -12,6 +12,7 @@ from api.models import Provider from api.utils import DateHelper from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.test import MasuTestCase @@ -47,17 +48,18 @@ def test_acceptable_parameters(self, _): cleaned_column_mapping = FixParquetTaskHandler.clean_column_names(self.aws_provider.type) for parameters in acceptable_parameters: with self.subTest(parameters=parameters): - with patch("masu.celery.tasks.fix_parquet_data_types.delay") as patch_celery: + with patch("masu.celery.tasks.fix_parquet_data_types.apply_async") as patch_celery: response = self.client.get(reverse(self.ENDPOINT), parameters) self.assertEqual(response.status_code, 200) simulate = parameters.get("simulate", False) if simulate == "bad_value": simulate = False - patch_celery.assert_called_once_with( - schema_name=self.schema_name, - provider_type=Provider.PROVIDER_AWS_LOCAL, - provider_uuid=self.aws_provider.uuid, - simulate=simulate, - bill_date=self.bill_date, - cleaned_column_mapping=cleaned_column_mapping, - ) + async_kwargs = { + "schema_name": self.schema_name, + "provider_type": Provider.PROVIDER_AWS_LOCAL, + "provider_uuid": self.aws_provider.uuid, + "simulate": simulate, + "bill_date": self.bill_date, + "cleaned_column_mapping": cleaned_column_mapping, + } + patch_celery.assert_called_once_with((), async_kwargs, queue=GET_REPORT_FILES_QUEUE) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index f04a68d329..18e72679d7 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -9,6 +9,7 @@ import pyarrow.parquet as pq from botocore.exceptions import ClientError from django.conf import settings +from django_tenants.utils import schema_context from api.common import log_json from api.provider.models import Provider @@ -34,7 +35,6 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat self.simulate = simulate self.bill_date = bill_date self.file_tracker = StateTracker(provider_uuid) - self.openshift_data = False # Not sure if we need this self.report_types = self._set_report_types() self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) self.logging_context = { @@ -115,12 +115,16 @@ def _generate_s3_path_prefixes(self, bill_date): """ generates the s3 path prefixes. """ + with schema_context(self.schema_name): + ocp_on_cloud_check = Provider.objects.filter( + infrastructure__infrastructure_provider__uuid=self.provider_uuid + ).exists() path_prefixes = set() for report_type in self.report_types: path_prefixes.add(self._parquet_path_s3(bill_date, report_type)) path_prefixes.add(self._parquet_daily_path_s3(bill_date, report_type)) - if self.openshift_data: - path_prefixes.add(self._parquet_ocp_on_cloud_path_s3(bill_date)) + if ocp_on_cloud_check: + path_prefixes.add(self._parquet_ocp_on_cloud_path_s3(bill_date)) return path_prefixes # Stolen from parquet_report_processor From 2da00de87c582138240b5958a63a27177ee852dd Mon Sep 17 00:00:00 2001 From: myersCody Date: Thu, 4 Jan 2024 15:57:33 -0500 Subject: [PATCH 11/30] Change default queue at the task level. --- koku/masu/celery/tasks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index 1ca5f09c4e..38ae083f4a 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -40,6 +40,7 @@ from masu.processor.orchestrator import Orchestrator from masu.processor.tasks import autovacuum_tune_schema from masu.processor.tasks import DEFAULT +from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.processor.tasks import PRIORITY_QUEUE from masu.processor.tasks import REMOVE_EXPIRED_DATA_QUEUE from masu.prometheus_stats import QUEUES @@ -58,7 +59,7 @@ } -@celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT) +@celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=GET_REPORT_FILES_QUEUE) def fix_parquet_data_types(*args, **kwargs): verify_parquet = VerifyParquetFiles(*args, **kwargs) verify_parquet.retrieve_verify_reload_s3_parquet() From ad143dbbc47698de7c68789bd2802074c1cea528 Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Fri, 5 Jan 2024 13:26:19 +0000 Subject: [PATCH 12/30] Apply suggestions from code review --- koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py | 4 ++-- koku/masu/api/upgrade_trino/test/test_view.py | 4 ++-- koku/masu/api/upgrade_trino/view.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index a990d40788..926863d388 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -1,8 +1,8 @@ # -# Copyright 2023 Red Hat Inc. +# Copyright 2024 Red Hat Inc. # SPDX-License-Identifier: Apache-2.0 # -"""Test the hcs_report_data endpoint view.""" +"""Test the verify parquet files endpoint view.""" import os import shutil import tempfile diff --git a/koku/masu/api/upgrade_trino/test/test_view.py b/koku/masu/api/upgrade_trino/test/test_view.py index da708f692e..73f3361748 100644 --- a/koku/masu/api/upgrade_trino/test/test_view.py +++ b/koku/masu/api/upgrade_trino/test/test_view.py @@ -1,8 +1,8 @@ # -# Copyright 2023 Red Hat Inc. +# Copyright 2024 Red Hat Inc. # SPDX-License-Identifier: Apache-2.0 # -"""Test the hcs_report_data endpoint view.""" +"""Test the verify parquet files endpoint view.""" from unittest.mock import patch from uuid import uuid4 diff --git a/koku/masu/api/upgrade_trino/view.py b/koku/masu/api/upgrade_trino/view.py index d124cbf338..ae1524abc7 100644 --- a/koku/masu/api/upgrade_trino/view.py +++ b/koku/masu/api/upgrade_trino/view.py @@ -1,5 +1,5 @@ # -# Copyright 2023 Red Hat Inc. +# Copyright 2024 Red Hat Inc. # SPDX-License-Identifier: Apache-2.0 # """View for fixing parquet files endpoint.""" From 8b20494af824b53a91cbff781f9082a53723764b Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Fri, 5 Jan 2024 17:10:23 +0000 Subject: [PATCH 13/30] fix timestamp types --- .../api/upgrade_trino/util/verify_parquet_files.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 18e72679d7..c38ff9daa3 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -4,7 +4,6 @@ from pathlib import Path import ciso8601 -import pandas as pd import pyarrow as pa import pyarrow.parquet as pq from botocore.exceptions import ClientError @@ -48,11 +47,12 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat def _set_pyarrow_types(self, cleaned_column_mapping): mapping = {} for key, default_val in cleaned_column_mapping.items(): - if pd.isnull(default_val): - # TODO: Azure saves datetime as pa.timestamp("ms") - # TODO: AWS saves datetime as timestamp[ms, tz=UTC] - # Should we be storing in a standard type here? - mapping[key] = pa.timestamp("ms") + if str(default_val) == "NaT": + # Store original provider datetime type + if self.provider_type == "Azure": + mapping[key] = pa.timestamp("ms") + else: + mapping[key] = pa.timestamp("ms", tz="UTC") elif isinstance(default_val, str): mapping[key] = pa.string() elif isinstance(default_val, float): From ac93175e59eb66ea4f79bd7bdbd4eac13da9546b Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Fri, 5 Jan 2024 17:19:06 +0000 Subject: [PATCH 14/30] build date range --- koku/masu/api/upgrade_trino/util/verify_parquet_files.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index c38ff9daa3..5f349a2e17 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -12,6 +12,7 @@ from api.common import log_json from api.provider.models import Provider +from api.utils import DateHelper from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.config import Config from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE @@ -68,7 +69,8 @@ def _set_report_types(self): def _get_bill_dates(self): # However far back we want to fix. - return [ciso8601.parse_datetime(self.bill_date)] + dh = DateHelper() + return dh.list_months(ciso8601.parse_datetime(self.bill_date), dh.today.replace(tzinfo=None)) # Stolen from parquet_report_processor def _parquet_path_s3(self, bill_date, report_type): From c126259f05a878c31796d8a6b4e1710e32d92f1b Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Mon, 8 Jan 2024 12:28:50 +0000 Subject: [PATCH 15/30] fix unit tests --- .../test/test_verify_parquet_files.py | 4 ++-- .../util/verify_parquet_files.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 926863d388..1c18c0265c 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -52,8 +52,8 @@ def tearDown(self): def create_default_verify_handler(self): return VerifyParquetFiles( schema_name=self.schema_name, - provider_uuid=self.aws_provider_uuid, - provider_type=self.aws_provider.type, + provider_uuid=self.azure_provider_uuid, + provider_type=self.azure_provider.type, simulate=True, bill_date=datetime(2023, 1, 1), cleaned_column_mapping=self.required_columns, diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 5f349a2e17..5fe46dcdfa 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -70,7 +70,9 @@ def _set_report_types(self): def _get_bill_dates(self): # However far back we want to fix. dh = DateHelper() - return dh.list_months(ciso8601.parse_datetime(self.bill_date), dh.today.replace(tzinfo=None)) + return dh.list_months( + ciso8601.parse_datetime(self.bill_date).replace(tzinfo=None), dh.today.replace(tzinfo=None) + ) # Stolen from parquet_report_processor def _parquet_path_s3(self, bill_date, report_type): @@ -196,7 +198,7 @@ def retrieve_verify_reload_s3_parquet(self): continue self.file_tracker.finalize_and_clean_up() - def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): + def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names, timestamp_std): """Performs a transformation to change a double to a timestamp.""" if not field_names: return @@ -208,7 +210,7 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n # if len is 0 here we get an empty list, if it does # have a value for the field, overwrite it with bill_date replaced_values = [self.bill_date] * len(table[field.name]) - corrected_column = pa.array(replaced_values, type=pa.timestamp("ms")) + corrected_column = pa.array(replaced_values, type=timestamp_std) field = pa.field(field.name, corrected_column.type) fields.append(field) # Create a new schema @@ -263,8 +265,13 @@ def _coerce_parquet_data_type(self, parquet_file_path): expected_data_type=correct_data_type, ) ) - if field.type == pa.float64() and correct_data_type == pa.timestamp("ms"): + if ( + field.type == pa.float64() + and correct_data_type == pa.timestamp("ms") + or correct_data_type == pa.timestamp("ms", tz="UTC") + ): double_to_timestamp_fields.append(field.name) + timestamp_std = correct_data_type else: field = pa.field(field.name, correct_data_type) corrected_fields[field.name] = correct_data_type @@ -295,7 +302,9 @@ def _coerce_parquet_data_type(self, parquet_file_path): table = table.cast(new_schema) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) - self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) + self._perform_transformation_double_to_timestamp( + parquet_file_path, double_to_timestamp_fields, timestamp_std + ) # Signal that we need to send this update to S3. return self.file_tracker.COERCE_REQUIRED From 0b6fd87bcedc3807e9837ded90f16acd054509a9 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 09:11:37 -0500 Subject: [PATCH 16/30] Add bill date to the file tracking logic. --- .../api/upgrade_trino/util/state_tracker.py | 129 ++++++++++++------ .../api/upgrade_trino/util/task_handler.py | 42 +++--- .../util/verify_parquet_files.py | 50 ++++--- 3 files changed, 141 insertions(+), 80 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index eb242da55c..823f086563 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -10,6 +10,10 @@ class StateTracker: + """Tracks the state of each s3 file for the provider per bill date""" + + CONTEXT_KEY = "conversion_metadata" + CONVERTER_VERSION = "0" FOUND_S3_FILE = "found_s3_file" DOWNLOADED_LOCALLY = "downloaded_locally" NO_CHANGES_NEEDED = "no_changes_needed" @@ -23,15 +27,42 @@ def __init__(self, provider_uuid): self.tracker = {} self.local_files = {} self.provider_uuid = provider_uuid - self.context_key = "dtype_conversion" - self.failed_files_key = "dtype_failed_files" - def set_state(self, s3_obj_key, state): - self.tracker[s3_obj_key] = state + def add_to_queue(self, bill_date_data): + """ + Checks the provider object's metadata to see if we should start the task. + + Args: + conversion_metadata (dict): Metadata for the conversion. + + Returns: + bool: True if the task should be added to the queue, False otherwise. + """ + # TODO: turn the keys here into a variable. + if bill_date_data.get("version") != self.CONVERTER_VERSION: + # always kick off a task if the version does not match or exist. + return True + if bill_date_data.get("conversion_successful"): + # if the conversion was successful for this version do not kick + # off a task. + return False + return True + + def set_state(self, s3_obj_key, state, bill_date): + # TODO: make bill date a string. + bill_date_files = self.tracker.get(bill_date) + if bill_date_files: + bill_date_files[s3_obj_key] = state + self.tracker[bill_date] = bill_date_files + else: + self.tracker[bill_date] = {s3_obj_key: state} - def add_local_file(self, s3_obj_key, local_path): + def add_local_file(self, s3_obj_key, local_path, bill_date): + # TODO: make bill date a string. self.local_files[s3_obj_key] = local_path - self.tracker[s3_obj_key] = self.DOWNLOADED_LOCALLY + bill_date_files = self.tracker.get(bill_date) + if bill_date_files: + self.set_state(s3_obj_key, self.DOWNLOADED_LOCALLY, bill_date) def get_files_that_need_updated(self): """Returns a mapping of files in the s3 needs @@ -40,34 +71,46 @@ def get_files_that_need_updated(self): {s3_object_key: local_file_path} for """ mapping = {} - for s3_obj_key, state in self.tracker.items(): - if state == self.COERCE_REQUIRED: - mapping[s3_obj_key] = self.local_files.get(s3_obj_key) + for bill_date, bill_metadata in self.tracker.items(): + bill_date_data = {} + for s3_obj_key, state in bill_metadata.items(): + if state == self.COERCE_REQUIRED: + bill_date_data[s3_obj_key] = self.local_files.get(s3_obj_key) + mapping[bill_date] = bill_date_data return mapping def generate_simulate_messages(self): """ Generates the simulate messages. """ - files_count = 0 - files_failed = [] - files_need_updated = [] - files_correct = [] - for s3_obj_key, state in self.tracker.items(): - files_count += 1 - if state == self.COERCE_REQUIRED: - files_need_updated.append(s3_obj_key) - elif state == self.NO_CHANGES_NEEDED: - files_correct.append(s3_obj_key) - else: - files_failed.append(s3_obj_key) - simulate_info = { - "Files that have all correct data_types.": files_correct, - "Files that need to be updated.": files_need_updated, - "Files that failed to convert.": files_failed, - } - for substring, files_list in simulate_info.items(): - LOG.info(log_json(self.provider_uuid, msg=substring, file_count=len(files_list), total_count=files_count)) + for bill_date, bill_data in self.tracker.items(): + files_count = 0 + files_failed = [] + files_need_updated = [] + files_correct = [] + for s3_obj_key, state in bill_data.items(): + files_count += 1 + if state == self.COERCE_REQUIRED: + files_need_updated.append(s3_obj_key) + elif state == self.NO_CHANGES_NEEDED: + files_correct.append(s3_obj_key) + else: + files_failed.append(s3_obj_key) + simulate_info = { + "Files that have all correct data_types.": files_correct, + "Files that need to be updated.": files_need_updated, + "Files that failed to convert.": files_failed, + } + for substring, files_list in simulate_info.items(): + LOG.info( + log_json( + self.provider_uuid, + msg=substring, + file_count=len(files_list), + total_count=files_count, + bill_date=bill_date, + ) + ) self._clean_local_files() def _clean_local_files(self): @@ -75,23 +118,29 @@ def _clean_local_files(self): if os.path.exists(file_path): os.remove(file_path) - def _check_for_incomplete_files(self): - incomplete_files = [] - for file_prefix, state in self.tracker.items(): - if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: - file_metadata = {"key": file_prefix, "state": state} - incomplete_files.append(file_metadata) - return incomplete_files + def _create_bill_date_metadata(self): + # Check for incomplete files + metadata = {} + for bill_date, bill_metadata in self.tracker.items(): + bill_date_data = {"version": self.CONVERTER_VERSION} + incomplete_files = [] + for file_prefix, state in bill_metadata.items(): + if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: + file_metadata = {"key": file_prefix, "state": state} + incomplete_files.append(file_metadata) + if incomplete_files: + bill_date_data["conversion_successful"] = False + bill_date_data["dtype_failed_files"] = incomplete_files + if not incomplete_files: + bill_date_data["conversion_successful"] = True + metadata[bill_date] = bill_date_data + return metadata def _check_if_complete(self): - incomplete_files = self._check_for_incomplete_files() try: manager = ProviderManager(self.provider_uuid) context = manager.get_additional_context() - context[self.context_key] = True - if incomplete_files: - context[self.context_key] = False - context[self.failed_files_key] = incomplete_files + context[self.CONTEXT_KEY] = self._create_bill_date_metadata() manager.model.set_additional_context(context) LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context)) except ProviderManagerError: diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index 8c0845ddb3..e6fa722538 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -8,6 +8,7 @@ from api.common import log_json from api.provider.models import Provider +from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.celery.tasks import fix_parquet_data_types from masu.processor import is_customer_large from masu.processor.orchestrator import get_billing_month_start @@ -32,6 +33,7 @@ class FixParquetTaskHandler: provider_type: Optional[str] = field(default=None) simulate: Optional[bool] = field(default=False) cleaned_column_mapping: Optional[dict] = field(default=None) + state_tracker: StateTracker = field(init=False) # Node role is the only column we add manually for OCP # Therefore, it is the only column that can be incorrect @@ -47,7 +49,7 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": """Create an instance from query parameters.""" reprocess_kwargs = cls() if start_date := query_params.get("start_date"): - reprocess_kwargs.bill_date = start_date + reprocess_kwargs.bill_date = get_billing_month_start(start_date) if provider_uuid := query_params.get("provider_uuid"): provider = Provider.objects.filter(uuid=provider_uuid).first() @@ -103,22 +105,26 @@ def build_celery_tasks(self): account = copy.deepcopy(provider.account) report_month = get_billing_month_start(self.bill_date) - async_result = fix_parquet_data_types.s( - schema_name=account.get("schema_name"), - provider_type=account.get("provider_type"), - provider_uuid=account.get("provider_uuid"), - simulate=self.simulate, - bill_date=report_month, - cleaned_column_mapping=self.cleaned_column_mapping, - ).apply_async(queue=queue_name) - LOG.info( - log_json( - provider.uuid, - msg="Calling fix_parquet_data_types", - schema=account.get("schema_name"), - provider_uuid=provider.uuid, - task_id=str(async_result), + tracker = StateTracker(self.provider_uuid) + conversion_metadata = provider.additional_context.get(StateTracker.CONTEXT_KEY, {}) + bill_data = conversion_metadata.get(report_month, {}) + if tracker.add_to_queue(bill_data): + async_result = fix_parquet_data_types.s( + schema_name=account.get("schema_name"), + provider_type=account.get("provider_type"), + provider_uuid=account.get("provider_uuid"), + simulate=self.simulate, + bill_date=report_month, + cleaned_column_mapping=self.cleaned_column_mapping, + ).apply_async(queue=queue_name) + LOG.info( + log_json( + provider.uuid, + msg="Calling fix_parquet_data_types", + schema=account.get("schema_name"), + provider_uuid=provider.uuid, + task_id=str(async_result), + ) ) - ) - async_results.append(str(async_result)) + async_results.append(str(async_result)) return async_results diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 5fe46dcdfa..ec6dea46e8 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -24,7 +24,6 @@ class VerifyParquetFiles: - CONVERTER_VERSION = 1.0 S3_OBJ_LOG_KEY = "s3_object_key" S3_PREFIX_LOG_KEY = "s3_prefix" @@ -157,7 +156,7 @@ def retrieve_verify_reload_s3_parquet(self): for s3_object in s3_bucket.objects.filter(Prefix=prefix): s3_object_key = s3_object.key self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key - self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE) + self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE, bill_date) local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) LOG.info( log_json( @@ -167,8 +166,10 @@ def retrieve_verify_reload_s3_parquet(self): ) ) s3_bucket.download_file(s3_object_key, local_file_path) - self.file_tracker.add_local_file(s3_object_key, local_file_path) - self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) + self.file_tracker.add_local_file(s3_object_key, local_file_path, bill_date) + self.file_tracker.set_state( + s3_object_key, self._coerce_parquet_data_type(local_file_path), bill_date + ) del self.logging_context[self.S3_OBJ_LOG_KEY] del self.logging_context[self.S3_PREFIX_LOG_KEY] @@ -177,25 +178,30 @@ def retrieve_verify_reload_s3_parquet(self): return False else: files_need_updated = self.file_tracker.get_files_that_need_updated() - for s3_obj_key, converted_local_file_path in files_need_updated.items(): - self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key - # Overwrite s3 object with updated file data - with open(converted_local_file_path, "rb") as new_file: - LOG.info( - log_json( - self.provider_uuid, - msg="Uploading revised parquet file.", - context=self.logging_context, - local_file_path=converted_local_file_path, + for bill_date, bill_date_data in files_need_updated.items(): + for s3_obj_key, converted_local_file_path in bill_date_data.items(): + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key + # Overwrite s3 object with updated file data + with open(converted_local_file_path, "rb") as new_file: + LOG.info( + log_json( + self.provider_uuid, + msg="Uploading revised parquet file.", + context=self.logging_context, + local_file_path=converted_local_file_path, + ) ) - ) - try: - s3_bucket.upload_fileobj(new_file, s3_obj_key) - self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) - except ClientError as e: - LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") - self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) - continue + try: + s3_bucket.upload_fileobj( + new_file, + s3_obj_key, + ExtraArgs={"Metadata": {"converter_version": StateTracker.CONVERTER_VERSION}}, + ) + self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE, bill_date) + except ClientError as e: + LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") + self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED, bill_date) + continue self.file_tracker.finalize_and_clean_up() def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names, timestamp_std): From 1c394b768c8ebdee1608b68d47ff96ef7d12f090 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 09:14:08 -0500 Subject: [PATCH 17/30] Fix add_local_file --- koku/masu/api/upgrade_trino/util/state_tracker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 823f086563..2f969a6e41 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -60,9 +60,7 @@ def set_state(self, s3_obj_key, state, bill_date): def add_local_file(self, s3_obj_key, local_path, bill_date): # TODO: make bill date a string. self.local_files[s3_obj_key] = local_path - bill_date_files = self.tracker.get(bill_date) - if bill_date_files: - self.set_state(s3_obj_key, self.DOWNLOADED_LOCALLY, bill_date) + self.set_state(s3_obj_key, self.DOWNLOADED_LOCALLY, bill_date) def get_files_that_need_updated(self): """Returns a mapping of files in the s3 needs From 40532964420f112c91530591775d9d3194129eb4 Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Mon, 8 Jan 2024 14:47:25 +0000 Subject: [PATCH 18/30] fix str date --- koku/masu/api/upgrade_trino/util/verify_parquet_files.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index ec6dea46e8..e50dd54be3 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -143,6 +143,7 @@ def retrieve_verify_reload_s3_parquet(self): s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) bill_dates = self._get_bill_dates() for bill_date in bill_dates: + str_date = bill_date.strftime("%Y-%m-%d") for prefix in self._generate_s3_path_prefixes(bill_date): self.logging_context[self.S3_PREFIX_LOG_KEY] = prefix LOG.info( @@ -156,7 +157,7 @@ def retrieve_verify_reload_s3_parquet(self): for s3_object in s3_bucket.objects.filter(Prefix=prefix): s3_object_key = s3_object.key self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key - self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE, bill_date) + self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE, str_date) local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) LOG.info( log_json( @@ -166,9 +167,9 @@ def retrieve_verify_reload_s3_parquet(self): ) ) s3_bucket.download_file(s3_object_key, local_file_path) - self.file_tracker.add_local_file(s3_object_key, local_file_path, bill_date) + self.file_tracker.add_local_file(s3_object_key, local_file_path, str_date) self.file_tracker.set_state( - s3_object_key, self._coerce_parquet_data_type(local_file_path), bill_date + s3_object_key, self._coerce_parquet_data_type(local_file_path), str_date ) del self.logging_context[self.S3_OBJ_LOG_KEY] del self.logging_context[self.S3_PREFIX_LOG_KEY] From c4599bb29ee52e46b7bd41cde3f8d6bbdfba99f8 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 10:18:22 -0500 Subject: [PATCH 19/30] Fix unittests. --- .../test/test_verify_parquet_files.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 1c18c0265c..620f0165c2 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -45,6 +45,7 @@ def setUp(self): "index": False, } self.suffix = ".parquet" + self.bill_date = str(DateHelper().this_month_start) def tearDown(self): shutil.rmtree(self.temp_dir) @@ -55,7 +56,7 @@ def create_default_verify_handler(self): provider_uuid=self.azure_provider_uuid, provider_type=self.azure_provider.type, simulate=True, - bill_date=datetime(2023, 1, 1), + bill_date=self.bill_date, cleaned_column_mapping=self.required_columns, ) @@ -114,11 +115,12 @@ def test_coerce_parquet_data_type_no_changes_needed(self): with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file, self.bill_date) return_state = verify_handler._coerce_parquet_data_type(temp_file) - verify_handler.file_tracker.set_state(temp_file.name, return_state) + verify_handler.file_tracker.set_state(temp_file.name, return_state, self.bill_date) self.assertEqual(return_state, StateTracker.NO_CHANGES_NEEDED) - self.assertEqual(verify_handler.file_tracker._check_for_incomplete_files(), []) + bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() + self.assertTrue(bill_metadata.get(self.bill_date, {}).get("conversion_successful")) def test_coerce_parquet_data_type_coerce_needed(self): """Test that files created through reindex are fixed correctly.""" @@ -128,12 +130,12 @@ def test_coerce_parquet_data_type_coerce_needed(self): temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") data_frame.to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(filename, temp_file) + verify_handler.file_tracker.add_local_file(filename, temp_file, self.bill_date) return_state = verify_handler._coerce_parquet_data_type(temp_file) self.assertEqual(return_state, StateTracker.COERCE_REQUIRED) - verify_handler.file_tracker.set_state(filename, return_state) + verify_handler.file_tracker.set_state(filename, return_state, self.bill_date) files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() - self.assertTrue(files_need_updating.get(filename)) + self.assertTrue(files_need_updating.get(self.bill_date, {}).get(filename)) table = pq.read_table(temp_file) schema = table.schema for field in schema: @@ -150,11 +152,12 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file, self.bill_date) return_state = verify_handler._coerce_parquet_data_type(temp_file) - verify_handler.file_tracker.set_state(temp_file.name, return_state) + verify_handler.file_tracker.set_state(temp_file.name, return_state, self.bill_date) self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) - self.assertNotEqual(verify_handler.file_tracker._check_for_incomplete_files(), []) + bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() + self.assertNotEqual(bill_metadata, {}) def test_oci_s3_paths(self): """test path generation for oci sources.""" From ef68d81b4a068473790a0b61d39731f38f1d0b89 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 10:29:18 -0500 Subject: [PATCH 20/30] Clean up comments. --- koku/masu/api/upgrade_trino/util/state_tracker.py | 2 -- koku/masu/api/upgrade_trino/util/verify_parquet_files.py | 6 ------ 2 files changed, 8 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 2f969a6e41..842a80844e 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -49,7 +49,6 @@ def add_to_queue(self, bill_date_data): return True def set_state(self, s3_obj_key, state, bill_date): - # TODO: make bill date a string. bill_date_files = self.tracker.get(bill_date) if bill_date_files: bill_date_files[s3_obj_key] = state @@ -58,7 +57,6 @@ def set_state(self, s3_obj_key, state, bill_date): self.tracker[bill_date] = {s3_obj_key: state} def add_local_file(self, s3_obj_key, local_path, bill_date): - # TODO: make bill date a string. self.local_files[s3_obj_key] = local_path self.set_state(s3_obj_key, self.DOWNLOADED_LOCALLY, bill_date) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index e50dd54be3..f72427332f 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -67,13 +67,11 @@ def _set_report_types(self): return [None] def _get_bill_dates(self): - # However far back we want to fix. dh = DateHelper() return dh.list_months( ciso8601.parse_datetime(self.bill_date).replace(tzinfo=None), dh.today.replace(tzinfo=None) ) - # Stolen from parquet_report_processor def _parquet_path_s3(self, bill_date, report_type): """The path in the S3 bucket where Parquet files are loaded.""" return get_path_prefix( @@ -85,7 +83,6 @@ def _parquet_path_s3(self, bill_date, report_type): report_type=report_type, ) - # Stolen from parquet_report_processor def _parquet_daily_path_s3(self, bill_date, report_type): """The path in the S3 bucket where Parquet files are loaded.""" if report_type is None: @@ -100,7 +97,6 @@ def _parquet_daily_path_s3(self, bill_date, report_type): daily=True, ) - # Stolen from parquet_report_processor def _parquet_ocp_on_cloud_path_s3(self, bill_date): """The path in the S3 bucket where Parquet files are loaded.""" return get_path_prefix( @@ -113,7 +109,6 @@ def _parquet_ocp_on_cloud_path_s3(self, bill_date): daily=True, ) - # Stolen from parquet_report_processor def _generate_s3_path_prefixes(self, bill_date): """ generates the s3 path prefixes. @@ -130,7 +125,6 @@ def _generate_s3_path_prefixes(self, bill_date): path_prefixes.add(self._parquet_ocp_on_cloud_path_s3(bill_date)) return path_prefixes - # Stolen from parquet_report_processor @property def local_path(self): local_path = Path(Config.TMP_DIR, self.schema_name, str(self.provider_uuid)) From 9d1ec9457124487e4545e676f46692720605f764 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 13:18:22 -0500 Subject: [PATCH 21/30] Create a task per billing period start. --- .../test/test_verify_parquet_files.py | 21 +-- koku/masu/api/upgrade_trino/test/test_view.py | 6 +- koku/masu/api/upgrade_trino/util/__init__.py | 6 + .../api/upgrade_trino/util/state_tracker.py | 131 +++++++++--------- .../api/upgrade_trino/util/task_handler.py | 58 ++++---- .../util/verify_parquet_files.py | 109 +++++++-------- 6 files changed, 167 insertions(+), 164 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 620f0165c2..bc7029eac8 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -15,6 +15,7 @@ import pyarrow.parquet as pq from api.utils import DateHelper +from masu.api.upgrade_trino.util.state_tracker import CONTEXT_KEY_MAPPING from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles @@ -73,7 +74,6 @@ def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): ] for metadata in test_metadata: with self.subTest(metadata=metadata): - bill_date = str(DateHelper().this_month_start) required_columns = FixParquetTaskHandler.clean_column_names(metadata["type"]) data_frame = pd.DataFrame() data_frame = data_frame.reindex(columns=required_columns.keys()) @@ -86,7 +86,7 @@ def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): provider_uuid=metadata["uuid"], provider_type=metadata["type"], simulate=False, - bill_date=bill_date, + bill_date=self.bill_date, cleaned_column_mapping=required_columns, ) prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) @@ -115,12 +115,12 @@ def test_coerce_parquet_data_type_no_changes_needed(self): with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(temp_file.name, temp_file, self.bill_date) + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) - verify_handler.file_tracker.set_state(temp_file.name, return_state, self.bill_date) + verify_handler.file_tracker.set_state(temp_file.name, return_state) self.assertEqual(return_state, StateTracker.NO_CHANGES_NEEDED) bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() - self.assertTrue(bill_metadata.get(self.bill_date, {}).get("conversion_successful")) + self.assertTrue(bill_metadata.get(CONTEXT_KEY_MAPPING["successful"])) def test_coerce_parquet_data_type_coerce_needed(self): """Test that files created through reindex are fixed correctly.""" @@ -130,12 +130,12 @@ def test_coerce_parquet_data_type_coerce_needed(self): temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") data_frame.to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(filename, temp_file, self.bill_date) + verify_handler.file_tracker.add_local_file(filename, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) self.assertEqual(return_state, StateTracker.COERCE_REQUIRED) - verify_handler.file_tracker.set_state(filename, return_state, self.bill_date) + verify_handler.file_tracker.set_state(filename, return_state) files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() - self.assertTrue(files_need_updating.get(self.bill_date, {}).get(filename)) + self.assertTrue(files_need_updating.get(filename)) table = pq.read_table(temp_file) schema = table.schema for field in schema: @@ -152,11 +152,12 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(temp_file.name, temp_file, self.bill_date) + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) - verify_handler.file_tracker.set_state(temp_file.name, return_state, self.bill_date) + verify_handler.file_tracker.set_state(temp_file.name, return_state) self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() + self.assertFalse(bill_metadata.get(self.bill_date, {}).get(CONTEXT_KEY_MAPPING["successful"])) self.assertNotEqual(bill_metadata, {}) def test_oci_s3_paths(self): diff --git a/koku/masu/api/upgrade_trino/test/test_view.py b/koku/masu/api/upgrade_trino/test/test_view.py index 73f3361748..bbf169921c 100644 --- a/koku/masu/api/upgrade_trino/test/test_view.py +++ b/koku/masu/api/upgrade_trino/test/test_view.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # """Test the verify parquet files endpoint view.""" +import datetime from unittest.mock import patch from uuid import uuid4 @@ -10,7 +11,6 @@ from django.urls import reverse from api.models import Provider -from api.utils import DateHelper from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.test import MasuTestCase @@ -19,7 +19,7 @@ @override_settings(ROOT_URLCONF="masu.urls") class TestUpgradeTrinoView(MasuTestCase): ENDPOINT = "fix_parquet" - bill_date = DateHelper().month_start("2023-12-01") + bill_date = datetime.datetime(2024, 1, 1, 0, 0) @patch("koku.middleware.MASU", return_value=True) def test_required_parameters_failure(self, _): @@ -41,9 +41,9 @@ def test_provider_uuid_does_not_exist(self, _): def test_acceptable_parameters(self, _): """Test that the endpoint accepts""" acceptable_parameters = [ - {"start_date": self.bill_date, "provider_type": self.aws_provider.type}, {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": True}, {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": "bad_value"}, + {"start_date": self.bill_date, "provider_type": self.aws_provider.type}, ] cleaned_column_mapping = FixParquetTaskHandler.clean_column_names(self.aws_provider.type) for parameters in acceptable_parameters: diff --git a/koku/masu/api/upgrade_trino/util/__init__.py b/koku/masu/api/upgrade_trino/util/__init__.py index e69de29bb2..d1d41a128b 100644 --- a/koku/masu/api/upgrade_trino/util/__init__.py +++ b/koku/masu/api/upgrade_trino/util/__init__.py @@ -0,0 +1,6 @@ +CONTEXT_KEY_MAPPING = { + "metadata": "conversion_metadata", + "version": "version", + "successful": "successful", + "failed_files": "dtype_failed_files", +} diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 842a80844e..b13354774d 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -1,9 +1,11 @@ import logging import os +from datetime import date from api.common import log_json from api.provider.provider_manager import ProviderManager from api.provider.provider_manager import ProviderManagerError +from masu.api.upgrade_trino.util import CONTEXT_KEY_MAPPING LOG = logging.getLogger(__name__) @@ -12,7 +14,6 @@ class StateTracker: """Tracks the state of each s3 file for the provider per bill date""" - CONTEXT_KEY = "conversion_metadata" CONVERTER_VERSION = "0" FOUND_S3_FILE = "found_s3_file" DOWNLOADED_LOCALLY = "downloaded_locally" @@ -22,13 +23,14 @@ class StateTracker: SENT_TO_S3_FAILED = "sent_to_s3_failed" FAILED_DTYPE_CONVERSION = "failed_data_type_conversion" - def __init__(self, provider_uuid): + def __init__(self, provider_uuid: str, bill_date: date): self.files = [] self.tracker = {} self.local_files = {} self.provider_uuid = provider_uuid + self.bill_date_str = bill_date.strftime("%Y-%m-%d") - def add_to_queue(self, bill_date_data): + def add_to_queue(self, conversion_metadata): """ Checks the provider object's metadata to see if we should start the task. @@ -38,27 +40,30 @@ def add_to_queue(self, bill_date_data): Returns: bool: True if the task should be added to the queue, False otherwise. """ - # TODO: turn the keys here into a variable. - if bill_date_data.get("version") != self.CONVERTER_VERSION: + bill_metadata = conversion_metadata.get(self.bill_date_str, {}) + if bill_metadata.get(CONTEXT_KEY_MAPPING["version"]) != self.CONVERTER_VERSION: # always kick off a task if the version does not match or exist. return True - if bill_date_data.get("conversion_successful"): + if bill_metadata.get(CONTEXT_KEY_MAPPING["successful"]): # if the conversion was successful for this version do not kick # off a task. + LOG.info( + log_json( + self.provider_uuid, + msg="Conversion already marked as successful", + bill_date=self.bill_date_str, + provider_uuid=self.provider_uuid, + ) + ) return False return True - def set_state(self, s3_obj_key, state, bill_date): - bill_date_files = self.tracker.get(bill_date) - if bill_date_files: - bill_date_files[s3_obj_key] = state - self.tracker[bill_date] = bill_date_files - else: - self.tracker[bill_date] = {s3_obj_key: state} + def set_state(self, s3_obj_key, state): + self.tracker[s3_obj_key] = state - def add_local_file(self, s3_obj_key, local_path, bill_date): + def add_local_file(self, s3_obj_key, local_path): self.local_files[s3_obj_key] = local_path - self.set_state(s3_obj_key, self.DOWNLOADED_LOCALLY, bill_date) + self.tracker[s3_obj_key] = self.DOWNLOADED_LOCALLY def get_files_that_need_updated(self): """Returns a mapping of files in the s3 needs @@ -67,46 +72,43 @@ def get_files_that_need_updated(self): {s3_object_key: local_file_path} for """ mapping = {} - for bill_date, bill_metadata in self.tracker.items(): - bill_date_data = {} - for s3_obj_key, state in bill_metadata.items(): - if state == self.COERCE_REQUIRED: - bill_date_data[s3_obj_key] = self.local_files.get(s3_obj_key) - mapping[bill_date] = bill_date_data + for s3_obj_key, state in self.tracker.items(): + if state == self.COERCE_REQUIRED: + mapping[s3_obj_key] = self.local_files.get(s3_obj_key) return mapping def generate_simulate_messages(self): """ Generates the simulate messages. """ - for bill_date, bill_data in self.tracker.items(): - files_count = 0 - files_failed = [] - files_need_updated = [] - files_correct = [] - for s3_obj_key, state in bill_data.items(): - files_count += 1 - if state == self.COERCE_REQUIRED: - files_need_updated.append(s3_obj_key) - elif state == self.NO_CHANGES_NEEDED: - files_correct.append(s3_obj_key) - else: - files_failed.append(s3_obj_key) - simulate_info = { - "Files that have all correct data_types.": files_correct, - "Files that need to be updated.": files_need_updated, - "Files that failed to convert.": files_failed, - } - for substring, files_list in simulate_info.items(): - LOG.info( - log_json( - self.provider_uuid, - msg=substring, - file_count=len(files_list), - total_count=files_count, - bill_date=bill_date, - ) + + files_count = 0 + files_failed = [] + files_need_updated = [] + files_correct = [] + for s3_obj_key, state in self.tracker.items(): + files_count += 1 + if state == self.COERCE_REQUIRED: + files_need_updated.append(s3_obj_key) + elif state == self.NO_CHANGES_NEEDED: + files_correct.append(s3_obj_key) + else: + files_failed.append(s3_obj_key) + simulate_info = { + "Files that have all correct data_types.": files_correct, + "Files that need to be updated.": files_need_updated, + "Files that failed to convert.": files_failed, + } + for substring, files_list in simulate_info.items(): + LOG.info( + log_json( + self.provider_uuid, + msg=substring, + file_count=len(files_list), + total_count=files_count, + bill_date=self.bill_date_str, ) + ) self._clean_local_files() def _clean_local_files(self): @@ -116,27 +118,26 @@ def _clean_local_files(self): def _create_bill_date_metadata(self): # Check for incomplete files - metadata = {} - for bill_date, bill_metadata in self.tracker.items(): - bill_date_data = {"version": self.CONVERTER_VERSION} - incomplete_files = [] - for file_prefix, state in bill_metadata.items(): - if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: - file_metadata = {"key": file_prefix, "state": state} - incomplete_files.append(file_metadata) - if incomplete_files: - bill_date_data["conversion_successful"] = False - bill_date_data["dtype_failed_files"] = incomplete_files - if not incomplete_files: - bill_date_data["conversion_successful"] = True - metadata[bill_date] = bill_date_data - return metadata + bill_date_data = {"version": self.CONVERTER_VERSION} + incomplete_files = [] + for file_prefix, state in self.tracker.items(): + if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: + file_metadata = {"key": file_prefix, "state": state} + incomplete_files.append(file_metadata) + if incomplete_files: + bill_date_data[CONTEXT_KEY_MAPPING["successful"]] = False + bill_date_data[CONTEXT_KEY_MAPPING["failed_files"]] = incomplete_files + if not incomplete_files: + bill_date_data[CONTEXT_KEY_MAPPING["successful"]] = True + return bill_date_data def _check_if_complete(self): try: manager = ProviderManager(self.provider_uuid) context = manager.get_additional_context() - context[self.CONTEXT_KEY] = self._create_bill_date_metadata() + conversion_metadata = context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + conversion_metadata[self.bill_date_str] = self._create_bill_date_metadata() + context[CONTEXT_KEY_MAPPING["metadata"]] = conversion_metadata manager.model.set_additional_context(context) LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context)) except ProviderManagerError: @@ -145,5 +146,3 @@ def _check_if_complete(self): def finalize_and_clean_up(self): self._check_if_complete() self._clean_local_files() - # We can decide if we want to record - # failed parquet conversion diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index e6fa722538..6625c33c2c 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -4,14 +4,16 @@ from dataclasses import field from typing import Optional +from dateutil import parser from django.http import QueryDict from api.common import log_json from api.provider.models import Provider +from api.utils import DateHelper +from masu.api.upgrade_trino.util import CONTEXT_KEY_MAPPING from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.celery.tasks import fix_parquet_data_types from masu.processor import is_customer_large -from masu.processor.orchestrator import get_billing_month_start from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.processor.tasks import GET_REPORT_FILES_QUEUE_XL from masu.util.common import strip_characters_from_column_name @@ -28,12 +30,11 @@ class RequiredParametersError(Exception): @dataclass class FixParquetTaskHandler: - bill_date: Optional[str] = field(default=None) + start_date: Optional[str] = field(default=None) provider_uuid: Optional[str] = field(default=None) provider_type: Optional[str] = field(default=None) simulate: Optional[bool] = field(default=False) cleaned_column_mapping: Optional[dict] = field(default=None) - state_tracker: StateTracker = field(init=False) # Node role is the only column we add manually for OCP # Therefore, it is the only column that can be incorrect @@ -49,7 +50,8 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": """Create an instance from query parameters.""" reprocess_kwargs = cls() if start_date := query_params.get("start_date"): - reprocess_kwargs.bill_date = get_billing_month_start(start_date) + if start_date: + reprocess_kwargs.start_date = parser.parse(start_date).replace(day=1) if provider_uuid := query_params.get("provider_uuid"): provider = Provider.objects.filter(uuid=provider_uuid).first() @@ -67,7 +69,7 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": if not reprocess_kwargs.provider_type and not reprocess_kwargs.provider_uuid: raise RequiredParametersError("provider_uuid or provider_type must be supplied") - if not reprocess_kwargs.bill_date: + if not reprocess_kwargs.start_date: raise RequiredParametersError("start_date must be supplied as a parameter.") reprocess_kwargs.cleaned_column_mapping = reprocess_kwargs.clean_column_names(reprocess_kwargs.provider_type) @@ -104,27 +106,29 @@ def build_celery_tasks(self): queue_name = GET_REPORT_FILES_QUEUE_XL account = copy.deepcopy(provider.account) - report_month = get_billing_month_start(self.bill_date) - tracker = StateTracker(self.provider_uuid) - conversion_metadata = provider.additional_context.get(StateTracker.CONTEXT_KEY, {}) - bill_data = conversion_metadata.get(report_month, {}) - if tracker.add_to_queue(bill_data): - async_result = fix_parquet_data_types.s( - schema_name=account.get("schema_name"), - provider_type=account.get("provider_type"), - provider_uuid=account.get("provider_uuid"), - simulate=self.simulate, - bill_date=report_month, - cleaned_column_mapping=self.cleaned_column_mapping, - ).apply_async(queue=queue_name) - LOG.info( - log_json( - provider.uuid, - msg="Calling fix_parquet_data_types", - schema=account.get("schema_name"), - provider_uuid=provider.uuid, - task_id=str(async_result), + conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + dh = DateHelper() + bill_datetimes = dh.list_months(self.start_date, dh.today.replace(tzinfo=None)) + for bill_date in bill_datetimes: + tracker = StateTracker(self.provider_uuid, bill_date) + if tracker.add_to_queue(conversion_metadata): + async_result = fix_parquet_data_types.s( + schema_name=account.get("schema_name"), + provider_type=account.get("provider_type"), + provider_uuid=account.get("provider_uuid"), + simulate=self.simulate, + bill_date=bill_date, + cleaned_column_mapping=self.cleaned_column_mapping, + ).apply_async(queue=queue_name) + LOG.info( + log_json( + provider.uuid, + msg="Calling fix_parquet_data_types", + schema=account.get("schema_name"), + provider_uuid=provider.uuid, + task_id=str(async_result), + bill_date=bill_date, + ) ) - ) - async_results.append(str(async_result)) + async_results.append(str(async_result)) return async_results diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index f72427332f..d4b85deca2 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -12,7 +12,6 @@ from api.common import log_json from api.provider.models import Provider -from api.utils import DateHelper from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.config import Config from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE @@ -32,8 +31,8 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat self.provider_uuid = uuid.UUID(provider_uuid) self.provider_type = provider_type.replace("-local", "") self.simulate = simulate - self.bill_date = bill_date - self.file_tracker = StateTracker(provider_uuid) + self.bill_date = self._bill_date(bill_date) + self.file_tracker = StateTracker(provider_uuid, self.bill_date) self.report_types = self._set_report_types() self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) self.logging_context = { @@ -44,6 +43,12 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat "bill_date": self.bill_date, } + def _bill_date(self, bill_date): + """bill_date""" + if isinstance(bill_date, str): + return ciso8601.parse_datetime(bill_date).replace(tzinfo=None).date() + return bill_date + def _set_pyarrow_types(self, cleaned_column_mapping): mapping = {} for key, default_val in cleaned_column_mapping.items(): @@ -66,12 +71,6 @@ def _set_report_types(self): return ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] return [None] - def _get_bill_dates(self): - dh = DateHelper() - return dh.list_months( - ciso8601.parse_datetime(self.bill_date).replace(tzinfo=None), dh.today.replace(tzinfo=None) - ) - def _parquet_path_s3(self, bill_date, report_type): """The path in the S3 bucket where Parquet files are loaded.""" return get_path_prefix( @@ -135,68 +134,62 @@ def retrieve_verify_reload_s3_parquet(self): """Retrieves the s3 files from s3""" s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) - bill_dates = self._get_bill_dates() - for bill_date in bill_dates: - str_date = bill_date.strftime("%Y-%m-%d") - for prefix in self._generate_s3_path_prefixes(bill_date): - self.logging_context[self.S3_PREFIX_LOG_KEY] = prefix + for prefix in self._generate_s3_path_prefixes(self.bill_date): + self.logging_context[self.S3_PREFIX_LOG_KEY] = prefix + LOG.info( + log_json( + self.provider_uuid, + msg="Retrieving files from S3.", + context=self.logging_context, + prefix=prefix, + ) + ) + for s3_object in s3_bucket.objects.filter(Prefix=prefix): + s3_object_key = s3_object.key + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key + self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE) + local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) LOG.info( log_json( self.provider_uuid, - msg="Retrieving files from S3.", + msg="Downloading file locally", context=self.logging_context, - prefix=prefix, ) ) - for s3_object in s3_bucket.objects.filter(Prefix=prefix): - s3_object_key = s3_object.key - self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key - self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE, str_date) - local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) - LOG.info( - log_json( - self.provider_uuid, - msg="Downloading file locally", - context=self.logging_context, - ) - ) - s3_bucket.download_file(s3_object_key, local_file_path) - self.file_tracker.add_local_file(s3_object_key, local_file_path, str_date) - self.file_tracker.set_state( - s3_object_key, self._coerce_parquet_data_type(local_file_path), str_date - ) - del self.logging_context[self.S3_OBJ_LOG_KEY] - del self.logging_context[self.S3_PREFIX_LOG_KEY] + s3_bucket.download_file(s3_object_key, local_file_path) + self.file_tracker.add_local_file(s3_object_key, local_file_path) + self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) + del self.logging_context[self.S3_OBJ_LOG_KEY] + del self.logging_context[self.S3_PREFIX_LOG_KEY] if self.simulate: self.file_tracker.generate_simulate_messages() return False else: files_need_updated = self.file_tracker.get_files_that_need_updated() - for bill_date, bill_date_data in files_need_updated.items(): - for s3_obj_key, converted_local_file_path in bill_date_data.items(): - self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key - # Overwrite s3 object with updated file data - with open(converted_local_file_path, "rb") as new_file: - LOG.info( - log_json( - self.provider_uuid, - msg="Uploading revised parquet file.", - context=self.logging_context, - local_file_path=converted_local_file_path, - ) + for s3_obj_key, converted_local_file_path in files_need_updated.items(): + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key + # Overwrite s3 object with updated file data + with open(converted_local_file_path, "rb") as new_file: + LOG.info( + log_json( + self.provider_uuid, + msg="Uploading revised parquet file.", + context=self.logging_context, + local_file_path=converted_local_file_path, ) - try: - s3_bucket.upload_fileobj( - new_file, - s3_obj_key, - ExtraArgs={"Metadata": {"converter_version": StateTracker.CONVERTER_VERSION}}, - ) - self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE, bill_date) - except ClientError as e: - LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") - self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED, bill_date) - continue + ) + try: + s3_bucket.upload_fileobj( + new_file, + s3_obj_key, + ExtraArgs={"Metadata": {"converter_version": StateTracker.CONVERTER_VERSION}}, + ) + self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) + except ClientError as e: + LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") + self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) + continue self.file_tracker.finalize_and_clean_up() def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names, timestamp_std): From 510b0fff9285ab542f18872330faa44f8566d101 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 8 Jan 2024 16:02:15 -0500 Subject: [PATCH 22/30] Fix transformation error & improve coverage to better highlight future failures. --- .../test/test_verify_parquet_files.py | 72 +++++++++++++------ .../util/verify_parquet_files.py | 19 +++-- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index bc7029eac8..cec351dfe0 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -47,6 +47,7 @@ def setUp(self): } self.suffix = ".parquet" self.bill_date = str(DateHelper().this_month_start) + self.default_provider = self.azure_provider def tearDown(self): shutil.rmtree(self.temp_dir) @@ -54,41 +55,55 @@ def tearDown(self): def create_default_verify_handler(self): return VerifyParquetFiles( schema_name=self.schema_name, - provider_uuid=self.azure_provider_uuid, - provider_type=self.azure_provider.type, + provider_uuid=str(self.default_provider.uuid), + provider_type=self.default_provider.type, simulate=True, bill_date=self.bill_date, cleaned_column_mapping=self.required_columns, ) + def build_expected_additional_context(self, verify_hander, successful=True): + return { + CONTEXT_KEY_MAPPING["metadata"]: { + verify_hander.file_tracker.bill_date_str: { + CONTEXT_KEY_MAPPING["version"]: verify_hander.file_tracker.CONVERTER_VERSION, + CONTEXT_KEY_MAPPING["successful"]: successful, + } + } + } + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): """Test fixes for reindexes on all required columns.""" # build a parquet file where reindex is used for all required columns - test_metadata = [ - {"uuid": self.aws_provider_uuid, "type": self.aws_provider.type}, - {"uuid": self.azure_provider_uuid, "type": self.azure_provider.type}, - {"uuid": self.ocp_provider_uuid, "type": self.ocp_provider.type}, - {"uuid": self.oci_provider_uuid, "type": self.oci_provider.type}, - ] - for metadata in test_metadata: - with self.subTest(metadata=metadata): - required_columns = FixParquetTaskHandler.clean_column_names(metadata["type"]) - data_frame = pd.DataFrame() - data_frame = data_frame.reindex(columns=required_columns.keys()) - filename = f"test_{metadata['uuid']}{self.suffix}" - temp_file = os.path.join(self.temp_dir, filename) - data_frame.to_parquet(temp_file, **self.panda_kwargs) + + def create_tmp_test_file(provider, required_columns): + """Creates a parquet file with all empty required columns through reindexing.""" + data_frame = pd.DataFrame() + data_frame = data_frame.reindex(columns=required_columns.keys()) + filename = f"test_{str(provider.uuid)}{self.suffix}" + temp_file = os.path.join(self.temp_dir, filename) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + return temp_file + + attributes = ["aws_provider", "azure_provider", "ocp_provider", "oci_provider"] + for attr in attributes: + with self.subTest(attr=attr): + provider = getattr(self, attr) + required_columns = FixParquetTaskHandler.clean_column_names(provider.type) + temp_file = create_tmp_test_file(provider, required_columns) mock_bucket = mock_s3_resource.return_value.Bucket.return_value verify_handler = VerifyParquetFiles( schema_name=self.schema_name, - provider_uuid=metadata["uuid"], - provider_type=metadata["type"], + provider_uuid=str(provider.uuid), + provider_type=provider.type, simulate=False, bill_date=self.bill_date, cleaned_column_mapping=required_columns, ) + conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + self.assertTrue(verify_handler.file_tracker.add_to_queue(conversion_metadata)) prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) filter_side_effect = [[DummyS3Object(key=temp_file)]] for _ in range(len(prefixes) - 1): @@ -103,6 +118,12 @@ def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): for field in schema: self.assertEqual(field.type, verify_handler.required_columns.get(field.name)) os.remove(temp_file) + provider.refresh_from_db() + self.assertEqual( + provider.additional_context, self.build_expected_additional_context(verify_handler, True) + ) + conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"]) + self.assertFalse(verify_handler.file_tracker.add_to_queue(conversion_metadata)) def test_coerce_parquet_data_type_no_changes_needed(self): """Test a parquet file with correct dtypes.""" @@ -156,9 +177,18 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): return_state = verify_handler._coerce_parquet_data_type(temp_file) verify_handler.file_tracker.set_state(temp_file.name, return_state) self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) - bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() - self.assertFalse(bill_metadata.get(self.bill_date, {}).get(CONTEXT_KEY_MAPPING["successful"])) - self.assertNotEqual(bill_metadata, {}) + verify_handler.file_tracker._check_if_complete() + self.default_provider.refresh_from_db() + conversion_metadata = self.default_provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"]) + self.assertIsNotNone(conversion_metadata) + bill_metadata = conversion_metadata.get(verify_handler.file_tracker.bill_date_str) + self.assertIsNotNone(bill_metadata) + self.assertFalse(bill_metadata.get(CONTEXT_KEY_MAPPING["successful"]), True) + self.assertIsNotNone(bill_metadata.get(CONTEXT_KEY_MAPPING["failed_files"])) + # confirm nothing would be sent to s3 + self.assertEqual(verify_handler.file_tracker.get_files_that_need_updated(), {}) + # confirm that it should be retried on next run + self.assertTrue(verify_handler.file_tracker.add_to_queue(conversion_metadata)) def test_oci_s3_paths(self): """test path generation for oci sources.""" diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index d4b85deca2..b27dc62b93 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -192,7 +192,7 @@ def retrieve_verify_reload_s3_parquet(self): continue self.file_tracker.finalize_and_clean_up() - def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names, timestamp_std): + def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): """Performs a transformation to change a double to a timestamp.""" if not field_names: return @@ -204,7 +204,8 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n # if len is 0 here we get an empty list, if it does # have a value for the field, overwrite it with bill_date replaced_values = [self.bill_date] * len(table[field.name]) - corrected_column = pa.array(replaced_values, type=timestamp_std) + correct_data_type = self.required_columns.get(field.name) + corrected_column = pa.array(replaced_values, type=correct_data_type) field = pa.field(field.name, corrected_column.type) fields.append(field) # Create a new schema @@ -259,13 +260,11 @@ def _coerce_parquet_data_type(self, parquet_file_path): expected_data_type=correct_data_type, ) ) - if ( - field.type == pa.float64() - and correct_data_type == pa.timestamp("ms") - or correct_data_type == pa.timestamp("ms", tz="UTC") - ): + if field.type == pa.float64() and correct_data_type in [ + pa.timestamp("ms"), + pa.timestamp("ms", tz="UTC"), + ]: double_to_timestamp_fields.append(field.name) - timestamp_std = correct_data_type else: field = pa.field(field.name, correct_data_type) corrected_fields[field.name] = correct_data_type @@ -296,9 +295,7 @@ def _coerce_parquet_data_type(self, parquet_file_path): table = table.cast(new_schema) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) - self._perform_transformation_double_to_timestamp( - parquet_file_path, double_to_timestamp_fields, timestamp_std - ) + self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. return self.file_tracker.COERCE_REQUIRED From 478fec83db2402ca3847d02500aa8f9911f929fb Mon Sep 17 00:00:00 2001 From: myersCody Date: Tue, 9 Jan 2024 13:20:48 -0500 Subject: [PATCH 23/30] Implement StrEnum for constants. --- .../test/test_verify_parquet_files.py | 29 +++++----- koku/masu/api/upgrade_trino/util/__init__.py | 6 -- koku/masu/api/upgrade_trino/util/constants.py | 57 +++++++++++++++++++ .../api/upgrade_trino/util/state_tracker.py | 39 ++++++------- .../api/upgrade_trino/util/task_handler.py | 4 +- .../util/verify_parquet_files.py | 16 +++--- 6 files changed, 99 insertions(+), 52 deletions(-) create mode 100644 koku/masu/api/upgrade_trino/util/constants.py diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index cec351dfe0..9655f5aabc 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -15,8 +15,9 @@ import pyarrow.parquet as pq from api.utils import DateHelper -from masu.api.upgrade_trino.util.state_tracker import CONTEXT_KEY_MAPPING -from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.api.upgrade_trino.util.constants import ConversionContextKeys +from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles from masu.celery.tasks import PROVIDER_REPORT_TYPE_MAP @@ -64,10 +65,10 @@ def create_default_verify_handler(self): def build_expected_additional_context(self, verify_hander, successful=True): return { - CONTEXT_KEY_MAPPING["metadata"]: { + ConversionContextKeys.metadata: { verify_hander.file_tracker.bill_date_str: { - CONTEXT_KEY_MAPPING["version"]: verify_hander.file_tracker.CONVERTER_VERSION, - CONTEXT_KEY_MAPPING["successful"]: successful, + ConversionContextKeys.version: CONVERTER_VERSION, + ConversionContextKeys.successful: successful, } } } @@ -102,7 +103,7 @@ def create_tmp_test_file(provider, required_columns): bill_date=self.bill_date, cleaned_column_mapping=required_columns, ) - conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata, {}) self.assertTrue(verify_handler.file_tracker.add_to_queue(conversion_metadata)) prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) filter_side_effect = [[DummyS3Object(key=temp_file)]] @@ -122,7 +123,7 @@ def create_tmp_test_file(provider, required_columns): self.assertEqual( provider.additional_context, self.build_expected_additional_context(verify_handler, True) ) - conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"]) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata) self.assertFalse(verify_handler.file_tracker.add_to_queue(conversion_metadata)) def test_coerce_parquet_data_type_no_changes_needed(self): @@ -139,9 +140,9 @@ def test_coerce_parquet_data_type_no_changes_needed(self): verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) verify_handler.file_tracker.set_state(temp_file.name, return_state) - self.assertEqual(return_state, StateTracker.NO_CHANGES_NEEDED) + self.assertEqual(return_state, cstates.no_changes_needed) bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() - self.assertTrue(bill_metadata.get(CONTEXT_KEY_MAPPING["successful"])) + self.assertTrue(bill_metadata.get(ConversionContextKeys.successful)) def test_coerce_parquet_data_type_coerce_needed(self): """Test that files created through reindex are fixed correctly.""" @@ -153,7 +154,7 @@ def test_coerce_parquet_data_type_coerce_needed(self): verify_handler = self.create_default_verify_handler() verify_handler.file_tracker.add_local_file(filename, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) - self.assertEqual(return_state, StateTracker.COERCE_REQUIRED) + self.assertEqual(return_state, cstates.coerce_required) verify_handler.file_tracker.set_state(filename, return_state) files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() self.assertTrue(files_need_updating.get(filename)) @@ -176,15 +177,15 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) verify_handler.file_tracker.set_state(temp_file.name, return_state) - self.assertEqual(return_state, StateTracker.FAILED_DTYPE_CONVERSION) + self.assertEqual(return_state, cstates.conversion_failed) verify_handler.file_tracker._check_if_complete() self.default_provider.refresh_from_db() - conversion_metadata = self.default_provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"]) + conversion_metadata = self.default_provider.additional_context.get(ConversionContextKeys.metadata) self.assertIsNotNone(conversion_metadata) bill_metadata = conversion_metadata.get(verify_handler.file_tracker.bill_date_str) self.assertIsNotNone(bill_metadata) - self.assertFalse(bill_metadata.get(CONTEXT_KEY_MAPPING["successful"]), True) - self.assertIsNotNone(bill_metadata.get(CONTEXT_KEY_MAPPING["failed_files"])) + self.assertFalse(bill_metadata.get(ConversionContextKeys.successful), True) + self.assertIsNotNone(bill_metadata.get(ConversionContextKeys.failed_files)) # confirm nothing would be sent to s3 self.assertEqual(verify_handler.file_tracker.get_files_that_need_updated(), {}) # confirm that it should be retried on next run diff --git a/koku/masu/api/upgrade_trino/util/__init__.py b/koku/masu/api/upgrade_trino/util/__init__.py index d1d41a128b..e69de29bb2 100644 --- a/koku/masu/api/upgrade_trino/util/__init__.py +++ b/koku/masu/api/upgrade_trino/util/__init__.py @@ -1,6 +0,0 @@ -CONTEXT_KEY_MAPPING = { - "metadata": "conversion_metadata", - "version": "version", - "successful": "successful", - "failed_files": "dtype_failed_files", -} diff --git a/koku/masu/api/upgrade_trino/util/constants.py b/koku/masu/api/upgrade_trino/util/constants.py new file mode 100644 index 0000000000..cfe5c5f607 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/constants.py @@ -0,0 +1,57 @@ +from enum import Enum + +# Update this to trigger the converter to run again +# even if marked as successful +CONVERTER_VERSION = "1" + + +class ReprEnum(Enum): + """ + Only changes the repr(), leaving str() and format() to the mixed-in type. + """ + + +# StrEnum is available in python 3.11, vendored over from +# https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345 +class StrEnum(str, ReprEnum): + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError(f"too many arguments for str(): {values!r}") + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError(f"{values[0]!r} is not a string") + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError(f"encoding must be a string, not {values[1]!r}") + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + +class ConversionContextKeys(StrEnum): + metadata = "conversion_metadata" + version = "version" + successful = "successful" + failed_files = "dtype_failed_files" + + +class ConversionStates(StrEnum): + found_s3_file = "found_s3_file" + downloaded_locally = "downloaded_locally" + no_changes_needed = "no_changes_needed" + coerce_required = "coerce_required" + s3_complete = "sent_to_s3_complete" + s3_failed = "sent_to_s3_failed" + conversion_failed = "failed_data_type_conversion" diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index b13354774d..20dc813b9a 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -5,7 +5,9 @@ from api.common import log_json from api.provider.provider_manager import ProviderManager from api.provider.provider_manager import ProviderManagerError -from masu.api.upgrade_trino.util import CONTEXT_KEY_MAPPING +from masu.api.upgrade_trino.util.constants import ConversionContextKeys +from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION LOG = logging.getLogger(__name__) @@ -14,15 +16,6 @@ class StateTracker: """Tracks the state of each s3 file for the provider per bill date""" - CONVERTER_VERSION = "0" - FOUND_S3_FILE = "found_s3_file" - DOWNLOADED_LOCALLY = "downloaded_locally" - NO_CHANGES_NEEDED = "no_changes_needed" - COERCE_REQUIRED = "coerce_required" - SENT_TO_S3_COMPLETE = "sent_to_s3_complete" - SENT_TO_S3_FAILED = "sent_to_s3_failed" - FAILED_DTYPE_CONVERSION = "failed_data_type_conversion" - def __init__(self, provider_uuid: str, bill_date: date): self.files = [] self.tracker = {} @@ -41,10 +34,10 @@ def add_to_queue(self, conversion_metadata): bool: True if the task should be added to the queue, False otherwise. """ bill_metadata = conversion_metadata.get(self.bill_date_str, {}) - if bill_metadata.get(CONTEXT_KEY_MAPPING["version"]) != self.CONVERTER_VERSION: + if bill_metadata.get(ConversionContextKeys.version) != CONVERTER_VERSION: # always kick off a task if the version does not match or exist. return True - if bill_metadata.get(CONTEXT_KEY_MAPPING["successful"]): + if bill_metadata.get(ConversionContextKeys.successful): # if the conversion was successful for this version do not kick # off a task. LOG.info( @@ -63,7 +56,7 @@ def set_state(self, s3_obj_key, state): def add_local_file(self, s3_obj_key, local_path): self.local_files[s3_obj_key] = local_path - self.tracker[s3_obj_key] = self.DOWNLOADED_LOCALLY + self.tracker[s3_obj_key] = cstates.downloaded_locally def get_files_that_need_updated(self): """Returns a mapping of files in the s3 needs @@ -73,7 +66,7 @@ def get_files_that_need_updated(self): """ mapping = {} for s3_obj_key, state in self.tracker.items(): - if state == self.COERCE_REQUIRED: + if state == cstates.coerce_required: mapping[s3_obj_key] = self.local_files.get(s3_obj_key) return mapping @@ -88,9 +81,9 @@ def generate_simulate_messages(self): files_correct = [] for s3_obj_key, state in self.tracker.items(): files_count += 1 - if state == self.COERCE_REQUIRED: + if state == cstates.coerce_required: files_need_updated.append(s3_obj_key) - elif state == self.NO_CHANGES_NEEDED: + elif state == cstates.no_changes_needed: files_correct.append(s3_obj_key) else: files_failed.append(s3_obj_key) @@ -118,26 +111,26 @@ def _clean_local_files(self): def _create_bill_date_metadata(self): # Check for incomplete files - bill_date_data = {"version": self.CONVERTER_VERSION} + bill_date_data = {"version": CONVERTER_VERSION} incomplete_files = [] for file_prefix, state in self.tracker.items(): - if state not in [self.SENT_TO_S3_COMPLETE, self.NO_CHANGES_NEEDED]: + if state not in [cstates.s3_complete, cstates.no_changes_needed]: file_metadata = {"key": file_prefix, "state": state} incomplete_files.append(file_metadata) if incomplete_files: - bill_date_data[CONTEXT_KEY_MAPPING["successful"]] = False - bill_date_data[CONTEXT_KEY_MAPPING["failed_files"]] = incomplete_files + bill_date_data[ConversionContextKeys.successful] = False + bill_date_data[ConversionContextKeys.failed_files] = incomplete_files if not incomplete_files: - bill_date_data[CONTEXT_KEY_MAPPING["successful"]] = True + bill_date_data[ConversionContextKeys.successful] = True return bill_date_data def _check_if_complete(self): try: manager = ProviderManager(self.provider_uuid) context = manager.get_additional_context() - conversion_metadata = context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + conversion_metadata = context.get(ConversionContextKeys.metadata, {}) conversion_metadata[self.bill_date_str] = self._create_bill_date_metadata() - context[CONTEXT_KEY_MAPPING["metadata"]] = conversion_metadata + context[ConversionContextKeys.metadata] = conversion_metadata manager.model.set_additional_context(context) LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context)) except ProviderManagerError: diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index 6625c33c2c..61d44180e9 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -10,7 +10,7 @@ from api.common import log_json from api.provider.models import Provider from api.utils import DateHelper -from masu.api.upgrade_trino.util import CONTEXT_KEY_MAPPING +from masu.api.upgrade_trino.util.constants import ConversionContextKeys from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.celery.tasks import fix_parquet_data_types from masu.processor import is_customer_large @@ -106,7 +106,7 @@ def build_celery_tasks(self): queue_name = GET_REPORT_FILES_QUEUE_XL account = copy.deepcopy(provider.account) - conversion_metadata = provider.additional_context.get(CONTEXT_KEY_MAPPING["metadata"], {}) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata, {}) dh = DateHelper() bill_datetimes = dh.list_months(self.start_date, dh.today.replace(tzinfo=None)) for bill_date in bill_datetimes: diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index b27dc62b93..58174bb4b1 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -12,6 +12,8 @@ from api.common import log_json from api.provider.models import Provider +from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.config import Config from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE @@ -147,7 +149,7 @@ def retrieve_verify_reload_s3_parquet(self): for s3_object in s3_bucket.objects.filter(Prefix=prefix): s3_object_key = s3_object.key self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key - self.file_tracker.set_state(s3_object_key, self.file_tracker.FOUND_S3_FILE) + self.file_tracker.set_state(s3_object_key, cstates.found_s3_file) local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) LOG.info( log_json( @@ -183,12 +185,12 @@ def retrieve_verify_reload_s3_parquet(self): s3_bucket.upload_fileobj( new_file, s3_obj_key, - ExtraArgs={"Metadata": {"converter_version": StateTracker.CONVERTER_VERSION}}, + ExtraArgs={"Metadata": {"converter_version": CONVERTER_VERSION}}, ) - self.file_tracker.set_state(s3_obj_key, self.file_tracker.SENT_TO_S3_COMPLETE) + self.file_tracker.set_state(s3_obj_key, cstates.s3_complete) except ClientError as e: LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") - self.file_tracker.set_state(s3_object_key, self.file_tracker.SENT_TO_S3_FAILED) + self.file_tracker.set_state(s3_object_key, cstates.s3_failed) continue self.file_tracker.finalize_and_clean_up() @@ -280,7 +282,7 @@ def _coerce_parquet_data_type(self, parquet_file_path): local_file_path=parquet_file_path, ) ) - return self.file_tracker.NO_CHANGES_NEEDED + return cstates.no_changes_needed new_schema = pa.schema(fields) LOG.info( @@ -297,8 +299,8 @@ def _coerce_parquet_data_type(self, parquet_file_path): pa.parquet.write_table(table, parquet_file_path) self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. - return self.file_tracker.COERCE_REQUIRED + return cstates.coerce_required except Exception as e: LOG.info(log_json(self.provider_uuid, msg="Failed to coerce data.", context=self.logging_context, error=e)) - return self.file_tracker.FAILED_DTYPE_CONVERSION + return cstates.conversion_failed From 17e109592cd28ae8eae09a1afd1ca72458a52c94 Mon Sep 17 00:00:00 2001 From: Cody Myers Date: Tue, 9 Jan 2024 16:00:22 -0500 Subject: [PATCH 24/30] Update koku/masu/api/upgrade_trino/util/constants.py Co-authored-by: Sam Doran --- koku/masu/api/upgrade_trino/util/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/koku/masu/api/upgrade_trino/util/constants.py b/koku/masu/api/upgrade_trino/util/constants.py index cfe5c5f607..3d91d3a083 100644 --- a/koku/masu/api/upgrade_trino/util/constants.py +++ b/koku/masu/api/upgrade_trino/util/constants.py @@ -13,7 +13,7 @@ class ReprEnum(Enum): # StrEnum is available in python 3.11, vendored over from # https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345 -class StrEnum(str, ReprEnum): +class StrEnum(str, ReprEnum): # pragma: no cover """ Enum where members are also (and must be) strings """ From 8c5259d8bf5ef5c2d29a5153420069cba195748a Mon Sep 17 00:00:00 2001 From: Luke Couzens Date: Wed, 10 Jan 2024 11:33:02 +0000 Subject: [PATCH 25/30] fix double to timestamp conversion issue --- koku/masu/api/upgrade_trino/util/verify_parquet_files.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 58174bb4b1..9d92187850 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -33,7 +33,8 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat self.provider_uuid = uuid.UUID(provider_uuid) self.provider_type = provider_type.replace("-local", "") self.simulate = simulate - self.bill_date = self._bill_date(bill_date) + self.bill_date_time = self._bill_date_time(bill_date) + self.bill_date = self.bill_date_time.date() self.file_tracker = StateTracker(provider_uuid, self.bill_date) self.report_types = self._set_report_types() self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) @@ -45,10 +46,10 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat "bill_date": self.bill_date, } - def _bill_date(self, bill_date): + def _bill_date_time(self, bill_date): """bill_date""" if isinstance(bill_date, str): - return ciso8601.parse_datetime(bill_date).replace(tzinfo=None).date() + return ciso8601.parse_datetime(bill_date).replace(tzinfo=None) return bill_date def _set_pyarrow_types(self, cleaned_column_mapping): @@ -205,7 +206,7 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n if field.name in field_names: # if len is 0 here we get an empty list, if it does # have a value for the field, overwrite it with bill_date - replaced_values = [self.bill_date] * len(table[field.name]) + replaced_values = [self.bill_date_time] * len(table[field.name]) correct_data_type = self.required_columns.get(field.name) corrected_column = pa.array(replaced_values, type=correct_data_type) field = pa.field(field.name, corrected_column.type) From da5ea49efe46d3100a5e09047cb0d52dbf00b787 Mon Sep 17 00:00:00 2001 From: myersCody Date: Wed, 10 Jan 2024 11:46:29 -0500 Subject: [PATCH 26/30] Change the default value for transformation. --- .../test/test_verify_parquet_files.py | 64 +++++++++++++++---- .../api/upgrade_trino/util/state_tracker.py | 1 + .../util/verify_parquet_files.py | 22 ++++--- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 9655f5aabc..b03e43b1bc 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -73,6 +73,12 @@ def build_expected_additional_context(self, verify_hander, successful=True): } } + def verify_correct_types(self, temp_file, verify_handler): + table = pq.read_table(temp_file) + schema = table.schema + for field in schema: + self.assertEqual(field.type, verify_handler.required_columns.get(field.name)) + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): @@ -114,11 +120,8 @@ def create_tmp_test_file(provider, required_columns): VerifyParquetFiles.local_path = self.temp_dir verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_called() - table = pq.read_table(temp_file) - schema = table.schema - for field in schema: - self.assertEqual(field.type, verify_handler.required_columns.get(field.name)) - os.remove(temp_file) + self.verify_correct_types(temp_file, verify_handler) + # Test that the additional context is set correctly provider.refresh_from_db() self.assertEqual( provider.additional_context, self.build_expected_additional_context(verify_handler, True) @@ -126,7 +129,41 @@ def create_tmp_test_file(provider, required_columns): conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata) self.assertFalse(verify_handler.file_tracker.add_to_queue(conversion_metadata)) - def test_coerce_parquet_data_type_no_changes_needed(self): + def test_double_to_timestamp_transformation_with_reindex(self): + """Test double to datetime transformation with values""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "unrequired_column": ["a", "b", "c"], + } + test_file = "transformation_test.parquet" + data_frame = pd.DataFrame(file_data) + data_frame = data_frame.reindex(columns=self.required_columns) + temp_file = os.path.join(self.temp_dir, test_file) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler._perform_transformation_double_to_timestamp(temp_file, ["datetime"]) + self.verify_correct_types(temp_file, verify_handler) + + def test_double_to_timestamp_transformation_with_values(self): + """Test double to datetime transformation with values""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "datetime": [1.1, 2.2, 3.3], + "unrequired_column": ["a", "b", "c"], + } + test_file = "transformation_test.parquet" + data_frame = pd.DataFrame(file_data) + data_frame = data_frame.reindex(columns=self.required_columns) + temp_file = os.path.join(self.temp_dir, test_file) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler._perform_transformation_double_to_timestamp(temp_file, ["datetime"]) + self.verify_correct_types(temp_file, verify_handler) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + def test_coerce_parquet_data_type_no_changes_needed(self, _): """Test a parquet file with correct dtypes.""" file_data = { "float": [1.1, 2.2, 3.3], @@ -143,6 +180,9 @@ def test_coerce_parquet_data_type_no_changes_needed(self): self.assertEqual(return_state, cstates.no_changes_needed) bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() self.assertTrue(bill_metadata.get(ConversionContextKeys.successful)) + # Test that generated messages would contain these files. + simulated_messages = verify_handler.file_tracker.generate_simulate_messages() + self.assertIn(str(temp_file.name), simulated_messages.get("Files that have all correct data_types.")) def test_coerce_parquet_data_type_coerce_needed(self): """Test that files created through reindex are fixed correctly.""" @@ -158,11 +198,13 @@ def test_coerce_parquet_data_type_coerce_needed(self): verify_handler.file_tracker.set_state(filename, return_state) files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() self.assertTrue(files_need_updating.get(filename)) - table = pq.read_table(temp_file) - schema = table.schema - for field in schema: - self.assertEqual(field.type, self.expected_pyarrow_dtypes.get(field.name)) - os.remove(temp_file) + self.verify_correct_types(temp_file, verify_handler) + # Test that generated messages would contain these files. + simulated_messages = verify_handler.file_tracker.generate_simulate_messages() + self.assertIn(filename, simulated_messages.get("Files that need to be updated.")) + # Test delete clean local files. + verify_handler.file_tracker._clean_local_files() + self.assertFalse(os.path.exists(temp_file)) def test_coerce_parquet_data_type_failed_to_coerce(self): """Test a parquet file with correct dtypes.""" diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 20dc813b9a..b393c8bb02 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -103,6 +103,7 @@ def generate_simulate_messages(self): ) ) self._clean_local_files() + return simulate_info def _clean_local_files(self): for file_path in self.local_files.values(): diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 9d92187850..9c3a570571 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -33,8 +33,7 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat self.provider_uuid = uuid.UUID(provider_uuid) self.provider_type = provider_type.replace("-local", "") self.simulate = simulate - self.bill_date_time = self._bill_date_time(bill_date) - self.bill_date = self.bill_date_time.date() + self.bill_date = self._bill_date(bill_date) self.file_tracker = StateTracker(provider_uuid, self.bill_date) self.report_types = self._set_report_types() self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) @@ -46,10 +45,10 @@ def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_dat "bill_date": self.bill_date, } - def _bill_date_time(self, bill_date): + def _bill_date(self, bill_date): """bill_date""" if isinstance(bill_date, str): - return ciso8601.parse_datetime(bill_date).replace(tzinfo=None) + return ciso8601.parse_datetime(bill_date).replace(tzinfo=None).date() return bill_date def _set_pyarrow_types(self, cleaned_column_mapping): @@ -199,16 +198,23 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n """Performs a transformation to change a double to a timestamp.""" if not field_names: return + LOG.info( + log_json( + self.provider_uuid, + msg="Transforming fields from double to timestamp.", + context=self.logging_context, + local_file_path=parquet_file_path, + updated_columns=field_names, + ) + ) table = pq.read_table(parquet_file_path) schema = table.schema fields = [] for field in schema: if field.name in field_names: - # if len is 0 here we get an empty list, if it does - # have a value for the field, overwrite it with bill_date - replaced_values = [self.bill_date_time] * len(table[field.name]) + replacement_value = [] correct_data_type = self.required_columns.get(field.name) - corrected_column = pa.array(replaced_values, type=correct_data_type) + corrected_column = pa.array(replacement_value, type=correct_data_type) field = pa.field(field.name, corrected_column.type) fields.append(field) # Create a new schema From 27cc7948a3514dc4763292a051f11e38dab275e1 Mon Sep 17 00:00:00 2001 From: Cody Myers Date: Fri, 12 Jan 2024 15:20:56 -0500 Subject: [PATCH 27/30] Apply suggestions from code review Co-authored-by: Sam Doran --- koku/masu/api/upgrade_trino/util/state_tracker.py | 13 ++++++------- .../api/upgrade_trino/util/verify_parquet_files.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index b393c8bb02..1a85406915 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -64,11 +64,11 @@ def get_files_that_need_updated(self): {s3_object_key: local_file_path} for """ - mapping = {} - for s3_obj_key, state in self.tracker.items(): - if state == cstates.coerce_required: - mapping[s3_obj_key] = self.local_files.get(s3_obj_key) - return mapping + return { + s3_obj_key: self.local_files.get(s3_obj_key) + for s3_obj_key, state in self.tracker.items() + if state == cstates.coerce_required + } def generate_simulate_messages(self): """ @@ -107,8 +107,7 @@ def generate_simulate_messages(self): def _clean_local_files(self): for file_path in self.local_files.values(): - if os.path.exists(file_path): - os.remove(file_path) + file_path.unlink(missing_ok=True) def _create_bill_date_metadata(self): # Check for incomplete files diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 9c3a570571..c6f7b93426 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -150,7 +150,7 @@ def retrieve_verify_reload_s3_parquet(self): s3_object_key = s3_object.key self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key self.file_tracker.set_state(s3_object_key, cstates.found_s3_file) - local_file_path = os.path.join(self.local_path, os.path.basename(s3_object_key)) + local_file_path = self.local_path.joinpath(os.path.basename(s3_object_key)) LOG.info( log_json( self.provider_uuid, From 523e19b66138e7215ee08db1162cfb0f8892d0a1 Mon Sep 17 00:00:00 2001 From: myersCody Date: Fri, 12 Jan 2024 15:27:14 -0500 Subject: [PATCH 28/30] Address PR comments. --- koku/common/__ini__.py | 0 koku/common/enum.py | 36 ++++++++++++++++++ .../test/test_verify_parquet_files.py | 8 ++-- koku/masu/api/upgrade_trino/util/constants.py | 37 +------------------ .../api/upgrade_trino/util/state_tracker.py | 19 +++++----- .../api/upgrade_trino/util/task_handler.py | 23 ++++++------ .../util/verify_parquet_files.py | 14 +++---- 7 files changed, 69 insertions(+), 68 deletions(-) create mode 100644 koku/common/__ini__.py create mode 100644 koku/common/enum.py diff --git a/koku/common/__ini__.py b/koku/common/__ini__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/common/enum.py b/koku/common/enum.py new file mode 100644 index 0000000000..8052fb0c08 --- /dev/null +++ b/koku/common/enum.py @@ -0,0 +1,36 @@ +from enum import Enum + + +class ReprEnum(Enum): + """ + Only changes the repr(), leaving str() and format() to the mixed-in type. + """ + + +# StrEnum is available in python 3.11, vendored over from +# https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345 +class StrEnum(str, ReprEnum): # pragma: no cover + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError(f"too many arguments for str(): {values!r}") + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError(f"{values[0]!r} is not a string") + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError(f"encoding must be a string, not {values[1]!r}") + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index b03e43b1bc..2014bbde1f 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -16,7 +16,7 @@ from api.utils import DateHelper from masu.api.upgrade_trino.util.constants import ConversionContextKeys -from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import ConversionStates from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles @@ -177,7 +177,7 @@ def test_coerce_parquet_data_type_no_changes_needed(self, _): verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) verify_handler.file_tracker.set_state(temp_file.name, return_state) - self.assertEqual(return_state, cstates.no_changes_needed) + self.assertEqual(return_state, ConversionStates.no_changes_needed) bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() self.assertTrue(bill_metadata.get(ConversionContextKeys.successful)) # Test that generated messages would contain these files. @@ -194,7 +194,7 @@ def test_coerce_parquet_data_type_coerce_needed(self): verify_handler = self.create_default_verify_handler() verify_handler.file_tracker.add_local_file(filename, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) - self.assertEqual(return_state, cstates.coerce_required) + self.assertEqual(return_state, ConversionStates.coerce_required) verify_handler.file_tracker.set_state(filename, return_state) files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() self.assertTrue(files_need_updating.get(filename)) @@ -219,7 +219,7 @@ def test_coerce_parquet_data_type_failed_to_coerce(self): verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) return_state = verify_handler._coerce_parquet_data_type(temp_file) verify_handler.file_tracker.set_state(temp_file.name, return_state) - self.assertEqual(return_state, cstates.conversion_failed) + self.assertEqual(return_state, ConversionStates.conversion_failed) verify_handler.file_tracker._check_if_complete() self.default_provider.refresh_from_db() conversion_metadata = self.default_provider.additional_context.get(ConversionContextKeys.metadata) diff --git a/koku/masu/api/upgrade_trino/util/constants.py b/koku/masu/api/upgrade_trino/util/constants.py index 3d91d3a083..fff80b596c 100644 --- a/koku/masu/api/upgrade_trino/util/constants.py +++ b/koku/masu/api/upgrade_trino/util/constants.py @@ -1,45 +1,10 @@ -from enum import Enum +from common.enum import StrEnum # Update this to trigger the converter to run again # even if marked as successful CONVERTER_VERSION = "1" -class ReprEnum(Enum): - """ - Only changes the repr(), leaving str() and format() to the mixed-in type. - """ - - -# StrEnum is available in python 3.11, vendored over from -# https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345 -class StrEnum(str, ReprEnum): # pragma: no cover - """ - Enum where members are also (and must be) strings - """ - - def __new__(cls, *values): - "values must already be of type `str`" - if len(values) > 3: - raise TypeError(f"too many arguments for str(): {values!r}") - if len(values) == 1: - # it must be a string - if not isinstance(values[0], str): - raise TypeError(f"{values[0]!r} is not a string") - if len(values) >= 2: - # check that encoding argument is a string - if not isinstance(values[1], str): - raise TypeError(f"encoding must be a string, not {values[1]!r}") - if len(values) == 3: - # check that errors argument is a string - if not isinstance(values[2], str): - raise TypeError("errors must be a string, not %r" % (values[2])) - value = str(*values) - member = str.__new__(cls, value) - member._value_ = value - return member - - class ConversionContextKeys(StrEnum): metadata = "conversion_metadata" version = "version" diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py index 1a85406915..0f4288f1f1 100644 --- a/koku/masu/api/upgrade_trino/util/state_tracker.py +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -1,12 +1,11 @@ import logging -import os from datetime import date from api.common import log_json from api.provider.provider_manager import ProviderManager from api.provider.provider_manager import ProviderManagerError from masu.api.upgrade_trino.util.constants import ConversionContextKeys -from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import ConversionStates from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION @@ -56,7 +55,7 @@ def set_state(self, s3_obj_key, state): def add_local_file(self, s3_obj_key, local_path): self.local_files[s3_obj_key] = local_path - self.tracker[s3_obj_key] = cstates.downloaded_locally + self.tracker[s3_obj_key] = ConversionStates.downloaded_locally def get_files_that_need_updated(self): """Returns a mapping of files in the s3 needs @@ -64,10 +63,10 @@ def get_files_that_need_updated(self): {s3_object_key: local_file_path} for """ - return { - s3_obj_key: self.local_files.get(s3_obj_key) - for s3_obj_key, state in self.tracker.items() - if state == cstates.coerce_required + return { + s3_obj_key: self.local_files.get(s3_obj_key) + for s3_obj_key, state in self.tracker.items() + if state == ConversionStates.coerce_required } def generate_simulate_messages(self): @@ -81,9 +80,9 @@ def generate_simulate_messages(self): files_correct = [] for s3_obj_key, state in self.tracker.items(): files_count += 1 - if state == cstates.coerce_required: + if state == ConversionStates.coerce_required: files_need_updated.append(s3_obj_key) - elif state == cstates.no_changes_needed: + elif state == ConversionStates.no_changes_needed: files_correct.append(s3_obj_key) else: files_failed.append(s3_obj_key) @@ -114,7 +113,7 @@ def _create_bill_date_metadata(self): bill_date_data = {"version": CONVERTER_VERSION} incomplete_files = [] for file_prefix, state in self.tracker.items(): - if state not in [cstates.s3_complete, cstates.no_changes_needed]: + if state not in [ConversionStates.s3_complete, ConversionStates.no_changes_needed]: file_metadata = {"key": file_prefix, "state": state} incomplete_files.append(file_metadata) if incomplete_files: diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index 61d44180e9..a221157997 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -28,7 +28,7 @@ class RequiredParametersError(Exception): """Handle require parameters error.""" -@dataclass +@dataclass(frozen=True) class FixParquetTaskHandler: start_date: Optional[str] = field(default=None) provider_uuid: Optional[str] = field(default=None) @@ -48,32 +48,33 @@ class FixParquetTaskHandler: @classmethod def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": """Create an instance from query parameters.""" - reprocess_kwargs = cls() + kwargs = {} if start_date := query_params.get("start_date"): if start_date: - reprocess_kwargs.start_date = parser.parse(start_date).replace(day=1) + kwargs["start_date"] = parser.parse(start_date).replace(day=1) if provider_uuid := query_params.get("provider_uuid"): provider = Provider.objects.filter(uuid=provider_uuid).first() if not provider: raise RequiredParametersError(f"The provider_uuid {provider_uuid} does not exist.") - reprocess_kwargs.provider_uuid = provider_uuid - reprocess_kwargs.provider_type = provider.type + kwargs["provider_uuid"] = provider_uuid + kwargs["provider_type"] = provider.type if provider_type := query_params.get("provider_type"): - reprocess_kwargs.provider_type = provider_type + kwargs["provider_type"] = provider_type if simulate := query_params.get("simulate"): if simulate.lower() == "true": - reprocess_kwargs.simulate = True + kwargs["simulate"] = True - if not reprocess_kwargs.provider_type and not reprocess_kwargs.provider_uuid: + if not kwargs["provider_type"] and not kwargs["provider_uuid"]: raise RequiredParametersError("provider_uuid or provider_type must be supplied") - if not reprocess_kwargs.start_date: + + if not kwargs["start_date"]: raise RequiredParametersError("start_date must be supplied as a parameter.") - reprocess_kwargs.cleaned_column_mapping = reprocess_kwargs.clean_column_names(reprocess_kwargs.provider_type) - return reprocess_kwargs + kwargs["cleaned_column_mapping"] = cls.clean_column_names(kwargs["provider_type"]) + return cls(**kwargs) @classmethod def clean_column_names(cls, provider_type): diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index c6f7b93426..71575642a7 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -12,7 +12,7 @@ from api.common import log_json from api.provider.models import Provider -from masu.api.upgrade_trino.util.constants import ConversionStates as cstates +from masu.api.upgrade_trino.util.constants import ConversionStates from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION from masu.api.upgrade_trino.util.state_tracker import StateTracker from masu.config import Config @@ -149,7 +149,7 @@ def retrieve_verify_reload_s3_parquet(self): for s3_object in s3_bucket.objects.filter(Prefix=prefix): s3_object_key = s3_object.key self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key - self.file_tracker.set_state(s3_object_key, cstates.found_s3_file) + self.file_tracker.set_state(s3_object_key, ConversionStates.found_s3_file) local_file_path = self.local_path.joinpath(os.path.basename(s3_object_key)) LOG.info( log_json( @@ -187,10 +187,10 @@ def retrieve_verify_reload_s3_parquet(self): s3_obj_key, ExtraArgs={"Metadata": {"converter_version": CONVERTER_VERSION}}, ) - self.file_tracker.set_state(s3_obj_key, cstates.s3_complete) + self.file_tracker.set_state(s3_obj_key, ConversionStates.s3_complete) except ClientError as e: LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") - self.file_tracker.set_state(s3_object_key, cstates.s3_failed) + self.file_tracker.set_state(s3_object_key, ConversionStates.s3_failed) continue self.file_tracker.finalize_and_clean_up() @@ -289,7 +289,7 @@ def _coerce_parquet_data_type(self, parquet_file_path): local_file_path=parquet_file_path, ) ) - return cstates.no_changes_needed + return ConversionStates.no_changes_needed new_schema = pa.schema(fields) LOG.info( @@ -306,8 +306,8 @@ def _coerce_parquet_data_type(self, parquet_file_path): pa.parquet.write_table(table, parquet_file_path) self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. - return cstates.coerce_required + return ConversionStates.coerce_required except Exception as e: LOG.info(log_json(self.provider_uuid, msg="Failed to coerce data.", context=self.logging_context, error=e)) - return cstates.conversion_failed + return ConversionStates.conversion_failed From ca008550735dcba5788b6f4e00d0e2a2775d534e Mon Sep 17 00:00:00 2001 From: myersCody Date: Fri, 12 Jan 2024 15:49:38 -0500 Subject: [PATCH 29/30] Fix unittests. --- .../api/upgrade_trino/test/test_verify_parquet_files.py | 7 ++++--- koku/masu/api/upgrade_trino/util/task_handler.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 2014bbde1f..913b700f8a 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -8,6 +8,7 @@ import tempfile from collections import namedtuple from datetime import datetime +from pathlib import Path from unittest.mock import patch import pandas as pd @@ -117,7 +118,7 @@ def create_tmp_test_file(provider, required_columns): filter_side_effect.append([]) mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file - VerifyParquetFiles.local_path = self.temp_dir + VerifyParquetFiles.local_path = Path(self.temp_dir) verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_called() self.verify_correct_types(temp_file, verify_handler) @@ -192,7 +193,7 @@ def test_coerce_parquet_data_type_coerce_needed(self): temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") data_frame.to_parquet(temp_file, **self.panda_kwargs) verify_handler = self.create_default_verify_handler() - verify_handler.file_tracker.add_local_file(filename, temp_file) + verify_handler.file_tracker.add_local_file(filename, Path(temp_file)) return_state = verify_handler._coerce_parquet_data_type(temp_file) self.assertEqual(return_state, ConversionStates.coerce_required) verify_handler.file_tracker.set_state(filename, return_state) @@ -359,7 +360,7 @@ def test_retrieve_verify_reload_s3_parquet_failure(self, mock_s3_resource, _): filter_side_effect.append([]) mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file - VerifyParquetFiles.local_path = self.temp_dir + VerifyParquetFiles.local_path = Path(self.temp_dir) verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_not_called() os.remove(temp_file) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index a221157997..ed2b5158d1 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -67,10 +67,10 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": if simulate.lower() == "true": kwargs["simulate"] = True - if not kwargs["provider_type"] and not kwargs["provider_uuid"]: + if not kwargs.get("provider_type") and not kwargs.get("provider_uuid"): raise RequiredParametersError("provider_uuid or provider_type must be supplied") - if not kwargs["start_date"]: + if not kwargs.get("start_date"): raise RequiredParametersError("start_date must be supplied as a parameter.") kwargs["cleaned_column_mapping"] = cls.clean_column_names(kwargs["provider_type"]) From 9ae32b431f422cc8140cf5e5eb72048ff69d5679 Mon Sep 17 00:00:00 2001 From: myersCody Date: Mon, 15 Jan 2024 09:03:32 -0500 Subject: [PATCH 30/30] Fix the misspelling. --- koku/common/{__ini__.py => __init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename koku/common/{__ini__.py => __init__.py} (100%) diff --git a/koku/common/__ini__.py b/koku/common/__init__.py similarity index 100% rename from koku/common/__ini__.py rename to koku/common/__init__.py