Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refact training code #3359

Merged
merged 47 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3812866
Fix single-task training&data stat
iProzd Feb 28, 2024
08e18fe
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
ae27607
Fix EnergyFittingNetDirect
iProzd Feb 28, 2024
7f573ab
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
f9265d5
Add data_requirement for dataloader
iProzd Feb 28, 2024
f8d2980
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
c9eb767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
00105c7
Update make_base_descriptor.py
iProzd Feb 28, 2024
5a9df83
Update typing
iProzd Feb 28, 2024
75da5b1
Update training.py
iProzd Feb 28, 2024
6c171c5
Fix uts
iProzd Feb 28, 2024
2e87e1d
Fix uts
iProzd Feb 28, 2024
eb8094d
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
2618d98
Support multi-task training
iProzd Feb 28, 2024
f1585b2
Take advice from QL scan
iProzd Feb 28, 2024
463f9fb
Support no validation
iProzd Feb 28, 2024
e8575af
Update se_r.py
iProzd Feb 28, 2024
66d03b8
omit data prob log
iProzd Feb 28, 2024
e9e0d95
omit seed log
iProzd Feb 28, 2024
90be50e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
ab35653
Add fparam and aparam
iProzd Feb 29, 2024
64d6079
Add type hint for `Callable`
iProzd Feb 29, 2024
6020a2b
Fix nopbc
iProzd Feb 29, 2024
5db7883
Add DataRequirementItem
iProzd Feb 29, 2024
c03a5ba
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cce52da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
18cbf9e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cdcfcb2
Fix neighbor-stat for multitask (#31)
iProzd Feb 29, 2024
a7d44d1
Revert "Fix neighbor-stat for multitask (#31)"
iProzd Feb 29, 2024
fdca653
Move label requirement to loss func
iProzd Feb 29, 2024
525ce93
resolve conversations
iProzd Feb 29, 2024
46ee16c
set label_requirement abstractmethod
iProzd Feb 29, 2024
9d18dc4
make label_requirement dynamic
iProzd Feb 29, 2024
ad7227d
update docs
iProzd Feb 29, 2024
35598d2
replace lazy with functools.lru_cache
iProzd Feb 29, 2024
c0a0cfc
Update training.py
iProzd Feb 29, 2024
d50e2a2
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
66edca5
Update deepmd/pt/train/training.py
wanghan-iapcm Feb 29, 2024
d5a1549
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
c51f865
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
e17546a
Update test_multitask.py
iProzd Feb 29, 2024
1debf4f
Fix h5py files in multitask DDP
iProzd Feb 29, 2024
db31edc
FIx h5py file read block
iProzd Feb 29, 2024
60dda49
Merge branch 'devel' into train_rf
iProzd Mar 1, 2024
3dfc31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
615446f
Update hybrid.py
iProzd Mar 1, 2024
e26c118
Update hybrid.py
iProzd Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@
@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
pass

Check warning on line 92 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L92

Added line #L92 was not covered by tests

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
iProzd marked this conversation as resolved.
Show resolved Hide resolved
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
Type,
)

from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
Expand Down Expand Up @@ -93,7 +96,7 @@ def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def data_requirement(self) -> dict:
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""

@abstractmethod
Expand Down
26 changes: 25 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand All @@ -8,6 +12,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .make_model import (
make_model,
Expand All @@ -17,6 +24,23 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
def data_requirement(self) -> dict:
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""
raise NotImplementedError

Check warning on line 29 in deepmd/dpmodel/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/dp_model.py#L29

Added line #L29 was not covered by tests

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
6 changes: 6 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -99,8 +105,8 @@
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (

Check warning on line 109 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L108-L109

Added lines #L108 - L109 were not covered by tests
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
Expand All @@ -118,7 +124,7 @@
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = (

Check warning on line 127 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L127

Added line #L127 was not covered by tests
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
Expand Down Expand Up @@ -154,7 +160,7 @@
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, stat_file_path = {}, {}, {}

Check warning on line 163 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L163

Added line #L163 was not covered by tests
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@
"""Returns the embedding dimension."""
pass

def compute_input_stats(

Check warning on line 91 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L91

Added line #L91 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,23 @@
"""
return self.se_atten.mixed_types()

def share_params(self, base_class, shared_level, resume=False):
assert (

Check warning on line 151 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L150-L151

Added lines #L150 - L151 were not covered by tests
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For DPA1 descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in both type_embedding and se_atten
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.se_atten.share_params(base_class.se_atten, 0, resume=resume)

Check warning on line 159 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L157-L159

Added lines #L157 - L159 were not covered by tests
# shared_level: 1
# share all parameters in type_embedding
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]

Check warning on line 163 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L162-L163

Added lines #L162 - L163 were not covered by tests
# Other shared levels
else:
raise NotImplementedError

Check warning on line 166 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L166

Added line #L166 was not covered by tests

@property
def dim_out(self):
Expand All @@ -173,8 +173,10 @@
def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(

Check warning on line 176 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L176

Added line #L176 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
return self.se_atten.compute_input_stats(merged, path)

Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,40 +291,40 @@
"""
return True

def share_params(self, base_class, shared_level, resume=False):
assert (

Check warning on line 295 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L294-L295

Added lines #L294 - L295 were not covered by tests
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For DPA2 descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in type_embedding, repinit and repformers
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.repinit.share_params(base_class.repinit, 0, resume=resume)
self._modules["g1_shape_tranform"] = base_class._modules[

Check warning on line 304 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L301-L304

Added lines #L301 - L304 were not covered by tests
"g1_shape_tranform"
]
self.repformers.share_params(base_class.repformers, 0, resume=resume)

Check warning on line 307 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L307

Added line #L307 was not covered by tests
# shared_level: 1
# share all parameters in type_embedding and repinit
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.repinit.share_params(base_class.repinit, 0, resume=resume)

Check warning on line 312 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L310-L312

Added lines #L310 - L312 were not covered by tests
# shared_level: 2
# share all parameters in type_embedding and repformers
elif shared_level == 2:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self._modules["g1_shape_tranform"] = base_class._modules[

Check warning on line 317 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L315-L317

Added lines #L315 - L317 were not covered by tests
"g1_shape_tranform"
]
self.repformers.share_params(base_class.repformers, 0, resume=resume)

Check warning on line 320 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L320

Added line #L320 was not covered by tests
# shared_level: 3
# share all parameters in type_embedding
elif shared_level == 3:
self._modules["type_embedding"] = base_class._modules["type_embedding"]

Check warning on line 324 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L323-L324

Added lines #L323 - L324 were not covered by tests
# Other shared levels
else:
raise NotImplementedError

Check warning on line 327 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L327

Added line #L327 was not covered by tests

@property
def dim_out(self):
Expand All @@ -335,11 +335,13 @@
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def compute_input_stats(

Check warning on line 338 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L338

Added line #L338 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.compute_input_stats(merged, path)

Check warning on line 344 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L344

Added line #L344 was not covered by tests

def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,15 @@
else:
raise NotImplementedError

def compute_input_stats(

Check warning on line 160 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L160

Added line #L160 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
for ii, descrpt in enumerate(self.descriptor_list):
# need support for hybrid descriptors
descrpt.compute_input_stats(merged, path)

Check warning on line 168 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L168

Added line #L168 was not covered by tests

def forward(
self,
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,22 +280,24 @@

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(

Check warning on line 283 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L283

Added line #L283 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 293 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L292-L293

Added lines #L292 - L293 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 295 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L295

Added line #L295 was not covered by tests
else:
sampled = merged

Check warning on line 297 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L297

Added line #L297 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 300 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L299-L300

Added lines #L299 - L300 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,28 @@
"""
return self.sea.mixed_types()

def share_params(self, base_class, shared_level, resume=False):
assert (

Check warning on line 133 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L132-L133

Added lines #L132 - L133 were not covered by tests
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For SeA descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in sea
if shared_level == 0:
self.sea.share_params(base_class.sea, 0, resume=resume)

Check warning on line 140 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L139-L140

Added lines #L139 - L140 were not covered by tests
# Other shared levels
else:
raise NotImplementedError

Check warning on line 143 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L143

Added line #L143 was not covered by tests

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
return self.sea.dim_out

def compute_input_stats(

Check warning on line 150 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L150

Added line #L150 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)
Expand Down Expand Up @@ -428,22 +430,24 @@
else:
raise KeyError(key)

def compute_input_stats(

Check warning on line 433 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L433

Added line #L433 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 443 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L442-L443

Added lines #L442 - L443 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 445 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L445

Added line #L445 was not covered by tests
else:
sampled = merged

Check warning on line 447 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L447

Added line #L447 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 450 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L449-L450

Added lines #L449 - L450 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,24 @@
"""Returns the output dimension of embedding."""
return self.get_dim_emb()

def compute_input_stats(

Check warning on line 205 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L205

Added line #L205 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 215 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L214-L215

Added lines #L214 - L215 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 217 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L217

Added line #L217 was not covered by tests
else:
sampled = merged

Check warning on line 219 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L219

Added line #L219 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 222 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L221-L222

Added lines #L221 - L222 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,50 +153,52 @@
"""
return False

def share_params(self, base_class, shared_level, resume=False):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
assert (

Check warning on line 157 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L156-L157

Added lines #L156 - L157 were not covered by tests
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For SeR descriptors, the user-defined share-level
# shared_level: 0
if shared_level == 0:

Check warning on line 162 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L162

Added line #L162 was not covered by tests
# link buffers
if hasattr(self, "mean") and not resume:

Check warning on line 164 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L164

Added line #L164 was not covered by tests
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
self.mean = base_class.mean
self.stddev = base_class.stddev

Check warning on line 175 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L166-L175

Added lines #L166 - L175 were not covered by tests
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
# the following will successfully link all the params except buffers
for item in self._modules:
self._modules[item] = base_class._modules[item]

Check warning on line 179 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L178-L179

Added lines #L178 - L179 were not covered by tests
# Other shared levels
else:
raise NotImplementedError

Check warning on line 182 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L182

Added line #L182 was not covered by tests

def compute_input_stats(

Check warning on line 184 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L184

Added line #L184 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 194 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L193-L194

Added lines #L193 - L194 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 196 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L196

Added line #L196 was not covered by tests
else:
sampled = merged

Check warning on line 198 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L198

Added line #L198 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 201 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L200-L201

Added lines #L200 - L201 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
41 changes: 24 additions & 17 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.utils.data import (

Check warning on line 10 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L10

Added line #L10 was not covered by tests
DataRequirementItem,
)

from .dp_model import (
DPModel,
)
Expand Down Expand Up @@ -91,22 +96,24 @@
model_predict = model_ret
return model_predict

@property
def data_requirement(self):
data_requirement = {
"dipole": {
"ndof": 3,
"atomic": False,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
"atomic_dipole": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
"type_sel": self.get_sel_type(),
},
}
def data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [

Check warning on line 101 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L99-L101

Added lines #L99 - L101 were not covered by tests
DataRequirementItem(
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"dipole",
ndof=3,
atomic=False,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
DataRequirementItem(
"atomic_dipole",
ndof=3,
atomic=True,
must=False,
high_prec=False,
type_sel=self.get_sel_type(),
),
]
return data_requirement

Check warning on line 119 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L119

Added line #L119 was not covered by tests
94 changes: 60 additions & 34 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
)

Expand All @@ -15,6 +16,9 @@
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.utils.data import (

Check warning on line 19 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L19

Added line #L19 was not covered by tests
DataRequirementItem,
)

from .make_model import (
make_model,
Expand Down Expand Up @@ -101,39 +105,61 @@
model_predict = model_ret
return model_predict

@property
def data_requirement(self):
data_requirement = {
"energy": {
"ndof": 1,
"atomic": False,
"must": False,
"high_prec": True,
},
"force": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
},
"virial": {
"ndof": 9,
"atomic": False,
"must": False,
"high_prec": False,
},
"atom_ener": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
},
"atom_pref": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
"repeat": 3,
},
}
def data_requirement(self) -> List[DataRequirementItem]:
data_requirement = [

Check warning on line 110 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L108-L110

Added lines #L108 - L110 were not covered by tests
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
]
return data_requirement

Check warning on line 148 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L148

Added line #L148 was not covered by tests

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"] = DPModel.update_sel(
global_jdata, local_jdata["dpmodel"]
)
return local_jdata_cpy
Loading
Loading