Skip to content

Commit

Permalink
feat: Update region switching for regional device arns (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijc committed Feb 15, 2022
1 parent 8d27257 commit 8bd1bc3
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 23 deletions.
23 changes: 23 additions & 0 deletions src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,29 @@ def refresh_metadata(self) -> None:
self._populate_properties(self._aws_session)

def _get_session_and_initialize(self, session):
device_region = self._arn.split(":")[3]
return (
self._get_regional_device_session(session)
if device_region
else self._get_non_regional_device_session(session)
)

def _get_regional_device_session(self, session):
device_region = self._arn.split(":")[3]
region_session = (
session
if session.region == device_region
else AwsSession.copy_session(session, device_region)
)
try:
self._populate_properties(region_session)
return region_session
except ClientError as e:
raise ValueError(f"'{self._arn}' not found") if e.response["Error"][
"Code"
] == "ResourceNotFoundException" else e

def _get_non_regional_device_session(self, session):
current_region = session.region
try:
self._populate_properties(session)
Expand Down
29 changes: 24 additions & 5 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,22 +531,41 @@ def __hash__(self) -> int:
@staticmethod
def _initialize_session(session_value, device, logger):
aws_session = session_value or AwsSession()
device_region = device.split(":")[3]
return (
AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger)
if device_region
else AwsQuantumJob._initialize_non_regional_device_session(aws_session, device, logger)
)

@staticmethod
def _initialize_regional_device_session(aws_session, device, logger):
device_region = device.split(":")[3]
current_region = aws_session.region
if current_region != device_region:
aws_session = aws_session.copy_session(region=device_region)
logger.info(f"Changed session region from '{current_region}' to '{device_region}'")
try:
aws_session.get_device(device)
return aws_session
except ClientError as e:
raise ValueError(f"'{device}' not found.") if e.response["Error"][
"Code"
] == "ResourceNotFoundException" else e

@staticmethod
def _initialize_non_regional_device_session(aws_session, device, logger):
original_region = aws_session.region
try:
aws_session.get_device(device)
return aws_session
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
if "qpu" not in device:
raise ValueError(f"Simulator '{device}' not found in '{current_region}'")
raise ValueError(f"Simulator '{device}' not found in '{original_region}'")
else:
raise e

return AwsQuantumJob._find_device_session(aws_session, device, current_region, logger)

@staticmethod
def _find_device_session(aws_session, device, original_region, logger):
for region in frozenset(AwsDevice.REGIONS) - {original_region}:
device_session = aws_session.copy_session(region=region)
try:
Expand Down
88 changes: 78 additions & 10 deletions test/unit_tests/braket/aws/test_aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,14 @@ def test_gate_model_sim_schema():
)


@pytest.fixture
def arn():
return "test_arn"
@pytest.fixture(
params=[
"arn:aws:braket:us-west-1::device/quantum-simulator/amazon/sim",
"arn:aws:braket:::device/quantum-simulator/amazon/sim",
]
)
def arn(request):
return request.param


@pytest.fixture
Expand Down Expand Up @@ -303,6 +308,7 @@ def _device(arn):
def test_device_aws_session(device_capabilities, get_device_data, arn):
mock_session = Mock()
mock_session.get_device.return_value = get_device_data
mock_session.region = RIGETTI_REGION
device = AwsDevice(arn, mock_session)
_assert_device_fields(device, device_capabilities, get_device_data)

Expand Down Expand Up @@ -356,9 +362,69 @@ def test_device_qpu_no_aws_session(
_assert_device_fields(device, MOCK_GATE_MODEL_QPU_CAPABILITIES_1, MOCK_GATE_MODEL_QPU_1)


@patch("braket.aws.aws_device.AwsSession.copy_session")
@patch("braket.aws.aws_device.AwsSession")
def test_regional_device_region_switch(aws_session_init, mock_copy_session, aws_session):
device_region = "device-region"
arn = f"arn:aws:braket:{device_region}::device/quantum-simulator/amazon/sim"
aws_session_init.return_value = aws_session
mock_session = Mock()
mock_session.get_device.return_value = MOCK_GATE_MODEL_SIMULATOR
mock_copy_session.return_value = mock_session
device = AwsDevice(arn)
aws_session.get_device.assert_not_called()
mock_copy_session.assert_called_once()
mock_copy_session.assert_called_with(aws_session, device_region)
_assert_device_fields(device, MOCK_GATE_MODEL_SIMULATOR_CAPABILITIES, MOCK_GATE_MODEL_SIMULATOR)


@patch("braket.aws.aws_device.AwsSession")
@pytest.mark.parametrize(
"get_device_side_effect, expected_exception",
[
(
[
ClientError(
{
"Error": {
"Code": "ResourceNotFoundException",
}
},
"getDevice",
)
],
ValueError,
),
(
[
ClientError(
{
"Error": {
"Code": "ThrottlingException",
}
},
"getDevice",
)
],
ClientError,
),
],
)
def test_regional_device_raises_error(
aws_session_init, get_device_side_effect, expected_exception, aws_session
):
arn = "arn:aws:braket:us-west-1::device/quantum-simulator/amazon/sim"
aws_session.get_device.side_effect = get_device_side_effect
aws_session_init.return_value = aws_session
with pytest.raises(expected_exception):
AwsDevice(arn)
aws_session.get_device.assert_called_once()


def test_device_refresh_metadata(arn):
mock_session = Mock()
mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1
mock_session.region = RIGETTI_REGION
device = AwsDevice(arn, mock_session)
_assert_device_fields(device, MOCK_GATE_MODEL_QPU_CAPABILITIES_1, MOCK_GATE_MODEL_QPU_1)

Expand All @@ -370,9 +436,10 @@ def test_device_refresh_metadata(arn):
def test_equality(arn):
mock_session = Mock()
mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1
mock_session.region = RIGETTI_REGION
device_1 = AwsDevice(arn, mock_session)
device_2 = AwsDevice(arn, mock_session)
other_device = AwsDevice("foo_bar", mock_session)
other_device = AwsDevice("arn:aws:braket:::device/quantum-simulator/amazon/bar", mock_session)
non_device = "HI"

assert device_1 == device_2
Expand All @@ -384,6 +451,7 @@ def test_equality(arn):
def test_repr(arn):
mock_session = Mock()
mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1
mock_session.region = RIGETTI_REGION
device = AwsDevice(arn, mock_session)
expected = "Device('name': {}, 'arn': {})".format(device.name, device.arn)
assert repr(device) == expected
Expand Down Expand Up @@ -616,8 +684,8 @@ def test_run_with_positional_args_and_kwargs(
{"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"},
)
@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_run_env_variables(aws_quantum_task_mock, device, circuit):
device("foo:bar").run(circuit)
def test_run_env_variables(aws_quantum_task_mock, device, circuit, arn):
device(arn).run(circuit)
assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path")


Expand Down Expand Up @@ -670,8 +738,8 @@ def test_run_batch_with_max_parallel_and_kwargs(
{"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"},
)
@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_run_batch_env_variables(aws_quantum_task_mock, device, circuit):
device("foo:bar").run_batch([circuit])
def test_run_batch_env_variables(aws_quantum_task_mock, device, circuit, arn):
device(arn).run_batch([circuit])
assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path")


Expand All @@ -688,7 +756,7 @@ def _run_and_assert(
):
run_and_assert(
aws_quantum_task_mock,
device_factory("foo_bar"),
device_factory("arn:aws:braket:::device/quantum-simulator/amazon/sim"),
MOCK_DEFAULT_S3_DESTINATION_FOLDER,
AwsDevice.DEFAULT_SHOTS_SIMULATOR,
AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
Expand Down Expand Up @@ -720,7 +788,7 @@ def _run_batch_and_assert(
run_batch_and_assert(
aws_quantum_task_mock,
aws_session_mock,
device_factory("foo_bar"),
device_factory("arn:aws:braket:::device/quantum-simulator/amazon/sim"),
MOCK_DEFAULT_S3_DESTINATION_FOLDER,
AwsDevice.DEFAULT_SHOTS_SIMULATOR,
AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
Expand Down
93 changes: 85 additions & 8 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,20 @@ def role_arn():
return "arn:aws:iam::0000000000:role/AmazonBraketInternalSLR"


@pytest.fixture
def device_arn():
return "arn:aws:braket:::device/qpu/test/device-name"
@pytest.fixture(
params=[
"arn:aws:braket:us-test-1::device/qpu/test/device-name",
"arn:aws:braket:::device/qpu/test/device-name",
]
)
def device_arn(request):
return request.param


@pytest.fixture
def prepare_job_args(aws_session):
def prepare_job_args(aws_session, device_arn):
return {
"device": Mock(),
"device": device_arn,
"source_module": Mock(),
"entry_point": Mock(),
"image_uri": Mock(),
Expand Down Expand Up @@ -796,7 +801,8 @@ def test_logs_error(quantum_job, generate_get_job_response, capsys):
quantum_job.logs(wait=True, poll_interval_seconds=0)


def test_initialize_session_for_valid_device(device_arn, aws_session, caplog):
def test_initialize_session_for_valid_non_regional_device(aws_session, caplog):
device_arn = "arn:aws:braket:::device/qpu/test/device-name"
first_region = aws_session.region
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -826,6 +832,76 @@ def test_initialize_session_for_valid_device(device_arn, aws_session, caplog):
assert f"Changed session region from '{first_region}' to '{aws_session.region}'" in caplog.text


def test_initialize_session_for_valid_regional_device(aws_session, caplog):
device_arn = f"arn:aws:braket:{aws_session.region}::device/qpu/test/device-name"
logger = logging.getLogger(__name__)
aws_session.get_device.return_value = device_arn
caplog.set_level(logging.INFO)
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)
assert not caplog.text


@pytest.mark.parametrize(
"get_device_side_effect, expected_exception",
[
(
[
ClientError(
{
"Error": {
"Code": "ResourceNotFoundException",
}
},
"getDevice",
)
],
ValueError,
),
(
[
ClientError(
{
"Error": {
"Code": "ThrottlingException",
}
},
"getDevice",
)
],
ClientError,
),
],
)
def test_regional_device_raises_error(
get_device_side_effect, expected_exception, aws_session, caplog
):
device_arn = f"arn:aws:braket:{aws_session.region}::device/qpu/test/device-name"
aws_session.get_device.side_effect = get_device_side_effect
logger = logging.getLogger(__name__)
caplog.set_level(logging.INFO)
with pytest.raises(expected_exception):
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)
aws_session.get_device.assert_called_with(device_arn)
assert not caplog.text


def test_regional_device_switches(aws_session, caplog):
original_region = aws_session.region
device_region = "us-east-1"
device_arn = f"arn:aws:braket:{device_region}::device/qpu/test/device-name"
mock_session = Mock()
mock_session.get_device.side_effect = device_arn
aws_session.copy_session.side_effect = [mock_session]
logger = logging.getLogger(__name__)
caplog.set_level(logging.INFO)

assert mock_session == AwsQuantumJob._initialize_session(aws_session, device_arn, logger)

aws_session.copy_session.assert_called_with(region=device_region)
mock_session.get_device.assert_called_with(device_arn)
assert f"Changed session region from '{original_region}' to '{device_region}'" in caplog.text


def test_initialize_session_for_invalid_device(aws_session, device_arn):
logger = logging.getLogger(__name__)
aws_session.get_device.side_effect = ClientError(
Expand All @@ -837,7 +913,7 @@ def test_initialize_session_for_invalid_device(aws_session, device_arn):
"getDevice",
)

device_not_found = "QPU 'arn:aws:braket:::device/qpu/test/device-name' not found."
device_not_found = f"'{device_arn}' not found."
with pytest.raises(ValueError, match=device_not_found):
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)

Expand Down Expand Up @@ -880,7 +956,8 @@ def test_exception_in_credentials_session_region(device_arn, aws_session):
AwsQuantumJob._initialize_session(aws_session, device_arn, logger)


def test_exceptions_in_all_device_regions(device_arn, aws_session):
def test_exceptions_in_all_device_regions(aws_session):
device_arn = "arn:aws:braket:::device/qpu/test/device-name"
logger = logging.getLogger(__name__)

aws_session.get_device.side_effect = [
Expand Down

0 comments on commit 8bd1bc3

Please sign in to comment.