-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(jax/array-api): dipole/polarizability fitting (#4278)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `DipoleFittingNet` and `PolarFittingNet` classes for enhanced fitting functionality. - Expanded support for JAX as a backend for fitting tensors, alongside existing TensorFlow and PyTorch support. - **Bug Fixes** - Improved error handling and parameter validation in the `DipoleFitting` and `PolarFitting` classes. - **Documentation** - Updated documentation to reflect JAX as a supported backend for fitting tensors. - **Tests** - Enhanced testing framework to support evaluations with JAX and Array API Strict, including new test methods and properties. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
8 changed files
with
184 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.jax.fitting.fitting import ( | ||
DipoleFittingNet, | ||
DOSFittingNet, | ||
EnergyFittingNet, | ||
PolarFittingNet, | ||
) | ||
|
||
__all__ = [ | ||
"EnergyFittingNet", | ||
"DOSFittingNet", | ||
"DipoleFittingNet", | ||
"PolarFittingNet", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters