Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add act_on protocol #3019

Merged
merged 8 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@
)

from cirq.sim import (
ActOnStateVectorArgs,
StabilizerStateChForm,
CIRCUIT_LIKE,
CliffordSimulator,
Expand Down Expand Up @@ -394,6 +395,7 @@

# pylint: disable=redefined-builtin
from cirq.protocols import (
act_on,
apply_channel,
apply_mixture,
apply_unitaries,
Expand Down Expand Up @@ -438,10 +440,11 @@
QuilFormatter,
read_json,
resolve_parameters,
SupportsActOn,
SupportsApplyChannel,
SupportsApplyMixture,
SupportsConsistentApplyUnitary,
SupportsApproximateEquality,
SupportsConsistentApplyUnitary,
SupportsChannel,
SupportsCircuitDiagramInfo,
SupportsCommutes,
Expand Down
23 changes: 23 additions & 0 deletions cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,29 @@ def __init__(self, dimension: int = 2) -> None:
def _qid_shape_(self):
return (self._dimension,)

def _act_on_(self, args: Any):
from cirq import sim

if isinstance(args, sim.ActOnStateVectorArgs):
# Do a silent measurement.
measurements, _ = sim.measure_state_vector(
args.target_tensor,
args.axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape)
result = measurements[0]

# Use measurement result to zero the qid.
if result:
zero = args.subspace_index(0)
other = args.subspace_index(result)
args.target_tensor[zero] = args.target_tensor[other]
args.target_tensor[other] = 0

return True

return NotImplemented

def _channel_(self) -> Iterable[np.ndarray]:
# The first axis is over the list of channel matrices
channel = np.zeros((self._dimension,) * 3, dtype=np.complex64)
Expand Down
62 changes: 50 additions & 12 deletions cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.Operation':
return self.gate.on(*new_qubits)

def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation':
if self.gate is new_gate:
return self
return new_gate.on(*self.qubits)

def __repr__(self):
Expand Down Expand Up @@ -103,7 +105,7 @@ def _value_equality_values_(self):
return self.gate, self._group_interchangeable_qubits()

def _qid_shape_(self):
return protocols.qid_shape(self.gate)
return self.gate._qid_shape_()

def _num_qubits_(self):
return len(self._qubits)
Expand All @@ -114,33 +116,57 @@ def _decompose_(self) -> 'cirq.OP_TREE':
NotImplemented)

def _pauli_expansion_(self) -> value.LinearDict[str]:
return protocols.pauli_expansion(self.gate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the advantage of delegating in these methods? I don't see why this is better and adds a lot of boilerplate getattr code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The triggering event for me is that the allow_decompose parameter was not being respected in several places when GateOperation was in play. For example, has_unitary(gate_op, allow_decompose=True) would call GateOperation._has_unitary_, which would call has_unitary(gate) with no allow_decompose=True, resulting in the decomposition being computed.

Basically, if we bounce a protocol method back into itself then we can get an amplification effect where more work than necessary is done because fallback logic is run once per bounce-back-level. By delegating directly into the implementation of the sub gate, fallback logic is only run once. This makes things more efficient, and honestly a bit more stable or at least predictable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well about decompose :). We should probably make this the norm then throughout.

getter = getattr(self.gate, '_pauli_expansion_', None)
if getter is not None:
return getter()
return NotImplemented

def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs'
) -> Union[np.ndarray, None, NotImplementedType]:
return protocols.apply_unitary(self.gate, args, default=None)
getter = getattr(self.gate, '_apply_unitary_', None)
if getter is not None:
return getter(args)
return NotImplemented

def _has_unitary_(self) -> bool:
return protocols.has_unitary(self.gate)
getter = getattr(self.gate, '_has_unitary_', None)
if getter is not None:
return getter()
return NotImplemented

def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return protocols.unitary(self.gate, default=None)
getter = getattr(self.gate, '_unitary_', None)
if getter is not None:
return getter()
return NotImplemented

def _commutes_(self, other: Any,
atol: float) -> Union[bool, NotImplementedType, None]:
return self.gate._commutes_on_qids_(self.qubits, other, atol=atol)

def _has_mixture_(self) -> bool:
return protocols.has_mixture(self.gate)
getter = getattr(self.gate, '_has_mixture_', None)
if getter is not None:
return getter()
return NotImplemented

def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return protocols.mixture(self.gate, NotImplemented)
getter = getattr(self.gate, '_mixture_', None)
if getter is not None:
return getter()
return NotImplemented

def _has_channel_(self) -> bool:
return protocols.has_channel(self.gate)
getter = getattr(self.gate, '_has_channel_', None)
if getter is not None:
return getter()
return NotImplemented

def _channel_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.channel(self.gate, NotImplemented)
getter = getattr(self.gate, '_channel_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_(self) -> Optional[str]:
getter = getattr(self.gate, '_measurement_key_', None)
Expand All @@ -154,12 +180,21 @@ def _measurement_keys_(self) -> Optional[Iterable[str]]:
return getter()
return NotImplemented

def _act_on_(self, args: Any):
getter = getattr(self.gate, '_act_on_', None)
if getter is not None:
return getter(args)
return NotImplemented

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.gate)
getter = getattr(self.gate, '_is_parameterized_', None)
if getter is not None:
return getter()
return NotImplemented

def _resolve_parameters_(self, resolver):
resolved_gate = protocols.resolve_parameters(self.gate, resolver)
return GateOperation(resolved_gate, self._qubits)
return self.with_gate(resolved_gate)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
Expand All @@ -174,7 +209,10 @@ def _decompose_into_clifford_(self):
return sub(self.qubits)

def _trace_distance_bound_(self) -> float:
return protocols.trace_distance_bound(self.gate)
getter = getattr(self.gate, '_trace_distance_bound_', None)
if getter is not None:
return getter()
return NotImplemented

def _phase_by_(self, phase_turns: float,
qubit_index: int) -> 'GateOperation':
Expand Down
52 changes: 52 additions & 0 deletions cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def test_pauli_expansion():
assert (cirq.pauli_expansion(cirq.CNOT(a, b)) == cirq.pauli_expansion(
cirq.CNOT))

class No(cirq.Gate):

def num_qubits(self) -> int:
return 1

class Yes(cirq.Gate):

def num_qubits(self) -> int:
return 1

def _pauli_expansion_(self):
return cirq.LinearDict({'X': 0.5})

assert cirq.pauli_expansion(No().on(a), default=None) is None
assert cirq.pauli_expansion(Yes().on(a)) == cirq.LinearDict({'X': 0.5})


def test_unitary():
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -344,3 +360,39 @@ def _mul_with_qubits(self, qubits, other):
# Handles the symmetric type case correctly.
assert m * m == 6
assert r * r == 4


def test_with_gate():
g1 = cirq.GateOperation(cirq.X, cirq.LineQubit.range(1))
g2 = cirq.GateOperation(cirq.Y, cirq.LineQubit.range(1))
assert g1.with_gate(cirq.X) is g1
assert g1.with_gate(cirq.Y) == g2


def test_is_parameterized():

class No1(cirq.Gate):

def num_qubits(self) -> int:
return 1

class No2(cirq.Gate):

def num_qubits(self) -> int:
return 1

def _is_parameterized_(self):
return False

class Yes(cirq.Gate):

def num_qubits(self) -> int:
return 1

def _is_parameterized_(self):
return True

q = cirq.LineQubit(0)
assert not cirq.is_parameterized(No1().on(q))
assert not cirq.is_parameterized(No2().on(q))
assert cirq.is_parameterized(Yes().on(q))
25 changes: 24 additions & 1 deletion cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING
from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, \
TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -215,6 +216,28 @@ def _from_json_dict_(cls,
invert_mask=tuple(invert_mask),
qid_shape=None if qid_shape is None else tuple(qid_shape))

def _act_on_(self, args: Any) -> bool:
dabacon marked this conversation as resolved.
Show resolved Hide resolved
from cirq import sim

if isinstance(args, sim.ActOnStateVectorArgs):

invert_mask = self.full_invert_mask()
bits, _ = sim.measure_state_vector(
args.target_tensor,
args.axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape,
seed=args.prng)
corrected = [
bit ^ (bit < 2 and mask)
for bit, mask in zip(bits, invert_mask)
]
args.record_measurement_result(self.key, corrected)

return True

return NotImplemented


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)
4 changes: 4 additions & 0 deletions cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.


from cirq.protocols.act_on_protocol import (
act_on,
SupportsActOn,
)
from cirq.protocols.apply_unitary_protocol import (
apply_unitaries,
apply_unitary,
Expand Down
Loading