-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pd: support dpa2 #4418
pd: support dpa2 #4418
Conversation
for more information, see https://pre-commit.ci
…y to coverage newly added code
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
source/tests/pd/model/test_descriptor_dpa2.py (2)
24-118
: Add docstring and improve test data organization.While the test setup is thorough, consider these improvements for better maintainability:
- Add a docstring explaining the test setup and the meaning of the test data
- Consider moving the hardcoded test values to separate data files
- Add comments explaining the significance of the reference values
Example docstring:
class TestDPA2(unittest.TestCase): def setUp(self): + """Initialize test data for DPA2 descriptor testing. + + Sets up: + - Simulation cell parameters + - Atomic coordinates + - Atom types + - Reference descriptor values + - Paths to model parameters and embeddings + """
119-186
: Enhance test coverage and improve test organization.The test covers the happy path but could be improved:
- Split into separate test methods for different configurations
- Add error cases and edge cases
- Add input validation tests
Consider refactoring like this:
def test_descriptor_basic_config(self): """Test DPA2 descriptor with basic configuration.""" # Current test logic with concat_output_tebd=False def test_descriptor_with_concat_output(self): """Test DPA2 descriptor with concatenated output.""" # Current test logic with concat_output_tebd=True def test_descriptor_invalid_inputs(self): """Test DPA2 descriptor with invalid inputs.""" # Add tests for invalid coordinates, types, etc. def test_descriptor_edge_cases(self): """Test DPA2 descriptor with edge cases.""" # Add tests for empty system, single atom, etc.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/tests/pd/model/test_descriptor_dpa2.py
(1 hunks)
🔇 Additional comments (1)
source/tests/pd/model/test_descriptor_dpa2.py (1)
1-23
: LGTM! Well-organized imports following best practices.
The imports are properly organized (standard library → third-party → local) and use modern Python practices like Path for file operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
source/tests/pd/test_multitask.py
(2 hunks)source/tests/pd/test_training.py
(2 hunks)
🔇 Additional comments (3)
source/tests/pd/test_training.py (2)
27-27
: LGTM!
The import statement follows the established pattern and is properly placed with other model imports.
199-213
: LGTM! Verify test data availability.
The test class implementation follows the established pattern and maintains consistency with other test classes. The implementation looks correct and well-structured.
Let's verify the test data availability:
#!/bin/bash
# Description: Verify that the required test data exists
# Test data path: water/data/data_0
# Check if the test data directory exists
if [ -d "source/tests/pd/water/data/data_0" ]; then
echo "Test data directory exists"
# List contents to verify structure
ls -la source/tests/pd/water/data/data_0
else
echo "Warning: Test data directory not found"
fi
# Check if the configuration file exists
if [ -f "source/tests/pd/water/se_atten.json" ]; then
echo "Configuration file exists"
else
echo "Warning: Configuration file not found"
fi
source/tests/pd/test_multitask.py (1)
33-33
: LGTM!
The import statement follows the established pattern and is properly placed with other model imports.
repinit_variable.pop("embeddings_strip") | ||
) | ||
obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) | ||
obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning
redefined
This assignment to 'env_mat' is unnecessary as it is
redefined
NetworkCollection.deserialize( | ||
repinit_three_body_variable.pop("embeddings_strip") | ||
) | ||
) |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning
redefined
obj.repformers.layers = paddle.nn.LayerList( | ||
[RepformerLayer.deserialize(layer) for layer in repformer_layers] | ||
) | ||
return obj |
Check notice
Code scanning / CodeQL
Unused local variable Note
# h2: nb x nloc x nnei x 3 | ||
# msk: nb x nloc x nnei | ||
nb, nloc, nnei, _ = g2.shape | ||
ng2 = g2.shape[-1] |
Check notice
Code scanning / CodeQL
Unused local variable Note
for ii in range(1, nitem): | ||
uu = uu + scale * update_list[ii] | ||
return uu | ||
|
Check notice
Code scanning / CodeQL
Unused local variable Note
[False, True], # use_econf_tebd | ||
[True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) | ||
): | ||
dtype = PRECISION_DICT[prec] |
Check notice
Code scanning / CodeQL
Unused local variable Note test
) | ||
dd0.repformers.stddev = paddle.to_tensor(dstd_2, dtype=dtype).to( | ||
device=env.DEVICE | ||
) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
def _test_unused(self, model_params): | ||
self.model = get_model(model_params).to(env.DEVICE) | ||
natoms = 5 | ||
generator = paddle.seed(GLOBAL_SEED) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
coord = paddle.rand([natoms, 3], dtype=dtype).to(device=env.DEVICE) | ||
coord = paddle.matmul(coord, cell) | ||
atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) | ||
idx_perm = [1, 0, 4, 3, 2] |
Check notice
Code scanning / CodeQL
Unused local variable Note test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (3)
deepmd/pd/model/task/fitting.py (1)
214-215
: Improve tensor initialization and device placement.The change from using
place
argument to.to(device)
is good for consistency with PyTorch-style device placement. However, consider usingpaddle.zeros_like
for better efficiency when reinitializing tensors.- paddle.zeros(self.dim_case_embd, dtype=self.prec).to(device=device), + paddle.zeros_like(paddle.empty([self.dim_case_embd], dtype=self.prec)).to(device=device),deepmd/pd/utils/multi_task.py (2)
99-101
: LGTM! Consider adding parameter type hints.The addition of the return type annotation
-> None
is a good practice as it clearly indicates the function modifies the input in-place. Consider also adding type hints for the parameters to further improve code clarity.def replace_one_item( - params_dict, key_type, key_in_dict, suffix="", index=None + params_dict: dict, + key_type: str, + key_in_dict: str, + suffix: str = "", + index: Optional[int] = None ) -> None:
99-101
: Document sharing levels in function docstring.The function handles sophisticated parameter sharing with different sharing levels (e.g., "my_des1:1"), but this feature isn't well-documented. Consider adding a docstring explaining:
- The sharing level syntax (e.g., ":1")
- How sharing levels affect parameter sharing
- Special handling of hybrid descriptors
def replace_one_item( params_dict, key_type, key_in_dict, suffix="", index=None ) -> None: """Replace parameters with shared values based on sharing configuration. Args: params_dict: Dictionary to modify with shared parameters key_type: Type of parameter (e.g., "descriptor", "fitting_net") key_in_dict: Key in shared_dict, optionally with sharing level (e.g., "my_des1:1") suffix: Optional suffix for hybrid models index: Optional index for hybrid model list Sharing levels: - No suffix (e.g., "my_des1"): Share all parameters (level 0) - With suffix (e.g., "my_des1:1"): Share subset of parameters (level 1+) """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pd/model/task/fitting.py
(1 hunks)deepmd/pd/utils/multi_task.py
(1 hunks)source/tests/pd/model/water/multitask.json
(1 hunks)source/tests/pd/model/water/multitask_sharefit.json
(2 hunks)source/tests/pd/test_multitask.py
(4 hunks)
🔇 Additional comments (5)
deepmd/pd/model/task/fitting.py (1)
214-215
: Verify case embedding usage patterns across the codebase.
Let's verify how case embeddings are used across the codebase to ensure this change doesn't impact other components.
✅ Verification successful
Based on the search results, I can now generate the final response:
Change from zeros to eye initialization appears safe
The change from paddle.zeros()
to commented-out paddle.eye()[0]
for case embeddings initialization is safe because:
- The actual case embedding used during inference is overridden in the
get_case_embd
method (line 273) with:
self.case_embd = paddle.eye(self.dim_case_embd, dtype=self.prec).to(device)[case_idx]
-
The initial value is only used as a placeholder during model initialization and doesn't affect the model's behavior since it gets replaced with the proper one-hot encoded case embedding during runtime.
-
The device placement is consistent in both versions since both use the same
.to(device=device)
pattern.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for case embedding usage patterns
echo "Searching for case embedding usage..."
rg -l "case_embd" --type py
# Search for potential device placement patterns
echo "Checking device placement patterns..."
rg "\.to\(device" --type py -A 2
Length of output: 39817
source/tests/pd/model/water/multitask_sharefit.json (1)
Line range hint 94-118
: LGTM! Path updates are consistent with Paddle backend support.
The changes correctly update the data paths from PyTorch (pt) to Paddle (pd) backend.
source/tests/pd/test_multitask.py (2)
324-324
: False positive: data_file
variable is used.
The static analysis tool incorrectly flags data_file
as unused. This variable is used in the configuration setup for both training and validation data systems.
Line range hint 239-394
: LGTM! Test classes are well-structured and complete.
The new test classes (TestMultiTaskSeASharefit
, TestMultiTaskDPA2
, TestMultiTaskDPA2Tebd
) are properly implemented with:
- Correct inheritance from unittest.TestCase and MultiTaskTrainTest
- Complete setUp methods with proper configuration initialization
- Proper tearDown methods
- Consistent test patterns with existing test classes
deepmd/pd/utils/multi_task.py (1)
99-101
: Verify DPA-2 descriptor registration.
Since this PR adds DPA-2 support, we should verify that the DPA-2 descriptor is properly registered and can be handled by the parameter sharing mechanism.
✅ Verification successful
Let me gather more information to verify the DPA-2 descriptor registration and implementation.
Based on the search results, I can now provide a final response:
DPA-2 descriptor is properly integrated and registered
The codebase shows comprehensive DPA-2 descriptor implementation across multiple backends:
- Paddle implementation in
deepmd/pd/model/descriptor/dpa2.py
- PyTorch implementation in
deepmd/pt/model/descriptor/dpa2.py
- JAX implementation in
deepmd/jax/descriptor/dpa2.py
- Base implementation in
deepmd/dpmodel/descriptor/dpa2.py
The descriptor is also properly used in configuration files (*.json, *.yaml) and has corresponding test cases. The parameter sharing mechanism in multi_task.py
will work correctly with DPA-2 descriptors.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify DPA-2 descriptor registration and usage in parameter sharing
# Check if DPA-2 descriptor is registered
rg -A 5 "register_descriptor.*dpa2"
# Check for DPA-2 usage in configuration files
rg -g "*.json" -g "*.yaml" "type.*dpa2"
# Verify descriptor class implementation
ast-grep --pattern 'class DescrptDPA2(BaseDescriptor):
$$$
'
Length of output: 807
Script:
#!/bin/bash
# Check for DPA-2 descriptor registration in registry files
rg -A 5 "register.*descriptor.*dpa"
# Look for DPA-2 descriptor class definition
rg -A 10 "class.*Descrpt.*DPA2?"
# Check imports and usage of DPA-2 descriptor
rg "from.*descrpt.*dpa|import.*descrpt.*dpa"
Length of output: 17188
@njzjz This PR is ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
deepmd/pd/model/descriptor/dpa2.py (3)
504-504
: Rename unused loop variable 'ii'
You are not using the loop index anywhere within the loop body, so rename it to underscore or another meaningful placeholder.- for ii, descrpt in enumerate(descrpt_list): + for _, descrpt in enumerate(descrpt_list):🧰 Tools
🪛 Ruff (0.8.2)
504-504: Loop control variable
ii
not used within loop bodyRename unused
ii
to_ii
(B007)
689-689
: Remove the unused assignment to 'env_mat'
Assigning to 'env_mat' serves no purpose if the variable is never referenced afterwards, so it can be removed to reduce clutter.- env_mat = repformers_variable.pop("env_mat") + repformers_variable.pop("env_mat")🧰 Tools
🪛 Ruff (0.8.2)
689-689: Local variable
env_mat
is assigned to but never usedRemove assignment to unused variable
env_mat
(F841)
901-902
: Implementation incomplete: 'enable_compression'
The method raises NotImplementedError, indicating it’s unfinished. If needed, you can delegate this to a future enhancement or provide a stub to guide the next steps.Would you like help creating a placeholder implementation or opening a GitHub issue to track this?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/model/descriptor/dpa2.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pd/model/descriptor/dpa2.py
90-90: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
504-504: Loop control variable ii
not used within loop body
Rename unused ii
to _ii
(B007)
689-689: Local variable env_mat
is assigned to but never used
Remove assignment to unused variable env_mat
(F841)
🔇 Additional comments (2)
deepmd/pd/model/descriptor/dpa2.py (2)
289-291
: Rename 'g1_shape_tranform' to 'g1_shape_transform'
This is a typographical error, which was also identified in a past review. Renaming it improves clarity.
Example fix:
- self.g1_shape_tranform = Identity()
+ self.g1_shape_transform = Identity()
90-90
:
Avoid mutable default arguments for 'exclude_types'
Repeated instantiations of this function share the same list object for the default argument. This could lead to surprising behavior.
Use a sentinel value like None for the default, and assign an empty list within the constructor, for example:
- exclude_types: list[tuple[int, int]] = [],
+ exclude_types: Optional[list[tuple[int, int]]] = None,
...
if exclude_types is None:
exclude_types = []
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
90-90: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
@coderabbitai resolve |
1 similar comment
@coderabbitai resolve |
@njzjz I noticed that there might be some subtle typos in the deepmd project. In Paddle, we have introduced a tool called typos in the pre-commit process, which can automatically detect potential typos. You can try it and added into deepmd if interested: PaddlePaddle/Paddle#69377. |
Support DPA-2 in paddle backend. This PR will be updated after #4414 is merged.
Training curve:
Accuracy test(left: paddle, right: torch):
Ralated optimization of Paddle framework:
paddle.where
andpaddle.where_
in eager mode PaddlePaddle/Paddle#69556Summary by CodeRabbit
New Features
DescrptDPA2
,DescrptBlockRepformers
,DescrptSeTTebd
, andDescrptBlockSeTTebd
.Bug Fixes
Documentation
Tests
DescrptDPA2
and multitask training scenarios.