Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom backend transpiler stages #8648

Merged
merged 16 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion qiskit/compiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _log_transpile_time(start_time, end_time):
def _combine_args(shared_transpiler_args, unique_config):
# Pop optimization_level to exclude it from the kwargs when building a
# PassManagerConfig
level = shared_transpiler_args.get("optimization_level")
level = shared_transpiler_args.pop("optimization_level")
pass_manager_config = shared_transpiler_args
pass_manager_config.update(unique_config.pop("pass_manager_config"))
pass_manager_config = PassManagerConfig(**pass_manager_config)
Expand Down Expand Up @@ -660,6 +660,9 @@ def _parse_transpile_args(
}

list_transpile_args = []
if scheduling_method is None and hasattr(backend, "get_scheduling_stage"):
scheduling_method = backend.get_scheduling_stage()
mtreinish marked this conversation as resolved.
Show resolved Hide resolved

for key, value in {
"inst_map": inst_map,
"coupling_map": coupling_map,
Expand Down Expand Up @@ -691,6 +694,8 @@ def _parse_transpile_args(
"pass_manager_config": kwargs,
}
list_transpile_args.append(transpile_args)
if hasattr(backend, "get_post_translation_stage"):
shared_dict["post_translation_pm"] = backend.get_post_translation_stage()

return list_transpile_args, shared_dict

Expand Down
48 changes: 48 additions & 0 deletions qiskit/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,54 @@ def _define(self):
transpiler will ensure that it continues to be well supported by Qiskit
moving forward.

.. _custom_transpiler_backend:

Custom Transpiler Passes
^^^^^^^^^^^^^^^^^^^^^^^^
As part of the transpiler there is a provision for backends to provide custom
stage implementation to facilitate hardware specific optimizations and
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
circuit transformations. Currently there are two hook points supported,
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
``get_post_translation_stage()`` which is used for a backend to specify a
:class:`~PassManager` which will be run after basis translation stage in the
compiler and ``get_scheduling_stage()`` which is used for a backend to
specify a :class:`~PassManager` which will be run for the scheduling stage
by default (which is the last defined stage in a default compilation). These
hook points in a :class:`~.BackendV2` class should only be used if your
backend has special requirements for compilation that are not met by the
default backend.

To leverage these hook points you just need to add the methods to your
:class:`~.BackendV2` implementation and have them return a
:class:`~.PassManager` object. For example::

from qiskit.circuit.library import XGate
from qiskit.transpiler.passes import (
ALAPScheduleAnalysis,
PadDynamicalDecoupling,
ResetAfterMeasureSimplification
)

class Mybackend(BackendV2):

def get_scheduling_stage(self):
dd_sequence = [XGate(), XGate()]
pm = PassManager([
ALAPScheduleAnalysis(self.instruction_durations),
PadDynamicalDecoupling(self.instruction_durations, dd_sequence)
])
return pm

def get_post_translation_stage(self):
pm = PassManager([ResetAfterMeasureSimplification()])
return pm

This snippet of a backend implementation will now have the :func:`~.transpile`
function run a custom stage for scheduling (unless the user manually requests a
different one explicitly) which will insert dynamical decoupling sequences and
also simplify resets after measurements after the basis translation stage. This
way if these two compilation steps are **required** for running on ``Mybackend``
the transpiler will be able to perform these steps without any manual user input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to add a different interface for passes that are suggested but not required? For example, ResetAfterMeasureSimplification is (presumably) never required, but for certain backends, it's a generally good idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess required is probably too strong a word. I was trying to reinforce here that these passes will be run by default for every transpile() call so only use it for things you want to be run all the time. My intent was for something like ResetAfterMeasureSimplification to be usable in post-translation which is why I put it in the example. For the appropriate backend it's an easy optimization that improves the fidelity, but that is a backend specific thing and only the backend object will have sufficient context to know whether it's a good idea or not. It's not something we can cleanly express via the target (which is the point of having this interface).

But, writing this I just realize that I think we should add an optimization level kwarg to the hook method. Because, I realize to really support using ResetAfterMeasureSimplification we want to not inject that into the pipeline for level 0. So giving backends the context of which optimization level the pass will enable doing this and potential tweak the custom stages in a similar way for higher optimization levels.


Run Method
----------

Expand Down
14 changes: 14 additions & 0 deletions qiskit/providers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,20 @@ class BackendV2(Backend, ABC):
will build a :class:`~qiskit.providers.models.BackendConfiguration` object
and :class:`~qiskit.providers.models.BackendProperties` from the attributes
defined in this class for backwards compatibility.

A backend object can optionally contain methods named
``get_post_translation_stage`` and ``get_scheduling_stage``. If these
methods are present on a backend object and this object is used for
:func:`~.transpile` or :func:`~.generate_preset_pass_manager` the
transpilation process will default to using the output from those methods
as the scheduling stage and the post-translation compilation stage. This
enables a backend which has custom requirements for compilation to transform
the circuit to ensure it is runnable on the backend. These hooks are enabled
by default and should only be used to enable extra compilation steps
if they are **required** to ensure a circuit is executable on the backend.
These methods are passed no input arguments and are expected to return
a :class:`~.PassManager` object representing that stage of the transpilation
process.
"""

version = 2
Expand Down
11 changes: 8 additions & 3 deletions qiskit/transpiler/passmanager_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
target=None,
init_method=None,
optimization_method=None,
optimization_level=None,
kdk marked this conversation as resolved.
Show resolved Hide resolved
post_translation_pm=None,
):
"""Initialize a PassManagerConfig object

Expand Down Expand Up @@ -80,7 +80,8 @@ def __init__(
init_method (str): The plugin name for the init stage plugin to use
optimization_method (str): The plugin name for the optimization stage plugin
to use.
optimization_level (int): The optimization level being used for compilation.
post_translation_pm (PassManager): An optional pass manager representing a
post-translation stage.
"""
self.initial_layout = initial_layout
self.basis_gates = basis_gates
Expand All @@ -100,7 +101,7 @@ def __init__(
self.unitary_synthesis_method = unitary_synthesis_method
self.unitary_synthesis_plugin_config = unitary_synthesis_plugin_config
self.target = target
self.optimization_level = optimization_level
self.post_translation_pm = post_translation_pm

@classmethod
def from_backend(cls, backend, **pass_manager_options):
Expand Down Expand Up @@ -152,6 +153,10 @@ def from_backend(cls, backend, **pass_manager_options):
if res.target is None:
if backend_version >= 2:
res.target = backend.target
if res.scheduling_method is None and hasattr(backend, "get_scheduling_stage"):
res.scheduling_method = backend.get_scheduling_stage()
if hasattr(backend, "get_post_translation_stage"):
res.post_translation_pm = backend.get_post_translation_stage()
return res

def __str__(self):
Expand Down
1 change: 0 additions & 1 deletion qiskit/transpiler/preset_passmanagers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def generate_preset_pass_manager(
initial_layout=initial_layout,
init_method=init_method,
optimization_method=optimization_method,
optimization_level=optimization_level,
)

if backend is not None:
Expand Down
6 changes: 6 additions & 0 deletions qiskit/transpiler/preset_passmanagers/level0.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def _choose_layout_condition(property_set):
sched = common.generate_scheduling(
instruction_durations, scheduling_method, timing_constraints, inst_map
)
elif isinstance(scheduling_method, PassManager):
sched = scheduling_method
else:
sched = plugin_manager.get_passmanager_stage(
"scheduling", scheduling_method, pass_manager_config, optimization_level=0
Expand All @@ -191,13 +193,17 @@ def _choose_layout_condition(property_set):
optimization = plugin_manager.get_passmanager_stage(
"optimization", optimization_method, pass_manager_config, optimization_level=0
)
post_translation = None
if pass_manager_config.post_translation_pm is not None:
post_translation = pass_manager_config.post_translation_pm
mtreinish marked this conversation as resolved.
Show resolved Hide resolved

return StagedPassManager(
init=init,
layout=layout,
pre_routing=pre_routing,
routing=routing,
translation=translation,
post_translation=post_translation,
pre_optimization=pre_opt,
optimization=optimization,
scheduling=sched,
Expand Down
7 changes: 7 additions & 0 deletions qiskit/transpiler/preset_passmanagers/level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def _unroll_condition(property_set):
sched = common.generate_scheduling(
instruction_durations, scheduling_method, timing_constraints, inst_map
)
elif isinstance(scheduling_method, PassManager):
sched = scheduling_method
else:
sched = plugin_manager.get_passmanager_stage(
"scheduling", scheduling_method, pass_manager_config, optimization_level=1
Expand All @@ -291,12 +293,17 @@ def _unroll_condition(property_set):
else:
init = unroll_3q

post_translation = None
if pass_manager_config.post_translation_pm is not None:
post_translation = pass_manager_config.post_translation_pm

return StagedPassManager(
init=init,
layout=layout,
pre_routing=pre_routing,
routing=routing,
translation=translation,
post_translation=post_translation,
pre_optimization=pre_optimization,
optimization=optimization,
scheduling=sched,
Expand Down
7 changes: 7 additions & 0 deletions qiskit/transpiler/preset_passmanagers/level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def _unroll_condition(property_set):
sched = common.generate_scheduling(
instruction_durations, scheduling_method, timing_constraints, inst_map
)
elif isinstance(scheduling_method, PassManager):
sched = scheduling_method
else:
sched = plugin_manager.get_passmanager_stage(
"scheduling", scheduling_method, pass_manager_config, optimization_level=2
Expand All @@ -266,12 +268,17 @@ def _unroll_condition(property_set):
else:
init = unroll_3q

post_translation = None
if pass_manager_config.post_translation_pm is not None:
post_translation = pass_manager_config.post_translation_pm

return StagedPassManager(
init=init,
layout=layout,
pre_routing=pre_routing,
routing=routing,
translation=translation,
post_translation=post_translation,
pre_optimization=pre_optimization,
optimization=optimization,
scheduling=sched,
Expand Down
13 changes: 6 additions & 7 deletions qiskit/transpiler/preset_passmanagers/level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def level_3_pass_manager(pass_manager_config: PassManagerConfig) -> StagedPassMa
timing_constraints = pass_manager_config.timing_constraints or TimingConstraints()
unitary_synthesis_plugin_config = pass_manager_config.unitary_synthesis_plugin_config
target = pass_manager_config.target
# Override an unset optimization_level for stage plugin use.
# it will be restored to None before this is returned
optimization_level = pass_manager_config.optimization_level
if optimization_level is None:
pass_manager_config.optimization_level = 3

# Layout on good qubits if calibration info available, otherwise on dense links
_given_layout = SetLayout(initial_layout)
Expand Down Expand Up @@ -313,20 +308,24 @@ def _unroll_condition(property_set):
sched = common.generate_scheduling(
instruction_durations, scheduling_method, timing_constraints, inst_map
)
elif isinstance(scheduling_method, PassManager):
sched = scheduling_method
else:
sched = plugin_manager.get_passmanager_stage(
"scheduling", scheduling_method, pass_manager_config, optimization_level=3
)

# Restore PassManagerConfig optimization_level override
pass_manager_config.optimization_level = optimization_level
post_translation = None
if pass_manager_config.post_translation_pm is not None:
post_translation = pass_manager_config.post_translation_pm

return StagedPassManager(
init=init,
layout=layout,
pre_routing=pre_routing,
routing=routing,
translation=translation,
post_translation=post_translation,
pre_optimization=pre_optimization,
optimization=optimization,
scheduling=sched,
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/add-backend-custom-passes-cddfd05c8704a4b1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
features:
- |
The :class:`~.BackendV2` class now has support for two new optional hook
points enabling backends to inject custom compilation steps as part of
:func:`~.transpile` and :func:`~.generate_preset_pass_manager`. If a
:class:`~.BackendV2` implementation includes the methods
``get_scheduling_stage()`` or ``get_post_translation_stage()`` the
transpiler will use the returned :class:`~.PassManager` object to run
additional custom transpiler passes when targetting that backend.
For more details on how to use this see :ref:`custom_transpiler_backend`.
13 changes: 5 additions & 8 deletions releasenotes/notes/stage-plugin-interface-47daae40f7d0ad3c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@ features:
``optimization_method`` which are used to specify alternative plugins to
use for the ``init`` stage and ``optimization`` stages respectively.
- |
The :class:`~.PassManagerConfig` class has 3 new attributes,
:attr:`~.PassManagerConfig.init_method`,
:attr:`~.PassManagerConfig.optimization_method`, and
:attr:`~.PassManagerConfig.optimization_level` along with matching keyword
arguments on the constructor methods. The first two attributes represent
The :class:`~.PassManagerConfig` class has 2 new attributes,
:attr:`~.PassManagerConfig.init_method` and
:attr:`~.PassManagerConfig.optimization_method`
along with matching keyword arguments on the constructor methods. These represent
the user specified ``init`` and ``optimization`` plugins to use for
compilation. The :attr:`~.PassManagerConfig.optimization_level` attribute
represents the compilations optimization level if specified which can
be used to inform stage plugin behavior.
compilation.
77 changes: 76 additions & 1 deletion test/python/transpiler/test_preset_passmanagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from qiskit.circuit import Qubit
from qiskit.compiler import transpile, assemble
from qiskit.transpiler import CouplingMap, Layout, PassManager, TranspilerError
from qiskit.circuit.library import U2Gate, U3Gate, QuantumVolume
from qiskit.transpiler.passes import (
ALAPScheduleAnalysis,
PadDynamicalDecoupling,
RemoveResetInZeroState,
)
from qiskit.circuit.library import U2Gate, U3Gate, XGate, QuantumVolume
from qiskit.test import QiskitTestCase
from qiskit.providers.fake_provider import (
FakeBelem,
Expand Down Expand Up @@ -431,6 +436,39 @@ def test_partial_layout_fully_connected_cm(self, level):
Layout.from_qubit_list([ancilla[0], ancilla[1], qr[1], ancilla[2], qr[0]]),
)

@data(0, 1, 2, 3)
def test_backend_with_custom_stages(self, optimization_level):
"""Test transpile() executes backend specific custom stage."""

class TargetBackend(FakeLagosV2):
"""Fake Lagos subclass with custom transpiler stages."""

def get_scheduling_stage(self):
"""Custom scheduling passes."""
dd_sequence = [XGate(), XGate()]
pm = PassManager(
[
ALAPScheduleAnalysis(self.instruction_durations),
PadDynamicalDecoupling(self.instruction_durations, dd_sequence),
]
)
return pm

def get_post_translation_stage(self):
"""Custom post translation stage."""
pm = PassManager([RemoveResetInZeroState()])
return pm

target = TargetBackend()
qr = QuantumRegister(2, "q")
qc = QuantumCircuit(qr)
qc.h(qr[0])
qc.cx(qr[0], qr[1])
_ = transpile(qc, target, optimization_level=optimization_level, callback=self.callback)
self.assertIn("ALAPScheduleAnalysis", self.passes)
self.assertIn("PadDynamicalDecoupling", self.passes)
self.assertIn("RemoveResetInZeroState", self.passes)


@ddt
class TestInitialLayouts(QiskitTestCase):
Expand Down Expand Up @@ -1052,3 +1090,40 @@ def test_invalid_optimization_level(self):
"""Assert we fail with an invalid optimization_level."""
with self.assertRaises(ValueError):
generate_preset_pass_manager(42)

@data(0, 1, 2, 3)
def test_backend_with_custom_stages(self, optimization_level):
"""Test generated preset pass manager includes backend specific custom stages."""

class TargetBackend(FakeLagosV2):
"""Fake lagos subclass with custom transpiler stages."""

def get_scheduling_stage(self):
"""Custom scheduling stage."""
dd_sequence = [XGate(), XGate()]
pm = PassManager(
[
ALAPScheduleAnalysis(self.instruction_durations),
PadDynamicalDecoupling(self.instruction_durations, dd_sequence),
]
)
return pm

def get_post_translation_stage(self):
"""Custom post translation stage."""
pm = PassManager([RemoveResetInZeroState()])
return pm

target = TargetBackend()
pm = generate_preset_pass_manager(optimization_level, target)
self.assertIsInstance(pm, PassManager)
pass_list = [y.__class__.__name__ for x in pm.passes() for y in x["passes"]]
self.assertIn("PadDynamicalDecoupling", pass_list)
self.assertIn("ALAPScheduleAnalysis", pass_list)
self.assertIsNotNone(pm.post_translation) # pylint: disable=no-member
post_translation_pass_list = [
y.__class__.__name__
for x in pm.post_translation.passes() # pylint: disable=no-member
for y in x["passes"]
]
self.assertIn("RemoveResetInZeroState", post_translation_pass_list)