From 1431e091717afbd8c30ec926c11ec0b6cb32c7fb Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Thu, 21 Sep 2023 13:50:29 -0700 Subject: [PATCH] fix: fixing search device when don't have access to a region. --- src/braket/aws/aws_device.py | 42 +++++++----- test/unit_tests/braket/aws/test_aws_device.py | 68 +++++++++++++++++++ 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index fbd493071..98255e0fc 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -17,6 +17,7 @@ import json import os import urllib.request +import warnings from datetime import datetime from enum import Enum from typing import Dict, List, Optional, Tuple, Union @@ -601,23 +602,32 @@ def get_devices( types_for_region = sorted( types if region == session_region else types - {AwsDeviceType.SIMULATOR} ) - region_device_arns = [ - result["deviceArn"] - for result in session_for_region.search_devices( - arns=arns, - names=names, - types=types_for_region, - statuses=statuses, - provider_names=provider_names, + try: + region_device_arns = [ + result["deviceArn"] + for result in session_for_region.search_devices( + arns=arns, + names=names, + types=types_for_region, + statuses=statuses, + provider_names=provider_names, + ) + ] + device_map.update( + { + arn: AwsDevice(arn, session_for_region) + for arn in region_device_arns + if arn not in device_map + } ) - ] - device_map.update( - { - arn: AwsDevice(arn, session_for_region) - for arn in region_device_arns - if arn not in device_map - } - ) + except ClientError as e: + error_code = e.response["Error"]["Code"] + warnings.warn( + f"{error_code}: Unable to search region '{region}' for devices." + " Please check your settings or try again later." + f" Continuing without devices in '{region}'." + ) + devices = list(device_map.values()) devices.sort(key=lambda x: getattr(x, order_by)) return devices diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index 40abf6389..d1f6fa845 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -1669,6 +1669,74 @@ def test_get_devices_simulators_only(mock_copy_session, aws_session): assert [result.name for result in results] == ["SV1"] +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_get_devices_with_error_in_region(mock_copy_session, aws_session): + aws_session.search_devices.side_effect = [ + # us-west-1 + [ + { + "deviceArn": SV1_ARN, + "deviceName": "SV1", + "deviceType": "SIMULATOR", + "deviceStatus": "ONLINE", + "providerName": "Amazon Braket", + } + ], + ValueError("should not be reachable"), + ] + aws_session.get_device.side_effect = [ + MOCK_GATE_MODEL_SIMULATOR, + ValueError("should not be reachable"), + ] + session_for_region = Mock() + session_for_region.search_devices.side_effect = [ + # us-east-1 + [ + { + "deviceArn": IONQ_ARN, + "deviceName": "IonQ Device", + "deviceType": "QPU", + "deviceStatus": "ONLINE", + "providerName": "IonQ", + }, + ], + # us-west-2 + ClientError( + { + "Error": { + "Code": "Test Code", + "Message": "Test Message", + } + }, + "Test Operation", + ), + # eu-west-2 + [ + { + "deviceArn": OQC_ARN, + "deviceName": "Lucy", + "deviceType": "QPU", + "deviceStatus": "ONLINE", + "providerName": "OQC", + } + ], + # Only two regions to search outside of current + ValueError("should not be reachable"), + ] + session_for_region.get_device.side_effect = [ + MOCK_GATE_MODEL_QPU_2, + MOCK_GATE_MODEL_QPU_3, + ValueError("should not be reachable"), + ] + mock_copy_session.return_value = session_for_region + # Search order: us-east-1, us-west-1, us-west-2, eu-west-2 + results = AwsDevice.get_devices( + statuses=["ONLINE"], + aws_session=aws_session, + ) + assert [result.name for result in results] == ["Blah", "Lucy", "SV1"] + + @pytest.mark.xfail(raises=ValueError) def test_get_devices_invalid_order_by(): AwsDevice.get_devices(order_by="foo")