Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): reformat nlist in the TF model #4336

Merged
merged 1 commit into from
Nov 12, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 11, 2024

Reformat the neighbor list in the TF model to convert the dynamic shape to the determined shape so the TF model can accept the neighbor list with a dynamic shape.

Summary by CodeRabbit

  • New Features

    • Introduced a new function to format neighbor lists based on selected neighbors and cutoff radius.
    • Enhanced deserialization process to incorporate the new formatting function for improved neighbor list handling.
  • Tests

    • Added a new test suite for the neighbor list formatting function, ensuring its functionality under various scenarios.

Format the neighbor list in the TF model to convert the dynamic shape to the determined shape, so the TF model can accept the neighbor list with a dynamic shape.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

coderabbitai bot commented Nov 11, 2024

📝 Walkthrough

Walkthrough

The changes introduce a new function format_nlist in deepmd/jax/jax2tf/format_nlist.py, which formats a neighbor list based on selected neighbors and a cutoff radius. This function is integrated into the deserialize_to_file function in deepmd/jax/jax2tf/serialization.py, modifying its parameter handling. Additionally, a new test suite for format_nlist is created in source/jax2tf_tests/test_format_nlist.py, which includes multiple unit tests to validate the functionality of the new method across different scenarios.

Changes

File Change Summary
deepmd/jax/jax2tf/format_nlist.py Added new function format_nlist(extended_coord, nlist, nsel, rcut) for neighbor list formatting.
deepmd/jax/jax2tf/serialization.py Updated deserialize_to_file to import and call format_nlist. Changed parameter types for call_lower_without_atomic_virial and call_lower_with_atomic_virial.
source/jax2tf_tests/test_format_nlist.py Added new test class TestFormatNlist with four unit tests for format_nlist.

Suggested labels

Python


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (8)
deepmd/jax/jax2tf/format_nlist.py (4)

2-4: Consider documenting the experimental numpy dependency.

The code relies on TensorFlow's experimental numpy module. While this is necessary for numpy-like operations in TensorFlow, it's worth documenting this dependency and potential stability considerations.

Add a comment explaining the experimental nature:

+# Using experimental numpy for tensor operations - required for numpy-like functionality in TF
import tensorflow.experimental.numpy as tnp

13-37: Enhance docstring with type and constraint information.

The docstring is well-structured but could benefit from additional details about types and constraints.

Consider adding:

  • The expected dtype of nlist (presumably integer)
  • Valid range for nsel (must be positive)
  • Valid range for rcut (must be positive)
  • Whether negative values in nlist (other than -1) are valid

52-66: Consider optimizing the truncation logic.

While the implementation is correct, there are potential improvements for robustness and performance:

  1. Replace float("inf") with tf.float32.max or the appropriate dtype max:
-        rr2 = tnp.where(m_real_nei, rr2, float("inf"))
+        rr2 = tnp.where(m_real_nei, rr2, tf.float32.max)
  1. Consider fusing operations to reduce memory usage:
-        coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
-        coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
+        coord1 = tnp.take_along_axis(extended_coord, index, axis=1).reshape(n_nf, n_nloc, n_nsel, 3)

69-71: Enhance the XLA-related comment.

While the comment explains the purpose, it could be more detailed about the XLA implications.

Consider expanding the comment:

-    # do a reshape any way; this will tell the xla the shape without any dynamic shape
+    # Explicitly reshape to help XLA compiler optimize the graph by providing
+    # static shape information, even though the shape hasn't changed. This
+    # eliminates the need for dynamic shape handling in downstream operations.
source/jax2tf_tests/test_format_nlist.py (3)

16-44: Add docstrings and comments to explain the test setup.

The test setup uses specific numerical values and configurations that would benefit from documentation:

  • Purpose of the test class
  • Explanation of the test parameters (nf, nloc, ns, etc.)
  • Reasoning behind the chosen coordinate values
  • Expected behavior of the neighbor list construction

Consider adding a class docstring like:

class TestFormatNlist(tf.test.TestCase):
    """Tests the format_nlist function with various neighbor list configurations.
    
    The test setup simulates a small molecular system with:
    - 3 local atoms (nloc)
    - 5x5x3 periodic images (ns)
    - 2 atom types
    - Triclinic cell
    """

45-48: Add more assertions to test_format_nlist_equal.

While testing for equality is good, consider adding more specific assertions:

  • Shape of the output
  • Data type consistency
  • Range of indices
 def test_format_nlist_equal(self):
     nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
+    self.assertEqual(nlist.shape, self.nlist.shape)
+    self.assertEqual(nlist.dtype, self.nlist.dtype)
+    self.assertTrue(tf.reduce_all(nlist >= 0))
+    self.assertTrue(tf.reduce_all(nlist < len(self.ecoord)))
     self.assertAllEqual(nlist, self.nlist)

77-91: Add test cases for edge cases and error conditions.

The current tests focus on valid inputs, but it would be valuable to test error conditions:

  • Empty neighbor list
  • Zero cutoff radius
  • Invalid coordinates

Consider adding tests like:

def test_format_nlist_empty(self):
    empty_nlist = tf.zeros((1, self.nloc, 0), dtype=tf.int32)
    with self.assertRaises(ValueError):
        format_nlist(self.ecoord, empty_nlist, sum(self.nsel), self.rcut)

def test_format_nlist_zero_cutoff(self):
    with self.assertRaises(ValueError):
        format_nlist(self.ecoord, self.nlist, sum(self.nsel), 0.0)
deepmd/jax/jax2tf/serialization.py (1)

91-91: Consider adding error handling for format_nlist calls.

While the format_nlist calls are correctly placed and use appropriate model parameters, consider adding error handling to gracefully handle potential formatting failures. This is especially important as the function deals with shape transformations that could fail with invalid input.

Example error handling:

-            nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
+            try:
+                nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
+            except ValueError as e:
+                tf.print("Error formatting neighbor list:", e)
+                raise

Also applies to: 116-116

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 02a3048 and 1eecf10.

📒 Files selected for processing (3)
  • deepmd/jax/jax2tf/format_nlist.py (1 hunks)
  • deepmd/jax/jax2tf/serialization.py (4 hunks)
  • source/jax2tf_tests/test_format_nlist.py (1 hunks)
🔇 Additional comments (5)
deepmd/jax/jax2tf/format_nlist.py (2)

42-50: LGTM! Efficient padding implementation.

The padding implementation correctly handles the case where we need more neighbors, using efficient tensor operations and preserving the original dtype.


67-68: LGTM! Efficient handling of the no-change case.

The implementation correctly handles the case where no modification is needed.

source/jax2tf_tests/test_format_nlist.py (1)

3-3: Consider the stability implications of using experimental numpy.

The code uses tensorflow.experimental.numpy which might be subject to breaking changes. Consider adding a version check or documenting the minimum required TensorFlow version to ensure compatibility.

deepmd/jax/jax2tf/serialization.py (2)

13-15: LGTM: Import statement is correctly placed.

The import of format_nlist is properly organized and aligns with the PR's objective of reformatting neighbor lists.


82-82: LGTM: Input signature changes enable dynamic neighbor lists.

The modification of nlist parameter spec to accept dynamic shapes ([None, None, None]) aligns with the PR objective. The change is consistently applied to both atomic virial and non-atomic virial functions.

Let's verify if there are any other functions in the codebase that expect the old fixed-shape neighbor list format:

Also applies to: 109-109

✅ Verification successful

Based on the search results, I can now generate the final response since I have sufficient information to verify the review comment:

LGTM: Dynamic neighbor list shape change is safe and consistent

The change from fixed shape [None, None, model.get_nnei()] to dynamic shape [None, None, None] for the nlist parameter is safe because:

  1. The format_nlist function is called before any processing to ensure the neighbor list has the correct shape using model.get_nnei() and model.get_rcut() (see jax2tf/serialization.py)
  2. All internal processing functions properly handle the neighbor list shape by extracting dimensions at runtime rather than assuming fixed sizes
  3. The dynamic shape provides more flexibility while maintaining compatibility with the existing codebase's shape handling
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for TensorSpec definitions with model.get_nnei()
rg "TensorSpec.*model\.get_nnei\(\)" --type py

Length of output: 48


Script:

#!/bin/bash
# Let's search more broadly for neighbor list related tensor specs and shapes
rg "TensorSpec.*\[None,\s*None,.*\].*int64" -A 2 -B 2 --type py

# Also search for any references to neighbor list shapes
rg "nlist.*shape" --type py -A 2 -B 2

# Look for any model.get_nnei() usage
rg "get_nnei\(\)" --type py -A 2 -B 2

Length of output: 65349

deepmd/jax/jax2tf/format_nlist.py Show resolved Hide resolved
source/jax2tf_tests/test_format_nlist.py Show resolved Hide resolved
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable mapping is not used.
Copy link

codecov bot commented Nov 11, 2024

Codecov Report

Attention: Patch coverage is 24.13793% with 22 lines in your changes missing coverage. Please review.

Project coverage is 84.48%. Comparing base (dcbf607) to head (1eecf10).
Report is 7 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/jax2tf/format_nlist.py 15.38% 22 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4336      +/-   ##
==========================================
- Coverage   84.51%   84.48%   -0.04%     
==========================================
  Files         575      576       +1     
  Lines       53398    53429      +31     
  Branches     3059     3059              
==========================================
+ Hits        45129    45138       +9     
- Misses       7306     7328      +22     
  Partials      963      963              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@njzjz njzjz marked this pull request as draft November 11, 2024 11:03
@njzjz
Copy link
Member Author

njzjz commented Nov 11, 2024

This PR needs more effort to work with the C++ interface.

@njzjz
Copy link
Member Author

njzjz commented Nov 11, 2024

This PR needs more effort to work with the C++ interface.

The issue doesn't come from this PR, but come from the C++ interface. I've fixed it.

@njzjz njzjz marked this pull request as ready for review November 11, 2024 14:25
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 12, 2024
Merged via the queue into deepmodeling:devel with commit 560d82e Nov 12, 2024
60 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Nov 13, 2024
Including nlist and no nlist interface.

The limitation: A SavedModel created on a device cannot be run on
another. For example, a CUDA model cannot be run on the CPU.

The model is generated using #4336.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Added support for the JAX backend, including specific model and
checkpoint file formats.
- Introduced a new shell script for model conversion to enhance
usability.
- Updated installation documentation to clarify JAX support and
requirements.
- New section in documentation detailing limitations of the JAX backend
with LAMMPS.

- **Bug Fixes**
- Enhanced error handling for model initialization and backend
compatibility.

- **Documentation**
- Updated backend documentation to include JAX details and limitations.
- Improved clarity in installation instructions for both TensorFlow and
JAX.

- **Tests**
- Added comprehensive unit tests for JAX integration with the Deep
Potential class.
  - Expanded test coverage for LAMMPS integration with DeepMD.

- **Chores**
- Updated CMake configurations and workflow files for improved testing
and dependency management.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Your Name <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants