Skip to content

Commit

Permalink
simplify: support using true error as error indicator (#1321)
Browse files Browse the repository at this point in the history
i.e. Eq. B.2 in the DP-GEN paper.

Requires DeePMD-kit version >=2.2.4.
  • Loading branch information
njzjz authored Sep 2, 2023
1 parent 2b6fecf commit 1f0505d
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 30 deletions.
32 changes: 32 additions & 0 deletions dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def general_simplify_arginfo() -> Argument:
)
doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_true_error_f_trust_lo = "The lower bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_f_trust_hi = "The higher bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_e_trust_lo = "The lower bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_e_trust_hi = "The higher bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4."

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Expand Down Expand Up @@ -66,6 +70,34 @@ def general_simplify_arginfo() -> Argument:
default=float("inf"),
doc=doc_model_devi_e_trust_hi,
),
Argument(
"true_error_f_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_f_trust_lo,
),
Argument(
"true_error_f_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_f_trust_hi,
),
Argument(
"true_error_e_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_e_trust_lo,
),
Argument(
"true_error_e_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_e_trust_hi,
),
]


Expand Down
126 changes: 96 additions & 30 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
rest_data_name = "data.rest"
accurate_data_name = "data.accurate"
detail_file_name_prefix = "details"
true_error_file_name = "true_error"
sys_name_fmt = "sys." + data_system_fmt
sys_name_pattern = "sys.[0-9]*[0-9]"

Expand Down Expand Up @@ -238,6 +239,18 @@ def run_model_devi(iter_index, jdata, mdata):
forward_files = [system_file_name]
backward_files = [detail_file_name]

f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf"))
e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf"))
if f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf"):
command_true_error = "{dp} model-devi -m {model} -s {system} -o {detail_file} --real_error".format(
dp=mdata.get("model_devi_command", "dp"),
model=" ".join(task_model_list),
system=system_file_name,
detail_file=true_error_file_name,
)
commands.append(command_true_error)
backward_files.append(true_error_file_name)

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
Expand Down Expand Up @@ -270,6 +283,11 @@ def post_model_devi(iter_index, jdata, mdata):
f_trust_hi = jdata["model_devi_f_trust_hi"]
e_trust_lo = jdata["model_devi_e_trust_lo"]
e_trust_hi = jdata["model_devi_e_trust_hi"]
f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf"))
f_trust_hi_err = jdata.get("true_error_f_trust_hi", float("inf"))
e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf"))
e_trust_hi_err = jdata.get("true_error_e_trust_hi", float("inf"))
use_true_error = f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf")

type_map = jdata.get("type_map", [])
sys_accurate = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -282,38 +300,86 @@ def post_model_devi(iter_index, jdata, mdata):
)

detail_file_name = detail_file_name_prefix
with open(os.path.join(work_path, detail_file_name)) as f:
for line in f:
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:
# DeePMD-kit < 2.2.2
cidx_devi_e = None
else:
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
if not use_true_error:
with open(os.path.join(work_path, detail_file_name)) as f:
for line in f:
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:
# DeePMD-kit < 2.2.2
cidx_devi_e = None
else:
e_devi = 0.0
subsys = sys_entire[name][idx]
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
else:
e_devi = 0.0
subsys = sys_entire[name][idx]
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
else:
raise RuntimeError(
"reach a place that should NOT be reached..."
)
else:
with open(os.path.join(work_path, detail_file_name)) as f, open(
os.path.join(work_path, true_error_file_name)
) as f_err:
for line, line_err in zip(f, f_err):
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
cidx_devi_e = columns.index("devi_e")
else:
raise RuntimeError("reach a place that should NOT be reached...")
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
f_err = float(line_err.split()[cidx_max_devi_f])
e_devi = float(line.split()[cidx_devi_e])
e_err = float(line_err.split()[cidx_devi_e])

subsys = sys_entire[name][idx]
if (
f_devi >= f_trust_hi
or e_devi >= e_trust_hi
or f_err >= f_trust_hi_err
or e_err >= e_trust_hi_err
):
sys_failed.append(subsys)
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
or f_trust_lo_err <= f_err < f_trust_hi_err
or e_trust_lo_err <= e_err < e_trust_hi_err
):
sys_candinate.append(subsys)
elif (
f_devi < f_trust_lo
and e_devi < e_trust_lo
and f_err < f_trust_lo_err
and e_err < e_trust_lo_err
):
sys_accurate.append(subsys)
else:
raise RuntimeError(
"reach a place that should NOT be reached..."
)

counter = {
"candidate": sys_candinate.get_nframes(),
Expand Down
26 changes: 26 additions & 0 deletions tests/simplify/test_post_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def setUp(self):
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)
np.savetxt(
self.work_path / "true_error",
model_devi,
fmt=["%12d"] + ["%19.6e" for _ in range(7)],
header="data.rest.old/"
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)

def tearDown(self):
shutil.rmtree("iter.000001", ignore_errors=True)
Expand Down Expand Up @@ -114,3 +122,21 @@ def test_post_model_devi_accurate(self):
{},
)
assert (self.work_path / "data.accurate" / self.system.formula).exists()

def test_post_model_devi_true_error_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"true_error_e_trust_lo": float("inf"),
"true_error_e_trust_hi": float("inf"),
"true_error_f_trust_lo": 0.15,
"true_error_f_trust_hi": 0.25,
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()
26 changes: 26 additions & 0 deletions tests/simplify/test_run_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,29 @@ def test_one_h5(self):
},
}
dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata)

def test_true_error(self):
jdata = {
"type_map": ["H"],
"true_error_f_trust_lo": 0.15,
"true_error_f_trust_hi": 0.25,
}
with tempfile.TemporaryDirectory() as remote_root:
mdata = {
"model_devi_command": (
f"test -d {dpgen.simplify.simplify.rest_data_name}.old"
f"&& touch {dpgen.simplify.simplify.detail_file_name_prefix}"
f"&& touch {dpgen.simplify.simplify.true_error_file_name}"
"&& echo dp"
),
"model_devi_machine": {
"context_type": "LocalContext",
"batch_type": "shell",
"local_root": "./",
"remote_root": remote_root,
},
"model_devi_resources": {
"group_size": 1,
},
}
dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata)

0 comments on commit 1f0505d

Please sign in to comment.