Skip to content

Commit

Permalink
Merge pull request #2 from aws/jobmetrics
Browse files Browse the repository at this point in the history
feature: add utility class for getting jobs metrics
  • Loading branch information
krneta authored Jun 28, 2021
2 parents a24894c + 6e5b248 commit f427f08
Show file tree
Hide file tree
Showing 11 changed files with 824 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str:
obj = s3.Object(s3_bucket, s3_object_key)
return obj.get()["Body"].read().decode("utf-8")

def create_logs_client(self) -> "boto3.session.Session.client":
"""
Create a CloudWatch Logs boto client.
Returns:
'boto3.session.Session.client': The CloudWatch Logs boto client.
"""
return self.boto_session.client("logs", config=self._config)

def get_device(self, arn: str) -> Dict[str, Any]:
"""
Calls the Amazon Braket `get_device` API to
Expand Down
12 changes: 12 additions & 0 deletions src/braket/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
17 changes: 17 additions & 0 deletions src/braket/jobs/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.jobs.metrics.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher # noqa: F401
from braket.jobs.metrics.cwl_metrics import CwlMetrics # noqa: F401
from braket.jobs.metrics.cwl_metrics_fetcher import CwlMetricsFetcher # noqa: F401
from braket.jobs.metrics.exceptions import MetricsRetrievalError # noqa: F401
172 changes: 172 additions & 0 deletions src/braket/jobs/metrics/cwl_insights_metrics_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import re
import time
from logging import Logger, getLogger
from typing import Any, Dict, List, Optional

from braket.aws.aws_session import AwsSession
from braket.jobs.metrics.cwl_metrics import CwlMetrics
from braket.jobs.metrics.exceptions import MetricsRetrievalError


class CwlInsightsMetricsFetcher(object):

LOG_GROUP_NAME = "/aws/braket/jobs"
QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60

def __init__(
self,
aws_session: AwsSession,
poll_timeout_seconds: float = 10,
poll_interval_seconds: float = 1,
logger: Logger = getLogger(__name__),
):
"""
Args:
aws_session (AwsSession): AwsSession to connect to AWS with.
poll_timeout_seconds (float): The polling timeout for retrieving the metrics,
in seconds. Default: 10 seconds.
poll_interval_seconds (float): The polling interval for results in seconds.
Default: 1 second.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
"""
self._poll_timeout_seconds = poll_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._logger = logger
self._logs_client = aws_session.create_logs_client()

@staticmethod
def _get_element_from_log_line(
element_name: str, log_line: List[Dict[str, Any]]
) -> Optional[str]:
"""
Finds and returns an element of a log line from CloudWatch Insights results.
Args:
element_name (str): The element to find.
log_line (List[Dict[str, Any]]): An iterator for RegEx matches on a log line.
Returns:
Optional[str] : The value of the element with the element name, or None if no such
element is found.
"""
return next(
(element["value"] for element in log_line if element["field"] == element_name), None
)

def _get_metrics_results_sync(self, query_id: str) -> List[Any]:
"""
Waits for the CloudWatch Insights query to complete and returns all the results.
Args:
query_id (str): CloudWatch Insights query ID.
Returns:
List[Any]: The results from CloudWatch insights 'GetQueryResults' operation.
"""
timeout_time = time.time() + self._poll_timeout_seconds
while time.time() < timeout_time:
response = self._logs_client.get_query_results(queryId=query_id)
query_status = response["status"]
if query_status in ["Failed", "Cancelled"]:
raise MetricsRetrievalError(f"Query {query_id} failed with status {query_status}")
elif query_status == "Complete":
return response["results"]
else:
time.sleep(self._poll_interval_seconds)
self._logger.warning(f"Timed out waiting for query {query_id}.")
return []

def _parse_log_line(self, result_entry: List[Dict[str, Any]], metrics: CwlMetrics) -> None:
"""
Parses the single entry from CloudWatch Insights results and adds any metrics it finds
to 'all_metrics', along with the timestamp for the entry.
Args:
result_entry (List[Dict[str, Any]]): A structured result from calling CloudWatch
Insights to get logs that contain metrics. A single entry will contain the message
(the actual line logged to output), the timestamp (generated by CloudWatch Logs),
and other metadata that we (currently) do not use.
all_metrics (Dict[int, Dict[str, List[Any]]]) : The list of all metrics.
"""
message = self._get_element_from_log_line("@message", result_entry)
if message:
timestamp = self._get_element_from_log_line("@timestamp", result_entry)
metrics.add_metrics_from_log_message(timestamp, message)

def _parse_log_query_results(self, results: List[Any]) -> Dict[int, Dict[str, List[Any]]]:
"""
Parses CloudWatch Insights results and returns all found metrics.
Args:
results (List[Any]): A structured result from calling CloudWatch Insights to get
logs that contain metrics.
Returns:
Dict[int, Dict[str, List[Any]]] : The list of all metrics that can be found in the
CloudWatch Logs results. Each unique set of metrics will be represented by a key
in the topmost dictionary. We can think of this set of metrics as a single table.
Each table will have a set of metrics, indexed by the column name. The
entries are not sorted.
"""
metrics = CwlMetrics()
for result in results:
self._parse_log_line(result, metrics)
return metrics.get_metric_data_as_list()

def get_all_metrics_for_job(
self, job_name: str, job_start_time: int = None, job_end_time: int = None
) -> List[Dict[str, List[Any]]]:
"""
Synchronously retrieves all the algorithm metrics logged by a given Job.
Args:
job_name (str): The name of the Job. The name must be exact to ensure only the relevant
metrics are retrieved.
job_start_time (int): The time at which the job started.
Default: 3 hours before job_end_time.
job_end_time (int): If the job is complete, this should be the time at which the
job finished. Default: current time.
Returns:
List[Dict[str, List[Any]]] : The list of all metrics that can be found in the
CloudWatch Logs results. Each item in the list can be thought of as a separate
table. Each table will have a set of metrics, indexed by the column name. The
entries are not sorted.
"""
query_end_time = job_end_time or int(time.time())
query_start_time = job_start_time or query_end_time - self.QUERY_DEFAULT_JOB_DURATION

# job name needs to be specific to prevent jobs with similar names from being conflated
query = (
f"fields @timestamp, @message "
f"| filter @logStream like /^{re.escape(job_name)}$/ "
f"| filter @message like /^Metrics - /"
)

response = self._logs_client.start_query(
logGroupName=self.LOG_GROUP_NAME,
startTime=query_start_time,
endTime=query_end_time,
queryString=query,
limit=10000,
)

query_id = response["queryId"]

results = self._get_metrics_results_sync(query_id)

return self._parse_log_query_results(results)
97 changes: 97 additions & 0 deletions src/braket/jobs/metrics/cwl_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import re
from collections import defaultdict
from logging import Logger, getLogger
from typing import Dict, FrozenSet, Iterator


class CwlMetrics(object):

METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;")
TIMESTAMP = "Timestamp"

def __init__(
self,
logger: Logger = getLogger(__name__),
):
self._logger = logger
self.metric_tables = defaultdict(lambda: defaultdict(list))

def _add_metrics_to_appropriate_table(
self,
columns: FrozenSet[str],
metrics: Dict[str, float],
) -> None:
"""
Adds the given metrics to the appropriate table.
Args:
columns (FrozenSet[str]): The set of column names representing the metrics.
metrics (Dict[str, float]): A set of metrics in the format {<metric name> : <value>}.
"""
metrics_table = self.metric_tables[columns]
for column_name in metrics.keys():
metrics_table[column_name].append(metrics[column_name])

def _get_metrics_from_log_line_matches(self, all_matches: Iterator) -> Dict[str, float]:
"""
Converts matches from a RegEx to a set of metrics.
Args:
all_matches (Iterator): An iterator for RegEx matches on a log line.
Returns:
Dict[str, float]: The set of metrics found by the RegEx. The result will be in the
format {<metric name> : <value>}. This implies that multiple metrics with
the same name will be deduped to the last instance of that metric.
"""
metrics = {}
for match in all_matches:
subgroup = match.groups()
value = subgroup[1]
try:
metrics[subgroup[0]] = float(value)
except ValueError:
self._logger.warning(f"Unable to convert value {value} to a float.")
return metrics

def add_metrics_from_log_message(self, timestamp: str, message: str) -> None:
"""
Parses a line from CloudWatch Logs adds all the metrics that have been logged
on that line. The timestamp is also added to match the corresponding values.
Args:
timestamp (str): A formatted string representing the timestamp for any found metrics.
message (str): A log line from CloudWatch Logs.
"""
if not message:
return
all_matches = self.METRICS_DEFINITIONS.finditer(message)
parsed_metrics = self._get_metrics_from_log_line_matches(all_matches)
if not parsed_metrics:
return
columns = frozenset(parsed_metrics.keys())
self._add_metrics_to_appropriate_table(columns, parsed_metrics)
self.metric_tables[columns][self.TIMESTAMP].append(timestamp or "N/A")

def get_metric_data_as_list(self):
"""
Gets all the metrics data for all tables, as a list.
Returns:
List[Dict[str, List[Any]]] : The list of all tables. Each table will have a set
of metrics, indexed by the column name.
"""
return list(self.metric_tables.values())
Loading

0 comments on commit f427f08

Please sign in to comment.