diff --git a/src/metatrain/experimental/soap_bpnn/spherical.py b/src/metatrain/experimental/soap_bpnn/spherical.py index 9c72c179..d4e29178 100644 --- a/src/metatrain/experimental/soap_bpnn/spherical.py +++ b/src/metatrain/experimental/soap_bpnn/spherical.py @@ -156,6 +156,9 @@ def forward( ) elif self.o3_lambda == 1: basis = self.vector_basis(systems, selected_atoms) + basis = basis / torch.sqrt( + torch.sum(torch.square(basis), dim=-1, keepdim=True) + ) elif self.o3_lambda == 2: basis = torch.empty( (num_atoms, 5, 5),