Skip to content

Commit

Permalink
Add smoothing spline
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Mar 26, 2024
1 parent 850d846 commit bd34f08
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions interpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CubicSpline,
PchipInterpolator,
PPoly,
SmoothingSpline,
)
from ._spline import (
Interpolator1D,
Expand Down
127 changes: 127 additions & 0 deletions interpax/_ppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,130 @@ def __init__(
x, _, y, axis, _ = prepare_input(x, y, axis, check=check)
df = approx_df(x, y, "cubic2", axis, bc_type=bc_type)
super().__init__(x, y, df, axis=axis, extrapolate=extrapolate, check=check)


class SmoothingSpline(CubicSpline):
"""Smoothing spline for noisy data.
The spline f minimizes
p ∑ᵢ wᵢ ||yᵢ − f(xᵢ)||² + (1−p) ∫ₓ ||f''(x)||²
Parameters
----------
x : array_like, shape (n,)
1-D array containing values of the independent variable.
Values must be real, finite and in strictly increasing order.
y : array_like, shape(n,...)
Array containing values of the dependent variable. It can have
arbitrary number of dimensions, but the length along ``axis``
(see below) must match the length of ``x``. Values must be finite.
p : float, optional
Smoothing parameter in the range [0,1].
- For ``p=0`` the spline is a straight line fit to the data.
- For ``p=1``, it is the cubic spline interpolant.
If not given, ``p`` is determined automatically given the data sites. The
calculation of the smoothing spline requires the solution of a linear system
whose coefficient matrix has the form pA + (1-p)B, with the matrices A and B
depending on the data sites. The automatically computed smoothing parameter
makes p*trace(A) equal (1 - p)*trace(B).
w : array_like, shape(n,)
Weights for spline fitting. Must be positive. If None, then weights are all
equal. Default is None.
axis : int, optional
Axis along which `y` is assumed to be varying. Meaning that for
``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``.
Default is 0.
bc_type : string or 2-tuple, optional
Boundary condition type. Two additional equations, given by the
boundary conditions, are required to determine all coefficients of
polynomials on each segment [2]_.
If `bc_type` is a string, then the specified condition will be applied
at both ends of a spline. Available conditions are:
* 'not-a-knot' (default): The first and second segment at a curve end
are the same polynomial. It is a good default when there is no
information on boundary conditions.
* 'periodic': The interpolated functions is assumed to be periodic
of period ``x[-1] - x[0]``. The first and last value of `y` must be
identical: ``y[0] == y[-1]``. This boundary condition will result in
``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``.
* 'clamped': The first derivative at curves ends are zero. Assuming
a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition.
* 'natural': The second derivative at curve ends are zero. Assuming
a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition.
If `bc_type` is a 2-tuple, the first and the second value will be
applied at the curve start and end respectively. The tuple values can
be one of the previously mentioned strings (except 'periodic') or a
tuple `(order, deriv_values)` allowing to specify arbitrary
derivatives at curve ends:
* `order`: the derivative order, 1 or 2.
* `deriv_value`: array_like containing derivative values, shape must
be the same as `y`, excluding ``axis`` dimension. For example, if
`y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with
the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D
and have the shape (n0, n1).
extrapolate : {bool, 'periodic', None}, optional
If bool, determines whether to extrapolate to out-of-bounds points
based on first and last intervals, or to return NaNs. If 'periodic',
periodic extrapolation is used. If None (default), ``extrapolate`` is
set to 'periodic' for ``bc_type='periodic'`` and to True otherwise.
check : bool
Whether to perform checks on the input. Should be False if used under JIT.
"""

def __init__(
self,
x: jax.Array,
y: jax.Array,
p: float = None,
w: jax.Array = None,
axis: int = 0,
bc_type: Union[str, tuple] = "natural",
extrapolate: Union[bool, str] = None,
check: bool = True,
):
if w is None:
w = jnp.ones_like(x)

dx = jnp.diff((x[1:] + x[:-1]) / 2, prepend=x[0], append=x[-1])

@jax.jit
def loss(f):
g = CubicSpline(
x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=False
)
return jnp.concatenate([g(x), g(x, nu=2)])

AB = jax.jit(jax.jacfwd(loss))(jnp.zeros_like(x))
A = AB[: AB.shape[1]]
B = AB[AB.shape[1] :]
A = jnp.sqrt(w)[:, None] * A
B = jnp.sqrt(dx)[:, None] * B

# normalize smoothing parameter
span = jnp.ptp(x)
eff_x = 1 + (span**2) / jnp.sum(jnp.diff(x) ** 2)
eff_w = jnp.sum(w) ** 2 / jnp.sum(w**2)
k = 80 * (span**3) * (x.size**-2) * (eff_x**-0.5) * (eff_w**-0.5)
s = 0.5 if p is None else p
p = s / (s + (1 - s) * k)

# p w ||Af - y||_2 + (1-p) ||Bf||_2
# ||sqrt(p w)(Af - y)||_2 + ||sqrt(1-p)Bf||_2

lhs = jnp.vstack([jnp.sqrt(p) * A, jnp.sqrt(1 - p) * B])

y = jnp.moveaxis(y, axis, 0)
rhs = jnp.concatenate([jnp.sqrt(p * w) * y, jnp.zeros_like(y)], axis=0)
f = jnp.linalg.lstsq(lhs, rhs, rcond=None)[0]
f = jnp.moveaxis(f, 0, axis)
super().__init__(
x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=check
)

0 comments on commit bd34f08

Please sign in to comment.