diff --git a/dpgen/simplify/arginfo.py b/dpgen/simplify/arginfo.py index 9aa8f0234..5db5b0812 100644 --- a/dpgen/simplify/arginfo.py +++ b/dpgen/simplify/arginfo.py @@ -34,6 +34,10 @@ def general_simplify_arginfo() -> Argument: ) doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2." doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2." + doc_true_error_f_trust_lo = "The lower bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4." + doc_true_error_f_trust_hi = "The higher bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4." + doc_true_error_e_trust_lo = "The lower bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4." + doc_true_error_e_trust_hi = "The higher bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4." return [ Argument("labeled", bool, optional=True, default=False, doc=doc_labeled), @@ -66,6 +70,34 @@ def general_simplify_arginfo() -> Argument: default=float("inf"), doc=doc_model_devi_e_trust_hi, ), + Argument( + "true_error_f_trust_lo", + float, + optional=True, + default=float("inf"), + doc=doc_true_error_f_trust_lo, + ), + Argument( + "true_error_f_trust_hi", + float, + optional=True, + default=float("inf"), + doc=doc_true_error_f_trust_hi, + ), + Argument( + "true_error_e_trust_lo", + float, + optional=True, + default=float("inf"), + doc=doc_true_error_e_trust_lo, + ), + Argument( + "true_error_e_trust_hi", + float, + optional=True, + default=float("inf"), + doc=doc_true_error_e_trust_hi, + ), ] diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index e5dc24d7c..c392adbf5 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -52,6 +52,7 @@ rest_data_name = "data.rest" accurate_data_name = "data.accurate" detail_file_name_prefix = "details" +true_error_file_name = "true_error" sys_name_fmt = "sys." + data_system_fmt sys_name_pattern = "sys.[0-9]*[0-9]" @@ -238,6 +239,18 @@ def run_model_devi(iter_index, jdata, mdata): forward_files = [system_file_name] backward_files = [detail_file_name] + f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf")) + e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf")) + if f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf"): + command_true_error = "{dp} model-devi -m {model} -s {system} -o {detail_file} --real_error".format( + dp=mdata.get("model_devi_command", "dp"), + model=" ".join(task_model_list), + system=system_file_name, + detail_file=true_error_file_name, + ) + commands.append(command_true_error) + backward_files.append(true_error_file_name) + api_version = mdata.get("api_version", "1.0") if Version(api_version) < Version("1.0"): raise RuntimeError( @@ -270,6 +283,11 @@ def post_model_devi(iter_index, jdata, mdata): f_trust_hi = jdata["model_devi_f_trust_hi"] e_trust_lo = jdata["model_devi_e_trust_lo"] e_trust_hi = jdata["model_devi_e_trust_hi"] + f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf")) + f_trust_hi_err = jdata.get("true_error_f_trust_hi", float("inf")) + e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf")) + e_trust_hi_err = jdata.get("true_error_e_trust_hi", float("inf")) + use_true_error = f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf") type_map = jdata.get("type_map", []) sys_accurate = dpdata.MultiSystems(type_map=type_map) @@ -282,38 +300,86 @@ def post_model_devi(iter_index, jdata, mdata): ) detail_file_name = detail_file_name_prefix - with open(os.path.join(work_path, detail_file_name)) as f: - for line in f: - if line.startswith("# data.rest.old"): - name = (line.split()[1]).split("/")[-1] - elif line.startswith("#"): - columns = line.split()[1:] - cidx_step = columns.index("step") - cidx_max_devi_f = columns.index("max_devi_f") - try: - cidx_devi_e = columns.index("devi_e") - except ValueError: - # DeePMD-kit < 2.2.2 - cidx_devi_e = None - else: - idx = int(line.split()[cidx_step]) - f_devi = float(line.split()[cidx_max_devi_f]) - if cidx_devi_e is not None: - e_devi = float(line.split()[cidx_devi_e]) + if not use_true_error: + with open(os.path.join(work_path, detail_file_name)) as f: + for line in f: + if line.startswith("# data.rest.old"): + name = (line.split()[1]).split("/")[-1] + elif line.startswith("#"): + columns = line.split()[1:] + cidx_step = columns.index("step") + cidx_max_devi_f = columns.index("max_devi_f") + try: + cidx_devi_e = columns.index("devi_e") + except ValueError: + # DeePMD-kit < 2.2.2 + cidx_devi_e = None else: - e_devi = 0.0 - subsys = sys_entire[name][idx] - if f_devi >= f_trust_hi or e_devi >= e_trust_hi: - sys_failed.append(subsys) - elif ( - f_trust_lo <= f_devi < f_trust_hi - or e_trust_lo <= e_devi < e_trust_hi - ): - sys_candinate.append(subsys) - elif f_devi < f_trust_lo and e_devi < e_trust_lo: - sys_accurate.append(subsys) + idx = int(line.split()[cidx_step]) + f_devi = float(line.split()[cidx_max_devi_f]) + if cidx_devi_e is not None: + e_devi = float(line.split()[cidx_devi_e]) + else: + e_devi = 0.0 + subsys = sys_entire[name][idx] + if f_devi >= f_trust_hi or e_devi >= e_trust_hi: + sys_failed.append(subsys) + elif ( + f_trust_lo <= f_devi < f_trust_hi + or e_trust_lo <= e_devi < e_trust_hi + ): + sys_candinate.append(subsys) + elif f_devi < f_trust_lo and e_devi < e_trust_lo: + sys_accurate.append(subsys) + else: + raise RuntimeError( + "reach a place that should NOT be reached..." + ) + else: + with open(os.path.join(work_path, detail_file_name)) as f, open( + os.path.join(work_path, true_error_file_name) + ) as f_err: + for line, line_err in zip(f, f_err): + if line.startswith("# data.rest.old"): + name = (line.split()[1]).split("/")[-1] + elif line.startswith("#"): + columns = line.split()[1:] + cidx_step = columns.index("step") + cidx_max_devi_f = columns.index("max_devi_f") + cidx_devi_e = columns.index("devi_e") else: - raise RuntimeError("reach a place that should NOT be reached...") + idx = int(line.split()[cidx_step]) + f_devi = float(line.split()[cidx_max_devi_f]) + f_err = float(line_err.split()[cidx_max_devi_f]) + e_devi = float(line.split()[cidx_devi_e]) + e_err = float(line_err.split()[cidx_devi_e]) + + subsys = sys_entire[name][idx] + if ( + f_devi >= f_trust_hi + or e_devi >= e_trust_hi + or f_err >= f_trust_hi_err + or e_err >= e_trust_hi_err + ): + sys_failed.append(subsys) + elif ( + f_trust_lo <= f_devi < f_trust_hi + or e_trust_lo <= e_devi < e_trust_hi + or f_trust_lo_err <= f_err < f_trust_hi_err + or e_trust_lo_err <= e_err < e_trust_hi_err + ): + sys_candinate.append(subsys) + elif ( + f_devi < f_trust_lo + and e_devi < e_trust_lo + and f_err < f_trust_lo_err + and e_err < e_trust_lo_err + ): + sys_accurate.append(subsys) + else: + raise RuntimeError( + "reach a place that should NOT be reached..." + ) counter = { "candidate": sys_candinate.get_nframes(), diff --git a/tests/simplify/test_post_model_devi.py b/tests/simplify/test_post_model_devi.py index 0eeac7fc2..539ad0f76 100644 --- a/tests/simplify/test_post_model_devi.py +++ b/tests/simplify/test_post_model_devi.py @@ -41,6 +41,14 @@ def setUp(self): + self.system.formula + "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e", ) + np.savetxt( + self.work_path / "true_error", + model_devi, + fmt=["%12d"] + ["%19.6e" for _ in range(7)], + header="data.rest.old/" + + self.system.formula + + "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e", + ) def tearDown(self): shutil.rmtree("iter.000001", ignore_errors=True) @@ -114,3 +122,21 @@ def test_post_model_devi_accurate(self): {}, ) assert (self.work_path / "data.accurate" / self.system.formula).exists() + + def test_post_model_devi_true_error_candidate(self): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_e_trust_lo": 0.15, + "model_devi_e_trust_hi": 0.25, + "model_devi_f_trust_lo": float("inf"), + "model_devi_f_trust_hi": float("inf"), + "true_error_e_trust_lo": float("inf"), + "true_error_e_trust_hi": float("inf"), + "true_error_f_trust_lo": 0.15, + "true_error_f_trust_hi": 0.25, + "iter_pick_number": 1, + }, + {}, + ) + assert (self.work_path / "data.picked" / self.system.formula).exists() diff --git a/tests/simplify/test_run_model_devi.py b/tests/simplify/test_run_model_devi.py index 2499a9f9c..e928afa8e 100644 --- a/tests/simplify/test_run_model_devi.py +++ b/tests/simplify/test_run_model_devi.py @@ -84,3 +84,29 @@ def test_one_h5(self): }, } dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata) + + def test_true_error(self): + jdata = { + "type_map": ["H"], + "true_error_f_trust_lo": 0.15, + "true_error_f_trust_hi": 0.25, + } + with tempfile.TemporaryDirectory() as remote_root: + mdata = { + "model_devi_command": ( + f"test -d {dpgen.simplify.simplify.rest_data_name}.old" + f"&& touch {dpgen.simplify.simplify.detail_file_name_prefix}" + f"&& touch {dpgen.simplify.simplify.true_error_file_name}" + "&& echo dp" + ), + "model_devi_machine": { + "context_type": "LocalContext", + "batch_type": "shell", + "local_root": "./", + "remote_root": remote_root, + }, + "model_devi_resources": { + "group_size": 1, + }, + } + dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata)