Skip to content

Commit

Permalink
Add args and kwargs to create_task internal (#30)
Browse files Browse the repository at this point in the history
* Add args and kwargs to create_task internal

So polling parameters can be set

* Cleaner shots default passing

Use None for default shots until create_task API is invoked
  • Loading branch information
speller26 authored Feb 19, 2020
1 parent 6d47fd9 commit d2861d6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/braket/aws/aws_qpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run(
task_specification (Union[Circuit, Problem]): Specification of task
(circuit or annealing problem) to run on device.
s3_destination_folder: The S3 location to save the task's results
shots (Optional[int]): The number of times to run the circuit or annealing task
shots (Optional[int]): The number of times to run the circuit or annealing problem
*aws_quantum_task_args: Variable length positional arguments for
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`.
**aws_quantum_task_kwargs: Variable length keyword arguments for
Expand Down
2 changes: 1 addition & 1 deletion src/braket/aws/aws_quantum_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run(
task_specification (Union[Circuit, Problem]): Specification of task
(circuit or annealing problem) to run on device.
s3_destination_folder: The S3 location to save the task's results
shots (Optional[int]): The number of times to run the circuit or annealing task
shots (Optional[int]): The number of times to run the circuit or annealing problem
*aws_quantum_task_args: Variable length positional arguments for
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`.
**aws_quantum_task_kwargs: Variable length keyword arguments for
Expand Down
31 changes: 21 additions & 10 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio
import time
from functools import singledispatch
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Union

from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
Expand All @@ -36,13 +36,16 @@ class AwsQuantumTask(QuantumTask):
ANNEALING_IR_TYPE = "annealing"
DEFAULT_SHOTS = 1_000

DEFAULT_RESULTS_POLL_TIMEOUT = 120
DEFAULT_RESULTS_POLL_INTERVAL = 0.25

@staticmethod
def create(
aws_session: AwsSession,
device_arn: str,
task_specification: Union[Circuit, Problem],
s3_destination_folder: AwsSession.S3DestinationFolder,
shots: int = DEFAULT_SHOTS,
shots: Optional[int] = None,
backend_parameters: Dict[str, Any] = None,
*args,
**kwargs,
Expand All @@ -59,9 +62,9 @@ def create(
(circuit or annealing problem) to run on device.
s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple with bucket (index 0)
and key (index 1) that is the results destination folder in S3.
shots (int): The number of times to run the circuit or annealing task on the device.
If the device is a classical simulator then this implies sampling the state N times,
where N = `shots`. Default = 1_000.
shots (Optional[int]): The number of times to run the circuit or annealing problem
on the device. If the device is a classical simulator then this implies sampling
the state N times, where N = `shots`. If not set, will default to 1_000.
backend_parameters (Dict[str, Any]): Additional parameters to pass to the device.
For example, for D-Wave:
>>> backend_parameters = {"dWaveParameters": {"postprocess": "OPTIMIZATION"}}
Expand All @@ -80,7 +83,11 @@ def create(
"s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
)

create_task_kwargs = _create_common_params(device_arn, s3_destination_folder, shots)
create_task_kwargs = _create_common_params(
device_arn,
s3_destination_folder,
shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
)
return _create_internal(
task_specification,
aws_session,
Expand All @@ -95,8 +102,8 @@ def __init__(
arn: str,
aws_session: AwsSession,
results_formatter: Callable[[str], Any],
poll_timeout_seconds: int = 120,
poll_interval_seconds: int = 0.25,
poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
):
"""
Args:
Expand Down Expand Up @@ -277,7 +284,9 @@ def _(
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, GateModelQuantumTaskResult.from_string)
return AwsQuantumTask(
task_arn, aws_session, GateModelQuantumTaskResult.from_string, *args, **kwargs
)


@_create_internal.register
Expand All @@ -298,7 +307,9 @@ def _(
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, AnnealingQuantumTaskResult.from_string)
return AwsQuantumTask(
task_arn, aws_session, AnnealingQuantumTaskResult.from_string, *args, **kwargs
)


def _create_common_params(
Expand Down
7 changes: 6 additions & 1 deletion src/braket/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def __init__(self, name: str, status: str, status_reason: str):

@abstractmethod
def run(
self, task_specification: Union[Circuit, Problem], location, shots: Optional[int]
self,
task_specification: Union[Circuit, Problem],
location,
shots: Optional[int],
*args,
**kwargs
) -> QuantumTask:
""" Run a quantum task specification (circuit or annealing program) on this quantum device.
Expand Down

0 comments on commit d2861d6

Please sign in to comment.