From 3ade11e128e14e13284082b93d863d0cb398ec4b Mon Sep 17 00:00:00 2001 From: Gerardo Duran-Martin Date: Sun, 7 May 2023 08:58:13 +0000 Subject: [PATCH] refactor: bandits/* Refactor to make it compatible with JAX>0.2.22; compatible with JSL@a5580c7~ --- bandits/agents/limited_memory_neural_linear.py | 11 +++++------ bandits/agents/linear_bandit.py | 11 +++++------ bandits/agents/linear_bandit_wide.py | 3 +-- bandits/agents/linear_kf_bandit.py | 9 ++++----- 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/bandits/agents/limited_memory_neural_linear.py b/bandits/agents/limited_memory_neural_linear.py index 0954aeb..b47a5f6 100644 --- a/bandits/agents/limited_memory_neural_linear.py +++ b/bandits/agents/limited_memory_neural_linear.py @@ -3,7 +3,6 @@ from jax.random import split from jax.lax import scan, cond from jax.nn import one_hot -from jax.ops import index_update import optax @@ -87,7 +86,7 @@ def _update_buffer(self, buffer, new_item, index): """ source: https://github.com/google/jax/issues/4590 """ - buffer = index_update(buffer, index, new_item) + buffer = buffer.at[index].set(new_item) index = (index + 1) % self.buffer_size return buffer, index @@ -143,10 +142,10 @@ def loss_fn(params): b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 # update only the chosen action at time t - mu = index_update(mu, action, mu_update) - Sigma = index_update(Sigma, action, Sigma_update) - a = index_update(a, action, a_update) - b = index_update(b, action, b_update) + mu = mu.at[action].set(mu_update) + Sigma = Sigma.at[action].set(Sigma_update) + a = a.at[action].set(a_update) + b = b.at[action].set(b_update) t = t + 1 buffer = (context_buffer, reward_buffer, action_buffer, buffer_ix) diff --git a/bandits/agents/linear_bandit.py b/bandits/agents/linear_bandit.py index f361011..724bb88 100644 --- a/bandits/agents/linear_bandit.py +++ b/bandits/agents/linear_bandit.py @@ -1,7 +1,6 @@ import jax.numpy as jnp from jax import lax from jax import random -from jax.ops import index_update from tensorflow_probability.substrates import jax as tfp @@ -47,11 +46,11 @@ def update_bel(self, bel, context, action, reward): b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 # Update only the chosen action at time t - mu = index_update(mu, action, mu_update) - Sigma = index_update(Sigma, action, Sigma_update) - a = index_update(a, action, a_update) - b = index_update(b, action, b_update) - + mu = mu.at[action].set(mu_update) + Sigma = Sigma.at[action].set(Sigma_update) + a = a.at[action].set(a_update) + b = b.at[action].set(b_update) + bel = (mu, Sigma, a, b) return bel diff --git a/bandits/agents/linear_bandit_wide.py b/bandits/agents/linear_bandit_wide.py index ef4afb5..8925bb6 100644 --- a/bandits/agents/linear_bandit_wide.py +++ b/bandits/agents/linear_bandit_wide.py @@ -10,7 +10,6 @@ from jax import vmap from jax.random import split from jax.nn import one_hot -from jax.ops import index_update from jax.lax import scan from .agent_utils import NIGupdate @@ -29,7 +28,7 @@ def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25): def widen(self, context, action): phi = jnp.zeros((self.num_arms, self.num_features)) - phi = index_update(phi, action, context) + phi = phi.at[action].set(context) return phi.flatten() def init_bel(self, key, contexts, states, actions, rewards): diff --git a/bandits/agents/linear_kf_bandit.py b/bandits/agents/linear_kf_bandit.py index e831bd7..d8602af 100644 --- a/bandits/agents/linear_kf_bandit.py +++ b/bandits/agents/linear_kf_bandit.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -from jax.ops import index_update from jax.lax import scan from jax.random import split from jsl.lds.kalman_filter import KalmanFilterNoiseEstimation @@ -46,10 +45,10 @@ def update_bel(self, bel, context, action, reward): mu_k, Sigma_k, v_k, tau_k = self.kf.kalman_step(state, xs) - mu = index_update(mu, action, mu_k) - Sigma = index_update(Sigma, action, Sigma_k) - v = index_update(v, action, v_k) - tau = index_update(tau, action, tau_k) + mu = mu.at[action].set(mu_k) + Sigma = Sigma.at[action].set(Sigma_k) + v = v.at[action].set(v_k) + tau = tau.at[action].set(tau_k) bel = (mu, Sigma, v, tau)