From 96b8479fe4865f7dfe3cce16f056b814ba60f3fb Mon Sep 17 00:00:00 2001 From: ikkoham Date: Thu, 11 Aug 2022 23:57:01 +0900 Subject: [PATCH] initial commit of Settings dataclass --- qiskit/primitives/base_sampler.py | 42 +++++++++++++------------------ qiskit/primitives/sampler.py | 7 ++++-- qiskit/primitives/settings.py | 38 ++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 27 deletions(-) create mode 100644 qiskit/primitives/settings.py diff --git a/qiskit/primitives/base_sampler.py b/qiskit/primitives/base_sampler.py index 9dbae175a1e4..6aa818585554 100644 --- a/qiskit/primitives/base_sampler.py +++ b/qiskit/primitives/base_sampler.py @@ -92,6 +92,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from copy import copy +from dataclasses import asdict from typing import cast from warnings import warn @@ -105,6 +106,7 @@ from qiskit.utils.deprecation import deprecate_arguments, deprecate_function from .sampler_result import SamplerResult +from .settings import Settings class BaseSampler(ABC): @@ -117,9 +119,9 @@ class BaseSampler(ABC): def __init__( self, + settings: Settings, circuits: Iterable[QuantumCircuit] | QuantumCircuit | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, - run_options: dict | None = None, ): """ Args: @@ -156,9 +158,7 @@ def __init__( f"Different number of parameters ({len(self._parameters)}) " f"and circuits ({len(self._circuits)})" ) - self._run_options = Options() - if run_options is not None: - self._run_options.update_options(**run_options) + self._settings = settings def __new__( cls, @@ -216,25 +216,17 @@ def parameters(self) -> tuple[ParameterView, ...]: return tuple(self._parameters) @property - def run_options(self) -> Options: - """Return options values for the estimator. + def settings(self) -> Settings: + """Return settings values for the sampler. Returns: - run_options + Settings for sampler. """ - return self._run_options + return self._settings - def set_run_options(self, **fields) -> BaseSampler: - """Set options values for the estimator. - - Args: - **fields: The fields to update the options - - Returns: - self - """ - self._run_options.update_options(**fields) - return self + @settings.setter + def settings(self, settings): + self._settings = settings @deprecate_function( "The BaseSampler.__call__ method is deprecated as of Qiskit Terra 0.21.0 " @@ -312,13 +304,13 @@ def __call__( f"The number of circuits is {len(self.circuits)}, " f"but the index {max(circuits)} is given." ) - run_opts = copy(self.run_options) - run_opts.update_options(**run_options) + run_opts = asdict(self.settings.run_options) + run_opts.update(run_options) return self._call( circuits=circuits, parameter_values=parameter_values, - **run_opts.__dict__, + **run_opts, ) def run( @@ -387,14 +379,14 @@ def run( f"The number of values ({len(parameter_value)}) does not match " f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit." ) - run_opts = copy(self.run_options) - run_opts.update_options(**run_options) + run_opts = asdict(self.settings.run_options) + run_opts.update(run_options) return self._run( circuits, parameter_values, parameter_views, - **run_opts.__dict__, + **run_opts, ) @abstractmethod diff --git a/qiskit/primitives/sampler.py b/qiskit/primitives/sampler.py index 7d740f8a8a32..a410f0677a04 100644 --- a/qiskit/primitives/sampler.py +++ b/qiskit/primitives/sampler.py @@ -29,6 +29,7 @@ from .base_sampler import BaseSampler from .primitive_job import PrimitiveJob from .sampler_result import SamplerResult +from .settings import ReferenceSettings from .utils import final_measurement_mapping, init_circuit @@ -53,7 +54,7 @@ def __init__( self, circuits: QuantumCircuit | Iterable[QuantumCircuit] | None = None, parameters: Iterable[Iterable[Parameter]] | None = None, - run_options: dict | None = None, + settings: ReferenceSettings | None = None, ): """ Args: @@ -76,7 +77,9 @@ def __init__( preprocessed_circuits.append(circuit) else: preprocessed_circuits = None - super().__init__(preprocessed_circuits, parameters, run_options) + if settings is None: + settings = ReferenceSettings() + super().__init__(settings, preprocessed_circuits, parameters) self._is_closed = False def _call( diff --git a/qiskit/primitives/settings.py b/qiskit/primitives/settings.py new file mode 100644 index 000000000000..4297c1f12e68 --- /dev/null +++ b/qiskit/primitives/settings.py @@ -0,0 +1,38 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Run options for Primitives.""" +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np + + +@dataclass +class RunOptions: + ... + + +@dataclass +class Settings: + run_options: RunOptions + + +@dataclass +class ReferenceRunOptions(RunOptions): + shots: int | None = None + seed: int | np.random.Generator | None = None + + +@dataclass +class ReferenceSettings(Settings): + run_options: RunOptions = field(default_factory=ReferenceRunOptions)