Skip to content

Commit

Permalink
fix(pt): keep mapping not none during lmp steps when nghost == 0 (#4209)
Browse files Browse the repository at this point in the history
enhancement on #4144

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced tensor mapping capabilities with the addition of a new
`mapping_tensor` variable.
- Updated `compute` method to handle ghost atoms and support improved
tensor creation logic.
	- Overloaded `computew` methods to support both double and float types.

- **Bug Fixes**
- Improved error handling in the `translate_error` method for better
exception management.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
CaRoLZhangxy authored Oct 15, 2024
1 parent 5c092e6 commit 16172e6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 1 deletion.
1 change: 1 addition & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class DeepPotPT : public DeepPotBase {
int do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
c10::optional<torch::Tensor> mapping_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
/**
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
Expand Down
1 change: 0 additions & 1 deletion source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
std::vector<std::int64_t> atype_64(datype.begin(), datype.end());
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
c10::optional<torch::Tensor> mapping_tensor;
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
Expand Down

0 comments on commit 16172e6

Please sign in to comment.