Skip to content

Commit

Permalink
add tests for dpmodel
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 28, 2024
1 parent 88c5219 commit 27ea218
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 3 deletions.
175 changes: 175 additions & 0 deletions source/tests/common/dpmodel/test_update_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import random
import unittest
from unittest.mock import (
patch,
)

from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)


def update_sel(jdata):
type_map = jdata["model"].get("type_map")
train_data = None
jdata["model"], _ = BaseModel.update_sel(train_data, type_map, jdata["model"])
return jdata


class TestTrain(unittest.TestCase):
def setUp(self) -> None:
self.update_sel = UpdateSel()
self.mock_min_nbor_dist = random.random()
return super().setUp()

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_one_sel(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [10, 20]

min_nbor_dist, sel = self.update_sel.update_one_sel(None, None, 6, "auto")
# self.assertEqual(descriptor['sel'], [11,22])
self.assertEqual(sel, [12, 24])
self.assertAlmostEqual(min_nbor_dist, self.mock_min_nbor_dist)
min_nbor_dist, sel = self.update_sel.update_one_sel(None, None, 6, "auto:1.5")
# self.assertEqual(descriptor['sel'], [15,30])
self.assertEqual(sel, [16, 32])
self.assertAlmostEqual(min_nbor_dist, self.mock_min_nbor_dist)

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_sel_hybrid(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [10, 20]

jdata = {
"model": {
"descriptor": {
"type": "hybrid",
"list": [
{"type": "se_e2_a", "rcut": 6, "sel": "auto"},
{"type": "se_e2_a", "rcut": 6, "sel": "auto:1.5"},
],
}
},
"training": {"training_data": {}},
}
expected_out = {
"model": {
"descriptor": {
"type": "hybrid",
"list": [
{"type": "se_e2_a", "rcut": 6, "sel": [12, 24]},
{"type": "se_e2_a", "rcut": 6, "sel": [16, 32]},
],
}
},
"training": {"training_data": {}},
}
jdata = update_sel(jdata)
self.assertEqual(jdata, expected_out)

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_sel(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [10, 20]

jdata = {
"model": {"descriptor": {"type": "se_e2_a", "rcut": 6, "sel": "auto"}},
"training": {"training_data": {}},
}
expected_out = {
"model": {"descriptor": {"type": "se_e2_a", "rcut": 6, "sel": [12, 24]}},
"training": {"training_data": {}},
}
jdata = update_sel(jdata)
self.assertEqual(jdata, expected_out)

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_sel_atten_auto(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [25]

jdata = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": "auto",
"rcut": 6,
}
},
"training": {"training_data": {}},
}
expected_out = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": 28,
"rcut": 6,
}
},
"training": {"training_data": {}},
}
jdata = update_sel(jdata)
self.assertEqual(jdata, expected_out)

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_sel_atten_int(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [25]

jdata = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": 30,
"rcut": 6,
}
},
"training": {"training_data": {}},
}
expected_out = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": 30,
"rcut": 6,
}
},
"training": {"training_data": {}},
}
jdata = update_sel(jdata)
self.assertEqual(jdata, expected_out)

@patch("deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat")
def test_update_sel_atten_list(self, sel_mock):
sel_mock.return_value = self.mock_min_nbor_dist, [25]

jdata = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": 30,
"rcut": 6,
}
},
"training": {"training_data": {}},
}
expected_out = {
"model": {
"descriptor": {
"type": "se_atten",
"sel": 30,
"rcut": 6,
}
},
"training": {"training_data": {}},
}
jdata = update_sel(jdata)
self.assertEqual(jdata, expected_out)

def test_wrap_up_4(self):
self.assertEqual(self.update_sel.wrap_up_4(12), 3 * 4)
self.assertEqual(self.update_sel.wrap_up_4(13), 4 * 4)
self.assertEqual(self.update_sel.wrap_up_4(14), 4 * 4)
self.assertEqual(self.update_sel.wrap_up_4(15), 4 * 4)
self.assertEqual(self.update_sel.wrap_up_4(16), 4 * 4)
self.assertEqual(self.update_sel.wrap_up_4(17), 5 * 4)
4 changes: 1 addition & 3 deletions source/tests/pt/test_update_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
def update_sel(jdata):
type_map = jdata["model"].get("type_map")
train_data = None
jdata["model"], min_nbor_dist = BaseModel.update_sel(
train_data, type_map, jdata["model"]
)
jdata["model"], _ = BaseModel.update_sel(train_data, type_map, jdata["model"])
return jdata


Expand Down

0 comments on commit 27ea218

Please sign in to comment.