From 1000b71c81a7f0de4308e9014dae6a9565735f4b Mon Sep 17 00:00:00 2001 From: Stephen Face Date: Tue, 19 Apr 2022 16:29:21 -0700 Subject: [PATCH 1/3] feat: add convenient selection of data parallelism --- src/braket/aws/aws_quantum_job.py | 6 ++++++ src/braket/jobs/quantum_job_creation.py | 11 +++++++++++ test/unit_tests/braket/aws/test_aws_quantum_job.py | 1 + 3 files changed, 18 insertions(+) diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 85aaf770a..d57f14db6 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -73,6 +73,7 @@ def create( hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, instance_config: InstanceConfig = None, + distribution: str = None, stopping_condition: StoppingCondition = None, output_data_config: OutputDataConfig = None, copy_checkpoints_from_job: str = None, @@ -134,6 +135,10 @@ def create( to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30). + distribution (str): A str that specifies how the job should be distributed. If set to + "dataparallel", the hyperparameters for the job will be set to use data parallelism + features for PyTorch or TensorFlow. Default: None. + stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). @@ -181,6 +186,7 @@ def create( hyperparameters=hyperparameters, input_data=input_data, instance_config=instance_config, + distribution=distribution, stopping_condition=stopping_condition, output_data_config=output_data_config, copy_checkpoints_from_job=copy_checkpoints_from_job, diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index db159eccf..01f7275cb 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -44,6 +44,7 @@ def prepare_quantum_job( hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, instance_config: InstanceConfig = None, + distribution: str = None, stopping_condition: StoppingCondition = None, output_data_config: OutputDataConfig = None, copy_checkpoints_from_job: str = None, @@ -98,6 +99,10 @@ def prepare_quantum_job( to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). + distribution (str): A str that specifies how the job should be distributed. If set to + "dataparallel", the hyperparameters for the job will be set to use data parallelism + features for PyTorch or TensorFlow. Default: None. + stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). @@ -191,6 +196,12 @@ def prepare_quantum_job( "s3Uri" ] aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) + if distribution == "dataparallel": + distribruted_hyperparams = { + "sagemaker_distributed_dataparallel_enabled": "true", + "sagemaker_instance_type": instance_config.instanceType, + } + hyperparameters.update(distribruted_hyperparams) create_job_kwargs = { "jobName": job_name, diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index 265a0cdbd..3954ccb6d 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -506,6 +506,7 @@ def prepare_job_args(aws_session, device_arn): "hyperparameters": Mock(), "input_data": Mock(), "instance_config": Mock(), + "distribution": Mock(), "stopping_condition": Mock(), "output_data_config": Mock(), "copy_checkpoints_from_job": Mock(), From fca5c5bc55a0154bd3f0a19cef71275596bd5c8b Mon Sep 17 00:00:00 2001 From: Stephen Face <60493521+shpface@users.noreply.github.com> Date: Mon, 25 Apr 2022 10:20:06 -0700 Subject: [PATCH 2/3] fix spelling Co-authored-by: Aaron Berdy --- src/braket/jobs/quantum_job_creation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 01f7275cb..5f848fd82 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -197,11 +197,11 @@ def prepare_quantum_job( ] aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) if distribution == "dataparallel": - distribruted_hyperparams = { + distributed_hyperparams = { "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_config.instanceType, } - hyperparameters.update(distribruted_hyperparams) + hyperparameters.update(distributed_hyperparams) create_job_kwargs = { "jobName": job_name, From 50d371ceff789092e71582430fe6eace019d6df3 Mon Sep 17 00:00:00 2001 From: Stephen Face Date: Tue, 26 Apr 2022 10:15:20 -0700 Subject: [PATCH 3/3] Rename distribution to 'data_parallel' --- src/braket/aws/aws_quantum_job.py | 2 +- src/braket/jobs/quantum_job_creation.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index d57f14db6..8bb0c7201 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -136,7 +136,7 @@ def create( instanceCount=1, volumeSizeInGB=30). distribution (str): A str that specifies how the job should be distributed. If set to - "dataparallel", the hyperparameters for the job will be set to use data parallelism + "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: None. stopping_condition (StoppingCondition): The maximum length of time, in seconds, diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 5f848fd82..397d5b575 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -100,7 +100,7 @@ def prepare_quantum_job( instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). distribution (str): A str that specifies how the job should be distributed. If set to - "dataparallel", the hyperparameters for the job will be set to use data parallelism + "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: None. stopping_condition (StoppingCondition): The maximum length of time, in seconds, @@ -196,7 +196,7 @@ def prepare_quantum_job( "s3Uri" ] aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) - if distribution == "dataparallel": + if distribution == "data_parallel": distributed_hyperparams = { "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_config.instanceType,