Skip to content

Commit

Permalink
Track QuantumCircuit.global_phase in ParameterTable
Browse files Browse the repository at this point in the history
We have previously always had a split where circuit parameters used in
instructions were tracked in the `ParameterTable`, but any parameters
used in the global phase were not.  Any method that influenced the
parameters needed to separately check the global phase, and merge that
information with that in the `ParameterTable`.  This made it easy to
forget, or easy for the handling of it to become out of sync.

This commit now tracks the global phase as part of the `ParameterTable`,
so this object is now the canonical source of parameter information for
the circuit (outside the context of calibrations, which are handled
entirely separately).  The `ParameterTable` is an internal detail, and
only accessible through private attributes, so is not part of the public
interface.
  • Loading branch information
jakelishman committed Dec 18, 2023
1 parent 1c023bd commit e446589
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 23 deletions.
33 changes: 33 additions & 0 deletions qiskit/circuit/parametertable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
Look-up table for variable parameters in QuantumCircuit.
"""
import operator
from collections.abc import MappingView, MutableMapping, MutableSet


Expand All @@ -21,6 +22,9 @@ class ParameterReferences(MutableSet):
testing is overridden such that items that are otherwise value-wise equal
are still considered distinct if their ``instruction``\\ s are referentially
distinct.
In the case of the special value :attr:`.ParameterTable.GLOBAL_PHASE` for ``instruction``, the
``param_index`` should be ``None``.
"""

def _instance_key(self, ref):
Expand Down Expand Up @@ -83,6 +87,24 @@ class ParameterTable(MutableMapping):

__slots__ = ["_table", "_keys", "_names"]

class _GlobalPhaseSentinel:
__slots__ = ()

def __copy__(self):
return self

def __deepcopy__(self, memo=None):
return self

def __reduce__(self):
return (operator.attrgetter("GLOBAL_PHASE"), (ParameterTable,))

def __repr__(self):
return "<global-phase sentinel>"

GLOBAL_PHASE = _GlobalPhaseSentinel()
"""Tracking object to indicate that a reference refers to the global phase of a circuit."""

def __init__(self, mapping=None):
"""Create a new instance, initialized with ``mapping`` if provided.
Expand Down Expand Up @@ -145,6 +167,17 @@ def get_names(self):
"""
return self._names

def discard_references(self, expression, key):
"""Remove all references to parameters contained within ``expression`` at the given table
``key``. This also discards parameter entries from the table if they have no further
references. No action is taken if the object is not tracked."""
for parameter in expression.parameters:
if (refs := self._table.get(parameter)) is not None:
if len(refs) == 1:
del self[parameter]
else:
refs.discard(key)

def __delitem__(self, key):
del self._table[key]
self._keys.discard(key)
Expand Down
74 changes: 51 additions & 23 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def data(self, data_input: Iterable):
else:
data_input = list(data_input)
self._data.clear()
self._parameters = None
self._parameter_table = ParameterTable()
# Repopulate the parameter table with any global-phase entries.
self.global_phase = self.global_phase
if not data_input:
return
if isinstance(data_input[0], CircuitInstruction):
Expand Down Expand Up @@ -2411,6 +2414,9 @@ def copy(self, name: str | None = None) -> "QuantumCircuit":
operation_copies = {
id(instruction.operation): instruction.operation.copy() for instruction in self._data
}
# The special global-phase sentinel doesn't need copying, but this ensures that it'll get
# recognised. The global phase itself was already copied over in 'copy_empty_like`.
operation_copies[id(ParameterTable.GLOBAL_PHASE)] = ParameterTable.GLOBAL_PHASE

cpy._parameter_table = ParameterTable(
{
Expand Down Expand Up @@ -2473,6 +2479,10 @@ def copy_empty_like(self, name: str | None = None) -> "QuantumCircuit":
cpy._vars_capture = self._vars_capture.copy()

cpy._parameter_table = ParameterTable()
for parameter in getattr(cpy.global_phase, "parameters", ()):
cpy._parameter_table[parameter] = ParameterReferences(
[(ParameterTable.GLOBAL_PHASE, None)]
)
cpy._data = CircuitData(self._data.qubits, self._data.clbits)

cpy._calibrations = copy.deepcopy(self._calibrations)
Expand All @@ -2489,6 +2499,8 @@ def clear(self) -> None:
"""
self._data.clear()
self._parameter_table.clear()
# Repopulate the parameter table with any phase symbols.
self.global_phase = self.global_phase

def _create_creg(self, length: int, name: str) -> ClassicalRegister:
"""Creates a creg, checking if ClassicalRegister with same name exists"""
Expand Down Expand Up @@ -2752,6 +2764,8 @@ def remove_final_measurements(self, inplace: bool = True) -> Optional["QuantumCi
for creg in cregs_to_add:
circ.add_register(creg)

# Clear instruction info
circ.clear()
# Set circ instructions to match the new DAG
for node in new_dag.topological_op_nodes():
# Get arguments for classical condition (if any)
Expand Down Expand Up @@ -2825,9 +2839,20 @@ def global_phase(self, angle: ParameterValueType):
Args:
angle (float, ParameterExpression): radians
"""
if not (isinstance(angle, ParameterExpression) and angle.parameters):
# Set the phase to the [0, 2π) interval
angle = float(angle) % (2 * np.pi)
# If we're currently parametric, we need to throw away the references. This setter is
# called by some subclasses before the inner `_global_phase` is initialised.
global_phase_reference = (ParameterTable.GLOBAL_PHASE, None)
if isinstance(previous := getattr(self, "_global_phase", None), ParameterExpression):
self._parameter_table.discard_references(previous, global_phase_reference)

if isinstance(angle, ParameterExpression) and angle.parameters:
for parameter in angle.parameters:
if parameter not in self._parameter_table:
self._parameters = None
self._parameter_table[parameter] = ParameterReferences(())
self._parameter_table[parameter].add(global_phase_reference)
else:
angle = _normalize_global_phase(angle)
if self._control_flow_scopes:
self._control_flow_scopes[-1].global_phase = angle
else:
Expand Down Expand Up @@ -2902,9 +2927,7 @@ def parameters(self) -> ParameterView:
def num_parameters(self) -> int:
"""The number of parameter objects in the circuit."""
# Avoid a (potential) object creation if we can.
if self._parameters is not None:
return len(self._parameters)
return len(self._unsorted_parameters())
return len(self._parameter_table.get_keys())

def _unsorted_parameters(self) -> set[Parameter]:
"""Efficiently get all parameters in the circuit, without any sorting overhead.
Expand All @@ -2915,13 +2938,7 @@ def _unsorted_parameters(self) -> set[Parameter]:
should not be mutated. This is an internal performance detail. Code outside of this
package should not use this method.
"""
# This should be free, by accessing the actual backing data structure of the table, but that
# means that we need to copy it if adding keys from the global phase.
parameters = self._parameter_table.get_keys()
if isinstance(self.global_phase, ParameterExpression):
# Deliberate copy.
parameters = parameters | self.global_phase.parameters
return parameters
return self._parameter_table.get_keys()

@overload
def assign_parameters(
Expand Down Expand Up @@ -3038,7 +3055,6 @@ def assign_parameters( # pylint: disable=missing-raises-doc
# 'target' so we can take advantage of any caching we might be doing.
if isinstance(parameters, dict):
raw_mapping = parameters if flat_input else self._unroll_param_dict(parameters)
# Remember that we _must not_ mutate the output of `_unsorted_parameters`.
our_parameters = self._unsorted_parameters()
if strict and (extras := raw_mapping.keys() - our_parameters):
raise CircuitError(
Expand Down Expand Up @@ -3073,15 +3089,20 @@ def assign_parameters( # pylint: disable=missing-raises-doc
)
for operation, index in references:
seen_operations[id(operation)] = operation
assignee = operation.params[index]
if operation is ParameterTable.GLOBAL_PHASE:
assignee = target.global_phase
validate = _normalize_global_phase
else:
assignee = operation.params[index]
validate = operation.validate_parameter
if isinstance(assignee, ParameterExpression):
new_parameter = assignee.assign(to_bind, bound_value)
for parameter in update_parameters:
if parameter not in target._parameter_table:
target._parameter_table[parameter] = ParameterReferences(())
target._parameter_table[parameter].add((operation, index))
if not new_parameter.parameters:
new_parameter = operation.validate_parameter(new_parameter.numeric())
new_parameter = validate(new_parameter.numeric())
elif isinstance(assignee, QuantumCircuit):
new_parameter = assignee.assign_parameters(
{to_bind: bound_value}, inplace=False, flat_input=True
Expand All @@ -3091,7 +3112,12 @@ def assign_parameters( # pylint: disable=missing-raises-doc
f"Saw an unknown type during symbolic binding: {assignee}."
" This may indicate an internal logic error in symbol tracking."
)
operation.params[index] = new_parameter
if operation is ParameterTable.GLOBAL_PHASE:
# We've already handled parameter table updates in bulk, so we need to skip the
# public setter trying to do it again.
target._global_phase = new_parameter
else:
operation.params[index] = new_parameter

# After we've been through everything at the top level, make a single visit to each
# operation we've seen, rebinding its definition if necessary.
Expand All @@ -3103,12 +3129,6 @@ def assign_parameters( # pylint: disable=missing-raises-doc
parameter_binds.mapping, inplace=True, flat_input=True, strict=False
)

if isinstance(target.global_phase, ParameterExpression):
new_phase = target.global_phase
for parameter in new_phase.parameters & parameter_binds.mapping.keys():
new_phase = new_phase.assign(parameter, parameter_binds.mapping[parameter])
target.global_phase = new_phase

# Finally, assign the parameters inside any of the calibrations. We don't track these in
# the `ParameterTable`, so we manually reconstruct things.
def map_calibration(qubits, parameters, schedule):
Expand Down Expand Up @@ -6064,3 +6084,11 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_):
else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'"
)
raise CircuitError(message)


def _normalize_global_phase(angle):
"""Return the normalized form of an angle for use in the global phase. This coerces to float if
possible, and fixes to the interval :math:`[0, 2\\pi)`."""
if isinstance(angle, ParameterExpression) and angle.parameters:
return angle
return float(angle) % (2.0 * np.pi)
37 changes: 37 additions & 0 deletions test/python/circuit/test_circuit_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,25 @@ def test_copy_copies_registers(self):
self.assertEqual(len(qc.cregs), 1)
self.assertEqual(len(copied.cregs), 2)

def test_copy_handles_global_phase(self):
"""Test that the global phase is included in the copy, including parameters."""
a, b = Parameter("a"), Parameter("b")

nonparametric = QuantumCircuit(global_phase=1.0).copy()
self.assertEqual(nonparametric.global_phase, 1.0)
self.assertEqual(set(nonparametric.parameters), set())

parameter_phase = QuantumCircuit(global_phase=a).copy()
self.assertEqual(parameter_phase.global_phase, a)
self.assertEqual(set(parameter_phase.parameters), {a})
# The `assign_parameters` is an indirect test that the `ParameterTable` is fully valid.
self.assertEqual(parameter_phase.assign_parameters({a: 1.0}).global_phase, 1.0)

expression_phase = QuantumCircuit(global_phase=a - b).copy()
self.assertEqual(expression_phase.global_phase, a - b)
self.assertEqual(set(expression_phase.parameters), {a, b})
self.assertEqual(expression_phase.assign_parameters({a: 3, b: 2}).global_phase, 1.0)

def test_copy_empty_like_circuit(self):
"""Test copy_empty_like method makes a clear copy."""
qr = QuantumRegister(2)
Expand Down Expand Up @@ -463,6 +482,24 @@ def test_copy_empty_variables(self):
self.assertEqual({b, d}, set(copied.iter_captured_vars()))
self.assertEqual({b}, set(qc.iter_captured_vars()))

def test_copy_empty_like_parametric_phase(self):
"""Test that the parameter table of an empty circuit remains valid after copying a circuit
with a parametric global phase."""
a, b = Parameter("a"), Parameter("b")

single = QuantumCircuit(global_phase=a).copy_empty_like()
self.assertEqual(single.global_phase, a)
self.assertEqual(set(single.parameters), {a})
# The `assign_parameters` is an indirect test that the `ParameterTable` is fully valid.
self.assertEqual(single.assign_parameters({a: 1.0}).global_phase, 1.0)

stripped_instructions = QuantumCircuit(1, global_phase=a - b)
stripped_instructions.rz(a, 0)
stripped_instructions = stripped_instructions.copy_empty_like()
self.assertEqual(stripped_instructions.global_phase, a - b)
self.assertEqual(set(stripped_instructions.parameters), {a, b})
self.assertEqual(stripped_instructions.assign_parameters({a: 3, b: 2}).global_phase, 1.0)

def test_circuit_copy_rejects_invalid_types(self):
"""Test copy method rejects argument with type other than 'string' and 'None' type."""
qc = QuantumCircuit(1, 1)
Expand Down

0 comments on commit e446589

Please sign in to comment.