From f8bf6c412822e9301054312dc8c849ba21b0321d Mon Sep 17 00:00:00 2001 From: "Christopher J. Wood" Date: Mon, 22 Jan 2024 12:31:18 -0500 Subject: [PATCH] Remove basis validation from Observable/ObservableArray The intent is that an Estimator can retrieve the basis terms in an observables array from the `terms` attribute and validate them itself, rather than hard code the allowed terms here. If we do still want to hardcode this in the classes we could add it back, but the current version of expecting subclasses to override ALLOWED_BASIS doesn't work with the current implementation of class methods which done return the subclass type from methods like reshape etc --- qiskit/primitives/containers/observable.py | 90 ++++++------------- .../containers/observables_array.py | 47 ++++++++-- 2 files changed, 65 insertions(+), 72 deletions(-) diff --git a/qiskit/primitives/containers/observable.py b/qiskit/primitives/containers/observable.py index c9b0eb5e5e47..8ec10aba11ed 100644 --- a/qiskit/primitives/containers/observable.py +++ b/qiskit/primitives/containers/observable.py @@ -16,11 +16,9 @@ """ from __future__ import annotations -import re -from collections.abc import Mapping, Iterable +from typing import Union, Iterable, Mapping as MappingType +from collections.abc import Mapping from collections import defaultdict -from functools import lru_cache -from typing import Union from numbers import Complex from qiskit.quantum_info import Pauli, SparsePauliOp @@ -29,7 +27,7 @@ str, Pauli, SparsePauliOp, - Mapping[Union[str, Pauli], complex], + MappingType[Union[str, Pauli], complex], Iterable[Union[str, Pauli, SparsePauliOp]], ] """Types that can be natively used to construct a :const:`BasisObservable`.""" @@ -38,8 +36,7 @@ class Observable(Mapping): """A sparse container for a Hermitian observable for an :class:`.Estimator` primitive.""" - ALLOWED_BASIS: str = "IXYZ01+-lr" - """The allowed characters in :class:`.Observable` strings.""" + __slots__ = ("_data", "_num_qubits", "_terms") def __init__( self, @@ -56,7 +53,8 @@ def __init__( ValueError: If ``validate=True`` and the input observable-like is not valid. """ self._data = data - self._num_qubits = len(next(iter(data))) + self._num_qubits: int = len(next(iter(data))) + self._terms: str = "" if validate: self.validate() @@ -72,6 +70,15 @@ def __iter__(self): def __len__(self): return len(self._data) + @property + def terms(self) -> str: + """Return a string containing all unique basis terms used in the observable""" + if not self._terms: + # QUESTION: Should terms be `tuple[str, ...]` instead + # to allow for basis identification using more than 1 character? + self._terms = "".join(set().union(*self._data)) + return self._terms + @property def num_qubits(self) -> int: """The number of qubits in the observable""" @@ -82,13 +89,10 @@ def validate(self): if not isinstance(self._data, Mapping): raise TypeError(f"Observable data type {type(self._data)} is not a Mapping.") for key, value in self._data.items(): - try: - self._validate_basis(key) - self._validate_coeff(value) - except TypeError as ex: - raise TypeError(f"Invalid type for item ({key}, {value})") from ex - except Exception as ex: # pylint: disable = broad-except - raise ValueError(f"Invalid value for item ({key}, {value})") from ex + if not isinstance(key, str): + raise TypeError(f"Item {key} is not a str") + if not isinstance(value, Complex): + raise TypeError(f"Value {value} is not a complex number") @classmethod def coerce(cls, observable: ObservableLike) -> Observable: @@ -113,19 +117,19 @@ def coerce(cls, observable: ObservableLike) -> Observable: if isinstance(observable, Pauli): label, phase = observable[:].to_label(), observable.phase - cls._validate_basis(label) data = {label: 1} if phase == 0 else {label: (-1j) ** phase} - return Observable(data) + return cls.coerce(data) # String conversion if isinstance(observable, str): - cls._validate_basis(observable) return cls.coerce({observable: 1}) # Mapping conversion (with possible Pauli keys) if isinstance(observable, Mapping): # NOTE: This assumes length of keys is number of qubits # this might not be very robust + # We also compute terms while iterating to save doing it again later + terms = set() num_qubits = len(next(iter(observable))) unique = defaultdict(complex) for basis, coeff in observable.items(): @@ -134,55 +138,15 @@ def coerce(cls, observable: ObservableLike) -> Observable: if phase != 0: coeff = coeff * (-1j) ** phase # Validate basis - cls._validate_basis(basis) if len(basis) != num_qubits: raise ValueError( "Number of qubits must be the same for all observable basis elements." ) + terms = terms.union(basis) unique[basis] += coeff - return Observable(dict(unique)) + obs = Observable(dict(unique)) + # Manually set terms so a second iteration over data is not needed + obs._data = "".join(terms) + return obs raise TypeError(f"Invalid observable type: {type(observable)}") - - @classmethod - def _validate_basis(cls, basis: any) -> None: - """Validate a basis string. - - Args: - basis: a basis object to validate. - - Raises: - TypeError: If the input basis is not a string - ValueError: If basis string contains invalid characters - """ - if not isinstance(basis, str): - raise TypeError(f"basis {basis} is not a string") - - # NOTE: the allowed basis characters can be overridden by modifying the class - # attribute ALLOWED_BASIS - allowed_pattern = _regex_match(cls.ALLOWED_BASIS) - if not allowed_pattern.match(basis): - invalid_pattern = _regex_invalid(cls.ALLOWED_BASIS) - invalid_chars = list(set(invalid_pattern.findall(basis))) - raise ValueError( - f"Observable basis string '{basis}' contains invalid characters {invalid_chars}," - f" allowed characters are {list(cls.ALLOWED_BASIS)}.", - ) - - @classmethod - def _validate_coeff(cls, coeff: any): - """Validate the consistency in observables array.""" - if not isinstance(coeff, Complex): - raise TypeError(f"Value {coeff} is not a complex number") - - -@lru_cache(1) -def _regex_match(allowed_chars: str) -> re.Pattern: - """Return pattern for matching if a string contains only the allowed characters.""" - return re.compile(f"^[{re.escape(allowed_chars)}]*$") - - -@lru_cache(1) -def _regex_invalid(allowed_chars: str) -> re.Pattern: - """Return pattern for selecting invalid strings""" - return re.compile(f"[^{re.escape(allowed_chars)}]") diff --git a/qiskit/primitives/containers/observables_array.py b/qiskit/primitives/containers/observables_array.py index 90bce56f5f42..0dd5ee117af1 100644 --- a/qiskit/primitives/containers/observables_array.py +++ b/qiskit/primitives/containers/observables_array.py @@ -58,21 +58,47 @@ def __init__( observables = observables._array self._array = object_array(observables, copy=copy, list_types=(PauliList,)) self._shape = self._array.shape + self._num_qubits = None + self._terms = None if validate: # Convert array items to Observable objects - # and validate they are on the same number of qubits - num_qubits = None + # and set terms and num_qubits, validating consistency + terms = set() + num_qubits = set() for ndi, obs in np.ndenumerate(self._array): - basis_obs = Observable.coerce(obs) - basis_num_qubits = len(next(iter(basis_obs))) - if num_qubits is None: - num_qubits = basis_num_qubits - elif basis_num_qubits != num_qubits: + obs = Observable.coerce(obs) + terms.update(obs.terms) + num_qubits.add(obs.num_qubits) + if len(num_qubits) > 1: raise ValueError( "The number of qubits must be the same for all observables in the " "observables array." ) - self._array[ndi] = basis_obs + self._array[ndi] = obs + self._num_qubits = num_qubits + self._terms = "".join(terms) + + @property + def terms(self) -> str: + """Return a string containing all unique basis terms used in the observable""" + if not self._terms: + # QUESTION: Should terms be `tuple[str, ...]` instead + # to allow for basis identification using more than 1 character? + self._terms = "".join(set().union(*(elem.terms for elem in self._array.ravel()))) + return self._terms + + @property + def num_qubits(self) -> int: + """The number of qubits in the observable""" + if self._num_qubits is None: + qubits = {elem.num_qubits for elem in self._array.ravel()} + if len(qubits) > 1: + raise ValueError( + "The number of qubits must be the same for all observables in the " + "observables array." + ) + self._num_qubits = next(iter(qubits)) + return self._num_qubits def __repr__(self): prefix = f"{type(self).__name__}(" @@ -115,7 +141,10 @@ def reshape(self, shape: int | Iterable[int]) -> ObservablesArray: Returns: A new array. """ - return ObservablesArray(self._array.reshape(shape), copy=False, validate=False) + obs = ObservablesArray(self._array.reshape(shape), copy=False, validate=False) + obs._num_qubits = self._num_qubits + obs._terms = self._terms + return obs def ravel(self) -> ObservablesArray: """Return a new array with one dimension.