Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634090639
Change-Id: Ib2de4babc80eed7f32841ae70b5744a9be3d3a97
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 15, 2024
1 parent 290d13a commit efaddc5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
32 changes: 22 additions & 10 deletions neural_testbed/agents/factories/preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,39 @@ def get_rmsprop_preconditioner(running_average_factor=0.99, eps=1e-7):

def init_fn(params):
return RMSPropPreconditionerState(
grad_moment_estimates=jax.tree_map(jnp.zeros_like, params))
grad_moment_estimates=jax.tree.map(jnp.zeros_like, params)
)

def update_preconditioner_fn(gradient, preconditioner_state):
r = running_average_factor
grad_moment_estimates = jax.tree_map(
lambda e, g: e * r + g**2 *(1-r),
preconditioner_state.grad_moment_estimates, gradient)
grad_moment_estimates = jax.tree.map(
lambda e, g: e * r + g**2 * (1 - r),
preconditioner_state.grad_moment_estimates,
gradient,
)
return RMSPropPreconditionerState(
grad_moment_estimates=grad_moment_estimates)

def multiply_by_m_inv_fn(vec, preconditioner_state):
return jax.tree_map(lambda e, v: v / (eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates, vec)
return jax.tree.map(
lambda e, v: v / (eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates,
vec,
)

def multiply_by_m_sqrt_fn(vec, preconditioner_state):
return jax.tree_map(lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates, vec)
return jax.tree.map(
lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates,
vec,
)

def multiply_by_m_sqrt_inv_fn(vec, preconditioner_state):
return jax.tree_map(lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates, vec)
return jax.tree.map(
lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)),
preconditioner_state.grad_moment_estimates,
vec,
)

return Preconditioner(
init=init_fn,
Expand Down
21 changes: 12 additions & 9 deletions neural_testbed/agents/factories/sgld_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@

def normal_like_tree(a, key):
"""Generate Gaussian noises."""
treedef = jax.tree_structure(a)
num_vars = len(jax.tree_leaves(a))
treedef = jax.tree.structure(a)
num_vars = len(jax.tree.leaves(a))
all_keys = jax.random.split(key, num=(num_vars + 1))
noise = jax.tree_map(lambda p, k: jax.random.normal(k, shape=p.shape), a,
jax.tree_unflatten(treedef, all_keys[1:]))
noise = jax.tree.map(
lambda p, k: jax.random.normal(k, shape=p.shape),
a,
jax.tree.unflatten(treedef, all_keys[1:]),
)
return noise, all_keys[0]


Expand Down Expand Up @@ -82,8 +85,9 @@ def init_fn(params):
return OptaxSGLDState(
count=jnp.zeros([], jnp.int32),
rng_key=jax.random.PRNGKey(seed),
momentum=jax.tree_map(jnp.zeros_like, params),
preconditioner_state=preconditioner.init(params))
momentum=jax.tree.map(jnp.zeros_like, params),
preconditioner_state=preconditioner.init(params),
)

def update_fn(gradient, state, params=None):
del params
Expand All @@ -98,10 +102,9 @@ def update_fn(gradient, state, params=None):
def update_momentum(m, g, n):
return momentum_decay * m + g * jnp.sqrt(step_size) - n * noise_std

momentum = jax.tree_map(update_momentum, state.momentum, gradient,
noise)
momentum = jax.tree.map(update_momentum, state.momentum, gradient, noise)
updates = preconditioner.multiply_by_m_inv(momentum, preconditioner_state)
updates = jax.tree_map(lambda m: -m * jnp.sqrt(step_size), updates)
updates = jax.tree.map(lambda m: -m * jnp.sqrt(step_size), updates)
return updates, OptaxSGLDState(
count=state.count + 1,
rng_key=new_key,
Expand Down

0 comments on commit efaddc5

Please sign in to comment.