Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add args and kwargs to create_task internal #30

Merged
merged 2 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/braket/aws/aws_qpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run(
task_specification (Union[Circuit, Problem]): Specification of task
(circuit or annealing problem) to run on device.
s3_destination_folder: The S3 location to save the task's results
shots (Optional[int]): The number of times to run the circuit or annealing task
shots (Optional[int]): The number of times to run the circuit or annealing problem
*aws_quantum_task_args: Variable length positional arguments for
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`.
**aws_quantum_task_kwargs: Variable length keyword arguments for
Expand Down
2 changes: 1 addition & 1 deletion src/braket/aws/aws_quantum_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run(
task_specification (Union[Circuit, Problem]): Specification of task
(circuit or annealing problem) to run on device.
s3_destination_folder: The S3 location to save the task's results
shots (Optional[int]): The number of times to run the circuit or annealing task
shots (Optional[int]): The number of times to run the circuit or annealing problem
*aws_quantum_task_args: Variable length positional arguments for
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`.
**aws_quantum_task_kwargs: Variable length keyword arguments for
Expand Down
31 changes: 21 additions & 10 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio
import time
from functools import singledispatch
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Union

from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
Expand All @@ -36,13 +36,16 @@ class AwsQuantumTask(QuantumTask):
ANNEALING_IR_TYPE = "annealing"
DEFAULT_SHOTS = 1_000

DEFAULT_RESULTS_POLL_TIMEOUT = 120
DEFAULT_RESULTS_POLL_INTERVAL = 0.25

@staticmethod
def create(
aws_session: AwsSession,
device_arn: str,
task_specification: Union[Circuit, Problem],
s3_destination_folder: AwsSession.S3DestinationFolder,
shots: int = DEFAULT_SHOTS,
shots: Optional[int] = None,
backend_parameters: Dict[str, Any] = None,
*args,
**kwargs,
Expand All @@ -59,9 +62,9 @@ def create(
(circuit or annealing problem) to run on device.
s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple with bucket (index 0)
and key (index 1) that is the results destination folder in S3.
shots (int): The number of times to run the circuit or annealing task on the device.
If the device is a classical simulator then this implies sampling the state N times,
where N = `shots`. Default = 1_000.
shots (Optional[int]): The number of times to run the circuit or annealing problem
on the device. If the device is a classical simulator then this implies sampling
the state N times, where N = `shots`. If not set, will default to 1_000.
backend_parameters (Dict[str, Any]): Additional parameters to pass to the device.
For example, for D-Wave:
>>> backend_parameters = {"dWaveParameters": {"postprocess": "OPTIMIZATION"}}
Expand All @@ -80,7 +83,11 @@ def create(
"s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
)

create_task_kwargs = _create_common_params(device_arn, s3_destination_folder, shots)
create_task_kwargs = _create_common_params(
device_arn,
s3_destination_folder,
shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
)
return _create_internal(
task_specification,
aws_session,
Expand All @@ -95,8 +102,8 @@ def __init__(
arn: str,
aws_session: AwsSession,
results_formatter: Callable[[str], Any],
poll_timeout_seconds: int = 120,
poll_interval_seconds: int = 0.25,
poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL,
):
"""
Args:
Expand Down Expand Up @@ -277,7 +284,9 @@ def _(
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, GateModelQuantumTaskResult.from_string)
return AwsQuantumTask(
task_arn, aws_session, GateModelQuantumTaskResult.from_string, *args, **kwargs
)


@_create_internal.register
Expand All @@ -298,7 +307,9 @@ def _(
)

task_arn = aws_session.create_quantum_task(**create_task_kwargs)
return AwsQuantumTask(task_arn, aws_session, AnnealingQuantumTaskResult.from_string)
return AwsQuantumTask(
task_arn, aws_session, AnnealingQuantumTaskResult.from_string, *args, **kwargs
)


def _create_common_params(
Expand Down
7 changes: 6 additions & 1 deletion src/braket/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def __init__(self, name: str, status: str, status_reason: str):

@abstractmethod
def run(
self, task_specification: Union[Circuit, Problem], location, shots: Optional[int]
self,
task_specification: Union[Circuit, Problem],
location,
shots: Optional[int],
*args,
**kwargs
) -> QuantumTask:
""" Run a quantum task specification (circuit or annealing program) on this quantum device.

Expand Down