Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QPE bloqs #1297

Merged
merged 9 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qualtran/bloqs/phase_estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# limitations under the License.

from qualtran.bloqs.phase_estimation.lp_resource_state import LPResourceState
from qualtran.bloqs.phase_estimation.qpe_window_state import RectangularWindowState
from qualtran.bloqs.phase_estimation.qubitization_qpe import QubitizationQPE
from qualtran.bloqs.phase_estimation.text_book_qpe import TextbookQPE
143 changes: 80 additions & 63 deletions qualtran/bloqs/phase_estimation/lp_resource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,23 @@
# limitations under the License.

"""Resource states proposed by A. Luis and J. Peřina (1996) for optimal phase measurements"""
from collections import Counter
from functools import cached_property
from typing import Iterator, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, Set, TYPE_CHECKING

import attrs
import cirq
import numpy as np
import sympy
from numpy.typing import NDArray

from qualtran import (
Bloq,
bloq_example,
BloqDocSpec,
GateWithRegisters,
QUInt,
Register,
Side,
Signature,
)
from qualtran.bloqs.basic_gates import CZPowGate, GlobalPhase, Hadamard, OnEach, Ry, Rz, XGate
from qualtran.bloqs.mcmt import MultiControlZ

from qualtran import Bloq, bloq_example, BloqDocSpec, GateWithRegisters, QBit, Signature
from qualtran.bloqs.basic_gates import CZ, Hadamard, OnEach, Ry, Rz, XGate
from qualtran.bloqs.phase_estimation.qpe_window_state import QPEWindowStateBase
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.symbolics import acos, HasLength, is_symbolic, pi, SymbolicInt
from qualtran.symbolics import acos, ceil, is_symbolic, log2, pi, SymbolicFloat, SymbolicInt

if TYPE_CHECKING:
from qualtran import BloqBuilder, Soquet, SoquetT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -67,32 +59,33 @@ def signature(self) -> 'Signature':
def pretty_name(self) -> str:
return 'LPRS'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
def build_composite_bloq(
self, bb: 'BloqBuilder', *, m: 'SoquetT', anc: 'Soquet'
) -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise ValueError(f'Symbolic bitsize {self.bitsize} not supported')
q, anc = quregs['m'].tolist()[::-1], quregs['anc']
yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)]
m = bb.add(OnEach(self.bitsize, Hadamard()), q=m)
q = bb.split(m)[::-1]
anc = bb.add(Hadamard(), q=anc)
for i in range(self.bitsize):
rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1)
yield Rz(angle=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc)
yield Hadamard().on(*anc)
q[i], anc = bb.add(Rz(angle=rz_angle).controlled(), ctrl=q[i], q=anc)
anc = bb.add(Rz(angle=-2 * np.pi / (2**self.bitsize + 1)), q=anc)
anc = bb.add(Hadamard(), q=anc)
return {'m': bb.join(q[::-1]), 'anc': anc}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1)
ret: Set[Tuple[Bloq, SymbolicInt]] = {
(Rz(angle=rz_angle), 1),
(Hadamard(), 2 + self.bitsize),
}
ret: Counter['Bloq'] = Counter()
ret[Rz(angle=rz_angle)] += 1
ret[OnEach(self.bitsize, Hadamard())] += 1
ret[Hadamard()] += 2
if is_symbolic(self.bitsize):
ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)}
ret[Rz(angle=rz_angle).controlled()] += self.bitsize
else:
ret |= {
(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(int(self.bitsize))
}
return ret
for i in range(self.bitsize):
ret[Rz(angle=rz_angle * (2**i)).controlled()] += 1
return set(ret.items())

def _t_complexity_(self) -> 'TComplexity':
# Uses self.bitsize controlled-Rz rotations which decomposes into
Expand All @@ -102,7 +95,7 @@ def _t_complexity_(self) -> 'TComplexity':


@attrs.frozen
class LPResourceState(GateWithRegisters):
class LPResourceState(QPEWindowStateBase):
r"""Prepares optimal resource state $\chi_{m}$ proposed by A. Luis and J. Peřina (1996)

Uses a single round of amplitude amplification, as described in Ref 2, to prepare the
Expand All @@ -128,53 +121,77 @@ class LPResourceState(GateWithRegisters):

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('m', QUInt(self.bitsize), side=Side.THRU)])
return Signature([self.m_register])

@classmethod
def from_standard_deviation_eps(cls, eps: SymbolicFloat):
r"""Estimate the phase $\phi$ with uncertainty in standard deviation bounded by $\epsilon$.

The standard deviation of phase estimation using optimal resource states scales as the
square of Holevo variance $\tan{\frac{\pi}{2^m}}$.
This bound can be used to estimate the size of the phase register s.t. the estimated phase
has a standard deviation of at-most $\epsilon$. See the class docstring for more details.

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> Iterator[cirq.OP_TREE]:
"""Use the _LPResourceStateHelper and do a single round of amplitude amplification."""
q = quregs['m'].flatten().tolist()
anc, flag = context.qubit_manager.qalloc(2)
```
m = ceil(log2(pi/eps))
```

Args:
eps: Maximum standard deviation of the estimated phase.
"""
return LPResourceState(ceil(log2(pi(eps) / eps)))

@property
def m_bits(self) -> SymbolicInt:
return self.bitsize

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
qpe_reg = bb.allocate(dtype=self.m_register.dtype)
anc, flag = bb.allocate(dtype=QBit()), bb.allocate(dtype=QBit())

flag_angle = np.arccos(1 / (1 + 2**self.bitsize))

# Prepare initial state
yield Ry(angle=flag_angle).on(flag)
yield LPRSInterimPrep(self.bitsize).on(*q, anc)
flag = bb.add(Ry(angle=flag_angle), q=flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize), m=qpe_reg, anc=anc)

# Reflect around the target state
yield CZPowGate().on(flag, anc)
flag, anc = bb.add(CZ(), q1=flag, q2=anc)

# Reflect around the initial state
yield LPRSInterimPrep(self.bitsize).adjoint().on(*q, anc)
yield Ry(angle=-flag_angle).on(flag)

yield XGate().on(flag)
yield MultiControlZ((0,) * (self.bitsize + 1)).on(*q, anc, flag)
yield XGate().on(flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize).adjoint(), m=qpe_reg, anc=anc)
flag = bb.add(Ry(angle=-flag_angle), q=flag)

flag, anc, qpe_reg = bb.add(
ReflectionUsingPrepare.reflection_around_zero([1, 1, self.bitsize], global_phase=1j),
reg0_=flag,
reg1_=anc,
reg2_=qpe_reg,
)

yield LPRSInterimPrep(self.bitsize).on(*q, anc)
yield Ry(angle=flag_angle).on(flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize), m=qpe_reg, anc=anc)
flag = bb.add(Ry(angle=flag_angle), q=flag)

# Reset ancilla to |0> state.
yield [XGate().on(flag), XGate().on(anc)]
yield GlobalPhase(exponent=0.5).on()
context.qubit_manager.qfree([flag, anc])
flag = bb.add(XGate(), q=flag)
anc = bb.add(XGate(), q=anc)
bb.free(flag)
bb.free(anc)
return {'qpe_reg': qpe_reg}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
flag_angle = acos(1 / (1 + 2**self.bitsize))
cvs: Union[HasLength, Tuple[int, ...]] = (
HasLength(self.bitsize + 1) if is_symbolic(self.bitsize) else (0,) * (self.bitsize + 1)
reflection_bloq: 'Bloq' = ReflectionUsingPrepare.reflection_around_zero(
[1, 1, self.bitsize], global_phase=1j
)
return {
(LPRSInterimPrep(self.bitsize), 2),
(LPRSInterimPrep(self.bitsize).adjoint(), 1),
(Ry(angle=flag_angle), 3),
(MultiControlZ(cvs), 1),
(XGate(), 4),
(GlobalPhase(exponent=0.5), 1),
(CZPowGate(), 1),
(Ry(angle=flag_angle), 2),
(Ry(angle=-1 * flag_angle), 1),
(reflection_bloq, 1),
(XGate(), 2),
(CZ(), 1),
}


Expand Down
62 changes: 13 additions & 49 deletions qualtran/bloqs/phase_estimation/lp_resource_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
LPRSInterimPrep,
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.cirq_interop.testing import GateHelper
from qualtran.resource_counting.generalizers import (
generalize_rotation_angle,
ignore_alloc_free,
Expand All @@ -42,99 +41,64 @@ def test_lp_resource_state_auto(bloq_autotester):

def test_lp_resource_state_symb():
bloq = _lp_resource_state_symbolic.make()
assert bloq.t_complexity().t == 4 * bloq.bitsize
assert bloq.t_complexity().t == 4 * bloq.bitsize + 4


def get_interim_resource_state(m: int) -> np.ndarray:
N = 2**m
state_vector = np.zeros(2 * N, dtype=np.complex128)
state_vector[:N] = np.cos(np.pi * (1 + np.arange(N)) / (1 + N))
state_vector[N:] = 1j * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))
return np.sqrt(1 / N) * state_vector
state_vector = np.zeros((N, 2), dtype=np.complex128)
state_vector[:, 0] = np.cos(np.pi * (1 + np.arange(N)) / (1 + N))
state_vector[:, 1] = 1j * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))
return np.sqrt(1 / N) * state_vector.reshape(2 * N)


def get_resource_state(m: int) -> np.ndarray:
N = 2**m
return np.sqrt(2 / (1 + N)) * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))


def test_intermediate_resource_state_cirq_quick():
n = 3
bloq = LPRSInterimPrep(n)
state = GateHelper(bloq).circuit.final_state_vector()
np.testing.assert_allclose(state, get_interim_resource_state(n))


def test_intermediate_resource_state_tensor_quick():
n = 3
bloq = LPRSInterimPrep(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
pytest.xfail("https://github.com/quantumlib/Qualtran/issues/1068")
np.testing.assert_allclose(state_vec, get_interim_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_intermediate_resource_state_cirq(n):
def test_intermediate_resource_state(n):
bloq = LPRSInterimPrep(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = initialize_from_zero(bloq).tensor_contract()
np.testing.assert_allclose(state, get_interim_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_intermediate_resource_state_tensor(n):
bloq = LPRSInterimPrep(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
pytest.xfail("https://github.com/quantumlib/Qualtran/issues/1068")
np.testing.assert_allclose(state_vec, get_interim_resource_state(n))


def test_prepares_resource_state_cirq_quick():
def test_prepares_resource_state_quick():
n = 3
bloq = LPResourceState(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = bloq.tensor_contract()
np.testing.assert_allclose(state, get_resource_state(n))


def test_prepares_resource_state_tensor_quick():
n = 3
bloq = LPResourceState(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
np.testing.assert_allclose(state_vec, get_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_prepares_resource_state_cirq(n):
def test_prepares_resource_state(n):
bloq = LPResourceState(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = bloq.tensor_contract()
np.testing.assert_allclose(state, get_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_prepares_resource_state_tensor(n):
bloq = LPResourceState(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
np.testing.assert_allclose(state_vec, get_resource_state(n))


@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_t_complexity(n):
bloq = LPResourceState(n)
qlt_testing.assert_equivalent_bloq_counts(
bloq, [ignore_split_join, ignore_alloc_free, generalize_rotation_angle]
)
lprs_interim_count = 3 * TComplexity(rotations=2 * n + 1, clifford=2 + 3 * n)
multi_control_pauli_count = TComplexity(t=4 * n, clifford=17 * n + 5)
reflection_using_prepare = TComplexity(t=4 * n + 4, clifford=17 * n + 22)
misc_count = TComplexity(rotations=3, clifford=5)

assert bloq.t_complexity() == (lprs_interim_count + multi_control_pauli_count + misc_count)
assert bloq.t_complexity() == (lprs_interim_count + reflection_using_prepare + misc_count)


@pytest.mark.parametrize('bitsize', [*range(1, 14, 2)])
Expand Down
Loading
Loading