Skip to content

Commit

Permalink
add cluster_node_limit for MWPF decoder to better tune decoding tim…
Browse files Browse the repository at this point in the history
…e and accuracy (#857)

This is a new parameter that is agnostic to individual machine's clock
speed. We limit the maximum number of dual variables inside each cluster
to avoid wasting computing time on small yet complicated clusters. A
default value of 50 is good enough for small code and improves the
decoding speed a lot. Intuitively (but not precisely), this means we
limit the maximum number of dual variables (blossoms and their children)
in an alternating tree to 50, and once it hits this limit, it will fall
back to union-find decoder. Of course, in hypergraph cases, the
situation is a little bit more complicated and this limit is hit more
often.
  • Loading branch information
yuewuo authored Dec 4, 2024
1 parent e6fd563 commit a3e080c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ jobs:
- run: bazel build :stim_dev_wheel
- run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl
- run: pip install -e glue/sample
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5
- run: pytest glue/sample
- run: dev/doctest_proper.py --module sinter
- run: sinter help
Expand Down
37 changes: 24 additions & 13 deletions glue/sample/src/sinter/_decoding/_decoding_mwpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
return ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
)


Expand Down Expand Up @@ -75,12 +75,18 @@ def compile_decoder_for_dem(
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
) -> CompiledDecoder:
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem, decoder_cls=decoder_cls
dem,
decoder_cls=decoder_cls,
cluster_node_limit=cluster_node_limit,
)
return MwpfCompiledDecoder(
solver, fault_masks, dem.num_detectors, dem.num_observables
solver,
fault_masks,
dem.num_detectors,
dem.num_observables,
)

def decode_via_files(
Expand Down Expand Up @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
def deduplicate_hyperedges(
hyperedges: List[Tuple[List[int], float, int]]
) -> List[Tuple[List[int], float, int]]:
indices: dict[frozenset[int], int] = dict()
indices: dict[frozenset[int], Tuple[int, float]] = dict()
result: List[Tuple[List[int], float, int]] = []
for dets, weight, mask in hyperedges:
dets_set = frozenset(dets)
if dets_set in indices:
idx = indices[dets_set]
idx, min_weight = indices[dets_set]
p1 = 1 / (1 + math.exp(weight))
p2 = 1 / (1 + math.exp(result[idx][1]))
p = p1 * (1 - p2) + p2 * (1 - p1)
# not sure why would this fail? two hyperedges with different masks?
# assert mask == result[idx][2], (result[idx], (dets, weight, mask))
result[idx] = (dets, math.log((1 - p) / p), result[idx][2])
# choosing the mask from the most likely error
new_mask = result[idx][2]
if weight < min_weight:
indices[dets_set] = (idx, weight)
new_mask = mask
result[idx] = (dets, math.log((1 - p) / p), new_mask)
else:
indices[dets_set] = len(result)
indices[dets_set] = (len(result), weight)
result.append((dets, weight, mask))
return result


def detector_error_model_to_mwpf_solver_and_fault_masks(
model: stim.DetectorErrorModel, decoder_cls: Any = None
model: stim.DetectorErrorModel,
decoder_cls: Any = None,
cluster_node_limit: int = 50,
) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
"""Convert a stim error model into a NetworkX graph."""

Expand All @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
# Accept it and keep going, though of course decoding will probably perform terribly.
return
if p > 0.5:
# mwpf doesn't support negative edge weights.
# mwpf doesn't support negative edge weights (yet, will be supported in the next version).
# approximate them as weight 0.
p = 0.5
weight = math.log((1 - p) / p)
Expand All @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
# mwpf package panic on duplicate edges, thus we need to handle them here
hyperedges = deduplicate_hyperedges(hyperedges)

# fix the input by connecting an edge to all isolated vertices
# fix the input by connecting an edge to all isolated vertices; will be supported in the next version
for idx in range(num_detectors):
if not is_detector_connected[idx]:
hyperedges.append(([idx], 0, 0))
Expand All @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
decoder_cls = mwpf.SolverSerialJointSingleHair
return (
(
decoder_cls(initializer)
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
if num_detectors > 0 and len(rescaled_edges) > 0
else None
),
Expand Down

0 comments on commit a3e080c

Please sign in to comment.