Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor update_sel and save min_nbor_dist #3829

Merged
merged 14 commits into from
May 31, 2024
Prev Previous commit
se_e3
Signed-off-by: Jinzhe Zeng <[email protected]>
njzjz committed May 30, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
commit ecb7b6671c33573480082aae5646e2f4bea0ba3a
28 changes: 24 additions & 4 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
@@ -348,15 +351,32 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 379 in deepmd/dpmodel/descriptor/se_t.py

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L379

Added line #L379 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 382 in deepmd/dpmodel/descriptor/se_t.py

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L382

Added line #L382 was not covered by tests
28 changes: 24 additions & 4 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,9 @@
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
@@ -324,18 +327,35 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 355 in deepmd/pt/model/descriptor/se_t.py

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L355

Added line #L355 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 358 in deepmd/pt/model/descriptor/se_t.py

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L358

Added line #L358 was not covered by tests


@DescriptorBlock.register("se_e3")
Loading