Skip to content

Commit

Permalink
feat: Allow differentiating/smoothing any axis
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed May 14, 2024
1 parent a76b94b commit d2cddac
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions derivative/differentiation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import numpy as np
from numpy.typing import NDArray

from .utils import _memoize_arrays

Expand Down Expand Up @@ -216,41 +217,38 @@ def x(self, X, t, axis=1):
Returns:
:obj:`ndarray` of float: Returns dX/dt along axis.
"""
X, flat = _align_axes(X, t, axis)
X, orig_shape = _align_axes(X, t, axis)

if X.shape[1] == 1:
dX = X
else:
dX = np.array([list(self.compute_x_for(t, x, np.arange(len(t)))) for x in X])

return _restore_axes(dX, axis, flat)
return _restore_axes(dX, axis, orig_shape)


def _align_axes(X, t, axis):
# Cast
def _align_axes(X, t, axis) -> tuple[NDArray, tuple[int, ...]]:
X = np.array(X)
flat = False
# Check shape and axis
if len(X.shape) == 1:
orig_shape = X.shape
# By convention, differentiate axis 1
if len(orig_shape) == 1:
X = X.reshape(1, -1)
flat = True
elif len(X.shape) == 2:
if axis == 0:
X = X.T
elif axis == 1:
pass
else:
raise ValueError("Invalid axis.")
else:
raise ValueError("Invalid shape of X.")

ax_len = orig_shape[axis]
# order of operations coupled with _restore_axes. Move differentiation axis to
# zero so reshape does not skew differentiation axis
X = np.moveaxis(X, axis, 0).reshape((ax_len, -1)).T
if X.shape[1] != len(t):
raise ValueError("Desired X axis size does not match t size.")
return X, flat
return X, orig_shape


def _restore_axes(dX, axis, flat):
if flat:
def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArray:
if len(orig_shape) == 1:
return dX.flatten()
else:
return dX if axis == 1 else dX.T
# order of operations coupled with _align_axes
extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis)
moved_shape = (orig_shape[axis],) + extra_dims
dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis)
return dX

0 comments on commit d2cddac

Please sign in to comment.