Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: pt: remove data stat from model init #3233

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
PRECISION_DICT,
NativeOP,
)
from .descriptor import (
DescrptSeA,
)
from .model import (
DPAtomicModel,
DPModel,
Expand All @@ -17,13 +20,26 @@
get_reduce_name,
model_check_output,
)
from .utils import (
EmbeddingNet,
EnvMat,
FittingNet,
NativeLayer,
NativeNet,
)

__all__ = [
"DPModel",
"DPAtomicModel",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
"EnvMat",
"NativeLayer",
"NativeNet",
"EmbeddingNet",
"FittingNet",
"DescrptSeA",
Comment on lines +37 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason of providing all the classes here?

"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE))
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)

@classmethod
def get_data_stat_key(cls, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a good idea to pass a dict at interface. clearly write what does the method need.

"""Get the keys for the data statistic of the descriptor."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one would not understand the method from such a doc str.

descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)

Check warning on line 68 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L67-L68

Added lines #L67 - L68 were not covered by tests

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one would not understand the method from such a doc str.

descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)
Comment on lines 58 to 74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a subclass doesn't implement these subclasses, the program will stuck!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding the error message for this case:

Suggested change
@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)
@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)
@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
if cls is not Descriptor:
raise NotImplementedError("get_data_stat_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)
@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)


Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,19 @@

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

Check warning on line 138 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L138

Added line #L138 was not covered by tests

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,22 @@

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return (
f'stat_file_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}'
f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz'
)

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

Check warning on line 330 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L330

Added line #L330 was not covered by tests

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return {
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,19 @@

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

Check warning on line 124 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L124

Added line #L124 was not covered by tests

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
13 changes: 2 additions & 11 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


def get_model(model_params, sampled=None):
def get_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
Expand All @@ -35,16 +35,7 @@ def get_model(model_params, sampled=None):
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

return EnergyModel(
descriptor,
fitting,
type_map=model_params["type_map"],
type_embedding=model_params.get("type_embedding", None),
resuming=model_params.get("resuming", False),
stat_file_dir=model_params.get("stat_file_dir", None),
stat_file_path=model_params.get("stat_file_path", None),
sampled=sampled,
)
return EnergyModel(descriptor, fitting, type_map=model_params["type_map"])


__all__ = [
Expand Down
43 changes: 2 additions & 41 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,9 @@ class DPAtomicModel(BaseModel, BaseAtomicModel):
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
type_embedding
Type embedding net
resuming
Whether to resume/fine-tune from checkpoint or not.
stat_file_dir
The directory to the state files.
stat_file_path
The path to the state files.
sampled
Sampled frames to compute the statistics.
"""

# I am enough with the shit interface!
def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]],
type_embedding: Optional[dict] = None,
resuming: bool = False,
stat_file_dir=None,
stat_file_path=None,
sampled=None,
**kwargs,
):
def __init__(self, descriptor, fitting, type_map: Optional[List[str]]):
super().__init__()
ntypes = len(type_map)
self.type_map = type_map
Expand All @@ -72,17 +50,6 @@ def __init__(
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
# Statistics
fitting_net = None # TODO: hack!!! not sure if it is correct.
self.compute_or_load_stat(
fitting_net,
ntypes,
resuming=resuming,
type_map=type_map,
stat_file_dir=stat_file_dir,
stat_file_path=stat_file_path,
sampled=sampled,
)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -122,13 +89,7 @@ def deserialize(cls, data) -> "DPAtomicModel":
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
# TODO: dirty hack to provide type_map and avoid data stat!!!
obj = cls(
descriptor_obj,
fitting_obj,
type_map=data["type_map"],
resuming=True,
)
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

def forward_atomic(
Expand Down
Loading
Loading