diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 099fbc42a4..d573ba9b7f 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -475,13 +475,13 @@ def forward( nt = extended_atype_embd.shape[-1] atype_tebd_ext = extended_atype_embd # nb x (nloc x nnei) x nt - index = nlist_copy.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) # nb x (nloc x nnei) x nt atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) # (nb x nloc) x nnei - exclude_mask = self.emask(nlist_copy, extended_atype).view(nb * nloc, nnei) + exclude_mask = self.emask(nlist, extended_atype).view(nb * nloc, nnei) if self.old_impl: assert self.filter_layers_old is not None dmatrix = dmatrix.view(