Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): build nlist in the SavedModel & fix nopbc for StableHLO and SavedModel #4318

Merged
merged 32 commits into from
Nov 11, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 6, 2024

Per our discussion, use TF to build the neighbor list in the SavedModel format.

Also, fix a bug when the number of ghost atoms is zero. The polymorphic_shape needs to be larger than 1, and nghost == 0 triggered the error. Previously, I also tried nall or nghost - 1 but none of them worked. Finally, I export two different functions... So now four functions are stored in the model: calculate virial or not, x nghost is zero or not. The tests for nopbc are added.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced model initialization with additional parameters for improved functionality.
    • Introduced functions for neighbor list management and region transformations in molecular simulations.
    • Added new methods for handling atomic virial calculations in model predictions.
    • New functions for transforming model outputs to accommodate local and ghost atoms.
  • Bug Fixes

    • Improved error handling in model serialization and evaluation processes.
  • Tests

    • Added comprehensive unit tests for new functionalities, ensuring consistent behavior across different scenarios, including tests for neighbor list construction and region transformations.
  • Chores

    • Updated testing workflow for better organization and efficiency.
    • Modified dependency management and linting configurations in pyproject.toml.

Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = tf.shape(mapping)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable mldims is not used.
Copy link
Contributor

coderabbitai bot commented Nov 6, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces several enhancements across multiple files, primarily focusing on the DeepEval class and related functionalities. Key changes include the addition of new parameters for model initialization, improvements in error handling, and the introduction of new functions for neighbor list management and coordinate transformations. The testing framework has also been updated to cover new functionalities, ensuring compatibility with TensorFlow's experimental features. Overall, the changes aim to enhance model evaluation capabilities and streamline the handling of atomic parameters and ghost atoms.

Changes

File Path Change Summary
deepmd/jax/infer/deep_eval.py Updated DeepEval class to include new parameters stablehlo_no_ghost and stablehlo_atomic_virial_no_ghost in the HLO object initialization. Error handling remains unchanged.
deepmd/jax/jax2tf/__init__.py Added import for tensorflow.experimental.numpy as tnp and called tnp.experimental_enable_numpy_behavior(). Error handling for non-eager execution remains unchanged.
deepmd/jax/jax2tf/make_model.py Introduced model_call_from_call_lower function for obtaining model predictions, handling coordinates, atom types, and ghost atoms.
deepmd/jax/jax2tf/nlist.py New functions for neighbor list management: build_neighbor_list, nlist_distinguish_types, extend_coord_with_ghosts, and tf_outer.
deepmd/jax/jax2tf/region.py New functions for coordinate transformations: phys2inter, inter2phys, normalize_coord, and to_face_distance.
deepmd/jax/jax2tf/serialization.py Modified deserialize_to_file to include has_ghost_atoms in exported_whether_do_atomic_virial, affecting serialization logic for models with ghost atoms.
deepmd/jax/jax2tf/tfmodel.py Added methods _call and _call_atomic_virial in TFModelWrapper, refactored call method to handle atomic virial calculations based on the do_atomic_virial flag.
deepmd/jax/jax2tf/transform_output.py Introduced get_leading_dims and communicate_extended_output functions for managing model outputs, especially for local and ghost atoms.
deepmd/jax/model/hlo.py Updated HLO class constructor to include new parameters related to ghost atoms and modified call_lower method for enhanced control flow based on input dimensions.
deepmd/jax/utils/serialization.py Updated exported_whether_do_atomic_virial to include has_ghost_atoms, allowing for different serialized outputs based on the presence of ghost atoms.
source/tests/consistent/io/test_io.py Enhanced IOTest class to evaluate results with and without periodic boundary conditions, adding assertions for consistency across different backends.
.github/workflows/test_python.yml Updated testing workflow configuration to improve job management, matrix strategy for Python versions, and caching mechanism for test durations.
pyproject.toml Modified dependency management and linting configurations, including banned imports and per-file ignores for specific test paths.
source/tests/jax/__init__.py Added SPDX license identifier for licensing under LGPL-3.0-or-later.
source/jax2tf_tests/__init__.py Added SPDX license identifier for licensing under LGPL-3.0-or-later.
source/jax2tf_tests/test_nlist.py Introduced unit tests for neighbor list functions, validating outputs against expected results.
source/jax2tf_tests/test_region.py Introduced unit tests for region transformation functions, ensuring correctness through assertions.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant DeepEval
    participant HLO
    participant Model
    User->>DeepEval: Initialize with model file
    DeepEval->>HLO: Create HLO with parameters
    HLO->>Model: Prepare model for evaluation
    User->>DeepEval: Evaluate model with coordinates
    DeepEval->>Model: Call model with parameters
    Model-->>DeepEval: Return predictions
    DeepEval-->>User: Return evaluation results
Loading

Possibly related PRs

Suggested labels

Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 215efc6 and e4bac35.

📒 Files selected for processing (2)
  • source/jax2tf_tests/test_nlist.py (1 hunks)
  • source/jax2tf_tests/test_region.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/jax2tf_tests/test_nlist.py
  • source/jax2tf_tests/test_region.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Outside diff range and nitpick comments (24)
deepmd/jax/jax2tf/region.py (3)

29-48: Extract common shape validation logic.

Similar to phys2inter, this function would benefit from shape validation. Consider extracting the validation into a shared utility function to avoid code duplication.

+def _validate_coord_cell_shapes(coord: tnp.ndarray, cell: tnp.ndarray, name: str):
+    coord_shape = tf.shape(coord)
+    cell_shape = tf.shape(cell)
+    tf.debugging.assert_equal(
+        coord_shape[-1],
+        cell_shape[-1],
+        message=f"{name}: Last dimension of coord and cell must match"
+    )
+    tf.debugging.assert_equal(
+        cell_shape[-2:],
+        tf.constant([3, 3]),
+        message=f"{name}: Cell must be a 3x3 matrix"
+    )

 def inter2phys(...):
+    _validate_coord_cell_shapes(coord, cell, "inter2phys")
     return tnp.matmul(coord, cell)

51-72: Optimize matrix operations for better performance.

The current implementation performs two matrix multiplications by calling phys2inter and inter2phys. This could be optimized by caching the inverse cell matrix and reusing it.

 def normalize_coord(
     coord: tnp.ndarray,
     cell: tnp.ndarray,
 ) -> tnp.ndarray:
+    rec_cell = tf.linalg.inv(cell)
-    icoord = phys2inter(coord, cell)
+    icoord = tnp.matmul(coord, rec_cell)
     icoord = tnp.remainder(icoord, 1.0)
-    return inter2phys(icoord, cell)
+    return tnp.matmul(icoord, cell)

75-93: Add validation for cell tensor in to_face_distance.

The function reshapes the input tensor but doesn't validate that the final dimensions are correct.

 def to_face_distance(
     cell: tnp.ndarray,
 ) -> tnp.ndarray:
     cshape = tf.shape(cell)
+    tf.debugging.assert_equal(
+        cshape[-2:],
+        tf.constant([3, 3]),
+        message="Cell must be a 3x3 matrix"
+    )
     dist = b_to_face_distance(tnp.reshape(cell, [-1, 3, 3]))
     return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0))
deepmd/jax/utils/serialization.py (2)

72-76: Simplify ghost atoms logic using ternary operator

The logic is correct, but it can be more concise using a ternary operator.

-            if has_ghost_atoms:
-                nghost_ = nghost
-            else:
-                nghost_ = 0
+            nghost_ = nghost if has_ghost_atoms else 0
🧰 Tools
🪛 Ruff

72-75: Use ternary operator nghost_ = nghost if has_ghost_atoms else 0 instead of if-else-block

Replace if-else-block with nghost_ = nghost if has_ghost_atoms else 0

(SIM108)


Tests needed for new serialization configurations

The implementation introduces new serialization configurations for ghost atoms and atomic virial combinations, but no corresponding test coverage was found. Tests should be added to verify:

  • Serialization with ghost atoms enabled/disabled
  • Serialization with atomic virial calculation enabled/disabled
  • Combined scenarios of both configurations

Specific areas needing test coverage:

  • exported_whether_do_atomic_virial() function behavior
  • Serialization of all four combinations stored in @variables
  • Proper handling of ghost atoms in the exported models
🔗 Analysis chain

Line range hint 56-119: Verify test coverage for all configurations

While the implementation looks solid, we should ensure proper test coverage for all combinations of:

  • Atomic virial calculation (True/False)
  • Ghost atoms presence (True/False)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for test files covering these scenarios
rg -l "exported_whether_do_atomic_virial" --type python | grep "test"

Length of output: 98


Script:

#!/bin/bash
# Search for test files covering these scenarios without file type filter
rg -l "exported_whether_do_atomic_virial" | grep "test"

Length of output: 53


Script:

#!/bin/bash
# Search for test files that might cover the serialization functionality
rg -l "stablehlo.*ghost|stablehlo.*atomic_virial" | grep "test"

Length of output: 63


Script:

#!/bin/bash
# Look for test files in the deepmd directory
fd "test.*\.py$" deepmd/

Length of output: 83


Script:

#!/bin/bash
# Search for test files in the jax directory specifically
fd "test.*\.py$" deepmd/jax/

# Look for any serialization tests
rg -l "serialize|serializ" deepmd/jax/ | grep "test"

# Check content of test files for related functionality
rg "atomic_virial|ghost" deepmd/jax/ -g "test*.py" -C 2

Length of output: 141

source/tests/consistent/io/test_io.py (2)

186-202: Add documentation for nopbc test cases.

While the implementation is correct, it would be helpful to add a comment explaining that passing None as the box parameter indicates no periodic boundary conditions. This would make the test's intention more explicit.

+ # Test evaluation without periodic boundary conditions by passing None as box
 ret = deep_eval.eval(
     self.coords,
     None,
     self.atype,
     fparam=fparam,
     aparam=aparam,
 )

The review comment's suggestion for additional test coverage is valid.

While the codebase has logic to handle zero ghost atoms in the JAX serialization code, there are no explicit test cases verifying this scenario. The existing tests in test_virtual_type.py only cover non-zero ghost atoms (nghost=10). Adding test cases for zero ghost atoms would help ensure the polymorphic_shape fix works correctly for this edge case.

  • Add test cases in source/tests/consistent/io/test_io.py to verify model behavior with zero ghost atoms
  • Consider adding similar test coverage in source/tests/tf/test_virtual_type.py
🔗 Analysis chain

Line range hint 145-218: Consider adding test cases for zero ghost atoms.

While the test coverage for nopbc is comprehensive, consider adding specific test cases to verify the fix for zero ghost atoms mentioned in the PR objectives. This would ensure that the polymorphic_shape issue is properly addressed.


Based on the search results, I can see that there are existing tests for ghost atoms in source/tests/tf/test_virtual_type.py and source/tests/tf/common.py that cover cases with non-zero ghost atoms (e.g., nghost = 10 and nghost = 4). The codebase also has specific handling for zero ghost atoms in JAX serialization (deepmd/jax/jax2tf/serialization.py and deepmd/jax/utils/serialization.py).

Let me verify if there are specific test cases for zero ghost atoms.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if there are any existing tests for zero ghost atoms
# Test: Search for test cases involving ghost atoms
rg -i "ghost.*atom|nghost" --type py

Length of output: 8281


Script:

#!/bin/bash
# Check for test cases that explicitly test zero ghost atoms
rg -B2 -A2 "nghost.*=.*0|has_ghost_atoms.*=.*false" --type py

Length of output: 1487

deepmd/jax/jax2tf/tfmodel.py (1)

144-147: Consider using a ternary operator for better readability

The call selection can be more concise.

-        if do_atomic_virial:
-            call = self._call_atomic_virial
-        else:
-            call = self._call
+        call = self._call_atomic_virial if do_atomic_virial else self._call
🧰 Tools
🪛 Ruff

144-147: Use ternary operator call = self._call_atomic_virial if do_atomic_virial else self._call instead of if-else-block

Replace if-else-block with call = self._call_atomic_virial if do_atomic_virial else self._call

(SIM108)

deepmd/jax/jax2tf/make_model.py (4)

50-72: Document all function parameters in the docstring

The parameters call_lower, rcut, sel, mixed_types, and model_output_def are not documented in the docstring. Please add descriptions for these parameters to enhance clarity and maintainability.

Example addition to the docstring:

     """
     Return model prediction from lower interface.

     Parameters
     ----------
+    call_lower
+        A callable that accepts extended coordinates and other parameters, returning model predictions as a dictionary.
+    rcut
+        Cut-off radius for neighbor list construction.
+    sel
+        List of integers specifying selected atom types.
+    mixed_types
+        Boolean indicating whether to treat atom types as mixed.
+    model_output_def
+        Definition of the model outputs.
     coord
         The coordinates of the atoms.
         shape: nf x (nloc x 3)

75-76: Improve variable naming for better readability

Using aliases like cc, bb, fp, ap for coord, box, fparam, aparam reduces code readability. It's clearer to use the original variable names throughout the function.

Apply this diff to enhance readability:

- cc, bb, fp, ap = coord, box, fparam, aparam
- del coord, box, fparam, aparam
+ # Use original variable names throughout the function

75-76: Remove unnecessary variable deletion

The del statements for coord, box, fparam, and aparam might be unnecessary. Python's garbage collector handles memory management, and removing these lines can simplify the code.

Apply this diff to remove unnecessary code:

  cc, bb, fp, ap = coord, box, fparam, aparam
- del coord, box, fparam, aparam

93-93: Clarify the logic of distinguish_types parameter

The expression distinguish_types=not mixed_types may be confusing. Consider renaming mixed_types to ignore_atom_types or adding a comment to clarify that distinguish_types is the inverse of mixed_types.

deepmd/jax/jax2tf/transform_output.py (3)

54-54: Remove unused variable mldims.

The variable mldims is assigned but never used, which could lead to confusion. Please remove it to keep the code clean.

🧰 Tools
🪛 Ruff

54-54: Local variable mldims is assigned to but never used

Remove assignment to unused variable mldims

(F841)

🪛 GitHub Check: CodeQL

[notice] 54-54: Unused local variable
Variable mldims is not used.


83-83: Replace assert with explicit exception handling.

Using assert statements for control flow may not be ideal in production code because they can be disabled with optimization flags. Consider raising a specific exception with a clear error message to handle cases where vdef.r_differentiable is False while vdef.c_differentiable is True.


41-44: Enhance the docstring for communicate_extended_output.

The current docstring is brief. Providing detailed descriptions of the parameters, return values, and any important computational steps would improve maintainability and clarity for other developers.

deepmd/jax/jax2tf/nlist.py (6)

2-4: Simplify import statement by removing unnecessary parentheses

The parentheses around Union are unnecessary when importing a single item.

Apply this diff to simplify the import:

-from typing import (
-    Union,
-)
+from typing import Union

9-11: Simplify import statement by removing unnecessary parentheses

The parentheses around to_face_distance are unnecessary when importing a single item.

Apply this diff to simplify the import:

-from .region import (
-    to_face_distance,
-)
+from .region import to_face_distance

14-14: Remove reference to 'chatgpt' in comment for professionalism

It's better to avoid mentioning specific tools like 'chatgpt' in code comments for professionalism and maintainability.

Apply this diff to update the comment:

-## translated from torch implementation by chatgpt
+## Translated from torch implementation

28-28: Correct typo in docstring: 'exptended' should be 'extended'

There's a typo in the parameter description for coord.

Apply this diff to fix the typo:

-        exptended coordinates of shape [batch_size, nall x 3]
+        extended coordinates of shape [batch_size, nall x 3]

154-154: Correct typo in docstring: 'peridoc' should be 'periodic'

There's a typo in the function description.

Apply this diff to fix the typo:

-    """Extend the coordinates of the atoms by appending peridoc images.
+    """Extend the coordinates of the atoms by appending periodic images.

147-147: Remove reference to 'chatgpt' in comment for professionalism

As before, it's better to avoid mentioning specific tools like 'chatgpt' in code comments.

Apply this diff to update the comment:

-## translated from torch implementation by chatgpt
+## Translated from torch implementation
deepmd/jax/model/hlo.py (2)

183-192: Consider adding comments to clarify the conditional logic in call_lower.

The conditional logic in call_lower now depends on the dimensions of extended_coord and nlist, as well as the do_atomic_virial flag. Adding comments to explain the reasoning behind these conditions can improve code readability and maintainability.

For example:

 def call_lower(
     self,
     extended_coord: jnp.ndarray,
     extended_atype: jnp.ndarray,
     nlist: jnp.ndarray,
     mapping: Optional[jnp.ndarray] = None,
     fparam: Optional[jnp.ndarray] = None,
     aparam: Optional[jnp.ndarray] = None,
     do_atomic_virial: bool = False,
 ):
+    # Determine if ghost atoms are present based on the shape of extended_coord and nlist
     if extended_coord.shape[1] > nlist.shape[1]:
+        # Case with ghost atoms
         if do_atomic_virial:
             call_lower = self._call_lower_atomic_virial
         else:
             call_lower = self._call_lower
     else:
+        # Case without ghost atoms
         if do_atomic_virial:
             call_lower = self._call_lower_atomic_virial_no_ghost
         else:
             call_lower = self._call_lower_no_ghost
     return call_lower(
         extended_coord,
         extended_atype,

183-192: Ensure unit tests cover all branches of the new conditional logic.

With the updated call_lower method introducing additional branches, it's important to verify that all scenarios are properly tested. This includes cases:

  • With and without ghost atoms (i.e., when extended_coord.shape[1] > nlist.shape[1] and when it's not).
  • With do_atomic_virial set to both True and False.
deepmd/jax/jax2tf/serialization.py (1)

55-58: Simplify assignment with a ternary operator

Consider replacing the if-else block with a ternary operator for conciseness and readability.

Apply this diff to implement the suggestion:

             # nghost >= 1 is assumed if there is
             # other workaround does not work, such as
             # nall; nloc + nghost - 1
-            if has_ghost_atoms:
-                nghost = "nghost"
-            else:
-                nghost = "0"
+            nghost = "nghost" if has_ghost_atoms else "0"
🧰 Tools
🪛 Ruff

55-58: Use ternary operator nghost = "nghost" if has_ghost_atoms else "0" instead of if-else-block

Replace if-else-block with nghost = "nghost" if has_ghost_atoms else "0"

(SIM108)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 430dfa9 and 94d2054.

📒 Files selected for processing (11)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
  • deepmd/jax/jax2tf/__init__.py (2 hunks)
  • deepmd/jax/jax2tf/make_model.py (1 hunks)
  • deepmd/jax/jax2tf/nlist.py (1 hunks)
  • deepmd/jax/jax2tf/region.py (1 hunks)
  • deepmd/jax/jax2tf/serialization.py (5 hunks)
  • deepmd/jax/jax2tf/tfmodel.py (2 hunks)
  • deepmd/jax/jax2tf/transform_output.py (1 hunks)
  • deepmd/jax/model/hlo.py (3 hunks)
  • deepmd/jax/utils/serialization.py (3 hunks)
  • source/tests/consistent/io/test_io.py (2 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/jax/jax2tf/serialization.py

55-58: Use ternary operator nghost = "nghost" if has_ghost_atoms else "0" instead of if-else-block

Replace if-else-block with nghost = "nghost" if has_ghost_atoms else "0"

(SIM108)

deepmd/jax/jax2tf/tfmodel.py

144-147: Use ternary operator call = self._call_atomic_virial if do_atomic_virial else self._call instead of if-else-block

Replace if-else-block with call = self._call_atomic_virial if do_atomic_virial else self._call

(SIM108)

deepmd/jax/jax2tf/transform_output.py

54-54: Local variable mldims is assigned to but never used

Remove assignment to unused variable mldims

(F841)

deepmd/jax/utils/serialization.py

72-75: Use ternary operator nghost_ = nghost if has_ghost_atoms else 0 instead of if-else-block

Replace if-else-block with nghost_ = nghost if has_ghost_atoms else 0

(SIM108)

🪛 GitHub Check: CodeQL
deepmd/jax/jax2tf/transform_output.py

[notice] 54-54: Unused local variable
Variable mldims is not used.

🔇 Additional comments (15)
deepmd/jax/jax2tf/__init__.py (2)

14-14: Verify the impact of enabling numpy behavior globally.

The call to tnp.experimental_enable_numpy_behavior() affects the entire module. While necessary for the neighbor list calculations, we should ensure it doesn't interfere with existing functionality.

Let's check for any other numpy operations that might be affected:


3-3: LGTM, but note the experimental nature of the import.

The import of tensorflow.experimental.numpy is necessary for the neighbor list calculations, but be aware that it's from TensorFlow's experimental namespace.

Let's check if there are any stability notices or deprecation warnings:

✅ Verification successful

Import of tensorflow.experimental.numpy is consistently used across the module

The import of tensorflow.experimental.numpy as tnp is well-integrated throughout the jax2tf module:

  • Used consistently in 6 files within the jax2tf package
  • Primarily used for array operations in neighbor list calculations, coordinate transformations, and model outputs
  • The enable_numpy_behavior() is properly configured in __init__.py

The experimental nature of the import is acceptable here as:

  1. It's used consistently across the entire module
  2. The functionality is essential for numpy-like operations in TensorFlow
  3. The code follows TensorFlow's recommended patterns for numpy compatibility
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check TensorFlow documentation or issues related to experimental.numpy stability
gh search issues "tensorflow experimental.numpy stability" --repo tensorflow/tensorflow --json title,body --jq '.[] | select(.title | contains("experimental.numpy") or contains("tnp")) | {title: .title, body: .body}'

Length of output: 218


Script:

#!/bin/bash
# Let's check the actual usage of tnp in the codebase to understand its impact
rg "import tensorflow.experimental.numpy as tnp" -A 3

# Also check for any direct usage of tnp
rg "\btnp\." -B 1 -A 1

# Look for any configuration or setup related to numpy behavior
rg "enable_numpy_behavior|set_numpy_enabled"

Length of output: 17217

deepmd/jax/jax2tf/region.py (1)

2-3: Consider the stability implications of using experimental NumPy.

The code uses tensorflow.experimental.numpy which, being experimental, may have breaking changes in future releases. Consider:

  1. Adding version constraints for TensorFlow
  2. Documenting the minimum required TensorFlow version
  3. Adding error handling for potential API changes
deepmd/jax/utils/serialization.py (4)

56-58: LGTM: Function signature properly updated

The addition of has_ghost_atoms parameter aligns with the PR's objective to handle cases where the number of ghost atoms is zero. The type hints are correctly specified.


79-83: LGTM: Shape definitions properly handle ghost atoms

The shape definitions correctly use nghost_ to handle both ghost and no-ghost cases, ensuring proper tensor dimensions for coordinates, atom types, neighbor lists, and mapping arrays.


92-109: LGTM: Comprehensive export configurations

The implementation properly handles all combinations of atomic virial calculation and ghost atoms presence, which directly addresses the PR's objective. The naming convention is clear and consistent.


116-119: LGTM: Variable storage properly implemented

The new variables for no-ghost scenarios are correctly added to the data dictionary with appropriate naming and proper use of np.void for serialized data storage.

source/tests/consistent/io/test_io.py (2)

145-145: LGTM: Clear and consistent variable declaration.

The addition of rets_nopbc list follows the same pattern as the existing rets list and aligns with the PR's objective to test no periodic boundary conditions scenarios.


210-218: LGTM: Well-implemented verification logic with improved error messages.

The assertions for nopbc results:

  • Maintain consistency with existing verification patterns
  • Include helpful backend information in error messages
  • Properly handle unsupported cases with NaN checks
  • Use appropriate tolerance values
deepmd/jax/jax2tf/tfmodel.py (2)

55-56: LGTM: New wrapper methods for atomic virial calculations

The addition of _call and _call_atomic_virial attributes follows the existing pattern and maintains consistency with the codebase structure.


149-165: LGTM: Robust parameter initialization

The initialization of None parameters with empty arrays is well-implemented:

  • Preserves batch dimensions from input coordinates
  • Uses correct data type (float64)
  • Maintains appropriate tensor shapes for each parameter
  • Prevents TensorFlow errors about None values

This change improves the robustness of the model wrapper.

deepmd/jax/infer/deep_eval.py (1)

100-105: LGTM! The changes align with PR objectives.

The addition of stablehlo_no_ghost and stablehlo_atomic_virial_no_ghost parameters properly addresses the ghost atoms handling issue mentioned in the PR objectives.

Let's verify that these new variables are present in HLO files:

✅ Verification successful

The new variables are properly integrated into the codebase

The verification confirms that:

  • The variables are properly defined in the HLO class constructor (deepmd/jax/model/hlo.py)
  • They are correctly initialized and used with jax_export.deserialize() in the HLO class
  • The serialization logic in deepmd/jax/utils/serialization.py properly stores these variables in the model data

The implementation in deep_eval.py correctly reads these variables from the model data, maintaining consistency with how they are stored during serialization.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for the presence of new variables in HLO files
# Expected: Files containing these new variables

# Search for files containing the new variable names
rg -l "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" --type py

# Search for test files that verify these variables
rg -l "test.*stablehlo.*ghost" --type py

Length of output: 202


Script:

#!/bin/bash
# Let's examine the implementation and usage of these variables in the found files

# Check hlo.py implementation
echo "=== Content from deepmd/jax/model/hlo.py ==="
rg -A 5 -B 5 "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" deepmd/jax/model/hlo.py

echo -e "\n=== Content from deepmd/jax/utils/serialization.py ==="
rg -A 5 -B 5 "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" deepmd/jax/utils/serialization.py

# Check if there are any tests that might not directly mention "stablehlo" and "ghost"
echo -e "\n=== Related test files ==="
fd "test.*\.py" | xargs rg -l "no_ghost"

Length of output: 1891

deepmd/jax/jax2tf/make_model.py (3)

96-103: Confirm that call_lower is called with correct arguments

Ensure that the call_lower function accepts the provided arguments extended_coord, extended_atype, nlist, mapping, fparam, and aparam. Verify that the boolean parameter do_atomic_virial is correctly handled within call_lower.

Run the following script to verify the signature of call_lower:

#!/bin/bash
# Description: Verify the signature of `call_lower`.

# Expected: `call_lower` accepts the specified arguments and matches the Callable type hint.

# Since `call_lower` is passed as a parameter, search for typical usages or definitions
rg -A 5 'def call_lower' | rg 'def call_lower\('

# Alternatively, check for any type hints or comments indicating the expected parameters
rg -A 5 'call_lower' | rg 'Callable\[['

105-109: Verify that do_atomic_virial is correctly propagated

Ensure that the do_atomic_virial flag is correctly used inside communicate_extended_output, and that the function handles it appropriately.

Please run the following script to check the usage of do_atomic_virial:

#!/bin/bash
# Description: Check how `do_atomic_virial` is used in `communicate_extended_output`.

# Expected: The flag is used to conditionally process atomic virial calculations.

# Find the definition of `communicate_extended_output` and its usage of `do_atomic_virial`
rg -A 15 'def communicate_extended_output' | rg 'do_atomic_virial'

85-86: Ensure extend_coord_with_ghosts handles box being None

Passing bb to extend_coord_with_ghosts when bb is None may cause issues inside that function. Verify that extend_coord_with_ghosts can handle a None value for box.

Please run the following script to check if extend_coord_with_ghosts can handle bb being None:

✅ Verification successful

extend_coord_with_ghosts safely handles None box/cell parameter

Looking at the implementations, both the JAX and dpmodel versions properly handle the case when the cell/box parameter is None or empty:

  • In JAX (jax2tf/nlist.py): Checks if cell shape is empty with if tf.shape(cell)[-1] == 0
  • In dpmodel (dpmodel/utils/nlist.py): Explicitly checks if cell is None

In both cases, it safely returns the original coordinates without any ghost atoms when the cell/box is None.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if `extend_coord_with_ghosts` handles `box` being `None`.

# Expected: No errors or exceptions when `box` is `None`.

# Find the definition of `extend_coord_with_ghosts`
rg -A 10 'def extend_coord_with_ghosts' | grep 'box'

# Search for any checks or handling of `box` being `None` within the function
rg -A 20 'def extend_coord_with_ghosts' | rg 'if box is None|box is not None'

Length of output: 537


Script:

#!/bin/bash
# Let's try a different approach to find the implementation of extend_coord_with_ghosts

# Search for the function definition with more context
rg -A 20 "def extend_coord_with_ghosts"

# Also search in case it's defined with different spacing
rg -A 20 "def\s+extend_coord_with_ghosts"

# Try finding any usage of this function to understand the context
rg "extend_coord_with_ghosts\("

Length of output: 15278


Script:

#!/bin/bash
# Let's check the implementation in jax2tf/nlist.py since that's the one being used in make_model.py
rg -A 50 "def extend_coord_with_ghosts" deepmd/jax/jax2tf/nlist.py

# Also check if there's any handling of None cell/box in the dpmodel implementation since it has Optional typing
rg -A 50 "def extend_coord_with_ghosts" deepmd/dpmodel/utils/nlist.py

Length of output: 3600

deepmd/jax/jax2tf/region.py Show resolved Hide resolved
deepmd/jax/jax2tf/region.py Show resolved Hide resolved
deepmd/jax/jax2tf/make_model.py Show resolved Hide resolved
deepmd/jax/jax2tf/make_model.py Show resolved Hide resolved
deepmd/jax/jax2tf/nlist.py Show resolved Hide resolved
deepmd/jax/jax2tf/nlist.py Show resolved Hide resolved
deepmd/jax/model/hlo.py Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link

codecov bot commented Nov 6, 2024

Codecov Report

Attention: Patch coverage is 68.85965% with 71 lines in your changes missing coverage. Please review.

Project coverage is 84.52%. Comparing base (6c66be9) to head (e4bac35).
Report is 8 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/jax2tf/transform_output.py 12.50% 35 Missing ⚠️
deepmd/jax/jax2tf/make_model.py 38.09% 13 Missing ⚠️
deepmd/jax/jax2tf/serialization.py 66.66% 8 Missing ⚠️
deepmd/jax/jax2tf/nlist.py 91.76% 7 Missing ⚠️
deepmd/jax/jax2tf/region.py 79.16% 5 Missing ⚠️
deepmd/jax/model/hlo.py 77.77% 2 Missing ⚠️
deepmd/jax/jax2tf/tfmodel.py 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4318      +/-   ##
==========================================
+ Coverage   84.50%   84.52%   +0.01%     
==========================================
  Files         570      575       +5     
  Lines       53162    53388     +226     
  Branches     3059     3059              
==========================================
+ Hits        44927    45124     +197     
- Misses       7272     7301      +29     
  Partials      963      963              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (6)
source/tests/jax/jax2tf/__init__.py (1)

1-8: Add missing TensorFlow configurations

Based on the AI summary, this file should include TensorFlow's experimental numpy configuration and eager execution check.

Consider adding these configurations:

 # SPDX-License-Identifier: LGPL-3.0-or-later
 import pytest
+import tensorflow as tf
+import tensorflow.experimental.numpy as tnp
 
 from ...utils import (
     DP_TEST_TF2_ONLY,
 )
 
+if not tf.executing_eagerly():
+    raise RuntimeError("This module must be run in eager execution mode")
+
+tnp.experimental_enable_numpy_behavior()
+
 pytestmark = pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1")
source/tests/jax/jax2tf/test_region.py (3)

4-4: Consider the stability implications of using experimental numpy.

While using tensorflow.experimental.numpy is acceptable in tests, be aware that its API might change in future TensorFlow versions. Consider adding a comment explaining why the experimental version is preferred over the stable numpy.


18-22: Add documentation for cell initialization.

The cell initialization involves complex shape transformations. Consider adding comments to explain:

  1. Why the cell is initialized with this specific 3D structure
  2. The purpose of reshaping to [1, 1, -1, 3]
  3. The significance of tiling to [4, 5, 1, 1]

25-34: Enhance test coverage with edge cases.

While the current test with random inputs is good, consider adding specific test cases for:

  1. Zero vectors
  2. Unit vectors
  3. Extreme values
  4. Negative coordinates

This would ensure the transformation handles all possible scenarios correctly.

source/tests/jax/jax2tf/test_nlist.py (2)

17-44: Add docstrings to improve test documentation.

The test class and setUp method would benefit from docstrings explaining:

  • The purpose of the test suite
  • The structure of the test data
  • The meaning of key parameters (nf, nloc, ns, etc.)

Example addition:

 class TestNeighList(tf.test.TestCase):
+    """Test suite for neighbor list construction and coordinate extension.
+    
+    Tests the functionality of build_neighbor_list and extend_coord_with_ghosts
+    with realistic molecular simulation data.
+    """
     def setUp(self):
+        """Initialize test data including coordinates, cell, and reference lists.
+        
+        Sets up:
+        - nf (int): Number of frames (3)
+        - nloc (int): Number of local atoms (3)
+        - ns (int): Number of shifts (75)
+        - cell (array): Simulation cell parameters
+        - coordinates and atom types
+        """

45-65: Extract magic numbers into named constants.

The test uses magic numbers for array comparisons. Consider extracting these into named constants at the class level for better maintainability.

 class TestNeighList(tf.test.TestCase):
+    # Constants for array comparisons
+    INVALID_NEIGHBOR = -1
+
     def test_build_notype(self):
-        nlist_mask = nlist[0] == -1
+        nlist_mask = nlist[0] == self.INVALID_NEIGHBOR
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 84cb819 and bd27d4f.

📒 Files selected for processing (6)
  • .github/workflows/test_python.yml (1 hunks)
  • pyproject.toml (2 hunks)
  • source/tests/jax/__init__.py (1 hunks)
  • source/tests/jax/jax2tf/__init__.py (1 hunks)
  • source/tests/jax/jax2tf/test_nlist.py (1 hunks)
  • source/tests/jax/jax2tf/test_region.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • source/tests/jax/init.py
🔇 Additional comments (6)
source/tests/jax/jax2tf/test_region.py (1)

1-53: Verify the connection with ghost atoms handling.

The PR objectives mention handling ghost atoms and neighbor lists, but this test file focuses on coordinate transformations. Please clarify how these region transformations relate to the ghost atoms functionality.

✅ Verification successful

Region transformations are essential for ghost atoms and neighbor list handling

The verification shows clear connections between region transformations and ghost atoms functionality:

  1. The inter2phys and to_face_distance functions are directly used in neighbor list construction across all implementations (TF/PT/JAX):

    • In deepmd/tf/utils/nlist.py, deepmd/pt/utils/nlist.py, and deepmd/jax/jax2tf/nlist.py
    • Used to calculate buffer sizes for ghost atoms based on cutoff distances
  2. The region transformations serve critical purposes:

    • inter2phys: Converts internal coordinates to physical coordinates needed for ghost atom positioning
    • to_face_distance: Calculates cell face distances to determine how many ghost atom copies are needed in each direction
  3. These functions are part of the core infrastructure used by the ghost atoms implementation found in files like:

    • deepmd/tf/utils/nlist.py: extend_coord_with_ghosts
    • deepmd/pt/utils/nlist.py: extend_input_and_build_neighbor_list
    • deepmd/jax/jax2tf/nlist.py: build_neighbor_list
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for connections between region transformations and ghost atoms
rg -l "ghost.*region|region.*ghost" --type py

# Look for usage of these transformation functions in neighbor list construction
ast-grep --pattern 'inter2phys|to_face_distance' 

Length of output: 98


Script:

#!/bin/bash
# Search for files containing ghost atoms implementation
rg -l "ghost.*atom" --type py

# Search for implementations using region transformations
rg -A 5 "inter2phys|to_face_distance" --type py

# Look for neighbor list implementations
rg -l "neighbor.*list|neigh.*list" --type py

# Check the actual implementation files for region.py
fd region.py

Length of output: 20719

.github/workflows/test_python.yml (1)

61-61: LGTM! Test paths align with PR objectives.

The addition of test_io.py and the entire jax2tf directory to TF2 eager mode testing appropriately covers both the DeepEval class modifications and the new neighbor list functionality.

Let's verify the test coverage:

✅ Verification successful

Test paths are appropriate and contain relevant test coverage

The verification confirms that:

  • test_io.py thoroughly tests the DeepEval class with both PBC and non-PBC conditions across different backends (TF, PyTorch, JAX)
  • test_nlist.py comprehensively tests the neighbor list functionality with:
    • Type-based and non-type-based neighbor list building
    • Ghost atom coordination extension
    • Proper handling of periodic boundary conditions
    • Validation against reference neighbor lists
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the test files exist and contain relevant test cases

# Test 1: Check if the test files exist
echo "Checking test files existence..."
fd -t f "test_io.py$" "source/tests/consistent/io"
fd -t f "test_nlist.py$" "source/tests/jax/jax2tf"

# Test 2: Check test content for neighbor list and nopbc related tests
echo "Checking test content..."
rg -l "test.*nlist|test.*pbc" "source/tests/jax/jax2tf"

Length of output: 367


Script:

#!/bin/bash
# Let's check the actual content of these test files to verify coverage

echo "Content of test_io.py:"
cat source/tests/consistent/io/test_io.py

echo -e "\nContent of test_nlist.py:"
cat source/tests/jax/jax2tf/test_nlist.py

# Also check for any other test files in jax2tf directory
echo -e "\nOther test files in jax2tf directory:"
fd -t f "^test.*\.py$" "source/tests/jax/jax2tf"

Length of output: 15126

source/tests/jax/jax2tf/test_nlist.py (2)

101-109: Add explanation for the coordinate transformation logic.

The shift vector calculation and matrix transformation logic is complex and would benefit from detailed comments explaining:

  • The purpose of the transformation
  • Why the inverse cell matrix is used
  • The expected output format
#!/bin/bash
# Check if similar matrix transformations are documented elsewhere
rg "tf.linalg.inv.*cell" -B 2 -A 2

4-4: Consider the stability implications of using experimental NumPy.

The use of tensorflow.experimental.numpy might lead to stability issues as experimental features can change or be deprecated. Consider documenting this dependency clearly and monitoring for updates.

pyproject.toml (2)

408-411: LGTM: Consistent module-level import restrictions.

The addition of "deepmd.jax" and "jax" to banned module-level imports follows the established pattern of restricting direct imports of deep learning frameworks, similar to the existing restrictions on tensorflow and torch.


424-424: LGTM: Appropriate test directory exception.

Adding "source/tests/jax/**" to the TID253 exceptions is consistent with the existing pattern for test directories and necessary for allowing JAX imports in test files.

source/tests/jax/jax2tf/__init__.py Outdated Show resolved Hide resolved
source/tests/jax/jax2tf/test_region.py Outdated Show resolved Hide resolved
source/tests/jax/jax2tf/test_nlist.py Outdated Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (4)
source/tests/jax/jax2tf/test_region.py (2)

26-34: Add docstrings to explain the test class and cell structure.

The test setup is well-implemented, but would benefit from documentation explaining:

  • The purpose of the TestRegion class
  • The structure and significance of the cell array
  • Why these specific dimensions (4, 5, -1, 3) are used
 class TestRegion(tf.test.TestCase):
+    """Tests for region transformation functions.
+    
+    The test class validates coordinate transformations and distance calculations
+    using a 4x5 batch of 3x3 cell matrices. Each cell represents a transformation
+    from internal to physical coordinates.
+    """
     def setUp(self):
+        """Sets up test fixtures with a batched cell array for coordinate transforms."""

35-44: Consider adding edge cases to the coordinate transformation test.

While the basic test is solid, consider adding test cases for:

  • Zero coordinates
  • Large coordinate values
  • Negative coordinates
  • Boundary values
 def test_inter_to_phys(self):
+    """Test internal to physical coordinate transformation with various cases."""
     rng = tf.random.Generator.from_seed(GLOBAL_SEED)
     inter = rng.normal(shape=[4, 5, 3, 3])
+    # Test regular case
     phys = inter2phys(inter, self.cell)
     for ii in range(4):
         for jj in range(5):
             expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj])
             self.assertAllClose(
                 phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec
             )
+    
+    # Test edge cases
+    edge_cases = [
+        tnp.zeros([4, 5, 3, 3]),  # Zero coordinates
+        tnp.ones([4, 5, 3, 3]) * 1e6,  # Large values
+        tnp.ones([4, 5, 3, 3]) * -1.0,  # Negative values
+    ]
+    for case in edge_cases:
+        phys = inter2phys(case, self.cell)
+        for ii in range(4):
+            for jj in range(5):
+                expected_phys = tnp.matmul(case[ii, jj], self.cell[ii, jj])
+                self.assertAllClose(
+                    phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec
+                )
source/tests/jax/jax2tf/test_nlist.py (2)

28-55: Add docstrings to explain test setup and data structure.

The test setup is comprehensive but would benefit from documentation explaining:

  • The purpose and structure of the test data
  • The meaning of magic numbers (e.g., nloc=3, ns=553)
  • The format and meaning of ref_nlist array

Example docstring:

class TestNeighList(tf.test.TestCase):
    """Tests for neighbor list construction and coordinate extension.
    
    Test data structure:
    - nf: Number of frames
    - nloc: Number of local atoms per frame
    - ns: Number of shifts (5x5x3 grid)
    - ref_nlist: Reference neighbor list with format [...]
    """

77-98: Optimize test_build_type by reducing redundant operations.

The test method could be optimized by:

  1. Caching the mapping operation results
  2. Using vectorized operations instead of the loop
 def test_build_type(self):
     ecoord, eatype, mapping = extend_coord_with_ghosts(
         self.coord, self.atype, self.cell, self.rcut
     )
     nlist = build_neighbor_list(
         ecoord,
         eatype,
         self.nloc,
         self.rcut,
         self.nsel,
         distinguish_types=True,
     )
     self.assertAllClose(nlist[0], nlist[1])
     nlist_mask = nlist[0] == -1
     nlist_loc = mapping[0][nlist[0]]
     nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc)
-    for ii in range(2):
-        self.assertAllClose(
-            tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1),
-            tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1),
-        )
+    # Split and sort once, then compare
+    nlist_split = tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1), axis=-1)
+    ref_split = tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1), axis=-1)
+    self.assertAllClose(nlist_split, ref_split)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between a00aae8 and 8bef185.

📒 Files selected for processing (3)
  • source/tests/jax/jax2tf/__init__.py (1 hunks)
  • source/tests/jax/jax2tf/test_nlist.py (1 hunks)
  • source/tests/jax/jax2tf/test_region.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • source/tests/jax/jax2tf/init.py
🔇 Additional comments (4)
source/tests/jax/jax2tf/test_region.py (2)

1-24: LGTM! Well-structured imports and test setup.

The imports are properly organized with clear separation of concerns. Good practice using conditional imports and skip markers to handle TensorFlow version compatibility.


46-63: Add input validation and consider vectorizing calculations.

The previous review comment about input validation is still applicable. Additionally, consider vectorizing the calculations to improve test performance:

 def test_to_face_dist(self):
+    """Test face distance calculations with validation."""
     cell0 = self.cell[0][0]
     vol = tf.linalg.det(cell0)
+    # Ensure cell is not singular
+    self.assertGreater(tf.abs(vol), 1e-10)
+
     # area of surfaces xy, xz, yz
     sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1]))
     sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2]))
     syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2]))
+    # Validate surface areas
+    self.assertGreater(sxy, 1e-10)
+    self.assertGreater(sxz, 1e-10)
+    self.assertGreater(syz, 1e-10)
+
     # vol / area gives distance
     dz = vol / sxy
     dy = vol / sxz
     dx = vol / syz
     expected = tnp.array([dx, dy, dz])
     dists = to_face_distance(self.cell)
-    for ii in range(4):
-        for jj in range(5):
-            self.assertAllClose(
-                dists[ii][jj], expected, rtol=self.prec, atol=self.prec
-            )
+    # Vectorized comparison
+    self.assertAllClose(
+        dists, tnp.broadcast_to(expected, [4, 5, 3]),
+        rtol=self.prec, atol=self.prec
+    )
source/tests/jax/jax2tf/test_nlist.py (2)

1-26: LGTM! Well-structured imports and test setup.

The imports are properly organized with clear conditional logic to handle TensorFlow version conflicts.


56-164: Test coverage aligns well with PR objectives.

The test suite effectively validates the neighbor list construction in SavedModel format and handles ghost atoms appropriately, which directly addresses the PR objectives. The tests verify:

  1. Neighbor list construction with and without type distinction
  2. Coordinate extension with ghost atoms
  3. Grid alignment and shift vectors

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
source/tests/consistent/io/test_io.py (1)

149-149: Add docstring explaining nopbc test cases.

While the implementation of no periodic boundary condition tests is good, consider adding documentation to explain:

  1. The purpose of these test cases
  2. The expected behavior when box is None
  3. The relationship with the PR's nopbc fixes

Example docstring addition:

def test_deep_eval(self):
+    """Test deep_eval with both periodic and non-periodic boundary conditions.
+    
+    This test verifies:
+    1. Regular evaluation with periodic boundary conditions
+    2. Evaluation without periodic boundaries (box=None)
+    3. Atomic-level outputs for both cases
+    """

Also applies to: 190-206

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 8bef185 and fc7b6b7.

📒 Files selected for processing (3)
  • source/tests/consistent/io/test_io.py (4 hunks)
  • source/tests/jax/jax2tf/test_nlist.py (1 hunks)
  • source/tests/jax/jax2tf/test_region.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/jax/jax2tf/test_nlist.py
  • source/tests/jax/jax2tf/test_region.py
🔇 Additional comments (2)
source/tests/consistent/io/test_io.py (2)

3-3: LGTM: Good memory management practice.

The addition of garbage collection after object deletion is a good practice to prevent memory leaks and OOM issues in CI environments.

Also applies to: 112-114, 207-208


216-223: Verify tolerance values for nopbc comparisons.

The assertions use the same tolerance values (rtol=1e-12, atol=1e-12) as the periodic boundary tests. Please verify if these tolerances are appropriate for nopbc cases, as different numerical behaviors might occur without periodic boundaries.

✅ Verification successful

Based on the search results, I can now generate the final response:

Tolerance values for nopbc comparisons are appropriate

The tolerance values (rtol=1e-12, atol=1e-12) are appropriate for non-periodic boundary condition (nopbc) cases because:

  1. These values are consistent with other high-precision tests in the codebase, particularly:
    • The model initialization tests use the same tolerances (1e-12) for both periodic and non-periodic cases
    • The consistent IO tests use these tolerances uniformly across different backends
  2. Non-periodic calculations are typically simpler than periodic ones (no wrapping/image calculations), so using the same strict tolerances is reasonable
  3. The test already handles special cases with the np.isnan(vv2).all() check for unsupported features
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if there are any existing tests that use different tolerances for nopbc cases

# Search for tolerance values in test files
echo "Searching for existing tolerance patterns in tests..."
rg -A 2 'assert_allclose.*nopbc' source/tests/

Length of output: 165


Script:

#!/bin/bash
# Let's search for all assert_allclose usage in test files to understand tolerance patterns
echo "Searching for assert_allclose patterns in tests..."
rg -B 2 -A 2 'assert_allclose' source/tests/consistent/io/test_io.py

# Also search for nopbc related tests
echo -e "\nSearching for nopbc related tests..."
rg -B 2 -A 2 'nopbc' source/tests/consistent/io/test_io.py

Length of output: 1551


Script:

#!/bin/bash
# Let's search for tolerance patterns in other test files to get a broader view
echo "Searching for tolerance patterns in all test files..."
rg -B 2 -A 2 'rtol.*atol' source/tests/

# Also search specifically for non-periodic boundary condition related tests
echo -e "\nSearching for other nopbc/non-periodic tests..."
rg -B 2 -A 2 'non.?periodic|nopbc' source/tests/

Length of output: 42671

@njzjz
Copy link
Member Author

njzjz commented Nov 8, 2024

The tests may fail due to the number of CPU threads is not limited.

@njzjz
Copy link
Member Author

njzjz commented Nov 9, 2024

The test failed due to out-of-memory. I don't know why JAX consumed so much memory... I add a fixture to call gc.collect() for each package.

@njzjz
Copy link
Member Author

njzjz commented Nov 9, 2024

The test failed due to out-of-memory. I don't know why JAX consumed so much memory... I add a fixture to call gc.collect() for each package.

Oh, it does not work.

@njzjz
Copy link
Member Author

njzjz commented Nov 9, 2024

I move the JAX test in test_io.py into a separated run...

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
source/tests/jax/jax2tf/test_nlist.py (2)

34-54: Consider making test data more dynamic and maintainable.

The current setup has several potential improvements:

  1. The cell matrix and reference neighbor list are hardcoded, making tests brittle
  2. The test could benefit from parameterization to cover more cases

Consider:

  1. Using @pytest.mark.parametrize to test different cell configurations
  2. Moving test data to external fixtures
  3. Adding helper methods to generate reference neighbor lists dynamically

Example parameterization:

@pytest.mark.parametrize("cell_type", [
    "cubic",
    "orthorhombic",
    "triclinic"
])
def test_build_notype(self, cell_type):
    cell = get_test_cell(cell_type)
    # ... rest of the test

77-98: Expand type-specific neighbor list testing.

The current test could be enhanced with:

  1. More diverse type combinations
  2. Boundary cases between different types
  3. Verification of type-specific cutoffs

Consider adding:

def test_build_type_diverse(self):
    # Test with more atom types
    diverse_types = tnp.array([-1, 0, 1, 2, 3], dtype=tnp.int32)
    nsel_diverse = [5] * 4  # neighbor selection for each type
    # ... rest of the test
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between b0a496c and 4e42105.

📒 Files selected for processing (2)
  • source/tests/jax/jax2tf/test_nlist.py (1 hunks)
  • source/tests/jax/jax2tf/test_region.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/jax/jax2tf/test_region.py
🔇 Additional comments (1)
source/tests/jax/jax2tf/test_nlist.py (1)

6-6: Consider the stability implications of using experimental NumPy.

The use of tensorflow.experimental.numpy could lead to potential stability issues as it's an experimental API. Consider:

  1. Adding error handling for experimental features
  2. Documenting any known limitations
  3. Having a fallback mechanism

source/tests/jax/jax2tf/test_nlist.py Outdated Show resolved Hide resolved
source/tests/jax/jax2tf/test_nlist.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (4)
source/jax2tf_tests/test_region.py (2)

29-33: Consider documenting the significance of tensor dimensions.

The cell tensor is initialized with specific dimensions [4, 5, -1, 3] but the significance of these dimensions (4 and 5) is not documented. Consider adding comments explaining what these dimensions represent or extracting them as named constants.


47-64: Consider enhancing the mathematical documentation.

While the code includes basic comments, it would be beneficial to add more detailed documentation about:

  1. The mathematical formulas used for face distance calculations
  2. Why volume divided by area gives the correct distance
  3. References to relevant geometric principles

This would make the test more maintainable and educational for future contributors.

Example documentation format:

def test_to_face_dist(self):
    """Tests face distance calculations using volume/area method.
    
    Mathematical background:
    - For a parallelepiped, the distance to a face is given by V/A
      where V is the volume and A is the area of the parallel face
    - Volume is calculated using determinant of cell matrix
    - Surface areas are calculated using cross products of edge vectors
    
    References:
    - [Add relevant geometric principles or papers]
    """
source/jax2tf_tests/test_nlist.py (2)

29-55: Add docstrings and comments to explain test data structure.

The setUp method initializes several parameters but lacks documentation explaining:

  • The significance of dimensions (nf=3, nloc=3, ns=75)
  • The structure of reference neighbor lists
  • The coordinate transformation process

Add docstrings following this structure:

def setUp(self):
    """Initialize test data for neighbor list construction.
    
    Test data structure:
    - nf (int): Number of frames
    - nloc (int): Number of local atoms
    - ns (int): Number of shifts (5x5x3 grid)
    - cell (array): Simulation cell matrix
    - icoord (array): Initial coordinates
    - ref_nlist (array): Reference neighbor lists
    """

56-76: Enhance test_build_notype with edge cases.

While the basic neighbor list construction is tested, consider adding tests for:

  • Empty neighbor lists
  • Maximum number of neighbors
  • Single atom systems
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4e42105 and 215efc6.

📒 Files selected for processing (5)
  • .github/workflows/test_python.yml (1 hunks)
  • pyproject.toml (2 hunks)
  • source/jax2tf_tests/__init__.py (1 hunks)
  • source/jax2tf_tests/test_nlist.py (1 hunks)
  • source/jax2tf_tests/test_region.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • source/jax2tf_tests/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • .github/workflows/test_python.yml
  • pyproject.toml
🔇 Additional comments (5)
source/jax2tf_tests/test_region.py (2)

36-46: LGTM! Well-structured test with proper assertions.

The test effectively validates the coordinate transformation across batch dimensions while ensuring reproducibility through a fixed seed.


6-8: Consider the stability implications of using experimental numpy.

The code uses tensorflow.experimental.numpy which may have API changes in future versions. Consider documenting the minimum required TensorFlow version and tracking the module's stability status.

source/jax2tf_tests/test_nlist.py (3)

77-98: LGTM! Comprehensive type-aware neighbor list testing.

The test properly validates type-distinguished neighbor lists and verifies the results against reference data.


5-6: Consider the stability implications of using experimental numpy.

The code imports tensorflow.experimental.numpy, which is an experimental module. While it provides numpy-like functionality, its API might change in future TensorFlow releases.

Let's check if there's a more stable alternative:

Consider:

  1. Adding a comment explaining why the experimental module is necessary
  2. Adding version constraints in requirements.txt
  3. Creating an abstraction layer to isolate experimental API usage

29-55: 🛠️ Refactor suggestion

Add tests for zero ghost atoms case.

The PR objectives mention issues with ghost atoms when the count is zero, but this case isn't explicitly tested. Additionally, given the memory concerns mentioned in the PR comments, consider adding memory usage assertions.

Let's check if other test files cover this:

Add test cases for:

  1. Zero ghost atoms scenario
  2. Memory usage patterns using memory_profiler or similar tools

source/jax2tf_tests/test_nlist.py Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz
Copy link
Member Author

njzjz commented Nov 9, 2024

It finally passed after I moved the tf2jax directory out of tests... I don't know why

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants