diff --git a/cirq/__init__.py b/cirq/__init__.py index 6c4f2cdb9e9..220604d45f7 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -315,6 +315,7 @@ ) from cirq.sim import ( + ActOnStateVectorArgs, StabilizerStateChForm, CIRCUIT_LIKE, CliffordSimulator, @@ -394,6 +395,7 @@ # pylint: disable=redefined-builtin from cirq.protocols import ( + act_on, apply_channel, apply_mixture, apply_unitaries, @@ -438,10 +440,11 @@ QuilFormatter, read_json, resolve_parameters, + SupportsActOn, SupportsApplyChannel, SupportsApplyMixture, - SupportsConsistentApplyUnitary, SupportsApproximateEquality, + SupportsConsistentApplyUnitary, SupportsChannel, SupportsCircuitDiagramInfo, SupportsCommutes, diff --git a/cirq/ops/common_channels.py b/cirq/ops/common_channels.py index f2ed4d67140..987bd433cce 100644 --- a/cirq/ops/common_channels.py +++ b/cirq/ops/common_channels.py @@ -552,6 +552,29 @@ def __init__(self, dimension: int = 2) -> None: def _qid_shape_(self): return (self._dimension,) + def _act_on_(self, args: Any): + from cirq import sim + + if isinstance(args, sim.ActOnStateVectorArgs): + # Do a silent measurement. + measurements, _ = sim.measure_state_vector( + args.target_tensor, + args.axes, + out=args.target_tensor, + qid_shape=args.target_tensor.shape) + result = measurements[0] + + # Use measurement result to zero the qid. + if result: + zero = args.subspace_index(0) + other = args.subspace_index(result) + args.target_tensor[zero] = args.target_tensor[other] + args.target_tensor[other] = 0 + + return True + + return NotImplemented + def _channel_(self) -> Iterable[np.ndarray]: # The first axis is over the list of channel matrices channel = np.zeros((self._dimension,) * 3, dtype=np.complex64) diff --git a/cirq/ops/common_channels_test.py b/cirq/ops/common_channels_test.py index d5d01984b4f..1e2b647cb10 100644 --- a/cirq/ops/common_channels_test.py +++ b/cirq/ops/common_channels_test.py @@ -314,6 +314,37 @@ def test_reset_channel_text_diagram(): cirq.ResetChannel(3)) == cirq.CircuitDiagramInfo(wire_symbols=('R',))) +def test_reset_act_on(): + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(cirq.ResetChannel(), object()) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(1, 1, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + + cirq.act_on(cirq.ResetChannel(), args) + assert args.log_of_measurement_results == {} + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) + + cirq.act_on(cirq.ResetChannel(), args) + assert args.log_of_measurement_results == {} + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) + + def test_phase_damping_channel(): d = cirq.phase_damp(0.3) np.testing.assert_almost_equal(cirq.channel(d), diff --git a/cirq/ops/gate_operation.py b/cirq/ops/gate_operation.py index d7300b28a24..a3f8f733419 100644 --- a/cirq/ops/gate_operation.py +++ b/cirq/ops/gate_operation.py @@ -56,6 +56,8 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.Operation': return self.gate.on(*new_qubits) def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation': + if self.gate is new_gate: + return self return new_gate.on(*self.qubits) def __repr__(self): @@ -103,7 +105,7 @@ def _value_equality_values_(self): return self.gate, self._group_interchangeable_qubits() def _qid_shape_(self): - return protocols.qid_shape(self.gate) + return self.gate._qid_shape_() def _num_qubits_(self): return len(self._qubits) @@ -114,33 +116,57 @@ def _decompose_(self) -> 'cirq.OP_TREE': NotImplemented) def _pauli_expansion_(self) -> value.LinearDict[str]: - return protocols.pauli_expansion(self.gate) + getter = getattr(self.gate, '_pauli_expansion_', None) + if getter is not None: + return getter() + return NotImplemented def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs' ) -> Union[np.ndarray, None, NotImplementedType]: - return protocols.apply_unitary(self.gate, args, default=None) + getter = getattr(self.gate, '_apply_unitary_', None) + if getter is not None: + return getter(args) + return NotImplemented def _has_unitary_(self) -> bool: - return protocols.has_unitary(self.gate) + getter = getattr(self.gate, '_has_unitary_', None) + if getter is not None: + return getter() + return NotImplemented def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: - return protocols.unitary(self.gate, default=None) + getter = getattr(self.gate, '_unitary_', None) + if getter is not None: + return getter() + return NotImplemented def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]: return self.gate._commutes_on_qids_(self.qubits, other, atol=atol) def _has_mixture_(self) -> bool: - return protocols.has_mixture(self.gate) + getter = getattr(self.gate, '_has_mixture_', None) + if getter is not None: + return getter() + return NotImplemented def _mixture_(self) -> Sequence[Tuple[float, Any]]: - return protocols.mixture(self.gate, NotImplemented) + getter = getattr(self.gate, '_mixture_', None) + if getter is not None: + return getter() + return NotImplemented def _has_channel_(self) -> bool: - return protocols.has_channel(self.gate) + getter = getattr(self.gate, '_has_channel_', None) + if getter is not None: + return getter() + return NotImplemented def _channel_(self) -> Union[Tuple[np.ndarray], NotImplementedType]: - return protocols.channel(self.gate, NotImplemented) + getter = getattr(self.gate, '_channel_', None) + if getter is not None: + return getter() + return NotImplemented def _measurement_key_(self) -> Optional[str]: getter = getattr(self.gate, '_measurement_key_', None) @@ -154,12 +180,21 @@ def _measurement_keys_(self) -> Optional[Iterable[str]]: return getter() return NotImplemented + def _act_on_(self, args: Any): + getter = getattr(self.gate, '_act_on_', None) + if getter is not None: + return getter(args) + return NotImplemented + def _is_parameterized_(self) -> bool: - return protocols.is_parameterized(self.gate) + getter = getattr(self.gate, '_is_parameterized_', None) + if getter is not None: + return getter() + return NotImplemented def _resolve_parameters_(self, resolver): resolved_gate = protocols.resolve_parameters(self.gate, resolver) - return GateOperation(resolved_gate, self._qubits) + return self.with_gate(resolved_gate) def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': @@ -174,7 +209,10 @@ def _decompose_into_clifford_(self): return sub(self.qubits) def _trace_distance_bound_(self) -> float: - return protocols.trace_distance_bound(self.gate) + getter = getattr(self.gate, '_trace_distance_bound_', None) + if getter is not None: + return getter() + return NotImplemented def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'GateOperation': diff --git a/cirq/ops/gate_operation_test.py b/cirq/ops/gate_operation_test.py index c0ba51e6bd5..53cee754f54 100644 --- a/cirq/ops/gate_operation_test.py +++ b/cirq/ops/gate_operation_test.py @@ -207,6 +207,22 @@ def test_pauli_expansion(): assert (cirq.pauli_expansion(cirq.CNOT(a, b)) == cirq.pauli_expansion( cirq.CNOT)) + class No(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + class Yes(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _pauli_expansion_(self): + return cirq.LinearDict({'X': 0.5}) + + assert cirq.pauli_expansion(No().on(a), default=None) is None + assert cirq.pauli_expansion(Yes().on(a)) == cirq.LinearDict({'X': 0.5}) + def test_unitary(): a = cirq.NamedQubit('a') @@ -344,3 +360,39 @@ def _mul_with_qubits(self, qubits, other): # Handles the symmetric type case correctly. assert m * m == 6 assert r * r == 4 + + +def test_with_gate(): + g1 = cirq.GateOperation(cirq.X, cirq.LineQubit.range(1)) + g2 = cirq.GateOperation(cirq.Y, cirq.LineQubit.range(1)) + assert g1.with_gate(cirq.X) is g1 + assert g1.with_gate(cirq.Y) == g2 + + +def test_is_parameterized(): + + class No1(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + class No2(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _is_parameterized_(self): + return False + + class Yes(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _is_parameterized_(self): + return True + + q = cirq.LineQubit(0) + assert not cirq.is_parameterized(No1().on(q)) + assert not cirq.is_parameterized(No2().on(q)) + assert cirq.is_parameterized(Yes().on(q)) diff --git a/cirq/ops/measurement_gate.py b/cirq/ops/measurement_gate.py index 05e3c1cd828..4044dc2bb4c 100644 --- a/cirq/ops/measurement_gate.py +++ b/cirq/ops/measurement_gate.py @@ -215,6 +215,28 @@ def _from_json_dict_(cls, invert_mask=tuple(invert_mask), qid_shape=None if qid_shape is None else tuple(qid_shape)) + def _act_on_(self, args: Any) -> bool: + from cirq import sim + + if isinstance(args, sim.ActOnStateVectorArgs): + + invert_mask = self.full_invert_mask() + bits, _ = sim.measure_state_vector( + args.target_tensor, + args.axes, + out=args.target_tensor, + qid_shape=args.target_tensor.shape, + seed=args.prng) + corrected = [ + bit ^ (bit < 2 and mask) + for bit, mask in zip(bits, invert_mask) + ] + args.record_measurement_result(self.key, corrected) + + return True + + return NotImplemented + def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str: return ','.join(str(q) for q in qubits) diff --git a/cirq/ops/measurement_gate_test.py b/cirq/ops/measurement_gate_test.py index 3f19c0a7534..f8797d61f9f 100644 --- a/cirq/ops/measurement_gate_test.py +++ b/cirq/ops/measurement_gate_test.py @@ -189,3 +189,89 @@ def test_op_repr(): "cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), " "key='out', " "invert_mask=(False, True))") + + +def test_act_on(): + a, b = cirq.LineQubit.range(2) + m = cirq.measure(a, b, key='out', invert_mask=(True,)) + + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(m, object()) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2, 2, 2), dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [1, 0]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 0, 0), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [1, 1]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 1, 0), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + + with pytest.raises(ValueError, match="already logged to key"): + cirq.act_on(m, args) + + +def test_act_on_qutrit(): + a, b = cirq.LineQid.range(2, dimension=3) + m = cirq.measure(a, b, key='out', invert_mask=(True,)) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 2, 0, 2, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [2, 2]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 2, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [2, 1]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 2, 0, 1, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 2]} diff --git a/cirq/protocols/__init__.py b/cirq/protocols/__init__.py index bbf9e8d3585..ccd5d295940 100644 --- a/cirq/protocols/__init__.py +++ b/cirq/protocols/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. +from cirq.protocols.act_on_protocol import ( + act_on, + SupportsActOn, +) from cirq.protocols.apply_unitary_protocol import ( apply_unitaries, apply_unitary, diff --git a/cirq/protocols/act_on_protocol.py b/cirq/protocols/act_on_protocol.py new file mode 100644 index 00000000000..71fec022cf1 --- /dev/null +++ b/cirq/protocols/act_on_protocol.py @@ -0,0 +1,121 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A protocol that wouldn't exist if python had __rimul__.""" + +from typing import (Any, TYPE_CHECKING, Union) + +from typing_extensions import Protocol + +from cirq._doc import document +from cirq.type_workarounds import NotImplementedType + +if TYPE_CHECKING: + pass + + +class SupportsActOn(Protocol): + """An object that explicitly specifies how to act on simulator states.""" + + @document + def _act_on_(self, args: Any) -> Union[NotImplementedType, bool]: + """Applies an action to the given argument, if it is a supported type. + + For example, unitary operations can implement an `_act_on_` method that + checks if `isinstance(args, cirq.ActOnStateVectorArgs)` and, if so, + apply their unitary effect to the state vector. + + The global `cirq.act_on` method looks for whether or not the given + argument has this value, before attempting any fallback strategies + specified by the argument being acted on. + + This method is analogous to python's `__imul__` in that it is expected + to perform an inline effect if it recognizes the type of an argument, + and return NotImplemented otherwise. It is also analogous to python's + `__rmul__` in that dispatch is being done on the right hand side value + instead of the left hand side value. If python had an `__rimul__` + method, then `_act_on_` would not exist because it would be redundant. + + Args: + args: An object of unspecified type. The method must check if this + object is of a recognized type and act on it if so. + + Returns: + True: The receiving object (`self`) acted on the argument. + NotImplemented: The receiving object did not act on the argument. + + All other return values are considered to be errors. + """ + + +def act_on( + action: Any, + args: Any, + *, + allow_decompose: bool = True, +): + """Applies an action to a state argument. + + For example, the action may be a `cirq.Operation` and the state argument may + represent the internal state of a state vector simulator (a + `cirq.ActOnStateVectorArgs`). + + The action is applied by first checking if `action._act_on_` exists and + returns `True` (instead of `NotImplemented`) for the given object. Then + fallback strategies specified by the state argument via `_act_on_fallback_` + are attempted. If those also fail, the method fails with a `TypeError`. + + Args: + action: The action to apply to the state tensor. Typically a + `cirq.Operation`. + args: A mutable state object that should be modified by the action. May + specify an `_act_on_fallback_` method to use in case the action + doesn't recognize it. + allow_decompose: Defaults to True. Forwarded into the + `_act_on_fallback_` method of `args`. Determines if decomposition + should be used or avoided when attempting to act `action` on `args`. + Used by internal methods to avoid redundant decompositions. + + Returns: + Nothing. Results are communicated by editing `args`. + + Raises: + TypeError: Failed to act `action` on `args`. + """ + + action_act_on = getattr(action, '_act_on_', None) + if action_act_on is not None: + result = action_act_on(args) + if result is True: + return + if result is not NotImplemented: + raise ValueError( + f'_act_on_ must return True or NotImplemented but got ' + f'{result!r} from {action!r}._act_on_') + + arg_fallback = getattr(args, '_act_on_fallback_', None) + if arg_fallback is not None: + result = arg_fallback(action, allow_decompose=allow_decompose) + if result is True: + return + if result is not NotImplemented: + raise ValueError( + f'_act_on_fallback_ must return True or NotImplemented but got ' + f'{result!r} from {type(args)}._act_on_fallback_') + + raise TypeError("Failed to act action on state argument.\n" + "Tried both action._act_on_ and args._act_on_fallback_.\n" + "\n" + f"State argument type: {type(args)}\n" + f"Action type: {type(action)}\n" + f"Action repr: {action!r}\n") diff --git a/cirq/protocols/act_on_protocol_test.py b/cirq/protocols/act_on_protocol_test.py new file mode 100644 index 00000000000..2b7315e2f67 --- /dev/null +++ b/cirq/protocols/act_on_protocol_test.py @@ -0,0 +1,35 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A protocol that wouldn't exist if python had __rimul__.""" + +import pytest + +import cirq + + +def test_act_on_checks(): + + class Bad(): + + def _act_on_(self, args): + return False + + def _act_on_fallback_(self, action, allow_decompose): + return False + + with pytest.raises(ValueError, match="must return True or NotImplemented"): + _ = cirq.act_on(Bad(), object()) + + with pytest.raises(ValueError, match="must return True or NotImplemented"): + _ = cirq.act_on(object(), Bad()) diff --git a/cirq/protocols/apply_unitary_protocol.py b/cirq/protocols/apply_unitary_protocol.py index 986110e6868..0ed9addfbe2 100644 --- a/cirq/protocols/apply_unitary_protocol.py +++ b/cirq/protocols/apply_unitary_protocol.py @@ -190,9 +190,6 @@ def subspace_index(self, bit of the integer is the desired bit for the first axis, and so forth in decreasing order. Can't be specified at the same time as `little_endian_bits_int`. - value_tuple: The desired value of the qids at the targeted `axes`, - packed into a tuple. Specify either `little_endian_bits_int` or - `value_tuple`. Returns: A value that can be used to index into `target_tensor` and @@ -215,7 +212,8 @@ def subspace_index(self, return linalg.slice_for_qubits_equal_to( self.axes, little_endian_qureg_value=little_endian_bits_int, - big_endian_qureg_value=big_endian_bits_int) + big_endian_qureg_value=big_endian_bits_int, + qid_shape=self.target_tensor.shape) class SupportsConsistentApplyUnitary(Protocol): @@ -265,10 +263,13 @@ def _apply_unitary_(self, args: ApplyUnitaryArgs """ -def apply_unitary(unitary_value: Any, - args: ApplyUnitaryArgs, - default: TDefault = RaiseTypeErrorIfNotProvided - ) -> Union[np.ndarray, TDefault]: +def apply_unitary( + unitary_value: Any, + args: ApplyUnitaryArgs, + default: TDefault = RaiseTypeErrorIfNotProvided, + *, + allow_decompose: bool = True, +) -> Union[np.ndarray, TDefault]: """High performance left-multiplication of a unitary effect onto a tensor. Applies the unitary effect of `unitary_value` to the tensor specified in @@ -290,7 +291,7 @@ def apply_unitary(unitary_value: Any, Case c) Method returns a numpy array. Multiply the matrix onto the target tensor and return to the caller. - C. Try to use `unitary_value._decompose_()`. + C. Try to use `unitary_value._decompose_()` (if `allow_decompose`). Case a) Method not present or returns `NotImplemented` or `None`. Continue to next strategy. Case b) Method returns an OP_TREE. @@ -311,6 +312,9 @@ def apply_unitary(unitary_value: Any, default: What should be returned if `unitary_value` doesn't have a unitary effect. If not specified, a TypeError is raised instead of returning a default value. + allow_decompose: Defaults to True. If set to False, and applying the + unitary effect requires decomposing the object, the method will + pretend the object has no unitary effect. Returns: If the receiving object does not have a unitary effect, then the @@ -341,6 +345,8 @@ def apply_unitary(unitary_value: Any, _strat_apply_unitary_from_decompose, _strat_apply_unitary_from_unitary ] + if not allow_decompose: + strats.remove(_strat_apply_unitary_from_decompose) # Try each strategy, stopping if one works. for strat in strats: diff --git a/cirq/protocols/apply_unitary_protocol_test.py b/cirq/protocols/apply_unitary_protocol_test.py index 2fc5e95940f..ba5d5960ed3 100644 --- a/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq/protocols/apply_unitary_protocol_test.py @@ -275,8 +275,10 @@ def test_big_endian_subspace_index(): state = np.zeros(shape=(2, 3, 4, 5, 1, 6, 1, 1)) args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 3]) s = slice(None) - assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, ...) - assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, ...) + assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, s, s, + s, s) + assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, s, s, s, + s) def test_apply_unitaries(): diff --git a/cirq/protocols/has_unitary_protocol.py b/cirq/protocols/has_unitary_protocol.py index 98f5c818138..048fb6959d3 100644 --- a/cirq/protocols/has_unitary_protocol.py +++ b/cirq/protocols/has_unitary_protocol.py @@ -52,7 +52,7 @@ def _has_unitary_(self) -> bool: """ -def has_unitary(val: Any) -> bool: +def has_unitary(val: Any, *, allow_decompose: bool = True) -> bool: """Determines whether the value has a unitary effect. Determines whether `val` has a unitary effect by attempting the following @@ -104,6 +104,8 @@ def has_unitary(val: Any) -> bool: _strat_has_unitary_from_has_unitary, _strat_has_unitary_from_decompose, _strat_has_unitary_from_apply_unitary, _strat_has_unitary_from_unitary ] + if not allow_decompose: + strats.remove(_strat_has_unitary_from_decompose) for strat in strats: result = strat(val) if result is not None: diff --git a/cirq/protocols/has_unitary_protocol_test.py b/cirq/protocols/has_unitary_protocol_test.py index ed2381c9487..b636b790c55 100644 --- a/cirq/protocols/has_unitary_protocol_test.py +++ b/cirq/protocols/has_unitary_protocol_test.py @@ -47,6 +47,7 @@ def _unitary_(self): assert not cirq.has_unitary(No1()) assert not cirq.has_unitary(No2()) assert cirq.has_unitary(Yes()) + assert cirq.has_unitary(Yes(), allow_decompose=False) def test_via_apply_unitary(): @@ -82,6 +83,7 @@ def _apply_unitary_(self, args): return args.target_tensor assert cirq.has_unitary(Yes1()) + assert cirq.has_unitary(Yes1(), allow_decompose=False) assert cirq.has_unitary(Yes2()) assert not cirq.has_unitary(No1()) assert not cirq.has_unitary(No2()) @@ -122,6 +124,10 @@ def _decompose_(self): assert not cirq.has_unitary(No2()) assert not cirq.has_unitary(No3()) + assert not cirq.has_unitary(Yes1(), allow_decompose=False) + assert not cirq.has_unitary(Yes2(), allow_decompose=False) + assert not cirq.has_unitary(No1(), allow_decompose=False) + def test_via_has_unitary(): diff --git a/cirq/protocols/json_serialization_test.py b/cirq/protocols/json_serialization_test.py index 2991f8eb4ab..fd540074bc3 100644 --- a/cirq/protocols/json_serialization_test.py +++ b/cirq/protocols/json_serialization_test.py @@ -111,6 +111,11 @@ def test_fail_to_resolve(): # cirq.Circuit(cirq.rx(sympy.Symbol('theta')).on(Q0)), SHOULDNT_BE_SERIALIZED = [ + # Intermediate states with work buffers and unknown external prng guts. + 'ActOnStateVectorArgs', + 'ApplyChannelArgs', + 'ApplyMixtureArgs', + 'ApplyUnitaryArgs', # Circuit optimizers are function-like. Only attributes # are ignore_failures, tolerance, and other feature flags @@ -257,9 +262,6 @@ def test_mutually_exclusive_blacklist(): NOT_YET_SERIALIZABLE = [ - 'ApplyChannelArgs', - 'ApplyMixtureArgs', - 'ApplyUnitaryArgs', 'AsymmetricDepolarizingChannel', 'AxisAngleDecomposition', 'Calibration', diff --git a/cirq/protocols/measurement_key_protocol.py b/cirq/protocols/measurement_key_protocol.py index 064bdaf623b..eea88b01195 100644 --- a/cirq/protocols/measurement_key_protocol.py +++ b/cirq/protocols/measurement_key_protocol.py @@ -109,7 +109,8 @@ def measurement_keys(val: Any, *, don't directly specify their measurement keys will be decomposed in order to find measurement keys within the decomposed operations. If not set, composite operations will appear to have no measurement - keys. + keys. Used by internal methods to stop redundant decompositions from + being performed. Returns: The measurement keys of the value. If the value has no measurement, diff --git a/cirq/protocols/mixture_protocol.py b/cirq/protocols/mixture_protocol.py index bbdf2874308..edc92713586 100644 --- a/cirq/protocols/mixture_protocol.py +++ b/cirq/protocols/mixture_protocol.py @@ -19,6 +19,8 @@ from typing_extensions import Protocol from cirq._doc import document +from cirq.protocols.decompose_protocol import \ + _try_decompose_into_operations_and_qubits from cirq.protocols.has_unitary_protocol import has_unitary from cirq.type_workarounds import NotImplementedType @@ -161,12 +163,22 @@ def mixture_channel(val: Any, default: Any = RaiseTypeErrorIfNotProvided "method, but it returned NotImplemented.".format(type(val))) -def has_mixture_channel(val: Any) -> bool: +def has_mixture_channel(val: Any, *, allow_decompose: bool = True) -> bool: """Returns whether the value has a mixture channel representation. In contrast to `has_mixture` this method falls back to checking whether the value has a unitary representation via `has_channel`. + Args: + val: The value to check. + allow_decompose: Used by internal methods to stop redundant + decompositions from being performed (e.g. there's no need to + decompose an object to check if it is unitary as part of determining + if the object is a quantum channel, when the quantum channel check + will already be doing a more general decomposition check). Defaults + to True. When false, the decomposition strategy for determining + the result is skipped. + Returns: If `val` has a `_has_mixture_` method and its result is not NotImplemented, that result is returned. Otherwise, if `val` has a @@ -180,9 +192,13 @@ def has_mixture_channel(val: Any) -> bool: if result is not NotImplemented: return result - result = has_unitary(val) - if result is not NotImplemented and result: - return result + if has_unitary(val, allow_decompose=False): + return True + + if allow_decompose: + operations, _, _ = _try_decompose_into_operations_and_qubits(val) + if operations is not None: + return all(has_mixture_channel(val) for val in operations) # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead. return mixture_channel(val, None) is not None diff --git a/cirq/protocols/mixture_protocol_test.py b/cirq/protocols/mixture_protocol_test.py index efef5512bb2..7b14fcd8365 100644 --- a/cirq/protocols/mixture_protocol_test.py +++ b/cirq/protocols/mixture_protocol_test.py @@ -147,6 +147,28 @@ def test_has_mixture_channel(): assert cirq.has_mixture_channel(ReturnsUnitary()) assert not cirq.has_mixture_channel(ReturnsNotImplementedUnitary()) + class NoAtom(cirq.Operation): + + @property + def qubits(self): + return cirq.LineQubit.range(2) + + def with_qubits(self): + raise NotImplementedError() + + class No1: + + def _decompose_(self): + return [NoAtom()] + + class Yes1: + + def _decompose_(self): + return [cirq.X(cirq.LineQubit(0))] + + assert not cirq.has_mixture_channel(No1()) + assert cirq.has_mixture_channel(Yes1()) + def test_valid_mixture(): cirq.validate_mixture(ReturnsValidTuple()) diff --git a/cirq/qis/states.py b/cirq/qis/states.py index 88319bb9742..f343492b801 100644 --- a/cirq/qis/states.py +++ b/cirq/qis/states.py @@ -392,7 +392,7 @@ def validate_qid_shape(state: np.ndarray, return qid_shape -def validate_indices(num_qubits: int, indices: List[int]) -> None: +def validate_indices(num_qubits: int, indices: Sequence[int]) -> None: """Validates that the indices have values within range of num_qubits.""" if any(index < 0 for index in indices): raise IndexError('Negative index in indices: {}'.format(indices)) diff --git a/cirq/sim/__init__.py b/cirq/sim/__init__.py index d6d379c89b4..bdf79d4a14d 100644 --- a/cirq/sim/__init__.py +++ b/cirq/sim/__init__.py @@ -14,6 +14,9 @@ """Base simulation classes and generic simulators.""" +from cirq.sim.act_on_state_vector_args import ( + ActOnStateVectorArgs,) + from cirq.sim.density_matrix_utils import ( measure_density_matrix, sample_density_matrix, diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py new file mode 100644 index 00000000000..94ee0d773ff --- /dev/null +++ b/cirq/sim/act_on_state_vector_args.py @@ -0,0 +1,229 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Objects and methods for acting efficiently on a state vector.""" + +from typing import Any, Iterable, Sequence, Tuple, TYPE_CHECKING, Union, Dict + +import numpy as np + +from cirq import linalg, protocols +from cirq.protocols.decompose_protocol import ( + _try_decompose_into_operations_and_qubits,) + +if TYPE_CHECKING: + import cirq + + +class ActOnStateVectorArgs: + """State and context for an operation acting on a state vector. + + There are three common ways to act on this object: + + 1. Directly edit the `target_tensor` property, which is storing the state + vector of the quantum system as a numpy array with one axis per qudit. + 2. Overwrite the `available_buffer` property with the new state vector, and + then pass `available_buffer` into `swap_target_tensor_for`. + 3. Call `record_measurement_result(key, val)` to log a measurement result. + """ + + def __init__(self, target_tensor: np.ndarray, available_buffer: np.ndarray, + axes: Iterable[int], prng: np.random.RandomState, + log_of_measurement_results: Dict[str, Any]): + """ + Args: + target_tensor: The state vector to act on, stored as a numpy array + with one dimension for each qubit in the system. Operations are + expected to perform inplace edits of this object. + available_buffer: A workspace with the same shape and dtype as + `target_tensor`. Used by operations that cannot be applied to + `target_tensor` inline, in order to avoid unnecessary + allocations. Passing `available_buffer` into + `swap_target_tensor_for` will swap it for `target_tensor`. + axes: The indices of axes corresponding to the qubits that the + operation is supposed to act upon. + prng: The pseudo random number generator to use for probabilistic + effects. + log_of_measurement_results: A mutable object that measurements are + being recorded into. Edit it easily by calling + `ActOnStateVectorArgs.record_measurement_result`. + """ + self.target_tensor = target_tensor + self.available_buffer = available_buffer + self.axes = tuple(axes) + self.prng = prng + self.log_of_measurement_results = log_of_measurement_results + + def swap_target_tensor_for(self, new_target_tensor: np.ndarray): + """Gives a new state vector for the system. + + Typically, the new state vector should be `args.available_buffer` where + `args` is this `cirq.ActOnStateVectorArgs` instance. + + Args: + new_target_tensor: The new system state. Must have the same shape + and dtype as the old system state. + """ + if new_target_tensor is self.available_buffer: + self.available_buffer = self.target_tensor + self.target_tensor = new_target_tensor + + def record_measurement_result(self, key: str, value: Any): + """Adds a measurement result to the log. + + Args: + key: The key the measurement result should be logged under. Note + that operations should only store results under keys they have + declared in a `_measurement_keys_` method. + value: The value to log for the measurement. + """ + if key in self.log_of_measurement_results: + raise ValueError(f"Measurement already logged to key {key!r}") + self.log_of_measurement_results[key] = value + + def subspace_index(self, + little_endian_bits_int: int = 0, + *, + big_endian_bits_int: int = 0 + ) -> Tuple[Union[slice, int, 'ellipsis'], ...]: + """An index for the subspace where the target axes equal a value. + + Args: + little_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The least significant + bit of the integer is the desired bit for the first axis, and + so forth in increasing order. Can't be specified at the same + time as `big_endian_bits_int`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=0, b=2, c=1] via 7 = 1*(3*2) + 2*(2) + 0. + + big_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The most significant + bit of the integer is the desired bit for the first axis, and + so forth in decreasing order. Can't be specified at the same + time as `little_endian_bits_int`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=1, b=2, c=0] via 7 = 1*(3*2) + 2*(2) + 0. + + Returns: + A value that can be used to index into `target_tensor` and + `available_buffer`, and manipulate only the part of Hilbert space + corresponding to a given bit assignment. + + Example: + If `target_tensor` is a 4 qubit tensor and `axes` is `[1, 3]` and + then this method will return the following when given + `little_endian_bits=0b01`: + + `(slice(None), 0, slice(None), 1, Ellipsis)` + + Therefore the following two lines would be equivalent: + + args.target_tensor[args.subspace_index(0b01)] += 1 + + args.target_tensor[:, 0, :, 1] += 1 + """ + return linalg.slice_for_qubits_equal_to( + self.axes, + little_endian_qureg_value=little_endian_bits_int, + big_endian_qureg_value=big_endian_bits_int, + qid_shape=self.target_tensor.shape) + + def _act_on_fallback_(self, action: Any, allow_decompose: bool): + strats = [ + _strat_act_on_state_vector_from_apply_unitary, + _strat_act_on_state_vector_from_mixture, + ] + if allow_decompose: + strats.append(_strat_act_on_state_vector_from_apply_decompose) + + # Try each strategy, stopping if one works. + for strat in strats: + result = strat(action, self) + if result is False: + break # coverage: ignore + if result is True: + return True + assert result is NotImplemented, str(result) + + return NotImplemented + + +def _strat_act_on_state_vector_from_apply_unitary( + unitary_value: Any, + args: 'cirq.ActOnStateVectorArgs', +) -> bool: + new_target_tensor = protocols.apply_unitary( + unitary_value, + protocols.ApplyUnitaryArgs( + target_tensor=args.target_tensor, + available_buffer=args.available_buffer, + axes=args.axes, + ), + allow_decompose=False, + default=NotImplemented) + if new_target_tensor is NotImplemented: + return NotImplemented + args.swap_target_tensor_for(new_target_tensor) + return True + + +def _strat_act_on_state_vector_from_apply_decompose( + val: Any, + args: ActOnStateVectorArgs, +) -> bool: + operations, qubits, _ = _try_decompose_into_operations_and_qubits(val) + if operations is None: + return NotImplemented + return _act_all_on_state_vector(operations, qubits, args) + + +def _act_all_on_state_vector(actions: Iterable[Any], + qubits: Sequence['cirq.Qid'], + args: 'cirq.ActOnStateVectorArgs'): + assert len(qubits) == len(args.axes) + qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)} + + old_axes = args.axes + try: + for action in actions: + args.axes = tuple(qubit_map[q] for q in action.qubits) + protocols.act_on(action, args) + finally: + args.axes = old_axes + return True + + +def _strat_act_on_state_vector_from_mixture(action: Any, + args: 'cirq.ActOnStateVectorArgs' + ) -> bool: + mixture = protocols.mixture(action, default=None) + if mixture is None: + return NotImplemented + probabilities, unitaries = zip(*mixture) + + index = args.prng.choice(range(len(unitaries)), p=probabilities) + shape = protocols.qid_shape(action) * 2 + unitary = unitaries[index].astype(args.target_tensor.dtype).reshape(shape) + linalg.targeted_left_multiply(unitary, + args.target_tensor, + args.axes, + out=args.available_buffer) + args.swap_target_tensor_for(args.available_buffer) + return True diff --git a/cirq/sim/act_on_state_vector_args_test.py b/cirq/sim/act_on_state_vector_args_test.py new file mode 100644 index 00000000000..878dcc7858d --- /dev/null +++ b/cirq/sim/act_on_state_vector_args_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import cirq + + +def test_decomposed_fallback(): + + class Composite(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + yield cirq.X(*qubits) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64), + available_buffer=np.empty((2, 2, 2), dtype=np.complex64), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}) + + cirq.act_on(Composite(), args) + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(0, 1, 0), shape=(2, 2, 2), dtype=np.complex64)) + + +def test_cannot_act(): + + class NoDetails: + pass + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64), + available_buffer=np.empty((2, 2, 2), dtype=np.complex64), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}) + + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(NoDetails(), args) diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index 00f713e64fc..4a43f5786aa 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -61,7 +61,8 @@ def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None): @staticmethod def is_supported_operation(op: 'cirq.Operation') -> bool: """Checks whether given operation can be simulated by this simulator.""" - if protocols.is_measurement(op): return True + # TODO: support more general Pauli measurements + if isinstance(op.gate, cirq.MeasurementGate): return True if isinstance(op, GlobalPhaseOperation): return True if not protocols.has_unitary(op): return False u = cirq.unitary(op) @@ -96,23 +97,26 @@ def _base_iterator(self, circuit: circuits.Circuit, state=CliffordState( qubit_map, initial_state=initial_state)) - else: - state = CliffordState(qubit_map, initial_state=initial_state) - - for moment in circuit: - measurements = collections.defaultdict( - list) # type: Dict[str, List[np.ndarray]] - - for op in moment: - if protocols.has_unitary(op): - state.apply_unitary(op) - elif protocols.is_measurement(op): - key = protocols.measurement_key(op) - measurements[key].extend( - state.perform_measurement(op.qubits, self._prng)) - - yield CliffordSimulatorStepResult(measurements=measurements, - state=state) + return + + state = CliffordState(qubit_map, initial_state=initial_state) + + for moment in circuit: + measurements: Dict[str, List[np.ndarray]] = collections.defaultdict( + list) + + for op in moment: + if isinstance(op.gate, ops.MeasurementGate): + key = protocols.measurement_key(op) + measurements[key].extend( + state.perform_measurement(op.qubits, self._prng)) + elif protocols.has_unitary(op): + state.apply_unitary(op) + else: + raise NotImplementedError(f"Unrecognized operation: {op!r}") + + yield CliffordSimulatorStepResult(measurements=measurements, + state=state) def _simulator_iterator( self, diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index 2225051807e..c56e608fcbc 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -578,10 +578,10 @@ def _qubit_map_to_shape(qubit_map: Dict[ops.Qid, int]) -> Tuple[int, ...]: def _verify_unique_measurement_keys(circuit: circuits.Circuit): result = collections.Counter( - protocols.measurement_key(op, default=None) - for op in ops.flatten_op_tree(iter(circuit))) - result[None] = 0 - duplicates = [k for k, v in result.most_common() if v > 1] - if duplicates: - raise ValueError('Measurement key {} repeated'.format( - ",".join(duplicates))) + key for op in ops.flatten_op_tree(iter(circuit)) + for key in protocols.measurement_keys(op)) + if result: + duplicates = [k for k, v in result.most_common() if v > 1] + if duplicates: + raise ValueError('Measurement key {} repeated'.format( + ",".join(duplicates))) diff --git a/cirq/sim/sparse_simulator.py b/cirq/sim/sparse_simulator.py index 899a539a523..6996c4461e0 100644 --- a/cirq/sim/sparse_simulator.py +++ b/cirq/sim/sparse_simulator.py @@ -15,43 +15,26 @@ """A simulator that uses numpy's einsum or sparse matrix operations.""" import collections - -from typing import Dict, Iterator, List, Tuple, Type, TYPE_CHECKING +from typing import Dict, Iterator, List, Type, TYPE_CHECKING, DefaultDict import numpy as np -from cirq import circuits, linalg, ops, protocols, qis, study, value -from cirq.sim import simulator, wave_function, wave_function_simulator +from cirq import circuits, ops, protocols, qis, study, value +from cirq.sim import ( + simulator, + wave_function, + wave_function_simulator, + act_on_state_vector_args, +) if TYPE_CHECKING: import cirq -class _FlipGate(ops.SingleQubitGate): - """A unitary gate that flips the |0> state with another state. - - Used by `Simulator` to reset a qubit. - """ - - def __init__(self, dimension: int, reset_value: int): - assert 0 < reset_value < dimension - self.dimension = dimension - self.reset_value = reset_value - - def _qid_shape_(self) -> Tuple[int, ...]: - return (self.dimension,) - - def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray: - args.available_buffer[..., 0] = args.target_tensor[..., self. - reset_value] - args.available_buffer[..., self. - reset_value] = args.target_tensor[..., 0] - return args.available_buffer - - # Mutable named tuple to hold state and a buffer. -class _StateAndBuffer(): - def __init__(self, state, buffer): +class _StateAndBuffer: + + def __init__(self, state: np.ndarray, buffer: np.ndarray): self.state = state self.buffer = buffer @@ -152,11 +135,9 @@ def __init__(self, self._dtype = dtype self._prng = value.parse_random_state(seed) - def _run( - self, - circuit: circuits.Circuit, - param_resolver: study.ParamResolver, - repetitions: int) -> Dict[str, List[np.ndarray]]: + def _run(self, circuit: circuits.Circuit, + param_resolver: study.ParamResolver, + repetitions: int) -> Dict[str, np.ndarray]: """See definition in `cirq.SimulatesSamples`.""" param_resolver = param_resolver or study.ParamResolver({}) resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) @@ -165,13 +146,12 @@ def _run( def measure_or_mixture(op): return protocols.is_measurement(op) or protocols.has_mixture(op) if circuit.are_all_matches_terminal(measure_or_mixture): - return self._run_sweep_sample(resolved_circuit, repetitions) + return self._run_sweep_terminal_sample(resolved_circuit, + repetitions) return self._run_sweep_repeat(resolved_circuit, repetitions) - def _run_sweep_sample( - self, - circuit: circuits.Circuit, - repetitions: int) -> Dict[str, List[np.ndarray]]: + def _run_sweep_terminal_sample(self, circuit: circuits.Circuit, + repetitions: int) -> Dict[str, np.ndarray]: for step_result in self._base_iterator( circuit=circuit, qubit_order=ops.QubitOrder.DEFAULT, @@ -187,16 +167,16 @@ def _run_sweep_sample( repetitions, seed=self._prng) - def _run_sweep_repeat( - self, - circuit: circuits.Circuit, - repetitions: int) -> Dict[str, List[np.ndarray]]: - measurements = {} # type: Dict[str, List[np.ndarray]] + def _run_sweep_repeat(self, circuit: circuits.Circuit, + repetitions: int) -> Dict[str, np.ndarray]: if repetitions == 0: - for _, op, _ in circuit.findall_operations_with_gate_type( - ops.MeasurementGate): - measurements[protocols.measurement_key(op)] = np.empty([0, 1]) + return { + key: np.empty(shape=[0, 1]) + for key in protocols.measurement_keys(circuit) + } + measurements: DefaultDict[str, List[ + np.ndarray]] = collections.defaultdict(list) for _ in range(repetitions): all_step_results = self._base_iterator( circuit, @@ -205,8 +185,6 @@ def _run_sweep_repeat( for step_result in all_step_results: for k, v in step_result.measurements.items(): - if not k in measurements: - measurements[k] = [] measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()} @@ -240,7 +218,7 @@ def _base_iterator( qubit_order: ops.QubitOrderOrList, initial_state: 'cirq.STATE_VECTOR_LIKE', perform_measurements: bool = True, - ) -> Iterator: + ) -> Iterator['SparseSimulatorStep']: qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for( circuit.all_qubits()) num_qubits = len(qubits) @@ -253,120 +231,27 @@ def _base_iterator( if len(circuit) == 0: yield SparseSimulatorStep(state, {}, qubit_map, self._dtype) - def on_stuck(bad_op: ops.Operation): - return TypeError( - "Can't simulate unknown operations that don't specify a " - "_unitary_ method, a _decompose_ method, " - "(_has_unitary_ + _apply_unitary_) methods," - "(_has_mixture_ + _mixture_) methods, or are measurements." - ": {!r}".format(bad_op)) - - def keep(potential_op: ops.Operation) -> bool: - # The order of this is optimized to call has_xxx methods first. - return (protocols.has_unitary(potential_op) or - protocols.has_mixture(potential_op) or - protocols.is_measurement(potential_op) or - isinstance(potential_op.gate, ops.ResetChannel)) - - data = _StateAndBuffer(state=np.reshape(state, qid_shape), - buffer=np.empty(qid_shape, dtype=self._dtype)) + sim_state = act_on_state_vector_args.ActOnStateVectorArgs( + target_tensor=np.reshape(state, qid_shape), + available_buffer=np.empty(qid_shape, dtype=self._dtype), + axes=[], + prng=self._prng, + log_of_measurement_results={}) + for moment in circuit: - measurements = collections.defaultdict( - list) # type: Dict[str, List[int]] - - unitary_ops_and_measurements = protocols.decompose( - moment, keep=keep, on_stuck_raise=on_stuck) - - for op in unitary_ops_and_measurements: - indices = [qubit_map[qubit] for qubit in op.qubits] - if isinstance(op.gate, ops.ResetChannel): - self._simulate_reset(op, data, indices) - elif protocols.has_unitary(op): - self._simulate_unitary(op, data, indices) - elif protocols.is_measurement(op): - # Do measurements second, since there may be mixtures that - # operate as measurements. - # TODO: support measurement outside the computational basis. - # Github issue: - # https://github.com/quantumlib/Cirq/issues/1357 - if perform_measurements: - self._simulate_measurement(op, data, indices, - measurements, num_qubits) - elif protocols.has_mixture(op): - self._simulate_mixture(op, data, indices) - - yield SparseSimulatorStep( - state_vector=data.state, - measurements=measurements, - qubit_map=qubit_map, - dtype=self._dtype) - - def _simulate_unitary(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that has a unitary.""" - result = protocols.apply_unitary( - op, - args=protocols.ApplyUnitaryArgs( - data.state, - data.buffer, - indices)) - if result is data.buffer: - data.buffer = data.state - data.state = result - - def _simulate_reset(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that is a reset to the |0> state.""" - if isinstance(op.gate, ops.ResetChannel): - reset = op.gate - # Do a silent measurement. - bits, _ = wave_function.measure_state_vector( - data.state, indices, out=data.state, qid_shape=data.state.shape) - # Apply bit flip(s) to change the reset the bits to 0. - for b, i, d in zip(bits, indices, protocols.qid_shape(reset)): - if b == 0: - continue # Already zero, no reset needed - reset_unitary = _FlipGate(d, reset_value=b)(*op.qubits) - self._simulate_unitary(reset_unitary, data, [i]) - - def _simulate_measurement(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int], - measurements: Dict[str, List[int]], - num_qubits: int) -> None: - """Simulate an op that is a measurement in the computational basis.""" - # TODO: support measurement outside computational basis. - # Github issue: https://github.com/quantumlib/Cirq/issues/1357 - if isinstance(op.gate, ops.MeasurementGate): - meas = op.gate - invert_mask = meas.full_invert_mask() - # Measure updates inline. - bits, _ = wave_function.measure_state_vector( - data.state, - indices, - out=data.state, - qid_shape=data.state.shape, - seed=self._prng) - corrected = [ - bit ^ (bit < 2 and mask) - for bit, mask in zip(bits, invert_mask) - ] - key = protocols.measurement_key(meas) - measurements[key].extend(corrected) - - def _simulate_mixture(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that is a mixtures of unitaries.""" - probs, unitaries = zip(*protocols.mixture(op)) - # We work around numpy barfing on choosing from a list of - # numpy arrays (which is not `one-dimensional`) by selecting - # the index of the unitary. - index = self._prng.choice(range(len(unitaries)), p=probs) - shape = protocols.qid_shape(op) * 2 - unitary = unitaries[index].astype(self._dtype).reshape(shape) - result = linalg.targeted_left_multiply(unitary, data.state, indices, - out=data.buffer) - data.buffer = data.state - data.state = result + for op in moment: + if perform_measurements or not isinstance( + op.gate, ops.MeasurementGate): + sim_state.axes = tuple( + qubit_map[qubit] for qubit in op.qubits) + protocols.act_on(op, sim_state) + + yield SparseSimulatorStep(state_vector=sim_state.target_tensor, + measurements=dict( + sim_state.log_of_measurement_results), + qubit_map=qubit_map, + dtype=self._dtype) + sim_state.log_of_measurement_results.clear() def _check_all_resolved(self, circuit): """Raises if the circuit contains unresolved symbols.""" diff --git a/cirq/sim/sparse_simulator_test.py b/cirq/sim/sparse_simulator_test.py index 2402734646f..a6bdf91241d 100644 --- a/cirq/sim/sparse_simulator_test.py +++ b/cirq/sim/sparse_simulator_test.py @@ -104,6 +104,13 @@ def test_run_measure_at_end_no_repetitions(dtype): assert mock_sim.call_count == 4 +def test_run_repetitions_terminal_measurement_stochastic(): + q = cirq.LineQubit(0) + c = cirq.Circuit(cirq.H(q), cirq.measure(q, key='q')) + results = cirq.Simulator().run(c, repetitions=10000) + assert 1000 <= sum(v[0] for v in results.measurements['q']) < 9000 + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_run_repetitions_measure_at_end(dtype): q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq/sim/wave_function.py b/cirq/sim/wave_function.py index b4c2e590b6a..ec0b7dbb065 100644 --- a/cirq/sim/wave_function.py +++ b/cirq/sim/wave_function.py @@ -13,13 +13,7 @@ # limitations under the License. """Helpers for handling quantum wavefunctions.""" -from typing import ( - Dict, - List, - Optional, - Tuple, - TYPE_CHECKING, -) +from typing import (Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence) import abc import numpy as np @@ -248,7 +242,7 @@ def sample_state_vector( def measure_state_vector( state: np.ndarray, - indices: List[int], + indices: Sequence[int], *, # Force keyword args qid_shape: Optional[Tuple[int, ...]] = None, out: np.ndarray = None, @@ -337,7 +331,7 @@ def measure_state_vector( return measurement_bits, out -def _probs(state: np.ndarray, indices: List[int], +def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: """Returns the probabilities for a measurement on the given indices.""" tensor = np.reshape(state, qid_shape) diff --git a/docs/api.rst b/docs/api.rst index ff2ae7b04ef..71a5259c6a9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -240,6 +240,7 @@ results. cirq.validate_mixture cirq.validate_probability cirq.xeb_fidelity + cirq.ActOnStateVectorArgs cirq.CircuitSampleJob cirq.CliffordSimulator cirq.CliffordSimulatorStepResult @@ -308,6 +309,7 @@ the magic methods that can be implemented. :toctree: generated/ cirq.DEFAULT_RESOLVERS + cirq.act_on cirq.apply_channel cirq.apply_mixture cirq.apply_unitaries @@ -355,6 +357,7 @@ the magic methods that can be implemented. cirq.QasmOutput cirq.QuilFormatter cirq.QuilOutput + cirq.SupportsActOn cirq.SupportsApplyChannel cirq.SupportsApplyMixture cirq.SupportsApproximateEquality