diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index 8f69168b5a..207a13286c 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -335,7 +335,7 @@ class DeepPotPT : public DeepPotBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/include/DeepSpinPT.h b/source/api_cc/include/DeepSpinPT.h index 643557eb07..be4c85d898 100644 --- a/source/api_cc/include/DeepSpinPT.h +++ b/source/api_cc/include/DeepSpinPT.h @@ -257,7 +257,7 @@ class DeepSpinPT : public DeepSpinBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index ce104b0f8e..7e5d391b1f 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -171,7 +171,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +234,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, fparam_tensor, diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 3ae0eb3bb7..c72cb34b15 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -179,7 +179,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +234,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, spin_wrapped_Tensor, firstneigh_tensor,