Skip to content

Commit

Permalink
refactor: bandits/*
Browse files Browse the repository at this point in the history
Refactor to make it compatible with JAX>0.2.22; compatible with JSL@a5580c7~
  • Loading branch information
gerdm committed May 7, 2023
1 parent 2ead775 commit 3ade11e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
11 changes: 5 additions & 6 deletions bandits/agents/limited_memory_neural_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax.random import split
from jax.lax import scan, cond
from jax.nn import one_hot
from jax.ops import index_update

import optax

Expand Down Expand Up @@ -87,7 +86,7 @@ def _update_buffer(self, buffer, new_item, index):
"""
source: https://github.com/google/jax/issues/4590
"""
buffer = index_update(buffer, index, new_item)
buffer = buffer.at[index].set(new_item)
index = (index + 1) % self.buffer_size
return buffer, index

Expand Down Expand Up @@ -143,10 +142,10 @@ def loss_fn(params):
b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2

# update only the chosen action at time t
mu = index_update(mu, action, mu_update)
Sigma = index_update(Sigma, action, Sigma_update)
a = index_update(a, action, a_update)
b = index_update(b, action, b_update)
mu = mu.at[action].set(mu_update)
Sigma = Sigma.at[action].set(Sigma_update)
a = a.at[action].set(a_update)
b = b.at[action].set(b_update)
t = t + 1

buffer = (context_buffer, reward_buffer, action_buffer, buffer_ix)
Expand Down
11 changes: 5 additions & 6 deletions bandits/agents/linear_bandit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax.numpy as jnp
from jax import lax
from jax import random
from jax.ops import index_update

from tensorflow_probability.substrates import jax as tfp

Expand Down Expand Up @@ -47,11 +46,11 @@ def update_bel(self, bel, context, action, reward):
b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2

# Update only the chosen action at time t
mu = index_update(mu, action, mu_update)
Sigma = index_update(Sigma, action, Sigma_update)
a = index_update(a, action, a_update)
b = index_update(b, action, b_update)

mu = mu.at[action].set(mu_update)
Sigma = Sigma.at[action].set(Sigma_update)
a = a.at[action].set(a_update)
b = b.at[action].set(b_update)
bel = (mu, Sigma, a, b)

return bel
Expand Down
3 changes: 1 addition & 2 deletions bandits/agents/linear_bandit_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from jax import vmap
from jax.random import split
from jax.nn import one_hot
from jax.ops import index_update
from jax.lax import scan

from .agent_utils import NIGupdate
Expand All @@ -29,7 +28,7 @@ def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25):

def widen(self, context, action):
phi = jnp.zeros((self.num_arms, self.num_features))
phi = index_update(phi, action, context)
phi = phi.at[action].set(context)
return phi.flatten()

def init_bel(self, key, contexts, states, actions, rewards):
Expand Down
9 changes: 4 additions & 5 deletions bandits/agents/linear_kf_bandit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import jax.numpy as jnp
from jax.ops import index_update
from jax.lax import scan
from jax.random import split
from jsl.lds.kalman_filter import KalmanFilterNoiseEstimation
Expand Down Expand Up @@ -46,10 +45,10 @@ def update_bel(self, bel, context, action, reward):

mu_k, Sigma_k, v_k, tau_k = self.kf.kalman_step(state, xs)

mu = index_update(mu, action, mu_k)
Sigma = index_update(Sigma, action, Sigma_k)
v = index_update(v, action, v_k)
tau = index_update(tau, action, tau_k)
mu = mu.at[action].set(mu_k)
Sigma = Sigma.at[action].set(Sigma_k)
v = v.at[action].set(v_k)
tau = tau.at[action].set(tau_k)

bel = (mu, Sigma, v, tau)

Expand Down

0 comments on commit 3ade11e

Please sign in to comment.