Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support local simulators for jobs #309

Merged
merged 14 commits into from
May 4, 2022
9 changes: 7 additions & 2 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ def create(
"""Creates a job by invoking the Braket CreateJob API.

Args:
device (str): ARN for the AWS device which is primarily
accessed for the execution of this job.
device (str): ARN for the AWS device which is primarily accessed for the execution
ajberdy marked this conversation as resolved.
Show resolved Hide resolved
of this job. Alternatively, a string of the format
"local:<provider>.<simulator>.<name>" for using a local simulator for the job.
ajberdy marked this conversation as resolved.
Show resolved Hide resolved
This string will be available as the environment variable `AMZN_BRAKET_DEVICE_ARN`
inside the job container when using a Braket container.

source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
Expand Down Expand Up @@ -531,6 +534,8 @@ def __hash__(self) -> int:
@staticmethod
def _initialize_session(session_value, device, logger):
aws_session = session_value or AwsSession()
if device.startswith("local:"):
return aws_session
device_region = device.split(":")[3]
return (
AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger)
Expand Down
28 changes: 22 additions & 6 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import tarfile
import tempfile
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -184,10 +185,11 @@ def test_quantum_job_constructor_default_session(
assert job._aws_session == aws_session_mock.return_value


@pytest.mark.xfail(raises=ValueError)
def test_quantum_job_constructor_invalid_region(aws_session):
region_mismatch = "The aws session region does not match the region for the supplied arn."
arn = "arn:aws:braket:unknown-region:875981177017:job/quantum_job_name"
AwsQuantumJob(arn, aws_session)
with pytest.raises(ValueError, match=region_mismatch):
AwsQuantumJob(arn, aws_session)


@patch("braket.aws.aws_quantum_job.boto3.Session")
Expand Down Expand Up @@ -528,9 +530,9 @@ def test_name(quantum_job_arn, quantum_job_name, aws_session):
assert quantum_job.name == quantum_job_name


@pytest.mark.xfail(raises=AttributeError)
def test_no_arn_setter(quantum_job):
quantum_job.arn = 123
with pytest.raises(AttributeError, match="can't set attribute"):
quantum_job.arn = 123


@pytest.mark.parametrize("wait_until_complete", [True, False])
Expand Down Expand Up @@ -577,16 +579,20 @@ def test_cancel_job(quantum_job_arn, aws_session, generate_cancel_job_response):
assert status == cancellation_status


@pytest.mark.xfail(raises=ClientError)
def test_cancel_job_surfaces_exception(quantum_job, aws_session):
exception_response = {
"Error": {
"Code": "ValidationException",
"Message": "unit-test-error",
}
}
error_string = re.escape(
"An error occurred (ValidationException) when calling the "
"cancel_job operation: unit-test-error"
)
aws_session.cancel_job.side_effect = ClientError(exception_response, "cancel_job")
quantum_job.cancel()
with pytest.raises(ClientError, match=error_string):
quantum_job.cancel()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -985,3 +991,13 @@ def test_exceptions_in_all_device_regions(aws_session):
)
with pytest.raises(ClientError, match=error_message):
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)


@patch("braket.aws.aws_quantum_job.AwsSession")
def test_initialize_session_local_device(mock_new_session, aws_session):
logger = logging.getLogger(__name__)
device = "local:provider.device.name"
# don't change a provided AwsSession
assert AwsQuantumJob._initialize_session(aws_session, device, logger) == aws_session
# otherwise, create an AwsSession with the profile defaults
assert AwsQuantumJob._initialize_session(None, device, logger) == mock_new_session()