Skip to content

Commit

Permalink
fix(pt): fix seed in dpmodel fitting (deepmodeling#3916)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced a new optional `seed` parameter across various fitting
modules to enhance customization and reproducibility of model fitting
processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 1f798a6 commit 2b1e0d6
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@ def __init__(
c_differentiable: bool = True,
type_map: Optional[List[str]] = None,
old_impl=False,
# not used
seed: Optional[Union[int, List[int]]] = None,
):
# seed, uniform_seed are not included
if tot_ener_zero:
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
Expand Down Expand Up @@ -143,6 +141,7 @@ def __init__(
mixed_types=mixed_types,
exclude_types=exclude_types,
type_map=type_map,
seed=seed,
)
self.old_impl = False

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
mixed_types: bool = False,
exclude_types: List[int] = [],
type_map: Optional[List[str]] = None,
# not used
seed: Optional[Union[int, List[int]]] = None,
):
if bias_dos is not None:
Expand All @@ -69,6 +68,7 @@ def __init__(
mixed_types=mixed_types,
exclude_types=exclude_types,
type_map=type_map,
seed=seed,
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(
mixed_types: bool = False,
exclude_types: List[int] = [],
type_map: Optional[List[str]] = None,
# not used
seed: Optional[Union[int, List[int]]] = None,
):
super().__init__(
Expand All @@ -70,6 +69,7 @@ def __init__(
mixed_types=mixed_types,
exclude_types=exclude_types,
type_map=type_map,
seed=seed,
)

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dict,
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -134,8 +135,8 @@ def __init__(
mixed_types: bool = True,
exclude_types: List[int] = [],
type_map: Optional[List[str]] = None,
seed: Optional[Union[int, List[int]]] = None,
):
# seed, uniform_seed are not included
if tot_ener_zero:
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
type_map=type_map,
seed=seed,
)

def serialize(self) -> dict:
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,8 @@ def __init__(
scale: Optional[List[float]] = None,
shift_diag: bool = True,
type_map: Optional[List[str]] = None,
# not used
seed: Optional[Union[int, List[int]]] = None,
):
# seed, uniform_seed are not included
if tot_ener_zero:
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
Expand Down Expand Up @@ -167,6 +165,7 @@ def __init__(
mixed_types=mixed_types,
exclude_types=exclude_types,
type_map=type_map,
seed=seed,
)
self.old_impl = False

Expand Down

0 comments on commit 2b1e0d6

Please sign in to comment.