Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: migrate and refactor polar and dos bias #3662

Merged
merged 28 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
774f2ce
chore: try rename to atom_
anyangml Apr 10, 2024
2aff780
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
d6a6571
fix: UTs
anyangml Apr 10, 2024
bd75e48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
14a002c
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 11, 2024
f0baf2e
fix: data shape
anyangml Apr 12, 2024
5300d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
90ede06
fix: var name
anyangml Apr 12, 2024
8f9dc5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
d8d3f16
fix: var_name
anyangml Apr 12, 2024
ed5c543
fix: loss name
anyangml Apr 15, 2024
88da7ce
fix: dp var name
anyangml Apr 15, 2024
7176a39
fix: dp var name
anyangml Apr 15, 2024
3136c10
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 16, 2024
88e41e5
chore: remove bias in fitting
anyangml Apr 16, 2024
c94608a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
23c7fdf
chore: remove UTs
anyangml Apr 16, 2024
ead2a38
fix: UT import
anyangml Apr 16, 2024
ec89624
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
5698bfc
chore: move polar bias
anyangml Apr 17, 2024
3f11f7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
d7036b8
feat: add UT on out_std
anyangml Apr 17, 2024
09d775d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
cd2e78e
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 17, 2024
0eacfe9
fix: UTs
anyangml Apr 18, 2024
c0c08ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
ae709c4
fix: UTs
anyangml Apr 18, 2024
e57dd7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def apply_out_stat(
if self.fitting_net.shift_diag:
nframes, nloc = atype.shape
device = out_bias[self.bias_keys[0]].device
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, device=device)
temp = torch.zeros(ntypes, dtype=dtype, device=device)
for i in range(ntypes):
temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3)))
modified_bias = temp[atype]
Expand All @@ -51,7 +52,7 @@ def apply_out_stat(
modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype]
)

eye = torch.eye(3, device=device)
eye = torch.eye(3, dtype=dtype, device=device)
eye = eye.repeat(nframes, nloc, 1, 1)
# (nframes, nloc, 3, 3)
modified_bias = modified_bias.unsqueeze(-1) * eye
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,8 @@ def compute_output_stats_atomic(
nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1]))
nan_padding.fill(np.nan)
bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0)
std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0)
std_atom_e[kk] = np.concatenate([std_atom_e[kk], nan_padding], axis=0)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
else:
# this key does not have atomic labels, skip it.
continue
bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)
return bias_atom_e, std_atom_e
23 changes: 18 additions & 5 deletions source/tests/pt/model/test_atomic_model_atomic_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,15 @@ def cvt_ret(x):
self.merged_output_stat, stat_file_path=self.stat_file_path
)
ret1 = md0.forward_common_atomic(*args)
expected_std = np.ones((2, 2, 2)) # 2 keys, 2 atypes, 2 max dims.
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)
expected_std = np.ones(
(2, 2, 2), dtype=np.float64
) # 2 keys, 2 atypes, 2 max dims.
expected_std[0, :, :1] = np.array([0.0, 0.816496]).reshape(
2, 1
) # updating std for foo based on [5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]
np.testing.assert_almost_equal(
to_numpy_array(md0.out_std), expected_std, decimal=4
)
ret1 = cvt_ret(ret1)
# nt x odim
foo_bias = np.array([5.0, 6.0]).reshape(2, 1)
Expand All @@ -233,7 +240,9 @@ def raise_error():
ret2 = cvt_ret(ret2)
for kk in ["foo", "bar"]:
np.testing.assert_almost_equal(ret1[kk], ret2[kk])
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)
np.testing.assert_almost_equal(
to_numpy_array(md0.out_std), expected_std, decimal=4
)

# 4. test change bias
BaseAtomicModel.change_out_bias(
Expand All @@ -249,15 +258,19 @@ def raise_error():
]
ret3 = md0.forward_common_atomic(*args)
ret3 = cvt_ret(ret3)

expected_std[0, :, :1] = np.array([1.24722, 0.47140]).reshape(
2, 1
) # updating std for foo based on [4.0, 3.0, 2.0], [1.0, 1.0, 1.0]]
expected_ret3 = {}
# new bias [2.666, 1.333]
expected_ret3["foo"] = np.array(
[[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]]
).reshape(2, 3, 1)
for kk in ["foo"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)
np.testing.assert_almost_equal(
to_numpy_array(md0.out_std), expected_std, decimal=4
)


class TestAtomicModelStatMergeGlobalAtomic(
Expand Down
33 changes: 32 additions & 1 deletion source/tests/pt/model/test_polar_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def cvt_ret(x):
)
ret1 = md0.forward_common_atomic(*args)
ret1 = cvt_ret(ret1)
expected_std = np.ones((1, 2, 9)) # 1 keys, 2 atypes, 9 max dims.
expected_std = np.zeros(
(1, 2, 9), dtype=np.float64
) # 1 keys, 2 atypes, 9 max dims.
expected_std[:, 1, :] = np.ones(9, dtype=np.float64) * 0.8164966 # updating std
# nt x odim (dia)
diagnoal_bias = np.array(
[
Expand Down Expand Up @@ -239,6 +242,34 @@ def raise_error():
ret3 = cvt_ret(ret3)

expected_ret3 = {}
expected_std = np.array(
[
[
[
1.4142136,
1.4142136,
1.4142136,
1.2472191,
1.2472191,
1.2472191,
1.2472191,
1.2472191,
1.2472191,
],
[
0.4714045,
0.4714045,
0.4714045,
0.8164966,
0.8164966,
0.8164966,
2.6246693,
2.6246693,
2.6246693,
],
]
]
)
# new bias [[[3.0000, -, -, -, 2.6667, -, -, -, 2.3333],
# [1.6667, -, -, -, 2.0000, -, -, -, 1.3333]]]
# which yields [2.667, 1.667]
Expand Down
Loading