Skip to content

Commit

Permalink
refact: the DPA2 descriptor (#3758)
Browse files Browse the repository at this point in the history
- Refact the DPA2 descriptor in PyTorch with clearer interface
- Support residual
- Remove bn
- Add numpy implement

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added a new descriptor class `DescrptDPA2` implementing DPA-2
functionality for computing descriptors and representations based on
input coordinates and atom types.
- Expanded supported backends for DPA-2 descriptor to include DP in
addition to PyTorch.

- **Documentation**
- Updated the supported backends information in the documentation for
the DPA-2 descriptor to reflect the addition of DP backend support.
- Added a reference to the model implementation and a training example
link in the DPA-2 descriptor documentation.

- **Tests**
- Introduced test cases for the `DescrptDPA2` class in different
frameworks like TensorFlow, PyTorch, and DeepMD to cover various
parameters and configurations.
- Validated the functionality of the `DescrptDPA2` descriptor class for
deep learning models in the test case class `TestDescrptDPA2`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 10, 2024
1 parent 7ab3040 commit 74dce7f
Show file tree
Hide file tree
Showing 35 changed files with 6,165 additions and 828 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dpa1 import (
DescrptDPA1,
)
from .dpa2 import (
DescrptDPA2,
)
from .hybrid import (
DescrptHybrid,
)
Expand All @@ -19,6 +22,7 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
]
127 changes: 127 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)

log = logging.getLogger(__name__)


class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
atomic types and neighbor list, calculate the new descriptor.
"""

local_cluster = False

def __new__(cls, *args, **kwargs):
if cls is DescriptorBlock:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of DescriptorBlock should be set by `type`")
cls = cls.get_class_by_type(descrpt_type)
return super().__new__(cls)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
pass

@abstractmethod
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
pass

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
pass

@abstractmethod
def get_ntypes(self) -> int:
"""Returns the number of element types."""
pass

@abstractmethod
def get_dim_out(self) -> int:
"""Returns the output dimension."""
pass

@abstractmethod
def get_dim_in(self) -> int:
"""Returns the input dimension."""
pass

@abstractmethod
def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.
"""
raise NotImplementedError

def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

@abstractmethod
def call(
self,
nlist: np.ndarray,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
extended_atype_embd: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
):
"""Calculate DescriptorBlock."""
pass
Loading

0 comments on commit 74dce7f

Please sign in to comment.