From 36c5f32804c1de1b8b6b94b93cead0bcf6d16742 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerardo=20Dur=C3=A1n-Mart=C3=ADn?= Date: Wed, 29 Dec 2021 23:04:20 +0000 Subject: [PATCH 1/4] remove nlds_lib/* --- bandits/nlds_lib/base.py | 79 -------- .../diagonal_extended_kalman_filter.py | 74 -------- bandits/nlds_lib/extended_kalman_filter.py | 81 --------- bandits/nlds_lib/lds_lib_orig.py | 125 ------------- bandits/nlds_lib/unscented_kalman_filter.py | 171 ------------------ 5 files changed, 530 deletions(-) delete mode 100644 bandits/nlds_lib/base.py delete mode 100644 bandits/nlds_lib/diagonal_extended_kalman_filter.py delete mode 100644 bandits/nlds_lib/extended_kalman_filter.py delete mode 100644 bandits/nlds_lib/lds_lib_orig.py delete mode 100644 bandits/nlds_lib/unscented_kalman_filter.py diff --git a/bandits/nlds_lib/base.py b/bandits/nlds_lib/base.py deleted file mode 100644 index 8e6e712..0000000 --- a/bandits/nlds_lib/base.py +++ /dev/null @@ -1,79 +0,0 @@ -# Library of nonlinear dynamical systems -# Usage: Every discrete xKF class inherits from NLDS. -# There are two ways to use this library in the discrete case: -# 1) Explicitly initialize a discrete NLDS object with the desired parameters, -# then pass it onto the xKF class of your choice. -# 2) Initialize the xKF object with the desired NLDS parameters using -# the .from_base constructor. -# Way 1 is preferable whenever you want to use the same NLDS for multiple -# filtering processes. Way 2 is preferred whenever you want to use a single NLDS -# for a single filtering process - -# Author: Gerardo Durán-Martín (@gerdm) - -import jax -from jax.random import split, multivariate_normal - - -class NLDS: - """ - Base class for the Nonliear dynamical systems' module - """ - - def __init__(self, fz, fx, Q, R): - self.fz = fz - self.fx = fx - self.__Q = Q - self.__R = R - - def Q(self, z, *args): - if callable(self.__Q): - return self.__Q(z, *args) - else: - return self.__Q - - def R(self, x, *args): - if callable(self.__R): - return self.__R(x, *args) - else: - return self.__R - - def __sample_step(self, input_vals, obs): - key, state_t = input_vals - key_system, key_obs, key = split(key, 3) - - state_t = multivariate_normal(key_system, self.fz(state_t), self.Q(state_t)) - obs_t = multivariate_normal(key_obs, self.fx(state_t, *obs), self.R(state_t, *obs)) - - return (key, state_t), (state_t, obs_t) - - def sample(self, key, x0, nsteps, obs=None): - """ - Sample discrete elements of a nonlinear system - Parameters - ---------- - key: jax.random.PRNGKey - x0: array(state_size) - Initial state of simulation - nsteps: int - Total number of steps to sample from the system - obs: None, tuple of arrays - Observed values to pass to fx and R - Returns - ------- - * array(nsamples, state_size) - State-space values - * array(nsamples, obs_size) - Observed-space values - """ - obs = () if obs is None else obs - state_t = x0.copy() - obs_t = self.fx(state_t) - - self.state_size, *_ = state_t.shape - self.obs_t, *_ = obs_t.shape - - init_state = (key, state_t) - _, hist = jax.lax.scan(self.__sample_step, init_state, obs, length=nsteps) - - return hist diff --git a/bandits/nlds_lib/diagonal_extended_kalman_filter.py b/bandits/nlds_lib/diagonal_extended_kalman_filter.py deleted file mode 100644 index 2494d3e..0000000 --- a/bandits/nlds_lib/diagonal_extended_kalman_filter.py +++ /dev/null @@ -1,74 +0,0 @@ -import jax.numpy as jnp -from jax import jacrev -from .base import NLDS - - -class DiagonalExtendedKalmanFilter(NLDS): - """ - Implementation of the Diagonal Extended Kalman Filter for a nonlinear - dynamical system with discrete observations. Also known as the - Node-decoupled Extended Kalman Filter (NDEKF) - """ - - def __init__(self, fz, fx, Q, R): - super().__init__(fz, fx, Q, R) - self.Dfz = jacrev(fz) - self.Dfx = jacrev(fx) - - @classmethod - def from_base(cls, model): - """ - Initialise class from an instance of the NLDS parent class - """ - return cls(model.fz, model.fx, model.Q, model.R) - - def filter_step(self, state, xs, eps=0.001): - """ - Run the Extended Kalman filter algorithm for a single step - Paramters - --------- - state: tuple - Mean, covariance at time t-1 - xs: tuple - Target value and observations at time t - """ - mu_t, Vt, t = state - xt, obs = xs - - mu_t_cond = self.fz(mu_t) - Ht = self.Dfx(mu_t_cond, *obs) - - Rt = self.R(mu_t_cond, *obs) - xt_hat = self.fx(mu_t_cond, *obs) - xi = xt - xt_hat - A = jnp.linalg.inv(Rt + jnp.einsum("id,jd,d->ij", Ht, Ht, Vt)) - mu_t = mu_t_cond + jnp.einsum("s,is,ij,j->s", Vt, Ht, A, xi) - Vt = Vt - jnp.einsum("s,is,ij,is,s->s", Vt, Ht, A, Ht, Vt) + self.Q(mu_t, t) - - return (mu_t, Vt, t + 1), (mu_t, None) - - def filter(self, init_state, sample_obs, observations=None, Vinit=None): - """ - Run the Extended Kalman Filter algorithm over a set of observed samples. - Parameters - ---------- - init_state: array(state_size) - sample_obs: array(nsamples, obs_size) - Returns - ------- - * array(nsamples, state_size) - History of filtered mean terms - * array(nsamples, state_size, state_size) - History of filtered covariance terms - """ - self.state_size, *_ = init_state.shape - - Vt = self.Q(init_state) if Vinit is None else Vinit - - t = 0 - state = (init_state, Vinit, t) - observations = (observations,) if type(observations) is not tuple else observations - xs = (sample_obs, observations) - (mu_t, Vt, _), mu_t_hist = jax.lax.scan(self.filter_step, state, xs) - - return (mu_t, Vt), mu_t_hist diff --git a/bandits/nlds_lib/extended_kalman_filter.py b/bandits/nlds_lib/extended_kalman_filter.py deleted file mode 100644 index 5629183..0000000 --- a/bandits/nlds_lib/extended_kalman_filter.py +++ /dev/null @@ -1,81 +0,0 @@ -import jax.numpy as jnp -from jax import jacrev -from jax.lax import scan - -from .base import NLDS - - -class ExtendedKalmanFilter(NLDS): - """ - Implementation of the Extended Kalman Filter for a nonlinear - dynamical system with discrete observations - """ - - def __init__(self, fz, fx, Q, R): - super().__init__(fz, fx, Q, R) - self.Dfz = jacrev(fz) - self.Dfx = jacrev(fx) - - @classmethod - def from_base(cls, model): - """ - Initialise class from an instance of the NLDS parent class - """ - return cls(model.fz, model.fx, model.Q, model.R) - - def filter_step(self, state, xs, eps=0.001): - """ - Run the Extended Kalman filter algorithm for a single step - Paramters - --------- - state: tuple - Mean, covariance at time t-1 - xs: tuple - Target value and observations at time t - """ - mu_t, Vt, t = state - xt, obs = xs - - state_size, *_ = mu_t.shape - I = jnp.eye(state_size) - Gt = self.Dfz(mu_t) - mu_t_cond = self.fz(mu_t) - Vt_cond = Gt @ Vt @ Gt.T + self.Q(mu_t, t) - Ht = self.Dfx(mu_t_cond, *obs) - - Rt = self.R(mu_t_cond, *obs) - num_inputs, *_ = Rt.shape - - xt_hat = self.fx(mu_t_cond, *obs) - Mt = Ht @ Vt_cond @ Ht.T + Rt + eps * jnp.eye(num_inputs) - Kt = Vt_cond @ Ht.T @ jnp.linalg.inv(Mt) - mu_t = mu_t_cond + Kt @ (xt - xt_hat) - Vt = (I - Kt @ Ht) @ Vt_cond @ (I - Kt @ Ht).T + Kt @ Rt @ Kt.T - # Vt = (I - Kt @ Ht) @ Vt_cond - return (mu_t, Vt, t + 1), (mu_t, None) - - def filter(self, init_state, sample_obs, observations=None, Vinit=None): - """ - Run the Extended Kalman Filter algorithm over a set of observed samples. - Parameters - ---------- - init_state: array(state_size) - sample_obs: array(nsamples, obs_size) - Returns - ------- - * array(nsamples, state_size) - History of filtered mean terms - * array(nsamples, state_size, state_size) - History of filtered covariance terms - """ - self.state_size, *_ = init_state.shape - - Vt = self.Q(init_state) if Vinit is None else Vinit - - t = 0 - state = (init_state, Vinit, t) - observations = (observations,) if type(observations) is not tuple else observations - xs = (sample_obs, observations) - (mu_t, Vt, _), mu_t_hist = scan(self.filter_step, state, xs) - - return (mu_t, Vt), mu_t_hist diff --git a/bandits/nlds_lib/lds_lib_orig.py b/bandits/nlds_lib/lds_lib_orig.py deleted file mode 100644 index 63991ba..0000000 --- a/bandits/nlds_lib/lds_lib_orig.py +++ /dev/null @@ -1,125 +0,0 @@ -from jax import vmap -from jax.lax import Precision -import jax.numpy as jnp -from jax.lax import scan -from tensorflow_probability.substrates import jax as tfp - -tfd = tfp.distributions - - -class KalmanFilterNoiseEstimation: - """ - Implementation of the Kalman Filtering and Smoothing - procedure of a Linear Dynamical System with known parameters. - This class exemplifies the use of Kalman Filtering assuming - the model parameters are known. - Parameters - ---------- - A: array(state_size, state_size) - Transition matrix - C: array(observation_size, state_size) - Observation matrix - Q: array(state_size, state_size) - Transition covariance matrix - R: array(observation_size, observation_size) - Observation covariance - mu0: array(state_size) - Mean of initial configuration - Sigma0: array(state_size, state_size) or 0 - Covariance of initial configuration. If value is set - to zero, the initial state will be completely determined - by mu0 - timesteps: int - Total number of steps to sample - """ - - def __init__(self, A, Q, mu0, Sigma0, v0, tau0, update_fn=None): - self.A = A - self.Q = Q - self.mu0 = mu0 - self.Sigma0 = Sigma0 - self.v = v0 - self.tau = tau0 - self.__update_fn = update_fn - - def update(self, state, bel, *args): - if self.__update_fn is None: - return bel - else: - return self.__update_fn(state, bel, *args) - - def kalman_step(self, state, xt): - mu, Sigma, v, tau = state - x, y = xt - - mu_cond = jnp.matmul(self.A, mu, precision=Precision.HIGHEST) - Sigmat_cond = jnp.matmul(jnp.matmul(self.A, Sigma, precision=Precision.HIGHEST), self.A, - precision=Precision.HIGHEST) + self.Q - - e_k = y - x.T @ mu_cond - s_k = x.T @ Sigmat_cond @ x + 1 - Kt = (Sigmat_cond @ x) / s_k - - mu = mu + e_k * Kt - Sigma = Sigmat_cond - jnp.outer(Kt, Kt) * s_k - - v_update = v + 1 - tau = (v * tau + (e_k * e_k) / s_k) / v_update - - return mu, Sigma, v_update, tau - - def __kalman_filter(self, x_hist): - """ - Compute the online version of the Kalman-Filter, i.e, - the one-step-ahead prediction for the hidden state or the - time update step - Parameters - ---------- - x_hist: array(timesteps, observation_size) - Returns - ------- - * array(timesteps, state_size): - Filtered means mut - * array(timesteps, state_size, state_size) - Filtered covariances Sigmat - * array(timesteps, state_size) - Filtered conditional means mut|t-1 - * array(timesteps, state_size, state_size) - Filtered conditional covariances Sigmat|t-1 - """ - _, (mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist) = scan(self.kalman_step, - (self.mu0, self.Sigma0, 0), x_hist) - return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist - - def filter(self, x_hist): - """ - Compute the online version of the Kalman-Filter, i.e, - the one-step-ahead prediction for the hidden state or the - time update step. - Note that x_hist can optionally be of dimensionality two, - This corresponds to different samples of the same underlying - Linear Dynamical System - Parameters - ---------- - x_hist: array(n_samples?, timesteps, observation_size) - Returns - ------- - * array(n_samples?, timesteps, state_size): - Filtered means mut - * array(n_samples?, timesteps, state_size, state_size) - Filtered covariances Sigmat - * array(n_samples?, timesteps, state_size) - Filtered conditional means mut|t-1 - * array(n_samples?, timesteps, state_size, state_size) - Filtered conditional covariances Sigmat|t-1 - """ - has_one_sim = False - if x_hist.ndim == 2: - x_hist = x_hist[None, ...] - has_one_sim = True - kalman_map = vmap(self.__kalman_filter, 0) - mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = kalman_map(x_hist) - if has_one_sim: - mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = mu_hist[0, ...], Sigma_hist[0, ...], mu_cond_hist[ - 0, ...], Sigma_cond_hist[0, ...] - return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist diff --git a/bandits/nlds_lib/unscented_kalman_filter.py b/bandits/nlds_lib/unscented_kalman_filter.py deleted file mode 100644 index 0a0cb44..0000000 --- a/bandits/nlds_lib/unscented_kalman_filter.py +++ /dev/null @@ -1,171 +0,0 @@ -import jax.numpy as jnp -from jax.random import split, choice, multivariate_normal -from jax.lax import scan -from jax.ops import index_update -from jax.scipy import stats - -from .base import NLDS - - -class UnscentedKalmanFilter(NLDS): - """ - Implementation of the Unscented Kalman Filter for discrete time systems - """ - - def __init__(self, fz, fx, Q, R, alpha, beta, kappa, d): - super().__init__(fz, fx, Q, R) - self.d = d - self.alpha = alpha - self.beta = beta - self.kappa = kappa - self.lmbda = alpha ** 2 * (self.d + kappa) - self.d - self.gamma = jnp.sqrt(self.d + self.lmbda) - - @classmethod - def from_base(cls, model, alpha, beta, kappa, d): - """ - Initialise class from an instance of the NLDS parent class - """ - return cls(model.fz, model.fx, model.Q, model.R, alpha, beta, kappa, d) - - @staticmethod - def sqrtm(M): - """ - Compute the matrix square-root of a hermitian - matrix M. i,e, R such that RR = M - - Parameters - ---------- - M: array(m, m) - Hermitian matrix - - Returns - ------- - array(m, m): square-root matrix - """ - evals, evecs = jnp.linalg.eigh(M) - R = evecs @ jnp.sqrt(jnp.diag(evals)) @ jnp.linalg.inv(evecs) - return R - - def filter(self, init_state, sample_obs, observations=None, Vinit=None): - """ - Run the Unscented Kalman Filter algorithm over a set of observed samples. - Parameters - ---------- - sample_obs: array(nsamples, obs_size) - Returns - ------- - * array(nsamples, state_size) - History of filtered mean terms - * array(nsamples, state_size, state_size) - History of filtered covariance terms - """ - wm_vec = jnp.array([1 / (2 * (self.d + self.lmbda)) if i > 0 - else self.lmbda / (self.d + self.lmbda) - for i in range(2 * self.d + 1)]) - wc_vec = jnp.array([1 / (2 * (self.d + self.lmbda)) if i > 0 - else self.lmbda / (self.d + self.lmbda) + (1 - self.alpha ** 2 + self.beta) - for i in range(2 * self.d + 1)]) - nsteps, *_ = sample_obs.shape - mu_t = init_state - Sigma_t = self.Q(init_state) if Vinit is None else Vinit - if observations is None: - observations = [()] * nsteps - else: - observations = [(obs,) for obs in observations] - - mu_hist = jnp.zeros((nsteps, self.d)) - Sigma_hist = jnp.zeros((nsteps, self.d, self.d)) - - mu_hist = index_update(mu_hist, 0, mu_t) - Sigma_hist = index_update(Sigma_hist, 0, Sigma_t) - - for t in range(nsteps): - # TO-DO: use jax.scipy.linalg.sqrtm when it gets added to lib - comp1 = mu_t[:, None] + self.gamma * self.sqrtm(Sigma_t) - comp2 = mu_t[:, None] - self.gamma * self.sqrtm(Sigma_t) - # sigma_points = jnp.c_[mu_t, comp1, comp2] - sigma_points = jnp.concatenate((mu_t[:, None], comp1, comp2), axis=1) - - z_bar = self.fz(sigma_points) - mu_bar = z_bar @ wm_vec - Sigma_bar = (z_bar - mu_bar[:, None]) - Sigma_bar = jnp.einsum("i,ji,ki->jk", wc_vec, Sigma_bar, Sigma_bar) + self.Q(mu_t) - - Sigma_bar_half = self.sqrtm(Sigma_bar) - comp1 = mu_bar[:, None] + self.gamma * Sigma_bar_half - comp2 = mu_bar[:, None] - self.gamma * Sigma_bar_half - # sigma_points = jnp.c_[mu_bar, comp1, comp2] - sigma_points = jnp.concatenate((mu_bar[:, None], comp1, comp2), axis=1) - - x_bar = self.fx(sigma_points, *observations[t]) - x_hat = x_bar @ wm_vec - St = x_bar - x_hat[:, None] - St = jnp.einsum("i,ji,ki->jk", wc_vec, St, St) + self.R(mu_t, *observations[t]) - - mu_hat_component = z_bar - mu_bar[:, None] - x_hat_component = x_bar - x_hat[:, None] - Sigma_bar_y = jnp.einsum("i,ji,ki->jk", wc_vec, mu_hat_component, x_hat_component) - Kt = Sigma_bar_y @ jnp.linalg.inv(St) - - mu_t = mu_bar + Kt @ (sample_obs[t] - x_hat) - Sigma_t = Sigma_bar - Kt @ St @ Kt.T - - mu_hist = index_update(mu_hist, t, mu_t) - Sigma_hist = index_update(Sigma_hist, t, Sigma_t) - - return mu_hist, Sigma_hist - - def __init__(self, fz, fx, Q, R): - """ - Implementation of the Bootrstrap Filter for discrete time systems - **This implementation considers the case of multivariate normals** - to-do: extend to general case - """ - super().__init__(fz, fx, Q, R) - - def __filter_step(self, state, obs_t): - nsamples = self.nsamples - indices = jnp.arange(nsamples) - zt_rvs, key_t = state - - key_t, key_reindex, key_next = split(key_t, 3) - # 1. Draw new points from the dynamic model - zt_rvs = multivariate_normal(key_t, self.fz(zt_rvs), self.Q(zt_rvs)) - - # 2. Calculate unnormalised weights - xt_rvs = self.fx(zt_rvs) - weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R(zt_rvs, obs_t)) - - # 3. Resampling - pi = choice(key_reindex, indices, - p=weights_t, shape=(nsamples,)) - zt_rvs = zt_rvs[pi, ...] - weights_t = jnp.ones(nsamples) / nsamples - - # 4. Compute latent-state estimate, - # Set next covariance state matrix - mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t) - - return (zt_rvs, key_next), mu_t - - def filter(self, key, init_state, sample_obs, nsamples=2000, Vinit=None): - """ - init_state: array(state_size,) - Initial state estimate - sample_obs: array(nsamples, obs_size) - Samples of the observations - """ - m, *_ = init_state.shape - nsteps = sample_obs.shape[0] - mu_hist = jnp.zeros((nsteps, m)) - - key, key_init = split(key, 2) - V = self.Q(init_state) if Vinit is None else Vinit - zt_rvs = multivariate_normal(key_init, init_state, V, shape=(nsamples,)) - - init_state = (zt_rvs, key) - self.nsamples = nsamples - _, mu_hist = scan(self.__filter_step, init_state, sample_obs) - - return mu_hist From 740f384dd34ef3cadf8507f09171bbe81d4076d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerardo=20Dur=C3=A1n-Mart=C3=ADn?= Date: Thu, 30 Dec 2021 01:10:22 +0000 Subject: [PATCH 2/4] feat: requirements.txt Add jsl elements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index aaf950b..5c014e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,3 +44,4 @@ tomli>=1.2.2 toolz>=0.11.2 typing_extensions>=4.0.1 urllib3>=1.26.7 +jsl @ git+git://github.com/probml/jsl@8c35b11bb11f83218e1958d5ad9f32b48e546595 \ No newline at end of file From e678f24a9484d03973932e289408bdea9514abd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerardo=20Dur=C3=A1n-Mart=C3=ADn?= Date: Thu, 30 Dec 2021 01:12:04 +0000 Subject: [PATCH 3/4] refactor: bandits/agents/* Import from JSL library --- bandits/agents/diagonal_subspace.py | 2 +- bandits/agents/ekf_orig_diag.py | 2 +- bandits/agents/ekf_orig_full.py | 2 +- bandits/agents/ekf_subspace.py | 2 +- bandits/agents/linear_kf_bandit.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bandits/agents/diagonal_subspace.py b/bandits/agents/diagonal_subspace.py index ede01e6..0f72d09 100644 --- a/bandits/agents/diagonal_subspace.py +++ b/bandits/agents/diagonal_subspace.py @@ -1,5 +1,5 @@ import jax.numpy as jnp -from nlds_lib.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter +from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter from .ekf_subspace import SubspaceNeuralBandit from tensorflow_probability.substrates import jax as tfp diff --git a/bandits/agents/ekf_orig_diag.py b/bandits/agents/ekf_orig_diag.py index 802e409..50d7984 100644 --- a/bandits/agents/ekf_orig_diag.py +++ b/bandits/agents/ekf_orig_diag.py @@ -7,7 +7,7 @@ from .agent_utils import train from scripts.training_utils import MLP -from nlds_lib.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter +from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter from tensorflow_probability.substrates import jax as tfp diff --git a/bandits/agents/ekf_orig_full.py b/bandits/agents/ekf_orig_full.py index e9cfa27..d58ea28 100644 --- a/bandits/agents/ekf_orig_full.py +++ b/bandits/agents/ekf_orig_full.py @@ -7,7 +7,7 @@ from flax.training import train_state from .agent_utils import train -from nlds_lib.extended_kalman_filter import ExtendedKalmanFilter +from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter from scripts.training_utils import MLP from tensorflow_probability.substrates import jax as tfp diff --git a/bandits/agents/ekf_subspace.py b/bandits/agents/ekf_subspace.py index fa4a97f..86c2e27 100644 --- a/bandits/agents/ekf_subspace.py +++ b/bandits/agents/ekf_subspace.py @@ -7,7 +7,7 @@ from sklearn.decomposition import PCA from .agent_utils import train, generate_random_basis, convert_params_from_subspace_to_full from scripts.training_utils import MLP -from nlds_lib.extended_kalman_filter import ExtendedKalmanFilter +from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions diff --git a/bandits/agents/linear_kf_bandit.py b/bandits/agents/linear_kf_bandit.py index e44428a..e831bd7 100644 --- a/bandits/agents/linear_kf_bandit.py +++ b/bandits/agents/linear_kf_bandit.py @@ -2,7 +2,7 @@ from jax.ops import index_update from jax.lax import scan from jax.random import split -from nlds_lib.lds_lib_orig import KalmanFilterNoiseEstimation +from jsl.lds.kalman_filter import KalmanFilterNoiseEstimation from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions From 83c799c180154af2da3a95be91ba124954269fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerardo=20Dur=C3=A1n-Mart=C3=ADn?= Date: Thu, 30 Dec 2021 01:12:49 +0000 Subject: [PATCH 4/4] fix: requirements.txt JSL depends on latest version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5c014e5..d8e02c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,4 +44,4 @@ tomli>=1.2.2 toolz>=0.11.2 typing_extensions>=4.0.1 urllib3>=1.26.7 -jsl @ git+git://github.com/probml/jsl@8c35b11bb11f83218e1958d5ad9f32b48e546595 \ No newline at end of file +jsl @ git+git://github.com/probml/jsl \ No newline at end of file