Skip to content

Commit

Permalink
using jax lax transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 10, 2024
1 parent b4e03d9 commit 3cdc5f1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 0 additions & 2 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def apply_operator(
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)

# Apply operator to density matrix: ρ' = O ρ O†

support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target))
# print("init 1", state.reshape((4, 4)).round(4))
state = permute_basis(state, support_perm, False)
state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims))

Expand Down
3 changes: 2 additions & 1 deletion horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> A
perm = tuple(ranked_support) + tuple(ranked_support + n_qubits)
if inv:
perm = np.argsort(perm)
return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm)
return jax.lax.transpose(operator, perm)
# return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm)


class StrEnum(str, Enum):
Expand Down

0 comments on commit 3cdc5f1

Please sign in to comment.