Skip to content

Commit

Permalink
pt: refact training code (#3359)
Browse files Browse the repository at this point in the history
This PR
- add data_requirement for dataloader 
- reformat `make_stat_input` and related training code
- support single-task & multi-task training

---------

Signed-off-by: Duo <[email protected]>
Signed-off-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2024
1 parent 54efc03 commit 16c6db6
Show file tree
Hide file tree
Showing 42 changed files with 1,409 additions and 331 deletions.
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
Expand Down
15 changes: 14 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -84,8 +86,19 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
pass

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand Down
70 changes: 23 additions & 47 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -104,36 +101,23 @@ def get_trainer(
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
):
training_dataset_params = data_dict_single["training_data"]
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
Expand All @@ -144,71 +128,63 @@ def prepare_trainer_input_single(
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

rank = dist.get_rank() if dist.is_initialized() else 0
if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
config["model"],
config["training"],
config["loss"],
rank=rank,
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
config["loss_dict"][model_key],
suffix=f"_{model_key}",
rank=rank,
)

trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down
107 changes: 106 additions & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch
import torch.nn.functional as F

Expand All @@ -11,6 +15,9 @@
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class EnergyStdLoss(TaskLoss):
Expand All @@ -23,16 +30,57 @@ def __init__(
limit_pref_f=0.0,
start_pref_v=0.0,
limit_pref_v=0.0,
start_pref_ae: float = 0.0,
limit_pref_ae: float = 0.0,
start_pref_pf: float = 0.0,
limit_pref_pf: float = 0.0,
use_l1_all: bool = False,
inference=False,
**kwargs,
):
"""Construct a layer to compute loss on energy, force and virial."""
r"""Construct a layer to compute loss on energy, force and virial.
Parameters
----------
starter_learning_rate : float
The learning rate at the start of the training.
start_pref_e : float
The prefactor of energy loss at the start of the training.
limit_pref_e : float
The prefactor of energy loss at the end of the training.
start_pref_f : float
The prefactor of force loss at the start of the training.
limit_pref_f : float
The prefactor of force loss at the end of the training.
start_pref_v : float
The prefactor of virial loss at the start of the training.
limit_pref_v : float
The prefactor of virial loss at the end of the training.
start_pref_ae : float
The prefactor of atomic energy loss at the start of the training.
limit_pref_ae : float
The prefactor of atomic energy loss at the end of the training.
start_pref_pf : float
The prefactor of atomic prefactor force loss at the start of the training.
limit_pref_pf : float
The prefactor of atomic prefactor force loss at the end of the training.
use_l1_all : bool
Whether to use L1 loss, if False (default), it will use L2 loss.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference

# TODO need support for atomic energy and atomic pref
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
self.start_pref_f = start_pref_f
Expand Down Expand Up @@ -153,3 +201,60 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_e:
label_requirement.append(
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
)
)
if self.has_f:
label_requirement.append(
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_v:
label_requirement.append(
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
)
)
if self.has_ae:
label_requirement.append(
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_pf:
label_requirement.append(
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
)
)
return label_requirement
20 changes: 19 additions & 1 deletion deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
)

import torch

from deepmd.utils.data import (
DataRequirementItem,
)


class TaskLoss(torch.nn.Module):
class TaskLoss(torch.nn.Module, ABC):
def __init__(self, **kwargs):
"""Construct loss."""
super().__init__()

def forward(self, model_pred, label, natoms, learning_rate):
"""Return loss ."""
raise NotImplementedError

@property
@abstractmethod
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
pass
Loading

0 comments on commit 16c6db6

Please sign in to comment.