diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 44862279c..431b9bd48 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -229,6 +229,29 @@ def refresh_metadata(self) -> None: self._populate_properties(self._aws_session) def _get_session_and_initialize(self, session): + device_region = self._arn.split(":")[3] + return ( + self._get_regional_device_session(session) + if device_region + else self._get_non_regional_device_session(session) + ) + + def _get_regional_device_session(self, session): + device_region = self._arn.split(":")[3] + region_session = ( + session + if session.region == device_region + else AwsSession.copy_session(session, device_region) + ) + try: + self._populate_properties(region_session) + return region_session + except ClientError as e: + raise ValueError(f"'{self._arn}' not found") if e.response["Error"][ + "Code" + ] == "ResourceNotFoundException" else e + + def _get_non_regional_device_session(self, session): current_region = session.region try: self._populate_properties(session) diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 48f9a6984..8a3f7ccb5 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -531,22 +531,41 @@ def __hash__(self) -> int: @staticmethod def _initialize_session(session_value, device, logger): aws_session = session_value or AwsSession() + device_region = device.split(":")[3] + return ( + AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger) + if device_region + else AwsQuantumJob._initialize_non_regional_device_session(aws_session, device, logger) + ) + + @staticmethod + def _initialize_regional_device_session(aws_session, device, logger): + device_region = device.split(":")[3] current_region = aws_session.region + if current_region != device_region: + aws_session = aws_session.copy_session(region=device_region) + logger.info(f"Changed session region from '{current_region}' to '{device_region}'") + try: + aws_session.get_device(device) + return aws_session + except ClientError as e: + raise ValueError(f"'{device}' not found.") if e.response["Error"][ + "Code" + ] == "ResourceNotFoundException" else e + @staticmethod + def _initialize_non_regional_device_session(aws_session, device, logger): + original_region = aws_session.region try: aws_session.get_device(device) return aws_session except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": if "qpu" not in device: - raise ValueError(f"Simulator '{device}' not found in '{current_region}'") + raise ValueError(f"Simulator '{device}' not found in '{original_region}'") else: raise e - return AwsQuantumJob._find_device_session(aws_session, device, current_region, logger) - - @staticmethod - def _find_device_session(aws_session, device, original_region, logger): for region in frozenset(AwsDevice.REGIONS) - {original_region}: device_session = aws_session.copy_session(region=region) try: diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index d78fc59ec..f86e4fa81 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -241,9 +241,14 @@ def test_gate_model_sim_schema(): ) -@pytest.fixture -def arn(): - return "test_arn" +@pytest.fixture( + params=[ + "arn:aws:braket:us-west-1::device/quantum-simulator/amazon/sim", + "arn:aws:braket:::device/quantum-simulator/amazon/sim", + ] +) +def arn(request): + return request.param @pytest.fixture @@ -303,6 +308,7 @@ def _device(arn): def test_device_aws_session(device_capabilities, get_device_data, arn): mock_session = Mock() mock_session.get_device.return_value = get_device_data + mock_session.region = RIGETTI_REGION device = AwsDevice(arn, mock_session) _assert_device_fields(device, device_capabilities, get_device_data) @@ -356,9 +362,69 @@ def test_device_qpu_no_aws_session( _assert_device_fields(device, MOCK_GATE_MODEL_QPU_CAPABILITIES_1, MOCK_GATE_MODEL_QPU_1) +@patch("braket.aws.aws_device.AwsSession.copy_session") +@patch("braket.aws.aws_device.AwsSession") +def test_regional_device_region_switch(aws_session_init, mock_copy_session, aws_session): + device_region = "device-region" + arn = f"arn:aws:braket:{device_region}::device/quantum-simulator/amazon/sim" + aws_session_init.return_value = aws_session + mock_session = Mock() + mock_session.get_device.return_value = MOCK_GATE_MODEL_SIMULATOR + mock_copy_session.return_value = mock_session + device = AwsDevice(arn) + aws_session.get_device.assert_not_called() + mock_copy_session.assert_called_once() + mock_copy_session.assert_called_with(aws_session, device_region) + _assert_device_fields(device, MOCK_GATE_MODEL_SIMULATOR_CAPABILITIES, MOCK_GATE_MODEL_SIMULATOR) + + +@patch("braket.aws.aws_device.AwsSession") +@pytest.mark.parametrize( + "get_device_side_effect, expected_exception", + [ + ( + [ + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ) + ], + ValueError, + ), + ( + [ + ClientError( + { + "Error": { + "Code": "ThrottlingException", + } + }, + "getDevice", + ) + ], + ClientError, + ), + ], +) +def test_regional_device_raises_error( + aws_session_init, get_device_side_effect, expected_exception, aws_session +): + arn = "arn:aws:braket:us-west-1::device/quantum-simulator/amazon/sim" + aws_session.get_device.side_effect = get_device_side_effect + aws_session_init.return_value = aws_session + with pytest.raises(expected_exception): + AwsDevice(arn) + aws_session.get_device.assert_called_once() + + def test_device_refresh_metadata(arn): mock_session = Mock() mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 + mock_session.region = RIGETTI_REGION device = AwsDevice(arn, mock_session) _assert_device_fields(device, MOCK_GATE_MODEL_QPU_CAPABILITIES_1, MOCK_GATE_MODEL_QPU_1) @@ -370,9 +436,10 @@ def test_device_refresh_metadata(arn): def test_equality(arn): mock_session = Mock() mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 + mock_session.region = RIGETTI_REGION device_1 = AwsDevice(arn, mock_session) device_2 = AwsDevice(arn, mock_session) - other_device = AwsDevice("foo_bar", mock_session) + other_device = AwsDevice("arn:aws:braket:::device/quantum-simulator/amazon/bar", mock_session) non_device = "HI" assert device_1 == device_2 @@ -384,6 +451,7 @@ def test_equality(arn): def test_repr(arn): mock_session = Mock() mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 + mock_session.region = RIGETTI_REGION device = AwsDevice(arn, mock_session) expected = "Device('name': {}, 'arn': {})".format(device.name, device.arn) assert repr(device) == expected @@ -616,8 +684,8 @@ def test_run_with_positional_args_and_kwargs( {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, ) @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") -def test_run_env_variables(aws_quantum_task_mock, device, circuit): - device("foo:bar").run(circuit) +def test_run_env_variables(aws_quantum_task_mock, device, circuit, arn): + device(arn).run(circuit) assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path") @@ -670,8 +738,8 @@ def test_run_batch_with_max_parallel_and_kwargs( {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, ) @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") -def test_run_batch_env_variables(aws_quantum_task_mock, device, circuit): - device("foo:bar").run_batch([circuit]) +def test_run_batch_env_variables(aws_quantum_task_mock, device, circuit, arn): + device(arn).run_batch([circuit]) assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path") @@ -688,7 +756,7 @@ def _run_and_assert( ): run_and_assert( aws_quantum_task_mock, - device_factory("foo_bar"), + device_factory("arn:aws:braket:::device/quantum-simulator/amazon/sim"), MOCK_DEFAULT_S3_DESTINATION_FOLDER, AwsDevice.DEFAULT_SHOTS_SIMULATOR, AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, @@ -720,7 +788,7 @@ def _run_batch_and_assert( run_batch_and_assert( aws_quantum_task_mock, aws_session_mock, - device_factory("foo_bar"), + device_factory("arn:aws:braket:::device/quantum-simulator/amazon/sim"), MOCK_DEFAULT_S3_DESTINATION_FOLDER, AwsDevice.DEFAULT_SHOTS_SIMULATOR, AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index 2265fd3c0..22bd64054 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -481,15 +481,20 @@ def role_arn(): return "arn:aws:iam::0000000000:role/AmazonBraketInternalSLR" -@pytest.fixture -def device_arn(): - return "arn:aws:braket:::device/qpu/test/device-name" +@pytest.fixture( + params=[ + "arn:aws:braket:us-test-1::device/qpu/test/device-name", + "arn:aws:braket:::device/qpu/test/device-name", + ] +) +def device_arn(request): + return request.param @pytest.fixture -def prepare_job_args(aws_session): +def prepare_job_args(aws_session, device_arn): return { - "device": Mock(), + "device": device_arn, "source_module": Mock(), "entry_point": Mock(), "image_uri": Mock(), @@ -796,7 +801,8 @@ def test_logs_error(quantum_job, generate_get_job_response, capsys): quantum_job.logs(wait=True, poll_interval_seconds=0) -def test_initialize_session_for_valid_device(device_arn, aws_session, caplog): +def test_initialize_session_for_valid_non_regional_device(aws_session, caplog): + device_arn = "arn:aws:braket:::device/qpu/test/device-name" first_region = aws_session.region logger = logging.getLogger(__name__) @@ -826,6 +832,76 @@ def test_initialize_session_for_valid_device(device_arn, aws_session, caplog): assert f"Changed session region from '{first_region}' to '{aws_session.region}'" in caplog.text +def test_initialize_session_for_valid_regional_device(aws_session, caplog): + device_arn = f"arn:aws:braket:{aws_session.region}::device/qpu/test/device-name" + logger = logging.getLogger(__name__) + aws_session.get_device.return_value = device_arn + caplog.set_level(logging.INFO) + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + assert not caplog.text + + +@pytest.mark.parametrize( + "get_device_side_effect, expected_exception", + [ + ( + [ + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ) + ], + ValueError, + ), + ( + [ + ClientError( + { + "Error": { + "Code": "ThrottlingException", + } + }, + "getDevice", + ) + ], + ClientError, + ), + ], +) +def test_regional_device_raises_error( + get_device_side_effect, expected_exception, aws_session, caplog +): + device_arn = f"arn:aws:braket:{aws_session.region}::device/qpu/test/device-name" + aws_session.get_device.side_effect = get_device_side_effect + logger = logging.getLogger(__name__) + caplog.set_level(logging.INFO) + with pytest.raises(expected_exception): + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + aws_session.get_device.assert_called_with(device_arn) + assert not caplog.text + + +def test_regional_device_switches(aws_session, caplog): + original_region = aws_session.region + device_region = "us-east-1" + device_arn = f"arn:aws:braket:{device_region}::device/qpu/test/device-name" + mock_session = Mock() + mock_session.get_device.side_effect = device_arn + aws_session.copy_session.side_effect = [mock_session] + logger = logging.getLogger(__name__) + caplog.set_level(logging.INFO) + + assert mock_session == AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + + aws_session.copy_session.assert_called_with(region=device_region) + mock_session.get_device.assert_called_with(device_arn) + assert f"Changed session region from '{original_region}' to '{device_region}'" in caplog.text + + def test_initialize_session_for_invalid_device(aws_session, device_arn): logger = logging.getLogger(__name__) aws_session.get_device.side_effect = ClientError( @@ -837,7 +913,7 @@ def test_initialize_session_for_invalid_device(aws_session, device_arn): "getDevice", ) - device_not_found = "QPU 'arn:aws:braket:::device/qpu/test/device-name' not found." + device_not_found = f"'{device_arn}' not found." with pytest.raises(ValueError, match=device_not_found): AwsQuantumJob._initialize_session(aws_session, device_arn, logger) @@ -880,7 +956,8 @@ def test_exception_in_credentials_session_region(device_arn, aws_session): AwsQuantumJob._initialize_session(aws_session, device_arn, logger) -def test_exceptions_in_all_device_regions(device_arn, aws_session): +def test_exceptions_in_all_device_regions(aws_session): + device_arn = "arn:aws:braket:::device/qpu/test/device-name" logger = logging.getLogger(__name__) aws_session.get_device.side_effect = [