Skip to content

Commit

Permalink
feat: jitted random gaussian
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Feb 20, 2025
1 parent 6c7fecc commit 8993e63
Showing 1 changed file with 55 additions and 24 deletions.
79 changes: 55 additions & 24 deletions src/qibojit/custom_operators/quantum_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,20 @@ def _pauli_basis(
setattr(QINFO, "_pauli_basis", _pauli_basis)


@njit(
["c16[:,::1](c16[:,::1], i8[:])", "c16[:,:,::1](c16[:,:,::1], i8[:])"], cache=True
)
@njit(["c16[:,:](c16[:,:], i8[:])", "c16[:,:,:](c16[:,:,:], i8[:])"], cache=True)
def numba_transpose(array, axes):
axes = to_fixed_tuple(axes, array.ndim)
array = np.transpose(array, axes)

Check warning on line 112 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L111-L112

Added lines #L111 - L112 were not covered by tests
return np.ascontiguousarray(array)
# return np.ascontiguousarray(array)
return array

Check warning on line 114 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L114

Added line #L114 was not covered by tests


@njit(["c16[:,::1](c16[:,::1], i8)", "c16[:,::1](c16[:,:,::1], i8)"], cache=True)
@njit(["c16[:,::1](c16[:,:], i8)", "c16[:,::1](c16[:,:,:], i8)"], cache=True)
def _vectorization_column(state, dim):
indices = ENGINE.arange(state.ndim)
indices[-2:] = indices[-2:][::-1]
state = numba_transpose(state, indices)
state = ENGINE.ascontiguousarray(state)
return ENGINE.reshape(state, (-1, dim**2))

Check warning on line 123 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L119-L123

Added lines #L119 - L123 were not covered by tests


Expand All @@ -141,10 +141,12 @@ def _vectorization_system(state, dim=0):
# setattr(QINFO, "_vectorization_system", _vectorization_system)


@njit(["c16[:,:,::1](c16[:,::1], i8)", "c16[:,:,::1](c16[:,:,::1], i8)"], cache=True)
@njit(["c16[:,:,:](c16[:,:], i8)", "c16[:,:,:](c16[:,:,:], i8)"], cache=True)
def _unvectorization_column(state, dim):
axes = ENGINE.arange(state.ndim)[::-1]
state = numba_transpose(state, axes).reshape(dim, dim, state.shape[0])
last_dim = state.shape[0]
state = numba_transpose(state, axes)
state = ENGINE.ascontiguousarray(state).reshape(dim, dim, last_dim)
return numba_transpose(state, ENGINE.array([2, 1, 0], dtype=ENGINE.int64))

Check warning on line 150 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L146-L150

Added lines #L146 - L150 were not covered by tests


Expand All @@ -153,12 +155,10 @@ def _unvectorization_column(state, dim):

@njit(
[
nbt.complex128[::1](
nbt.complex128[:, ::1], nbt.Tuple((nbt.int64[::1], nbt.int64[::1]))
),
nbt.float64[::1](
nbt.float64[:, ::1], nbt.Tuple((nbt.int64[::1], nbt.int64[::1]))
nbt.complex128[:](
nbt.complex128[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:]))
),
nbt.float64[:](nbt.float64[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:]))),
],
parallel=True,
cache=True,
Expand All @@ -172,14 +172,14 @@ def _array_at_2d_indices(array, indices):

@njit(
nbt.Tuple((nbt.complex128[:, ::1], nbt.int64[:, ::1]))(
nbt.complex128[:, ::1], nbt.int64
nbt.complex128[:, :], nbt.int64
),
cache=True,
)
def _post_sparse_pauli_basis_vectorization(basis, dim):
indices = ENGINE.nonzero(basis)
basis = _array_at_2d_indices(basis, indices)
basis = basis.reshape(-1, dim)
basis = ENGINE.ascontiguousarray(basis).reshape(-1, dim)
indices = indices[1].reshape(-1, dim)
return basis, indices

Check warning on line 184 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L180-L184

Added lines #L180 - L184 were not covered by tests

Expand Down Expand Up @@ -276,7 +276,7 @@ def _vectorize_sparse_pauli_basis_column(


@njit(
nbt.Tuple((nbt.complex128[:, ::1], nbt.int64[::1]))(
nbt.Tuple((nbt.complex128[:, ::1], nbt.int64[:]))(
nbt.int64,
nbt.complex128[:, ::1],
nbt.complex128[:, ::1],
Expand All @@ -296,27 +296,27 @@ def _pauli_to_comp_basis_sparse_row(
unitary = numba_transpose(unitary, ENGINE.arange(unitary.ndim)[::-1])
nonzero = ENGINE.nonzero(unitary)
unitary = _array_at_2d_indices(unitary, nonzero)
return unitary.reshape(unitary.shape[0], -1), nonzero[1]
return ENGINE.ascontiguousarray(unitary).reshape(unitary.shape[0], -1), nonzero[1]

Check warning on line 299 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L296-L299

Added lines #L296 - L299 were not covered by tests


setattr(QINFO, "_pauli_to_comp_basis_sparse_row", _pauli_to_comp_basis_sparse_row)


@njit(
nbt.Tuple((nbt.complex128[:, ::1], nbt.complex128[:, ::1], nbt.float64[:, :, ::1]))(
nbt.complex128[:, ::1]
nbt.Tuple((nbt.complex128[:, :], nbt.complex128[:, :], nbt.float64[:, :, ::1]))(
nbt.complex128[:, :]
),
parallel=True,
cache=True,
)
def _choi_to_kraus_preamble(choi_super_op):
U, coefficients, V = ENGINE.linalg.svd(choi_super_op)

Check warning on line 313 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L313

Added line #L313 was not covered by tests
U = np.ascontiguousarray(U)
# U = np.ascontiguousarray(U)
U = numba_transpose(U, ENGINE.arange(U.ndim)[::-1])
coefficients = ENGINE.sqrt(coefficients)
V = ENGINE.conj(V)
coefficients = coefficients.reshape(U.shape[0], 1, 1)

Check warning on line 318 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L315-L318

Added lines #L315 - L318 were not covered by tests
V = np.ascontiguousarray(V)
# V = np.ascontiguousarray(V)
return U, V, coefficients

Check warning on line 320 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L320

Added line #L320 was not covered by tests


Expand All @@ -337,15 +337,46 @@ def _kraus_operators(kraus_left, kraus_right):
def _choi_to_kraus_row(choi_super_op):
U, V, coefficients = _choi_to_kraus_preamble(choi_super_op)
dim = int(np.sqrt(U.shape[-1]))
kraus_left = coefficients * _unvectorization_row(U, dim)
kraus_right = coefficients * _unvectorization_row(V, dim)
kraus_left = coefficients * _unvectorization_row(ENGINE.ascontiguousarray(U), dim)
kraus_right = coefficients * _unvectorization_row(ENGINE.ascontiguousarray(V), dim)
kraus_ops = _kraus_operators(kraus_left, kraus_right)
return kraus_ops, coefficients

Check warning on line 343 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L338-L343

Added lines #L338 - L343 were not covered by tests


setattr(QINFO, "_choi_to_kraus_row", _choi_to_kraus_row)

# TODO: choi to kraus column

@njit(
nbt.Tuple((nbt.complex128[:, :, :, :], nbt.float64[:, :, ::1]))(
nbt.complex128[:, ::1]
),
cache=True,
)
def _choi_to_kraus_column(choi_super_op):
U, V, coefficients = _choi_to_kraus_preamble(choi_super_op)
dim = int(np.sqrt(U.shape[-1]))
kraus_left = coefficients * _unvectorization_column(U, dim)
kraus_right = coefficients * _unvectorization_column(V, dim)
kraus_ops = _kraus_operators(kraus_left, kraus_right)
return kraus_ops, coefficients

Check warning on line 361 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L356-L361

Added lines #L356 - L361 were not covered by tests


setattr(QINFO, "_choi_to_kraus_column", _choi_to_kraus_column)


@njit("c16[:,:](i8, i8, f8, f8)", parallel=True, cache=True)
def _random_gaussian_matrix(dims: int, rank: int, mean: float, stddev: float):
matrix = ENGINE.empty((dims, rank), dtype=ENGINE.complex128)
for i in prange(dims):
for j in prange(rank):
matrix[i, j] = ENGINE.random.normal(
loc=mean, scale=stddev
) + 1.0j * ENGINE.random.normal(loc=mean, scale=stddev)
return matrix


setattr(QINFO, "_random_gaussian_matrix", _random_gaussian_matrix)


"""
@njit(
Expand Down Expand Up @@ -383,4 +414,4 @@ def _kraus_to_stinespring(
)
"""

setattr(QINFO, "_kraus_to_stinespring", _kraus_to_stinespring)
# setattr(QINFO, "_kraus_to_stinespring", _kraus_to_stinespring)

0 comments on commit 8993e63

Please sign in to comment.