diff --git a/qiskit/primitives/containers/observable.py b/qiskit/primitives/containers/observable.py index c9b0eb5e5e47..8ec10aba11ed 100644 --- a/qiskit/primitives/containers/observable.py +++ b/qiskit/primitives/containers/observable.py @@ -16,11 +16,9 @@ """ from __future__ import annotations -import re -from collections.abc import Mapping, Iterable +from typing import Union, Iterable, Mapping as MappingType +from collections.abc import Mapping from collections import defaultdict -from functools import lru_cache -from typing import Union from numbers import Complex from qiskit.quantum_info import Pauli, SparsePauliOp @@ -29,7 +27,7 @@ str, Pauli, SparsePauliOp, - Mapping[Union[str, Pauli], complex], + MappingType[Union[str, Pauli], complex], Iterable[Union[str, Pauli, SparsePauliOp]], ] """Types that can be natively used to construct a :const:`BasisObservable`.""" @@ -38,8 +36,7 @@ class Observable(Mapping): """A sparse container for a Hermitian observable for an :class:`.Estimator` primitive.""" - ALLOWED_BASIS: str = "IXYZ01+-lr" - """The allowed characters in :class:`.Observable` strings.""" + __slots__ = ("_data", "_num_qubits", "_terms") def __init__( self, @@ -56,7 +53,8 @@ def __init__( ValueError: If ``validate=True`` and the input observable-like is not valid. """ self._data = data - self._num_qubits = len(next(iter(data))) + self._num_qubits: int = len(next(iter(data))) + self._terms: str = "" if validate: self.validate() @@ -72,6 +70,15 @@ def __iter__(self): def __len__(self): return len(self._data) + @property + def terms(self) -> str: + """Return a string containing all unique basis terms used in the observable""" + if not self._terms: + # QUESTION: Should terms be `tuple[str, ...]` instead + # to allow for basis identification using more than 1 character? + self._terms = "".join(set().union(*self._data)) + return self._terms + @property def num_qubits(self) -> int: """The number of qubits in the observable""" @@ -82,13 +89,10 @@ def validate(self): if not isinstance(self._data, Mapping): raise TypeError(f"Observable data type {type(self._data)} is not a Mapping.") for key, value in self._data.items(): - try: - self._validate_basis(key) - self._validate_coeff(value) - except TypeError as ex: - raise TypeError(f"Invalid type for item ({key}, {value})") from ex - except Exception as ex: # pylint: disable = broad-except - raise ValueError(f"Invalid value for item ({key}, {value})") from ex + if not isinstance(key, str): + raise TypeError(f"Item {key} is not a str") + if not isinstance(value, Complex): + raise TypeError(f"Value {value} is not a complex number") @classmethod def coerce(cls, observable: ObservableLike) -> Observable: @@ -113,19 +117,19 @@ def coerce(cls, observable: ObservableLike) -> Observable: if isinstance(observable, Pauli): label, phase = observable[:].to_label(), observable.phase - cls._validate_basis(label) data = {label: 1} if phase == 0 else {label: (-1j) ** phase} - return Observable(data) + return cls.coerce(data) # String conversion if isinstance(observable, str): - cls._validate_basis(observable) return cls.coerce({observable: 1}) # Mapping conversion (with possible Pauli keys) if isinstance(observable, Mapping): # NOTE: This assumes length of keys is number of qubits # this might not be very robust + # We also compute terms while iterating to save doing it again later + terms = set() num_qubits = len(next(iter(observable))) unique = defaultdict(complex) for basis, coeff in observable.items(): @@ -134,55 +138,15 @@ def coerce(cls, observable: ObservableLike) -> Observable: if phase != 0: coeff = coeff * (-1j) ** phase # Validate basis - cls._validate_basis(basis) if len(basis) != num_qubits: raise ValueError( "Number of qubits must be the same for all observable basis elements." ) + terms = terms.union(basis) unique[basis] += coeff - return Observable(dict(unique)) + obs = Observable(dict(unique)) + # Manually set terms so a second iteration over data is not needed + obs._data = "".join(terms) + return obs raise TypeError(f"Invalid observable type: {type(observable)}") - - @classmethod - def _validate_basis(cls, basis: any) -> None: - """Validate a basis string. - - Args: - basis: a basis object to validate. - - Raises: - TypeError: If the input basis is not a string - ValueError: If basis string contains invalid characters - """ - if not isinstance(basis, str): - raise TypeError(f"basis {basis} is not a string") - - # NOTE: the allowed basis characters can be overridden by modifying the class - # attribute ALLOWED_BASIS - allowed_pattern = _regex_match(cls.ALLOWED_BASIS) - if not allowed_pattern.match(basis): - invalid_pattern = _regex_invalid(cls.ALLOWED_BASIS) - invalid_chars = list(set(invalid_pattern.findall(basis))) - raise ValueError( - f"Observable basis string '{basis}' contains invalid characters {invalid_chars}," - f" allowed characters are {list(cls.ALLOWED_BASIS)}.", - ) - - @classmethod - def _validate_coeff(cls, coeff: any): - """Validate the consistency in observables array.""" - if not isinstance(coeff, Complex): - raise TypeError(f"Value {coeff} is not a complex number") - - -@lru_cache(1) -def _regex_match(allowed_chars: str) -> re.Pattern: - """Return pattern for matching if a string contains only the allowed characters.""" - return re.compile(f"^[{re.escape(allowed_chars)}]*$") - - -@lru_cache(1) -def _regex_invalid(allowed_chars: str) -> re.Pattern: - """Return pattern for selecting invalid strings""" - return re.compile(f"[^{re.escape(allowed_chars)}]") diff --git a/qiskit/primitives/containers/observables_array.py b/qiskit/primitives/containers/observables_array.py index 90bce56f5f42..0dd5ee117af1 100644 --- a/qiskit/primitives/containers/observables_array.py +++ b/qiskit/primitives/containers/observables_array.py @@ -58,21 +58,47 @@ def __init__( observables = observables._array self._array = object_array(observables, copy=copy, list_types=(PauliList,)) self._shape = self._array.shape + self._num_qubits = None + self._terms = None if validate: # Convert array items to Observable objects - # and validate they are on the same number of qubits - num_qubits = None + # and set terms and num_qubits, validating consistency + terms = set() + num_qubits = set() for ndi, obs in np.ndenumerate(self._array): - basis_obs = Observable.coerce(obs) - basis_num_qubits = len(next(iter(basis_obs))) - if num_qubits is None: - num_qubits = basis_num_qubits - elif basis_num_qubits != num_qubits: + obs = Observable.coerce(obs) + terms.update(obs.terms) + num_qubits.add(obs.num_qubits) + if len(num_qubits) > 1: raise ValueError( "The number of qubits must be the same for all observables in the " "observables array." ) - self._array[ndi] = basis_obs + self._array[ndi] = obs + self._num_qubits = num_qubits + self._terms = "".join(terms) + + @property + def terms(self) -> str: + """Return a string containing all unique basis terms used in the observable""" + if not self._terms: + # QUESTION: Should terms be `tuple[str, ...]` instead + # to allow for basis identification using more than 1 character? + self._terms = "".join(set().union(*(elem.terms for elem in self._array.ravel()))) + return self._terms + + @property + def num_qubits(self) -> int: + """The number of qubits in the observable""" + if self._num_qubits is None: + qubits = {elem.num_qubits for elem in self._array.ravel()} + if len(qubits) > 1: + raise ValueError( + "The number of qubits must be the same for all observables in the " + "observables array." + ) + self._num_qubits = next(iter(qubits)) + return self._num_qubits def __repr__(self): prefix = f"{type(self).__name__}(" @@ -115,7 +141,10 @@ def reshape(self, shape: int | Iterable[int]) -> ObservablesArray: Returns: A new array. """ - return ObservablesArray(self._array.reshape(shape), copy=False, validate=False) + obs = ObservablesArray(self._array.reshape(shape), copy=False, validate=False) + obs._num_qubits = self._num_qubits + obs._terms = self._terms + return obs def ravel(self) -> ObservablesArray: """Return a new array with one dimension.