Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid Python op creation in commutative cancellation #12701

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/circuit/src/circuit_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ impl CircuitInstruction {
.and_then(|attrs| attrs.unit.as_deref())
}

pub fn is_parameterized(&self) -> bool {
self.params
.iter()
.any(|x| matches!(x, Param::ParameterExpression(_)))
}

/// Creates a shallow copy with the given fields replaced.
///
/// Returns:
Expand Down
30 changes: 30 additions & 0 deletions crates/circuit/src/dag_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::circuit_instruction::{
convert_py_to_operation_type, operation_type_to_py, CircuitInstruction,
ExtraInstructionAttributes,
};
use crate::imports::QUANTUM_CIRCUIT;
use crate::operations::Operation;
use numpy::IntoPyArray;
use pyo3::prelude::*;
Expand Down Expand Up @@ -228,6 +229,16 @@ impl DAGOpNode {
Ok(())
}

#[getter]
fn num_qubits(&self) -> u32 {
self.instruction.operation.num_qubits()
}

#[getter]
fn num_clbits(&self) -> u32 {
self.instruction.operation.num_clbits()
}

#[getter]
fn get_qargs(&self, py: Python) -> Py<PyTuple> {
self.instruction.qubits.clone_ref(py)
Expand Down Expand Up @@ -259,6 +270,10 @@ impl DAGOpNode {
self.instruction.params.to_object(py)
}

pub fn is_parameterized(&self) -> bool {
self.instruction.is_parameterized()
}

#[getter]
fn matrix(&self, py: Python) -> Option<PyObject> {
let matrix = self.instruction.operation.matrix(&self.instruction.params);
Expand Down Expand Up @@ -325,6 +340,21 @@ impl DAGOpNode {
}
}

#[getter]
fn definition<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
let definition = self
.instruction
.operation
.definition(&self.instruction.params);
definition
.map(|data| {
QUANTUM_CIRCUIT
.get_bound(py)
.call_method1(intern!(py, "_from_circuit_data"), (data,))
})
.transpose()
}

/// Sets the Instruction name corresponding to the op for this node
#[setter]
fn set_name(&mut self, py: Python, new_name: PyObject) -> PyResult<()> {
Expand Down
34 changes: 32 additions & 2 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,38 @@ impl Operation for StandardGate {
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RXGate => todo!("Add when we have R"),
Self::RYGate => todo!("Add when we have R"),
Self::RXGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
CircuitData::from_standard_gates(
py,
1,
[(
Self::RGate,
smallvec![theta.clone(), FLOAT_ZERO],
smallvec![Qubit(0)],
)],
FLOAT_ZERO,
)
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RYGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
CircuitData::from_standard_gates(
py,
1,
[(
Self::RGate,
smallvec![theta.clone(), Param::Float(PI / 2.0)],
smallvec![Qubit(0)],
)],
FLOAT_ZERO,
)
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RZGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
Expand Down
35 changes: 34 additions & 1 deletion qiskit/circuit/commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from qiskit.circuit.operation import Operation
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit.quantum_info.operators import Operator
from qiskit._accelerate.circuit import StandardGate

_skipped_op_names = {"measure", "reset", "delay", "initialize"}
_no_cache_op_names = {"annotated"}
Expand Down Expand Up @@ -57,6 +58,23 @@ def __init__(self, standard_gate_commutations: dict = None, cache_max_entries: i
self._cache_miss = 0
self._cache_hit = 0

def commute_nodes(
self,
op1,
op2,
max_num_qubits: int = 3,
) -> bool:
"""Checks if two DAGOpNodes commute."""
qargs1 = op1.qargs
cargs1 = op2.cargs
if not isinstance(op1._raw_op, StandardGate):
op1 = op1.op
qargs2 = op2.qargs
cargs2 = op2.cargs
if not isinstance(op2._raw_op, StandardGate):
op2 = op2.op
return self.commute(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits)

def commute(
self,
op1: Operation,
Expand Down Expand Up @@ -255,9 +273,15 @@ def is_commutation_skipped(op, qargs, max_num_qubits):
if getattr(op, "is_parameterized", False) and op.is_parameterized():
return True

from qiskit.dagcircuit.dagnode import DAGOpNode

# we can proceed if op has defined: to_operator, to_matrix and __array__, or if its definition can be
# recursively resolved by operations that have a matrix. We check this by constructing an Operator.
if (hasattr(op, "to_matrix") and hasattr(op, "__array__")) or hasattr(op, "to_operator"):
if (
isinstance(op, DAGOpNode)
or (hasattr(op, "to_matrix") and hasattr(op, "__array__"))
or hasattr(op, "to_operator")
):
return False

return False
Expand Down Expand Up @@ -409,6 +433,15 @@ def _commute_matmul(
first_qarg = tuple(qarg[q] for q in first_qargs)
second_qarg = tuple(qarg[q] for q in second_qargs)

from qiskit.dagcircuit.dagnode import DAGOpNode

# If we have a DAGOpNode here we've received a StandardGate definition from
# rust and we can manually pull the matrix to use for the Operators
if isinstance(first_ops, DAGOpNode):
first_ops = first_ops.matrix
if isinstance(second_op, DAGOpNode):
second_op = second_op.matrix

# try to generate an Operator out of op, if this succeeds we can determine commutativity, otherwise
# return false
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ def run(self, dag):
does_commute = (
isinstance(current_gate, DAGOpNode)
and isinstance(prev_gate, DAGOpNode)
and self.comm_checker.commute(
current_gate.op,
current_gate.qargs,
current_gate.cargs,
prev_gate.op,
prev_gate.qargs,
prev_gate.cargs,
)
and self.comm_checker.commute_nodes(current_gate, prev_gate)
)
if not does_commute:
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from qiskit.circuit.library.standard_gates.rx import RXGate
from qiskit.circuit.library.standard_gates.p import PhaseGate
from qiskit.circuit.library.standard_gates.rz import RZGate
from qiskit.circuit import ControlFlowOp
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES


_CUTOFF_PRECISION = 1e-5
Expand Down Expand Up @@ -138,14 +138,14 @@ def run(self, dag):
total_phase = 0.0
for current_node in run:
if (
getattr(current_node.op, "condition", None) is not None
current_node.condition is not None
or len(current_node.qargs) != 1
or current_node.qargs[0] != run_qarg
):
raise RuntimeError("internal error")

if current_node.name in ["p", "u1", "rz", "rx"]:
current_angle = float(current_node.op.params[0])
current_angle = float(current_node.params[0])
elif current_node.name in ["z", "x"]:
current_angle = np.pi
elif current_node.name == "t":
Expand All @@ -159,8 +159,8 @@ def run(self, dag):

# Compose gates
total_angle = current_angle + total_angle
if current_node.op.definition:
total_phase += current_node.op.definition.global_phase
if current_node.definition:
total_phase += current_node.definition.global_phase

# Replace the data of the first node in the run
if cancel_set_key[0] == "z_rotation":
Expand Down Expand Up @@ -200,7 +200,9 @@ def _handle_control_flow_ops(self, dag):
"""

pass_manager = PassManager([CommutationAnalysis(), self])
for node in dag.op_nodes(ControlFlowOp):
for node in dag.op_nodes():
if node.name not in CONTROL_FLOW_OP_NAMES:
continue
mapped_blocks = []
for block in node.op.blocks:
new_circ = pass_manager.run(block)
Expand Down
Loading