diff --git a/interpax/_spline.py b/interpax/_spline.py index c44230d..6013582 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1072,23 +1072,14 @@ def derivative1(): def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs): """Make arrays periodic along a specified axis.""" period = abs(period) - xq = xq % period - x = x % period + xq = jnp.where(jnp.logical_or(xq > period, xq < 0), xq % period, xq) + x = jnp.where(jnp.logical_or(x > period, x < 0), x % period, x) i = jnp.argsort(x) x = x[i] - x = jnp.concatenate([x[-1:] - period, x, x[:1] + period]) arrs = list(arrs) for k in range(len(arrs)): if arrs[k] is not None: arrs[k] = jnp.take(arrs[k], i, axis, mode="wrap") - arrs[k] = jnp.concatenate( - [ - jnp.take(arrs[k], jnp.array([-1]), axis), - arrs[k], - jnp.take(arrs[k], jnp.array([0]), axis), - ], - axis=axis, - ) return (xq, x, *arrs)