diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index a4aa43ede1..7c7c8a2969 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -415,20 +415,28 @@ def _extract_spline_coefficient( # (nframes, nloc, nnei) expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1]) - # (nframes, nloc, nnei, nspline, 4) - expanded_tab_data = tab_data[expanded_i_type, j_type] - - # (nframes, nloc, nnei, 1, 4) - expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4) - # handle the case where idx is beyond the number of splines - clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64) - + clipped_indices = torch.clamp(idx, 0, nspline - 1).to(torch.int64) + + nframes = i_type.shape[0] + nloc = i_type.shape[1] + nnei = j_type.shape[2] + ntypes = tab_data.shape[0] + # tab_data_idx: (nframes, nloc, nnei) + tab_data_idx = ( + expanded_i_type * ntypes * nspline + j_type * nspline + clipped_indices + ) + # tab_data: (ntype, ntype, nspline, 4) + tab_data = tab_data.view(ntypes * ntypes * nspline, 4) + # tab_data_idx: (nframes * nloc * nnei, 4) + tab_data_idx = tab_data_idx.view(nframes * nloc * nnei, 1).expand(-1, 4) # (nframes, nloc, nnei, 4) - final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze() + final_coef = torch.gather(tab_data, 0, tab_data_idx).view( + nframes, nloc, nnei, 4 + ) # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. - final_coef[expanded_idx.squeeze() > nspline] = 0 + final_coef[idx > nspline] = 0 return final_coef @staticmethod