From f4a5241cd64baf300bc791855bc8546d1eb246c4 Mon Sep 17 00:00:00 2001 From: Ikko Hamamura Date: Tue, 6 Sep 2022 00:37:43 +0900 Subject: [PATCH] Default run_options for Primitives (#8513) * Add run_options to Primitives * rm unnecessary comments * initial commit of Settings dataclass * Revert "initial commit of Settings dataclass" This reverts commit 96b8479fe4865f7dfe3cce16f056b814ba60f3fb. * fix lint, improve docs, don't return self Co-authored-by: Julien Gacon Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- qiskit/primitives/base_estimator.py | 47 ++++++++++++++++--- qiskit/primitives/base_sampler.py | 42 +++++++++++++++-- qiskit/primitives/estimator.py | 13 +++++ qiskit/primitives/sampler.py | 8 ++-- ...imitives-run_options-eb4a360c3f1e197d.yaml | 6 +++ test/python/primitives/test_estimator.py | 18 +++++++ test/python/primitives/test_sampler.py | 14 ++++++ 7 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 releasenotes/notes/primitives-run_options-eb4a360c3f1e197d.yaml diff --git a/qiskit/primitives/base_estimator.py b/qiskit/primitives/base_estimator.py index 998e162e9497..10be880f3f37 100644 --- a/qiskit/primitives/base_estimator.py +++ b/qiskit/primitives/base_estimator.py @@ -113,6 +113,7 @@ from qiskit.exceptions import QiskitError from qiskit.opflow import PauliSumOp from qiskit.providers import JobV1 as Job +from qiskit.providers import Options from qiskit.quantum_info.operators import SparsePauliOp from qiskit.quantum_info.operators.base_operator import BaseOperator from qiskit.utils.deprecation import deprecate_arguments, deprecate_function @@ -126,13 +127,14 @@ class BaseEstimator(ABC): Base class for Estimator that estimates expectation values of quantum circuits and observables. """ - __hash__ = None # type: ignore + __hash__ = None def __init__( self, circuits: Iterable[QuantumCircuit] | QuantumCircuit | None = None, observables: Iterable[SparsePauliOp] | SparsePauliOp | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, + run_options: dict | None = None, ): """ Creating an instance of an Estimator, or using one in a ``with`` context opens a session that @@ -145,6 +147,7 @@ def __init__( will be bound. Defaults to ``[circ.parameters for circ in circuits]`` The indexing is such that ``parameters[i, j]`` is the j-th formal parameter of ``circuits[i]``. + run_options: runtime options. Raises: QiskitError: For mismatch of circuits and parameters list. @@ -185,6 +188,9 @@ def __init__( f"Different numbers of parameters of {i}-th circuit: " f"expected {circ.num_parameters}, actual {len(params)}." ) + self._run_options = Options() + if run_options is not None: + self._run_options.update_options(**run_options) def __new__( cls, @@ -258,6 +264,23 @@ def parameters(self) -> tuple[ParameterView, ...]: """ return tuple(self._parameters) + @property + def run_options(self) -> Options: + """Return options values for the estimator. + + Returns: + run_options + """ + return self._run_options + + def set_run_options(self, **fields) -> BaseEstimator: + """Set options values for the estimator. + + Args: + **fields: The fields to update the options + """ + self._run_options.update_options(**fields) + @deprecate_function( "The BaseSampler.__call__ method is deprecated as of Qiskit Terra 0.22.0 " "and will be removed no sooner than 3 months after the releasedate. " @@ -296,7 +319,7 @@ def __call__( circuits: the list of circuit indices or circuit objects. observables: the list of observable indices or observable objects. parameter_values: concrete parameters to be bound. - run_options: runtime options used for circuit execution. + run_options: Default runtime options used for circuit execution. Returns: EstimatorResult: The result of the estimator. @@ -312,7 +335,7 @@ def __call__( # Allow objects circuits = [ - self._circuit_ids.get(id(circuit)) # type: ignore + self._circuit_ids.get(id(circuit)) if not isinstance(circuit, (int, np.integer)) else circuit for circuit in circuits @@ -323,7 +346,7 @@ def __call__( "initialize the session." ) observables = [ - self._observable_ids.get(id(observable)) # type: ignore + self._observable_ids.get(id(observable)) if not isinstance(observable, (int, np.integer)) else observable for observable in observables @@ -386,12 +409,14 @@ def __call__( f"The number of circuits is {len(self.observables)}, " f"but the index {max(observables)} is given." ) + run_opts = copy(self.run_options) + run_opts.update_options(**run_options) return self._call( circuits=circuits, observables=observables, parameter_values=parameter_values, - **run_options, + **run_opts.__dict__, ) def run( @@ -495,8 +520,16 @@ def run( f"not match the number of qubits of the {i}-th observable " f"({observable.num_qubits})." ) - - return self._run(circuits, observables, parameter_values, parameter_views, **run_options) + run_opts = copy(self.run_options) + run_opts.update_options(**run_options) + + return self._run( + circuits, + observables, + parameter_values, + parameter_views, + **run_opts.__dict__, + ) @abstractmethod def _call( diff --git a/qiskit/primitives/base_sampler.py b/qiskit/primitives/base_sampler.py index 2b2f14b0d32f..6cd957996029 100644 --- a/qiskit/primitives/base_sampler.py +++ b/qiskit/primitives/base_sampler.py @@ -101,6 +101,7 @@ from qiskit.circuit.parametertable import ParameterView from qiskit.exceptions import QiskitError from qiskit.providers import JobV1 as Job +from qiskit.providers import Options from qiskit.utils.deprecation import deprecate_arguments, deprecate_function from .sampler_result import SamplerResult @@ -112,18 +113,20 @@ class BaseSampler(ABC): Base class of Sampler that calculates quasi-probabilities of bitstrings from quantum circuits. """ - __hash__ = None # type: ignore + __hash__ = None def __init__( self, circuits: Iterable[QuantumCircuit] | QuantumCircuit | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, + run_options: dict | None = None, ): """ Args: circuits: Quantum circuits to be executed. parameters: Parameters of each of the quantum circuits. Defaults to ``[circ.parameters for circ in circuits]``. + run_options: Default runtime options. Raises: QiskitError: For mismatch of circuits and parameters list. @@ -153,6 +156,9 @@ def __init__( f"Different number of parameters ({len(self._parameters)}) " f"and circuits ({len(self._circuits)})" ) + self._run_options = Options() + if run_options is not None: + self._run_options.update_options(**run_options) def __new__( cls, @@ -209,6 +215,23 @@ def parameters(self) -> tuple[ParameterView, ...]: """ return tuple(self._parameters) + @property + def run_options(self) -> Options: + """Return options values for the estimator. + + Returns: + run_options + """ + return self._run_options + + def set_run_options(self, **fields) -> BaseSampler: + """Set options values for the estimator. + + Args: + **fields: The fields to update the options + """ + self._run_options.update_options(**fields) + @deprecate_function( "The BaseSampler.__call__ method is deprecated as of Qiskit Terra 0.22.0 " "and will be removed no sooner than 3 months after the releasedate. " @@ -243,7 +266,7 @@ def __call__( # Allow objects circuits = [ - self._circuit_ids.get(id(circuit)) # type: ignore + self._circuit_ids.get(id(circuit)) if not isinstance(circuit, (int, np.integer)) else circuit for circuit in circuits @@ -285,11 +308,13 @@ def __call__( f"The number of circuits is {len(self.circuits)}, " f"but the index {max(circuits)} is given." ) + run_opts = copy(self.run_options) + run_opts.update_options(**run_options) return self._call( circuits=circuits, parameter_values=parameter_values, - **run_options, + **run_opts.__dict__, ) def run( @@ -358,8 +383,15 @@ def run( f"The number of values ({len(parameter_value)}) does not match " f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit." ) - - return self._run(circuits, parameter_values, parameter_views, **run_options) + run_opts = copy(self.run_options) + run_opts.update_options(**run_options) + + return self._run( + circuits, + parameter_values, + parameter_views, + **run_opts.__dict__, + ) @abstractmethod def _call( diff --git a/qiskit/primitives/estimator.py b/qiskit/primitives/estimator.py index 0715abfcd859..5fa73e00e27f 100644 --- a/qiskit/primitives/estimator.py +++ b/qiskit/primitives/estimator.py @@ -54,7 +54,19 @@ def __init__( circuits: QuantumCircuit | Iterable[QuantumCircuit] | None = None, observables: BaseOperator | PauliSumOp | Iterable[BaseOperator | PauliSumOp] | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, + run_options: dict | None = None, ): + """ + Args: + circuits: circuits that represent quantum states. + observables: observables to be estimated. + parameters: Parameters of each of the quantum circuits. + Defaults to ``[circ.parameters for circ in circuits]``. + run_options: Default runtime options. + + Raises: + QiskitError: if some classical bits are not used for measurements. + """ if isinstance(circuits, QuantumCircuit): circuits = (circuits,) if circuits is not None: @@ -69,6 +81,7 @@ def __init__( circuits=circuits, observables=observables, # type: ignore parameters=parameters, + run_options=run_options, ) self._is_closed = False diff --git a/qiskit/primitives/sampler.py b/qiskit/primitives/sampler.py index f9c1c2910ecb..40d35a0778f2 100644 --- a/qiskit/primitives/sampler.py +++ b/qiskit/primitives/sampler.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Any, cast +from typing import Any import numpy as np @@ -53,12 +53,14 @@ def __init__( self, circuits: QuantumCircuit | Iterable[QuantumCircuit] | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, + run_options: dict | None = None, ): """ Args: circuits: circuits to be executed parameters: Parameters of each of the quantum circuits. Defaults to ``[circ.parameters for circ in circuits]``. + run_options: Default runtime options. Raises: QiskitError: if some classical bits are not used for measurements. @@ -74,7 +76,7 @@ def __init__( preprocessed_circuits.append(circuit) else: preprocessed_circuits = None - super().__init__(preprocessed_circuits, parameters) + super().__init__(preprocessed_circuits, parameters, run_options) self._is_closed = False def _call( @@ -164,5 +166,5 @@ def _preprocess_circuit(circuit: QuantumCircuit): ) c_q_mapping = sorted((c, q) for q, c in q_c_mapping.items()) qargs = [q for _, q in c_q_mapping] - circuit = cast(QuantumCircuit, circuit.remove_final_measurements(inplace=False)) + circuit = circuit.remove_final_measurements(inplace=False) return circuit, qargs diff --git a/releasenotes/notes/primitives-run_options-eb4a360c3f1e197d.yaml b/releasenotes/notes/primitives-run_options-eb4a360c3f1e197d.yaml new file mode 100644 index 000000000000..153dbabd31e9 --- /dev/null +++ b/releasenotes/notes/primitives-run_options-eb4a360c3f1e197d.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added ``run_options`` arguments in constructor of primitives and ``run_options`` methods to + primitives. It is now possible to set default ``run_options`` in addition to passing + ``run_options`` at runtime. diff --git a/test/python/primitives/test_estimator.py b/test/python/primitives/test_estimator.py index 5ce1571867a9..58212a358627 100644 --- a/test/python/primitives/test_estimator.py +++ b/test/python/primitives/test_estimator.py @@ -568,6 +568,24 @@ def test_run_with_shots_option(self): self.assertIsInstance(result, EstimatorResult) np.testing.assert_allclose(result.values, [-1.307397243478641]) + def test_run_options(self): + """Test for run_options""" + with self.subTest("init"): + estimator = Estimator(run_options={"shots": 3000}) + self.assertEqual(estimator.run_options.get("shots"), 3000) + with self.subTest("set_run_options"): + estimator.set_run_options(shots=1024, seed=15) + self.assertEqual(estimator.run_options.get("shots"), 1024) + self.assertEqual(estimator.run_options.get("seed"), 15) + with self.subTest("run"): + result = estimator.run( + [self.ansatz], + [self.observable], + parameter_values=[[0, 1, 1, 2, 3, 5]], + ).result() + self.assertIsInstance(result, EstimatorResult) + np.testing.assert_allclose(result.values, [-1.307397243478641]) + if __name__ == "__main__": unittest.main() diff --git a/test/python/primitives/test_sampler.py b/test/python/primitives/test_sampler.py index b126f927de97..461d23bbd657 100644 --- a/test/python/primitives/test_sampler.py +++ b/test/python/primitives/test_sampler.py @@ -658,6 +658,20 @@ def test_primitive_job_status_done(self): job = sampler.run(circuits=[bell]) self.assertEqual(job.status(), JobStatus.DONE) + def test_run_options(self): + """Test for run_options""" + with self.subTest("init"): + sampler = Sampler(run_options={"shots": 3000}) + self.assertEqual(sampler.run_options.get("shots"), 3000) + with self.subTest("set_run_options"): + sampler.set_run_options(shots=1024, seed=15) + self.assertEqual(sampler.run_options.get("shots"), 1024) + self.assertEqual(sampler.run_options.get("seed"), 15) + with self.subTest("run"): + params, target = self._generate_params_target([1]) + result = sampler.run([self._pqc], parameter_values=params).result() + self._compare_probs(result.quasi_dists, target) + if __name__ == "__main__": unittest.main()