Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cluster_node_limit for MWPF decoder to better tune decoding time and accuracy #857

Merged
merged 10 commits into from
Dec 4, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ jobs:
- run: bazel build :stim_dev_wheel
- run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl
- run: pip install -e glue/sample
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5
- run: pytest glue/sample
- run: dev/doctest_proper.py --module sinter
- run: sinter help
Expand Down
37 changes: 24 additions & 13 deletions glue/sample/src/sinter/_decoding/_decoding_mwpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
return ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
)


Expand Down Expand Up @@ -75,12 +75,18 @@ def compile_decoder_for_dem(
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
) -> CompiledDecoder:
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem, decoder_cls=decoder_cls
dem,
decoder_cls=decoder_cls,
cluster_node_limit=cluster_node_limit,
)
return MwpfCompiledDecoder(
solver, fault_masks, dem.num_detectors, dem.num_observables
solver,
fault_masks,
dem.num_detectors,
dem.num_observables,
)

def decode_via_files(
Expand Down Expand Up @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
def deduplicate_hyperedges(
hyperedges: List[Tuple[List[int], float, int]]
) -> List[Tuple[List[int], float, int]]:
indices: dict[frozenset[int], int] = dict()
indices: dict[frozenset[int], Tuple[int, float]] = dict()
result: List[Tuple[List[int], float, int]] = []
for dets, weight, mask in hyperedges:
dets_set = frozenset(dets)
if dets_set in indices:
idx = indices[dets_set]
idx, min_weight = indices[dets_set]
p1 = 1 / (1 + math.exp(weight))
p2 = 1 / (1 + math.exp(result[idx][1]))
p = p1 * (1 - p2) + p2 * (1 - p1)
# not sure why would this fail? two hyperedges with different masks?
# assert mask == result[idx][2], (result[idx], (dets, weight, mask))
result[idx] = (dets, math.log((1 - p) / p), result[idx][2])
# choosing the mask from the most likely error
new_mask = result[idx][2]
if weight < min_weight:
indices[dets_set] = (idx, weight)
new_mask = mask
result[idx] = (dets, math.log((1 - p) / p), new_mask)
else:
indices[dets_set] = len(result)
indices[dets_set] = (len(result), weight)
result.append((dets, weight, mask))
return result


def detector_error_model_to_mwpf_solver_and_fault_masks(
model: stim.DetectorErrorModel, decoder_cls: Any = None
model: stim.DetectorErrorModel,
decoder_cls: Any = None,
cluster_node_limit: int = 50,
) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
"""Convert a stim error model into a NetworkX graph."""

Expand All @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
# Accept it and keep going, though of course decoding will probably perform terribly.
return
if p > 0.5:
# mwpf doesn't support negative edge weights.
# mwpf doesn't support negative edge weights (yet, will be supported in the next version).
# approximate them as weight 0.
p = 0.5
weight = math.log((1 - p) / p)
Expand All @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
# mwpf package panic on duplicate edges, thus we need to handle them here
hyperedges = deduplicate_hyperedges(hyperedges)

# fix the input by connecting an edge to all isolated vertices
# fix the input by connecting an edge to all isolated vertices; will be supported in the next version
for idx in range(num_detectors):
if not is_detector_connected[idx]:
hyperedges.append(([idx], 0, 0))
Expand All @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
decoder_cls = mwpf.SolverSerialJointSingleHair
return (
(
decoder_cls(initializer)
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
if num_detectors > 0 and len(rescaled_edges) > 0
else None
),
Expand Down
Loading