Skip to content

Commit

Permalink
Avoid floating point error in log2
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse authored and gamatos committed Feb 24, 2025
1 parent 4fecd5a commit 153cd10
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
5 changes: 3 additions & 2 deletions qujax/statetensor_observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jax.Array]]) -> jax.

single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq]
single_arrs = [
h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs
h_arr.reshape((2,) * int(jnp.rint(jnp.log2(h_arr.size))))
for h_arr in single_arrs
]

full_mat = single_arrs[0]
for single_matrix in single_arrs[1:]:
full_mat = jnp.kron(full_mat, single_matrix)
full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size)))
full_mat = full_mat.reshape((2,) * int(jnp.rint(jnp.log2(full_mat.size))))
return full_mat


Expand Down
9 changes: 7 additions & 2 deletions tests/test_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def test_sampling():
target_pmf = jnp.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0])
target_pmf /= target_pmf.sum()

target_st = jnp.sqrt(target_pmf).reshape((2,) * int(jnp.log2(target_pmf.size)))
target_st = jnp.sqrt(target_pmf).reshape(
(2,) * int(jnp.rint(jnp.log2(target_pmf.size)))
)

n_samps = 7

Expand All @@ -244,5 +246,8 @@ def test_sampling():
assert all(target_pmf[sample_ints] > 0)

sample_bitstrings = qujax.sample_bitstrings(random.PRNGKey(0), target_st, n_samps)
assert sample_bitstrings.shape == (n_samps, int(jnp.log2(target_pmf.size)))
assert sample_bitstrings.shape == (
n_samps,
int(jnp.rint(jnp.log2(target_pmf.size))),
)
assert all(qujax.bitstrings_to_integers(sample_bitstrings) == sample_ints)

0 comments on commit 153cd10

Please sign in to comment.