Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem running Tutorial_new_model_training.ipynb #1

Open
VinaySingh561 opened this issue Sep 25, 2024 · 3 comments
Open

Problem running Tutorial_new_model_training.ipynb #1

VinaySingh561 opened this issue Sep 25, 2024 · 3 comments

Comments

@VinaySingh561
Copy link

Hi,
I am trying to run Tutorial_new_model_training
but getting an error in step 3 [construct and initialize the NequIP energy model].
Please help me resolving the follwing error :
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 1
----> 1 model_fn, params, num_message_passing = NequIP_JAXMD_model(
2 r_max=r_max,
3 atomic_energies_dict={},
4 train_graphs=train_loader.graphs,
5 initialize_seed=config["model"]["seed"],
6 num_species = config["model"]["num_species"],
7 use_sc = True,
8 graph_net_steps = config["model"]["num_layers"],
9 hidden_irreps = config["model"]["internal_irreps"],
10 nonlinearities = {'e': 'swish', 'o': 'tanh'},
11 save_dir_name = save_dir_name,
12 reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
13 )
15 print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
17 predictor = jax.jit(
18 lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
19 )

File ~/phonax/phonax/phonax/nequip_model.py:716, in NequIP_JAXMD_model(r_max, atomic_energies_dict, train_graphs, initialize_seed, scaling, atomic_energies, avg_num_neighbors, avg_r_min, num_species, path_normalization, gradient_normalization, learnable_atomic_energies, radial_basis, radial_envelope, save_dir_name, reload, **kwargs)
713 return node_energies
715 if (initialize_seed is not None) and reload is None:
--> 716 params = jax.jit(model_.init)(
717 jax.random.PRNGKey(initialize_seed),
718 jnp.zeros((1, 3)),
719 jnp.array([16]),
720 jnp.array([0]),
721 jnp.array([0]),
722 )
723 elif reload is not None:
724 with open(f"{reload}/params.pkl", "rb") as f:

[... skipping hidden 11 frame]

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses hk.{get,set}_state then use "
170 "hk.transform_with_state.")

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

File ~/phonax/phonax/phonax/nequip_model.py:696, in NequIP_JAXMD_model..model_(vectors, node_z, senders, receivers)
689 if hk.running_init():
690 logging.info(
691 "model: "
692 f"hidden_irreps={nequip.hidden_irreps} "
693 f"sh_irreps={nequip.sh_irreps} ",
694 )
--> 696 contributions = nequip(
697 vectors, node_z, senders, receivers
698 ) # [n_nodes, num_interactions, 0e]
699 node_energies = contributions[:, 0]
701 node_energies = mean + std * node_energies

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If f is the constructor then
467 # its name will only be set after f has run. For methods other
468 # than __init__ we need the name before running in order to wrap their
469 # execution with named_call.
470 if module_name is None:

File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()

File ~/phonax/phonax/phonax/nequip_model.py:466, in NequIPEnergyModel.call(self, vectors, node_specie, senders, receivers)
464 # convolutions
465 for _ in range(self.graph_net_steps):
--> 466 h_node = NequIPConvolution(
467 hidden_irreps=hidden_irreps,
468 use_sc=self.use_sc,
469 nonlinearities=self.nonlinearities,
470 radial_net_nonlinearity=self.radial_net_nonlinearity,
471 radial_net_n_hidden=self.radial_net_n_hidden,
472 radial_net_n_layers=self.radial_net_n_layers,
473 num_basis=self.num_basis,
474 avg_num_neighbors=self.avg_num_neighbors,
475 scalar_mlp_std=self.scalar_mlp_std
476 )(h_node,
477 node_attrs,
478 edge_sh,
479 edge_src,
480 edge_dst,
481 embedded_dr_edge
482 )
484 # output block, two Linears that decay dimensions from h to h//2 to 1
485 for mul, ir in h_node.irreps:

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If f is the constructor then
467 # its name will only be set after f has run. For methods other
468 # than __init__ we need the name before running in order to wrap their
469 # execution with named_call.
470 if module_name is None:

File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()

File ~/phonax/phonax/phonax/nequip_model.py:336, in NequIPConvolution.call(self, node_features, node_attributes, edge_sh, edge_src, edge_dst, edge_embedded)
333 # self-connection, similar to a resnet-update that sums the output from
334 # the TP to chemistry-weighted h
335 if self.use_sc:
--> 336 h = h + self_connection
338 # gate nonlinearity, applied to gate data, consisting of:
339 # a) regular scalars,
340 # b) gate scalars, and
341 # c) non-scalars to be gated
342 # in this order
343 gate_fn = partial(
344 e3nn.gate,
345 even_act=get_nonlinearity_by_name(self.nonlinearities['e']),
(...)
348 odd_gate_act=get_nonlinearity_by_name(self.nonlinearities['o'])
349 )

File ~/.conda/envs/phonax/lib/python3.10/site-packages/e3nn_jax/_src/irreps_array.py:311, in IrrepsArray.add(self, other)
306 raise ValueError(
307 f"IrrepsArray({self.irreps}, shape={self.shape}) + scalar is not equivariant."
308 )
310 if self.irreps != other.irreps:
--> 311 raise ValueError(
312 f"IrrepsArray({self.irreps}, shape={self.shape}) + IrrepsArray({other.irreps}) is not equivariant."
313 )
315 zero_flags = tuple(x and y for x, y in zip(self.zero_flags, other.zero_flags))
316 chunks = None

ValueError: IrrepsArray(36x0e+12x1o+8x2e, shape=(1, 112)) + IrrepsArray(16x0e+12x0e+8x0e+12x1o+8x2e) is not equivariant.`

@AI4TE
Copy link

AI4TE commented Oct 3, 2024

I had the same problem. Did you solve it?

@AI4TE
Copy link

AI4TE commented Oct 3, 2024

After downgrading these libraries, the errors no longer occurred: e3nn-jax==0.20.6, dm-haiku==0.0.12, jax==0.4.31, jaxlib==0.4.31."

@VinaySingh561
Copy link
Author

VinaySingh561 commented Oct 5, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants