-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax/array-api): hybrid descriptor #4275
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe pull request introduces several modifications primarily to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (6)
source/tests/array_api_strict/descriptor/se_e2_r.py (1)
Line range hint 33-34
: Enhance the env_mat comment documentation
The current comment "env_mat doesn't store any value" could be more descriptive about why this is the case.
- # env_mat doesn't store any value
+ # env_mat is a computed property and doesn't persist any value
source/tests/array_api_strict/descriptor/se_e2_a.py (1)
Line range hint 24-36
: Add return type hint to setattr.
The method signature should include the return type hint for better type safety and code clarity.
- def __setattr__(self, name: str, value: Any) -> None:
+ def __setattr__(self, name: str, value: Any) -> None: # type: ignore
Enhance the env_mat comment.
The current comment could be more descriptive about why env_mat doesn't store any value.
- # env_mat doesn't store any value
+ # env_mat is a computed property and doesn't store any value directly
doc/model/train-hybrid.md (1)
Line range hint 1-100
: Consider adding JAX-specific usage examples.
While the documentation comprehensively covers the theory and general usage, it might be helpful to add JAX-specific examples or notes, particularly if there are any unique considerations when using the hybrid descriptor with JAX.
Would you like me to help draft JAX-specific usage examples or notes to add to the documentation?
source/tests/array_api_strict/descriptor/dpa1.py (1)
Line range hint 78-85
: Consider adding docstring explaining multiple identifiers
The class is registered with two different identifiers ("dpa1" and "se_atten"). Consider adding a docstring to explain why both identifiers exist and their intended usage.
@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
class DescrptDPA1(DescrptDPA1DP):
+ """Descriptor implementation that can be accessed via 'dpa1' or 'se_atten' identifiers.
+
+ The dual registration supports both the standard name (dpa1) and the
+ implementation-specific name (se_atten) for backward compatibility.
+ """
def __setattr__(self, name: str, value: Any) -> None:
source/tests/consistent/descriptor/test_hybrid.py (1)
152-168
: Consider using more specific return type annotations.
The implementation of the evaluation methods is correct and consistent with existing patterns. However, consider replacing Any
with a more specific return type annotation to improve type safety.
For example:
- def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
+ def eval_array_api_strict(self, array_api_strict_obj: Any) -> tuple[np.ndarray, ...]:
- def eval_jax(self, jax_obj: Any) -> Any:
+ def eval_jax(self, jax_obj: Any) -> tuple[np.ndarray, ...]:
This matches the return type used in extract_ret
method and provides better type information.
deepmd/dpmodel/descriptor/hybrid.py (1)
Line range hint 247-275
: Consider adding error handling for array operations.
While the array operations are now backend-agnostic, it would be beneficial to add error handling for potential array operation failures, especially when dealing with different backends that might have varying support for these operations.
Consider wrapping the array operations in try-except blocks:
def call(self, coord_ext, atype_ext, nlist, mapping: Optional[np.ndarray] = None):
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
out_descriptor = []
out_gr = []
out_g2 = None
out_h2 = None
out_sw = None
+ try:
if self.sel_no_mixed_types is not None:
nl_distinguish_types = nlist_distinguish_types(
nlist,
atype_ext,
self.sel_no_mixed_types,
)
else:
nl_distinguish_types = None
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
if self.mixed_types() == descrpt.mixed_types():
nl = xp.take(nlist, nci, axis=2)
else:
assert nl_distinguish_types is not None
nl = nl_distinguish_types[:, :, nci]
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
out_descriptor.append(odescriptor)
if gr is not None:
out_gr.append(gr)
out_descriptor = xp.concat(out_descriptor, axis=-1)
out_gr = xp.concat(out_gr, axis=-2) if out_gr else None
+ except Exception as e:
+ raise RuntimeError(f"Array operation failed with backend {xp.__name__}: {str(e)}")
return out_descriptor, out_gr, out_g2, out_h2, out_sw
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (11)
- deepmd/dpmodel/descriptor/hybrid.py (6 hunks)
- deepmd/jax/descriptor/init.py (2 hunks)
- deepmd/jax/descriptor/hybrid.py (1 hunks)
- doc/model/train-hybrid.md (1 hunks)
- source/tests/array_api_strict/descriptor/init.py (1 hunks)
- source/tests/array_api_strict/descriptor/base_descriptor.py (1 hunks)
- source/tests/array_api_strict/descriptor/dpa1.py (2 hunks)
- source/tests/array_api_strict/descriptor/hybrid.py (1 hunks)
- source/tests/array_api_strict/descriptor/se_e2_a.py (1 hunks)
- source/tests/array_api_strict/descriptor/se_e2_r.py (1 hunks)
- source/tests/consistent/descriptor/test_hybrid.py (4 hunks)
✅ Files skipped from review due to trivial changes (2)
- source/tests/array_api_strict/descriptor/init.py
- source/tests/array_api_strict/descriptor/base_descriptor.py
🔇 Additional comments (22)
deepmd/jax/descriptor/__init__.py (1)
5-7
: LGTM! Clean and consistent changes.
The import statement and __all__
list update follow the established patterns in the codebase, maintaining consistency with other descriptor implementations.
Also applies to: 19-19
source/tests/array_api_strict/descriptor/hybrid.py (3)
1-14
: LGTM! Well-organized imports with proper licensing.
The file structure follows best practices with:
- Clear license header
- Properly organized imports
- Appropriate type hints
16-17
: LGTM! Proper class registration and inheritance.
The class is correctly registered as a descriptor plugin and extends the appropriate base class.
18-24
: 🛠️ Refactor suggestion
Verify error handling for array conversion and deserialization.
While the implementation is functionally correct, it might benefit from explicit error handling:
- Array conversion might fail for invalid inputs in
nlist_cut_idx
- Serialization/deserialization might fail for incompatible descriptors in
descrpt_list
Consider adding error handling:
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
- value = [to_array_api_strict_array(vv) for vv in value]
+ try:
+ value = [to_array_api_strict_array(vv) for vv in value]
+ except (TypeError, ValueError) as e:
+ raise ValueError(f"Failed to convert nlist_cut_idx: {e}")
elif name in {"descrpt_list"}:
- value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]
+ try:
+ value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]
+ except Exception as e:
+ raise ValueError(f"Failed to process descrpt_list: {e}")
return super().__setattr__(name, value)
Let's verify the usage patterns:
deepmd/jax/descriptor/hybrid.py (3)
1-15
: LGTM! Clean and well-organized imports.
The imports are properly organized, with clear grouping and specific imports rather than wildcards. The license header is correctly included.
17-19
: LGTM! Proper class registration and decoration.
The class is correctly registered with BaseDescriptor and properly decorated as a flax module.
20-26
: Verify consistent attribute handling across JAX descriptors.
Let's ensure this attribute handling pattern is consistent with other JAX descriptors in the codebase.
✅ Verification successful
Let me gather more specific information about the attribute handling in the other JAX descriptors.
Attribute handling in hybrid.py is unique and correct for its use case
The verification shows that nlist_cut_idx
and descrpt_list
attributes are unique to the hybrid descriptor and not present in other JAX descriptors. Each descriptor implements __setattr__
specifically for its own attributes:
- se_e2_r.py and se_e2_a.py handle
dstd
,davg
,embeddings
, andenv_mat
- dpa1.py handles various attributes related to attention layers and networks
- hybrid.py correctly handles its unique attributes
nlist_cut_idx
anddescrpt_list
The implementation in hybrid.py follows the same pattern as other descriptors - converting values to appropriate types using serialization/deserialization or array conversion, which is consistent with the codebase's approach.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for similar attribute handling patterns in other JAX descriptors
# Look for other JAX descriptor implementations
echo "Searching for other JAX descriptors..."
rg -l "BaseDescriptor.register" "deepmd/jax/descriptor/"
# Check for similar __setattr__ patterns
echo "Checking for similar attribute handling patterns..."
ast-grep --pattern 'class $_ {
$$$
__setattr__($_, $_, $_) {
$$$
}
$$$
}'
Length of output: 470
Script:
#!/bin/bash
# Check implementation details of __setattr__ in other JAX descriptors
echo "Checking se_e2_r.py..."
rg "__setattr__" "deepmd/jax/descriptor/se_e2_r.py" -A 10
echo -e "\nChecking se_e2_a.py..."
rg "__setattr__" "deepmd/jax/descriptor/se_e2_a.py" -A 10
echo -e "\nChecking dpa1.py..."
rg "__setattr__" "deepmd/jax/descriptor/dpa1.py" -A 10
# Also check for usage of nlist_cut_idx and descrpt_list
echo -e "\nChecking usage of nlist_cut_idx and descrpt_list..."
rg "nlist_cut_idx|descrpt_list" "deepmd/jax/descriptor/"
Length of output: 3690
source/tests/array_api_strict/descriptor/se_e2_r.py (3)
17-19
: LGTM: Clean import statement
The import is well-structured and appropriately scoped.
Line range hint 24-37
: LGTM: Well-structured attribute handling
The implementation provides clear, type-specific handling for each attribute type with appropriate conversions and validations.
22-23
: Verify descriptor registration uniqueness
The class is registered under two different names. While this is valid, we should verify there are no conflicts with existing registrations.
✅ Verification successful
Let me gather more information to understand the implementation details of these registrations.
Multiple framework-specific implementations found - no conflicts
The dual registration of "se_e2_r" and "se_r" is intentionally replicated across different framework implementations (TensorFlow, PyTorch, JAX) and each implementation is properly scoped within its respective framework's namespace through different base classes:
- TensorFlow:
@Descriptor.register()
- PyTorch/JAX/Tests:
@BaseDescriptor.register()
- DPModel:
@BaseDescriptor.register()
The registrations in the test file follow the same pattern as the framework implementations, and there are no conflicting registrations within the same namespace.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any existing uses of these descriptor names
# to ensure we're not overwriting existing functionality
# Search for any other registrations of these descriptors
echo "Checking for existing registrations of se_e2_r and se_r..."
rg -l "register\([\"']se_e2_r[\"']\)" --type py
rg -l "register\([\"']se_r[\"']\)" --type py
Length of output: 597
Script:
#!/bin/bash
# Check the actual implementation of these registrations to understand potential conflicts
echo "Checking se_e2_r implementations..."
rg -A 5 "register\([\"']se_e2_r[\"']\)" --type py
echo -e "\nChecking se_r implementations..."
rg -A 5 "register\([\"']se_r[\"']\)" --type py
Length of output: 4554
source/tests/array_api_strict/descriptor/se_e2_a.py (2)
17-23
: LGTM! Clean import organization and clear class registration.
The dual registration with both "se_e2_a" and "se_a" identifiers provides good flexibility for users while maintaining backward compatibility.
Line range hint 24-36
: Verify array API compatibility across the codebase.
Let's ensure consistent usage of array API strict arrays in related files.
✅ Verification successful
Based on the search results, I can now generate the final response:
Array API strict array conversion is consistently implemented across the codebase.
The implementation in DescrptSeA
follows the established pattern used throughout the codebase:
- The
to_array_api_strict_array
function is used consistently in related descriptor files (se_e2_r.py
,se_e2_a.py
,hybrid.py
,dpa1.py
) - The array conversion is properly handled for the specific attributes (
dstd
,davg
) - The implementation aligns with other array API strict implementations in the utils and fitting modules
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for consistent array API usage in related descriptor files
# Expected: All array conversions should use to_array_api_strict_array
# Search for array conversion patterns in descriptor-related files
rg -l "to_array_api_strict_array" --type python
Length of output: 78
Script:
#!/bin/bash
# Search for array conversion patterns in all files
rg -l "to_array_api_strict_array"
# Also search for the specific class and its usage
rg -l "DescrptSeA"
# Look for array API related imports and usage
rg "array_api" -A 2
Length of output: 44835
doc/model/train-hybrid.md (1)
1-1
: Verify JAX backend implementation completeness.
The documentation has been updated to include JAX support. Let's verify that all necessary JAX implementation components are in place.
Also applies to: 4-4
✅ Verification successful
Let me gather more specific information about the JAX implementation.
Let me check one more aspect to verify the actual JAX implementation:
JAX hybrid descriptor implementation is verified and complete
The documentation update is accurate as the JAX backend implementation for hybrid descriptors is properly in place with:
- Implementation file at
deepmd/jax/descriptor/hybrid.py
- Proper class registration and imports in
deepmd/jax/descriptor/__init__.py
- Test coverage in
source/tests/consistent/descriptor/test_hybrid.py
- Inheritance from base implementation (
DescrptHybridDP
) with JAX-specific array handling
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify JAX implementation components
# Test 1: Check for JAX descriptor implementation
echo "Checking for JAX descriptor implementation..."
ast-grep --pattern 'class DescrptHybrid'
# Test 2: Check for hybrid descriptor registration
echo "Checking for hybrid descriptor registration..."
rg -l "register_descriptor.*hybrid"
# Test 3: Check for JAX-specific tests
echo "Checking for JAX-specific tests..."
fd test_hybrid.py
Length of output: 433
Script:
#!/bin/bash
# Check for JAX-specific hybrid descriptor implementation
echo "Checking JAX hybrid descriptor implementation..."
rg -A 5 "class.*Hybrid.*JAX|class.*JAX.*Hybrid"
# Check for hybrid descriptor registration in JAX
echo "Checking JAX hybrid descriptor registration..."
rg -A 5 "register.*hybrid.*jax|jax.*hybrid"
# Check test coverage for JAX hybrid
echo "Checking JAX hybrid tests..."
rg -A 5 "test.*hybrid.*jax|jax.*hybrid" source/tests/consistent/descriptor/test_hybrid.py
# Check for JAX-specific configuration or examples
echo "Checking for JAX hybrid examples..."
rg "jax.*hybrid|hybrid.*jax" doc/model/train-hybrid.md
Length of output: 2334
Script:
#!/bin/bash
# Check the actual JAX hybrid descriptor implementation
echo "Checking JAX hybrid descriptor implementation file..."
rg -A 10 "class DescrptHybrid" deepmd/jax/descriptor/hybrid.py
# Check if the implementation file exists
echo "Verifying file existence..."
fd hybrid.py deepmd/jax/descriptor/
Length of output: 692
source/tests/array_api_strict/descriptor/dpa1.py (2)
30-32
: LGTM: Clean import addition
The BaseDescriptor import is properly organized and necessary for the registration decorators.
78-79
: Verify unique registration identifiers
The registration looks good, but let's verify that these identifiers aren't used elsewhere in the codebase.
✅ Verification successful
Multiple framework-specific implementations found - registration pattern is correct
The registrations are intentionally duplicated across different framework implementations:
- PyTorch implementation in
deepmd/pt/model/descriptor/dpa1.py
- JAX implementation in
deepmd/jax/descriptor/dpa1.py
- Native implementation in
deepmd/dpmodel/descriptor/dpa1.py
- Array API test implementation in
source/tests/array_api_strict/descriptor/dpa1.py
Each implementation is framework-specific and uses the same identifiers to provide consistent API across different backends. This is a common pattern in ML frameworks supporting multiple backends.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any other uses of these identifiers in descriptor registrations
# Expect: Only one registration for each identifier
# Search for other registrations using these identifiers
echo "Checking for 'dpa1' registrations:"
rg -l "@BaseDescriptor\.register\(['\"]dpa1['\"]\)"
echo "Checking for 'se_atten' registrations:"
rg -l "@BaseDescriptor\.register\(['\"]se_atten['\"]\)"
Length of output: 593
Script:
#!/bin/bash
# Let's check the content of these files to understand if they are different implementations or duplicates
echo "Content of source/tests/array_api_strict/descriptor/dpa1.py:"
rg -A 2 "@BaseDescriptor\.register\(['\"]dpa1['\"]\)" source/tests/array_api_strict/descriptor/dpa1.py
echo -e "\nContent of deepmd/pt/model/descriptor/dpa1.py:"
rg -A 2 "@BaseDescriptor\.register\(['\"]dpa1['\"]\)" deepmd/pt/model/descriptor/dpa1.py
echo -e "\nContent of deepmd/dpmodel/descriptor/dpa1.py:"
rg -A 2 "@BaseDescriptor\.register\(['\"]dpa1['\"]\)" deepmd/dpmodel/descriptor/dpa1.py
echo -e "\nContent of deepmd/jax/descriptor/dpa1.py:"
rg -A 2 "@BaseDescriptor\.register\(['\"]dpa1['\"]\)" deepmd/jax/descriptor/dpa1.py
Length of output: 1309
source/tests/consistent/descriptor/test_hybrid.py (4)
15-16
: LGTM! Import changes are consistent with existing patterns.
The new imports for JAX and Array API Strict installation flags follow the established pattern.
33-42
: LGTM! Backend class definitions follow established patterns.
The conditional imports for JAX and Array API Strict backends are well-structured and consistent with existing implementations.
83-84
: LGTM! Class properties are properly defined.
The new backend class properties maintain consistency with existing backend definitions.
87-89
: LGTM! Skip flags are properly implemented.
The skip flags for JAX and Array API Strict backends are correctly defined and follow the established pattern for conditional test execution.
deepmd/dpmodel/descriptor/hybrid.py (3)
9-9
: LGTM: Good addition of array_api_compat.
The addition of array_api_compat import enables backend-agnostic array operations, which is essential for supporting multiple array libraries like NumPy and JAX.
70-70
: LGTM: Good refactoring of variable scope.
Converting nlist_cut_idx
to a local variable during initialization is a good practice as it:
- Keeps intermediate computation results local
- Makes the code more maintainable by clearly separating temporary computation from instance state
Also applies to: 96-97
247-247
: LGTM with a performance consideration.
The changes to use array_api_compat operations (xp.take
and xp.concat
) make the code backend-agnostic, which is great for flexibility. However, xp.take
might have different performance characteristics compared to direct indexing depending on the backend.
Let's verify the performance impact:
Also applies to: 264-264, 274-275
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4275 +/- ##
=======================================
Coverage 84.37% 84.37%
=======================================
Files 551 552 +1
Lines 51585 51602 +17
Branches 3052 3052
=======================================
+ Hits 43524 43540 +16
- Misses 7100 7102 +2
+ Partials 961 960 -1 ☔ View full report in Codecov by Sentry. |
Summary by CodeRabbit
New Features
DescrptHybrid
class with specialized attribute handling.Bug Fixes
Documentation