diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index 160e8e000..631b90360 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -16,7 +16,7 @@ jobs: check-code-format: runs-on: ubuntu-latest steps: - - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 - name: Set up Python uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # v4.7.0 with: diff --git a/.github/workflows/dependent-tests.yml b/.github/workflows/dependent-tests.yml index 732600ba2..10752e27e 100644 --- a/.github/workflows/dependent-tests.yml +++ b/.github/workflows/dependent-tests.yml @@ -21,7 +21,7 @@ jobs: - amazon-braket-pennylane-plugin-python steps: - - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # v4.7.0 with: diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index a26168c11..fbd9a8942 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -12,7 +12,7 @@ jobs: name: Build and publish distribution to PyPi runs-on: ubuntu-latest steps: - - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 - name: Set up Python uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # v4.7.0 with: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 591ffed15..1e316a8cc 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -24,7 +24,7 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # v4.7.0 with: diff --git a/.github/workflows/twine-check.yml b/.github/workflows/twine-check.yml index 5a5966763..f73925f8a 100644 --- a/.github/workflows/twine-check.yml +++ b/.github/workflows/twine-check.yml @@ -14,7 +14,7 @@ jobs: name: Check long description runs-on: ubuntu-latest steps: - - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 - name: Set up Python uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # v4.7.0 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index cb5d60c81..e059789fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## v1.56.0 (2023-09-26) + +### Features + + * add queue visibility information + ## v1.55.1.post0 (2023-09-18) ### Documentation Changes diff --git a/setup.py b/setup.py index f21f27c3d..4dff452cf 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "setuptools", "backoff", "boltons", - "boto3>=1.22.3", + "boto3>=1.28.53", "nest-asyncio", "networkx", "numpy", diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index 865a962de..ed1952453 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.55.2.dev0" +__version__ = "1.56.1.dev0" diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 98255e0fc..d495dbe3f 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -30,6 +30,7 @@ from braket.aws.aws_quantum_task import AwsQuantumTask from braket.aws.aws_quantum_task_batch import AwsQuantumTaskBatch from braket.aws.aws_session import AwsSession +from braket.aws.queue_information import QueueDepthInfo, QueueType from braket.circuits import Circuit, Gate, QubitSet from braket.circuits.gate_calibrations import GateCalibrations from braket.device_schema import DeviceCapabilities, ExecutionDay, GateModelQpuParadigmProperties @@ -677,6 +678,54 @@ def get_device_region(device_arn: str) -> str: "see 'https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html'" ) + def queue_depth(self) -> QueueDepthInfo: + """ + Task queue depth refers to the total number of quantum tasks currently waiting + to run on a particular device. + + Returns: + QueueDepthInfo: Instance of the QueueDepth class representing queue depth + information for quantum tasks and hybrid jobs. + Queue depth refers to the number of quantum tasks and hybrid jobs queued on a particular + device. The normal tasks refers to the quantum tasks not submitted via Hybrid Jobs. + Whereas, the priority tasks refers to the total number of quantum tasks waiting to run + submitted through Amazon Braket Hybrid Jobs. These tasks run before the normal tasks. + If the queue depth for normal or priority quantum tasks is greater than 4000, we display + their respective queue depth as '>4000'. Similarly, for hybrid jobs if there are more + than 1000 jobs queued on a device, display the hybrid jobs queue depth as '>1000'. + Additionally, for QPUs if hybrid jobs queue depth is 0, we display information about + priority and count of the running hybrid job. + + Example: + Queue depth information for a running job. + >>> device = AwsDevice(Device.Amazon.SV1) + >>> print(device.queue_depth()) + QueueDepthInfo(quantum_tasks={: '0', + : '1'}, jobs='0 (1 prioritized job(s) running)') + + If more than 4000 quantum tasks queued on a device. + >>> device = AwsDevice(Device.Amazon.DM1) + >>> print(device.queue_depth()) + QueueDepthInfo(quantum_tasks={: '>4000', + : '2000'}, jobs='100') + """ + metadata = self.aws_session.get_device(arn=self.arn) + queue_metadata = metadata.get("deviceQueueInfo") + queue_info = {} + + for response in queue_metadata: + queue_name = response.get("queue") + queue_priority = response.get("queuePriority") + queue_size = response.get("queueSize") + + if queue_name == "QUANTUM_TASKS_QUEUE": + priority_enum = QueueType(queue_priority) + queue_info.setdefault("quantum_tasks", {})[priority_enum] = queue_size + else: + queue_info["jobs"] = queue_size + + return QueueDepthInfo(**queue_info) + def refresh_gate_calibrations(self) -> Optional[GateCalibrations]: """ Refreshes the gate calibration data upon request. diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 288f17257..3120384fe 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -27,6 +27,7 @@ from braket.aws import AwsDevice from braket.aws.aws_session import AwsSession +from braket.aws.queue_information import HybridJobQueueInfo from braket.jobs import logs from braket.jobs.config import ( CheckpointConfig, @@ -278,6 +279,38 @@ def state(self, use_cached_value: bool = False) -> str: """ return self.metadata(use_cached_value).get("status") + def queue_position(self) -> HybridJobQueueInfo: + """ + The queue position details for the hybrid job. + + Returns: + HybridJobQueueInfo: Instance of HybridJobQueueInfo class representing + the queue position information for the hybrid job. The queue_position is + only returned when the hybrid job is not in RUNNING/CANCELLING/TERMINAL states, + else queue_position is returned as None. If the queue position of the hybrid + job is greater than 15, we return '>15' as the queue_position return value. + + Examples: + job status = QUEUED and position is 2 in the queue. + >>> job.queue_position() + HybridJobQueueInfo(queue_position='2', message=None) + + job status = QUEUED and position is 18 in the queue. + >>> job.queue_position() + HybridJobQueueInfo(queue_position='>15', message=None) + + job status = COMPLETED + >>> job.queue_position() + HybridJobQueueInfo(queue_position=None, + message='Job is in COMPLETED status. AmazonBraket does + not show queue position for this status.') + """ + response = self.metadata()["queueInfo"] + queue_position = None if response.get("position") == "None" else response.get("position") + message = response.get("message") + + return HybridJobQueueInfo(queue_position=queue_position, message=message) + def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: """Display logs for a given hybrid job, optionally tailing them until hybrid job is complete. diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index a4538cdd9..57fecb26c 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -24,6 +24,7 @@ from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing.problem import Problem from braket.aws.aws_session import AwsSession +from braket.aws.queue_information import QuantumTaskQueueInfo, QueueType from braket.circuits import Instruction from braket.circuits.circuit import Circuit, Gate, QubitSet from braket.circuits.circuit_helpers import validate_circuit_and_shots @@ -314,6 +315,40 @@ def state(self, use_cached_value: bool = False) -> str: """ return self._status(use_cached_value) + def queue_position(self) -> QuantumTaskQueueInfo: + """ + The queue position details for the quantum task. + + Returns: + QuantumTaskQueueInfo: Instance of QuantumTaskQueueInfo class + representing the queue position information for the quantum task. + The queue_position is only returned when quantum task is not in + RUNNING/CANCELLING/TERMINAL states, else queue_position is returned as None. + The normal tasks refers to the quantum tasks not submitted via Hybrid Jobs. + Whereas, the priority tasks refers to the total number of quantum tasks waiting to run + submitted through Amazon Braket Hybrid Jobs. These tasks run before the normal tasks. + If the queue position for normal or priority quantum tasks is greater than 2000, + we display their respective queue position as '>2000'. + + Examples: + task status = QUEUED and queue position is 2050 + >>> task.queue_position() + QuantumTaskQueueInfo(queue_type=, + queue_position='>2000', message=None) + + task status = COMPLETED + >>> task.queue_position() + QuantumTaskQueueInfo(queue_type=, + queue_position=None, message='Task is in COMPLETED status. AmazonBraket does + not show queue position for this status.') + """ + response = self.metadata()["queueInfo"] + queue_type = QueueType(response["queuePriority"]) + queue_position = None if response.get("position") == "None" else response.get("position") + message = response.get("message") + + return QuantumTaskQueueInfo(queue_type, queue_position, message) + def _status(self, use_cached_value: bool = False) -> str: metadata = self.metadata(use_cached_value) status = metadata.get("status") diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 300aee13e..5ce04a826 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -279,7 +279,9 @@ def get_quantum_task(self, arn: str) -> Dict[str, Any]: Returns: Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. """ - response = self.braket_client.get_quantum_task(quantumTaskArn=arn) + response = self.braket_client.get_quantum_task( + quantumTaskArn=arn, additionalAttributeNames=["QueueInfo"] + ) broadcast_event(_TaskStatusEvent(arn=response["quantumTaskArn"], status=response["status"])) return response @@ -324,7 +326,7 @@ def get_job(self, arn: str) -> Dict[str, Any]: Returns: Dict[str, Any]: The response from the Amazon Braket `GetQuantumJob` operation. """ - return self.braket_client.get_job(jobArn=arn) + return self.braket_client.get_job(jobArn=arn, additionalAttributeNames=["QueueInfo"]) def cancel_job(self, arn: str) -> Dict[str, Any]: """ diff --git a/src/braket/aws/queue_information.py b/src/braket/aws/queue_information.py new file mode 100644 index 000000000..d45ed8761 --- /dev/null +++ b/src/braket/aws/queue_information.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional + + +class QueueType(str, Enum): + """ + Enumerates the possible priorities for the queue. + + Values: + NORMAL: Represents normal queue for the device. + PRIORITY: Represents priority queue for the device. + """ + + NORMAL = "Normal" + PRIORITY = "Priority" + + +@dataclass() +class QueueDepthInfo: + """ + Represents quantum tasks and hybrid jobs queue depth information. + + Attributes: + quantum_tasks (Dict[QueueType, str]): number of quantum tasks waiting + to run on a device. This includes both 'Normal' and 'Priority' tasks. + For Example, {'quantum_tasks': {QueueType.NORMAL: '7', QueueType.PRIORITY: '3'}} + jobs (str): number of hybrid jobs waiting to run on a device. Additionally, for QPUs if + hybrid jobs queue depth is 0, we display information about priority and count of the + running hybrid jobs. Example, 'jobs': '0 (1 prioritized job(s) running)' + """ + + quantum_tasks: Dict[QueueType, str] + jobs: str + + +@dataclass +class QuantumTaskQueueInfo: + """ + Represents quantum tasks queue information. + + Attributes: + queue_type (QueueType): type of the quantum_task queue either 'Normal' + or 'Priority'. + queue_position (Optional[str]): current position of your quantum task within a respective + device queue. This value can be None based on the state of the task. Default: None. + message (Optional[str]): Additional message information. This key is present only + if 'queue_position' is None. Default: None. + """ + + queue_type: QueueType + queue_position: Optional[str] = None + message: Optional[str] = None + + +@dataclass +class HybridJobQueueInfo: + """ + Represents hybrid job queue information. + + Attributes: + queue_position (Optional[str]): current position of your hybrid job within a respective + device queue. If the queue position of the hybrid job is greater than 15, we + return '>15' as the queue_position return value. The queue_position is only + returned when hybrid job is not in RUNNING/CANCELLING/TERMINAL states, else + queue_position is returned as None. + message (Optional[str]): Additional message information. This key is present only + if 'queue_position' is None. Default: None. + """ + + queue_position: Optional[str] = None + message: Optional[str] = None diff --git a/test/integ_tests/test_queue_information.py b/test/integ_tests/test_queue_information.py new file mode 100644 index 000000000..3398fde40 --- /dev/null +++ b/test/integ_tests/test_queue_information.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from braket.aws import AwsDevice, AwsQuantumJob +from braket.aws.queue_information import ( + HybridJobQueueInfo, + QuantumTaskQueueInfo, + QueueDepthInfo, + QueueType, +) +from braket.circuits import Circuit +from braket.devices import Devices + + +def test_task_queue_position(): + device = AwsDevice(Devices.Amazon.SV1) + + bell = Circuit().h(0).cnot(0, 1) + task = device.run(bell, shots=10) + + # call the queue_position method. + queue_information = task.queue_position() + + # data type validations + assert isinstance(queue_information, QuantumTaskQueueInfo) + assert isinstance(queue_information.queue_type, QueueType) + assert isinstance(queue_information.queue_position, (str, type(None))) + + # assert queue priority + assert queue_information.queue_type in [QueueType.NORMAL, QueueType.PRIORITY] + + # assert message + if queue_information.queue_position is None: + assert queue_information.message is not None + assert isinstance(queue_information.message, (str, type(None))) + else: + assert queue_information.message is None + + +def test_job_queue_position(aws_session): + job = AwsQuantumJob.create( + device=Devices.Amazon.SV1, + source_module="test/integ_tests/job_test_script.py", + entry_point="job_test_script:start_here", + aws_session=aws_session, + wait_until_complete=True, + hyperparameters={"test_case": "completed"}, + ) + + # call the queue_position method. + queue_information = job.queue_position() + + # data type validations + assert isinstance(queue_information, HybridJobQueueInfo) + + # assert message + assert queue_information.queue_position is None + assert isinstance(queue_information.message, str) + + +def test_queue_depth(): + device = AwsDevice(Devices.Amazon.SV1) + + # call the queue_depth method. + queue_information = device.queue_depth() + + # data type validations + assert isinstance(queue_information, QueueDepthInfo) + assert isinstance(queue_information.quantum_tasks, dict) + assert isinstance(queue_information.jobs, str) + + for key, value in queue_information.quantum_tasks.items(): + assert isinstance(key, QueueType) + assert isinstance(value, str) diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index d1f6fa845..c1e034834 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -34,6 +34,7 @@ from jsonschema import validate from braket.aws import AwsDevice, AwsDeviceType, AwsQuantumTask +from braket.aws.queue_information import QueueDepthInfo, QueueType from braket.circuits import Circuit, FreeParameter, Gate, QubitSet from braket.circuits.gate_calibrations import GateCalibrations from braket.device_schema.device_execution_window import DeviceExecutionWindow @@ -77,7 +78,6 @@ MOCK_GATE_MODEL_QPU_CAPABILITIES_JSON_1 ) - MOCK_gate_calibrations_JSON = { "gates": { "0": { @@ -218,6 +218,11 @@ def test_mock_rigetti_schema_1(): "providerName": "Rigetti", "deviceStatus": "OFFLINE", "deviceCapabilities": MOCK_GATE_MODEL_QPU_CAPABILITIES_1.json(), + "deviceQueueInfo": [ + {"queue": "QUANTUM_TASKS_QUEUE", "queueSize": "19", "queuePriority": "Normal"}, + {"queue": "QUANTUM_TASKS_QUEUE", "queueSize": "3", "queuePriority": "Priority"}, + {"queue": "JOBS_QUEUE", "queueSize": "0 (3 prioritized job(s) running)"}, + ], } MOCK_GATE_MODEL_QPU_CAPABILITIES_JSON_2 = { @@ -628,7 +633,6 @@ def test_device_refresh_metadata(arn): "nativeGateCalibrationsRef": "file://hostname/foo/bar", } - MOCK_PULSE_MODEL_QPU_PULSE_CAPABILITIES_JSON_2 = { "braketSchemaHeader": { "name": "braket.device_schema.pulse.pulse_device_action_properties", @@ -2005,3 +2009,14 @@ def test_parse_calibration_data_bad_instr(bad_input): ) device = AwsDevice(DWAVE_ARN, mock_session) device._parse_calibration_json(bad_input) + + +def test_queue_depth(arn): + mock_session = Mock() + mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 + mock_session.region = RIGETTI_REGION + device = AwsDevice(arn, mock_session) + assert device.queue_depth() == QueueDepthInfo( + quantum_tasks={QueueType.NORMAL: "19", QueueType.PRIORITY: "3"}, + jobs="0 (3 prioritized job(s) running)", + ) diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index cbc535a2a..ffc9bdb39 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -24,6 +24,7 @@ from botocore.exceptions import ClientError from braket.aws import AwsQuantumJob, AwsSession +from braket.aws.queue_information import HybridJobQueueInfo @pytest.fixture @@ -226,6 +227,27 @@ def test_metadata_caching(quantum_job, aws_session, generate_get_job_response, q assert aws_session.get_job.call_count == 1 +def test_queue_position(quantum_job, aws_session, generate_get_job_response): + state_1 = "COMPLETED" + queue_info = { + "queue": "JOBS_QUEUE", + "position": "None", + "message": "Job is in COMPLETED status. " + "AmazonBraket does not show queue position for this status.", + } + get_job_response_completed = generate_get_job_response(status=state_1, queueInfo=queue_info) + aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.queue_position() == HybridJobQueueInfo( + queue_position=None, message=queue_info["message"] + ) + + state_2 = "QUEUED" + queue_info = {"queue": "JOBS_QUEUE", "position": "2"} + get_job_response_queued = generate_get_job_response(status=state_2, queueInfo=queue_info) + aws_session.get_job.return_value = get_job_response_queued + assert quantum_job.queue_position() == HybridJobQueueInfo(queue_position="2", message=None) + + def test_state(quantum_job, aws_session, generate_get_job_response, quantum_job_arn): state_1 = "RUNNING" get_job_response_running = generate_get_job_response(status=state_1) diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index de8ead78a..99270ad25 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -26,6 +26,7 @@ from braket.aws import AwsQuantumTask from braket.aws.aws_quantum_task import _create_annealing_device_params from braket.aws.aws_session import AwsSession +from braket.aws.queue_information import QuantumTaskQueueInfo, QueueType from braket.circuits import Circuit from braket.circuits.gates import PulseGate from braket.circuits.serialization import ( @@ -202,6 +203,23 @@ def test_metadata_call_if_none(quantum_task): quantum_task._aws_session.get_quantum_task.assert_called_with(quantum_task.id) +def test_queue_position(quantum_task): + state_1 = "QUEUED" + _mock_metadata(quantum_task._aws_session, state_1) + assert quantum_task.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position="2", message=None + ) + + state_2 = "COMPLETED" + message = ( + f"'Task is in {state_2} status. AmazonBraket does not show queue position for this status.'" + ) + _mock_metadata(quantum_task._aws_session, state_2) + assert quantum_task.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position=None, message=message + ) + + def test_state(quantum_task): state_1 = "RUNNING" _mock_metadata(quantum_task._aws_session, state_1) @@ -1097,11 +1115,32 @@ def _assert_create_quantum_task_called_with( def _mock_metadata(aws_session, state): - aws_session.get_quantum_task.return_value = { - "status": state, - "outputS3Bucket": S3_TARGET.bucket, - "outputS3Directory": S3_TARGET.key, - } + message = ( + f"'Task is in {state} status. AmazonBraket does not show queue position for this status.'" + ) + if state in AwsQuantumTask.TERMINAL_STATES or state in ["RUNNING", "CANCELLING"]: + aws_session.get_quantum_task.return_value = { + "status": state, + "outputS3Bucket": S3_TARGET.bucket, + "outputS3Directory": S3_TARGET.key, + "queueInfo": { + "queue": "QUANTUM_TASKS_QUEUE", + "position": "None", + "queuePriority": "Normal", + "message": message, + }, + } + else: + aws_session.get_quantum_task.return_value = { + "status": state, + "outputS3Bucket": S3_TARGET.bucket, + "outputS3Directory": S3_TARGET.key, + "queueInfo": { + "queue": "QUANTUM_TASKS_QUEUE", + "position": "2", + "queuePriority": "Normal", + }, + } def _mock_s3(aws_session, result): diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index 53423d99f..bfc65d54b 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -418,16 +418,20 @@ def test_create_quantum_task_with_job_token(aws_session): def test_get_quantum_task(aws_session): arn = "foo:bar:arn" status = "STATUS" + queue_info = ["QueueInfo"] return_value = {"quantumTaskArn": arn, "status": status} aws_session.braket_client.get_quantum_task.return_value = return_value assert aws_session.get_quantum_task(arn) == return_value - aws_session.braket_client.get_quantum_task.assert_called_with(quantumTaskArn=arn) + aws_session.braket_client.get_quantum_task.assert_called_with( + quantumTaskArn=arn, additionalAttributeNames=queue_info + ) def test_get_quantum_task_retry(aws_session, throttling_response, resource_not_found_response): arn = "foo:bar:arn" status = "STATUS" + queue_info = ["QueueInfo"] return_value = {"quantumTaskArn": arn, "status": status} aws_session.braket_client.get_quantum_task.side_effect = [ @@ -437,7 +441,9 @@ def test_get_quantum_task_retry(aws_session, throttling_response, resource_not_f ] assert aws_session.get_quantum_task(arn) == return_value - aws_session.braket_client.get_quantum_task.assert_called_with(quantumTaskArn=arn) + aws_session.braket_client.get_quantum_task.assert_called_with( + quantumTaskArn=arn, additionalAttributeNames=queue_info + ) assert aws_session.braket_client.get_quantum_task.call_count == 3 @@ -474,16 +480,20 @@ def test_get_quantum_task_does_not_retry_other_exceptions(aws_session): def test_get_job(aws_session, get_job_response): arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + queue_info = ["QueueInfo"] aws_session.braket_client.get_job.return_value = get_job_response assert aws_session.get_job(arn) == get_job_response - aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + aws_session.braket_client.get_job.assert_called_with( + jobArn=arn, additionalAttributeNames=queue_info + ) def test_get_job_retry( aws_session, get_job_response, throttling_response, resource_not_found_response ): arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + queue_info = ["QueueInfo"] aws_session.braket_client.get_job.side_effect = [ ClientError(resource_not_found_response, "unit-test"), @@ -492,12 +502,15 @@ def test_get_job_retry( ] assert aws_session.get_job(arn) == get_job_response - aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + aws_session.braket_client.get_job.assert_called_with( + jobArn=arn, additionalAttributeNames=queue_info + ) assert aws_session.braket_client.get_job.call_count == 3 def test_get_job_fail_after_retries(aws_session, throttling_response, resource_not_found_response): arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + queue_info = ["QueueInfo"] aws_session.braket_client.get_job.side_effect = [ ClientError(resource_not_found_response, "unit-test"), @@ -507,12 +520,15 @@ def test_get_job_fail_after_retries(aws_session, throttling_response, resource_n with pytest.raises(ClientError): aws_session.get_job(arn) - aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + aws_session.braket_client.get_job.assert_called_with( + jobArn=arn, additionalAttributeNames=queue_info + ) assert aws_session.braket_client.get_job.call_count == 3 def test_get_job_does_not_retry_other_exceptions(aws_session): arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + queue_info = ["QueueInfo"] exception_response = { "Error": { "Code": "SomeOtherException", @@ -526,7 +542,9 @@ def test_get_job_does_not_retry_other_exceptions(aws_session): with pytest.raises(ClientError): aws_session.get_job(arn) - aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + aws_session.braket_client.get_job.assert_called_with( + jobArn=arn, additionalAttributeNames=queue_info + ) assert aws_session.braket_client.get_job.call_count == 1