Skip to content

Commit

Permalink
Implement cutting of general 2-qubit unitaries
Browse files Browse the repository at this point in the history
Builds on #294.  Closes #186.
  • Loading branch information
garrison committed Jul 2, 2023
1 parent 69a1e69 commit 727d159
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
39 changes: 38 additions & 1 deletion circuit_knitting/cutting/qpd/qpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
iSwapGate,
DCXGate,
)
from qiskit.extensions import UnitaryGate
from qiskit.quantum_info.synthesis.two_qubit_decompose import TwoQubitWeylDecomposition

from .qpd_basis import QPDBasis
from .instructions import BaseQPDGate, TwoQubitQPDGate, QPDMeasure
Expand Down Expand Up @@ -225,10 +227,25 @@ def qpdbasis_from_gate(gate: Gate) -> QPDBasis:
try:
f = _qpdbasis_from_gate_funcs[gate.name]
except KeyError:
raise ValueError(f"Gate not supported: {gate.name}") from None
pass
else:
return f(gate)

if isinstance(gate, Gate) and gate.num_qubits == 2:
mat = gate.to_matrix()
d = TwoQubitWeylDecomposition(mat)
u = _u_from_thetavec([d.a, d.b, d.c])
retval = _nonlocal_qpd_basis_from_u(u)
for operations in unique_by_id(m[0] for m in retval.maps):
operations.insert(0, UnitaryGate(d.K2r))
operations.append(UnitaryGate(d.K1r))
for operations in unique_by_id(m[1] for m in retval.maps):
operations.insert(0, UnitaryGate(d.K2l))
operations.append(UnitaryGate(d.K1l))
return retval

raise ValueError(f"Gate not supported: {gate.name}") from None


def supported_gates() -> set[str]:
"""
Expand Down Expand Up @@ -256,6 +273,26 @@ def _copy_unique_sublists(lsts: tuple[list, ...], /) -> tuple[list, ...]:
return tuple(copy_by_id[id(lst)] for lst in lsts)


def _u_from_thetavec(
theta: np.typing.NDArray[np.float64] | Sequence[float], /
) -> np.ndarray[np.float64]:
theta = np.asarray(theta)
if theta.shape != (3,):
raise ValueError(
f"theta vector has wrong shape: {theta.shape} (1D vector of length 3 expected)"
)
eigvals = np.array(
[
-np.sum(theta),
-theta[0] + theta[1] + theta[2],
-theta[1] + theta[2] + theta[0],
-theta[2] + theta[0] + theta[1],
]
)
eigvecs = np.ones([1, 1]) / 2 - np.eye(4)
return np.transpose(eigvecs) @ (np.exp(1j * eigvals) * eigvecs[:, 0])


def _nonlocal_qpd_basis_from_u(
u: np.typing.NDArray[np.complex128] | Sequence[complex], /
) -> QPDBasis:
Expand Down
22 changes: 21 additions & 1 deletion test/cutting/qpd/test_qpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
generate_qpd_samples,
)
from circuit_knitting.cutting.qpd.qpd import *
from circuit_knitting.cutting.qpd.qpd import _nonlocal_qpd_basis_from_u
from circuit_knitting.cutting.qpd.qpd import (
_nonlocal_qpd_basis_from_u,
_u_from_thetavec,
)


@ddt
Expand Down Expand Up @@ -318,3 +321,20 @@ def test_nonlocal_qpd_basis_from_u(self):
e_info.value.args[0]
== "u vector has wrong shape: (3,) (1D vector of length 4 expected)"
)

@data(
([np.pi / 4] * 3, [(1 + 1j) / np.sqrt(8)] * 4),
([np.pi / 4, np.pi / 4, 0], [0.5, 0.5j, 0.5j, 0.5]),
)
@unpack
def test_u_from_thetavec(self, theta, expected):
assert _u_from_thetavec(theta) == pytest.approx(expected)

def test_u_from_thetavec_exceptions(self):
with self.subTest("Invalid shape"):
with pytest.raises(ValueError) as e_info:
_u_from_thetavec([0, 1, 2, 3])
assert (
e_info.value.args[0]
== "theta vector has wrong shape: (4,) (1D vector of length 3 expected)"
)
4 changes: 2 additions & 2 deletions test/cutting/qpd/test_qpd_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_eq(self):

def test_unsupported_gate(self):
with pytest.raises(ValueError) as e_info:
QPDBasis.from_gate(XXMinusYYGate(0.1))
assert e_info.value.args[0] == "Gate not supported: xx_minus_yy"
QPDBasis.from_gate(C3XGate())
assert e_info.value.args[0] == "Gate not supported: mcx"

def test_unbound_parameter(self):
with pytest.raises(ValueError) as e_info:
Expand Down
1 change: 1 addition & 0 deletions test/cutting/test_cutting_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def append_random_unitary(circuit: QuantumCircuit, qubits):
[CRYGate(np.pi / 7)],
[CRZGate(np.pi / 11)],
[RXXGate(np.pi / 3), CRYGate(np.pi / 7)],
[UnitaryGate(random_unitary(2**2))],
]
)
def example_circuit(
Expand Down

0 comments on commit 727d159

Please sign in to comment.