From c8d6166ffca2c41cbb21736629941f70ef4b714f Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 12 Jul 2022 15:52:40 -0700 Subject: [PATCH] GridDevice: Exclude MeasurementGates in validation of qubit pairs (#5654) I'm comfortable with special-casing MeasurementGate because it's the only gate today with the property that it can be applied to any subset of qubits. Fixes #5652 @maffoo --- .../cirq_google/devices/grid_device.py | 4 ++ .../cirq_google/devices/grid_device_test.py | 59 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index e920f6f981e..cf0c1f2d6be 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -37,6 +37,9 @@ MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate) WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate) +# Families of gates which can be applied to any subset of valid qubits. +_VARIADIC_GATE_FAMILIES = [MEASUREMENT_GATE_FAMILY, WAIT_GATE_FAMILY] + def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: """Raises a ValueError if the `DeviceSpecification` proto is invalid.""" @@ -338,6 +341,7 @@ def validate_operation(self, operation: cirq.Operation) -> None: if ( len(operation.qubits) == 2 + and not any(operation in gf for gf in _VARIADIC_GATE_FAMILIES) and frozenset(operation.qubits) not in self._metadata.qubit_pairs ): raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.') diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index c6d019f9482..4f1c88d8db2 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -252,15 +252,74 @@ def test_grid_device_from_proto(): def test_grid_device_validate_operations_positive(): device_info, spec = _create_device_spec_with_horizontal_couplings() device = cirq_google.GridDevice.from_proto(spec) + # Gates that can be applied to any subset of valid qubits + variadic_gates = [cirq.measure, cirq.WaitGate(cirq.Duration(nanos=1), num_qubits=2)] for q in device_info.grid_qubits: device.validate_operation(cirq.X(q)) + device.validate_operation(cirq.measure(q)) # horizontal qubit pairs for i in range(GRID_HEIGHT): device.validate_operation( cirq.CZ(device_info.grid_qubits[2 * i], device_info.grid_qubits[2 * i + 1]) ) + for gate in variadic_gates: + device.validate_operation( + gate(device_info.grid_qubits[2 * i], device_info.grid_qubits[2 * i + 1]) + ) + + +@pytest.mark.parametrize( + 'gate_func', + [ + lambda _: cirq.measure, + lambda num_qubits: cirq.WaitGate(cirq.Duration(nanos=1), num_qubits=num_qubits), + ], +) +def test_grid_device_validate_operations_variadic_gates_positive(gate_func): + device_info, spec = _create_device_spec_with_horizontal_couplings() + device = cirq_google.GridDevice.from_proto(spec) + + # Single qubit operations + for q in device_info.grid_qubits: + device.validate_operation(gate_func(1)(q)) + + # horizontal qubit pairs (coupled) + for i in range(GRID_HEIGHT): + device.validate_operation( + gate_func(2)(device_info.grid_qubits[2 * i], device_info.grid_qubits[2 * i + 1]) + ) + + # Variadic gates across vertical qubit pairs (uncoupled pairs) should succeed. + for i in range(GRID_HEIGHT - 1): + device.validate_operation( + gate_func(2)(device_info.grid_qubits[2 * i], device_info.grid_qubits[2 * (i + 1)]) + ) + device.validate_operation( + gate_func(2)( + device_info.grid_qubits[2 * i + 1], device_info.grid_qubits[2 * (i + 1) + 1] + ) + ) + + # 3-qubit measurements + for i in range(GRID_HEIGHT - 2): + device.validate_operation( + gate_func(3)( + device_info.grid_qubits[2 * i], + device_info.grid_qubits[2 * (i + 1)], + device_info.grid_qubits[2 * (i + 2)], + ) + ) + device.validate_operation( + gate_func(3)( + device_info.grid_qubits[2 * i + 1], + device_info.grid_qubits[2 * (i + 1) + 1], + device_info.grid_qubits[2 * (i + 2) + 1], + ) + ) + # All-qubit measurement + device.validate_operation(gate_func(len(device_info.grid_qubits))(*device_info.grid_qubits)) def test_grid_device_validate_operations_negative():