Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: the DPA2 descriptor #3758

Merged
merged 42 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7d82945
feat: Support `stripped_type_embedding` in PT/DP
iProzd Apr 25, 2024
a230198
Update train-se-atten.md
iProzd Apr 25, 2024
5157781
Update graph.py
iProzd Apr 25, 2024
f780d58
Update deepmd/utils/argcheck.py
iProzd Apr 25, 2024
3b3d25e
Update deepmd/pt/model/descriptor/se_atten.py
iProzd Apr 25, 2024
cf841f2
Update deepmd/tf/descriptor/se_a.py
iProzd Apr 25, 2024
a9e24d9
Update deepmd/tf/descriptor/se_a.py
iProzd Apr 25, 2024
764cab7
Update deepmd/tf/descriptor/se_atten.py
iProzd Apr 25, 2024
0b9cea1
Update deepmd/tf/descriptor/se_atten.py
iProzd Apr 25, 2024
30a594a
Merge branch 'devel' into add_strip_dpa1
iProzd Apr 25, 2024
1e86b75
Update docs
iProzd Apr 25, 2024
f3056ee
resolve conversations
iProzd Apr 26, 2024
4e231e4
rf dpa2 with identity implement
iProzd Apr 28, 2024
b7af498
Update test_dpa2.py
iProzd Apr 30, 2024
61d9794
Add residual support
iProzd May 7, 2024
0e4fe1c
rm bn support
iProzd May 7, 2024
7a1095c
Add numpy impl for DPA2
iProzd May 7, 2024
0924505
Merge branch 'devel' into rf_dpa2_consist
iProzd May 7, 2024
b19a0e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
c952798
update argcheck
iProzd May 7, 2024
bd1d5d9
Fix uts
iProzd May 8, 2024
0ebadc2
Update test_permutation.py
iProzd May 8, 2024
fe6ed6e
fix uts
iProzd May 8, 2024
104527d
Update test_dpa2.py
iProzd May 8, 2024
cbba7a7
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
e1270bd
Update argcheck.py
iProzd May 9, 2024
ceaaa07
Update se_atten.py
iProzd May 9, 2024
d2bcdbf
Fix typo
iProzd May 9, 2024
385e1f7
revert 'nf' to 'nb'
iProzd May 9, 2024
d1e38ad
Update repformers.py
iProzd May 9, 2024
9d0ad7f
Update repformers.py
iProzd May 9, 2024
375c03e
mv symmetrization_op into static
iProzd May 9, 2024
2f280e6
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
d85eef0
Update test_descriptor_dpa2.py
iProzd May 9, 2024
244c8e5
Update dpa2.md
iProzd May 9, 2024
e9fe376
separate args for repinit and repformers
iProzd May 9, 2024
515c534
Update repformer_layer.py
iProzd May 9, 2024
bd25aa6
Update repformer_layer.py
iProzd May 9, 2024
f17f40f
Update repformer_layer.py
iProzd May 9, 2024
a8c89dc
Update repformer_layer.py
iProzd May 9, 2024
e329e30
Update repformer_layer.py
iProzd May 9, 2024
e223336
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dpa1 import (
DescrptDPA1,
)
from .dpa2 import (
DescrptDPA2,
)
from .hybrid import (
DescrptHybrid,
)
Expand All @@ -19,6 +22,7 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
]
127 changes: 127 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)

log = logging.getLogger(__name__)


class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
atomic types and neighbor list, calculate the new descriptor.
"""

local_cluster = False

def __new__(cls, *args, **kwargs):
if cls is DescriptorBlock:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of DescriptorBlock should be set by `type`")
cls = cls.get_class_by_type(descrpt_type)
return super().__new__(cls)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
pass

@abstractmethod
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
pass

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
pass

@abstractmethod
def get_ntypes(self) -> int:
"""Returns the number of element types."""
pass

@abstractmethod
def get_dim_out(self) -> int:
"""Returns the output dimension."""
pass

@abstractmethod
def get_dim_in(self) -> int:
"""Returns the input dimension."""
pass

@abstractmethod
def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.

"""
raise NotImplementedError

def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

@abstractmethod
def call(
self,
nlist: np.ndarray,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
extended_atype_embd: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
):
"""Calculate DescriptorBlock."""
pass
Loading
Loading