From 7b0a9cd6f302daf96272a97cfdcb8676e773a7fd Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Wed, 16 Jun 2021 13:41:54 -0700 Subject: [PATCH 1/8] feature: add utility class for getting jobs metrics --- src/braket/aws/aws_session.py | 9 + src/braket/jobs/__init__.py | 12 + src/braket/jobs/metrics/__init__.py | 14 + .../metrics/cwl_insights_metrics_fetcher.py | 284 ++++++++++++++++++ .../unit_tests/braket/aws/test_aws_session.py | 6 + .../test_cwl_insights_metrics_fetcher.py | 164 ++++++++++ 6 files changed, 489 insertions(+) create mode 100644 src/braket/jobs/__init__.py create mode 100644 src/braket/jobs/metrics/__init__.py create mode 100644 src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py create mode 100644 test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 2969bd9b8..2097d7a95 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -137,6 +137,15 @@ def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: obj = s3.Object(s3_bucket, s3_object_key) return obj.get()["Body"].read().decode("utf-8") + def create_logs_client(self) -> Any: + """ + Create a CloudWatch Logs boto client. + + Returns: + Any: The CloudWatch Logs boto client. + """ + return self.boto_session.client("logs", config=self._config) + def get_device(self, arn: str) -> Dict[str, Any]: """ Calls the Amazon Braket `get_device` API to diff --git a/src/braket/jobs/__init__.py b/src/braket/jobs/__init__.py new file mode 100644 index 000000000..f4243de11 --- /dev/null +++ b/src/braket/jobs/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/src/braket/jobs/metrics/__init__.py b/src/braket/jobs/metrics/__init__.py new file mode 100644 index 000000000..fb113bc16 --- /dev/null +++ b/src/braket/jobs/metrics/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from braket.jobs.metrics.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher # noqa: F401 diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py new file mode 100644 index 000000000..358f374c8 --- /dev/null +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -0,0 +1,284 @@ +# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import re +import time +from logging import Logger, getLogger +from typing import Any, Dict, Iterator, List, Optional + +from braket.aws.aws_session import AwsSession + + +class CwlInsightsMetricsFetcher(object): + + LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" + METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*(\d*\.?\d*)\s*;") + TIMESTAMP = "Timestamp" + QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 + + def __init__( + self, + aws_session: AwsSession, + poll_timeout_seconds: float = 10, + poll_interval_seconds: float = 1, + logger: Logger = getLogger(__name__), + ): + """ + Args: + aws_session (AwsSession): AwsSession to connect to AWS with. + poll_timeout_seconds (float): The polling timeout for retrieving the metrics, + in seconds. Default: 10 seconds. + poll_interval_seconds (float): The polling interval for results in seconds. + Default: 1 second. + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for task to be in a terminal state. Default is `getLogger(__name__)` + """ + self._poll_timeout_seconds = poll_timeout_seconds + self._poll_interval_seconds = poll_interval_seconds + self._logger = logger + self._logs_client = aws_session.create_logs_client() + + def _get_metrics_results_sync(self, query_id: str) -> List[Any]: + """ + Waits for the CloudWatch Insights query to complete and returns all the results. + + Args: + query_id (str): CloudWatch Insights query ID. + + Returns: + List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. + """ + start_time = time.time() + while (time.time() - start_time) < self._poll_timeout_seconds: + response = self._logs_client.get_query_results(queryId=query_id) + query_status = response["status"] + if query_status in ["Failed", "Cancelled"]: + raise Exception(f"Query {query_id} failed with status {query_status}") + elif query_status == "Complete": + return response["results"] + else: + time.sleep(self._poll_interval_seconds) + self._logger.warning(f"Timed out waiting for query {query_id}.") + return [] + + @staticmethod + def _metrics_id_from_metrics(metrics: Dict[str, float]) -> int: + """ + Determines the semi-unique ID for a set of metrics that will represent the table for that + set of columns. The current implementation doesn't treat a difference in order as unique; + therefore, the same metrics output in various different orders will map to the same table. + + Args: + metrics (Dict[str, float]): The set of metrics. + + Returns: + int : The Metrics ID. + """ + metrics_id = 0 + for column_name in metrics.keys(): + metrics_id += hash(column_name) + return metrics_id + + @staticmethod + def _get_metrics_from_log_line_matches(all_matches: Iterator) -> Dict[str, float]: + """ + Converts matches from a RegEx to a set of metrics. + + Args: + all_matches (Iterator): An iterator for RegEx matches on a log line. + + Returns: + Dict[str, float]: The set of metrics found by the RegEx. The result will be in the + format { : }. This implies that multiple metrics with + the same name will be deduped to the last instance of that metric. + """ + metrics = {} + for match in all_matches: + subgroup = match.groups() + metrics[subgroup[0]] = subgroup[1] + return metrics + + @staticmethod + def _get_timestamp_from_log_line(log_line: List[Dict[str, Any]]) -> str: + """ + Finds and returns the timestamp of a log line from CloudWatch Insights results. + + Args: + log_line (List[Dict[str, Any]]): An iterator for RegEx matches on a log line. + + Returns: + str : The timestamp of the log line or 'N/A' if no timestamp is found. + """ + for element in log_line: + if element["field"] == "@timestamp": + return element["value"] + return "N/A" + + def _add_metrics( + self, + metrics_id: int, + metrics: Dict[str, float], + all_metrics: Dict[int, Dict[str, List[Any]]], + ) -> None: + """ + Adds the given metrics to the appropriate table in 'all_metrics'. If 'all_metrics' does not + currently have a table that represents the metrics, a new entry will be created to + represent the data. When the table entry is created, a "Timestamp" field is also added by + default, since all log metrics should have a timestamp. + + Args: + metrics_id (int): An ID to represent the given set of metrics. + metrics (Dict[str, float]): A set of metrics in the format { : }. + all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. + """ + if metrics_id not in all_metrics: + metrics_table = {} + for column_name in metrics.keys(): + metrics_table[column_name] = [] + metrics_table[self.TIMESTAMP] = [] + all_metrics[metrics_id] = metrics_table + metrics_table = all_metrics[metrics_id] + for column_name in metrics.keys(): + metrics_table[column_name].append(metrics[column_name]) + + def _parse_metrics_from_message( + self, message: str, all_metrics: Dict[int, Dict[str, List[Any]]] + ) -> Optional[int]: + """ + Parses a line from CloudWatch Logs to find all the metrics that have been logged + on that line. Any found metrics will be added to 'all_metrics'. + + Args: + message (str): A log line from CloudWatch Logs. + all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. + + Returns: + int : The ID to represent the given set of logged metrics, or None if no metrics + are found in the message. + """ + all_matches = self.METRICS_DEFINITIONS.finditer(message) + metrics = self._get_metrics_from_log_line_matches(all_matches) + if not metrics: + return None + metrics_id = self._metrics_id_from_metrics(metrics) + self._add_metrics(metrics_id, metrics, all_metrics) + return metrics_id + + def _parse_metrics_from_result_entry( + self, result_entry: List[Any], timestamp: str, all_metrics: Dict[int, Dict[str, List[Any]]] + ) -> None: + """ + Finds the actual log line containing metrics from a given CloudWatch Insights result entry, + and adds them to 'all_metrics'. The timestamp is added to match the corresponding values in + 'all_metrics'. + + Args: + result_entry (List[Any]): A structured result from calling CloudWatch Insights to get + logs that contain metrics. A single entry will contain the message + (the actual line logged to output), the timestamp (generated by CloudWatch Logs), + and other metadata that we (currently) do not use. + timestamp (str): A formatted string representing the timestamp for any found metrics. + all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. + """ + for element in result_entry: + if element["field"] == "@message": + metrics_id = self._parse_metrics_from_message(element["value"], all_metrics) + if metrics_id is not None: + all_metrics[metrics_id][self.TIMESTAMP].append(timestamp) + break + + def _parse_log_line( + self, result_entry: List[Any], all_metrics: Dict[int, Dict[str, List[Any]]] + ) -> None: + """ + Parses the single entry from CloudWatch Insights results and adds any metrics it finds + to 'all_metrics', along with the timestamp for the entry. + + Args: + result_entry (List[Any]): A structured result from calling CloudWatch Insights to get + logs that contain metrics. A single entry will contain the message + (the actual line logged to output), the timestamp (generated by CloudWatch Logs), + and other metadata that we (currently) do not use. + all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. + """ + timestamp = self._get_timestamp_from_log_line(result_entry) + self._parse_metrics_from_result_entry(result_entry, timestamp, all_metrics) + + def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, List[Any]]]: + """ + Parses CloudWatch Insights results and returns all found metrics. + + Args: + results (List[Any]): A structured result from calling CloudWatch Insights to get + logs that contain metrics. + + Returns: + Dict[int, Dict[str, List[Any]]] : The list of all metrics that can be found in the + CloudWatch Logs results. Each unique set of metrics will be represented by a key + in the topmost dictionary. We can think of this set of metrics as a single table. + Each table will have a set of metrics, indexed by the column name. The + entries are not sorted. + """ + all_metrics = {} + for result in results: + self._parse_log_line(result, all_metrics) + return all_metrics + + def get_all_metrics_for_job( + self, job_name: str, job_start_time: int = None, job_end_time: int = None + ) -> List[Dict[str, List[Any]]]: + """ + Synchronously retrieves all the algorithm metrics logged by a given Job. + + Args: + job_name (str): The name of the Job. The name must be exact to ensure only the relevant + metrics are retrieved. + job_start_time (int): The time at which the job started. + Default: 3 hours before job_end_time. + job_end_time (int): If the job is complete, this should be the time at which the + job finished. Default: current time. + + Returns: + List[Dict[str, List[Any]]] : The list of all metrics that can be found in the + CloudWatch Logs results. Each item in the list can be thought of as a separate + table. Each table will have a set of metrics, indexed by the column name. The + entries are not sorted. + """ + query_end_time = job_end_time or int(time.time()) + query_start_time = job_start_time or query_end_time - self.QUERY_DEFAULT_JOB_DURATION + + # job name needs to be specific to prevent jobs with similar names from being conflated + query = ( + f"fields @timestamp, @message " + f"| filter @logStream like /^{re.escape(job_name)}$/ " + f"| filter @message like /^Metrics - /" + ) + + response = self._logs_client.start_query( + logGroupName=self.LOG_GROUP_NAME, + startTime=query_start_time, + endTime=query_end_time, + queryString=query, + ) + + query_id = response["queryId"] + + results = self._get_metrics_results_sync(query_id) + + metric_data = self._parse_log_query_results(results) + + metric_data_list = [] + for metric_graph in metric_data.keys(): + metric_data_list.append(metric_data[metric_graph]) + + return metric_data_list diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index e3247f7c2..000948e11 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -131,6 +131,12 @@ def test_retrieve_s3_object_body_client_error(boto_session): aws_session.retrieve_s3_object_body(bucket_name, filename) +def test_create_logs_client(boto_session): + aws_session = AwsSession(boto_session=boto_session) + aws_session.create_logs_client() + boto_session.client.assert_called_with("logs", config=None) + + def test_get_device(boto_session): braket_client = Mock() return_val = {"deviceArn": "arn1", "deviceName": "name1"} diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py new file mode 100644 index 000000000..acc7a6edd --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -0,0 +1,164 @@ +# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from unittest.mock import Mock + +import pytest + +from braket.jobs.metrics import CwlInsightsMetricsFetcher + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +MALFORMED_METRICS_LOG_LINES = [ + [ + {"field": "@timestamp", "value": "Test timestamp 0"}, + {"field": "@message", "value": ""}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 1"}, + {"field": "@message", "value": "Test Test"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 2"}, + {"field": "@message", "value": "metric0=not_a_number;"}, + ], + [{"field": "@timestamp", "value": "Test timestamp 0"}], + [ + {"field": "@unknown", "value": "Unknown"}, + ], +] + + +SIMPLE_METRICS_LOG_LINES = [ + [ + {"field": "@timestamp", "value": "Test timestamp 0"}, + {"field": "@message", "value": "metric0=0.0; metric1=1.0; metric2=2.0"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 1"}, + {"field": "@message", "value": "metric0=0.1; metric1=1.1; metric2=2.1"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 2"}, + {"field": "@message", "value": "metric0=0.2; metric1=1.2; metric2=2.2"}, + ], +] + +SIMPLE_METRICS_RESULT = [ + { + "Timestamp": ["Test timestamp 0", "Test timestamp 1", "Test timestamp 2"], + "metric0": ["0.0", "0.1", "0.2"], + "metric1": ["1.0", "1.1", "1.2"], + } +] + + +MULTIPLE_TABLES_METRICS_LOG_LINES = [ + [ + {"field": "@timestamp", "value": "Test timestamp 0"}, + {"field": "@message", "value": "metric0=0.0;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 1"}, + {"field": "@message", "value": "metric0=0.1; metric1=1.1;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 2"}, + {"field": "@message", "value": "metric0=0.2; metric2=2.2;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 3"}, + {"field": "@message", "value": "metric0=0.3; metric1=1.3;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 4"}, + {"field": "@message", "value": "metric1=1.4; metric2=2.4;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 5"}, + {"field": "@message", "value": "metric0=0.5; metric1=1.5; metric2=2.5;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 6"}, + {"field": "@message", "value": "metric1=0.6; metric0=0.6;"}, + ], + [ + {"field": "@message", "value": "metric0=0.7; "}, + ], +] + + +MULTIPLE_TABLES_METRICS_RESULT = [ + {"Timestamp": ["Test timestamp 0", "N/A"], "metric0": ["0.0", "0.7"]}, + { + "Timestamp": ["Test timestamp 1", "Test timestamp 3", "Test timestamp 6"], + "metric0": ["0.1", "0.3", "0.6"], + "metric1": ["1.1", "1.3", "0.6"], + }, + {"Timestamp": ["Test timestamp 2"], "metric0": ["0.2"], "metric2": ["2.2"]}, + {"Timestamp": ["Test timestamp 4"], "metric1": ["1.4"], "metric2": ["2.4"]}, + {"Timestamp": ["Test timestamp 5"], "metric0": ["0.5"], "metric1": ["1.5"], "metric2": ["2.5"]}, +] + + +@pytest.mark.parametrize( + "log_insights_results, metrics_results", + [ + ([], []), + (MALFORMED_METRICS_LOG_LINES, []), + (SIMPLE_METRICS_LOG_LINES, SIMPLE_METRICS_RESULT), + (MULTIPLE_TABLES_METRICS_LOG_LINES, MULTIPLE_TABLES_METRICS_RESULT), + ], +) +def test_get_all_metrics_complete_results(aws_session, log_insights_results, metrics_results): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = { + "status": "Complete", + "results": log_insights_results, + } + + fetcher = CwlInsightsMetricsFetcher(aws_session) + result = fetcher.get_all_metrics_for_job("test_job") + assert result == metrics_results + + +def test_get_all_metrics_timeout(aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = {"status": "Queued"} + + fetcher = CwlInsightsMetricsFetcher(aws_session, 0.25, 0.5) + result = fetcher.get_all_metrics_for_job("test_job") + assert result == [] + + +@pytest.mark.xfail(raises=Exception) +def test_get_all_metrics_failed(aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = {"status": "Failed"} + + fetcher = CwlInsightsMetricsFetcher(aws_session) + fetcher.get_all_metrics_for_job("test_job") From b18a20226f066c2cc7696ffc4f01e4e509a7f1a0 Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Mon, 21 Jun 2021 16:21:53 -0700 Subject: [PATCH 2/8] incorporating feedback. --- src/braket/jobs/metrics/__init__.py | 1 + .../metrics/cwl_insights_metrics_fetcher.py | 117 ++++++------------ .../jobs/metrics/metrics_timeout_error.py | 18 +++ .../unit_tests/braket/aws/test_aws_session.py | 5 +- .../test_cwl_insights_metrics_fetcher.py | 6 +- 5 files changed, 61 insertions(+), 86 deletions(-) create mode 100644 src/braket/jobs/metrics/metrics_timeout_error.py diff --git a/src/braket/jobs/metrics/__init__.py b/src/braket/jobs/metrics/__init__.py index fb113bc16..9313dda69 100644 --- a/src/braket/jobs/metrics/__init__.py +++ b/src/braket/jobs/metrics/__init__.py @@ -12,3 +12,4 @@ # language governing permissions and limitations under the License. from braket.jobs.metrics.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher # noqa: F401 +from braket.jobs.metrics.metrics_timeout_error import MetricsTimeoutError # noqa: F401 diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py index 358f374c8..60eae9926 100644 --- a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -13,14 +13,17 @@ import re import time +from collections import defaultdict from logging import Logger, getLogger -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, FrozenSet, Iterator, List, Optional from braket.aws.aws_session import AwsSession +from braket.jobs.metrics.metrics_timeout_error import MetricsTimeoutError class CwlInsightsMetricsFetcher(object): + # TODO : Update this once we know the log group name for jobs. LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*(\d*\.?\d*)\s*;") TIMESTAMP = "Timestamp" @@ -58,12 +61,12 @@ def _get_metrics_results_sync(self, query_id: str) -> List[Any]: Returns: List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. """ - start_time = time.time() - while (time.time() - start_time) < self._poll_timeout_seconds: + timeout_time = time.time() + self._poll_timeout_seconds + while time.time() < timeout_time: response = self._logs_client.get_query_results(queryId=query_id) query_status = response["status"] if query_status in ["Failed", "Cancelled"]: - raise Exception(f"Query {query_id} failed with status {query_status}") + raise MetricsTimeoutError(f"Query {query_id} failed with status {query_status}") elif query_status == "Complete": return response["results"] else: @@ -71,24 +74,6 @@ def _get_metrics_results_sync(self, query_id: str) -> List[Any]: self._logger.warning(f"Timed out waiting for query {query_id}.") return [] - @staticmethod - def _metrics_id_from_metrics(metrics: Dict[str, float]) -> int: - """ - Determines the semi-unique ID for a set of metrics that will represent the table for that - set of columns. The current implementation doesn't treat a difference in order as unique; - therefore, the same metrics output in various different orders will map to the same table. - - Args: - metrics (Dict[str, float]): The set of metrics. - - Returns: - int : The Metrics ID. - """ - metrics_id = 0 - for column_name in metrics.keys(): - metrics_id += hash(column_name) - return metrics_id - @staticmethod def _get_metrics_from_log_line_matches(all_matches: Iterator) -> Dict[str, float]: """ @@ -109,93 +94,63 @@ def _get_metrics_from_log_line_matches(all_matches: Iterator) -> Dict[str, float return metrics @staticmethod - def _get_timestamp_from_log_line(log_line: List[Dict[str, Any]]) -> str: + def _get_element_from_log_line( + element_name: str, log_line: List[Dict[str, Any]] + ) -> Optional[str]: """ - Finds and returns the timestamp of a log line from CloudWatch Insights results. + Finds and returns an element of a log line from CloudWatch Insights results. Args: + element_name (str): The element to find. log_line (List[Dict[str, Any]]): An iterator for RegEx matches on a log line. Returns: - str : The timestamp of the log line or 'N/A' if no timestamp is found. + Optional[str] : The value of the element with the element name, or None if no such + element is found. """ for element in log_line: - if element["field"] == "@timestamp": + if element["field"] == element_name: return element["value"] - return "N/A" + return None + @staticmethod def _add_metrics( - self, - metrics_id: int, + columns: FrozenSet[str], metrics: Dict[str, float], all_metrics: Dict[int, Dict[str, List[Any]]], ) -> None: """ - Adds the given metrics to the appropriate table in 'all_metrics'. If 'all_metrics' does not - currently have a table that represents the metrics, a new entry will be created to - represent the data. When the table entry is created, a "Timestamp" field is also added by - default, since all log metrics should have a timestamp. + Adds the given metrics to the appropriate table in 'all_metrics'. Args: - metrics_id (int): An ID to represent the given set of metrics. + columns (FrozenSet[str]): The set of column names representing the metrics. metrics (Dict[str, float]): A set of metrics in the format { : }. all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. """ - if metrics_id not in all_metrics: - metrics_table = {} - for column_name in metrics.keys(): - metrics_table[column_name] = [] - metrics_table[self.TIMESTAMP] = [] - all_metrics[metrics_id] = metrics_table - metrics_table = all_metrics[metrics_id] + metrics_table = all_metrics[columns] for column_name in metrics.keys(): metrics_table[column_name].append(metrics[column_name]) def _parse_metrics_from_message( - self, message: str, all_metrics: Dict[int, Dict[str, List[Any]]] - ) -> Optional[int]: + self, timestamp: str, message: str, all_metrics: Dict[int, Dict[str, List[Any]]] + ) -> None: """ Parses a line from CloudWatch Logs to find all the metrics that have been logged - on that line. Any found metrics will be added to 'all_metrics'. + on that line. Any found metrics will be added to 'all_metrics'. The timestamp is + also added to match the corresponding values in 'all_metrics'. Args: + timestamp (str): A formatted string representing the timestamp for any found metrics. message (str): A log line from CloudWatch Logs. all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. - - Returns: - int : The ID to represent the given set of logged metrics, or None if no metrics - are found in the message. """ all_matches = self.METRICS_DEFINITIONS.finditer(message) metrics = self._get_metrics_from_log_line_matches(all_matches) if not metrics: return None - metrics_id = self._metrics_id_from_metrics(metrics) - self._add_metrics(metrics_id, metrics, all_metrics) - return metrics_id - - def _parse_metrics_from_result_entry( - self, result_entry: List[Any], timestamp: str, all_metrics: Dict[int, Dict[str, List[Any]]] - ) -> None: - """ - Finds the actual log line containing metrics from a given CloudWatch Insights result entry, - and adds them to 'all_metrics'. The timestamp is added to match the corresponding values in - 'all_metrics'. - - Args: - result_entry (List[Any]): A structured result from calling CloudWatch Insights to get - logs that contain metrics. A single entry will contain the message - (the actual line logged to output), the timestamp (generated by CloudWatch Logs), - and other metadata that we (currently) do not use. - timestamp (str): A formatted string representing the timestamp for any found metrics. - all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. - """ - for element in result_entry: - if element["field"] == "@message": - metrics_id = self._parse_metrics_from_message(element["value"], all_metrics) - if metrics_id is not None: - all_metrics[metrics_id][self.TIMESTAMP].append(timestamp) - break + columns = frozenset(metrics.keys()) + self._add_metrics(columns, metrics, all_metrics) + all_metrics[columns][self.TIMESTAMP].append(timestamp or "N/A") def _parse_log_line( self, result_entry: List[Any], all_metrics: Dict[int, Dict[str, List[Any]]] @@ -211,8 +166,10 @@ def _parse_log_line( and other metadata that we (currently) do not use. all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. """ - timestamp = self._get_timestamp_from_log_line(result_entry) - self._parse_metrics_from_result_entry(result_entry, timestamp, all_metrics) + message = self._get_element_from_log_line("@message", result_entry) + if message: + timestamp = self._get_element_from_log_line("@timestamp", result_entry) + self._parse_metrics_from_message(timestamp, message, all_metrics) def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, List[Any]]]: """ @@ -229,7 +186,7 @@ def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, Li Each table will have a set of metrics, indexed by the column name. The entries are not sorted. """ - all_metrics = {} + all_metrics = defaultdict(lambda: defaultdict(list)) for result in results: self._parse_log_line(result, all_metrics) return all_metrics @@ -277,8 +234,4 @@ def get_all_metrics_for_job( metric_data = self._parse_log_query_results(results) - metric_data_list = [] - for metric_graph in metric_data.keys(): - metric_data_list.append(metric_data[metric_graph]) - - return metric_data_list + return list(metric_data.values()) diff --git a/src/braket/jobs/metrics/metrics_timeout_error.py b/src/braket/jobs/metrics/metrics_timeout_error.py new file mode 100644 index 000000000..df6e9b195 --- /dev/null +++ b/src/braket/jobs/metrics/metrics_timeout_error.py @@ -0,0 +1,18 @@ +# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +class MetricsTimeoutError(Exception): + """Raised when retrieving metrics times out.""" + + pass diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index 000948e11..9ecf3655c 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -132,9 +132,10 @@ def test_retrieve_s3_object_body_client_error(boto_session): def test_create_logs_client(boto_session): - aws_session = AwsSession(boto_session=boto_session) + config = Mock() + aws_session = AwsSession(boto_session=boto_session, config=config) aws_session.create_logs_client() - boto_session.client.assert_called_with("logs", config=None) + boto_session.client.assert_called_with("logs", config=config) def test_get_device(boto_session): diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py index acc7a6edd..430b46d50 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -15,7 +15,7 @@ import pytest -from braket.jobs.metrics import CwlInsightsMetricsFetcher +from braket.jobs.metrics import CwlInsightsMetricsFetcher, MetricsTimeoutError @pytest.fixture @@ -123,6 +123,8 @@ def aws_session(): (MALFORMED_METRICS_LOG_LINES, []), (SIMPLE_METRICS_LOG_LINES, SIMPLE_METRICS_RESULT), (MULTIPLE_TABLES_METRICS_LOG_LINES, MULTIPLE_TABLES_METRICS_RESULT), + # TODO: https://app.asana.com/0/1199668788990775/1200502190825620 + # We should also test some real-world data, once we have it. ], ) def test_get_all_metrics_complete_results(aws_session, log_insights_results, metrics_results): @@ -152,7 +154,7 @@ def test_get_all_metrics_timeout(aws_session): assert result == [] -@pytest.mark.xfail(raises=Exception) +@pytest.mark.xfail(raises=MetricsTimeoutError) def test_get_all_metrics_failed(aws_session): logs_client_mock = Mock() aws_session.create_logs_client.return_value = logs_client_mock From 9b9c059f5d1210f62241a33efe32fc6eacc6386c Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Tue, 22 Jun 2021 09:40:18 -0700 Subject: [PATCH 3/8] Modifying the doc for returning the CW client, and updating copyright range. --- src/braket/aws/aws_session.py | 4 ++-- src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py | 2 +- src/braket/jobs/metrics/metrics_timeout_error.py | 2 +- .../braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 2097d7a95..ce9e19bb9 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -137,12 +137,12 @@ def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: obj = s3.Object(s3_bucket, s3_object_key) return obj.get()["Body"].read().decode("utf-8") - def create_logs_client(self) -> Any: + def create_logs_client(self) -> "boto3.session.Session.client": """ Create a CloudWatch Logs boto client. Returns: - Any: The CloudWatch Logs boto client. + 'boto3.session.Session.client': The CloudWatch Logs boto client. """ return self.boto_session.client("logs", config=self._config) diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py index 60eae9926..b41dedccd 100644 --- a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -1,4 +1,4 @@ -# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of diff --git a/src/braket/jobs/metrics/metrics_timeout_error.py b/src/braket/jobs/metrics/metrics_timeout_error.py index df6e9b195..774c284f2 100644 --- a/src/braket/jobs/metrics/metrics_timeout_error.py +++ b/src/braket/jobs/metrics/metrics_timeout_error.py @@ -1,4 +1,4 @@ -# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py index 430b46d50..aef5ff05b 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -1,4 +1,4 @@ -# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of From 5f02948deeb90c2cb65fad39b69029b149a4b364 Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Tue, 22 Jun 2021 09:42:05 -0700 Subject: [PATCH 4/8] Forgot two copyright notices. --- src/braket/jobs/__init__.py | 2 +- src/braket/jobs/metrics/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/braket/jobs/__init__.py b/src/braket/jobs/__init__.py index f4243de11..654905217 100644 --- a/src/braket/jobs/__init__.py +++ b/src/braket/jobs/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of diff --git a/src/braket/jobs/metrics/__init__.py b/src/braket/jobs/metrics/__init__.py index 9313dda69..24747a1b6 100644 --- a/src/braket/jobs/metrics/__init__.py +++ b/src/braket/jobs/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of From cbae744057e753ecf666b7d2f6b00285cfbb4565 Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Wed, 23 Jun 2021 16:19:14 -0700 Subject: [PATCH 5/8] More fixes per PR comments. --- src/braket/jobs/metrics/__init__.py | 2 +- .../metrics/cwl_insights_metrics_fetcher.py | 61 ++++++++++--------- ...metrics_timeout_error.py => exceptions.py} | 0 .../test_cwl_insights_metrics_fetcher.py | 34 +++++++---- 4 files changed, 55 insertions(+), 42 deletions(-) rename src/braket/jobs/metrics/{metrics_timeout_error.py => exceptions.py} (100%) diff --git a/src/braket/jobs/metrics/__init__.py b/src/braket/jobs/metrics/__init__.py index 24747a1b6..93cd5f23d 100644 --- a/src/braket/jobs/metrics/__init__.py +++ b/src/braket/jobs/metrics/__init__.py @@ -12,4 +12,4 @@ # language governing permissions and limitations under the License. from braket.jobs.metrics.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher # noqa: F401 -from braket.jobs.metrics.metrics_timeout_error import MetricsTimeoutError # noqa: F401 +from braket.jobs.metrics.exceptions import MetricsTimeoutError # noqa: F401 diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py index b41dedccd..5dfd10abb 100644 --- a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -18,14 +18,14 @@ from typing import Any, Dict, FrozenSet, Iterator, List, Optional from braket.aws.aws_session import AwsSession -from braket.jobs.metrics.metrics_timeout_error import MetricsTimeoutError +from braket.jobs.metrics.exceptions import MetricsTimeoutError class CwlInsightsMetricsFetcher(object): # TODO : Update this once we know the log group name for jobs. LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" - METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*(\d*\.?\d*)\s*;") + METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;") TIMESTAMP = "Timestamp" QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 @@ -74,25 +74,6 @@ def _get_metrics_results_sync(self, query_id: str) -> List[Any]: self._logger.warning(f"Timed out waiting for query {query_id}.") return [] - @staticmethod - def _get_metrics_from_log_line_matches(all_matches: Iterator) -> Dict[str, float]: - """ - Converts matches from a RegEx to a set of metrics. - - Args: - all_matches (Iterator): An iterator for RegEx matches on a log line. - - Returns: - Dict[str, float]: The set of metrics found by the RegEx. The result will be in the - format { : }. This implies that multiple metrics with - the same name will be deduped to the last instance of that metric. - """ - metrics = {} - for match in all_matches: - subgroup = match.groups() - metrics[subgroup[0]] = subgroup[1] - return metrics - @staticmethod def _get_element_from_log_line( element_name: str, log_line: List[Dict[str, Any]] @@ -108,10 +89,9 @@ def _get_element_from_log_line( Optional[str] : The value of the element with the element name, or None if no such element is found. """ - for element in log_line: - if element["field"] == element_name: - return element["value"] - return None + return next( + (element["value"] for element in log_line if element["field"] == element_name), None + ) @staticmethod def _add_metrics( @@ -131,6 +111,28 @@ def _add_metrics( for column_name in metrics.keys(): metrics_table[column_name].append(metrics[column_name]) + def _get_metrics_from_log_line_matches(self, all_matches: Iterator) -> Dict[str, float]: + """ + Converts matches from a RegEx to a set of metrics. + + Args: + all_matches (Iterator): An iterator for RegEx matches on a log line. + + Returns: + Dict[str, float]: The set of metrics found by the RegEx. The result will be in the + format { : }. This implies that multiple metrics with + the same name will be deduped to the last instance of that metric. + """ + metrics = {} + for match in all_matches: + subgroup = match.groups() + value = subgroup[1] + try: + metrics[subgroup[0]] = float(value) + except ValueError: + self._logger.warning(f"Unable to convert value {value} to a float.") + return metrics + def _parse_metrics_from_message( self, timestamp: str, message: str, all_metrics: Dict[int, Dict[str, List[Any]]] ) -> None: @@ -147,21 +149,21 @@ def _parse_metrics_from_message( all_matches = self.METRICS_DEFINITIONS.finditer(message) metrics = self._get_metrics_from_log_line_matches(all_matches) if not metrics: - return None + return columns = frozenset(metrics.keys()) self._add_metrics(columns, metrics, all_metrics) all_metrics[columns][self.TIMESTAMP].append(timestamp or "N/A") def _parse_log_line( - self, result_entry: List[Any], all_metrics: Dict[int, Dict[str, List[Any]]] + self, result_entry: List[Dict[str, Any]], all_metrics: Dict[int, Dict[str, List[Any]]] ) -> None: """ Parses the single entry from CloudWatch Insights results and adds any metrics it finds to 'all_metrics', along with the timestamp for the entry. Args: - result_entry (List[Any]): A structured result from calling CloudWatch Insights to get - logs that contain metrics. A single entry will contain the message + result_entry (List[Dict[str, Any]]): A structured result from calling CloudWatch + Insights to get logs that contain metrics. A single entry will contain the message (the actual line logged to output), the timestamp (generated by CloudWatch Logs), and other metadata that we (currently) do not use. all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. @@ -226,6 +228,7 @@ def get_all_metrics_for_job( startTime=query_start_time, endTime=query_end_time, queryString=query, + limit=10000, ) query_id = response["queryId"] diff --git a/src/braket/jobs/metrics/metrics_timeout_error.py b/src/braket/jobs/metrics/exceptions.py similarity index 100% rename from src/braket/jobs/metrics/metrics_timeout_error.py rename to src/braket/jobs/metrics/exceptions.py diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py index aef5ff05b..2c5a1e4f9 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -47,23 +47,33 @@ def aws_session(): SIMPLE_METRICS_LOG_LINES = [ [ {"field": "@timestamp", "value": "Test timestamp 0"}, - {"field": "@message", "value": "metric0=0.0; metric1=1.0; metric2=2.0"}, + {"field": "@message", "value": "metric0=0.0; metric1=1.0; metric2=2.0;"}, ], [ {"field": "@timestamp", "value": "Test timestamp 1"}, - {"field": "@message", "value": "metric0=0.1; metric1=1.1; metric2=2.1"}, + {"field": "@message", "value": "metric0=0.1; metric1=1.1; metric2=2.1;"}, ], [ {"field": "@timestamp", "value": "Test timestamp 2"}, - {"field": "@message", "value": "metric0=0.2; metric1=1.2; metric2=2.2"}, + {"field": "@message", "value": "metric0=0.2; metric1=1.2; metric2=2.2;"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 3"}, + {"field": "@message", "value": "metric0=-0.4; metric1=3.14e-22; metric2=3.14E22;"}, ], ] SIMPLE_METRICS_RESULT = [ { - "Timestamp": ["Test timestamp 0", "Test timestamp 1", "Test timestamp 2"], - "metric0": ["0.0", "0.1", "0.2"], - "metric1": ["1.0", "1.1", "1.2"], + "Timestamp": [ + "Test timestamp 0", + "Test timestamp 1", + "Test timestamp 2", + "Test timestamp 3", + ], + "metric0": [0.0, 0.1, 0.2, -0.4], + "metric1": [1.0, 1.1, 1.2, 3.14e-22], + "metric2": [2.0, 2.1, 2.2, 3.14e22], } ] @@ -104,15 +114,15 @@ def aws_session(): MULTIPLE_TABLES_METRICS_RESULT = [ - {"Timestamp": ["Test timestamp 0", "N/A"], "metric0": ["0.0", "0.7"]}, + {"Timestamp": ["Test timestamp 0", "N/A"], "metric0": [0.0, 0.7]}, { "Timestamp": ["Test timestamp 1", "Test timestamp 3", "Test timestamp 6"], - "metric0": ["0.1", "0.3", "0.6"], - "metric1": ["1.1", "1.3", "0.6"], + "metric0": [0.1, 0.3, 0.6], + "metric1": [1.1, 1.3, 0.6], }, - {"Timestamp": ["Test timestamp 2"], "metric0": ["0.2"], "metric2": ["2.2"]}, - {"Timestamp": ["Test timestamp 4"], "metric1": ["1.4"], "metric2": ["2.4"]}, - {"Timestamp": ["Test timestamp 5"], "metric0": ["0.5"], "metric1": ["1.5"], "metric2": ["2.5"]}, + {"Timestamp": ["Test timestamp 2"], "metric0": [0.2], "metric2": [2.2]}, + {"Timestamp": ["Test timestamp 4"], "metric1": [1.4], "metric2": [2.4]}, + {"Timestamp": ["Test timestamp 5"], "metric0": [0.5], "metric1": [1.5], "metric2": [2.5]}, ] From 345ae053946afe98d6aa62387272824f11b596ed Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Thu, 24 Jun 2021 14:14:13 -0700 Subject: [PATCH 6/8] Adding a metrics fetcher that gets logs directly from CloudWatch logs. --- src/braket/jobs/metrics/__init__.py | 4 +- .../metrics/cwl_insights_metrics_fetcher.py | 117 +++----------- src/braket/jobs/metrics/cwl_metrics.py | 97 +++++++++++ .../jobs/metrics/cwl_metrics_fetcher.py | 151 ++++++++++++++++++ src/braket/jobs/metrics/exceptions.py | 4 +- .../test_cwl_insights_metrics_fetcher.py | 138 ++++------------ .../braket/jobs/metrics/test_cwl_metrics.py | 102 ++++++++++++ .../jobs/metrics/test_cwl_metrics_fetcher.py | 135 ++++++++++++++++ 8 files changed, 548 insertions(+), 200 deletions(-) create mode 100644 src/braket/jobs/metrics/cwl_metrics.py create mode 100644 src/braket/jobs/metrics/cwl_metrics_fetcher.py create mode 100644 test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py create mode 100644 test/unit_tests/braket/jobs/metrics/test_cwl_metrics_fetcher.py diff --git a/src/braket/jobs/metrics/__init__.py b/src/braket/jobs/metrics/__init__.py index 93cd5f23d..0af7c1212 100644 --- a/src/braket/jobs/metrics/__init__.py +++ b/src/braket/jobs/metrics/__init__.py @@ -12,4 +12,6 @@ # language governing permissions and limitations under the License. from braket.jobs.metrics.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher # noqa: F401 -from braket.jobs.metrics.exceptions import MetricsTimeoutError # noqa: F401 +from braket.jobs.metrics.cwl_metrics import CwlMetrics # noqa: F401 +from braket.jobs.metrics.cwl_metrics_fetcher import CwlMetricsFetcher # noqa: F401 +from braket.jobs.metrics.exceptions import MetricsRetrievalError # noqa: F401 diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py index 5dfd10abb..3454d25b3 100644 --- a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -13,20 +13,18 @@ import re import time -from collections import defaultdict from logging import Logger, getLogger -from typing import Any, Dict, FrozenSet, Iterator, List, Optional +from typing import Any, Dict, List, Optional from braket.aws.aws_session import AwsSession -from braket.jobs.metrics.exceptions import MetricsTimeoutError +from braket.jobs.metrics.cwl_metrics import CwlMetrics +from braket.jobs.metrics.exceptions import MetricsRetrievalError class CwlInsightsMetricsFetcher(object): # TODO : Update this once we know the log group name for jobs. LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" - METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;") - TIMESTAMP = "Timestamp" QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 def __init__( @@ -51,29 +49,6 @@ def __init__( self._logger = logger self._logs_client = aws_session.create_logs_client() - def _get_metrics_results_sync(self, query_id: str) -> List[Any]: - """ - Waits for the CloudWatch Insights query to complete and returns all the results. - - Args: - query_id (str): CloudWatch Insights query ID. - - Returns: - List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. - """ - timeout_time = time.time() + self._poll_timeout_seconds - while time.time() < timeout_time: - response = self._logs_client.get_query_results(queryId=query_id) - query_status = response["status"] - if query_status in ["Failed", "Cancelled"]: - raise MetricsTimeoutError(f"Query {query_id} failed with status {query_status}") - elif query_status == "Complete": - return response["results"] - else: - time.sleep(self._poll_interval_seconds) - self._logger.warning(f"Timed out waiting for query {query_id}.") - return [] - @staticmethod def _get_element_from_log_line( element_name: str, log_line: List[Dict[str, Any]] @@ -93,70 +68,30 @@ def _get_element_from_log_line( (element["value"] for element in log_line if element["field"] == element_name), None ) - @staticmethod - def _add_metrics( - columns: FrozenSet[str], - metrics: Dict[str, float], - all_metrics: Dict[int, Dict[str, List[Any]]], - ) -> None: - """ - Adds the given metrics to the appropriate table in 'all_metrics'. - - Args: - columns (FrozenSet[str]): The set of column names representing the metrics. - metrics (Dict[str, float]): A set of metrics in the format { : }. - all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. - """ - metrics_table = all_metrics[columns] - for column_name in metrics.keys(): - metrics_table[column_name].append(metrics[column_name]) - - def _get_metrics_from_log_line_matches(self, all_matches: Iterator) -> Dict[str, float]: + def _get_metrics_results_sync(self, query_id: str) -> List[Any]: """ - Converts matches from a RegEx to a set of metrics. + Waits for the CloudWatch Insights query to complete and returns all the results. Args: - all_matches (Iterator): An iterator for RegEx matches on a log line. + query_id (str): CloudWatch Insights query ID. Returns: - Dict[str, float]: The set of metrics found by the RegEx. The result will be in the - format { : }. This implies that multiple metrics with - the same name will be deduped to the last instance of that metric. - """ - metrics = {} - for match in all_matches: - subgroup = match.groups() - value = subgroup[1] - try: - metrics[subgroup[0]] = float(value) - except ValueError: - self._logger.warning(f"Unable to convert value {value} to a float.") - return metrics - - def _parse_metrics_from_message( - self, timestamp: str, message: str, all_metrics: Dict[int, Dict[str, List[Any]]] - ) -> None: + List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. """ - Parses a line from CloudWatch Logs to find all the metrics that have been logged - on that line. Any found metrics will be added to 'all_metrics'. The timestamp is - also added to match the corresponding values in 'all_metrics'. + timeout_time = time.time() + self._poll_timeout_seconds + while time.time() < timeout_time: + response = self._logs_client.get_query_results(queryId=query_id) + query_status = response["status"] + if query_status in ["Failed", "Cancelled"]: + raise MetricsRetrievalError(f"Query {query_id} failed with status {query_status}") + elif query_status == "Complete": + return response["results"] + else: + time.sleep(self._poll_interval_seconds) + self._logger.warning(f"Timed out waiting for query {query_id}.") + return [] - Args: - timestamp (str): A formatted string representing the timestamp for any found metrics. - message (str): A log line from CloudWatch Logs. - all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics. - """ - all_matches = self.METRICS_DEFINITIONS.finditer(message) - metrics = self._get_metrics_from_log_line_matches(all_matches) - if not metrics: - return - columns = frozenset(metrics.keys()) - self._add_metrics(columns, metrics, all_metrics) - all_metrics[columns][self.TIMESTAMP].append(timestamp or "N/A") - - def _parse_log_line( - self, result_entry: List[Dict[str, Any]], all_metrics: Dict[int, Dict[str, List[Any]]] - ) -> None: + def _parse_log_line(self, result_entry: List[Dict[str, Any]], metrics: CwlMetrics) -> None: """ Parses the single entry from CloudWatch Insights results and adds any metrics it finds to 'all_metrics', along with the timestamp for the entry. @@ -171,7 +106,7 @@ def _parse_log_line( message = self._get_element_from_log_line("@message", result_entry) if message: timestamp = self._get_element_from_log_line("@timestamp", result_entry) - self._parse_metrics_from_message(timestamp, message, all_metrics) + metrics.add_metrics_from_log_message(timestamp, message) def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, List[Any]]]: """ @@ -188,10 +123,10 @@ def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, Li Each table will have a set of metrics, indexed by the column name. The entries are not sorted. """ - all_metrics = defaultdict(lambda: defaultdict(list)) + metrics = CwlMetrics() for result in results: - self._parse_log_line(result, all_metrics) - return all_metrics + self._parse_log_line(result, metrics) + return metrics.get_metric_data_as_list() def get_all_metrics_for_job( self, job_name: str, job_start_time: int = None, job_end_time: int = None @@ -235,6 +170,4 @@ def get_all_metrics_for_job( results = self._get_metrics_results_sync(query_id) - metric_data = self._parse_log_query_results(results) - - return list(metric_data.values()) + return self._parse_log_query_results(results) diff --git a/src/braket/jobs/metrics/cwl_metrics.py b/src/braket/jobs/metrics/cwl_metrics.py new file mode 100644 index 000000000..5b14acc47 --- /dev/null +++ b/src/braket/jobs/metrics/cwl_metrics.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import re +from collections import defaultdict +from logging import Logger, getLogger +from typing import Dict, FrozenSet, Iterator + + +class CwlMetrics(object): + + METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;") + TIMESTAMP = "Timestamp" + + def __init__( + self, + logger: Logger = getLogger(__name__), + ): + self._logger = logger + self.metric_tables = defaultdict(lambda: defaultdict(list)) + + def _add_metrics_to_appropriate_table( + self, + columns: FrozenSet[str], + metrics: Dict[str, float], + ) -> None: + """ + Adds the given metrics to the appropriate table. + + Args: + columns (FrozenSet[str]): The set of column names representing the metrics. + metrics (Dict[str, float]): A set of metrics in the format { : }. + """ + metrics_table = self.metric_tables[columns] + for column_name in metrics.keys(): + metrics_table[column_name].append(metrics[column_name]) + + def _get_metrics_from_log_line_matches(self, all_matches: Iterator) -> Dict[str, float]: + """ + Converts matches from a RegEx to a set of metrics. + + Args: + all_matches (Iterator): An iterator for RegEx matches on a log line. + + Returns: + Dict[str, float]: The set of metrics found by the RegEx. The result will be in the + format { : }. This implies that multiple metrics with + the same name will be deduped to the last instance of that metric. + """ + metrics = {} + for match in all_matches: + subgroup = match.groups() + value = subgroup[1] + try: + metrics[subgroup[0]] = float(value) + except ValueError: + self._logger.warning(f"Unable to convert value {value} to a float.") + return metrics + + def add_metrics_from_log_message(self, timestamp: str, message: str) -> None: + """ + Parses a line from CloudWatch Logs adds all the metrics that have been logged + on that line. The timestamp is also added to match the corresponding values. + + Args: + timestamp (str): A formatted string representing the timestamp for any found metrics. + message (str): A log line from CloudWatch Logs. + """ + if not message: + return + all_matches = self.METRICS_DEFINITIONS.finditer(message) + parsed_metrics = self._get_metrics_from_log_line_matches(all_matches) + if not parsed_metrics: + return + columns = frozenset(parsed_metrics.keys()) + self._add_metrics_to_appropriate_table(columns, parsed_metrics) + self.metric_tables[columns][self.TIMESTAMP].append(timestamp or "N/A") + + def get_metric_data_as_list(self): + """ + Gets all the metrics data for all tables, as a list. + + Returns: + List[Dict[str, List[Any]]] : The list of all tables. Each table will have a set + of metrics, indexed by the column name. + """ + return list(self.metric_tables.values()) diff --git a/src/braket/jobs/metrics/cwl_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_metrics_fetcher.py new file mode 100644 index 000000000..5f92c3eb0 --- /dev/null +++ b/src/braket/jobs/metrics/cwl_metrics_fetcher.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import time +from logging import Logger, getLogger +from typing import Any, Dict, List + +from braket.aws.aws_session import AwsSession +from braket.jobs.metrics.cwl_metrics import CwlMetrics + + +class CwlMetricsFetcher(object): + + # TODO : Update this once we know the log group name for jobs. + LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" + + def __init__( + self, + aws_session: AwsSession, + poll_timeout_seconds: float = 10, + logger: Logger = getLogger(__name__), + ): + """ + Args: + aws_session (AwsSession): AwsSession to connect to AWS with. + poll_timeout_seconds (float): The polling timeout for retrieving the metrics, + in seconds. Default: 10 seconds. + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for task to be in a terminal state. Default is `getLogger(__name__)` + """ + self._poll_timeout_seconds = poll_timeout_seconds + self._logger = logger + self._logs_client = aws_session.create_logs_client() + + @staticmethod + def _is_metrics_message(message): + """ + Returns true if a given message is designated as containing Metrics. + + Args: + message (str): The message to check. + + Returns: + True if the given message is designated as containing Metrics, False otherwise. + """ + if message: + return "Metrics -" in message + return False + + def _get_metrics_from_log_stream( + self, + stream_name: str, + timeout_time: float, + metrics: CwlMetrics, + ) -> None: + """ + Synchronously retrieves the algorithm metrics logged in a given job log stream. + + Args: + stream_name (str): The name of the log stream. + timeout_time (float) : We stop getting metrics if the current time is beyond + the timeout time. + metrics (CwlMetrics) : The metrics object to add the metrics to. + + Returns: + None + """ + kwargs = { + "logGroupName": self.LOG_GROUP_NAME, + "logStreamName": stream_name, + "startFromHead": True, + "limit": 10000, + } + + previous_token = None + while time.time() < timeout_time: + response = self._logs_client.get_log_events(**kwargs) + for event in response.get("events"): + message = event.get("message") + if self._is_metrics_message(message): + metrics.add_metrics_from_log_message(event.get("timestamp"), message) + next_token = response.get("nextForwardToken") + if not next_token or next_token == previous_token: + return + previous_token = next_token + kwargs["nextToken"] = next_token + self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") + + def _get_log_streams_for_job(self, job_name: str, timeout_time: float) -> List[str]: + """ + Retrieves the list of log streams relevant to a job. + + Args: + job_name (str): The name of the job. + timeout_time (float) : We stop getting metrics if the current time is beyond + the timeout time. + Returns: + List[str] : a list of log stream names for the given job. + """ + kwargs = { + "logGroupName": self.LOG_GROUP_NAME, + "logStreamNamePrefix": job_name + "/algo-", + } + log_streams = [] + while time.time() < timeout_time: + response = self._logs_client.describe_log_streams(**kwargs) + streams = response.get("logStreams") + if streams: + for stream in streams: + name = stream.get("logStreamName") + if name: + log_streams.append(name) + next_token = response.get("nextToken") + if not next_token: + return log_streams + kwargs["nextToken"] = next_token + self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") + return log_streams + + def get_all_metrics_for_job(self, job_name: str) -> List[Dict[str, List[Any]]]: + """ + Synchronously retrieves all the algorithm metrics logged by a given Job. + + Args: + job_name (str): The name of the Job. The name must be exact to ensure only the relevant + metrics are retrieved. + + Returns: + List[Dict[str, List[Any]]] : The list of all metrics that can be found in the + CloudWatch Logs results. Each item in the list can be thought of as a separate + table. Each table will have a set of metrics, indexed by the column name. + """ + timeout_time = time.time() + self._poll_timeout_seconds + + metrics = CwlMetrics() + + log_streams = self._get_log_streams_for_job(job_name, timeout_time) + for log_stream in log_streams: + self._get_metrics_from_log_stream(log_stream, timeout_time, metrics) + + return metrics.get_metric_data_as_list() diff --git a/src/braket/jobs/metrics/exceptions.py b/src/braket/jobs/metrics/exceptions.py index 774c284f2..677a3a447 100644 --- a/src/braket/jobs/metrics/exceptions.py +++ b/src/braket/jobs/metrics/exceptions.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. -class MetricsTimeoutError(Exception): - """Raised when retrieving metrics times out.""" +class MetricsRetrievalError(Exception): + """Raised when retrieving metrics fails.""" pass diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py index 2c5a1e4f9..73c5c1164 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -11,11 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from unittest.mock import Mock +from unittest.mock import Mock, call, patch import pytest -from braket.jobs.metrics import CwlInsightsMetricsFetcher, MetricsTimeoutError +from braket.jobs.metrics import CwlInsightsMetricsFetcher, MetricsRetrievalError @pytest.fixture @@ -24,132 +24,59 @@ def aws_session(): return _aws_session -MALFORMED_METRICS_LOG_LINES = [ +EXAMPLE_METRICS_LOG_LINES = [ [ {"field": "@timestamp", "value": "Test timestamp 0"}, - {"field": "@message", "value": ""}, + {"field": "@message", "value": "Test value 0"}, ], [ {"field": "@timestamp", "value": "Test timestamp 1"}, - {"field": "@message", "value": "Test Test"}, + {"field": "@message", "value": "Test value 1"}, ], [ {"field": "@timestamp", "value": "Test timestamp 2"}, - {"field": "@message", "value": "metric0=not_a_number;"}, ], - [{"field": "@timestamp", "value": "Test timestamp 0"}], [ - {"field": "@unknown", "value": "Unknown"}, + {"field": "@message", "value": "Test value 3"}, ], + [], ] - -SIMPLE_METRICS_LOG_LINES = [ - [ - {"field": "@timestamp", "value": "Test timestamp 0"}, - {"field": "@message", "value": "metric0=0.0; metric1=1.0; metric2=2.0;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 1"}, - {"field": "@message", "value": "metric0=0.1; metric1=1.1; metric2=2.1;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 2"}, - {"field": "@message", "value": "metric0=0.2; metric1=1.2; metric2=2.2;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 3"}, - {"field": "@message", "value": "metric0=-0.4; metric1=3.14e-22; metric2=3.14E22;"}, - ], -] - -SIMPLE_METRICS_RESULT = [ - { - "Timestamp": [ - "Test timestamp 0", - "Test timestamp 1", - "Test timestamp 2", - "Test timestamp 3", - ], - "metric0": [0.0, 0.1, 0.2, -0.4], - "metric1": [1.0, 1.1, 1.2, 3.14e-22], - "metric2": [2.0, 2.1, 2.2, 3.14e22], - } -] - - -MULTIPLE_TABLES_METRICS_LOG_LINES = [ - [ - {"field": "@timestamp", "value": "Test timestamp 0"}, - {"field": "@message", "value": "metric0=0.0;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 1"}, - {"field": "@message", "value": "metric0=0.1; metric1=1.1;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 2"}, - {"field": "@message", "value": "metric0=0.2; metric2=2.2;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 3"}, - {"field": "@message", "value": "metric0=0.3; metric1=1.3;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 4"}, - {"field": "@message", "value": "metric1=1.4; metric2=2.4;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 5"}, - {"field": "@message", "value": "metric0=0.5; metric1=1.5; metric2=2.5;"}, - ], - [ - {"field": "@timestamp", "value": "Test timestamp 6"}, - {"field": "@message", "value": "metric1=0.6; metric0=0.6;"}, - ], - [ - {"field": "@message", "value": "metric0=0.7; "}, - ], +EXPECTED_CALL_LIST = [ + call("Test timestamp 0", "Test value 0"), + call("Test timestamp 1", "Test value 1"), + call(None, "Test value 3"), ] -MULTIPLE_TABLES_METRICS_RESULT = [ - {"Timestamp": ["Test timestamp 0", "N/A"], "metric0": [0.0, 0.7]}, - { - "Timestamp": ["Test timestamp 1", "Test timestamp 3", "Test timestamp 6"], - "metric0": [0.1, 0.3, 0.6], - "metric1": [1.1, 1.3, 0.6], - }, - {"Timestamp": ["Test timestamp 2"], "metric0": [0.2], "metric2": [2.2]}, - {"Timestamp": ["Test timestamp 4"], "metric1": [1.4], "metric2": [2.4]}, - {"Timestamp": ["Test timestamp 5"], "metric0": [0.5], "metric1": [1.5], "metric2": [2.5]}, -] - - -@pytest.mark.parametrize( - "log_insights_results, metrics_results", - [ - ([], []), - (MALFORMED_METRICS_LOG_LINES, []), - (SIMPLE_METRICS_LOG_LINES, SIMPLE_METRICS_RESULT), - (MULTIPLE_TABLES_METRICS_LOG_LINES, MULTIPLE_TABLES_METRICS_RESULT), - # TODO: https://app.asana.com/0/1199668788990775/1200502190825620 - # We should also test some real-world data, once we have it. - ], -) -def test_get_all_metrics_complete_results(aws_session, log_insights_results, metrics_results): +@patch("braket.jobs.metrics.cwl_insights_metrics_fetcher.CwlMetrics.get_metric_data_as_list") +@patch("braket.jobs.metrics.cwl_insights_metrics_fetcher.CwlMetrics.add_metrics_from_log_message") +def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aws_session): logs_client_mock = Mock() aws_session.create_logs_client.return_value = logs_client_mock logs_client_mock.start_query.return_value = {"queryId": "test"} logs_client_mock.get_query_results.return_value = { "status": "Complete", - "results": log_insights_results, + "results": EXAMPLE_METRICS_LOG_LINES, } + expected_result = ["Test"] + mock_get_metrics.return_value = expected_result fetcher = CwlInsightsMetricsFetcher(aws_session) - result = fetcher.get_all_metrics_for_job("test_job") - assert result == metrics_results + + result = fetcher.get_all_metrics_for_job("test_job", job_start_time=1, job_end_time=2) + logs_client_mock.get_query_results.assert_called_with(queryId="test") + logs_client_mock.start_query.assert_called_with( + logGroupName="/aws/lambda/my-python-test-function", + startTime=1, + endTime=2, + queryString="fields @timestamp, @message | filter @logStream like /^test_job$/" + " | filter @message like /^Metrics - /", + limit=10000, + ) + assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST + assert result == expected_result def test_get_all_metrics_timeout(aws_session): @@ -159,12 +86,13 @@ def test_get_all_metrics_timeout(aws_session): logs_client_mock.start_query.return_value = {"queryId": "test"} logs_client_mock.get_query_results.return_value = {"status": "Queued"} - fetcher = CwlInsightsMetricsFetcher(aws_session, 0.25, 0.5) + fetcher = CwlInsightsMetricsFetcher(aws_session, 0.1, 0.2) result = fetcher.get_all_metrics_for_job("test_job") + logs_client_mock.get_query_results.assert_called() assert result == [] -@pytest.mark.xfail(raises=MetricsTimeoutError) +@pytest.mark.xfail(raises=MetricsRetrievalError) def test_get_all_metrics_failed(aws_session): logs_client_mock = Mock() aws_session.create_logs_client.return_value = logs_client_mock diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py new file mode 100644 index 000000000..b67068231 --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import pytest + +from braket.jobs.metrics import CwlMetrics + +MALFORMED_METRICS_LOG_LINES = [ + {"timestamp": "Test timestamp 0", "message": ""}, + {"timestamp": "Test timestamp 1", "message": "No metrics prefix metric0=2.0"}, + {"timestamp": "Test timestamp 2", "message": "Metrics - metric0=not_a_number;"}, + {"timestamp": "Test timestamp 3"}, + {"unknown": "Unknown"}, +] + + +SIMPLE_METRICS_LOG_LINES = [ + { + "timestamp": "Test timestamp 0", + "message": "Metrics - metric0=0.0; metric1=1.0; metric2=2.0;", + }, + { + "timestamp": "Test timestamp 1", + "message": "Metrics - metric0=0.1; metric1=1.1; metric2=2.1;", + }, + { + "timestamp": "Test timestamp 2", + "message": "Metrics - metric0=0.2; metric1=1.2; metric2=2.2;", + }, + { + "timestamp": "Test timestamp 3", + "message": "Metrics - metric0=-0.4; metric1=3.14e-22; metric2=3.14E22;", + }, +] + +SIMPLE_METRICS_RESULT = [ + { + "Timestamp": [ + "Test timestamp 0", + "Test timestamp 1", + "Test timestamp 2", + "Test timestamp 3", + ], + "metric0": [0.0, 0.1, 0.2, -0.4], + "metric1": [1.0, 1.1, 1.2, 3.14e-22], + "metric2": [2.0, 2.1, 2.2, 3.14e22], + } +] + +MULTIPLE_TABLES_METRICS_LOG_LINES = [ + {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.0;"}, + {"timestamp": "Test timestamp 1", "message": "Metrics - metric0=0.1; metric1=1.1;"}, + {"timestamp": "Test timestamp 2", "message": "Metrics - metric0=0.2; metric2=2.2;"}, + {"timestamp": "Test timestamp 3", "message": "Metrics - metric0=0.3; metric1=1.3;"}, + {"timestamp": "Test timestamp 4", "message": "Metrics - metric1=1.4; metric2=2.4;"}, + { + "timestamp": "Test timestamp 5", + "message": "Metrics - metric0=0.5; metric1=1.5; metric2=2.5;", + }, + {"timestamp": "Test timestamp 6", "message": "Metrics - metric1=0.6; metric0=0.6;"}, + {"message": "Metrics - metric0=0.7; "}, +] + +MULTIPLE_TABLES_METRICS_RESULT = [ + {"Timestamp": ["Test timestamp 0", "N/A"], "metric0": [0.0, 0.7]}, + { + "Timestamp": ["Test timestamp 1", "Test timestamp 3", "Test timestamp 6"], + "metric0": [0.1, 0.3, 0.6], + "metric1": [1.1, 1.3, 0.6], + }, + {"Timestamp": ["Test timestamp 2"], "metric0": [0.2], "metric2": [2.2]}, + {"Timestamp": ["Test timestamp 4"], "metric1": [1.4], "metric2": [2.4]}, + {"Timestamp": ["Test timestamp 5"], "metric0": [0.5], "metric1": [1.5], "metric2": [2.5]}, +] + + +@pytest.mark.parametrize( + "log_events, metrics_results", + [ + ([], []), + (MALFORMED_METRICS_LOG_LINES, []), + (SIMPLE_METRICS_LOG_LINES, SIMPLE_METRICS_RESULT), + (MULTIPLE_TABLES_METRICS_LOG_LINES, MULTIPLE_TABLES_METRICS_RESULT), + # TODO: https://app.asana.com/0/1199668788990775/1200502190825620 + # We should also test some real-world data, once we have it. + ], +) +def test_get_all_metrics_complete_results(log_events, metrics_results): + metrics = CwlMetrics() + for log_event in log_events: + metrics.add_metrics_from_log_message(log_event.get("timestamp"), log_event.get("message")) + assert metrics.get_metric_data_as_list() == metrics_results diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics_fetcher.py new file mode 100644 index 000000000..363b406be --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics_fetcher.py @@ -0,0 +1,135 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from unittest.mock import Mock, call, patch + +import pytest + +from braket.jobs.metrics import CwlMetricsFetcher + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +EXAMPLE_METRICS_LOG_LINES = [ + { + "timestamp": "Test timestamp 0", + "message": "Metrics - Test value 0", + }, + { + "timestamp": "Test timestamp 1", + "message": "Metrics - Test value 1", + }, + { + "timestamp": "Test timestamp 2", + }, + { + "message": "Metrics - Test value 3", + }, + { + # This metrics fetcher will filter out log line that don't have a "Metrics -" tag. + "message": "No prefix, Test value 4", + }, +] + +EXPECTED_CALL_LIST = [ + call("Test timestamp 0", "Metrics - Test value 0"), + call("Test timestamp 1", "Metrics - Test value 1"), + call(None, "Metrics - Test value 3"), +] + + +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.get_metric_data_as_list") +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.add_metrics_from_log_message") +def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}, {}] + } + logs_client_mock.get_log_events.return_value = { + "events": EXAMPLE_METRICS_LOG_LINES, + "nextForwardToken": None, + } + expected_result = ["Test"] + mock_get_metrics.return_value = expected_result + + fetcher = CwlMetricsFetcher(aws_session) + result = fetcher.get_all_metrics_for_job("test_job") + assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST + assert result == expected_result + + +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.add_metrics_from_log_message") +def test_get_log_streams_timeout(mock_add_metrics, aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}], + "nextToken": "forever", + } + logs_client_mock.get_log_events.return_value = { + "events": EXAMPLE_METRICS_LOG_LINES, + } + + fetcher = CwlMetricsFetcher(aws_session, 0.1) + result = fetcher.get_all_metrics_for_job("test_job") + mock_add_metrics.assert_not_called() + assert result == [] + + +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.add_metrics_from_log_message") +def test_get_no_streams_returned(mock_add_metrics, aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = {} + + fetcher = CwlMetricsFetcher(aws_session) + result = fetcher.get_all_metrics_for_job("test_job") + logs_client_mock.describe_log_streams.assert_called() + mock_add_metrics.assert_not_called() + assert result == [] + + +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.get_metric_data_as_list") +@patch("braket.jobs.metrics.cwl_metrics_fetcher.CwlMetrics.add_metrics_from_log_message") +def test_get_metrics_timeout(mock_add_metrics, mock_get_metrics, aws_session): + logs_client_mock = Mock() + aws_session.create_logs_client.return_value = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}] + } + logs_client_mock.get_log_events.side_effect = get_log_events_forever + expected_result = ["Test"] + mock_get_metrics.return_value = expected_result + + fetcher = CwlMetricsFetcher(aws_session, 0.1) + result = fetcher.get_all_metrics_for_job("test_job") + logs_client_mock.get_log_events.assert_called() + mock_add_metrics.assert_called() + assert result == expected_result + + +def get_log_events_forever(*args, **kwargs): + next_token = "1" + token = kwargs.get("nextToken") + if token and token == "1": + next_token = "2" + return {"events": EXAMPLE_METRICS_LOG_LINES, "nextForwardToken": next_token} From a3523202a38b466b885da6845069ccac528e334b Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Thu, 24 Jun 2021 14:35:26 -0700 Subject: [PATCH 7/8] Updating braket jobs log group name. --- src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py | 3 +-- src/braket/jobs/metrics/cwl_metrics_fetcher.py | 3 +-- .../braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py index 3454d25b3..6fe58411d 100644 --- a/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py @@ -23,8 +23,7 @@ class CwlInsightsMetricsFetcher(object): - # TODO : Update this once we know the log group name for jobs. - LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" + LOG_GROUP_NAME = "/aws/braket/jobs" QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 def __init__( diff --git a/src/braket/jobs/metrics/cwl_metrics_fetcher.py b/src/braket/jobs/metrics/cwl_metrics_fetcher.py index 5f92c3eb0..c9849c2f2 100644 --- a/src/braket/jobs/metrics/cwl_metrics_fetcher.py +++ b/src/braket/jobs/metrics/cwl_metrics_fetcher.py @@ -21,8 +21,7 @@ class CwlMetricsFetcher(object): - # TODO : Update this once we know the log group name for jobs. - LOG_GROUP_NAME = "/aws/lambda/my-python-test-function" + LOG_GROUP_NAME = "/aws/braket/jobs" def __init__( self, diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py index 73c5c1164..e35d1dd36 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_insights_metrics_fetcher.py @@ -68,7 +68,7 @@ def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aw result = fetcher.get_all_metrics_for_job("test_job", job_start_time=1, job_end_time=2) logs_client_mock.get_query_results.assert_called_with(queryId="test") logs_client_mock.start_query.assert_called_with( - logGroupName="/aws/lambda/my-python-test-function", + logGroupName="/aws/braket/jobs", startTime=1, endTime=2, queryString="fields @timestamp, @message | filter @logStream like /^test_job$/" From 6e5b248ed8db7f28cc4ba8e19521909ffd58e717 Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Thu, 24 Jun 2021 14:46:20 -0700 Subject: [PATCH 8/8] Fixing comment in one of the test cases. --- .../braket/jobs/metrics/test_cwl_metrics.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py index b67068231..1af95b912 100644 --- a/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py +++ b/test/unit_tests/braket/jobs/metrics/test_cwl_metrics.py @@ -17,9 +17,10 @@ MALFORMED_METRICS_LOG_LINES = [ {"timestamp": "Test timestamp 0", "message": ""}, - {"timestamp": "Test timestamp 1", "message": "No metrics prefix metric0=2.0"}, - {"timestamp": "Test timestamp 2", "message": "Metrics - metric0=not_a_number;"}, - {"timestamp": "Test timestamp 3"}, + {"timestamp": "Test timestamp 1", "message": "No semicolon metric0=2.0"}, + {"timestamp": "Test timestamp 2", "message": "metric0=not_a_number;"}, + {"timestamp": "Test timestamp 3", "message": "also not a number metric0=2 . 0;"}, + {"timestamp": "Test timestamp 4"}, {"unknown": "Unknown"}, ] @@ -27,15 +28,15 @@ SIMPLE_METRICS_LOG_LINES = [ { "timestamp": "Test timestamp 0", - "message": "Metrics - metric0=0.0; metric1=1.0; metric2=2.0;", + "message": "Metrics - metric0=0.0; metric1=1.0; metric2=2.0 ;", }, { "timestamp": "Test timestamp 1", - "message": "Metrics - metric0=0.1; metric1=1.1; metric2=2.1;", + "message": "Metrics - metric0=0.1; metric1=1.1; metric2= 2.1;", }, { "timestamp": "Test timestamp 2", - "message": "Metrics - metric0=0.2; metric1=1.2; metric2=2.2;", + "message": "Metrics - metric0=0.2; metric1=1.2; metric2= 2.2 ;", }, { "timestamp": "Test timestamp 3",