Skip to content

Commit

Permalink
Fix hash of Parameter and ParameterExpression (#10875)
Browse files Browse the repository at this point in the history
* Fix hash of `Parameter` and `ParameterExpression`

This fixes the construction paths for `Parameter` and its hash such that
it will now correctly hash equal to any `ParameterExpression`s that it
compares equal to.  This is a requirement of the Python data model for
hashmaps and hashsets, which previously we were breaking.  In order to
achieve this, we slightly modify the hash key such that the `Parameter`
instances are no longer a part of the hash of `ParameterExpression`,
which means we can use the same hashing strategy for both.

This rearrangement has the benefit of removing the requirement for the
`__new__` overrides on `Parameter` and `ParameterVectorElement`.

* Add type hint

Co-authored-by: Matthew Treinish <[email protected]>

* Tweak documentation wording

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
jakelishman and mtreinish authored Oct 10, 2023
1 parent 4f5d90d commit 8651d34
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 87 deletions.
73 changes: 40 additions & 33 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
Parameter Class for variable parameters.
"""

from uuid import uuid4
from __future__ import annotations

from uuid import uuid4, UUID

from qiskit.circuit.exceptions import CircuitError
from qiskit.utils import optionals as _optionals
Expand Down Expand Up @@ -52,34 +54,27 @@ class Parameter(ParameterExpression):

__slots__ = ("_name", "_uuid", "_hash")

def __new__(cls, name, uuid=None): # pylint: disable=unused-argument
# Parameter relies on self._uuid being set prior to other attributes
# (e.g. symbol_map) which may depend on self._uuid for Parameter's hash
# or __eq__ functions.
obj = object.__new__(cls)

if uuid is None:
obj._uuid = uuid4()
else:
obj._uuid = uuid

obj._hash = hash(obj._uuid)
return obj
# This `__init__` does not call the super init, because we can't construct the
# `_parameter_symbols` dictionary we need to pass to it before we're entirely initialised
# anyway, because `ParameterExpression` depends heavily on the structure of `Parameter`.

def __getnewargs__(self):
# Unpickling won't in general call __init__ but will always call
# __new__. Specify arguments to be passed to __new__ when unpickling.

return (self.name, self._uuid)

def __init__(self, name: str):
def __init__(
self, name: str, *, uuid: UUID | None = None
): # pylint: disable=super-init-not-called
"""Create a new named :class:`Parameter`.
Args:
name: name of the ``Parameter``, used for visual representation. This can
be any unicode string, e.g. "ϕ".
uuid: For advanced usage only. Override the UUID of this parameter, in order to make it
compare equal to some other parameter object. By default, two parameters with the
same name do not compare equal to help catch shadowing bugs when two circuits
containing the same named parameters are spurious combined. Setting the ``uuid``
field when creating two parameters to the same thing (along with the same name)
allows them to be equal. This is useful during serialization and deserialization.
"""
self._name = name
self._uuid = uuid4() if uuid is None else uuid
if not _optionals.HAS_SYMENGINE:
from sympy import Symbol

Expand All @@ -88,7 +83,12 @@ def __init__(self, name: str):
import symengine

symbol = symengine.Symbol(name)
super().__init__(symbol_map={self: symbol}, expr=symbol)

self._symbol_expr = symbol
self._parameter_keys = frozenset((self._hash_key(),))
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -144,20 +144,27 @@ def __eq__(self, other):
else:
return False

def _hash_key(self):
# `ParameterExpression` needs to be able to hash all its contained `Parameter` instances in
# its hash as part of the equality comparison but has its own more complete symbolic
# expression, so its full hash key is split into `(parameter_keys, symbolic_expression)`.
# This method lets containing expressions get only the bits they need for equality checks in
# the first value, without wasting time re-hashing individual Sympy/Symengine symbols.
return (self._name, self._uuid)

def __hash__(self):
# This is precached for performance, since it's used a lot and we are immutable.
return self._hash

# We have to manually control the pickling so that the hash is computable before the unpickling
# operation attempts to put this parameter into a hashmap.

def __getstate__(self):
return {"name": self._name}
return (self._name, self._uuid, self._symbol_expr)

def __setstate__(self, state):
self._name = state["name"]
if not _optionals.HAS_SYMENGINE:
from sympy import Symbol

symbol = Symbol(self._name)
else:
import symengine

symbol = symengine.Symbol(self._name)
super().__init__(symbol_map={self: symbol}, expr=symbol)
self._name, self._uuid, self._symbol_expr = state
self._parameter_keys = frozenset((self._hash_key(),))
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
20 changes: 11 additions & 9 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameters", "_symbol_expr", "_name_map"]
__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]

def __init__(self, symbol_map: dict, expr):
"""Create a new :class:`ParameterExpression`.
Expand All @@ -47,21 +47,24 @@ def __init__(self, symbol_map: dict, expr):
serving as their placeholder in expr.
expr (sympy.Expr): Expression of :class:`sympy.Symbol` s.
"""
# NOTE: `Parameter.__init__` does not call up to this method, since this method is dependent
# on `Parameter` instances already being initialised enough to be hashable. If changing
# this method, check that `Parameter.__init__` and `__setstate__` are still valid.
self._parameter_symbols = symbol_map
self._parameters = set(self._parameter_symbols)
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None

@property
def parameters(self) -> set:
"""Returns a set of the unbound Parameters in the expression."""
return self._parameters
return self._parameter_symbols.keys()

@property
def _names(self) -> dict:
"""Returns a mapping of parameter names to Parameters in the expression."""
if self._name_map is None:
self._name_map = {p.name: p for p in self._parameters}
self._name_map = {p.name: p for p in self._parameter_symbols}
return self._name_map

def conjugate(self) -> "ParameterExpression":
Expand Down Expand Up @@ -121,8 +124,7 @@ def bind(

symbol_values = {}
for parameter, value in parameter_values.items():
if parameter in self._parameters:
param_expr = self._parameter_symbols[parameter]
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
symbol_values[param_expr] = value

bound_symbol_expr = self._symbol_expr.subs(symbol_values)
Expand Down Expand Up @@ -197,8 +199,8 @@ def subs(
# but with our sympy symbols instead of theirs.
symbol_map = {}
for old_param, new_param in parameter_map.items():
if old_param in self._parameters:
symbol_map[self._parameter_symbols[old_param]] = new_param._symbol_expr
if (old_symbol := self._parameter_symbols.get(old_param)) is not None:
symbol_map[old_symbol] = new_param._symbol_expr
for p in new_param.parameters:
new_parameter_symbols[p] = symbol_type(p.name)

Expand Down Expand Up @@ -518,7 +520,7 @@ def __int__(self):
raise TypeError("could not cast expression to int") from exc

def __hash__(self):
return hash((frozenset(self._parameter_symbols), self._symbol_expr))
return hash((self._parameter_keys, self._symbol_expr))

def __copy__(self):
return self
Expand Down
35 changes: 7 additions & 28 deletions qiskit/circuit/parametervector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,8 @@ class ParameterVectorElement(Parameter):

___slots__ = ("_vector", "_index")

def __new__(cls, vector, index, uuid=None): # pylint:disable=unused-argument
obj = object.__new__(cls)

if uuid is None:
obj._uuid = uuid4()
else:
obj._uuid = uuid

obj._hash = hash(obj._uuid)
return obj

def __getnewargs__(self):
return (self.vector, self.index, self._uuid)

def __init__(self, vector, index, uuid=None): # pylint: disable=unused-argument
name = f"{vector.name}[{index}]"
super().__init__(name)
def __init__(self, vector, index, uuid=None):
super().__init__(f"{vector.name}[{index}]", uuid=uuid)
self._vector = vector
self._index = index

Expand All @@ -53,19 +38,13 @@ def vector(self):
return self._vector

def __getstate__(self):
return {
"name": self._name,
"uuid": self._uuid,
"vector": self._vector,
"index": self._index,
}
return super().__getstate__() + (self._vector, self._index)

def __setstate__(self, state):
self._name = state["name"]
self._uuid = state["uuid"]
self._vector = state["vector"]
self._index = state["index"]
super().__init__(self._name)
*super_state, vector, index = state
super().__setstate__(super_state)
self._vector = vector
self._index = index


class ParameterVector:
Expand Down
3 changes: 1 addition & 2 deletions qiskit/pulse/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,8 +1969,7 @@ def _collect_scoped_parameters(

if filter_regex and not re.search(filter_regex, new_name):
continue
scoped_param = Parameter.__new__(Parameter, new_name, uuid=getattr(param, "_uuid"))
scoped_param.__init__(new_name)
scoped_param = Parameter(new_name, uuid=getattr(param, "_uuid"))

unique_key = new_name, hash(param)
parameters_out[unique_key] = scoped_param
Expand Down
9 changes: 2 additions & 7 deletions qiskit/qpy/binary_io/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def _read_parameter(file_obj):
)
param_uuid = uuid.UUID(bytes=data.uuid)
name = file_obj.read(data.name_size).decode(common.ENCODE)
param = Parameter.__new__(Parameter, name, uuid=param_uuid)
param.__init__(name)
return param
return Parameter(name, uuid=param_uuid)


def _read_parameter_vec(file_obj, vectors):
Expand All @@ -216,10 +214,7 @@ def _read_parameter_vec(file_obj, vectors):
vector = vectors[name][0]
if vector[data.index]._uuid != param_uuid:
vectors[name][1].add(data.index)
vector._params[data.index] = ParameterVectorElement.__new__(
ParameterVectorElement, vector, data.index, uuid=param_uuid
)
vector._params[data.index].__init__(vector, data.index)
vector._params[data.index] = ParameterVectorElement(vector, data.index, uuid=param_uuid)
return vector[data.index]


Expand Down
13 changes: 13 additions & 0 deletions releasenotes/notes/fix-parameter-hash-d22c270090ffc80e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
features:
- |
:class:`.Parameter` now has an advanced-usage keyword argument ``uuid`` in its constructor,
which can be used to make the :class:`.Parameter` compare equal to another of the same name.
This should not typically be used by users, and is most useful for custom serialisation and
deserialisation.
fixes:
- |
The hash of a :class:`.Parameter` is now equal to the hashes of any
:class:`.ParameterExpression` that it compares equal to. Previously the hashes were different,
which would cause spurious additional entries in hashmaps when :class:`.Parameter` and
:class:`.ParameterExpression` values were mixed in the same map as it violated Python's data model.
60 changes: 52 additions & 8 deletions test/python/circuit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,16 +804,11 @@ def test_instruction_ryrz_vector(self):
for param in vec:
self.assertIn(param, qc_aer.parameters)

@data("single", "vector")
def test_parameter_equality_through_serialization(self, ptype):
def test_parameter_equality_through_serialization(self):
"""Verify parameters maintain their equality after serialization."""

if ptype == "single":
x1 = Parameter("x")
x2 = Parameter("x")
else:
x1 = ParameterVector("x", 2)[0]
x2 = ParameterVector("x", 2)[0]
x1 = Parameter("x")
x2 = Parameter("x")

x1_p = pickle.loads(pickle.dumps(x1))
x2_p = pickle.loads(pickle.dumps(x2))
Expand All @@ -824,6 +819,55 @@ def test_parameter_equality_through_serialization(self, ptype):
self.assertNotEqual(x1, x2_p)
self.assertNotEqual(x2, x1_p)

def test_parameter_vector_equality_through_serialization(self):
"""Verify elements of parameter vectors maintain their equality after serialization."""

x1 = ParameterVector("x", 2)
x2 = ParameterVector("x", 2)

x1_p = pickle.loads(pickle.dumps(x1))
x2_p = pickle.loads(pickle.dumps(x2))

self.assertEqual(x1[0], x1_p[0])
self.assertEqual(x2[0], x2_p[0])

self.assertNotEqual(x1[0], x2_p[0])
self.assertNotEqual(x2[0], x1_p[0])

self.assertIs(x1_p[0].vector, x1_p)
self.assertIs(x2_p[0].vector, x2_p)
self.assertEqual([p.index for p in x1_p], list(range(len(x1_p))))
self.assertEqual([p.index for p in x2_p], list(range(len(x2_p))))

@data("single", "vector")
def test_parameter_equality_to_expression(self, ptype):
"""Verify that parameters compare equal to `ParameterExpression`s that represent the same
thing."""

if ptype == "single":
x1 = Parameter("x")
x2 = Parameter("x")
else:
x1 = ParameterVector("x", 2)[0]
x2 = ParameterVector("x", 2)[0]

x1_expr = x1 + 0
# Smoke test: the test isn't valid if that above expression remains a `Parameter`; we need
# it to have upcast to `ParameterExpression`.
self.assertNotIsInstance(x1_expr, Parameter)
x2_expr = x2 + 0
self.assertNotIsInstance(x2_expr, Parameter)

self.assertEqual(x1, x1_expr)
self.assertEqual(x2, x2_expr)

self.assertNotEqual(x1, x2_expr)
self.assertNotEqual(x2, x1_expr)

# Since these two pairs of objects compared equal, they must have the same hash as well.
self.assertEqual(hash(x1), hash(x1_expr))
self.assertEqual(hash(x2), hash(x2_expr))

def test_binding_parameterized_circuits_built_in_multiproc(self):
"""Verify subcircuits built in a subprocess can still be bound."""
# ref: https://github.com/Qiskit/qiskit-terra/issues/2429
Expand Down

0 comments on commit 8651d34

Please sign in to comment.