Skip to content

Commit

Permalink
fix(pt ut): make separated uts deterministic (#4162)
Browse files Browse the repository at this point in the history
Fix failed uts in #4145 .

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added a `"seed"` property to multiple JSON configuration files,
enhancing control over randomness in model training and evaluation.
- Introduced a global seed parameter in various test functions to
improve reproducibility across test runs.

- **Bug Fixes**
- Ensured consistent random number generation in tests by integrating a
global seed parameter.

- **Documentation**
- Updated configuration files and test methods to reflect the addition
of the seed parameter for clarity and consistency.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Sep 25, 2024
1 parent 0b72dae commit 508759c
Show file tree
Hide file tree
Showing 15 changed files with 62 additions and 3 deletions.
3 changes: 2 additions & 1 deletion source/tests/pt/model/models/dpa1.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"activation_function": "tanh",
"scaling_factor": 1.0,
"normalize": true,
"temperature": 1.0
"temperature": 1.0,
"seed": 1
},
"fitting_net": {
"neuron": [
Expand Down
1 change: 1 addition & 0 deletions source/tests/pt/model/models/dpa2.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"g1_out_conv": false,
"g1_out_mlp": false
},
"seed": 1,
"add_tebd_to_repinit_out": false
},
"fitting_net": {
Expand Down
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_descriptor_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_consistency(
resnet_dt=idt,
old_impl=False,
exclude_mask=em,
seed=GLOBAL_SEED,
).to(env.DEVICE)
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down Expand Up @@ -130,6 +131,7 @@ def test_load_stat(self):
precision=prec,
resnet_dt=idt,
old_impl=False,
seed=GLOBAL_SEED,
)
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down Expand Up @@ -180,6 +182,7 @@ def test_jit(
precision=prec,
resnet_dt=idt,
old_impl=False,
seed=GLOBAL_SEED,
)
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_consistency(
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
ft1 = DPDipoleFitting.deserialize(ft0.serialize())
ft2 = DipoleFittingNet.deserialize(ft1.serialize())
Expand Down Expand Up @@ -139,6 +140,7 @@ def test_jit(
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
seed=GLOBAL_SEED,
).to(env.DEVICE)
torch.jit.script(ft0)

Expand Down Expand Up @@ -180,6 +182,7 @@ def test_rot(self):
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
if nfp > 0:
ifp = torch.tensor(
Expand Down Expand Up @@ -234,6 +237,7 @@ def test_permu(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
Expand Down Expand Up @@ -280,6 +284,7 @@ def test_trans(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand Down Expand Up @@ -327,6 +332,7 @@ def setUp(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping)
Expand Down
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_consistency(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
).to(env.DEVICE)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down Expand Up @@ -125,6 +126,7 @@ def test_consistency(
resnet_dt=idt,
smooth_type_embedding=sm,
old_impl=True,
seed=GLOBAL_SEED,
).to(env.DEVICE)
dd0_state_dict = dd0.se_atten.state_dict()
dd3_state_dict = dd3.se_atten.state_dict()
Expand Down Expand Up @@ -210,6 +212,7 @@ def test_jit(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
PRECISION_DICT,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_env_mat import (
TestCaseSingleFrameWithNlist,
)
Expand Down Expand Up @@ -152,6 +155,7 @@ def test_consistency(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
).to(env.DEVICE)

dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
Expand Down Expand Up @@ -201,6 +205,7 @@ def test_consistency(
add_tebd_to_repinit_out=False,
precision=prec,
old_impl=True,
seed=GLOBAL_SEED,
).to(env.DEVICE)
dd0_state_dict = dd0.state_dict()
dd3_state_dict = dd3.state_dict()
Expand Down Expand Up @@ -346,6 +351,7 @@ def test_jit(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
).to(env.DEVICE)

dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
Expand Down
5 changes: 4 additions & 1 deletion source/tests/pt/model/test_embedding_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
)
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf

from ...seed import (
GLOBAL_SEED,
)
from ..test_finetune import (
energy_data_requirement,
)
Expand Down Expand Up @@ -153,7 +156,7 @@ def test_consistency(self):
sel=self.sel,
neuron=self.filter_neuron,
axis_neuron=self.axis_neuron,
seed=1,
seed=GLOBAL_SEED,
)
dp_embedding, dp_force, dp_vars = base_se_a(
descriptor=dp_d,
Expand Down
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_consistency(
mixed_types=mixed_types,
exclude_types=et,
neuron=nn,
seed=GLOBAL_SEED,
).to(env.DEVICE)
ft1 = DPInvarFitting.deserialize(ft0.serialize())
ft2 = InvarFitting.deserialize(ft0.serialize())
Expand Down Expand Up @@ -168,6 +169,7 @@ def test_jit(
numb_aparam=nap,
mixed_types=mixed_types,
exclude_types=et,
seed=GLOBAL_SEED,
).to(env.DEVICE)
torch.jit.script(ft0)

Expand All @@ -177,6 +179,7 @@ def test_get_set(self):
self.nt,
3,
1,
seed=GLOBAL_SEED,
)
rng = np.random.default_rng(GLOBAL_SEED)
foo = rng.normal([3, 4])
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"temperature": 1.0,
"set_davg_zero": True,
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"neuron": [24, 24, 24],
Expand Down Expand Up @@ -155,6 +156,7 @@
"update_g2_has_attn": True,
"attn2_has_gate": True,
},
"seed": 1,
"add_tebd_to_repinit_out": False,
},
"fitting_net": {
Expand Down Expand Up @@ -207,6 +209,7 @@
"g1_out_conv": True,
"g1_out_mlp": True,
},
"seed": 1,
"add_tebd_to_repinit_out": False,
},
"fitting_net": {
Expand Down Expand Up @@ -235,6 +238,7 @@
"temperature": 1.0,
"set_davg_zero": True,
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"neuron": [24, 24, 24],
Expand Down Expand Up @@ -264,6 +268,7 @@
"scaling_factor": 1.0,
"normalize": True,
"temperature": 1.0,
"seed": 1,
},
{
"type": "dpa2",
Expand Down Expand Up @@ -296,6 +301,7 @@
"update_g2_has_attn": True,
"attn2_has_gate": True,
},
"seed": 1,
"add_tebd_to_repinit_out": False,
},
],
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_consistency(
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
seed=GLOBAL_SEED,
).to(env.DEVICE)
ft1 = DPPolarFitting.deserialize(ft0.serialize())
ft2 = PolarFittingNet.deserialize(ft0.serialize())
Expand Down Expand Up @@ -143,6 +144,7 @@ def test_jit(
numb_aparam=nap,
mixed_types=mixed_types,
fit_diag=fit_diag,
seed=GLOBAL_SEED,
).to(env.DEVICE)
torch.jit.script(ft0)

Expand Down Expand Up @@ -186,6 +188,7 @@ def test_rot(self):
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
seed=GLOBAL_SEED,
).to(env.DEVICE)
if nfp > 0:
ifp = torch.tensor(
Expand Down Expand Up @@ -248,6 +251,7 @@ def test_permu(self):
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
Expand Down Expand Up @@ -298,6 +302,7 @@ def test_trans(self):
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand Down Expand Up @@ -347,6 +352,7 @@ def setUp(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)
Expand Down
10 changes: 10 additions & 0 deletions source/tests/pt/model/test_property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
to_numpy_array,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_env_mat import (
TestCaseSingleFrameWithNlist,
)
Expand Down Expand Up @@ -78,6 +81,7 @@ def test_consistency(
bias_atom_p=bias_atom_p,
intensive=intensive,
bias_method=bias_method,
seed=GLOBAL_SEED,
).to(env.DEVICE)

ft1 = DPProperFittingNet.deserialize(ft0.serialize())
Expand Down Expand Up @@ -146,6 +150,7 @@ def test_jit(
mixed_types=self.dd0.mixed_types(),
intensive=intensive,
bias_method=bias_method,
seed=GLOBAL_SEED,
).to(env.DEVICE)
torch.jit.script(ft0)

Expand Down Expand Up @@ -199,6 +204,7 @@ def test_trans(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand Down Expand Up @@ -266,6 +272,7 @@ def test_rot(self):
mixed_types=self.dd0.mixed_types(),
intensive=intensive,
bias_method=bias_method,
seed=GLOBAL_SEED,
).to(env.DEVICE)
if nfp > 0:
ifp = torch.tensor(
Expand Down Expand Up @@ -320,6 +327,7 @@ def test_permu(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
Expand Down Expand Up @@ -367,6 +375,7 @@ def test_trans(self):
numb_fparam=0,
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
seed=GLOBAL_SEED,
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand Down Expand Up @@ -417,6 +426,7 @@ def setUp(self):
numb_aparam=0,
mixed_types=self.dd0.mixed_types(),
intensive=True,
seed=GLOBAL_SEED,
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = PropertyModel(self.dd0, self.ft0, self.type_mapping)
Expand Down
5 changes: 5 additions & 0 deletions source/tests/pt/model/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PRECISION_DICT,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_env_mat import (
TestCaseSingleFrameWithNlist,
)
Expand Down Expand Up @@ -64,6 +67,7 @@ def test_consistency(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
).to(env.DEVICE)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down Expand Up @@ -135,6 +139,7 @@ def test_jit(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
old_impl=False,
seed=GLOBAL_SEED,
)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
Expand Down
Loading

0 comments on commit 508759c

Please sign in to comment.