Skip to content

Commit

Permalink
Make stim.Tableau.from_stabilizers faster (#713)
Browse files Browse the repository at this point in the history
- Store the growing reduction as a circuit instead of as a tableau
- Measured 20x faster (140ms -> 6ms) on a 144 qubit case
- Measured 100x faster (15s -> 0.13s) on a 432 qubit case
  • Loading branch information
Strilanc authored Mar 14, 2024
1 parent 4040fd8 commit 4f1d217
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 41 deletions.
4 changes: 3 additions & 1 deletion src/stim/stabilizers/conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ Circuit stabilizer_state_vector_to_circuit(
/// ignore_noise: If the circuit contains noise channels, ignore them instead of raising an exception.
/// ignore_measurement: If the circuit contains measurements, ignore them instead of raising an exception.
/// ignore_reset: If the circuit contains resets, ignore them instead of raising an exception.
/// inverse: The last step of the implementation is to invert the tableau. Setting this argument
/// to true will skip this inversion, saving time but returning the inverse tableau.
///
/// Returns:
/// A tableau encoding the given circuit's Clifford operation.
template <size_t W>
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset);
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse = false);

/// Simulates the given circuit and outputs a state vector.
///
Expand Down
89 changes: 49 additions & 40 deletions src/stim/stabilizers/conversions.inl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ std::vector<std::vector<std::complex<float>>> tableau_to_unitary(const Tableau<W
}

template <size_t W>
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset) {
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse) {
Tableau<W> result(circuit.count_qubits());
TableauSimulator<W> sim(std::mt19937_64(0), circuit.count_qubits());

Expand Down Expand Up @@ -185,7 +185,10 @@ Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ig
}
});

return sim.inv_state.inverse();
if (!inverse) {
return sim.inv_state.inverse();
}
return sim.inv_state;
}

template <size_t W>
Expand Down Expand Up @@ -556,7 +559,7 @@ Tableau<W> stabilizers_to_tableau(
}

for (size_t k1 = 0; k1 < stabilizers.size(); k1++) {
for (size_t k2 = 0; k2 < stabilizers.size(); k2++) {
for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) {
if (!stabilizers[k1].ref().commutes(stabilizers[k2])) {
std::stringstream ss;
ss << "Some of the given stabilizers anticommute.\n";
Expand All @@ -568,44 +571,39 @@ Tableau<W> stabilizers_to_tableau(
}
}
}
Tableau<W> inverted(num_qubits);

PauliString<W> cur(num_qubits);
std::vector<size_t> targets;
while (targets.size() < num_qubits) {
targets.push_back(targets.size());
}
auto overwrite_cur_apply_recorded = [&](const PauliString<W> &e) {
PauliStringRef<W> cur_ref = cur.ref();
cur.xs.clear();
cur.zs.clear();
cur.xs.word_range_ref(0, e.xs.num_simd_words) = e.xs;
cur.zs.word_range_ref(0, e.xs.num_simd_words) = e.zs;
cur.sign = e.sign;
inverted.apply_within(cur_ref, targets);
};
Circuit elimination_instructions;
PauliString<W> buf(num_qubits);

size_t used = 0;
for (const auto &e : stabilizers) {
overwrite_cur_apply_recorded(e);
if (e.num_qubits == num_qubits) {
buf = e;
} else {
buf.xs.clear();
buf.zs.clear();
memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded());
memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded());
buf.sign = e.sign;
}
buf.ref().do_circuit(elimination_instructions);

// Find a non-identity term in the Pauli string past the region used by other stabilizers.
size_t pivot;
for (pivot = used; pivot < num_qubits; pivot++) {
if (cur.xs[pivot] || cur.zs[pivot]) {
if (buf.xs[pivot] || buf.zs[pivot]) {
break;
}
}

// Check for incompatible / redundant stabilizers.
if (pivot == num_qubits) {
if (cur.xs.not_zero()) {
if (buf.xs.not_zero()) {
throw std::invalid_argument("Some of the given stabilizers anticommute.");
}
if (cur.sign) {
if (buf.sign) {
throw std::invalid_argument("Some of the given stabilizers contradict each other.");
}
if (!allow_redundant && cur.zs.not_zero()) {
if (!allow_redundant && buf.zs.not_zero()) {
throw std::invalid_argument(
"Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. "
"To allow redundant stabilizers, pass the argument allow_redundant=True.");
Expand All @@ -614,32 +612,36 @@ Tableau<W> stabilizers_to_tableau(
}

// Change pivot basis to the Z axis.
if (cur.xs[pivot]) {
std::string name = cur.zs[pivot] ? "H_YZ" : "H_XZ";
inverted.inplace_scatter_append(GATE_DATA.at(name).tableau<W>(), {pivot});
if (buf.xs[pivot]) {
GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H;
GateTarget t = GateTarget::qubit(pivot);
CircuitInstruction instruction{g, {}, &t};
elimination_instructions.safe_append(instruction);
buf.ref().do_instruction(instruction);
}
// Cancel other terms in Pauli string.
for (size_t q = 0; q < num_qubits; q++) {
int p = cur.xs[q] + cur.zs[q] * 2;
int p = buf.xs[q] + buf.zs[q] * 2;
if (p && q != pivot) {
inverted.inplace_scatter_append(
GATE_DATA.at(p == 1 ? "XCX"
: p == 2 ? "XCZ"
: "XCY")
.tableau<W>(),
{pivot, q});
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(q)};
CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets};
elimination_instructions.safe_append(instruction);
buf.ref().do_instruction(instruction);
}
}

// Move pivot to diagonal.
if (pivot != used) {
inverted.inplace_scatter_append(GATE_DATA.at("SWAP").tableau<W>(), {pivot, used});
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(used)};
CircuitInstruction instruction{GateType::SWAP, {}, targets};
elimination_instructions.safe_append(instruction);
}

// Fix sign.
overwrite_cur_apply_recorded(e);
if (cur.sign) {
inverted.inplace_scatter_append(GATE_DATA.at("X").tableau<W>(), {used});
if (buf.sign) {
GateTarget t = GateTarget::qubit(used);
CircuitInstruction instruction{GateType::X, {}, &t};
elimination_instructions.safe_append(instruction);
}

used++;
Expand All @@ -653,10 +655,17 @@ Tableau<W> stabilizers_to_tableau(
}
}

if (num_qubits > 0) {
// Force size of resulting tableau to be correct.
GateTarget t = GateTarget::qubit(num_qubits - 1);
elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t});
elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t});
}

if (invert) {
return inverted;
return circuit_to_tableau<W>(elimination_instructions.inverse(), false, false, false, true);
}
return inverted.inverse();
return circuit_to_tableau<W>(elimination_instructions, false, false, false, true);
}

} // namespace stim
48 changes: 48 additions & 0 deletions src/stim/stabilizers/conversions.perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,51 @@ BENCHMARK(independent_to_disjoint_xyz_errors) {
std::cout << "data dependence";
}
}

BENCHMARK(stabilizers_to_tableau) {
std::vector<std::complex<float>> offsets{
{1, 0},
{-1, 0},
{0, 1},
{0, -1},
{3, 6},
{-6, 3},
};
size_t w = 24;
size_t h = 12;

auto normalize = [&](std::complex<float> c) -> std::complex<float> {
return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)};
};
auto q2i = [&](std::complex<float> c) -> size_t {
c = normalize(c);
return (int)c.real() / 2 + c.imag() * (w / 2);
};

std::vector<stim::PauliString<64>> stabilizers;
for (size_t x = 0; x < w; x++) {
for (size_t y = x % 2; y < h; y += 2) {
std::complex<float> s{x % 2 ? -1.0f : +1.0f, 0.0f};
std::complex<float> c{(float)x, (float)y};
stim::PauliString<64> ps(w * h / 2);
for (const auto &offset : offsets) {
size_t i = q2i(c + offset * s);
if (x % 2 == 0) {
ps.xs[i] = 1;
} else {
ps.zs[i] = 1;
}
}
stabilizers.push_back(ps);
}
}

size_t dep = 0;
benchmark_go([&]() {
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
dep += t.xs[0].zs[0];
}).goal_millis(5);
if (dep == 99999999) {
std::cout << "data dependence";
}
}

0 comments on commit 4f1d217

Please sign in to comment.