Skip to content

Commit

Permalink
feat(jax/array-api): DOS fitting (#4218)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced the `DOSFittingNet` class for enhanced fitting
capabilities.
- Added methods to evaluate different backends (JAX and Array API
Strict) for computing density of states.
- Enhanced testing framework to conditionally include tests based on
library availability.

- **Bug Fixes**
- Improved serialization of the `bias_atom_e` variable to ensure
consistent data representation.

- **Tests**
- Expanded the `TestDOS` class with new attributes and methods for
better backend evaluation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Oct 16, 2024
1 parent cfb4731 commit 5050f61
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
to_numpy_array,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
Expand Down Expand Up @@ -89,6 +90,6 @@ def serialize(self) -> dict:
**super().serialize(),
"type": "dos",
}
dd["@variables"]["bias_atom_e"] = self.bias_atom_e
dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e)

return dd
8 changes: 8 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
)

from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
flax_module,
Expand Down Expand Up @@ -37,3 +38,10 @@ class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@flax_module
class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
7 changes: 7 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
)

from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP

from ..common import (
Expand Down Expand Up @@ -36,3 +37,9 @@ class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
59 changes: 59 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -36,6 +38,20 @@
fitting_dos,
)

if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import DOSFittingNet as DOSFittingJAX
else:
DOSFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import DOSFittingNet as DOSFittingStrict
else:
DOSFittingStrict = object


@parameterized(
(True, False), # resnet_dt
Expand Down Expand Up @@ -74,9 +90,19 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_jax(self) -> bool:
return not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
return not INSTALLED_ARRAY_API_STRICT

tf_class = DOSFittingTF
dp_class = DOSFittingDP
pt_class = DOSFittingPT
jax_class = DOSFittingJAX
array_api_strict_class = DOSFittingStrict
args = fitting_dos()

def setUp(self):
Expand Down Expand Up @@ -157,6 +183,39 @@ def eval_dp(self, dp_obj: Any) -> Any:
fparam=self.fparam if numb_fparam else None,
)["dos"]

def eval_jax(self, jax_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_dos,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
)["dos"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_dos,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
)["dos"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down

0 comments on commit 5050f61

Please sign in to comment.