From 71ac105b4c28708811c9b4bd8d896ba449423919 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 12 Jun 2024 13:05:45 -0400 Subject: [PATCH] Revert "Merge pull request #29 from f0uriest/rc/periodic" This reverts commit 019d2ec68f2399f4723253b9bccadd4f9f78a293, reversing changes made to 850d8468d18d40acbc61a247baf33099a5990e8b. --- interpax/_spline.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 6013582..c44230d 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1072,14 +1072,23 @@ def derivative1(): def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs): """Make arrays periodic along a specified axis.""" period = abs(period) - xq = jnp.where(jnp.logical_or(xq > period, xq < 0), xq % period, xq) - x = jnp.where(jnp.logical_or(x > period, x < 0), x % period, x) + xq = xq % period + x = x % period i = jnp.argsort(x) x = x[i] + x = jnp.concatenate([x[-1:] - period, x, x[:1] + period]) arrs = list(arrs) for k in range(len(arrs)): if arrs[k] is not None: arrs[k] = jnp.take(arrs[k], i, axis, mode="wrap") + arrs[k] = jnp.concatenate( + [ + jnp.take(arrs[k], jnp.array([-1]), axis), + arrs[k], + jnp.take(arrs[k], jnp.array([0]), axis), + ], + axis=axis, + ) return (xq, x, *arrs)