Skip to content

Commit

Permalink
fix(pt): Fix PairTabAtomicModel OOM error (#3484)
Browse files Browse the repository at this point in the history
Reduce memory usage of `_extract_spline_coefficient`.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Mar 18, 2024
1 parent abf3477 commit eca5b30
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,20 +415,28 @@ def _extract_spline_coefficient(
# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4)

# handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64)

clipped_indices = torch.clamp(idx, 0, nspline - 1).to(torch.int64)

nframes = i_type.shape[0]
nloc = i_type.shape[1]
nnei = j_type.shape[2]
ntypes = tab_data.shape[0]
# tab_data_idx: (nframes, nloc, nnei)
tab_data_idx = (
expanded_i_type * ntypes * nspline + j_type * nspline + clipped_indices
)
# tab_data: (ntype, ntype, nspline, 4)
tab_data = tab_data.view(ntypes * ntypes * nspline, 4)
# tab_data_idx: (nframes * nloc * nnei, 4)
tab_data_idx = tab_data_idx.view(nframes * nloc * nnei, 1).expand(-1, 4)
# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()
final_coef = torch.gather(tab_data, 0, tab_data_idx).view(
nframes, nloc, nnei, 4
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
final_coef[idx > nspline] = 0
return final_coef

@staticmethod
Expand Down

0 comments on commit eca5b30

Please sign in to comment.