Skip to content

Commit

Permalink
Move utility functions _inverse_pattern and _get_ordered_swap to Rust (
Browse files Browse the repository at this point in the history
…#12327)

* Move utility functions _inverse_pattern and _get_ordered_swap to Rust

* fix formatting and pylint issues

* Changed input type to `PyArrayLike1<i64, AllowTypeChange>`

* Refactor `permutation.rs`, clean up imports, fix coverage error

* fix docstring for `_inverse_pattern`

Co-authored-by: Raynel Sanchez <[email protected]>

* fix docstring for `_get_ordered_swap`

Co-authored-by: Raynel Sanchez <[email protected]>

* remove pymodule nesting

* remove explicit `AllowTypeChange`

* Move input validation out of `_inverse_pattern` and `_get_ordered_swap`

---------

Co-authored-by: Raynel Sanchez <[email protected]>
  • Loading branch information
jpacold and raynelfss authored Jun 10, 2024
1 parent b2c3ffd commit 1956220
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 36 deletions.
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod isometry;
pub mod nlayout;
pub mod optimize_1q_gates;
pub mod pauli_exp_val;
pub mod permutation;
pub mod results;
pub mod sabre;
pub mod sampled_exp_val;
Expand Down
120 changes: 120 additions & 0 deletions crates/accelerate/src/permutation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use ndarray::{Array1, ArrayView1};
use numpy::PyArrayLike1;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::vec::Vec;

fn validate_permutation(pattern: &ArrayView1<i64>) -> PyResult<()> {
let n = pattern.len();
let mut seen: Vec<bool> = vec![false; n];

for &x in pattern {
if x < 0 {
return Err(PyValueError::new_err(
"Invalid permutation: input contains a negative number.",
));
}

if x as usize >= n {
return Err(PyValueError::new_err(format!(
"Invalid permutation: input has length {} and contains {}.",
n, x
)));
}

if seen[x as usize] {
return Err(PyValueError::new_err(format!(
"Invalid permutation: input contains {} more than once.",
x
)));
}

seen[x as usize] = true;
}

Ok(())
}

fn invert(pattern: &ArrayView1<i64>) -> Array1<usize> {
let mut inverse: Array1<usize> = Array1::zeros(pattern.len());
pattern.iter().enumerate().for_each(|(ii, &jj)| {
inverse[jj as usize] = ii;
});
inverse
}

fn get_ordered_swap(pattern: &ArrayView1<i64>) -> Vec<(i64, i64)> {
let mut permutation: Vec<usize> = pattern.iter().map(|&x| x as usize).collect();
let mut index_map = invert(pattern);

let n = permutation.len();
let mut swaps: Vec<(i64, i64)> = Vec::with_capacity(n);
for ii in 0..n {
let val = permutation[ii];
if val == ii {
continue;
}
let jj = index_map[ii];
swaps.push((ii as i64, jj as i64));
(permutation[ii], permutation[jj]) = (permutation[jj], permutation[ii]);
index_map[val] = jj;
index_map[ii] = ii;
}

swaps[..].reverse();
swaps
}

/// Checks whether an array of size N is a permutation of 0, 1, ..., N - 1.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _validate_permutation(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
validate_permutation(&view)?;
Ok(py.None())
}

/// Finds inverse of a permutation pattern.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _inverse_pattern(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
let inverse_i64: Vec<i64> = invert(&view).iter().map(|&x| x as i64).collect();
Ok(inverse_i64.to_object(py))
}

/// Sorts the input permutation by iterating through the permutation list
/// and putting each element to its correct position via a SWAP (if it's not
/// at the correct position already). If ``n`` is the length of the input
/// permutation, this requires at most ``n`` SWAPs.
///
/// More precisely, if the input permutation is a cycle of length ``m``,
/// then this creates a quantum circuit with ``m-1`` SWAPs (and of depth ``m-1``);
/// if the input permutation consists of several disjoint cycles, then each cycle
/// is essentially treated independently.
#[pyfunction]
#[pyo3(signature = (permutation_in))]
fn _get_ordered_swap(py: Python, permutation_in: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = permutation_in.as_array();
Ok(get_ordered_swap(&view).to_object(py))
}

#[pymodule]
pub fn permutation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(_validate_permutation, m)?)?;
m.add_function(wrap_pyfunction!(_inverse_pattern, m)?)?;
m.add_function(wrap_pyfunction!(_get_ordered_swap, m)?)?;
Ok(())
}
9 changes: 5 additions & 4 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use qiskit_accelerate::{
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val,
sparse_pauli_op::sparse_pauli_op, stochastic_swap::stochastic_swap,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
pauli_exp_val::pauli_expval, permutation::permutation, results::results, sabre::sabre,
sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
stochastic_swap::stochastic_swap, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
};

#[pymodule]
Expand All @@ -36,6 +36,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(nlayout))?;
m.add_wrapped(wrap_pymodule!(optimize_1q_gates))?;
m.add_wrapped(wrap_pymodule!(pauli_expval))?;
m.add_wrapped(wrap_pymodule!(permutation))?;
m.add_wrapped(wrap_pymodule!(results))?;
m.add_wrapped(wrap_pymodule!(sabre))?;
m.add_wrapped(wrap_pymodule!(sampled_exp_val))?;
Expand Down
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
sys.modules["qiskit._accelerate.stochastic_swap"] = qiskit._accelerate.stochastic_swap
sys.modules["qiskit._accelerate.two_qubit_decompose"] = qiskit._accelerate.two_qubit_decompose
sys.modules["qiskit._accelerate.vf2_layout"] = qiskit._accelerate.vf2_layout
sys.modules["qiskit._accelerate.permutation"] = qiskit._accelerate.permutation

from qiskit.exceptions import QiskitError, MissingOptionalLibraryError

Expand Down
36 changes: 6 additions & 30 deletions qiskit/synthesis/permutation/permutation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,12 @@

"""Utility functions for handling permutations."""


def _get_ordered_swap(permutation_in):
"""Sorts the input permutation by iterating through the permutation list
and putting each element to its correct position via a SWAP (if it's not
at the correct position already). If ``n`` is the length of the input
permutation, this requires at most ``n`` SWAPs.
More precisely, if the input permutation is a cycle of length ``m``,
then this creates a quantum circuit with ``m-1`` SWAPs (and of depth ``m-1``);
if the input permutation consists of several disjoint cycles, then each cycle
is essentially treated independently.
"""
permutation = list(permutation_in[:])
swap_list = []
index_map = _inverse_pattern(permutation_in)
for i, val in enumerate(permutation):
if val != i:
j = index_map[i]
swap_list.append((i, j))
permutation[i], permutation[j] = permutation[j], permutation[i]
index_map[val] = j
index_map[i] = i
swap_list.reverse()
return swap_list


def _inverse_pattern(pattern):
"""Finds inverse of a permutation pattern."""
b_map = {pos: idx for idx, pos in enumerate(pattern)}
return [b_map[pos] for pos in range(len(pattern))]
# pylint: disable=unused-import
from qiskit._accelerate.permutation import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)


def _pattern_to_cycles(pattern):
Expand Down
48 changes: 46 additions & 2 deletions test/python/synthesis/test_permutation_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,31 @@
synth_permutation_basic,
synth_permutation_reverse_lnn_kms,
)
from qiskit.synthesis.permutation.permutation_utils import _get_ordered_swap
from qiskit.synthesis.permutation.permutation_utils import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)
from test import QiskitTestCase # pylint: disable=wrong-import-order


@ddt
class TestPermutationSynthesis(QiskitTestCase):
"""Test the permutation synthesis functions."""

@data(4, 5, 10, 15, 20)
def test_inverse_pattern(self, width):
"""Test _inverse_pattern function produces correct index map."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)
inverse = _inverse_pattern(pattern)
for ii, jj in enumerate(pattern):
self.assertTrue(inverse[jj] == ii)

@data(4, 5, 10, 15, 20)
def test_get_ordered_swap(self, width):
"""Test get_ordered_swap function produces correct swap list."""
"""Test _get_ordered_swap function produces correct swap list."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)
Expand All @@ -46,6 +60,36 @@ def test_get_ordered_swap(self, width):
self.assertTrue(np.array_equal(pattern, output))
self.assertLess(len(swap_list), width)

@data(10, 20)
def test_invalid_permutations(self, width):
"""Check that _validate_permutation raises exceptions when the
input is not a permutation."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)

pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = -1
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_out_of_range)
self.assertIn("input contains a negative number", str(exc.exception))

pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = width
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_out_of_range)
self.assertIn(
"input has length {0} and contains {0}".format(width), str(exc.exception)
)

pattern_duplicate = np.copy(pattern)
pattern_duplicate[-1] = pattern[0]
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_duplicate)
self.assertIn(
"input contains {} more than once".format(pattern[0]), str(exc.exception)
)

@data(4, 5, 10, 15, 20)
def test_synth_permutation_basic(self, width):
"""Test synth_permutation_basic function produces the correct
Expand Down

0 comments on commit 1956220

Please sign in to comment.