diff --git a/interpax/_spline.py b/interpax/_spline.py index 6013582..c44230d 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1072,14 +1072,23 @@ def derivative1(): def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs): """Make arrays periodic along a specified axis.""" period = abs(period) - xq = jnp.where(jnp.logical_or(xq > period, xq < 0), xq % period, xq) - x = jnp.where(jnp.logical_or(x > period, x < 0), x % period, x) + xq = xq % period + x = x % period i = jnp.argsort(x) x = x[i] + x = jnp.concatenate([x[-1:] - period, x, x[:1] + period]) arrs = list(arrs) for k in range(len(arrs)): if arrs[k] is not None: arrs[k] = jnp.take(arrs[k], i, axis, mode="wrap") + arrs[k] = jnp.concatenate( + [ + jnp.take(arrs[k], jnp.array([-1]), axis), + arrs[k], + jnp.take(arrs[k], jnp.array([0]), axis), + ], + axis=axis, + ) return (xq, x, *arrs)