Skip to content

Commit

Permalink
Remove basis validation from Observable/ObservableArray
Browse files Browse the repository at this point in the history
The intent is that an Estimator can retrieve the basis terms in an observables array from the `terms` attribute and validate them itself, rather than hard code the allowed terms here. If we do still want to hardcode this in the classes we could add it back, but the current version of expecting subclasses to override ALLOWED_BASIS doesn't work with the current implementation of class methods which done return the subclass type from methods like reshape etc
  • Loading branch information
chriseclectic committed Jan 22, 2024
1 parent 942d997 commit f8bf6c4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 72 deletions.
90 changes: 27 additions & 63 deletions qiskit/primitives/containers/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
"""
from __future__ import annotations

import re
from collections.abc import Mapping, Iterable
from typing import Union, Iterable, Mapping as MappingType
from collections.abc import Mapping
from collections import defaultdict
from functools import lru_cache
from typing import Union
from numbers import Complex

from qiskit.quantum_info import Pauli, SparsePauliOp
Expand All @@ -29,7 +27,7 @@
str,
Pauli,
SparsePauliOp,
Mapping[Union[str, Pauli], complex],
MappingType[Union[str, Pauli], complex],
Iterable[Union[str, Pauli, SparsePauliOp]],
]
"""Types that can be natively used to construct a :const:`BasisObservable`."""
Expand All @@ -38,8 +36,7 @@
class Observable(Mapping):
"""A sparse container for a Hermitian observable for an :class:`.Estimator` primitive."""

ALLOWED_BASIS: str = "IXYZ01+-lr"
"""The allowed characters in :class:`.Observable` strings."""
__slots__ = ("_data", "_num_qubits", "_terms")

def __init__(
self,
Expand All @@ -56,7 +53,8 @@ def __init__(
ValueError: If ``validate=True`` and the input observable-like is not valid.
"""
self._data = data
self._num_qubits = len(next(iter(data)))
self._num_qubits: int = len(next(iter(data)))
self._terms: str = ""
if validate:
self.validate()

Expand All @@ -72,6 +70,15 @@ def __iter__(self):
def __len__(self):
return len(self._data)

@property
def terms(self) -> str:
"""Return a string containing all unique basis terms used in the observable"""
if not self._terms:
# QUESTION: Should terms be `tuple[str, ...]` instead
# to allow for basis identification using more than 1 character?
self._terms = "".join(set().union(*self._data))
return self._terms

@property
def num_qubits(self) -> int:
"""The number of qubits in the observable"""
Expand All @@ -82,13 +89,10 @@ def validate(self):
if not isinstance(self._data, Mapping):
raise TypeError(f"Observable data type {type(self._data)} is not a Mapping.")
for key, value in self._data.items():
try:
self._validate_basis(key)
self._validate_coeff(value)
except TypeError as ex:
raise TypeError(f"Invalid type for item ({key}, {value})") from ex
except Exception as ex: # pylint: disable = broad-except
raise ValueError(f"Invalid value for item ({key}, {value})") from ex
if not isinstance(key, str):
raise TypeError(f"Item {key} is not a str")
if not isinstance(value, Complex):
raise TypeError(f"Value {value} is not a complex number")

@classmethod
def coerce(cls, observable: ObservableLike) -> Observable:
Expand All @@ -113,19 +117,19 @@ def coerce(cls, observable: ObservableLike) -> Observable:

if isinstance(observable, Pauli):
label, phase = observable[:].to_label(), observable.phase
cls._validate_basis(label)
data = {label: 1} if phase == 0 else {label: (-1j) ** phase}
return Observable(data)
return cls.coerce(data)

# String conversion
if isinstance(observable, str):
cls._validate_basis(observable)
return cls.coerce({observable: 1})

# Mapping conversion (with possible Pauli keys)
if isinstance(observable, Mapping):
# NOTE: This assumes length of keys is number of qubits
# this might not be very robust
# We also compute terms while iterating to save doing it again later
terms = set()
num_qubits = len(next(iter(observable)))
unique = defaultdict(complex)
for basis, coeff in observable.items():
Expand All @@ -134,55 +138,15 @@ def coerce(cls, observable: ObservableLike) -> Observable:
if phase != 0:
coeff = coeff * (-1j) ** phase
# Validate basis
cls._validate_basis(basis)
if len(basis) != num_qubits:
raise ValueError(
"Number of qubits must be the same for all observable basis elements."
)
terms = terms.union(basis)
unique[basis] += coeff
return Observable(dict(unique))
obs = Observable(dict(unique))
# Manually set terms so a second iteration over data is not needed
obs._data = "".join(terms)
return obs

raise TypeError(f"Invalid observable type: {type(observable)}")

@classmethod
def _validate_basis(cls, basis: any) -> None:
"""Validate a basis string.
Args:
basis: a basis object to validate.
Raises:
TypeError: If the input basis is not a string
ValueError: If basis string contains invalid characters
"""
if not isinstance(basis, str):
raise TypeError(f"basis {basis} is not a string")

# NOTE: the allowed basis characters can be overridden by modifying the class
# attribute ALLOWED_BASIS
allowed_pattern = _regex_match(cls.ALLOWED_BASIS)
if not allowed_pattern.match(basis):
invalid_pattern = _regex_invalid(cls.ALLOWED_BASIS)
invalid_chars = list(set(invalid_pattern.findall(basis)))
raise ValueError(
f"Observable basis string '{basis}' contains invalid characters {invalid_chars},"
f" allowed characters are {list(cls.ALLOWED_BASIS)}.",
)

@classmethod
def _validate_coeff(cls, coeff: any):
"""Validate the consistency in observables array."""
if not isinstance(coeff, Complex):
raise TypeError(f"Value {coeff} is not a complex number")


@lru_cache(1)
def _regex_match(allowed_chars: str) -> re.Pattern:
"""Return pattern for matching if a string contains only the allowed characters."""
return re.compile(f"^[{re.escape(allowed_chars)}]*$")


@lru_cache(1)
def _regex_invalid(allowed_chars: str) -> re.Pattern:
"""Return pattern for selecting invalid strings"""
return re.compile(f"[^{re.escape(allowed_chars)}]")
47 changes: 38 additions & 9 deletions qiskit/primitives/containers/observables_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,47 @@ def __init__(
observables = observables._array
self._array = object_array(observables, copy=copy, list_types=(PauliList,))
self._shape = self._array.shape
self._num_qubits = None
self._terms = None
if validate:
# Convert array items to Observable objects
# and validate they are on the same number of qubits
num_qubits = None
# and set terms and num_qubits, validating consistency
terms = set()
num_qubits = set()
for ndi, obs in np.ndenumerate(self._array):
basis_obs = Observable.coerce(obs)
basis_num_qubits = len(next(iter(basis_obs)))
if num_qubits is None:
num_qubits = basis_num_qubits
elif basis_num_qubits != num_qubits:
obs = Observable.coerce(obs)
terms.update(obs.terms)
num_qubits.add(obs.num_qubits)
if len(num_qubits) > 1:
raise ValueError(
"The number of qubits must be the same for all observables in the "
"observables array."
)
self._array[ndi] = basis_obs
self._array[ndi] = obs
self._num_qubits = num_qubits
self._terms = "".join(terms)

@property
def terms(self) -> str:
"""Return a string containing all unique basis terms used in the observable"""
if not self._terms:
# QUESTION: Should terms be `tuple[str, ...]` instead
# to allow for basis identification using more than 1 character?
self._terms = "".join(set().union(*(elem.terms for elem in self._array.ravel())))
return self._terms

@property
def num_qubits(self) -> int:
"""The number of qubits in the observable"""
if self._num_qubits is None:
qubits = {elem.num_qubits for elem in self._array.ravel()}
if len(qubits) > 1:
raise ValueError(
"The number of qubits must be the same for all observables in the "
"observables array."
)
self._num_qubits = next(iter(qubits))
return self._num_qubits

def __repr__(self):
prefix = f"{type(self).__name__}("
Expand Down Expand Up @@ -115,7 +141,10 @@ def reshape(self, shape: int | Iterable[int]) -> ObservablesArray:
Returns:
A new array.
"""
return ObservablesArray(self._array.reshape(shape), copy=False, validate=False)
obs = ObservablesArray(self._array.reshape(shape), copy=False, validate=False)
obs._num_qubits = self._num_qubits
obs._terms = self._terms
return obs

def ravel(self) -> ObservablesArray:
"""Return a new array with one dimension.
Expand Down

0 comments on commit f8bf6c4

Please sign in to comment.