diff --git a/neural_testbed/agents/factories/preconditioner.py b/neural_testbed/agents/factories/preconditioner.py index 2268fb0..3cedfe1 100644 --- a/neural_testbed/agents/factories/preconditioner.py +++ b/neural_testbed/agents/factories/preconditioner.py @@ -51,27 +51,39 @@ def get_rmsprop_preconditioner(running_average_factor=0.99, eps=1e-7): def init_fn(params): return RMSPropPreconditionerState( - grad_moment_estimates=jax.tree_map(jnp.zeros_like, params)) + grad_moment_estimates=jax.tree.map(jnp.zeros_like, params) + ) def update_preconditioner_fn(gradient, preconditioner_state): r = running_average_factor - grad_moment_estimates = jax.tree_map( - lambda e, g: e * r + g**2 *(1-r), - preconditioner_state.grad_moment_estimates, gradient) + grad_moment_estimates = jax.tree.map( + lambda e, g: e * r + g**2 * (1 - r), + preconditioner_state.grad_moment_estimates, + gradient, + ) return RMSPropPreconditionerState( grad_moment_estimates=grad_moment_estimates) def multiply_by_m_inv_fn(vec, preconditioner_state): - return jax.tree_map(lambda e, v: v / (eps + jnp.sqrt(e)), - preconditioner_state.grad_moment_estimates, vec) + return jax.tree.map( + lambda e, v: v / (eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) def multiply_by_m_sqrt_fn(vec, preconditioner_state): - return jax.tree_map(lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)), - preconditioner_state.grad_moment_estimates, vec) + return jax.tree.map( + lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) def multiply_by_m_sqrt_inv_fn(vec, preconditioner_state): - return jax.tree_map(lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)), - preconditioner_state.grad_moment_estimates, vec) + return jax.tree.map( + lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) return Preconditioner( init=init_fn, diff --git a/neural_testbed/agents/factories/sgld_optimizer.py b/neural_testbed/agents/factories/sgld_optimizer.py index 823caae..80a714c 100644 --- a/neural_testbed/agents/factories/sgld_optimizer.py +++ b/neural_testbed/agents/factories/sgld_optimizer.py @@ -35,11 +35,14 @@ def normal_like_tree(a, key): """Generate Gaussian noises.""" - treedef = jax.tree_structure(a) - num_vars = len(jax.tree_leaves(a)) + treedef = jax.tree.structure(a) + num_vars = len(jax.tree.leaves(a)) all_keys = jax.random.split(key, num=(num_vars + 1)) - noise = jax.tree_map(lambda p, k: jax.random.normal(k, shape=p.shape), a, - jax.tree_unflatten(treedef, all_keys[1:])) + noise = jax.tree.map( + lambda p, k: jax.random.normal(k, shape=p.shape), + a, + jax.tree.unflatten(treedef, all_keys[1:]), + ) return noise, all_keys[0] @@ -82,8 +85,9 @@ def init_fn(params): return OptaxSGLDState( count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed), - momentum=jax.tree_map(jnp.zeros_like, params), - preconditioner_state=preconditioner.init(params)) + momentum=jax.tree.map(jnp.zeros_like, params), + preconditioner_state=preconditioner.init(params), + ) def update_fn(gradient, state, params=None): del params @@ -98,10 +102,9 @@ def update_fn(gradient, state, params=None): def update_momentum(m, g, n): return momentum_decay * m + g * jnp.sqrt(step_size) - n * noise_std - momentum = jax.tree_map(update_momentum, state.momentum, gradient, - noise) + momentum = jax.tree.map(update_momentum, state.momentum, gradient, noise) updates = preconditioner.multiply_by_m_inv(momentum, preconditioner_state) - updates = jax.tree_map(lambda m: -m * jnp.sqrt(step_size), updates) + updates = jax.tree.map(lambda m: -m * jnp.sqrt(step_size), updates) return updates, OptaxSGLDState( count=state.count + 1, rng_key=new_key,