-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat(pt): Support fitting_net input statistics. #4504
base: devel
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Warning Rate limit exceeded@Chengqian-Zhang has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 15 minutes and 26 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe pull request modifies the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
437-440
: Use a ternary operator for compactness.
Ruff suggests replacing theif callable(...)
block with a ternary operator. This is a minor readability enhancement.- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
437-440: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
457-457
: Implementaparam
statistics.
The TODO suggests you plan to handleaparam
similarly tofparam
. Let me know if you’d like help implementing that.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
437-440: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
🔇 Additional comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
301-301
: Consider verifying that fitting_net
is defined before usage.
If fitting_net
were ever None
, invoking compute_input_stats
would raise an exception. A quick check ensures safe execution.
deepmd/pt/model/task/fitting.py (2)
7-7
: New import statements look good.
Thanks for adding Callable
, Union
, and DPPath
; these additions enable clearer type hints and better modularity.
Also applies to: 43-45
416-436
: Comprehensive documentation.
The docstring clearly explains the purpose and usage of compute_input_stats
. This addition aligns with the PR objective to compute input statistics for fitting parameters.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4504 +/- ##
=======================================
Coverage 84.58% 84.58%
=======================================
Files 680 680
Lines 64510 64543 +33
Branches 3539 3539
=======================================
+ Hits 54563 54596 +33
+ Misses 8807 8806 -1
- Partials 1140 1141 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)
442-456
:⚠️ Potential issueAdd protection against zero standard deviation.
The division by
fparam_std
could lead to numerical instability orinf
values when the standard deviation is zero or very small.Apply this diff to add protection:
fparam_std = torch.std(cat_data, axis=0) +epsilon = 1e-12 # Small constant to prevent division by zero +fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std) fparam_inv_std = 1.0 / fparam_std
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
456-457
: Implement aparam statistics calculation.The TODO comment indicates missing implementation for aparam statistics, which is needed for complete functionality as indicated by the PR objectives.
Would you like me to generate the implementation for aparam statistics calculation? It would follow a similar pattern to the fparam implementation but handle the different dimensionality of aparam.
437-440
: Simplify conditional assignment using ternary operator.The if-else block can be simplified using a ternary operator.
Apply this diff to improve code style:
-if callable(merged): - sampled = merged() -else: - sampled = merged +sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
437-440: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
437-440: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (17)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (1)
deepmd/pt/model/task/fitting.py (1)
7-7
: LGTM!The new imports are correctly organized and necessary for the added functionality.
Also applies to: 43-45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)
99-102
: Simplify the code using a ternary operator.The code can be more concise using a ternary operator.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
99-102: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/model/task/fitting.py
(3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
99-102: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)
43-45
: LGTM!The import statement is correctly placed and follows the existing import structure.
78-98
: LGTM!The method signature and docstring are well-structured with clear parameter descriptions.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
deepmd/pt/model/task/fitting.py (2)
103-116
:⚠️ Potential issueAdd protection against division by zero.
The standard deviation calculation needs protection against zero or near-zero values.
Apply this diff to handle potential division by zero:
fparam_std = torch.std(cat_data, dim=0, unbiased=False) +epsilon = 1e-12 +fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std) fparam_inv_std = 1.0 / fparam_std
118-140
:⚠️ Potential issueAdd protection against division by zero in aparam calculations.
Similar to fparam, the aparam standard deviation calculation needs protection.
Apply this diff to handle potential division by zero:
aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2) +epsilon = 1e-12 +aparam_std = torch.where(aparam_std < epsilon, torch.tensor(epsilon, dtype=aparam_std.dtype, device=aparam_std.device), aparam_std) aparam_inv_std = 1.0 / aparam_std
🧹 Nitpick comments (2)
source/tests/pt/test_fitting_stat.py (1)
71-93
: Add more test cases for edge conditions.While the current test case validates the basic functionality, consider adding tests for:
- Empty data
- Zero standard deviation
- Single frame/atom scenarios
Would you like me to generate additional test cases to improve coverage?
deepmd/pt/model/task/fitting.py (1)
98-101
: Simplify conditional with ternary operator.The if-else block can be simplified.
-if callable(merged): - sampled = merged() -else: - sampled = merged +sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/task/fitting.py
(3 hunks)source/tests/pt/test_fitting_stat.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
98-101: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
source/tests/pt/test_fitting_stat.py (2)
17-41
: LGTM! Well-structured test data generation.The
_make_fake_data_pt
function is well-implemented with:
- Clear parameter handling
- Proper data structure generation
- Correct tensor conversion
44-55
: LGTM! Robust statistical computation helpers.The
_brute_fparam_pt
and_brute_aparam_pt
functions provide a reliable "brute force" approach to compute statistics, serving as a good reference for validation.Also applies to: 58-69
deepmd/pt/model/task/fitting.py (1)
78-97
: LGTM! Well-documented method signature.The method signature and docstring are clear and comprehensive.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)
101-101
: 🛠️ Refactor suggestionAdd input validation for empty data.
The method should validate the input data before processing to ensure robustness.
sampled = merged + if not sampled: + raise ValueError("No data samples provided") + if self.numb_fparam > 0 and not all("fparam" in frame for frame in sampled): + raise ValueError("Missing 'fparam' in some data samples") + if self.numb_aparam > 0 and not all("aparam" in frame for frame in sampled): + raise ValueError("Missing 'aparam' in some data samples")🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)
98-101
: Simplify the code using a ternary operator.The code can be more concise by using a ternary operator.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/task/fitting.py
(3 hunks)source/tests/pt/test_fitting_stat.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_fitting_stat.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
98-101: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)
43-45
: LGTM!The import statement is correctly placed and follows the project's import style.
103-152
: LGTM! Well-implemented statistical computations.The implementation correctly:
- Handles both frame and atomic parameters
- Prevents division by zero using epsilon
- Uses efficient tensor operations
- Properly handles data dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (6)
source/tests/pt/test_fitting_stat.py (3)
18-42
: Add input validation for the helper function.Consider adding validation for input parameters to ensure they have compatible dimensions and valid values.
def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): + if not sys_natoms or not sys_nframes: + raise ValueError("Empty system arrays") + if len(avgs) != len(stds): + raise ValueError("Averages and standard deviations must have the same length") + if any(n <= 0 for n in sys_natoms) or any(f <= 0 for f in sys_nframes): + raise ValueError("Number of atoms and frames must be positive") merged_output_stat = []
45-56
: Add shape validation for frame parameters.Consider validating the shape of input data to ensure consistent dimensions across all systems.
def _brute_fparam_pt(data, ndim): + if not data: + raise ValueError("Empty data") adata = [to_numpy_array(ii["fparam"]) for ii in data] + shapes = [a.shape[-1] for a in adata] + if len(set(shapes)) > 1: + raise ValueError(f"Inconsistent dimensions across systems: {shapes}") + if shapes[0] != ndim: + raise ValueError(f"Expected dimension {ndim}, got {shapes[0]}")
59-70
: Consider refactoring to reduce code duplication.The function shares significant code with
_brute_fparam_pt
. Consider extracting common functionality into a shared helper.+def _compute_stats(data, key, ndim): + if not data: + raise ValueError("Empty data") + adata = [to_numpy_array(ii[key]) for ii in data] + shapes = [a.shape[-1] for a in adata] + if len(set(shapes)) > 1: + raise ValueError(f"Inconsistent dimensions across systems: {shapes}") + if shapes[0] != ndim: + raise ValueError(f"Expected dimension {ndim}, got {shapes[0]}") + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + def _brute_fparam_pt(data, ndim): - adata = [to_numpy_array(ii["fparam"]) for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std + return _compute_stats(data, "fparam", ndim) def _brute_aparam_pt(data, ndim): - adata = [to_numpy_array(ii["aparam"]) for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std + return _compute_stats(data, "aparam", ndim)deepmd/pt/model/model/__init__.py (1)
257-258
: Add validation for data_stat_protect parameter.The parameter should be positive to prevent division by zero or negative values.
- data_stat_protect = model_params.get("data_stat_protect") + data_stat_protect = model_params.get("data_stat_protect", 1e-2) + if data_stat_protect <= 0: + raise ValueError("data_stat_protect must be positive")Also applies to: 279-279
deepmd/pt/model/task/fitting.py (2)
78-151
: Add input validation and optimize code.The implementation looks good but could benefit from:
- Input validation for merged data
- Code optimization using ternary operator
- Removal of unused import
-from deepmd.utils.path import ( - DPPath, -) def compute_input_stats( self, merged: Union[Callable[[], list[dict]], list[dict]], protection: float = 1e-2, ) -> None: + if protection <= 0: + raise ValueError("protection must be positive") - if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged + if not sampled: + raise ValueError("No data samples provided") # stat fparam if self.numb_fparam > 0: + if not all("fparam" in frame for frame in sampled): + raise ValueError("Missing 'fparam' in some data samples") cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
123-150
: Add validation for aparam data.Similar to fparam, add validation for aparam data presence and shape.
# stat aparam if self.numb_aparam > 0: + if not all("aparam" in frame for frame in sampled): + raise ValueError("Missing 'aparam' in some data samples") + if not all(frame["aparam"].shape[-1] == self.numb_aparam for frame in sampled): + raise ValueError(f"Inconsistent aparam dimensions, expected {self.numb_aparam}") sys_sumv = []
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/model/atomic_model/base_atomic_model.py
(2 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py
(1 hunks)deepmd/pt/model/model/__init__.py
(2 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)source/tests/pt/test_fitting_stat.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
44-44: deepmd.utils.path.DPPath
imported but unused
Remove unused import: deepmd.utils.path.DPPath
(F401)
98-101: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
🔇 Additional comments (3)
source/tests/pt/test_fitting_stat.py (1)
73-103
: Add comprehensive test cases including edge cases.The test should include:
- Near-zero standard deviation case (as requested in previous review)
- Large value ranges to test numerical stability
- Single frame/atom case
- Error cases for validation
class TestEnerFittingStat(unittest.TestCase): - def test(self) -> None: + def test_normal_case(self) -> None: descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) fitting = EnergyFittingNet( descrpt.get_ntypes(), descrpt.get_dim_out(), neuron=[240, 240, 240], resnet_dt=True, numb_fparam=3, numb_aparam=3, ) avgs = [0, 10, 100] stds = [2, 0.4, 0.0001] sys_natoms = [10, 100] sys_nframes = [5, 2] all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) fitting.compute_input_stats(all_data, protection=1e-2) np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg)) np.testing.assert_almost_equal( frefs_inv, to_numpy_array(fitting.fparam_inv_std) ) np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg)) np.testing.assert_almost_equal( arefs_inv, to_numpy_array(fitting.aparam_inv_std) ) + def test_near_zero_std(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=1, + numb_aparam=1, + ) + avgs = [1.0] + stds = [1e-10] # Near-zero std + sys_natoms = [10] + sys_nframes = [5] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + fitting.compute_input_stats(all_data, protection=1e-2) + # Verify that protection prevents division by zero + self.assertLess(to_numpy_array(fitting.fparam_inv_std)[0], 100.0) + self.assertLess(to_numpy_array(fitting.aparam_inv_std)[0], 100.0) + def test_single_frame(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=1, + numb_aparam=1, + ) + avgs = [1.0] + stds = [0.1] + sys_natoms = [1] + sys_nframes = [1] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + fitting.compute_input_stats(all_data, protection=1e-2) + # Verify stats are computed correctly for single frame + self.assertEqual(to_numpy_array(fitting.fparam_avg).shape, (1,)) + self.assertEqual(to_numpy_array(fitting.aparam_avg).shape, (1,))deepmd/pt/model/atomic_model/base_atomic_model.py (1)
82-82
: LGTM!The data_stat_protect parameter is well-integrated with a sensible default value.
Also applies to: 91-91
deepmd/pt/model/task/fitting.py (1)
103-121
: LGTM! Protection against division by zero is well-implemented.The implementation correctly handles near-zero standard deviations using the protection parameter.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
95-98
: Simplify the callable check using a ternary operator.The code can be more concise.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
95-98: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
122-151
: Add a comment explaining the aparam statistics computation approach.The implementation is correct, but it would be helpful to add a comment explaining why the statistics are computed differently for atomic parameters compared to frame parameters.
# stat aparam if self.numb_aparam > 0: + # Computing statistics for atomic parameters requires accumulating sums + # across all systems due to varying number of atoms per frame sys_sumv = [] sys_sumv2 = [] sys_sumn = []
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
95-98: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
deepmd/pt/model/task/fitting.py (3)
7-7
: LGTM!The
Callable
import is correctly added and properly used in type hints.
75-79
: LGTM! Protection value aligns with team's decision.The protection value of 1e-2 was chosen based on team discussion, providing a good balance for numerical stability.
100-120
: LGTM! Robust implementation of fparam statistics.The implementation correctly handles:
- Data reshaping and statistics computation
- Protection against numerical instability
- Device and dtype consistency
tmp_data_a = [] | ||
for jj in range(ndof): | ||
tmp_data_f.append( | ||
np.random.normal( # noqa: NPY002 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why ignore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the random seed is not set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore to solve warning in pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have set random seed now.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
I do not know why |
Solve issue #4281
Support fitting_net statistics to calculate the mean value and standard deviation of
fparam
/aparam
. So thatfparam
/aparam
can be normalized automatically before concatenating to descriptor.Summary by CodeRabbit
Summary by CodeRabbit
New Features
Bug Fixes