Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(pt): Support fitting_net input statistics. #4504

Open
wants to merge 26 commits into
base: devel
Choose a base branch
from

Conversation

Chengqian-Zhang
Copy link
Collaborator

@Chengqian-Zhang Chengqian-Zhang commented Dec 25, 2024

Solve issue #4281
Support fitting_net statistics to calculate the mean value and standard deviation of fparam/aparam. So that fparam/aparam can be normalized automatically before concatenating to descriptor.

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Introduced a method to compute input statistics, including mean and standard deviation for fitting parameters.
    • Enhanced functionality to compute additional statistics alongside existing ones.
    • Added a new parameter for data protection statistics to model configurations.
    • Added unit tests to validate the energy fitting model's statistical computations.
  • Bug Fixes

    • Improved error handling for input data dimensions to ensure consistency.

@Chengqian-Zhang Chengqian-Zhang marked this pull request as draft December 25, 2024 10:44
Copy link
Contributor

coderabbitai bot commented Dec 25, 2024

Warning

Rate limit exceeded

@Chengqian-Zhang has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 15 minutes and 26 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between e651bf4 and 9cb04eb.

📒 Files selected for processing (1)
  • source/tests/pt/test_fitting_stat.py (1 hunks)
📝 Walkthrough

Walkthrough

The pull request modifies the compute_or_load_stat method in the DPAtomicModel class and introduces a new method compute_input_stats in the Fitting class. The changes enhance the functionality of statistical computations for input data within the deep potential model, specifically allowing the fitting network to compute input statistics alongside existing descriptor statistics. Additionally, a new test file is added to validate these functionalities. A new parameter data_stat_protect is also introduced in relevant classes and functions to enhance configuration capabilities.

Changes

File Change Summary
deepmd/pt/model/task/fitting.py Added compute_input_stats method to Fitting class for calculating input statistics from provided data.
deepmd/pt/model/atomic_model/dp_atomic_model.py Updated compute_or_load_stat method to invoke compute_input_stats on the fitting network.
deepmd/pt/model/atomic_model/base_atomic_model.py Added data_stat_protect parameter to __init__ method of BaseAtomicModel.
deepmd/pt/model/model/__init__.py Updated get_standard_model function to include data_stat_protect parameter.
source/tests/pt/test_fitting_stat.py Added unit tests for energy fitting functionality, including synthetic data generation and validation of computed statistics.

Possibly related PRs

Suggested reviewers

  • njzjz
  • wanghan-iapcm

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)

437-440: Use a ternary operator for compactness.
Ruff suggests replacing the if callable(...) block with a ternary operator. This is a minor readability enhancement.

- if callable(merged):
-     sampled = merged()
- else:
-     sampled = merged
+ sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)


457-457: Implement aparam statistics.
The TODO suggests you plan to handle aparam similarly to fparam. Let me know if you’d like help implementing that.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between beeb3d9 and af5e589.

📒 Files selected for processing (2)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

🔇 Additional comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

301-301: Consider verifying that fitting_net is defined before usage.
If fitting_net were ever None, invoking compute_input_stats would raise an exception. A quick check ensures safe execution.

deepmd/pt/model/task/fitting.py (2)

7-7: New import statements look good.
Thanks for adding Callable, Union, and DPPath; these additions enable clearer type hints and better modularity.

Also applies to: 43-45


416-436: Comprehensive documentation.
The docstring clearly explains the purpose and usage of compute_input_stats. This addition aligns with the PR objective to compute input statistics for fitting parameters.

deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Dec 25, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.58%. Comparing base (62184e1) to head (9cb04eb).

Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4504   +/-   ##
=======================================
  Coverage   84.58%   84.58%           
=======================================
  Files         680      680           
  Lines       64510    64543   +33     
  Branches     3539     3539           
=======================================
+ Hits        54563    54596   +33     
+ Misses       8807     8806    -1     
- Partials     1140     1141    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)

442-456: ⚠️ Potential issue

Add protection against zero standard deviation.

The division by fparam_std could lead to numerical instability or inf values when the standard deviation is zero or very small.

Apply this diff to add protection:

 fparam_std = torch.std(cat_data, axis=0)
+epsilon = 1e-12  # Small constant to prevent division by zero
+fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std)
 fparam_inv_std = 1.0 / fparam_std
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)

456-457: Implement aparam statistics calculation.

The TODO comment indicates missing implementation for aparam statistics, which is needed for complete functionality as indicated by the PR objectives.

Would you like me to generate the implementation for aparam statistics calculation? It would follow a similar pattern to the fparam implementation but handle the different dimensionality of aparam.


437-440: Simplify conditional assignment using ternary operator.

The if-else block can be simplified using a ternary operator.

Apply this diff to improve code style:

-if callable(merged):
-    sampled = merged()
-else:
-    sampled = merged
+sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between af5e589 and e7dfe91.

📒 Files selected for processing (2)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (17)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (1)
deepmd/pt/model/task/fitting.py (1)

7-7: LGTM!

The new imports are correctly organized and necessary for the added functionality.

Also applies to: 43-45

deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)

99-102: Simplify the code using a ternary operator.

The code can be more concise using a ternary operator.

-        if callable(merged):
-            sampled = merged()
-        else:
-            sampled = merged
+        sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

99-102: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7dfe91 and 95229c5.

📒 Files selected for processing (1)
  • deepmd/pt/model/task/fitting.py (3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

99-102: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)

43-45: LGTM!

The import statement is correctly placed and follows the existing import structure.


78-98: LGTM!

The method signature and docstring are well-structured with clear parameter descriptions.

deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
deepmd/pt/model/task/fitting.py (2)

103-116: ⚠️ Potential issue

Add protection against division by zero.

The standard deviation calculation needs protection against zero or near-zero values.

Apply this diff to handle potential division by zero:

 fparam_std = torch.std(cat_data, dim=0, unbiased=False)
+epsilon = 1e-12
+fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std)
 fparam_inv_std = 1.0 / fparam_std

118-140: ⚠️ Potential issue

Add protection against division by zero in aparam calculations.

Similar to fparam, the aparam standard deviation calculation needs protection.

Apply this diff to handle potential division by zero:

 aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
+epsilon = 1e-12
+aparam_std = torch.where(aparam_std < epsilon, torch.tensor(epsilon, dtype=aparam_std.dtype, device=aparam_std.device), aparam_std)
 aparam_inv_std = 1.0 / aparam_std
🧹 Nitpick comments (2)
source/tests/pt/test_fitting_stat.py (1)

71-93: Add more test cases for edge conditions.

While the current test case validates the basic functionality, consider adding tests for:

  • Empty data
  • Zero standard deviation
  • Single frame/atom scenarios

Would you like me to generate additional test cases to improve coverage?

deepmd/pt/model/task/fitting.py (1)

98-101: Simplify conditional with ternary operator.

The if-else block can be simplified.

-if callable(merged):
-    sampled = merged()
-else:
-    sampled = merged
+sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 95229c5 and 9e0b2ff.

📒 Files selected for processing (2)
  • deepmd/pt/model/task/fitting.py (3 hunks)
  • source/tests/pt/test_fitting_stat.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
source/tests/pt/test_fitting_stat.py (2)

17-41: LGTM! Well-structured test data generation.

The _make_fake_data_pt function is well-implemented with:

  • Clear parameter handling
  • Proper data structure generation
  • Correct tensor conversion

44-55: LGTM! Robust statistical computation helpers.

The _brute_fparam_pt and _brute_aparam_pt functions provide a reliable "brute force" approach to compute statistics, serving as a good reference for validation.

Also applies to: 58-69

deepmd/pt/model/task/fitting.py (1)

78-97: LGTM! Well-documented method signature.

The method signature and docstring are clear and comprehensive.

deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
@Chengqian-Zhang Chengqian-Zhang marked this pull request as ready for review February 17, 2025 09:19
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)

101-101: 🛠️ Refactor suggestion

Add input validation for empty data.

The method should validate the input data before processing to ensure robustness.

         sampled = merged
+        if not sampled:
+            raise ValueError("No data samples provided")
+        if self.numb_fparam > 0 and not all("fparam" in frame for frame in sampled):
+            raise ValueError("Missing 'fparam' in some data samples")
+        if self.numb_aparam > 0 and not all("aparam" in frame for frame in sampled):
+            raise ValueError("Missing 'aparam' in some data samples")
🧰 Tools
🪛 Ruff (0.8.2)

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)

98-101: Simplify the code using a ternary operator.

The code can be more concise by using a ternary operator.

-        if callable(merged):
-            sampled = merged()
-        else:
-            sampled = merged
+        sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9e0b2ff and 1a4836c.

📒 Files selected for processing (2)
  • deepmd/pt/model/task/fitting.py (3 hunks)
  • source/tests/pt/test_fitting_stat.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/pt/test_fitting_stat.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)

43-45: LGTM!

The import statement is correctly placed and follows the project's import style.


103-152: LGTM! Well-implemented statistical computations.

The implementation correctly:

  1. Handles both frame and atomic parameters
  2. Prevents division by zero using epsilon
  3. Uses efficient tensor operations
  4. Properly handles data dimensions

@wanghan-iapcm wanghan-iapcm changed the title Feat(pt): Support fitting_net statistics. Feat(pt): Support fitting_net input statistics. Feb 18, 2025
deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
source/tests/pt/test_fitting_stat.py Show resolved Hide resolved
@njzjz njzjz linked an issue Feb 18, 2025 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (6)
source/tests/pt/test_fitting_stat.py (3)

18-42: Add input validation for the helper function.

Consider adding validation for input parameters to ensure they have compatible dimensions and valid values.

 def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
+    if not sys_natoms or not sys_nframes:
+        raise ValueError("Empty system arrays")
+    if len(avgs) != len(stds):
+        raise ValueError("Averages and standard deviations must have the same length")
+    if any(n <= 0 for n in sys_natoms) or any(f <= 0 for f in sys_nframes):
+        raise ValueError("Number of atoms and frames must be positive")
     merged_output_stat = []

45-56: Add shape validation for frame parameters.

Consider validating the shape of input data to ensure consistent dimensions across all systems.

 def _brute_fparam_pt(data, ndim):
+    if not data:
+        raise ValueError("Empty data")
     adata = [to_numpy_array(ii["fparam"]) for ii in data]
+    shapes = [a.shape[-1] for a in adata]
+    if len(set(shapes)) > 1:
+        raise ValueError(f"Inconsistent dimensions across systems: {shapes}")
+    if shapes[0] != ndim:
+        raise ValueError(f"Expected dimension {ndim}, got {shapes[0]}")

59-70: Consider refactoring to reduce code duplication.

The function shares significant code with _brute_fparam_pt. Consider extracting common functionality into a shared helper.

+def _compute_stats(data, key, ndim):
+    if not data:
+        raise ValueError("Empty data")
+    adata = [to_numpy_array(ii[key]) for ii in data]
+    shapes = [a.shape[-1] for a in adata]
+    if len(set(shapes)) > 1:
+        raise ValueError(f"Inconsistent dimensions across systems: {shapes}")
+    if shapes[0] != ndim:
+        raise ValueError(f"Expected dimension {ndim}, got {shapes[0]}")
+    all_data = []
+    for ii in adata:
+        tmp = np.reshape(ii, [-1, ndim])
+        if len(all_data) == 0:
+            all_data = np.array(tmp)
+        else:
+            all_data = np.concatenate((all_data, tmp), axis=0)
+    avg = np.average(all_data, axis=0)
+    std = np.std(all_data, axis=0)
+    return avg, std
+
 def _brute_fparam_pt(data, ndim):
-    adata = [to_numpy_array(ii["fparam"]) for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
+    return _compute_stats(data, "fparam", ndim)

 def _brute_aparam_pt(data, ndim):
-    adata = [to_numpy_array(ii["aparam"]) for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
+    return _compute_stats(data, "aparam", ndim)
deepmd/pt/model/model/__init__.py (1)

257-258: Add validation for data_stat_protect parameter.

The parameter should be positive to prevent division by zero or negative values.

-    data_stat_protect = model_params.get("data_stat_protect")
+    data_stat_protect = model_params.get("data_stat_protect", 1e-2)
+    if data_stat_protect <= 0:
+        raise ValueError("data_stat_protect must be positive")

Also applies to: 279-279

deepmd/pt/model/task/fitting.py (2)

78-151: Add input validation and optimize code.

The implementation looks good but could benefit from:

  1. Input validation for merged data
  2. Code optimization using ternary operator
  3. Removal of unused import
-from deepmd.utils.path import (
-    DPPath,
-)

 def compute_input_stats(
     self,
     merged: Union[Callable[[], list[dict]], list[dict]],
     protection: float = 1e-2,
 ) -> None:
+    if protection <= 0:
+        raise ValueError("protection must be positive")
-    if callable(merged):
-        sampled = merged()
-    else:
-        sampled = merged
+    sampled = merged() if callable(merged) else merged
+    if not sampled:
+        raise ValueError("No data samples provided")

     # stat fparam
     if self.numb_fparam > 0:
+        if not all("fparam" in frame for frame in sampled):
+            raise ValueError("Missing 'fparam' in some data samples")
         cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
🧰 Tools
🪛 Ruff (0.8.2)

98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)


123-150: Add validation for aparam data.

Similar to fparam, add validation for aparam data presence and shape.

     # stat aparam
     if self.numb_aparam > 0:
+        if not all("aparam" in frame for frame in sampled):
+            raise ValueError("Missing 'aparam' in some data samples")
+        if not all(frame["aparam"].shape[-1] == self.numb_aparam for frame in sampled):
+            raise ValueError(f"Inconsistent aparam dimensions, expected {self.numb_aparam}")
         sys_sumv = []
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1a4836c and 3e5cbe4.

📒 Files selected for processing (5)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (2 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/__init__.py (2 hunks)
  • deepmd/pt/model/task/fitting.py (3 hunks)
  • source/tests/pt/test_fitting_stat.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

44-44: deepmd.utils.path.DPPath imported but unused

Remove unused import: deepmd.utils.path.DPPath

(F401)


98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (3)
source/tests/pt/test_fitting_stat.py (1)

73-103: Add comprehensive test cases including edge cases.

The test should include:

  1. Near-zero standard deviation case (as requested in previous review)
  2. Large value ranges to test numerical stability
  3. Single frame/atom case
  4. Error cases for validation
 class TestEnerFittingStat(unittest.TestCase):
-    def test(self) -> None:
+    def test_normal_case(self) -> None:
         descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
         fitting = EnergyFittingNet(
             descrpt.get_ntypes(),
             descrpt.get_dim_out(),
             neuron=[240, 240, 240],
             resnet_dt=True,
             numb_fparam=3,
             numb_aparam=3,
         )
         avgs = [0, 10, 100]
         stds = [2, 0.4, 0.0001]
         sys_natoms = [10, 100]
         sys_nframes = [5, 2]
         all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
         frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
         arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
         fitting.compute_input_stats(all_data, protection=1e-2)
         np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg))
         np.testing.assert_almost_equal(
             frefs_inv, to_numpy_array(fitting.fparam_inv_std)
         )
         np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg))
         np.testing.assert_almost_equal(
             arefs_inv, to_numpy_array(fitting.aparam_inv_std)
         )

+    def test_near_zero_std(self) -> None:
+        descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
+        fitting = EnergyFittingNet(
+            descrpt.get_ntypes(),
+            descrpt.get_dim_out(),
+            neuron=[240, 240, 240],
+            resnet_dt=True,
+            numb_fparam=1,
+            numb_aparam=1,
+        )
+        avgs = [1.0]
+        stds = [1e-10]  # Near-zero std
+        sys_natoms = [10]
+        sys_nframes = [5]
+        all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
+        fitting.compute_input_stats(all_data, protection=1e-2)
+        # Verify that protection prevents division by zero
+        self.assertLess(to_numpy_array(fitting.fparam_inv_std)[0], 100.0)
+        self.assertLess(to_numpy_array(fitting.aparam_inv_std)[0], 100.0)

+    def test_single_frame(self) -> None:
+        descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
+        fitting = EnergyFittingNet(
+            descrpt.get_ntypes(),
+            descrpt.get_dim_out(),
+            neuron=[240, 240, 240],
+            resnet_dt=True,
+            numb_fparam=1,
+            numb_aparam=1,
+        )
+        avgs = [1.0]
+        stds = [0.1]
+        sys_natoms = [1]
+        sys_nframes = [1]
+        all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
+        fitting.compute_input_stats(all_data, protection=1e-2)
+        # Verify stats are computed correctly for single frame
+        self.assertEqual(to_numpy_array(fitting.fparam_avg).shape, (1,))
+        self.assertEqual(to_numpy_array(fitting.aparam_avg).shape, (1,))
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

82-82: LGTM!

The data_stat_protect parameter is well-integrated with a sensible default value.

Also applies to: 91-91

deepmd/pt/model/task/fitting.py (1)

103-121: LGTM! Protection against division by zero is well-implemented.

The implementation correctly handles near-zero standard deviations using the protection parameter.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)

95-98: Simplify the callable check using a ternary operator.

The code can be more concise.

-        if callable(merged):
-            sampled = merged()
-        else:
-            sampled = merged
+        sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

95-98: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)


122-151: Add a comment explaining the aparam statistics computation approach.

The implementation is correct, but it would be helpful to add a comment explaining why the statistics are computed differently for atomic parameters compared to frame parameters.

         # stat aparam
         if self.numb_aparam > 0:
+            # Computing statistics for atomic parameters requires accumulating sums
+            # across all systems due to varying number of atoms per frame
             sys_sumv = []
             sys_sumv2 = []
             sys_sumn = []
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3e5cbe4 and 4f1e009.

📒 Files selected for processing (2)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py

95-98: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

Replace if-else-block with sampled = merged() if callable(merged) else merged

(SIM108)

⏰ Context from checks skipped due to timeout of 90000ms (15)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
deepmd/pt/model/task/fitting.py (3)

7-7: LGTM!

The Callable import is correctly added and properly used in type hints.


75-79: LGTM! Protection value aligns with team's decision.

The protection value of 1e-2 was chosen based on team discussion, providing a good balance for numerical stability.


100-120: LGTM! Robust implementation of fparam statistics.

The implementation correctly handles:

  • Data reshaping and statistics computation
  • Protection against numerical instability
  • Device and dtype consistency

tmp_data_a = []
for jj in range(ndof):
tmp_data_f.append(
np.random.normal( # noqa: NPY002
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ignore?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that the random seed is not set.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore to solve warning in pre-commit.ci

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have set random seed now.

@Chengqian-Zhang
Copy link
Collaborator Author

I do not know why docs/readthedocs.org UT fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Do not compute fparam and aparam stat in pytorch backend.
3 participants