Skip to content

Commit

Permalink
Merge pull request #29 from f0uriest/rc/periodic
Browse files Browse the repository at this point in the history
Fix periodic transformation to avoid dx==0
  • Loading branch information
f0uriest authored Mar 26, 2024
2 parents 850d846 + 706d394 commit 019d2ec
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,23 +1072,14 @@ def derivative1():
def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs):
"""Make arrays periodic along a specified axis."""
period = abs(period)
xq = xq % period
x = x % period
xq = jnp.where(jnp.logical_or(xq > period, xq < 0), xq % period, xq)
x = jnp.where(jnp.logical_or(x > period, x < 0), x % period, x)
i = jnp.argsort(x)
x = x[i]
x = jnp.concatenate([x[-1:] - period, x, x[:1] + period])
arrs = list(arrs)
for k in range(len(arrs)):
if arrs[k] is not None:
arrs[k] = jnp.take(arrs[k], i, axis, mode="wrap")
arrs[k] = jnp.concatenate(
[
jnp.take(arrs[k], jnp.array([-1]), axis),
arrs[k],
jnp.take(arrs[k], jnp.array([0]), axis),
],
axis=axis,
)
return (xq, x, *arrs)


Expand Down

0 comments on commit 019d2ec

Please sign in to comment.