-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chore: remove grad from nlist linear model #4380
Chore: remove grad from nlist linear model #4380
Conversation
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in this pull request involve modifications to the Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🔇 Additional comments (1)deepmd/pt/model/atomic_model/linear_atomic_model.py (1)
Potential Issue with Detaching The extended_coord.detach(), Detaching Recommendation:
Consider verifying whether gradient tracking through Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/atomic_model/linear_atomic_model.py (3)
259-260
: LGTM! Optimized gradient trackingThe conditional gradient tracking is a good optimization that ensures gradients are only computed when needed through
do_grad_r()
ordo_grad_c()
.Consider adding a debug log to track when gradients are enabled, which could help with performance analysis:
if self.do_grad_r() or self.do_grad_c(): + logger.debug("Enabling gradients for extended_coord") extended_coord.requires_grad_(True)
Line range hint
673-674
: Consider increasing the clamp threshold for better numerical stabilityThe current clamp threshold of 1e-20 for preventing division by zero might be too small for float64 precision.
-denominator = torch.sum( +denominator = torch.clamp(torch.sum( torch.where( nlist_larger != -1, torch.exp(-pairwise_rr / self.smin_alpha), torch.zeros_like(nlist_larger), ), dim=-1, - ) # handle masked nnei. +), min=1e-16) # increased threshold for better numerical stability
Line range hint
675-686
: Optimize tensor operations for better performanceThe current implementation creates multiple intermediate tensors. Consider combining operations to reduce memory allocations.
-u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) -coef = torch.zeros_like(u) -left_mask = sigma < self.sw_rmin -mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) -right_mask = sigma >= self.sw_rmax -coef[left_mask] = 1 -smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 -coef[mid_mask] = smooth[mid_mask] -coef[right_mask] = 0 +# Compute normalized distance once +u = torch.clamp((sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin), 0.0, 1.0) +# Compute smooth transition directly +coef = torch.where(sigma < self.sw_rmin, + torch.ones_like(sigma), + torch.where(sigma >= self.sw_rmax, + torch.zeros_like(sigma), + -6 * u**5 + 15 * u**4 - 10 * u**3 + 1))
for more information, see https://pre-commit.ci
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4380 +/- ##
========================================
Coverage 84.50% 84.50%
========================================
Files 596 604 +8
Lines 56665 56942 +277
Branches 3459 3486 +27
========================================
+ Hits 47884 48120 +236
- Misses 7654 7697 +43
+ Partials 1127 1125 -2 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Summary by CodeRabbit
New Features
Bug Fixes