From 7e70ef2023096d4a771f1d6567a98d5b16f24f17 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Wed, 25 Oct 2023 12:03:37 +0100 Subject: [PATCH] Ensure `TemplateOptimization` returns native-symbolic objects `TemplateOptimization` currently uses Sympy internally to get access to solvers that have no equivalent in Symengine. Previously, it then did not convert the solutions to its equations back into Symengine format (if appropriate) before passing them on to `ParameterExpression`, which could lead to different assumptions about the type of the contained objects and bugs if `ParameterExpression` attempted to use Symengine-specific forms of methods on Sympy objects. --- .../template_substitution.py | 35 +++++++++++------- ...mplate-match-symbols-c00786155f101e39.yaml | 8 ++++ .../transpiler/test_template_matching.py | 37 +++++++++++++++++++ 3 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 releasenotes/notes/template-match-symbols-c00786155f101e39.yaml diff --git a/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py b/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py index b648cbaca944..06c5186d284c 100644 --- a/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py +++ b/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py @@ -22,6 +22,7 @@ from qiskit.dagcircuit.dagcircuit import DAGCircuit from qiskit.dagcircuit.dagdependency import DAGDependency from qiskit.converters.dagdependency_to_dag import dagdependency_to_dag +from qiskit.utils import optionals as _optionals class SubstitutionConfig: @@ -496,6 +497,15 @@ def _attempt_bind(self, template_sublist, circuit_sublist): import sympy as sym from sympy.parsing.sympy_parser import parse_expr + if _optionals.HAS_SYMENGINE: + import symengine + + # Converts Sympy expressions to Symengine ones. + to_native_symbolic = symengine.sympify + else: + # Our native form is sympy, so we don't need to do anything. + to_native_symbolic = lambda x: x + circuit_params, template_params = [], [] # Set of all parameter names that are present in the circuits to be optimised. circuit_params_set = set() @@ -555,7 +565,7 @@ def dummy_parameter(): node.op.params = sub_node_params # Create the fake binding dict and check - equations, circ_dict, temp_symbols, sol, fake_bind = [], {}, {}, {}, {} + equations, circ_dict, temp_symbols = [], {}, {} for circuit_param, template_param in zip(circuit_params, template_params): if isinstance(template_param, ParameterExpression): if isinstance(circuit_param, ParameterExpression): @@ -577,19 +587,18 @@ def dummy_parameter(): if not temp_symbols: return template_dag_dep - # Check compatibility by solving the resulting equation - sym_sol = sym.solve(equations, set(temp_symbols.values())) - for key in sym_sol: - try: - sol[str(key)] = ParameterExpression(circ_dict, sym_sol[key]) - except TypeError: - return None - - if not sol: + # Check compatibility by solving the resulting equation. `dict=True` (surprisingly) forces + # the output to always be a list, even if there's exactly one solution. + sym_sol = sym.solve(equations, set(temp_symbols.values()), dict=True) + if not sym_sol: + # No solutions. return None - - for key in temp_symbols: - fake_bind[key] = sol[str(key)] + # If there's multiple solutions, arbitrarily pick the first one. + sol = { + param.name: ParameterExpression(circ_dict, to_native_symbolic(expr)) + for param, expr in sym_sol[0].items() + } + fake_bind = {key: sol[key.name] for key in temp_symbols} for node in template_dag_dep.get_nodes(): bound_params = [] diff --git a/releasenotes/notes/template-match-symbols-c00786155f101e39.yaml b/releasenotes/notes/template-match-symbols-c00786155f101e39.yaml new file mode 100644 index 000000000000..a005879d22a2 --- /dev/null +++ b/releasenotes/notes/template-match-symbols-c00786155f101e39.yaml @@ -0,0 +1,8 @@ +--- +fixes: + - | + The :class:`.TemplateOptimization` pass will now return parametric expressions using the native + symbolic expression format of :class:`.ParameterExpression`, rather than always using Sympy. + For most supported platforms, this means that the expressions will be Symengine objects. + Previously, the pass could return mismatched objects, which could lead to later failures in + parameter-handling code. diff --git a/test/python/transpiler/test_template_matching.py b/test/python/transpiler/test_template_matching.py index 015e19e4416d..9e461305397f 100644 --- a/test/python/transpiler/test_template_matching.py +++ b/test/python/transpiler/test_template_matching.py @@ -344,6 +344,43 @@ def count_cx(qc): # however these are equivalent if the operators are the same self.assertTrue(Operator(circuit_in).equiv(circuit_out)) + def test_output_symbolic_library_equal(self): + """Test that the template matcher returns parameter expressions that use the same symbolic + library as the default; it should not coerce everything to Sympy when playing with the + `ParameterExpression` internals.""" + + a, b = Parameter("a"), Parameter("b") + + template = QuantumCircuit(1) + template.p(a, 0) + template.p(-a, 0) + template.rz(a, 0) + template.rz(-a, 0) + + circuit = QuantumCircuit(1) + circuit.p(-b, 0) + circuit.p(b, 0) + + pass_ = TemplateOptimization(template_list=[template], user_cost_dict={"p": 100, "rz": 1}) + out = pass_(circuit) + + expected = QuantumCircuit(1) + expected.rz(-b, 0) + expected.rz(b, 0) + self.assertEqual(out, expected) + + def symbolic_library(expr): + """Get the symbolic library of the expression - 'sympy' or 'symengine'.""" + return type(expr._symbol_expr).__module__.split(".")[0] + + out_exprs = [expr for instruction in out.data for expr in instruction.operation.params] + self.assertEqual( + [symbolic_library(b)] * len(out_exprs), [symbolic_library(expr) for expr in out_exprs] + ) + + # Assert that the result still works with parametric assignment. + self.assertEqual(out.assign_parameters({b: 1.5}), expected.assign_parameters({b: 1.5})) + def test_optimizer_does_not_replace_unbound_partial_match(self): """ Test that partial matches with parameters will not raise errors.