Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid exploring extraneous minima in the cut-finder search space #585

Merged
merged 14 commits into from
May 14, 2024
Merged
43 changes: 28 additions & 15 deletions circuit_knitting/cutting/cut_finding/best_first_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import heapq
import numpy as np
from typing import TYPE_CHECKING, Callable, cast
from typing import TYPE_CHECKING, Callable, cast, NamedTuple
from itertools import count

from .optimization_settings import OptimizationSettings
Expand All @@ -26,6 +26,21 @@
from .cut_optimization import CutOptimizationFuncArgs


class SearchStats(NamedTuple):
"""NamedTuple for collecting search statistics.

It carries information about the number of states visited
(dequeued from the search queue), the number of next-states generated,
the number of next-states that are enqueued after cost pruning,
and the number of backjumps performed.
"""

states_visited: int
next_states_generated: int
states_enqueued: int
backjumps: int


class BestFirstPriorityQueue:
"""Class that implements priority queues for best-first search.

Expand Down Expand Up @@ -149,6 +164,8 @@ class BestFirstSearch:

``stop_at_first_min`` (Boolean) is a flag that indicates whether or not to
stop the search after the first minimum-cost goal state has been reached.
In the absence of any non-LO QPD assignments, it always makes sense to stop once
the first minimum has been reached and therefore, we set this bool to ``True``.

``max_backjumps`` (int or None) is the maximum number of backjump operations that
can be performed before the search is forced to terminate. None indicates
Expand Down Expand Up @@ -185,7 +202,7 @@ def __init__(
self,
optimization_settings: OptimizationSettings,
search_functions: SearchFunctions,
stop_at_first_min: bool = False,
stop_at_first_min: bool = True,
):
"""Initialize an instance of :class:`BestFirstSearch`.

Expand Down Expand Up @@ -213,7 +230,7 @@ def __init__(
self.num_next_states = 0
self.num_enqueues = 0
self.num_backjumps = 0
self.penultimate_stats: np.typing.NDArray | None = None
self.penultimate_stats: NamedTuple | None = None

def initialize(
self,
Expand Down Expand Up @@ -258,7 +275,6 @@ def optimization_pass(
self.mincost_bound = self.mincost_bound_func(*args) # type: ignore

prev_depth = None

while (
self.pqueue.qsize() > 0
and (not self.stop_at_first_min or not self.min_reached)
Expand All @@ -267,7 +283,6 @@ def optimization_pass(
state, depth, cost = self.pqueue.get()

self.update_minimum_reached(cost)

if cost is None or self.cost_bounds_exceeded(cost):
return None, None

Expand Down Expand Up @@ -299,10 +314,10 @@ def minimum_reached(self) -> bool:
"""Return True if the optimization reached a global minimum."""
return self.min_reached

def get_stats(self, penultimate: bool = False) -> np.typing.NDArray[np.int_] | None:
def get_stats(self, penultimate: bool = False) -> NamedTuple | None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it complain if you use SearchStats as the return type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, I have changed that here (d931ec6).

"""Return statistics of the search that was performed.

This is a Numpy array containing the number of states visited
This is a NamedTuple containing the number of states visited
(dequeued), the number of next-states generated, the number of
next-states that are enqueued after cost pruning, and the number
of backjumps performed. Return None if no search is performed.
Expand All @@ -312,15 +327,13 @@ def get_stats(self, penultimate: bool = False) -> np.typing.NDArray[np.int_] | N
if penultimate:
return self.penultimate_stats

return np.array(
(
self.num_states_visited,
self.num_next_states,
self.num_enqueues,
self.num_backjumps,
),
dtype=int,
search_stats = SearchStats(
states_visited=self.num_states_visited,
next_states_generated=self.num_next_states,
states_enqueued=self.num_enqueues,
backjumps=self.num_backjumps,
)
return search_stats
ibrahim-shehzad marked this conversation as resolved.
Show resolved Hide resolved

def get_upperbound_cost(
self,
Expand Down
7 changes: 3 additions & 4 deletions circuit_knitting/cutting/cut_finding/cut_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import numpy as np
from dataclasses import dataclass
from typing import cast
from numpy.typing import NDArray
from typing import cast, NamedTuple
from .search_space_generator import ActionNames
from .cco_utils import select_search_engine, greedy_best_first_search
from .cutting_actions import disjoint_subcircuit_actions
Expand Down Expand Up @@ -261,7 +260,7 @@ def __init__(
"CutOptimization",
self.settings,
self.search_funcs,
stop_at_first_min=False,
stop_at_first_min=True,
)
sq.initialize([start_state], self.func_args)

Expand Down Expand Up @@ -299,7 +298,7 @@ def minimum_reached(self) -> bool:
"""
return self.search_engine.minimum_reached()

def get_stats(self, penultimate: bool = False) -> NDArray[np.int_]:
def get_stats(self, penultimate: bool = False) -> NamedTuple | None:
"""Return the search-engine statistics.

This is a Numpy array containing the number of states visited
Expand Down
9 changes: 3 additions & 6 deletions circuit_knitting/cutting/cut_finding/lo_cuts_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""File containing the wrapper class for optimizing LO gate and wire cuts."""
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple

from .cut_optimization import CutOptimization
from .cut_optimization import disjoint_subcircuit_actions
Expand All @@ -21,9 +21,6 @@
from .cut_optimization import cut_optimization_min_cost_bound_func
from .cut_optimization import cut_optimization_upper_bound_cost_func
from .search_space_generator import SearchFunctions, SearchSpaceGenerator

import numpy as np
from numpy.typing import NDArray
from .disjoint_subcircuits_state import DisjointSubcircuitsState

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -155,10 +152,10 @@ def get_results(self) -> DisjointSubcircuitsState | None:
"""Return the optimization results."""
return self.best_result

def get_stats(self, penultimate=False) -> dict[str, NDArray[np.int_]]:
def get_stats(self, penultimate=False) -> dict[str, NamedTuple | None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that there might one day be more keys in this dict than just CutOptimization?

I wonder if this would be better.

Suggested change
def get_stats(self, penultimate=False) -> dict[str, NamedTuple | None]:
def get_stats(self, penultimate=False) -> dict[str, Any]:

but the docstring is still a little bit weird, because it talks about the "value" of the dict without referencing what the key(s) are.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that there might one day be more keys in this dict than just CutOptimization?

Yes, that's right.

Copy link
Member

@garrison garrison May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I think my suggestion of Any makes the most sense, and maybe edit the docstring for clarity too.

"""Return a dictionary containing optimization results.

The value is a Numpy array containing the number of states visited
The value is a NamedTuple containing the number of states visited
(dequeued), the number of next-states generated, the number of
next-states that are enqueued after cost pruning, and the number
of backjumps performed. Return None if no search is performed.
Expand Down
14 changes: 7 additions & 7 deletions docs/circuit_cutting/tutorials/04_automatic_cut_finding.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
upgrade:
- |
The search engine inside the automated cut-finder has been primed to avoid extraneous searches and is therefore expected to run faster.
83 changes: 82 additions & 1 deletion test/cutting/cut_finding/test_best_first_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,31 @@
CircuitElement,
GateSpec,
)
from circuit_knitting.cutting.cut_finding.cut_optimization import CutOptimization
from circuit_knitting.cutting.cut_finding.cut_optimization import (
cut_optimization_next_state_func,
cut_optimization_min_cost_bound_func,
cut_optimization_cost_func,
cut_optimization_goal_state_func,
cut_optimization_upper_bound_cost_func,
CutOptimizationFuncArgs,
CutOptimization,
)
from circuit_knitting.cutting.cut_finding.optimization_settings import (
OptimizationSettings,
)
from circuit_knitting.cutting.automated_cut_finding import DeviceConstraints
from circuit_knitting.cutting.cut_finding.disjoint_subcircuits_state import (
get_actions_list,
)
from circuit_knitting.cutting.cut_finding.cutting_actions import (
disjoint_subcircuit_actions,
DisjointSubcircuitsState,
)

from circuit_knitting.cutting.cut_finding.best_first_search import (
BestFirstSearch,
SearchFunctions,
)


@fixture
Expand Down Expand Up @@ -124,3 +141,67 @@ def test_best_first_search(test_circuit: SimpleGateList):
assert op.get_upperbound_cost() == (27, inf)
assert op.minimum_reached() is True
assert out is None


def test_best_first_search_termination():
"""Test that if the best first search is run multiple times, it terminates once no further feasible cut states can be found,
in which case None is returned for both the cost and the state. This test also serves to describe the workflow of the optimizer
at a granular level."""

# Specify circuit
circuit = [
CircuitElement(name="cx", params=[], qubits=[0, 1], gamma=3),
CircuitElement(name="cx", params=[], qubits=[2, 3], gamma=3),
CircuitElement(name="cx", params=[], qubits=[1, 2], gamma=3),
]

interface = SimpleGateList(circuit)

# Specify optimization settings, search engine, and device constraints.
settings = OptimizationSettings(seed=123)
settings.set_engine_selection("CutOptimization", "BestFirst")

constraints = DeviceConstraints(qubits_per_subcircuit=3)

# Initialize and pass arguments to search space generating object.
func_args = CutOptimizationFuncArgs()
func_args.entangling_gates = interface.get_multiqubit_gates()
func_args.search_actions = disjoint_subcircuit_actions
func_args.max_gamma = settings.get_max_gamma
func_args.qpu_width = constraints.get_qpu_width()

# Initialize search functions object, needed to explore a search space.
cut_optimization_search_funcs = SearchFunctions(
cost_func=cut_optimization_cost_func,
upperbound_cost_func=cut_optimization_upper_bound_cost_func,
next_state_func=cut_optimization_next_state_func,
goal_state_func=cut_optimization_goal_state_func,
mincost_bound_func=cut_optimization_min_cost_bound_func,
)

# Initialize disjoint subcircuits state object
# while specifying number of qubits and max allowed wire cuts.
state = DisjointSubcircuitsState(interface.get_num_qubits(), 2)

# Initialize bfs object.
bfs = BestFirstSearch(
optimization_settings=settings, search_functions=cut_optimization_search_funcs
)

# Push an input state.
bfs.initialize([state], func_args)

counter = 0

cut_state = state
while cut_state is not None:
cut_state, cut_cost = bfs.optimization_pass(func_args)
counter += 1

# There are 5 possible cut states that can be found for this circuit,
# given that there need to be 3 qubits per subcircuit. These correspond
# to 3 gate cuts (i.e cutting any of the 3 gates) and cutting either of
# the input wires to the CNOT between qubits 1 and 2.
# After these 5 possible cuts are returned, at the 6th iteration, None
# is returned for both the state and the cost.
assert counter == 6 and cut_cost is None
6 changes: 3 additions & 3 deletions test/cutting/cut_finding/test_cut_finder_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

import numpy as np
from numpy import array
from pytest import fixture, raises
from qiskit import QuantumCircuit
from typing import Callable
Expand Down Expand Up @@ -190,8 +189,9 @@ def test_four_qubit_circuit_two_qubit_qpu(
) # circuit separated into 2 subcircuits.

assert (
optimization_pass.get_stats()["CutOptimization"] == array([15, 46, 15, 6])
).all() # matches known stats.
optimization_pass.get_stats()["CutOptimization"].backjumps
<= settings.max_backjumps
)


def test_seven_qubit_circuit_two_qubit_qpu(
Expand Down