From 4913f118507aa9778ddc783a82d81acf46496eeb Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Thu, 21 Sep 2023 12:44:42 +0100 Subject: [PATCH] Fix hash of `Parameter` and `ParameterExpression` This fixes the construction paths for `Parameter` and its hash such that it will now correctly hash equal to any `ParameterExpression`s that it compares equal to. This is a requirement of the Python data model for hashmaps and hashsets, which previously we were breaking. In order to achieve this, we slightly modify the hash key such that the `Parameter` instances are no longer a part of the hash of `ParameterExpression`, which means we can use the same hashing strategy for both. This rearrangement has the benefit of removing the requirement for the `__new__` overrides on `Parameter` and `ParameterVectorElement`. --- qiskit/circuit/parameter.py | 67 ++++++++++--------- qiskit/circuit/parameterexpression.py | 20 +++--- qiskit/circuit/parametervector.py | 35 ++-------- qiskit/pulse/schedule.py | 3 +- qiskit/qpy/binary_io/value.py | 9 +-- .../fix-parameter-hash-d22c270090ffc80e.yaml | 13 ++++ test/python/circuit/test_parameters.py | 60 ++++++++++++++--- 7 files changed, 121 insertions(+), 86 deletions(-) create mode 100644 releasenotes/notes/fix-parameter-hash-d22c270090ffc80e.yaml diff --git a/qiskit/circuit/parameter.py b/qiskit/circuit/parameter.py index 3e347b0beac1..bb15f8f6471d 100644 --- a/qiskit/circuit/parameter.py +++ b/qiskit/circuit/parameter.py @@ -52,34 +52,25 @@ class Parameter(ParameterExpression): __slots__ = ("_name", "_uuid", "_hash") - def __new__(cls, name, uuid=None): # pylint: disable=unused-argument - # Parameter relies on self._uuid being set prior to other attributes - # (e.g. symbol_map) which may depend on self._uuid for Parameter's hash - # or __eq__ functions. - obj = object.__new__(cls) - - if uuid is None: - obj._uuid = uuid4() - else: - obj._uuid = uuid - - obj._hash = hash(obj._uuid) - return obj - - def __getnewargs__(self): - # Unpickling won't in general call __init__ but will always call - # __new__. Specify arguments to be passed to __new__ when unpickling. - - return (self.name, self._uuid) + # This `__init__` does not call the super init, because we can't construct the + # `_parameter_symbols` dictionary we need to pass to it before we're entirely initialised + # anyway, because `ParameterExpression` depends heavily on the structure of `Parameter`. - def __init__(self, name: str): + def __init__(self, name: str, *, uuid=None): # pylint: disable=super-init-not-called """Create a new named :class:`Parameter`. Args: name: name of the ``Parameter``, used for visual representation. This can be any unicode string, e.g. "ϕ". + uuid: For advanced usage only. Override the UUID of this parameter, in order to make it + compare equal to some other parameter object. By default, two parameters with the + same name do not compare equal to help catch shadowing bugs when two circuits + containing the same named parameters are spurious combined. Setting the ``uuid`` + field when creating two parameters to the same thing (along with the same name) + allows them to be equal. This is useful during serialization and deserialization. """ self._name = name + self._uuid = uuid4() if uuid is None else uuid if not _optionals.HAS_SYMENGINE: from sympy import Symbol @@ -88,7 +79,12 @@ def __init__(self, name: str): import symengine symbol = symengine.Symbol(name) - super().__init__(symbol_map={self: symbol}, expr=symbol) + + self._symbol_expr = symbol + self._parameter_keys = frozenset((self._hash_key(),)) + self._hash = hash((self._parameter_keys, self._symbol_expr)) + self._parameter_symbols = {self: symbol} + self._name_map = None def assign(self, parameter, value): if parameter != self: @@ -144,20 +140,27 @@ def __eq__(self, other): else: return False + def _hash_key(self): + # This isn't the entirety of the object that's passed to `hash`, just the "key" part of + # individual parameters. The hash of a full `ParameterExpression` needs to depend on the + # "keys" of `Parameter`s, and our hash needs to be computable before we can be fully + # initialised as a `ParameterExpression`, so we break the cycle by making our "key" + # accessible separately. + return (self._name, self._uuid) + def __hash__(self): + # This is precached for performance, since it's used a lot and we are immutable. return self._hash + # We have to manually control the pickling so that the hash is computable before the unpickling + # operation attempts to put this parameter into a hashmap. + def __getstate__(self): - return {"name": self._name} + return (self._name, self._uuid, self._symbol_expr) def __setstate__(self, state): - self._name = state["name"] - if not _optionals.HAS_SYMENGINE: - from sympy import Symbol - - symbol = Symbol(self._name) - else: - import symengine - - symbol = symengine.Symbol(self._name) - super().__init__(symbol_map={self: symbol}, expr=symbol) + self._name, self._uuid, self._symbol_expr = state + self._parameter_keys = frozenset((self._hash_key(),)) + self._hash = hash((self._parameter_keys, self._symbol_expr)) + self._parameter_symbols = {self: self._symbol_expr} + self._name_map = None diff --git a/qiskit/circuit/parameterexpression.py b/qiskit/circuit/parameterexpression.py index dc7614a30586..00b35646fef8 100644 --- a/qiskit/circuit/parameterexpression.py +++ b/qiskit/circuit/parameterexpression.py @@ -33,7 +33,7 @@ class ParameterExpression: """ParameterExpression class to enable creating expressions of Parameters.""" - __slots__ = ["_parameter_symbols", "_parameters", "_symbol_expr", "_name_map"] + __slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"] def __init__(self, symbol_map: dict, expr): """Create a new :class:`ParameterExpression`. @@ -47,21 +47,24 @@ def __init__(self, symbol_map: dict, expr): serving as their placeholder in expr. expr (sympy.Expr): Expression of :class:`sympy.Symbol` s. """ + # NOTE: `Parameter.__init__` does not call up to this method, since this method is dependent + # on `Parameter` instances already being initialised enough to be hashable. If changing + # this method, check that `Parameter.__init__` and `__setstate__` are still valid. self._parameter_symbols = symbol_map - self._parameters = set(self._parameter_symbols) + self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols) self._symbol_expr = expr self._name_map: dict | None = None @property def parameters(self) -> set: """Returns a set of the unbound Parameters in the expression.""" - return self._parameters + return self._parameter_symbols.keys() @property def _names(self) -> dict: """Returns a mapping of parameter names to Parameters in the expression.""" if self._name_map is None: - self._name_map = {p.name: p for p in self._parameters} + self._name_map = {p.name: p for p in self._parameter_symbols} return self._name_map def conjugate(self) -> "ParameterExpression": @@ -121,8 +124,7 @@ def bind( symbol_values = {} for parameter, value in parameter_values.items(): - if parameter in self._parameters: - param_expr = self._parameter_symbols[parameter] + if (param_expr := self._parameter_symbols.get(parameter)) is not None: symbol_values[param_expr] = value bound_symbol_expr = self._symbol_expr.subs(symbol_values) @@ -197,8 +199,8 @@ def subs( # but with our sympy symbols instead of theirs. symbol_map = {} for old_param, new_param in parameter_map.items(): - if old_param in self._parameters: - symbol_map[self._parameter_symbols[old_param]] = new_param._symbol_expr + if (old_symbol := self._parameter_symbols.get(old_param)) is not None: + symbol_map[old_symbol] = new_param._symbol_expr for p in new_param.parameters: new_parameter_symbols[p] = symbol_type(p.name) @@ -507,7 +509,7 @@ def __int__(self): raise TypeError("could not cast expression to int") from exc def __hash__(self): - return hash((frozenset(self._parameter_symbols), self._symbol_expr)) + return hash((self._parameter_keys, self._symbol_expr)) def __copy__(self): return self diff --git a/qiskit/circuit/parametervector.py b/qiskit/circuit/parametervector.py index 024abd0b5294..abc8a6f60ef7 100644 --- a/qiskit/circuit/parametervector.py +++ b/qiskit/circuit/parametervector.py @@ -22,23 +22,8 @@ class ParameterVectorElement(Parameter): ___slots__ = ("_vector", "_index") - def __new__(cls, vector, index, uuid=None): # pylint:disable=unused-argument - obj = object.__new__(cls) - - if uuid is None: - obj._uuid = uuid4() - else: - obj._uuid = uuid - - obj._hash = hash(obj._uuid) - return obj - - def __getnewargs__(self): - return (self.vector, self.index, self._uuid) - - def __init__(self, vector, index, uuid=None): # pylint: disable=unused-argument - name = f"{vector.name}[{index}]" - super().__init__(name) + def __init__(self, vector, index, uuid=None): + super().__init__(f"{vector.name}[{index}]", uuid=uuid) self._vector = vector self._index = index @@ -53,19 +38,13 @@ def vector(self): return self._vector def __getstate__(self): - return { - "name": self._name, - "uuid": self._uuid, - "vector": self._vector, - "index": self._index, - } + return super().__getstate__() + (self._vector, self._index) def __setstate__(self, state): - self._name = state["name"] - self._uuid = state["uuid"] - self._vector = state["vector"] - self._index = state["index"] - super().__init__(self._name) + *super_state, vector, index = state + super().__setstate__(super_state) + self._vector = vector + self._index = index class ParameterVector: diff --git a/qiskit/pulse/schedule.py b/qiskit/pulse/schedule.py index 6643a03d1a31..a1f7ced82d8b 100644 --- a/qiskit/pulse/schedule.py +++ b/qiskit/pulse/schedule.py @@ -1969,8 +1969,7 @@ def _collect_scoped_parameters( if filter_regex and not re.search(filter_regex, new_name): continue - scoped_param = Parameter.__new__(Parameter, new_name, uuid=getattr(param, "_uuid")) - scoped_param.__init__(new_name) + scoped_param = Parameter(new_name, uuid=getattr(param, "_uuid")) unique_key = new_name, hash(param) parameters_out[unique_key] = scoped_param diff --git a/qiskit/qpy/binary_io/value.py b/qiskit/qpy/binary_io/value.py index 2edac6c81b52..698726a812d2 100644 --- a/qiskit/qpy/binary_io/value.py +++ b/qiskit/qpy/binary_io/value.py @@ -192,9 +192,7 @@ def _read_parameter(file_obj): ) param_uuid = uuid.UUID(bytes=data.uuid) name = file_obj.read(data.name_size).decode(common.ENCODE) - param = Parameter.__new__(Parameter, name, uuid=param_uuid) - param.__init__(name) - return param + return Parameter(name, uuid=param_uuid) def _read_parameter_vec(file_obj, vectors): @@ -211,10 +209,7 @@ def _read_parameter_vec(file_obj, vectors): vector = vectors[name][0] if vector[data.index]._uuid != param_uuid: vectors[name][1].add(data.index) - vector._params[data.index] = ParameterVectorElement.__new__( - ParameterVectorElement, vector, data.index, uuid=param_uuid - ) - vector._params[data.index].__init__(vector, data.index) + vector._params[data.index] = ParameterVectorElement(vector, data.index, uuid=param_uuid) return vector[data.index] diff --git a/releasenotes/notes/fix-parameter-hash-d22c270090ffc80e.yaml b/releasenotes/notes/fix-parameter-hash-d22c270090ffc80e.yaml new file mode 100644 index 000000000000..c325ea30e648 --- /dev/null +++ b/releasenotes/notes/fix-parameter-hash-d22c270090ffc80e.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + :class:`.Parameter` now has an advanced-usage keyword argument ``uuid`` in its constructor, + which can be used to make the :class:`.Parameter` compare equal to another of the same name. + This should not typically be used by users, and is most useful for custom serialisation and + deserialisation. +fixes: + - | + The hash of a :class:`.Parameter` is now equal to the hashes of any + :class:`.ParameterExpression` that it compares equal to. Previously the hashes were different, + which would cause spurious additional entries in hashmaps when :class:`.Parameter` and + :class:`.ParameterExpression` values were mixed in the same map. diff --git a/test/python/circuit/test_parameters.py b/test/python/circuit/test_parameters.py index caa3a148cd7f..9a18a6882920 100644 --- a/test/python/circuit/test_parameters.py +++ b/test/python/circuit/test_parameters.py @@ -804,16 +804,11 @@ def test_instruction_ryrz_vector(self): for param in vec: self.assertIn(param, qc_aer.parameters) - @data("single", "vector") - def test_parameter_equality_through_serialization(self, ptype): + def test_parameter_equality_through_serialization(self): """Verify parameters maintain their equality after serialization.""" - if ptype == "single": - x1 = Parameter("x") - x2 = Parameter("x") - else: - x1 = ParameterVector("x", 2)[0] - x2 = ParameterVector("x", 2)[0] + x1 = Parameter("x") + x2 = Parameter("x") x1_p = pickle.loads(pickle.dumps(x1)) x2_p = pickle.loads(pickle.dumps(x2)) @@ -824,6 +819,55 @@ def test_parameter_equality_through_serialization(self, ptype): self.assertNotEqual(x1, x2_p) self.assertNotEqual(x2, x1_p) + def test_parameter_vector_equality_through_serialization(self): + """Verify elements of parameter vectors maintain their equality after serialization.""" + + x1 = ParameterVector("x", 2) + x2 = ParameterVector("x", 2) + + x1_p = pickle.loads(pickle.dumps(x1)) + x2_p = pickle.loads(pickle.dumps(x2)) + + self.assertEqual(x1[0], x1_p[0]) + self.assertEqual(x2[0], x2_p[0]) + + self.assertNotEqual(x1[0], x2_p[0]) + self.assertNotEqual(x2[0], x1_p[0]) + + self.assertIs(x1_p[0].vector, x1_p) + self.assertIs(x2_p[0].vector, x2_p) + self.assertEqual([p.index for p in x1_p], list(range(len(x1_p)))) + self.assertEqual([p.index for p in x2_p], list(range(len(x2_p)))) + + @data("single", "vector") + def test_parameter_equality_to_expression(self, ptype): + """Verify that parameters compare equal to `ParameterExpression`s that represent the same + thing.""" + + if ptype == "single": + x1 = Parameter("x") + x2 = Parameter("x") + else: + x1 = ParameterVector("x", 2)[0] + x2 = ParameterVector("x", 2)[0] + + x1_expr = x1 + 0 + # Smoke test: the test isn't valid if that above expression remains a `Parameter`; we need + # it to have upcast to `ParameterExpression`. + self.assertNotIsInstance(x1_expr, Parameter) + x2_expr = x2 + 0 + self.assertNotIsInstance(x2_expr, Parameter) + + self.assertEqual(x1, x1_expr) + self.assertEqual(x2, x2_expr) + + self.assertNotEqual(x1, x2_expr) + self.assertNotEqual(x2, x1_expr) + + # Since these two pairs of objects compared equal, they must have the same hash as well. + self.assertEqual(hash(x1), hash(x1_expr)) + self.assertEqual(hash(x2), hash(x2_expr)) + def test_binding_parameterized_circuits_built_in_multiproc(self): """Verify subcircuits built in a subprocess can still be bound.""" # ref: https://github.com/Qiskit/qiskit-terra/issues/2429