From a3e080cb723394c85f17fd51c300feb91e61241e Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Wed, 4 Dec 2024 03:39:08 -0500 Subject: [PATCH] add `cluster_node_limit` for MWPF decoder to better tune decoding time and accuracy (#857) This is a new parameter that is agnostic to individual machine's clock speed. We limit the maximum number of dual variables inside each cluster to avoid wasting computing time on small yet complicated clusters. A default value of 50 is good enough for small code and improves the decoding speed a lot. Intuitively (but not precisely), this means we limit the maximum number of dual variables (blossoms and their children) in an alternating tree to 50, and once it hits this limit, it will fall back to union-find decoder. Of course, in hypergraph cases, the situation is a little bit more complicated and this limit is hit more often. --- .github/workflows/ci.yml | 2 +- .../src/sinter/_decoding/_decoding_mwpf.py | 37 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aa535b4f..6173fafb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -431,7 +431,7 @@ jobs: - run: bazel build :stim_dev_wheel - run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl - run: pip install -e glue/sample - - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1 + - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5 - run: pytest glue/sample - run: dev/doctest_proper.py --module sinter - run: sinter help diff --git a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py index 461cbc0f..2b69c608 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError: return ImportError( "The decoder 'MWPF' isn't installed\n" "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n" ) @@ -75,12 +75,18 @@ def compile_decoder_for_dem( # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, decoder_cls=decoder_cls + dem, + decoder_cls=decoder_cls, + cluster_node_limit=cluster_node_limit, ) return MwpfCompiledDecoder( - solver, fault_masks, dem.num_detectors, dem.num_observables + solver, + fault_masks, + dem.num_detectors, + dem.num_observables, ) def decode_via_files( @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def deduplicate_hyperedges( hyperedges: List[Tuple[List[int], float, int]] ) -> List[Tuple[List[int], float, int]]: - indices: dict[frozenset[int], int] = dict() + indices: dict[frozenset[int], Tuple[int, float]] = dict() result: List[Tuple[List[int], float, int]] = [] for dets, weight, mask in hyperedges: dets_set = frozenset(dets) if dets_set in indices: - idx = indices[dets_set] + idx, min_weight = indices[dets_set] p1 = 1 / (1 + math.exp(weight)) p2 = 1 / (1 + math.exp(result[idx][1])) p = p1 * (1 - p2) + p2 * (1 - p1) - # not sure why would this fail? two hyperedges with different masks? - # assert mask == result[idx][2], (result[idx], (dets, weight, mask)) - result[idx] = (dets, math.log((1 - p) / p), result[idx][2]) + # choosing the mask from the most likely error + new_mask = result[idx][2] + if weight < min_weight: + indices[dets_set] = (idx, weight) + new_mask = mask + result[idx] = (dets, math.log((1 - p) / p), new_mask) else: - indices[dets_set] = len(result) + indices[dets_set] = (len(result), weight) result.append((dets, weight, mask)) return result def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, decoder_cls: Any = None + model: stim.DetectorErrorModel, + decoder_cls: Any = None, + cluster_node_limit: int = 50, ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: """Convert a stim error model into a NetworkX graph.""" @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]): # Accept it and keep going, though of course decoding will probably perform terribly. return if p > 0.5: - # mwpf doesn't support negative edge weights. + # mwpf doesn't support negative edge weights (yet, will be supported in the next version). # approximate them as weight 0. p = 0.5 weight = math.log((1 - p) / p) @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): # mwpf package panic on duplicate edges, thus we need to handle them here hyperedges = deduplicate_hyperedges(hyperedges) - # fix the input by connecting an edge to all isolated vertices + # fix the input by connecting an edge to all isolated vertices; will be supported in the next version for idx in range(num_detectors): if not is_detector_connected[idx]: hyperedges.append(([idx], 0, 0)) @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): decoder_cls = mwpf.SolverSerialJointSingleHair return ( ( - decoder_cls(initializer) + decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit}) if num_detectors > 0 and len(rescaled_edges) > 0 else None ),