You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I am trying to run Tutorial_new_model_training but getting an error in step 3 [construct and initialize the NequIP energy model].
Please help me resolving the follwing error :
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 1
----> 1 model_fn, params, num_message_passing = NequIP_JAXMD_model(
2 r_max=r_max,
3 atomic_energies_dict={},
4 train_graphs=train_loader.graphs,
5 initialize_seed=config["model"]["seed"],
6 num_species = config["model"]["num_species"],
7 use_sc = True,
8 graph_net_steps = config["model"]["num_layers"],
9 hidden_irreps = config["model"]["internal_irreps"],
10 nonlinearities = {'e': 'swish', 'o': 'tanh'},
11 save_dir_name = save_dir_name,
12 reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
13 )
15 print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
17 predictor = jax.jit(
18 lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
19 )
File ~/phonax/phonax/phonax/nequip_model.py:716, in NequIP_JAXMD_model(r_max, atomic_energies_dict, train_graphs, initialize_seed, scaling, atomic_energies, avg_num_neighbors, avg_r_min, num_species, path_normalization, gradient_normalization, learnable_atomic_energies, radial_basis, radial_envelope, save_dir_name, reload, **kwargs)
713 return node_energies
715 if (initialize_seed is not None) and reload is None:
--> 716 params = jax.jit(model_.init)(
717 jax.random.PRNGKey(initialize_seed),
718 jnp.zeros((1, 3)),
719 jnp.array([16]),
720 jnp.array([0]),
721 jnp.array([0]),
722 )
723 elif reload is not None:
724 with open(f"{reload}/params.pkl", "rb") as f:
[... skipping hidden 11 frame]
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses hk.{get,set}_state then use "
170 "hk.transform_with_state.")
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If f is the constructor then
467 # its name will only be set afterf has run. For methods other
468 # than __init__ we need the name before running in order to wrap their
469 # execution with named_call.
470 if module_name is None:
File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:466, in NequIPEnergyModel.call(self, vectors, node_specie, senders, receivers)
464 # convolutions
465 for _ in range(self.graph_net_steps):
--> 466 h_node = NequIPConvolution(
467 hidden_irreps=hidden_irreps,
468 use_sc=self.use_sc,
469 nonlinearities=self.nonlinearities,
470 radial_net_nonlinearity=self.radial_net_nonlinearity,
471 radial_net_n_hidden=self.radial_net_n_hidden,
472 radial_net_n_layers=self.radial_net_n_layers,
473 num_basis=self.num_basis,
474 avg_num_neighbors=self.avg_num_neighbors,
475 scalar_mlp_std=self.scalar_mlp_std
476 )(h_node,
477 node_attrs,
478 edge_sh,
479 edge_src,
480 edge_dst,
481 embedded_dr_edge
482 )
484 # output block, two Linears that decay dimensions from h to h//2 to 1
485 for mul, ir in h_node.irreps:
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If f is the constructor then
467 # its name will only be set afterf has run. For methods other
468 # than __init__ we need the name before running in order to wrap their
469 # execution with named_call.
470 if module_name is None:
File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:336, in NequIPConvolution.call(self, node_features, node_attributes, edge_sh, edge_src, edge_dst, edge_embedded)
333 # self-connection, similar to a resnet-update that sums the output from
334 # the TP to chemistry-weighted h
335 if self.use_sc:
--> 336 h = h + self_connection
338 # gate nonlinearity, applied to gate data, consisting of:
339 # a) regular scalars,
340 # b) gate scalars, and
341 # c) non-scalars to be gated
342 # in this order
343 gate_fn = partial(
344 e3nn.gate,
345 even_act=get_nonlinearity_by_name(self.nonlinearities['e']),
(...)
348 odd_gate_act=get_nonlinearity_by_name(self.nonlinearities['o'])
349 )
File ~/.conda/envs/phonax/lib/python3.10/site-packages/e3nn_jax/_src/irreps_array.py:311, in IrrepsArray.add(self, other)
306 raise ValueError(
307 f"IrrepsArray({self.irreps}, shape={self.shape}) + scalar is not equivariant."
308 )
310 if self.irreps != other.irreps:
--> 311 raise ValueError(
312 f"IrrepsArray({self.irreps}, shape={self.shape}) + IrrepsArray({other.irreps}) is not equivariant."
313 )
315 zero_flags = tuple(x and y for x, y in zip(self.zero_flags, other.zero_flags))
316 chunks = None
ValueError: IrrepsArray(36x0e+12x1o+8x2e, shape=(1, 112)) + IrrepsArray(16x0e+12x0e+8x0e+12x1o+8x2e) is not equivariant.`
The text was updated successfully, but these errors were encountered:
Hi,
I am trying to run Tutorial_new_model_training
but getting an error in step 3 [construct and initialize the NequIP energy model].
Please help me resolving the follwing error :
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 1
----> 1 model_fn, params, num_message_passing = NequIP_JAXMD_model(
2 r_max=r_max,
3 atomic_energies_dict={},
4 train_graphs=train_loader.graphs,
5 initialize_seed=config["model"]["seed"],
6 num_species = config["model"]["num_species"],
7 use_sc = True,
8 graph_net_steps = config["model"]["num_layers"],
9 hidden_irreps = config["model"]["internal_irreps"],
10 nonlinearities = {'e': 'swish', 'o': 'tanh'},
11 save_dir_name = save_dir_name,
12 reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
13 )
15 print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
17 predictor = jax.jit(
18 lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
19 )
File ~/phonax/phonax/phonax/nequip_model.py:716, in NequIP_JAXMD_model(r_max, atomic_energies_dict, train_graphs, initialize_seed, scaling, atomic_energies, avg_num_neighbors, avg_r_min, num_species, path_normalization, gradient_normalization, learnable_atomic_energies, radial_basis, radial_envelope, save_dir_name, reload, **kwargs)
713 return node_energies
715 if (initialize_seed is not None) and reload is None:
--> 716 params = jax.jit(model_.init)(
717 jax.random.PRNGKey(initialize_seed),
718 jnp.zeros((1, 3)),
719 jnp.array([16]),
720 jnp.array([0]),
721 jnp.array([0]),
722 )
723 elif reload is not None:
724 with open(f"{reload}/params.pkl", "rb") as f:
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses
hk.{get,set}_state
then use "170 "
hk.transform_with_state
.")File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/phonax/phonax/phonax/nequip_model.py:696, in NequIP_JAXMD_model..model_(vectors, node_z, senders, receivers)
689 if hk.running_init():
690 logging.info(
691 "model: "
692 f"hidden_irreps={nequip.hidden_irreps} "
693 f"sh_irreps={nequip.sh_irreps} ",
694 )
--> 696 contributions = nequip(
697 vectors, node_z, senders, receivers
698 ) # [n_nodes, num_interactions, 0e]
699 node_energies = contributions[:, 0]
701 node_energies = mean + std * node_energies
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If
f
is the constructor then467 # its name will only be set after
f
has run. For methods other468 # than
__init__
we need the name before running in order to wrap their469 # execution with
named_call
.470 if module_name is None:
File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:466, in NequIPEnergyModel.call(self, vectors, node_specie, senders, receivers)
464 # convolutions
465 for _ in range(self.graph_net_steps):
--> 466 h_node = NequIPConvolution(
467 hidden_irreps=hidden_irreps,
468 use_sc=self.use_sc,
469 nonlinearities=self.nonlinearities,
470 radial_net_nonlinearity=self.radial_net_nonlinearity,
471 radial_net_n_hidden=self.radial_net_n_hidden,
472 radial_net_n_layers=self.radial_net_n_layers,
473 num_basis=self.num_basis,
474 avg_num_neighbors=self.avg_num_neighbors,
475 scalar_mlp_std=self.scalar_mlp_std
476 )(h_node,
477 node_attrs,
478 edge_sh,
479 edge_src,
480 edge_dst,
481 embedded_dr_edge
482 )
484 # output block, two Linears that decay dimensions from h to h//2 to 1
485 for mul, ir in h_node.irreps:
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, **kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(*args, **kwargs)
466 # Module names are set in the constructor. If
f
is the constructor then467 # its name will only be set after
f
has run. For methods other468 # than
__init__
we need the name before running in order to wrap their469 # execution with
named_call
.470 if module_name is None:
File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, **kwargs)
303 """Runs any method interceptors or the original method."""
304 if not interceptor_stack:
--> 305 return bound_method(*args, **kwargs)
307 ctx = MethodContext(module=self,
308 method_name=method_name,
309 orig_method=bound_method,
310 orig_class=orig_class)
311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:336, in NequIPConvolution.call(self, node_features, node_attributes, edge_sh, edge_src, edge_dst, edge_embedded)
333 # self-connection, similar to a resnet-update that sums the output from
334 # the TP to chemistry-weighted h
335 if self.use_sc:
--> 336 h = h + self_connection
338 # gate nonlinearity, applied to gate data, consisting of:
339 # a) regular scalars,
340 # b) gate scalars, and
341 # c) non-scalars to be gated
342 # in this order
343 gate_fn = partial(
344 e3nn.gate,
345 even_act=get_nonlinearity_by_name(self.nonlinearities['e']),
(...)
348 odd_gate_act=get_nonlinearity_by_name(self.nonlinearities['o'])
349 )
File ~/.conda/envs/phonax/lib/python3.10/site-packages/e3nn_jax/_src/irreps_array.py:311, in IrrepsArray.add(self, other)
306 raise ValueError(
307 f"IrrepsArray({self.irreps}, shape={self.shape}) + scalar is not equivariant."
308 )
310 if self.irreps != other.irreps:
--> 311 raise ValueError(
312 f"IrrepsArray({self.irreps}, shape={self.shape}) + IrrepsArray({other.irreps}) is not equivariant."
313 )
315 zero_flags = tuple(x and y for x, y in zip(self.zero_flags, other.zero_flags))
316 chunks = None
ValueError: IrrepsArray(36x0e+12x1o+8x2e, shape=(1, 112)) + IrrepsArray(16x0e+12x0e+8x0e+12x1o+8x2e) is not equivariant.`
The text was updated successfully, but these errors were encountered: