Skip to content

Commit

Permalink
Add unitary gate to rust
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreinish committed Jan 29, 2025
1 parent 9f6faf5 commit 223cea7
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 16 deletions.
74 changes: 73 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ indexmap.version = "2.7.1"
hashbrown.version = "0.14.5"
num-bigint = "0.4"
num-complex = "0.4"
nalgebra = "0.33"
ndarray = "0.15"
numpy = "0.23"
smallvec = "1.13"
Expand Down
3 changes: 3 additions & 0 deletions crates/accelerate/src/target_transpiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,9 @@ impl Target {
OperationRef::Gate(gate) => gate.gate.clone_ref(py),
OperationRef::Instruction(instruction) => instruction.instruction.clone_ref(py),
OperationRef::Operation(operation) => operation.operation.clone_ref(py),
OperationRef::Unitary(unitary) => unitary
.create_py_op(py, &ExtraInstructionAttributes::default())?
.into_any(),
},
TargetOperation::Variadic(op_cls) => op_cls.clone_ref(py),
};
Expand Down
6 changes: 5 additions & 1 deletion crates/circuit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ bytemuck.workspace = true
bitfield-struct.workspace = true
num-complex.workspace = true
ndarray.workspace = true
numpy.workspace = true
thiserror.workspace = true
approx.workspace = true
itertools.workspace = true
nalgebra.workspace = true

[dependencies.pyo3]
workspace = true
Expand All @@ -41,6 +41,10 @@ features = ["rayon"]
workspace = true
features = ["union"]

[dependencies.numpy]
workspace = true
features = ["nalgebra"]

[features]
cache_pygates = []

Expand Down
49 changes: 46 additions & 3 deletions crates/circuit/src/circuit_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
#[cfg(feature = "cache_pygates")]
use std::sync::OnceLock;

use numpy::{IntoPyArray, PyArray2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use pyo3::basic::CompareOp;
use pyo3::exceptions::{PyDeprecationWarning, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyList, PyString, PyTuple, PyType};
use pyo3::IntoPyObjectExt;
use pyo3::{intern, PyObject, PyResult};

use nalgebra::{MatrixView2, MatrixView4};
use num_complex::Complex64;
use smallvec::SmallVec;

use crate::imports::{
CONTROLLED_GATE, CONTROL_FLOW_OP, GATE, INSTRUCTION, OPERATION, WARNINGS_WARN,
};
use crate::operations::{
Operation, OperationRef, Param, PyGate, PyInstruction, PyOperation, StandardGate,
StandardInstruction, StandardInstructionType,
ArrayType, Operation, OperationRef, Param, PyGate, PyInstruction, PyOperation, StandardGate,
StandardInstruction, StandardInstructionType, UnitaryGate,
};
use crate::packed_instruction::PackedOperation;

Expand Down Expand Up @@ -341,6 +342,9 @@ impl CircuitInstruction {
OperationRef::Gate(gate) => gate.gate.clone_ref(py),
OperationRef::Instruction(instruction) => instruction.instruction.clone_ref(py),
OperationRef::Operation(operation) => operation.operation.clone_ref(py),
OperationRef::Unitary(unitary) => {
unitary.create_py_op(py, &self.extra_attrs)?.into_any()
}
};

#[cfg(feature = "cache_pygates")]
Expand Down Expand Up @@ -742,6 +746,45 @@ impl<'py> FromPyObject<'py> for OperationFromPython {
});
}

// We need to check by name here to avoid a circular import during initial loading
if ob.getattr(intern!(py, "name"))?.extract::<String>()? == "unitary" {
let params = extract_params()?;
if let Param::Obj(data) = &params[0] {
let py_matrix: PyReadonlyArray2<Complex64> = data.extract(py)?;
let matrix: Option<MatrixView2<Complex64>> = py_matrix.try_as_matrix();
if let Some(x) = matrix {
let unitary_gate = UnitaryGate {
array: ArrayType::OneQ(x.into_owned()),
};
return Ok(OperationFromPython {
operation: PackedOperation::from_unitary(Box::new(unitary_gate)),
params: SmallVec::new(),
extra_attrs: extract_extra()?,
});
}
let matrix: Option<MatrixView4<Complex64>> = py_matrix.try_as_matrix();
if let Some(x) = matrix {
let unitary_gate = UnitaryGate {
array: ArrayType::TwoQ(x.into_owned()),
};
return Ok(OperationFromPython {
operation: PackedOperation::from_unitary(Box::new(unitary_gate)),
params: SmallVec::new(),
extra_attrs: extract_extra()?,
});
} else {
let unitary_gate = UnitaryGate {
array: ArrayType::NDArray(py_matrix.as_array().to_owned()),
};
return Ok(OperationFromPython {
operation: PackedOperation::from_unitary(Box::new(unitary_gate)),
params: SmallVec::new(),
extra_attrs: extract_extra()?,
});
};
}
}

if ob_type.is_subclass(GATE.get_bound(py))? {
let params = extract_params()?;
let gate = Box::new(PyGate {
Expand Down
9 changes: 5 additions & 4 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3295,7 +3295,8 @@ def _format(operand):
py_op.operation.setattr(py, "condition", new_condition)?;
}
OperationRef::StandardGate(_)
| OperationRef::StandardInstruction(_) => {}
| OperationRef::StandardInstruction(_)
| OperationRef::Unitary(_) => {}
}
}
}
Expand Down Expand Up @@ -6245,9 +6246,9 @@ impl DAGCircuit {
};
#[cfg(feature = "cache_pygates")]
let py_op = match new_op.operation.view() {
OperationRef::StandardGate(_) | OperationRef::StandardInstruction(_) => {
OnceLock::new()
}
OperationRef::StandardGate(_)
| OperationRef::StandardInstruction(_)
| OperationRef::Unitary(_) => OnceLock::new(),
OperationRef::Gate(gate) => OnceLock::from(gate.gate.clone_ref(py)),
OperationRef::Instruction(instruction) => {
OnceLock::from(instruction.instruction.clone_ref(py))
Expand Down
Loading

0 comments on commit 223cea7

Please sign in to comment.