diff --git a/crates/accelerate/src/permutation.rs b/crates/accelerate/src/permutation.rs index f80d87882312..31ba433ddd30 100644 --- a/crates/accelerate/src/permutation.rs +++ b/crates/accelerate/src/permutation.rs @@ -77,12 +77,20 @@ fn get_ordered_swap(pattern: &ArrayView1) -> Vec<(i64, i64)> { swaps } +/// Checks whether an array of size N is a permutation of 0, 1, ..., N - 1. +#[pyfunction] +#[pyo3(signature = (pattern))] +fn _validate_permutation(py: Python, pattern: PyArrayLike1) -> PyResult { + let view = pattern.as_array(); + validate_permutation(&view)?; + Ok(py.None()) +} + /// Finds inverse of a permutation pattern. #[pyfunction] #[pyo3(signature = (pattern))] fn _inverse_pattern(py: Python, pattern: PyArrayLike1) -> PyResult { let view = pattern.as_array(); - validate_permutation(&view)?; let inverse_i64: Vec = invert(&view).iter().map(|&x| x as i64).collect(); Ok(inverse_i64.to_object(py)) } @@ -100,12 +108,12 @@ fn _inverse_pattern(py: Python, pattern: PyArrayLike1) -> PyResult) -> PyResult { let view = permutation_in.as_array(); - validate_permutation(&view)?; Ok(get_ordered_swap(&view).to_object(py)) } #[pymodule] pub fn permutation(m: &Bound) -> PyResult<()> { + m.add_function(wrap_pyfunction!(_validate_permutation, m)?)?; m.add_function(wrap_pyfunction!(_inverse_pattern, m)?)?; m.add_function(wrap_pyfunction!(_get_ordered_swap, m)?)?; Ok(()) diff --git a/qiskit/synthesis/permutation/permutation_utils.py b/qiskit/synthesis/permutation/permutation_utils.py index e77068de0976..dbd73bfe8111 100644 --- a/qiskit/synthesis/permutation/permutation_utils.py +++ b/qiskit/synthesis/permutation/permutation_utils.py @@ -13,7 +13,11 @@ """Utility functions for handling permutations.""" # pylint: disable=unused-import -from qiskit._accelerate.permutation import _inverse_pattern, _get_ordered_swap +from qiskit._accelerate.permutation import ( + _inverse_pattern, + _get_ordered_swap, + _validate_permutation, +) def _pattern_to_cycles(pattern): diff --git a/test/python/synthesis/test_permutation_synthesis.py b/test/python/synthesis/test_permutation_synthesis.py index c8a6cb942042..a879d5251f90 100644 --- a/test/python/synthesis/test_permutation_synthesis.py +++ b/test/python/synthesis/test_permutation_synthesis.py @@ -25,7 +25,11 @@ synth_permutation_basic, synth_permutation_reverse_lnn_kms, ) -from qiskit.synthesis.permutation.permutation_utils import _inverse_pattern, _get_ordered_swap +from qiskit.synthesis.permutation.permutation_utils import ( + _inverse_pattern, + _get_ordered_swap, + _validate_permutation, +) from test import QiskitTestCase # pylint: disable=wrong-import-order @@ -58,7 +62,7 @@ def test_get_ordered_swap(self, width): @data(10, 20) def test_invalid_permutations(self, width): - """Check that synth_permutation_basic raises exceptions when the + """Check that _validate_permutation raises exceptions when the input is not a permutation.""" np.random.seed(1) for _ in range(5): @@ -67,13 +71,13 @@ def test_invalid_permutations(self, width): pattern_out_of_range = np.copy(pattern) pattern_out_of_range[0] = -1 with self.assertRaises(ValueError) as exc: - _ = synth_permutation_basic(pattern_out_of_range) + _validate_permutation(pattern_out_of_range) self.assertIn("input contains a negative number", str(exc.exception)) pattern_out_of_range = np.copy(pattern) pattern_out_of_range[0] = width with self.assertRaises(ValueError) as exc: - _ = synth_permutation_basic(pattern_out_of_range) + _validate_permutation(pattern_out_of_range) self.assertIn( "input has length {0} and contains {0}".format(width), str(exc.exception) ) @@ -81,7 +85,7 @@ def test_invalid_permutations(self, width): pattern_duplicate = np.copy(pattern) pattern_duplicate[-1] = pattern[0] with self.assertRaises(ValueError) as exc: - _ = synth_permutation_basic(pattern_duplicate) + _validate_permutation(pattern_duplicate) self.assertIn( "input contains {} more than once".format(pattern[0]), str(exc.exception) )