diff --git a/bqskit/passes/control/foreach.py b/bqskit/passes/control/foreach.py index 8c3c9da02..7eb7b69ad 100644 --- a/bqskit/passes/control/foreach.py +++ b/bqskit/passes/control/foreach.py @@ -38,6 +38,11 @@ class ForEachBlockPass(BasePass): pass_down_key_prefix = 'ForEachBlockPass_pass_down_' """If a key exists in the pass data with this prefix, pass it to blocks.""" + pass_down_block_specific_key_prefix = ( + 'ForEachBlockPass_specific_pass_down_' + ) + """Key for injecting a block specific pass data.""" + def __init__( self, loop_body: WorkflowLike, @@ -197,6 +202,10 @@ async def run(self, circuit: Circuit, data: PassData) -> None: for key in data: if key.startswith(self.pass_down_key_prefix): block_data[key] = data[key] + elif key.startswith( + self.pass_down_block_specific_key_prefix, + ) and i in data[key]: + block_data[key] = data[key][i] block_data.seed = data.seed subcircuits.append(subcircuit) diff --git a/tests/passes/control/test_foreachblock.py b/tests/passes/control/test_foreachblock.py index 605392b85..23526adfd 100644 --- a/tests/passes/control/test_foreachblock.py +++ b/tests/passes/control/test_foreachblock.py @@ -5,8 +5,10 @@ from bqskit.compiler.basepass import BasePass from bqskit.compiler.compiler import Compiler from bqskit.compiler.passdata import PassData +from bqskit.compiler.workflow import Workflow from bqskit.ir.circuit import Circuit from bqskit.ir.gates import CircuitGate +from bqskit.ir.gates import CNOTGate from bqskit.ir.gates import HGate from bqskit.ir.gates import XGate from bqskit.ir.gates import YGate @@ -14,6 +16,9 @@ from bqskit.ir.operation import Operation from bqskit.passes import UnfoldPass from bqskit.passes.control.foreach import ForEachBlockPass +from bqskit.passes.partitioning.quick import QuickPartitioner +from bqskit.passes.synthesis.qsearch import QSearchSynthesisPass +from bqskit.passes.util.update import UpdateDataPass @pytest.fixture @@ -130,3 +135,31 @@ def test_no_hang_on_empty_collection(compiler: Compiler) -> None: circuit.append_gate(XGate(), 0) feb_pass = ForEachBlockPass(RemoveXGatePass(), collection_filter=empty_coll) compiler.compile(circuit, feb_pass) + + +def test_pass_down_seeds(compiler: Compiler) -> None: + circuit = Circuit(3) + circuit.append_gate(CNOTGate(), (0, 1)) + circuit.append_gate(CNOTGate(), (1, 2)) + + seed = Circuit(3) + seed.append_gate(CNOTGate(), (0, 1)) + seed.append_gate(CNOTGate(), (1, 2)) + + # Manually sed seed for block 0 + seeds = {0: [seed]} + + key = 'ForEachBlockPass_specific_pass_down_seed_circuits' + + partitioner = QuickPartitioner() + updater = UpdateDataPass(key, seeds) + qsearch = QSearchSynthesisPass() + foreach = ForEachBlockPass(qsearch) + unfolder = UnfoldPass() + workflow = Workflow([partitioner, updater, foreach, unfolder]) + compiled, data = compiler.compile( + circuit, workflow, request_data=True, + ) + dist = compiled.get_unitary().get_distance_from(circuit.get_unitary()) + assert dist <= 1e-5 + assert key in data