diff --git a/qiskit/primitives/__init__.py b/qiskit/primitives/__init__.py index e1b38e6f2ac2..14a40199428d 100644 --- a/qiskit/primitives/__init__.py +++ b/qiskit/primitives/__init__.py @@ -65,6 +65,7 @@ from .base.sampler_result import SamplerResult from .containers import ( BindingsArray, + Observable, ObservablesArray, PrimitiveResult, PubResult, diff --git a/qiskit/primitives/containers/__init__.py b/qiskit/primitives/containers/__init__.py index 4f362c146439..b4142f0335bc 100644 --- a/qiskit/primitives/containers/__init__.py +++ b/qiskit/primitives/containers/__init__.py @@ -18,6 +18,7 @@ from .bit_array import BitArray from .data_bin import make_data_bin, DataBin from .estimator_pub import EstimatorPub, EstimatorPubLike +from .observable import Observable from .observables_array import ObservablesArray from .primitive_result import PrimitiveResult from .pub_result import PubResult diff --git a/qiskit/primitives/containers/observable.py b/qiskit/primitives/containers/observable.py new file mode 100644 index 000000000000..c9b0eb5e5e47 --- /dev/null +++ b/qiskit/primitives/containers/observable.py @@ -0,0 +1,188 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + + +""" +Container class for an Estimator observable. +""" +from __future__ import annotations + +import re +from collections.abc import Mapping, Iterable +from collections import defaultdict +from functools import lru_cache +from typing import Union +from numbers import Complex + +from qiskit.quantum_info import Pauli, SparsePauliOp + +ObservableLike = Union[ + str, + Pauli, + SparsePauliOp, + Mapping[Union[str, Pauli], complex], + Iterable[Union[str, Pauli, SparsePauliOp]], +] +"""Types that can be natively used to construct a :const:`BasisObservable`.""" + + +class Observable(Mapping): + """A sparse container for a Hermitian observable for an :class:`.Estimator` primitive.""" + + ALLOWED_BASIS: str = "IXYZ01+-lr" + """The allowed characters in :class:`.Observable` strings.""" + + def __init__( + self, + data: Mapping[str, complex], + validate: bool = True, + ): + """Initialize an observables array. + + Args: + data: The observable data. + validate: If ``True``, the input data is validated during initialization. + + Raises: + ValueError: If ``validate=True`` and the input observable-like is not valid. + """ + self._data = data + self._num_qubits = len(next(iter(data))) + if validate: + self.validate() + + def __repr__(self): + return f"{type(self).__name__}({self._data})" + + def __getitem__(self, key): + return self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + @property + def num_qubits(self) -> int: + """The number of qubits in the observable""" + return self._num_qubits + + def validate(self): + """Validate the consistency in observables array.""" + if not isinstance(self._data, Mapping): + raise TypeError(f"Observable data type {type(self._data)} is not a Mapping.") + for key, value in self._data.items(): + try: + self._validate_basis(key) + self._validate_coeff(value) + except TypeError as ex: + raise TypeError(f"Invalid type for item ({key}, {value})") from ex + except Exception as ex: # pylint: disable = broad-except + raise ValueError(f"Invalid value for item ({key}, {value})") from ex + + @classmethod + def coerce(cls, observable: ObservableLike) -> Observable: + """Coerce an observable-like object into an :class:`.Observable`. + + Args: + observable: The observable-like input. + + Returns: + A coerced observables array. + + Raises: + TypeError: If the input cannot be formatted because its type is not valid. + ValueError: If the input observable is invalid. + """ + + # Pauli-type conversions + if isinstance(observable, SparsePauliOp): + # Call simplify to combine duplicate keys before converting to a mapping + data = dict(observable.simplify(atol=0).to_list()) + return cls.coerce(data) + + if isinstance(observable, Pauli): + label, phase = observable[:].to_label(), observable.phase + cls._validate_basis(label) + data = {label: 1} if phase == 0 else {label: (-1j) ** phase} + return Observable(data) + + # String conversion + if isinstance(observable, str): + cls._validate_basis(observable) + return cls.coerce({observable: 1}) + + # Mapping conversion (with possible Pauli keys) + if isinstance(observable, Mapping): + # NOTE: This assumes length of keys is number of qubits + # this might not be very robust + num_qubits = len(next(iter(observable))) + unique = defaultdict(complex) + for basis, coeff in observable.items(): + if isinstance(basis, Pauli): + basis, phase = basis[:].to_label(), basis.phase + if phase != 0: + coeff = coeff * (-1j) ** phase + # Validate basis + cls._validate_basis(basis) + if len(basis) != num_qubits: + raise ValueError( + "Number of qubits must be the same for all observable basis elements." + ) + unique[basis] += coeff + return Observable(dict(unique)) + + raise TypeError(f"Invalid observable type: {type(observable)}") + + @classmethod + def _validate_basis(cls, basis: any) -> None: + """Validate a basis string. + + Args: + basis: a basis object to validate. + + Raises: + TypeError: If the input basis is not a string + ValueError: If basis string contains invalid characters + """ + if not isinstance(basis, str): + raise TypeError(f"basis {basis} is not a string") + + # NOTE: the allowed basis characters can be overridden by modifying the class + # attribute ALLOWED_BASIS + allowed_pattern = _regex_match(cls.ALLOWED_BASIS) + if not allowed_pattern.match(basis): + invalid_pattern = _regex_invalid(cls.ALLOWED_BASIS) + invalid_chars = list(set(invalid_pattern.findall(basis))) + raise ValueError( + f"Observable basis string '{basis}' contains invalid characters {invalid_chars}," + f" allowed characters are {list(cls.ALLOWED_BASIS)}.", + ) + + @classmethod + def _validate_coeff(cls, coeff: any): + """Validate the consistency in observables array.""" + if not isinstance(coeff, Complex): + raise TypeError(f"Value {coeff} is not a complex number") + + +@lru_cache(1) +def _regex_match(allowed_chars: str) -> re.Pattern: + """Return pattern for matching if a string contains only the allowed characters.""" + return re.compile(f"^[{re.escape(allowed_chars)}]*$") + + +@lru_cache(1) +def _regex_invalid(allowed_chars: str) -> re.Pattern: + """Return pattern for selecting invalid strings""" + return re.compile(f"[^{re.escape(allowed_chars)}]") diff --git a/qiskit/primitives/containers/observables_array.py b/qiskit/primitives/containers/observables_array.py index 12dd51837b20..90bce56f5f42 100644 --- a/qiskit/primitives/containers/observables_array.py +++ b/qiskit/primitives/containers/observables_array.py @@ -16,10 +16,6 @@ """ from __future__ import annotations -import re -from collections import defaultdict -from collections.abc import Mapping as MappingType -from functools import lru_cache from typing import Iterable, Mapping, Union, overload import numpy as np @@ -29,30 +25,17 @@ from .object_array import object_array from .shape import ShapedMixin - -BasisObservable = Mapping[str, complex] -"""Representation type of a single observable.""" - -BasisObservableLike = Union[ - str, - Pauli, - SparsePauliOp, - Mapping[Union[str, Pauli], complex], - Iterable[Union[str, Pauli, SparsePauliOp]], -] -"""Types that can be natively used to construct a :const:`BasisObservable`.""" +from .observable import Observable, ObservableLike class ObservablesArray(ShapedMixin): - """An ND-array of :const:`.BasisObservable` for an :class:`.Estimator` primitive.""" + r"""An ND-array of :class:`.Observable`\s for an :class:`.Estimator` primitive.""" __slots__ = ("_array", "_shape") - ALLOWED_BASIS: str = "IXYZ01+-lr" - """The allowed characters in :const:`BasisObservable` strings.""" def __init__( self, - observables: BasisObservableLike | ArrayLike, + observables: ArrayLike | ObservableLike, copy: bool = True, validate: bool = True, ): @@ -76,9 +59,11 @@ def __init__( self._array = object_array(observables, copy=copy, list_types=(PauliList,)) self._shape = self._array.shape if validate: + # Convert array items to Observable objects + # and validate they are on the same number of qubits num_qubits = None for ndi, obs in np.ndenumerate(self._array): - basis_obs = self.format_observable(obs) + basis_obs = Observable.coerce(obs) basis_num_qubits = len(next(iter(basis_obs))) if num_qubits is None: num_qubits = basis_num_qubits @@ -106,7 +91,7 @@ def __array__(self, dtype=None): raise ValueError("Type must be 'None' or 'object'") @overload - def __getitem__(self, args: int | tuple[int, ...]) -> BasisObservable: + def __getitem__(self, args: int | tuple[int, ...]) -> Observable: ... @overload @@ -143,55 +128,6 @@ def ravel(self) -> ObservablesArray: """ return self.reshape(self.size) - @classmethod - def format_observable(cls, observable: BasisObservableLike) -> BasisObservable: - """Format an observable-like object into a :const:`BasisObservable`. - - Args: - observable: The observable-like to format. - - Returns: - The given observable as a :const:`~BasisObservable`. - - Raises: - TypeError: If the input cannot be formatted because its type is not valid. - ValueError: If the input observable is invalid. - """ - - # Pauli-type conversions - if isinstance(observable, SparsePauliOp): - # Call simplify to combine duplicate keys before converting to a mapping - return cls.format_observable(dict(observable.simplify(atol=0).to_list())) - - if isinstance(observable, Pauli): - label, phase = observable[:].to_label(), observable.phase - return {label: 1} if phase == 0 else {label: (-1j) ** phase} - - # String conversion - if isinstance(observable, str): - cls._validate_basis(observable) - return {observable: 1} - - # Mapping conversion (with possible Pauli keys) - if isinstance(observable, MappingType): - num_qubits = len(next(iter(observable))) - unique = defaultdict(complex) - for basis, coeff in observable.items(): - if isinstance(basis, Pauli): - basis, phase = basis[:].to_label(), basis.phase - if phase != 0: - coeff = coeff * (-1j) ** phase - # Validate basis - cls._validate_basis(basis) - if len(basis) != num_qubits: - raise ValueError( - "Number of qubits must be the same for all observable basis elements." - ) - unique[basis] += coeff - return dict(unique) - - raise TypeError(f"Invalid observable type: {type(observable)}") - @classmethod def coerce(cls, observables: ObservablesArrayLike) -> ObservablesArray: """Coerce ObservablesArrayLike into ObservableArray. @@ -204,62 +140,48 @@ def coerce(cls, observables: ObservablesArrayLike) -> ObservablesArray: """ if isinstance(observables, ObservablesArray): return observables + if isinstance(observables, (str, SparsePauliOp, Pauli, Mapping)): observables = [observables] - return cls(observables) - def validate(self): - """Validate the consistency in observables array.""" + # Convert array items to Observable objects and validate they are on the + # same number of qubis. Note that we copy some of the validation method + # here to avoid double iteration of the array + data = object_array(observables, copy=True, list_types=(PauliList,)) num_qubits = None - for obs in self._array.reshape(-1): - basis_num_qubits = len(next(iter(obs))) + for ndi, obs in np.ndenumerate(data): + basis_obs = Observable.coerce(obs) if num_qubits is None: - num_qubits = basis_num_qubits - elif basis_num_qubits != num_qubits: + num_qubits = basis_obs.num_qubits + elif basis_obs.num_qubits != num_qubits: raise ValueError( "The number of qubits must be the same for all observables in the " "observables array." ) + data[ndi] = basis_obs - @classmethod - def _validate_basis(cls, basis: str) -> None: - """Validate a basis string. - - Args: - basis: a basis string to validate. - - Raises: - ValueError: If basis string contains invalid characters - """ - # NOTE: the allowed basis characters can be overridden by modifying the class - # attribute ALLOWED_BASIS - allowed_pattern = _regex_match(cls.ALLOWED_BASIS) - if not allowed_pattern.match(basis): - invalid_pattern = _regex_invalid(cls.ALLOWED_BASIS) - invalid_chars = list(set(invalid_pattern.findall(basis))) - raise ValueError( - f"Observable basis string '{basis}' contains invalid characters {invalid_chars}," - f" allowed characters are {list(cls.ALLOWED_BASIS)}.", - ) - - -ObservablesArrayLike = Union[ObservablesArray, ArrayLike, BasisObservableLike] -"""Types that can be natively converted to an ObservablesArray""" - - -class PauliArray(ObservablesArray): - """An ND-array of Pauli-basis observables for an :class:`.Estimator` primitive.""" - - ALLOWED_BASIS = "IXYZ" + return cls(data, validate=False) + def validate(self): + """Validate the consistency in observables array.""" + # Convert array items to Observable objects + # and validate they are on the same number of qubits + if not isinstance(self._array, np.ndarray) or self._array.dtype != object: + raise TypeError("Data should be an object ndarray") -@lru_cache(1) -def _regex_match(allowed_chars: str) -> re.Pattern: - """Return pattern for matching if a string contains only the allowed characters.""" - return re.compile(f"^[{re.escape(allowed_chars)}]*$") + num_qubits = None + for ndi, obs in np.ndenumerate(self._array): + if not isinstance(obs, Observable): + raise TypeError(f"item at index {ndi} is a {type(obs)}, not an Observable.") + obs.validate() + if num_qubits is None: + num_qubits = obs.num_qubits + elif obs.num_qubits != num_qubits: + raise ValueError( + "The number of qubits must be the same for all observables in the " + "observables array." + ) -@lru_cache(1) -def _regex_invalid(allowed_chars: str) -> re.Pattern: - """Return pattern for selecting invalid strings""" - return re.compile(f"[^{re.escape(allowed_chars)}]") +ObservablesArrayLike = Union[ObservablesArray, ArrayLike, ObservableLike] +"""Types that can be natively converted to an ObservablesArray""" diff --git a/test/python/primitives/containers/test_observables_array.py b/test/python/primitives/containers/test_observables_array.py index fd43ebe09db5..1f3421e5f688 100644 --- a/test/python/primitives/containers/test_observables_array.py +++ b/test/python/primitives/containers/test_observables_array.py @@ -18,7 +18,7 @@ import numpy as np import qiskit.quantum_info as qi -from qiskit.primitives import ObservablesArray +from qiskit.primitives import ObservablesArray, Observable from qiskit.test import QiskitTestCase @@ -29,37 +29,37 @@ class ObservablesArrayTestCase(QiskitTestCase): @ddt.data(0, 1, 2) def test_format_observable_str(self, num_qubits): """Test format_observable for allowed basis str input""" - for chars in it.permutations(ObservablesArray.ALLOWED_BASIS, num_qubits): + for chars in it.permutations(Observable.ALLOWED_BASIS, num_qubits): label = "".join(chars) - obs = ObservablesArray.format_observable(label) - self.assertEqual(obs, {label: 1}) + obs = Observable.coerce(label) + self.assertEqual(obs._data, {label: 1}) def test_format_observable_custom_basis(self): """Test format_observable for custom allowed basis""" - class PauliArray(ObservablesArray): + class PauliObservable(Observable): """Custom array allowing only Paulis, not projectors""" ALLOWED_BASIS = "IXYZ" with self.assertRaises(ValueError): - PauliArray.format_observable("0101") + PauliObservable.coerce("0101") for p in qi.pauli_basis(1): - obs = PauliArray.format_observable(p) - self.assertEqual(obs, {p.to_label(): 1}) + obs = PauliObservable.coerce(p) + self.assertEqual(obs._data, {p.to_label(): 1}) @ddt.data("iXX", "012", "+/-") def test_format_observable_invalid_str(self, basis): """Test format_observable for Pauli input""" with self.assertRaises(ValueError): - ObservablesArray.format_observable(basis) + Observable.coerce(basis) @ddt.data(1, 2, 3) def test_format_observable_pauli(self, num_qubits): """Test format_observable for Pauli input""" for p in qi.pauli_basis(num_qubits): - obs = ObservablesArray.format_observable(p) - self.assertEqual(obs, {p.to_label(): 1}) + obs = Observable.coerce(p) + self.assertEqual(obs._data, {p.to_label(): 1}) @ddt.data(0, 1, 2, 3) def test_format_observable_phased_pauli(self, phase): @@ -67,8 +67,8 @@ def test_format_observable_phased_pauli(self, phase): pauli = qi.Pauli("IXYZ") pauli.phase = phase coeff = (-1j) ** phase - obs = ObservablesArray.format_observable(pauli) - self.assertIsInstance(obs, dict) + obs = Observable.coerce(pauli) + self.assertIsInstance(obs._data, dict) self.assertEqual(list(obs.keys()), ["IXYZ"]) np.testing.assert_allclose( list(obs.values()), [coeff], err_msg=f"Wrong value for Pauli {pauli}" @@ -79,8 +79,8 @@ def test_format_observable_phased_pauli_str(self, pauli): """Test format_observable for phased Pauli input""" pauli = qi.Pauli(pauli) coeff = (-1j) ** pauli.phase - obs = ObservablesArray.format_observable(pauli) - self.assertIsInstance(obs, dict) + obs = Observable.coerce(pauli) + self.assertIsInstance(obs._data, dict) self.assertEqual(list(obs.keys()), ["IXYZ"]) np.testing.assert_allclose( list(obs.values()), [coeff], err_msg=f"Wrong value for Pauli {pauli}" @@ -89,8 +89,8 @@ def test_format_observable_phased_pauli_str(self, pauli): def test_format_observable_phased_sparse_pauli_op(self): """Test format_observable for SparsePauliOp input with phase paulis""" op = qi.SparsePauliOp(["+I", "-X", "iY", "-iZ"], [1, 2, 3, 4]) - obs = ObservablesArray.format_observable(op) - self.assertIsInstance(obs, dict) + obs = Observable.coerce(op) + self.assertIsInstance(obs._data, dict) self.assertEqual(len(obs), 4) self.assertEqual(sorted(obs.keys()), sorted(["I", "X", "Y", "Z"])) np.testing.assert_allclose([obs[i] for i in ["I", "X", "Y", "Z"]], [1, -2, 3j, -4j]) @@ -98,8 +98,8 @@ def test_format_observable_phased_sparse_pauli_op(self): def test_format_observable_zero_sparse_pauli_op(self): """Test format_observable for SparsePauliOp input with zero val coeffs""" op = qi.SparsePauliOp(["I", "X", "Y", "Z"], [0, 0, 0, 1]) - obs = ObservablesArray.format_observable(op) - self.assertIsInstance(obs, dict) + obs = Observable.coerce(op) + self.assertIsInstance(obs._data, dict) self.assertEqual(len(obs), 1) self.assertEqual(sorted(obs.keys()), ["Z"]) self.assertEqual(obs["Z"], 1) @@ -107,8 +107,8 @@ def test_format_observable_zero_sparse_pauli_op(self): def test_format_observable_duplicate_sparse_pauli_op(self): """Test format_observable for SparsePauliOp wiht duplicate paulis""" op = qi.SparsePauliOp(["XX", "-XX", "iXX", "-iXX"], [2, 1, 3, 2]) - obs = ObservablesArray.format_observable(op) - self.assertIsInstance(obs, dict) + obs = Observable.coerce(op) + self.assertIsInstance(obs._data, dict) self.assertEqual(len(obs), 1) self.assertEqual(list(obs.keys()), ["XX"]) self.assertEqual(obs["XX"], 1 + 1j) @@ -116,7 +116,7 @@ def test_format_observable_duplicate_sparse_pauli_op(self): def test_format_observable_pauli_mapping(self): """Test format_observable for pauli-keyed Mapping input""" mapping = dict(zip(qi.pauli_basis(1), range(1, 5))) - obs = ObservablesArray.format_observable(mapping) + obs = Observable.coerce(mapping) target = {key.to_label(): val for key, val in mapping.items()} self.assertEqual(obs, target) @@ -124,13 +124,13 @@ def test_format_invalid_mapping_qubits(self): """Test an error is raised when different qubits in mapping keys""" mapping = {"IX": 1, "XXX": 2} with self.assertRaises(ValueError): - ObservablesArray.format_observable(mapping) + Observable.coerce(mapping) def test_format_invalid_mapping_basis(self): """Test an error is raised when keys contain invalid characters""" mapping = {"XX": 1, "0Z": 2, "02": 3} with self.assertRaises(ValueError): - ObservablesArray.format_observable(mapping) + Observable.coerce(mapping) def test_init_nested_list_str(self): """Test init with nested lists of str""" @@ -311,6 +311,6 @@ def test_validate(self): ObservablesArray([{"XX": 1}] * 5).validate() ObservablesArray([{"XX": 1}] * 15).reshape((3, 5)).validate() - obs = ObservablesArray([{"XX": 1}, {"XYZ": 1}], validate=False) + obs = ObservablesArray([Observable({"XX": 1}), Observable({"XYZ": 1})], validate=False) with self.assertRaisesRegex(ValueError, "number of qubits must be the same"): obs.validate()