Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Updated AwsQuantumTask to use new device parameters #127

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def run(
>>> )
>>> device = AwsDevice("arn3")
>>> device.run(problem, ("bucket-foo", "key-bar"),
>>> device_parameters = {"dWaveParameters": {"postprocessingType": "SAMPLING"}})
>>> device_parameters={
>>> "providerLevelParameters": {"postprocessingType": "SAMPLING"}}
>>> )

See Also:
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`
Expand Down
40 changes: 28 additions & 12 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

import asyncio
import json
import time
from functools import singledispatch
from logging import Logger, getLogger
Expand All @@ -26,6 +25,11 @@
from braket.aws.aws_session import AwsSession
from braket.circuits.circuit import Circuit
from braket.circuits.circuit_helpers import validate_circuit_and_shots
from braket.device_schema import GateModelParameters
from braket.device_schema.dwave import DwaveDeviceParameters
from braket.device_schema.ionq import IonqDeviceParameters
from braket.device_schema.rigetti import RigettiDeviceParameters
from braket.device_schema.simulators import GateModelSimulatorDeviceParameters
from braket.schema_common import BraketSchemaBase
from braket.task_result import AnnealingTaskResult, GateModelTaskResult
from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask
Expand Down Expand Up @@ -80,7 +84,7 @@ def create(

device_parameters (Dict[str, Any]): Additional parameters to send to the device.
For example, for D-Wave:
`{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}`
`{"providerLevelParameters": {"postprocessingType": "OPTIMIZATION"}}`

Returns:
AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.
Expand Down Expand Up @@ -110,6 +114,7 @@ def create(
aws_session,
create_task_kwargs,
device_parameters or {},
device_arn,
*args,
**kwargs,
)
Expand Down Expand Up @@ -346,7 +351,8 @@ def _create_internal(
task_specification: Union[Circuit, Problem],
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
device_parameters: Dict[str, Any],
device_parameters: Union[dict, BraketSchemaBase],
device_arn: str,
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -358,18 +364,27 @@ def _(
circuit: Circuit,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
device_parameters: Dict[str, Any],
device_parameters: Union[dict, BraketSchemaBase],
device_arn: str,
*args,
**kwargs,
) -> AwsQuantumTask:
validate_circuit_and_shots(circuit, create_task_kwargs["shots"])

# TODO: Update this to use `deviceCapabilities` from Amazon Braket's GetDevice operation
# in order to decide what parameters to build.
paradigm_paramters = GateModelParameters(qubitCount=circuit.qubit_count)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*parameters

if "ionq" in device_arn:
device_parameters = IonqDeviceParameters(paradigmParameters=paradigm_paramters)
elif "rigetti" in device_arn:
device_parameters = RigettiDeviceParameters(paradigmParameters=paradigm_paramters)
else: # default to use simulator
device_parameters = GateModelSimulatorDeviceParameters(
paradigmParameters=paradigm_paramters
)

create_task_kwargs.update(
{
"action": circuit.to_ir().json(),
"deviceParameters": json.dumps(
{"gateModelParameters": {"qubitCount": circuit.qubit_count}}
),
}
{"action": circuit.to_ir().json(), "deviceParameters": device_parameters.json()}
)
task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)
Expand All @@ -380,14 +395,15 @@ def _(
problem: Problem,
aws_session: AwsSession,
create_task_kwargs: Dict[str, Any],
device_parameters: Dict[str, Any],
device_parameters: Union[dict, DwaveDeviceParameters],
device_arn: str,
*args,
**kwargs,
) -> AwsQuantumTask:
create_task_kwargs.update(
{
"action": problem.to_ir().json(),
"deviceParameters": json.dumps({"annealingModelParameters": device_parameters}),
"deviceParameters": DwaveDeviceParameters.parse_obj(device_parameters).json(),
}
)

Expand Down
2 changes: 1 addition & 1 deletion test/integ_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def s3_bucket(s3_resource, boto_session):

region_name = boto_session.region_name
account_id = boto_session.client("sts").get_caller_identity()["Account"]
bucket_name = f"braket-sdk-integ-tests-{account_id}"
bucket_name = f"amazon-braket-sdk-integ-tests-{account_id}"
bucket = s3_resource.Bucket(bucket_name)

try:
Expand Down
51 changes: 34 additions & 17 deletions test/unit_tests/braket/aws/test_aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.

import asyncio
import json
import threading
import time
from unittest.mock import MagicMock, Mock, patch
Expand All @@ -24,6 +23,11 @@
from braket.aws import AwsQuantumTask
from braket.aws.aws_session import AwsSession
from braket.circuits import Circuit
from braket.device_schema import GateModelParameters
from braket.device_schema.dwave import DwaveDeviceParameters
from braket.device_schema.ionq import IonqDeviceParameters
from braket.device_schema.rigetti import RigettiDeviceParameters
from braket.device_schema.simulators import GateModelSimulatorDeviceParameters
from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult

S3_TARGET = AwsSession.S3DestinationFolder("foo", "bar")
Expand Down Expand Up @@ -240,7 +244,7 @@ def test_timeout_completed(aws_session):

# Setup the poll timing such that the timeout will occur after one API poll
quantum_task = AwsQuantumTask(
"foo:bar:arn", aws_session, poll_timeout_seconds=0.5, poll_interval_seconds=1,
"foo:bar:arn", aws_session, poll_timeout_seconds=0.5, poll_interval_seconds=1
)
assert quantum_task.result() is None
_mock_metadata(aws_session, "COMPLETED")
Expand All @@ -256,7 +260,7 @@ def test_timeout_no_result_terminal_state(aws_session):

# Setup the poll timing such that the timeout will occur after one API poll
quantum_task = AwsQuantumTask(
"foo:bar:arn", aws_session, poll_timeout_seconds=0.5, poll_interval_seconds=1,
"foo:bar:arn", aws_session, poll_timeout_seconds=0.5, poll_interval_seconds=1
)
assert quantum_task.result() is None

Expand All @@ -276,23 +280,33 @@ def test_create_invalid_task_specification(aws_session, arn):
AwsQuantumTask.create(aws_session, arn, "foo", S3_TARGET, 1000)


def test_from_circuit_with_shots(aws_session, arn, circuit):
@pytest.mark.parametrize(
"device_arn,device_parameters_class",
[
("device/qpu/ionq", IonqDeviceParameters),
("device/qpu/rigetti", RigettiDeviceParameters),
("device/quantum-simulator", GateModelSimulatorDeviceParameters),
],
)
def test_from_circuit_with_shots(device_arn, device_parameters_class, aws_session, circuit):
mocked_task_arn = "task-arn-1"
aws_session.create_quantum_task.return_value = mocked_task_arn
shots = 53

task = AwsQuantumTask.create(aws_session, arn, circuit, S3_TARGET, shots)
task = AwsQuantumTask.create(aws_session, device_arn, circuit, S3_TARGET, shots)
assert task == AwsQuantumTask(
mocked_task_arn, aws_session, GateModelQuantumTaskResult.from_string
)

_assert_create_quantum_task_called_with(
aws_session,
arn,
device_arn,
circuit,
S3_TARGET,
shots,
{"gateModelParameters": {"qubitCount": circuit.qubit_count}},
device_parameters_class(
paradigmParameters=GateModelParameters(qubitCount=circuit.qubit_count)
),
)


Expand All @@ -303,17 +317,20 @@ def test_from_circuit_with_shots_value_error(aws_session, arn, circuit):
AwsQuantumTask.create(aws_session, arn, circuit, S3_TARGET, 0)


def test_from_annealing(aws_session, arn, problem):
@pytest.mark.parametrize(
"device_parameters",
[
{"providerLevelParameters": {"postprocessingType": "OPTIMIZATION"}},
DwaveDeviceParameters.parse_obj(
{"providerLevelParameters": {"postprocessingType": "OPTIMIZATION"}}
),
],
)
def test_from_annealing(device_parameters, aws_session, arn, problem):
mocked_task_arn = "task-arn-1"
aws_session.create_quantum_task.return_value = mocked_task_arn

task = AwsQuantumTask.create(
aws_session,
arn,
problem,
S3_TARGET,
1000,
device_parameters={"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}},
aws_session, arn, problem, S3_TARGET, 1000, device_parameters=device_parameters
)
assert task == AwsQuantumTask(
mocked_task_arn, aws_session, AnnealingQuantumTaskResult.from_string
Expand All @@ -325,7 +342,7 @@ def test_from_annealing(aws_session, arn, problem):
problem,
S3_TARGET,
1000,
{"annealingModelParameters": {"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}},
DwaveDeviceParameters.parse_obj(device_parameters),
)


Expand Down Expand Up @@ -361,7 +378,7 @@ def _assert_create_quantum_task_called_with(
"outputS3Bucket": s3_results_prefix[0],
"outputS3KeyPrefix": s3_results_prefix[1],
"action": task_description.to_ir().json(),
"deviceParameters": json.dumps(device_parameters),
"deviceParameters": device_parameters.json(),
"shots": shots,
}
)
Expand Down