Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: pt: remove data stat from model init #3245

Merged
merged 9 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,13 @@
"""
pass

@abstractmethod
def compute_input_stats(self, merged):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Update mean and stddev for descriptor elements."""
pass
raise NotImplementedError

Check warning on line 74 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L74

Added line #L74 was not covered by tests

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, **kwargs):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the model bias by the statistics."""
pass
raise NotImplementedError

Check warning on line 78 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L78

Added line #L78 was not covered by tests

@abstractmethod
def fwd(
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@
else:
raise KeyError(key)

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError

Check warning on line 241 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L241

Added line #L241 was not covered by tests

def init_fitting_stat(self, result_dict):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

Check warning on line 245 in deepmd/dpmodel/fitting/invar_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/invar_fitting.py#L245

Added line #L245 was not covered by tests

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
"""Calculate fitting."""
pass

def compute_output_stats(self, merged):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Update the output bias for fitting net."""
raise NotImplementedError

Check warning on line 57 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L57

Added line #L57 was not covered by tests

def init_fitting_stat(self, **kwargs):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the model bias by the statistics."""
raise NotImplementedError

Check warning on line 61 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L61

Added line #L61 was not covered by tests

@abstractmethod
def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
75 changes: 29 additions & 46 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.train import (
training,
)
Expand All @@ -60,6 +63,7 @@
)
from deepmd.pt.utils.stat import (
make_stat_input,
process_stat_path,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

Expand Down Expand Up @@ -125,51 +129,18 @@

# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
has_stat_file_path = True
if not hybrid_descrpt:
### this design requires "rcut", "rcut_smth" and "sel" in the descriptor
### VERY BAD DESIGN!!!!
### not all descriptors provides these parameter in their constructor
default_stat_file_name = Descriptor.get_stat_name(
model_params_single["descriptor"]
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
)
model_params_single["stat_file_path"] = os.path.join(
model_params_single["stat_file_dir"], model_params_single["stat_file"]
)
if not os.path.exists(model_params_single["stat_file_path"]):
has_stat_file_path = False
else: ### need to remove this
default_stat_file_name = []
for descrpt in model_params_single["descriptor"]["list"]:
default_stat_file_name.append(
f'stat_file_rcut{descrpt["rcut"]:.2f}_'
f'smth{descrpt["rcut_smth"]:.2f}_'
f'sel{descrpt["sel"]}_{descrpt["type"]}.npz'
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
stat_file_path_single, has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Descriptor,
Fitting,
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(

Check warning on line 141 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L141

Added line #L141 was not covered by tests
"data stat for hybrid descriptor is not implemented!"
)
assert isinstance(
model_params_single["stat_file"], list
), "Stat file of hybrid descriptor must be a list!"
stat_file_path = []
for stat_file_path_item in model_params_single["stat_file"]:
single_file_path = os.path.join(
model_params_single["stat_file_dir"], stat_file_path_item
)
stat_file_path.append(single_file_path)
if not os.path.exists(single_file_path):
has_stat_file_path = False
model_params_single["stat_file_path"] = stat_file_path

# validation and training data
validation_data_single = DpLoaderSet(
Expand Down Expand Up @@ -209,19 +180,30 @@
type_split=type_split,
noise_settings=noise_settings,
)
return train_data_single, validation_data_single, sampled_single
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
train_data, validation_data, sampled = prepare_trainer_input_single(
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled = {}, {}, {}
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}

Check warning on line 200 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L200

Added line #L200 was not covered by tests
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
Expand All @@ -232,7 +214,8 @@
trainer = training.Trainer(
config,
train_data,
sampled,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE))
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
Expand Down
150 changes: 134 additions & 16 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
Expand All @@ -7,6 +8,7 @@
Callable,
List,
Optional,
Union,
)

import numpy as np
Expand All @@ -23,6 +25,8 @@
BaseDescriptor,
)

log = logging.getLogger(__name__)


class Descriptor(torch.nn.Module, BaseDescriptor):
"""The descriptor.
Expand Down Expand Up @@ -56,15 +60,130 @@
return Descriptor.__plugins.register(key)

@classmethod
def get_stat_name(cls, config):
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
def get_stat_name(cls, ntypes, type_name, **kwargs):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")

Check warning on line 69 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L69

Added line #L69 was not covered by tests
descrpt_type = type_name
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(
ntypes, type_name, **kwargs
)

@classmethod
def get_data_process_key(cls, config):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")

Check warning on line 83 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L83

Added line #L83 was not covered by tests
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
raise NotImplementedError("data_stat_key is not implemented!")

Check warning on line 93 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L93

Added line #L93 was not covered by tests

def compute_or_load_stat(
self,
type_map: List[str],
sampled=None,
stat_file_path: Optional[Union[str, List[str]]] = None,
):
"""
Compute or load the statistics parameters of the descriptor.
Calculate and save the mean and standard deviation of the descriptor to `stat_file_path`
if `sampled` is not None, otherwise load them from `stat_file_path`.

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
"""
# TODO support hybrid descriptor
descrpt_stat_key = self.data_stat_key
if sampled is not None: # compute the statistics results
tmp_dict = self.compute_input_stats(sampled)
result_dict = {key: tmp_dict[key] for key in descrpt_stat_key}
result_dict["type_map"] = type_map
if stat_file_path is not None:
self.save_stats(result_dict, stat_file_path)
else: # load the statistics results
assert stat_file_path is not None, "No stat file to load!"
result_dict = self.load_stats(type_map, stat_file_path)
self.init_desc_stat(**result_dict)

def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]):
"""
Save the statistics results to `stat_file_path`.

Parameters
----------
result_dict
The dictionary of statistics results.
stat_file_path
The path to the statistics file(s).
"""
if not isinstance(stat_file_path, list):
log.info(f"Saving stat file to {stat_file_path}")
np.savez_compressed(stat_file_path, **result_dict)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(

Check warning on line 144 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L144

Added line #L144 was not covered by tests
"save_stats for hybrid descriptor is not implemented!"
)

def load_stats(self, type_map, stat_file_path: Union[str, List[str]]):
"""
Load the statistics results to `stat_file_path`.

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
stat_file_path
The path to the statistics file(s).

Returns
-------
result_dict
The dictionary of statistics results.
"""
descrpt_stat_key = self.data_stat_key
target_type_map = type_map
if not isinstance(stat_file_path, list):
log.info(f"Loading stat file from {stat_file_path}")
stats = np.load(stat_file_path)
stat_type_map = list(stats["type_map"])
missing_type = [i for i in target_type_map if i not in stat_type_map]
assert not missing_type, (
f"These type are not in stat file {stat_file_path}: {missing_type}! "
f"Please change the stat file path!"
)
idx_map = [stat_type_map.index(i) for i in target_type_map]
if stats[descrpt_stat_key[0]].size: # not empty
result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key}
else:
result_dict = {key: [] for key in descrpt_stat_key}

Check warning on line 180 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L180

Added line #L180 was not covered by tests
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(

Check warning on line 182 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L182

Added line #L182 was not covered by tests
"load_stats for hybrid descriptor is not implemented!"
)
return result_dict

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
Expand Down Expand Up @@ -156,15 +275,13 @@
"""Returns the embedding dimension."""
pass

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for DescriptorBlock elements."""
pass
raise NotImplementedError

Check warning on line 280 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L280

Added line #L280 was not covered by tests

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
"""Initialize the model bias by the statistics."""
pass
def init_desc_stat(self, **kwargs):
"""Initialize mean and stddev by the statistics."""
raise NotImplementedError

Check warning on line 284 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L284

Added line #L284 was not covered by tests

def share_params(self, base_class, shared_level, resume=False):
assert (
Expand All @@ -188,13 +305,14 @@
self.sumr2,
self.suma2,
)
base_class.init_desc_stat(
sumr_base + sumr,
suma_base + suma,
sumn_base + sumn,
sumr2_base + sumr2,
suma2_base + suma2,
)
stat_dict = {

Check warning on line 308 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L308

Added line #L308 was not covered by tests
"sumr": sumr_base + sumr,
"suma": suma_base + suma,
"sumn": sumn_base + sumn,
"sumr2": sumr2_base + sumr2,
"suma2": suma2_base + suma2,
}
base_class.init_desc_stat(**stat_dict)

Check warning on line 315 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L315

Added line #L315 was not covered by tests
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
Loading