Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support local simulators for jobs #309

Merged
merged 14 commits into from
May 4, 2022
19 changes: 15 additions & 4 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create(
hyperparameters: Dict[str, Any] = None,
input_data: Union[str, Dict, S3DataSourceConfig] = None,
instance_config: InstanceConfig = None,
distribution: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
Expand All @@ -84,8 +85,11 @@ def create(
"""Creates a job by invoking the Braket CreateJob API.
Args:
device (str): ARN for the AWS device which is primarily
accessed for the execution of this job.
device (str): ARN for the AWS device which is primarily accessed for the execution
ajberdy marked this conversation as resolved.
Show resolved Hide resolved
of this job. Alternatively, a string of the format "local:<provider>/<simulator>"
for using a local simulator for the job. This string will be available as the
environment variable `AMZN_BRAKET_DEVICE_ARN` inside the job container when
using a Braket container.
source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
Expand Down Expand Up @@ -129,7 +133,11 @@ def create(
instance_config (InstanceConfig): Configuration of the instances to be used
to execute the job. Default: InstanceConfig(instanceType='ml.m5.large',
instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None).
instanceCount=1, volumeSizeInGB=30).
distribution (str): A str that specifies how the job should be distributed. If set to
"data_parallel", the hyperparameters for the job will be set to use data parallelism
features for PyTorch or TensorFlow. Default: None.
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Expand Down Expand Up @@ -178,6 +186,7 @@ def create(
hyperparameters=hyperparameters,
input_data=input_data,
instance_config=instance_config,
distribution=distribution,
stopping_condition=stopping_condition,
output_data_config=output_data_config,
copy_checkpoints_from_job=copy_checkpoints_from_job,
Expand Down Expand Up @@ -311,7 +320,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
stream_prefix = f"{self.name}/"
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
instance_count = 1 # currently only support a single instance
instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"]
has_streams = False
color_wrap = logs.ColorWrap()

Expand Down Expand Up @@ -531,6 +540,8 @@ def __hash__(self) -> int:
@staticmethod
def _initialize_session(session_value, device, logger):
aws_session = session_value or AwsSession()
if device.startswith("local:"):
return aws_session
device_region = AwsDevice.get_device_region(device)
return (
AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger)
Expand Down
1 change: 1 addition & 0 deletions src/braket/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class InstanceConfig:

instanceType: str = "ml.m5.large"
volumeSizeInGb: int = 30
instanceCount: int = 1


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/braket/jobs/image_uri_config/pl_pytorch.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"versions": {
"1.8.1": {
"1.9.1": {
"registries": {
"us-east-1": "292282985366",
"us-west-1": "292282985366",
Expand Down
8 changes: 7 additions & 1 deletion src/braket/jobs/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ def retrieve_image(framework: Framework, region: str):
framework_version = max(version for version in config["versions"])
version_config = config["versions"][framework_version]
registry = _registry_for_region(version_config, region)
tag = f"{version_config['repository']}:{framework_version}-cpu-py37-ubuntu18.04"
tag = ""
if framework == Framework.PL_TENSORFLOW:
tag = f"{version_config['repository']}:{framework_version}-gpu-py37-cu110-ubuntu18.04"
elif framework == Framework.PL_PYTORCH:
tag = f"{version_config['repository']}:{framework_version}-gpu-py38-cu111-ubuntu20.04"
else:
tag = f"{version_config['repository']}:{framework_version}-cpu-py37-ubuntu18.04"
return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}"


Expand Down
7 changes: 5 additions & 2 deletions src/braket/jobs/local/local_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ def create(
docker container.
Args:
device (str): ARN for the AWS device which is primarily
accessed for the execution of this job.
device (str): ARN for the AWS device which is primarily accessed for the execution
of this job. Alternatively, a string of the format "local:<provider>/<simulator>"
for using a local simulator for the job. This string will be available as the
environment variable `AMZN_BRAKET_DEVICE_ARN` inside the job container when
using a Braket container.
source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_metrics_for_job(
query = (
f"fields @timestamp, @message "
f"| filter @logStream like /^{job_name}\\// "
f"| filter @message like /^Metrics - /"
f"| filter @message like /Metrics - /"
)

response = self._logs_client.start_query(
Expand Down
43 changes: 27 additions & 16 deletions src/braket/jobs/metrics_data/log_metrics_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class LogMetricsParser(object):
METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;")
TIMESTAMP = "timestamp"
ITERATION_NUMBER = "iteration_number"
NODE_ID = "node_id"
NODE_TAG = re.compile(r"^\[([^\]]*)\]")

def __init__(
self,
Expand Down Expand Up @@ -102,32 +104,37 @@ def parse_log_message(self, timestamp: str, message: str) -> None:
return
if timestamp and self.TIMESTAMP not in parsed_metrics:
parsed_metrics[self.TIMESTAMP] = timestamp
node_match = self.NODE_TAG.match(message)
if node_match:
parsed_metrics[self.NODE_ID] = node_match.group(1)
self.all_metrics.append(parsed_metrics)

def get_columns_and_pivot_indices(
self, pivot: str
) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[int, int]]:
) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[Tuple[int, str], int]]:
"""
Parses the metrics to find all the metrics that have the pivot column. The values of
the pivot column are assigned a row index, so that all metrics with the same pivot value
are stored in the same row.
Parses the metrics to find all the metrics that have the pivot column. The values of the
pivot column are paired with the node_id and assigned a row index, so that all metrics
with the same pivot value and node_id are stored in the same row.
Args:
pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER.
Returns:
Tuple[Dict[str, List[Any]], Dict[int, int]]:
The Dict[str, List[Any]] the result table with all the metrics values initialized
Tuple[Dict[str, List[Any]], Dict[Tuple[int, str], int]]:
The Dict[str, List[Any]] is the result table with all the metrics values initialized
to None
The Dict[int, int] is the list of pivot indices, where the value of a pivot column
is mapped to a row index.
The Dict[Tuple[int, str], int] is the list of pivot indices, where the value of a
pivot column and node_id is mapped to a row index.
"""
row_count = 0
pivot_indices: dict[int, int] = {}
table: dict[str, list[Optional[Union[str, float, int]]]] = {}
for metric in self.all_metrics:
if pivot in metric:
if metric[pivot] not in pivot_indices:
pivot_indices[metric[pivot]] = row_count
# If no node_id is present, pair pivot value with None for the key.
metric_pivot = (metric[pivot], metric.get(self.NODE_ID))
if metric_pivot not in pivot_indices:
pivot_indices[metric_pivot] = row_count
row_count += 1
for column_name in metric:
table[column_name] = [None]
Expand All @@ -141,18 +148,21 @@ def get_metric_data_with_pivot(
"""
Gets the metric data for a given pivot column name. Metrics without the pivot column
are not included in the results. Metrics that have the same value in the pivot column
are returned in the same row. If the a metric has multiple values for the pivot value,
the statistic is used to determine which value is returned.
from the same node are returned in the same row. Metrics from different nodes are stored
in different rows. If the metric has multiple values for the row, the statistic is used
to determine which value is returned.
For example, for the metrics:
"iteration_number" : 0, "metricA" : 2, "metricB" : 1,
"iteration_number" : 0, "metricA" : 1,
"no_pivot_column" : 0, "metricA" : 0,
"iteration_number" : 1, "metricA" : 2,
"iteration_number" : 1, "node_id" : "nodeB", "metricB" : 0,
The result with iteration_number as the pivot, statistic of MIN the result will be:
iteration_number metricA metricB
0 1 1
1 2 None
iteration_number node_id metricA metricB
0 None 1 1
1 None 2 None
1 nodeB None 0
Args:
pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER.
Expand All @@ -164,7 +174,8 @@ def get_metric_data_with_pivot(
table, pivot_indices = self.get_columns_and_pivot_indices(pivot)
for metric in self.all_metrics:
if pivot in metric:
row = pivot_indices[metric[pivot]]
metric_pivot = (metric[pivot], metric.get(self.NODE_ID))
row = pivot_indices[metric_pivot]
for column_name in metric:
table[column_name][row] = self._get_value(
table[column_name][row], metric[column_name], statistic
Expand Down
11 changes: 11 additions & 0 deletions src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def prepare_quantum_job(
hyperparameters: Dict[str, Any] = None,
input_data: Union[str, Dict, S3DataSourceConfig] = None,
instance_config: InstanceConfig = None,
distribution: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
Expand Down Expand Up @@ -98,6 +99,10 @@ def prepare_quantum_job(
to execute the job. Default: InstanceConfig(instanceType='ml.m5.large',
instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None).
distribution (str): A str that specifies how the job should be distributed. If set to
"data_parallel", the hyperparameters for the job will be set to use data parallelism
features for PyTorch or TensorFlow. Default: None.
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
Expand Down Expand Up @@ -192,6 +197,12 @@ def prepare_quantum_job(
"s3Uri"
]
aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri)
if distribution == "data_parallel":
distributed_hyperparams = {
"sagemaker_distributed_dataparallel_enabled": "true",
"sagemaker_instance_type": instance_config.instanceType,
}
hyperparameters.update(distributed_hyperparams)

create_job_kwargs = {
"jobName": job_name,
Expand Down
29 changes: 23 additions & 6 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import tarfile
import tempfile
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -184,10 +185,11 @@ def test_quantum_job_constructor_default_session(
assert job._aws_session == aws_session_mock.return_value


@pytest.mark.xfail(raises=ValueError)
def test_quantum_job_constructor_invalid_region(aws_session):
region_mismatch = "The aws session region does not match the region for the supplied arn."
arn = "arn:aws:braket:unknown-region:875981177017:job/quantum_job_name"
AwsQuantumJob(arn, aws_session)
with pytest.raises(ValueError, match=region_mismatch):
AwsQuantumJob(arn, aws_session)


@patch("braket.aws.aws_quantum_job.boto3.Session")
Expand Down Expand Up @@ -504,6 +506,7 @@ def prepare_job_args(aws_session, device_arn):
"hyperparameters": Mock(),
"input_data": Mock(),
"instance_config": Mock(),
"distribution": Mock(),
"stopping_condition": Mock(),
"output_data_config": Mock(),
"copy_checkpoints_from_job": Mock(),
Expand All @@ -528,9 +531,9 @@ def test_name(quantum_job_arn, quantum_job_name, aws_session):
assert quantum_job.name == quantum_job_name


@pytest.mark.xfail(raises=AttributeError)
def test_no_arn_setter(quantum_job):
quantum_job.arn = 123
with pytest.raises(AttributeError, match="can't set attribute"):
quantum_job.arn = 123


@pytest.mark.parametrize("wait_until_complete", [True, False])
Expand Down Expand Up @@ -577,16 +580,20 @@ def test_cancel_job(quantum_job_arn, aws_session, generate_cancel_job_response):
assert status == cancellation_status


@pytest.mark.xfail(raises=ClientError)
def test_cancel_job_surfaces_exception(quantum_job, aws_session):
exception_response = {
"Error": {
"Code": "ValidationException",
"Message": "unit-test-error",
}
}
error_string = re.escape(
"An error occurred (ValidationException) when calling the "
"cancel_job operation: unit-test-error"
)
aws_session.cancel_job.side_effect = ClientError(exception_response, "cancel_job")
quantum_job.cancel()
with pytest.raises(ClientError, match=error_string):
quantum_job.cancel()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -987,6 +994,16 @@ def test_exceptions_in_all_device_regions(aws_session):
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)


@patch("braket.aws.aws_quantum_job.AwsSession")
def test_initialize_session_local_device(mock_new_session, aws_session):
logger = logging.getLogger(__name__)
device = "local:provider.device.name"
# don't change a provided AwsSession
assert AwsQuantumJob._initialize_session(aws_session, device, logger) == aws_session
# otherwise, create an AwsSession with the profile defaults
assert AwsQuantumJob._initialize_session(None, device, logger) == mock_new_session()


def test_bad_arn_format(aws_session):
logger = logging.getLogger(__name__)
device_not_found = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aw
startTime=1,
endTime=2,
queryString="fields @timestamp, @message | filter @logStream like /^test_job\\//"
" | filter @message like /^Metrics - /",
" | filter @message like /Metrics - /",
limit=10000,
)
assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,40 @@
"metric3": [None, None, 0.2, None],
}

MULTINODE_METRICS_LOG_LINES = [
{
"timestamp": "Test timestamp 0",
"message": "[nodeA]<stdout>:Metrics - metric0=1.0;",
},
# This line logs the same metric from a different node.
{
"timestamp": "Test timestamp 0",
"message": "[nodeB]<stdout>:Metrics - metric0=2.0;",
},
# This line also logs a metric unique to one node.
{
"timestamp": "Test timestamp 1",
"message": "[nodeA]<stdout>:Metrics - metricA=3.0;",
},
# This line logs a metric without a node tag.
{
"timestamp": "Test timestamp 1",
"message": "Metrics - metric0=0.0;",
},
]

MULTINODES_METRICS_RESULT = {
"timestamp": ["Test timestamp 0", "Test timestamp 0", "Test timestamp 1", "Test timestamp 1"],
"node_id": [
"nodeA",
"nodeB",
"nodeA",
None,
],
"metric0": [1.0, 2.0, None, 0.0],
"metricA": [None, None, 3.0, None],
}

# This will test how metrics are combined when the multiple metrics have the same timestamp
SINGLE_TIMESTAMP_METRICS_LOG_LINES = [
{"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.0;"},
Expand Down Expand Up @@ -135,6 +169,12 @@
MetricStatistic.MAX,
SIMPLE_METRICS_RESULT,
),
(
MULTINODE_METRICS_LOG_LINES,
MetricType.TIMESTAMP,
MetricStatistic.MAX,
MULTINODES_METRICS_RESULT,
),
(
SINGLE_TIMESTAMP_METRICS_LOG_LINES,
MetricType.TIMESTAMP,
Expand Down
Loading