From a14dc37441f73e8c9630d43b03ba8a1c9de5c757 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Thu, 24 Mar 2022 14:58:04 +0000 Subject: [PATCH] Fix template parameter substitution This fixes issues introduced recently in the PR that caused parameters to be incorrectly bound in the result. This meant that the actual numbers in the produced circuits were incorrect. This happened mostly by tracking data structures being updated at the wrong levels within loops. In addition, this commit also updates some data structures to more robust and efficient versions: - Testing whether a parameter has a clash is best done by constructing a set of parameters used in the input circuits, then testing directly on this, rather than stringifying expressions and using subsearch matches; this avoids problems if two parameters have contained names, or if more than one match is catenated into a single string. - Using a dictionary with a missing-element constructor to build the replacement parameters allows the looping logic to be simpler; the "build missing element" logic can be separated out to happen automatically. --- .../template_substitution.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py b/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py index 1e47acec8b5d..9eab25363f65 100644 --- a/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py +++ b/qiskit/transpiler/passes/optimization/template_matching/template_substitution.py @@ -14,9 +14,9 @@ Template matching substitution, given a list of maximal matches it substitutes them in circuit and creates a new optimized dag version of the circuit. """ +import collections import copy -import random -import string +import itertools from qiskit.circuit import Parameter, ParameterExpression from qiskit.dagcircuit.dagcircuit import DAGCircuit @@ -502,42 +502,44 @@ def _attempt_bind(self, template_sublist, circuit_sublist): from sympy.parsing.sympy_parser import parse_expr circuit_params, template_params = [], [] - circ_param_str = "" - sub_params = {} + # Set of all parameter names that are present in the circuits to be optimised. + circuit_params_set = set() template_dag_dep = copy.deepcopy(self.template_dag_dep) - # add parameters from circuit to circuit_params, as well as a - # string containing parameters to check for duplicates from the template + # add parameters from circuit to circuit_params for idx, t_idx in enumerate(template_sublist): qc_idx = circuit_sublist[idx] - circuit_params += self.circuit_dag_dep.get_node(qc_idx).op.params - for param_exp in self.circuit_dag_dep.get_node(qc_idx).op.params: - if isinstance(param_exp, ParameterExpression): - circ_param_str += str(param_exp) + parameters = self.circuit_dag_dep.get_node(qc_idx).op.params + circuit_params += parameters + for parameter in parameters: + if isinstance(parameter, ParameterExpression): + circuit_params_set.update(x.name for x in parameter.parameters) + + _dummy_counter = itertools.count() + + def dummy_parameter(): + # Strictly not _guaranteed_ to avoid naming clashes, but if someone's calling their + # parameters this then that's their own fault. + return Parameter(f"_qiskit_template_dummy_{next(_dummy_counter)}") + + # Substitutions for parameters that have clashing names between the input circuits and the + # defined templates. + template_clash_substitutions = collections.defaultdict(dummy_parameter) - # add parameters from template to template_params, as well as replace - # node parameters that have duplicate names in the circuit params - # create a dict of sub_params to substitute remaining nodes not in - # the template sublist + # add parameters from template to template_params, replacing parameters with names that + # clash with those in the circuit. for idx, t_idx in enumerate(template_sublist): node = template_dag_dep.get_node(t_idx) sub_node_params = [] for t_param_exp in node.op.params: if isinstance(t_param_exp, ParameterExpression): for t_param in t_param_exp.parameters: - if t_param.name in circ_param_str: - new_param_name = "".join( - random.choice(string.ascii_lowercase) for i in range(8) - ) - sub_params[t_param] = Parameter(new_param_name) - t_param_exp = t_param_exp.assign(t_param, sub_params[t_param]) - sub_node_params.append(t_param_exp) - template_params.append(t_param_exp) - else: - sub_node_params.append(t_param_exp) - template_params.append(t_param_exp) - + if t_param.name in circuit_params_set: + new_param = template_clash_substitutions[t_param.name] + t_param_exp = t_param_exp.assign(t_param, new_param) + sub_node_params.append(t_param_exp) + template_params.append(t_param_exp) node.op.params = sub_node_params for node in template_dag_dep.get_nodes(): @@ -545,14 +547,11 @@ def _attempt_bind(self, template_sublist, circuit_sublist): for param_exp in node.op.params: if isinstance(param_exp, ParameterExpression): for param in param_exp.parameters: - if param in sub_params: - sub_node_params.append( - param_exp.subs(sub_params) - ) # prolly need to specify exact entry - else: - sub_node_params.append(param_exp) - else: - sub_node_params.append(param_exp) + if param.name in template_clash_substitutions: + param_exp = param_exp.assign( + param, template_clash_substitutions[param.name] + ) + sub_node_params.append(param_exp) node.op.params = sub_node_params