How do I do tree_map over a list of NN params? #16992
Unanswered
niladridas
asked this question in
General
Replies: 1 comment
-
depending on the network, The following command will give you a better idea of where the function is being applied on the pytree: jax.tree_util.tree_map(lambda x:f'function applied to x of shape {x.shape}', example_pytrees) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a NN made is haiku as
net = hk.transform(net_fn)
.I have a training batch and then I do
params = net.init(jax.random.PRNGKey(42), next(iter(train_batches)))
to get the params.I make a pytree of params as :
example_pytrees = [params,params]
I am trying to iterate over this pytree to produce output from the NN, which is now
apply_net = lambda x: net.apply(x,jax.random.PRNGKey(0),next(iter(train_batches)))
.My question is:
apply_net(params)
works.But
jax.tree_util.tree_map(apply_net, example_pytrees)
shows error as:"params argument does not appear valid. It should be a mapping but is of type <class 'jaxlib.xla_extension.ArrayImpl'>.
What should I do?
Beta Was this translation helpful? Give feedback.
All reactions