From 6df932cb807664e3e583ea93eabca3c0b8b8e1cf Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 4 Mar 2024 00:17:24 -0500 Subject: [PATCH] Add lineax back in, WIP, still need to debug some nans --- interpax/_fd_derivs.py | 5 +++-- requirements.txt | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py index 895f506..57d97b4 100644 --- a/interpax/_fd_derivs.py +++ b/interpax/_fd_derivs.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +import lineax as lx from jax import jit from .utils import errorif @@ -237,9 +238,9 @@ def _cubic2(x, f, axis, bc_type): lower_diag = lower_diag.at[-1].set(dx[-1]) b = b.at[-1].set(0.5 * bc_end[1] * dx[-1] ** 2 + 3 * (f[-1] - f[-2])) - A = jnp.diag(diag) + jnp.diag(upper_diag, k=1) + jnp.diag(lower_diag, k=-1) + A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) - solve = lambda b: jnp.linalg.solve(A, b) + solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T fx = jnp.moveaxis(fx, 0, axis) return fx diff --git a/requirements.txt b/requirements.txt index 6272e0f..9fa60ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ equinox jax >= 0.3.2, <= 0.5.0 +lineax numpy >= 1.20.0, < 2.0