Skip to content

Commit

Permalink
Move input validation out of _inverse_pattern and _get_ordered_swap
Browse files Browse the repository at this point in the history
  • Loading branch information
jpacold committed Jun 7, 2024
1 parent b0992b3 commit 3647b16
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
12 changes: 10 additions & 2 deletions crates/accelerate/src/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,20 @@ fn get_ordered_swap(pattern: &ArrayView1<i64>) -> Vec<(i64, i64)> {
swaps
}

/// Checks whether an array of size N is a permutation of 0, 1, ..., N - 1.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _validate_permutation(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
validate_permutation(&view)?;
Ok(py.None())
}

/// Finds inverse of a permutation pattern.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _inverse_pattern(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
validate_permutation(&view)?;
let inverse_i64: Vec<i64> = invert(&view).iter().map(|&x| x as i64).collect();
Ok(inverse_i64.to_object(py))
}
Expand All @@ -100,12 +108,12 @@ fn _inverse_pattern(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject
#[pyo3(signature = (permutation_in))]
fn _get_ordered_swap(py: Python, permutation_in: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = permutation_in.as_array();
validate_permutation(&view)?;
Ok(get_ordered_swap(&view).to_object(py))
}

#[pymodule]
pub fn permutation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(_validate_permutation, m)?)?;
m.add_function(wrap_pyfunction!(_inverse_pattern, m)?)?;
m.add_function(wrap_pyfunction!(_get_ordered_swap, m)?)?;
Ok(())
Expand Down
6 changes: 5 additions & 1 deletion qiskit/synthesis/permutation/permutation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
"""Utility functions for handling permutations."""

# pylint: disable=unused-import
from qiskit._accelerate.permutation import _inverse_pattern, _get_ordered_swap
from qiskit._accelerate.permutation import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)


def _pattern_to_cycles(pattern):
Expand Down
14 changes: 9 additions & 5 deletions test/python/synthesis/test_permutation_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
synth_permutation_basic,
synth_permutation_reverse_lnn_kms,
)
from qiskit.synthesis.permutation.permutation_utils import _inverse_pattern, _get_ordered_swap
from qiskit.synthesis.permutation.permutation_utils import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)
from test import QiskitTestCase # pylint: disable=wrong-import-order


Expand Down Expand Up @@ -58,7 +62,7 @@ def test_get_ordered_swap(self, width):

@data(10, 20)
def test_invalid_permutations(self, width):
"""Check that synth_permutation_basic raises exceptions when the
"""Check that _validate_permutation raises exceptions when the
input is not a permutation."""
np.random.seed(1)
for _ in range(5):
Expand All @@ -67,21 +71,21 @@ def test_invalid_permutations(self, width):
pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = -1
with self.assertRaises(ValueError) as exc:
_ = synth_permutation_basic(pattern_out_of_range)
_validate_permutation(pattern_out_of_range)
self.assertIn("input contains a negative number", str(exc.exception))

pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = width
with self.assertRaises(ValueError) as exc:
_ = synth_permutation_basic(pattern_out_of_range)
_validate_permutation(pattern_out_of_range)
self.assertIn(
"input has length {0} and contains {0}".format(width), str(exc.exception)
)

pattern_duplicate = np.copy(pattern)
pattern_duplicate[-1] = pattern[0]
with self.assertRaises(ValueError) as exc:
_ = synth_permutation_basic(pattern_duplicate)
_validate_permutation(pattern_duplicate)
self.assertIn(
"input contains {} more than once".format(pattern[0]), str(exc.exception)
)
Expand Down

0 comments on commit 3647b16

Please sign in to comment.