Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refact training code #3359

Merged
merged 47 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3812866
Fix single-task training&data stat
iProzd Feb 28, 2024
08e18fe
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
ae27607
Fix EnergyFittingNetDirect
iProzd Feb 28, 2024
7f573ab
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
f9265d5
Add data_requirement for dataloader
iProzd Feb 28, 2024
f8d2980
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
c9eb767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
00105c7
Update make_base_descriptor.py
iProzd Feb 28, 2024
5a9df83
Update typing
iProzd Feb 28, 2024
75da5b1
Update training.py
iProzd Feb 28, 2024
6c171c5
Fix uts
iProzd Feb 28, 2024
2e87e1d
Fix uts
iProzd Feb 28, 2024
eb8094d
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
2618d98
Support multi-task training
iProzd Feb 28, 2024
f1585b2
Take advice from QL scan
iProzd Feb 28, 2024
463f9fb
Support no validation
iProzd Feb 28, 2024
e8575af
Update se_r.py
iProzd Feb 28, 2024
66d03b8
omit data prob log
iProzd Feb 28, 2024
e9e0d95
omit seed log
iProzd Feb 28, 2024
90be50e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
ab35653
Add fparam and aparam
iProzd Feb 29, 2024
64d6079
Add type hint for `Callable`
iProzd Feb 29, 2024
6020a2b
Fix nopbc
iProzd Feb 29, 2024
5db7883
Add DataRequirementItem
iProzd Feb 29, 2024
c03a5ba
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cce52da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
18cbf9e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cdcfcb2
Fix neighbor-stat for multitask (#31)
iProzd Feb 29, 2024
a7d44d1
Revert "Fix neighbor-stat for multitask (#31)"
iProzd Feb 29, 2024
fdca653
Move label requirement to loss func
iProzd Feb 29, 2024
525ce93
resolve conversations
iProzd Feb 29, 2024
46ee16c
set label_requirement abstractmethod
iProzd Feb 29, 2024
9d18dc4
make label_requirement dynamic
iProzd Feb 29, 2024
ad7227d
update docs
iProzd Feb 29, 2024
35598d2
replace lazy with functools.lru_cache
iProzd Feb 29, 2024
c0a0cfc
Update training.py
iProzd Feb 29, 2024
d50e2a2
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
66edca5
Update deepmd/pt/train/training.py
wanghan-iapcm Feb 29, 2024
d5a1549
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
c51f865
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
e17546a
Update test_multitask.py
iProzd Feb 29, 2024
1debf4f
Fix h5py files in multitask DDP
iProzd Feb 29, 2024
db31edc
FIx h5py file read block
iProzd Feb 29, 2024
60dda49
Merge branch 'devel' into train_rf
iProzd Mar 1, 2024
3dfc31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
615446f
Update hybrid.py
iProzd Mar 1, 2024
e26c118
Update hybrid.py
iProzd Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -84,8 +86,15 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
iProzd marked this conversation as resolved.
Show resolved Hide resolved
pass

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
iProzd marked this conversation as resolved.
Show resolved Hide resolved
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
7 changes: 7 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
Type,
)

from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
Expand Down Expand Up @@ -92,6 +95,10 @@ def is_aparam_nall(self) -> bool:
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand All @@ -8,6 +12,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .make_model import (
make_model,
Expand All @@ -17,6 +24,10 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
def data_requirement(self) -> List[DataRequirementItem]:
"""Get the data requirement for the model."""
raise NotImplementedError

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Expand Down
94 changes: 43 additions & 51 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -50,9 +53,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand All @@ -75,9 +75,11 @@ def get_trainer(
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if "model_dict" not in config.get("model", {}):
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand All @@ -88,7 +90,6 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

multi_task = "model_dict" in config["model"]
ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
Expand All @@ -97,10 +98,6 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
Expand All @@ -109,26 +106,11 @@ def prepare_trainer_input_single(
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand All @@ -143,59 +125,47 @@ def prepare_trainer_input_single(
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
Expand All @@ -207,7 +177,6 @@ def prepare_trainer_input_single(
trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down Expand Up @@ -252,11 +221,33 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(
fake_global_jdata, config["model"]["model_dict"][model_item]
)

trainer = get_trainer(
config,
Expand All @@ -266,6 +257,7 @@ def train(FLAGS):
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()

Expand Down
17 changes: 5 additions & 12 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -185,7 +182,7 @@ def forward_atomic(

def compute_or_load_stat(
self,
sampled,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Expand All @@ -198,22 +195,18 @@ def compute_or_load_stat(

Parameters
----------
sampled
The sampled data frames from different data systems.
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)
self.descriptor.compute_input_stats(sampled_func, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled, stat_file_path)
self.fitting_net.compute_output_stats(sampled_func, stat_file_path)

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_descriptor import (
BaseDescriptor,
)
from .descriptor import (
DescriptorBlock,
make_default_type_embedding,
Expand Down Expand Up @@ -31,6 +34,7 @@
)

__all__ = [
"BaseDescriptor",
"DescriptorBlock",
"make_default_type_embedding",
"DescrptBlockSeA",
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -86,7 +88,11 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError

Expand Down
Loading
Loading