diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index f5d1e48a4..8d0b7870d 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -73,6 +73,7 @@ def create( hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, instance_config: InstanceConfig = None, + distribution: str = None, stopping_condition: StoppingCondition = None, output_data_config: OutputDataConfig = None, copy_checkpoints_from_job: str = None, @@ -84,8 +85,11 @@ def create( """Creates a job by invoking the Braket CreateJob API. Args: - device (str): ARN for the AWS device which is primarily - accessed for the execution of this job. + device (str): ARN for the AWS device which is primarily accessed for the execution + of this job. Alternatively, a string of the format "local:/" + for using a local simulator for the job. This string will be available as the + environment variable `AMZN_BRAKET_DEVICE_ARN` inside the job container when + using a Braket container. source_module (str): Path (absolute, relative or an S3 URI) to a python module to be tarred and uploaded. If `source_module` is an S3 URI, it must point to a @@ -129,7 +133,11 @@ def create( instance_config (InstanceConfig): Configuration of the instances to be used to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', - instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). + instanceCount=1, volumeSizeInGB=30). + + distribution (str): A str that specifies how the job should be distributed. If set to + "data_parallel", the hyperparameters for the job will be set to use data parallelism + features for PyTorch or TensorFlow. Default: None. stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. @@ -178,6 +186,7 @@ def create( hyperparameters=hyperparameters, input_data=input_data, instance_config=instance_config, + distribution=distribution, stopping_condition=stopping_condition, output_data_config=output_data_config, copy_checkpoints_from_job=copy_checkpoints_from_job, @@ -311,7 +320,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: stream_prefix = f"{self.name}/" stream_names = [] # The list of log streams positions = {} # The current position in each stream, map of stream name -> position - instance_count = 1 # currently only support a single instance + instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"] has_streams = False color_wrap = logs.ColorWrap() @@ -531,6 +540,8 @@ def __hash__(self) -> int: @staticmethod def _initialize_session(session_value, device, logger): aws_session = session_value or AwsSession() + if device.startswith("local:"): + return aws_session device_region = AwsDevice.get_device_region(device) return ( AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger) diff --git a/src/braket/jobs/config.py b/src/braket/jobs/config.py index bbe1f2cbd..28b7c05d3 100644 --- a/src/braket/jobs/config.py +++ b/src/braket/jobs/config.py @@ -29,6 +29,7 @@ class InstanceConfig: instanceType: str = "ml.m5.large" volumeSizeInGb: int = 30 + instanceCount: int = 1 @dataclass diff --git a/src/braket/jobs/image_uri_config/pl_pytorch.json b/src/braket/jobs/image_uri_config/pl_pytorch.json index 384848cb4..50197b1cc 100644 --- a/src/braket/jobs/image_uri_config/pl_pytorch.json +++ b/src/braket/jobs/image_uri_config/pl_pytorch.json @@ -1,6 +1,6 @@ { "versions": { - "1.8.1": { + "1.9.1": { "registries": { "us-east-1": "292282985366", "us-west-1": "292282985366", diff --git a/src/braket/jobs/image_uris.py b/src/braket/jobs/image_uris.py index 0a9f27bbd..4996a7dd9 100644 --- a/src/braket/jobs/image_uris.py +++ b/src/braket/jobs/image_uris.py @@ -45,7 +45,13 @@ def retrieve_image(framework: Framework, region: str): framework_version = max(version for version in config["versions"]) version_config = config["versions"][framework_version] registry = _registry_for_region(version_config, region) - tag = f"{version_config['repository']}:{framework_version}-cpu-py37-ubuntu18.04" + tag = "" + if framework == Framework.PL_TENSORFLOW: + tag = f"{version_config['repository']}:{framework_version}-gpu-py37-cu110-ubuntu18.04" + elif framework == Framework.PL_PYTORCH: + tag = f"{version_config['repository']}:{framework_version}-gpu-py38-cu111-ubuntu20.04" + else: + tag = f"{version_config['repository']}:{framework_version}-cpu-py37-ubuntu18.04" return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}" diff --git a/src/braket/jobs/local/local_job.py b/src/braket/jobs/local/local_job.py index 8b7394261..b240658c6 100644 --- a/src/braket/jobs/local/local_job.py +++ b/src/braket/jobs/local/local_job.py @@ -53,8 +53,11 @@ def create( docker container. Args: - device (str): ARN for the AWS device which is primarily - accessed for the execution of this job. + device (str): ARN for the AWS device which is primarily accessed for the execution + of this job. Alternatively, a string of the format "local:/" + for using a local simulator for the job. This string will be available as the + environment variable `AMZN_BRAKET_DEVICE_ARN` inside the job container when + using a Braket container. source_module (str): Path (absolute, relative or an S3 URI) to a python module to be tarred and uploaded. If `source_module` is an S3 URI, it must point to a diff --git a/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py index 251d17ded..c40d70570 100644 --- a/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py @@ -167,7 +167,7 @@ def get_metrics_for_job( query = ( f"fields @timestamp, @message " f"| filter @logStream like /^{job_name}\\// " - f"| filter @message like /^Metrics - /" + f"| filter @message like /Metrics - /" ) response = self._logs_client.start_query( diff --git a/src/braket/jobs/metrics_data/log_metrics_parser.py b/src/braket/jobs/metrics_data/log_metrics_parser.py index 76ef319b4..d0faacb57 100644 --- a/src/braket/jobs/metrics_data/log_metrics_parser.py +++ b/src/braket/jobs/metrics_data/log_metrics_parser.py @@ -27,6 +27,8 @@ class LogMetricsParser(object): METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;") TIMESTAMP = "timestamp" ITERATION_NUMBER = "iteration_number" + NODE_ID = "node_id" + NODE_TAG = re.compile(r"^\[([^\]]*)\]") def __init__( self, @@ -102,32 +104,37 @@ def parse_log_message(self, timestamp: str, message: str) -> None: return if timestamp and self.TIMESTAMP not in parsed_metrics: parsed_metrics[self.TIMESTAMP] = timestamp + node_match = self.NODE_TAG.match(message) + if node_match: + parsed_metrics[self.NODE_ID] = node_match.group(1) self.all_metrics.append(parsed_metrics) def get_columns_and_pivot_indices( self, pivot: str - ) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[int, int]]: + ) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[Tuple[int, str], int]]: """ - Parses the metrics to find all the metrics that have the pivot column. The values of - the pivot column are assigned a row index, so that all metrics with the same pivot value - are stored in the same row. + Parses the metrics to find all the metrics that have the pivot column. The values of the + pivot column are paired with the node_id and assigned a row index, so that all metrics + with the same pivot value and node_id are stored in the same row. Args: pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER. Returns: - Tuple[Dict[str, List[Any]], Dict[int, int]]: - The Dict[str, List[Any]] the result table with all the metrics values initialized + Tuple[Dict[str, List[Any]], Dict[Tuple[int, str], int]]: + The Dict[str, List[Any]] is the result table with all the metrics values initialized to None - The Dict[int, int] is the list of pivot indices, where the value of a pivot column - is mapped to a row index. + The Dict[Tuple[int, str], int] is the list of pivot indices, where the value of a + pivot column and node_id is mapped to a row index. """ row_count = 0 pivot_indices: dict[int, int] = {} table: dict[str, list[Optional[Union[str, float, int]]]] = {} for metric in self.all_metrics: if pivot in metric: - if metric[pivot] not in pivot_indices: - pivot_indices[metric[pivot]] = row_count + # If no node_id is present, pair pivot value with None for the key. + metric_pivot = (metric[pivot], metric.get(self.NODE_ID)) + if metric_pivot not in pivot_indices: + pivot_indices[metric_pivot] = row_count row_count += 1 for column_name in metric: table[column_name] = [None] @@ -141,18 +148,21 @@ def get_metric_data_with_pivot( """ Gets the metric data for a given pivot column name. Metrics without the pivot column are not included in the results. Metrics that have the same value in the pivot column - are returned in the same row. If the a metric has multiple values for the pivot value, - the statistic is used to determine which value is returned. + from the same node are returned in the same row. Metrics from different nodes are stored + in different rows. If the metric has multiple values for the row, the statistic is used + to determine which value is returned. For example, for the metrics: "iteration_number" : 0, "metricA" : 2, "metricB" : 1, "iteration_number" : 0, "metricA" : 1, "no_pivot_column" : 0, "metricA" : 0, "iteration_number" : 1, "metricA" : 2, + "iteration_number" : 1, "node_id" : "nodeB", "metricB" : 0, The result with iteration_number as the pivot, statistic of MIN the result will be: - iteration_number metricA metricB - 0 1 1 - 1 2 None + iteration_number node_id metricA metricB + 0 None 1 1 + 1 None 2 None + 1 nodeB None 0 Args: pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER. @@ -164,7 +174,8 @@ def get_metric_data_with_pivot( table, pivot_indices = self.get_columns_and_pivot_indices(pivot) for metric in self.all_metrics: if pivot in metric: - row = pivot_indices[metric[pivot]] + metric_pivot = (metric[pivot], metric.get(self.NODE_ID)) + row = pivot_indices[metric_pivot] for column_name in metric: table[column_name][row] = self._get_value( table[column_name][row], metric[column_name], statistic diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 093bfce78..d5292c77a 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -44,6 +44,7 @@ def prepare_quantum_job( hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, instance_config: InstanceConfig = None, + distribution: str = None, stopping_condition: StoppingCondition = None, output_data_config: OutputDataConfig = None, copy_checkpoints_from_job: str = None, @@ -98,6 +99,10 @@ def prepare_quantum_job( to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). + distribution (str): A str that specifies how the job should be distributed. If set to + "data_parallel", the hyperparameters for the job will be set to use data parallelism + features for PyTorch or TensorFlow. Default: None. + stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). @@ -192,6 +197,12 @@ def prepare_quantum_job( "s3Uri" ] aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) + if distribution == "data_parallel": + distributed_hyperparams = { + "sagemaker_distributed_dataparallel_enabled": "true", + "sagemaker_instance_type": instance_config.instanceType, + } + hyperparameters.update(distributed_hyperparams) create_job_kwargs = { "jobName": job_name, diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index b71114109..bbb1ad79c 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -15,6 +15,7 @@ import json import logging import os +import re import tarfile import tempfile from unittest.mock import Mock, patch @@ -184,10 +185,11 @@ def test_quantum_job_constructor_default_session( assert job._aws_session == aws_session_mock.return_value -@pytest.mark.xfail(raises=ValueError) def test_quantum_job_constructor_invalid_region(aws_session): + region_mismatch = "The aws session region does not match the region for the supplied arn." arn = "arn:aws:braket:unknown-region:875981177017:job/quantum_job_name" - AwsQuantumJob(arn, aws_session) + with pytest.raises(ValueError, match=region_mismatch): + AwsQuantumJob(arn, aws_session) @patch("braket.aws.aws_quantum_job.boto3.Session") @@ -504,6 +506,7 @@ def prepare_job_args(aws_session, device_arn): "hyperparameters": Mock(), "input_data": Mock(), "instance_config": Mock(), + "distribution": Mock(), "stopping_condition": Mock(), "output_data_config": Mock(), "copy_checkpoints_from_job": Mock(), @@ -528,9 +531,9 @@ def test_name(quantum_job_arn, quantum_job_name, aws_session): assert quantum_job.name == quantum_job_name -@pytest.mark.xfail(raises=AttributeError) def test_no_arn_setter(quantum_job): - quantum_job.arn = 123 + with pytest.raises(AttributeError, match="can't set attribute"): + quantum_job.arn = 123 @pytest.mark.parametrize("wait_until_complete", [True, False]) @@ -577,7 +580,6 @@ def test_cancel_job(quantum_job_arn, aws_session, generate_cancel_job_response): assert status == cancellation_status -@pytest.mark.xfail(raises=ClientError) def test_cancel_job_surfaces_exception(quantum_job, aws_session): exception_response = { "Error": { @@ -585,8 +587,13 @@ def test_cancel_job_surfaces_exception(quantum_job, aws_session): "Message": "unit-test-error", } } + error_string = re.escape( + "An error occurred (ValidationException) when calling the " + "cancel_job operation: unit-test-error" + ) aws_session.cancel_job.side_effect = ClientError(exception_response, "cancel_job") - quantum_job.cancel() + with pytest.raises(ClientError, match=error_string): + quantum_job.cancel() @pytest.mark.parametrize( @@ -987,6 +994,16 @@ def test_exceptions_in_all_device_regions(aws_session): AwsQuantumJob._initialize_session(aws_session, device_arn, logger) +@patch("braket.aws.aws_quantum_job.AwsSession") +def test_initialize_session_local_device(mock_new_session, aws_session): + logger = logging.getLogger(__name__) + device = "local:provider.device.name" + # don't change a provided AwsSession + assert AwsQuantumJob._initialize_session(aws_session, device, logger) == aws_session + # otherwise, create an AwsSession with the profile defaults + assert AwsQuantumJob._initialize_session(None, device, logger) == mock_new_session() + + def test_bad_arn_format(aws_session): logger = logging.getLogger(__name__) device_not_found = ( diff --git a/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py index f0d42f523..17027cec4 100644 --- a/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py @@ -73,7 +73,7 @@ def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aw startTime=1, endTime=2, queryString="fields @timestamp, @message | filter @logStream like /^test_job\\//" - " | filter @message like /^Metrics - /", + " | filter @message like /Metrics - /", limit=10000, ) assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST diff --git a/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py b/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py index f3edadb02..4d9ca9dce 100644 --- a/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py +++ b/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py @@ -63,6 +63,40 @@ "metric3": [None, None, 0.2, None], } +MULTINODE_METRICS_LOG_LINES = [ + { + "timestamp": "Test timestamp 0", + "message": "[nodeA]:Metrics - metric0=1.0;", + }, + # This line logs the same metric from a different node. + { + "timestamp": "Test timestamp 0", + "message": "[nodeB]:Metrics - metric0=2.0;", + }, + # This line also logs a metric unique to one node. + { + "timestamp": "Test timestamp 1", + "message": "[nodeA]:Metrics - metricA=3.0;", + }, + # This line logs a metric without a node tag. + { + "timestamp": "Test timestamp 1", + "message": "Metrics - metric0=0.0;", + }, +] + +MULTINODES_METRICS_RESULT = { + "timestamp": ["Test timestamp 0", "Test timestamp 0", "Test timestamp 1", "Test timestamp 1"], + "node_id": [ + "nodeA", + "nodeB", + "nodeA", + None, + ], + "metric0": [1.0, 2.0, None, 0.0], + "metricA": [None, None, 3.0, None], +} + # This will test how metrics are combined when the multiple metrics have the same timestamp SINGLE_TIMESTAMP_METRICS_LOG_LINES = [ {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.0;"}, @@ -135,6 +169,12 @@ MetricStatistic.MAX, SIMPLE_METRICS_RESULT, ), + ( + MULTINODE_METRICS_LOG_LINES, + MetricType.TIMESTAMP, + MetricStatistic.MAX, + MULTINODES_METRICS_RESULT, + ), ( SINGLE_TIMESTAMP_METRICS_LOG_LINES, MetricType.TIMESTAMP, diff --git a/test/unit_tests/braket/jobs/test_image_uris.py b/test/unit_tests/braket/jobs/test_image_uris.py index ab608c5a8..9f44c7ef9 100644 --- a/test/unit_tests/braket/jobs/test_image_uris.py +++ b/test/unit_tests/braket/jobs/test_image_uris.py @@ -29,13 +29,13 @@ "us-east-1", Framework.PL_TENSORFLOW, "292282985366.dkr.ecr.us-east-1.amazonaws.com/amazon-braket-tensorflow-jobs:" - "2.4.1-cpu-py37-ubuntu18.04", + "2.4.1-gpu-py37-cu110-ubuntu18.04", ), ( "us-west-2", Framework.PL_PYTORCH, "292282985366.dkr.ecr.us-west-2.amazonaws.com/" - "amazon-braket-pytorch-jobs:1.8.1-cpu-py37-ubuntu18.04", + "amazon-braket-pytorch-jobs:1.9.1-gpu-py38-cu111-ubuntu20.04", ), ], ) diff --git a/test/unit_tests/braket/jobs/test_quantum_job_creation.py b/test/unit_tests/braket/jobs/test_quantum_job_creation.py index 48fe575c5..ec2d556fc 100644 --- a/test/unit_tests/braket/jobs/test_quantum_job_creation.py +++ b/test/unit_tests/braket/jobs/test_quantum_job_creation.py @@ -138,6 +138,18 @@ def instance_config(): ) +@pytest.fixture(params=[False, True]) +def data_parallel(request): + return request.param + + +@pytest.fixture +def distribution(data_parallel): + if data_parallel: + return "data_parallel" + return None + + @pytest.fixture def stopping_condition(): return StoppingCondition( @@ -227,6 +239,7 @@ def create_job_args( hyperparameters, input_data, instance_config, + distribution, stopping_condition, output_data_config, checkpoint_config, @@ -246,6 +259,7 @@ def create_job_args( "hyperparameters": hyperparameters, "input_data": input_data, "instance_config": instance_config, + "distribution": distribution, "stopping_condition": stopping_condition, "output_data_config": output_data_config, "checkpoint_config": checkpoint_config, @@ -308,6 +322,12 @@ def _translate_creation_args(create_job_args): hyperparameters = {str(key): str(value) for key, value in hyperparameters.items()} input_data = create_job_args["input_data"] or {} instance_config = create_job_args["instance_config"] or InstanceConfig() + if create_job_args["distribution"] == "data_parallel": + distributed_hyperparams = { + "sagemaker_distributed_dataparallel_enabled": "true", + "sagemaker_instance_type": instance_config.instanceType, + } + hyperparameters.update(distributed_hyperparams) output_data_config = create_job_args["output_data_config"] or OutputDataConfig( s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data") )