Skip to content

Commit

Permalink
Adaptive trust levels (deepmodeling#495)
Browse files Browse the repository at this point in the history
* support adaptive trust level

* update README

* fix bugs in readme

* adaptive lower trust level support percentage of total number of frames

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
2 people authored and Cloudac7 committed Dec 1, 2021
1 parent 8ae9f1e commit 56e6541
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 54 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,13 @@ The bold notation of key (such aas **type_map**) means that it's a necessary key
| **model_devi_skip** | Integer | 0 | Number of structures skipped for fp in each MD
| **model_devi_f_trust_lo** | Float | 0.05 | Lower bound of forces for the selection.
| **model_devi_f_trust_hi** | Float | 0.15 | Upper bound of forces for the selection
| **model_devi_e_trust_lo** | Float | 1e10 | Lower bound of energies for the selection. Recommend to set them a high number, since forces provide more precise information. Special cases such as energy minimization may need this. |
| **model_devi_e_trust_hi** | Float | 1e10 | Upper bound of energies for the selection. |
| **model_devi_v_trust_lo** | Float | 1e10 | Lower bound of virial for the selection. Should be used with DeePMD-kit v2.x |
| **model_devi_v_trust_hi** | Float | 1e10 | Upper bound of virial for the selection. Should be used with DeePMD-kit v2.x |
| model_devi_adapt_trust_lo | Boolean | False | Adaptively determines the lower trust levels of force and virial. This option should be used together with `model_devi_numb_candi_f`, `model_devi_numb_candi_v` and optionally with `model_devi_perc_candi_f` and `model_devi_perc_candi_v`. `dpgen` will make two sets: 1. From the frames with force model deviation lower than `model_devi_f_trust_hi`, select `max(model_devi_numb_candi_f, model_devi_perc_candi_f*n_frames)` frames with largest force model deviation. 2. From the frames with virial model deviation lower than `model_devi_v_trust_hi`, select `max(model_devi_numb_candi_v, model_devi_perc_candi_v*n_frames)` frames with largest virial model deviation. The union of the two sets is made as candidate dataset|
| model_devi_numb_candi_f | Int | 10 | See `model_devi_adapt_trust_lo`.|
| model_devi_numb_candi_v | Int | 0 | See `model_devi_adapt_trust_lo`.|
| model_devi_perc_candi_f | Float | 0.0 | See `model_devi_adapt_trust_lo`.|
| model_devi_perc_candi_v | Float | 0.0 | See `model_devi_adapt_trust_lo`.|
| **model_devi_clean_traj** | Boolean | true | Deciding whether to clean traj folders in MD since they are too large. |
| **model_devi_nopbc** | Boolean | False | Assume open boundary condition in MD simulations. |
| model_devi_activation_func | List of list of string | [["tanh","tanh"],["tanh","gelu"],["gelu","tanh"],["gelu","gelu"]] | Set activation functions for models, length of the List should be the same as `numb_models`, and two elements in the list of string respectively assign activation functions to the embedding and fitting nets within each model. *Backward compatibility*: the orginal "List of String" format is still supported, where embedding and fitting nets of one model use the same activation function, and the length of the List should be the same as `numb_models`|
Expand Down
241 changes: 189 additions & 52 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import scipy.constants as pc
from collections import Counter
from distutils.version import LooseVersion
from typing import List
from numpy.linalg import norm
from dpgen import dlog
from dpgen import SHORT_CMD
Expand Down Expand Up @@ -1321,11 +1322,169 @@ def check_bad_box(conf_name,
raise RuntimeError('unknow key', key)
return is_bad


def _select_by_model_devi_standard(
modd_system_task: List[str],
f_trust_lo : float,
f_trust_hi : float,
v_trust_lo : float,
v_trust_hi : float,
cluster_cutoff : float,
model_devi_skip : int = 0,
detailed_report_make_fp : bool = True,
):
fp_candidate = []
if detailed_report_make_fp:
fp_rest_accurate = []
fp_rest_failed = []
cc = 0
counter = Counter()
counter['candidate'] = 0
counter['failed'] = 0
counter['accurate'] = 0
for tt in modd_system_task :
with warnings.catch_warnings():
warnings.simplefilter("ignore")
all_conf = np.loadtxt(os.path.join(tt, 'model_devi.out'))
for ii in range(all_conf.shape[0]) :
if all_conf[ii][0] < model_devi_skip :
continue
cc = int(all_conf[ii][0])
if cluster_cutoff is None:
if (all_conf[ii][1] < v_trust_hi and all_conf[ii][1] >= v_trust_lo) or \
(all_conf[ii][4] < f_trust_hi and all_conf[ii][4] >= f_trust_lo) :
fp_candidate.append([tt, cc])
counter['candidate'] += 1
elif (all_conf[ii][1] >= v_trust_hi ) or (all_conf[ii][4] >= f_trust_hi ):
if detailed_report_make_fp:
fp_rest_failed.append([tt, cc])
counter['failed'] += 1
elif (all_conf[ii][1] < v_trust_lo and all_conf[ii][4] < f_trust_lo ):
if detailed_report_make_fp:
fp_rest_accurate.append([tt, cc])
counter['accurate'] += 1
else :
raise RuntimeError('md traj %s frame %d with f devi %f does not belong to either accurate, candidiate and failed, it should not happen' % (tt, ii, all_conf[ii][4]))
else:
idx_candidate = np.where(np.logical_and(all_conf[ii][7:] < f_trust_hi, all_conf[ii][7:] >= f_trust_lo))[0]
for jj in idx_candidate:
fp_candidate.append([tt, cc, jj])
counter['candidate'] += len(idx_candidate)
idx_rest_accurate = np.where(all_conf[ii][7:] < f_trust_lo)[0]
if detailed_report_make_fp:
for jj in idx_rest_accurate:
fp_rest_accurate.append([tt, cc, jj])
counter['accurate'] += len(idx_rest_accurate)
idx_rest_failed = np.where(all_conf[ii][7:] >= f_trust_hi)[0]
if detailed_report_make_fp:
for jj in idx_rest_failed:
fp_rest_failed.append([tt, cc, jj])
counter['failed'] += len(idx_rest_failed)

return fp_rest_accurate, fp_candidate, fp_rest_failed, counter



def _select_by_model_devi_adaptive_trust_low(
modd_system_task: List[str],
f_trust_hi : float,
numb_candi_f : int,
perc_candi_f : float,
v_trust_hi : float,
numb_candi_v : int,
perc_candi_v : float,
model_devi_skip : int = 0
):
"""
modd_system_task model deviation tasks belonging to one system
f_trust_hi
numb_candi_f number of candidate due to the f model deviation
perc_candi_f percentage of candidate due to the f model deviation
v_trust_hi
numb_candi_v number of candidate due to the v model deviation
perc_candi_v percentage of candidate due to the v model deviation
model_devi_skip
returns
accur the accurate set
candi the candidate set
failed the failed set
counter counters, number of elements in the sets
f_trust_lo adapted trust level of f
v_trust_lo adapted trust level of v
"""
idx_v = 1
idx_f = 4
accur = set()
candi = set()
failed = []
coll_v = []
coll_f = []
for tt in modd_system_task:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model_devi = np.loadtxt(os.path.join(tt, 'model_devi.out'))
for ii in range(model_devi.shape[0]) :
if model_devi[ii][0] < model_devi_skip :
continue
cc = int(model_devi[ii][0])
# tt: name of task folder
# cc: time step of the frame
md_v = model_devi[ii][idx_v]
md_f = model_devi[ii][idx_f]
if md_f > f_trust_hi or md_v > v_trust_hi:
failed.append([tt, cc])
else:
coll_v.append([model_devi[ii][idx_v], tt, cc])
coll_f.append([model_devi[ii][idx_f], tt, cc])
# now accur takes all non-failed frames,
# will be substracted by candidate lat er
accur.add((tt, cc))
# sort
coll_v.sort()
coll_f.sort()
assert(len(coll_v) == len(coll_f))
# calcuate numbers
numb_candi_v = max(numb_candi_v, int(perc_candi_v * 0.01 * len(coll_v)))
numb_candi_f = max(numb_candi_f, int(perc_candi_f * 0.01 * len(coll_f)))
# adjust number of candidate
if len(coll_v) < numb_candi_v:
numb_candi_v = len(coll_v)
if len(coll_f) < numb_candi_f:
numb_candi_f = len(coll_f)
# compute trust lo
if numb_candi_v == 0:
v_trust_lo = v_trust_hi
else:
v_trust_lo = coll_v[-numb_candi_v][0]
if numb_candi_f == 0:
f_trust_lo = f_trust_hi
else:
f_trust_lo = coll_f[-numb_candi_f][0]
# add to candidate set
for ii in range(len(coll_v) - numb_candi_v, len(coll_v)):
candi.add(tuple(coll_v[ii][1:]))
for ii in range(len(coll_f) - numb_candi_f, len(coll_f)):
candi.add(tuple(coll_f[ii][1:]))
# accurate set is substracted by the candidate set
accur = accur - candi
# convert to list
candi = [list(ii) for ii in candi]
accur = [list(ii) for ii in accur]
# counters
counter = Counter()
counter['candidate'] = len(candi)
counter['failed'] = len(failed)
counter['accurate'] = len(accur)

return accur, candi, failed, counter, f_trust_lo, v_trust_lo


def _make_fp_vasp_inner (modd_path,
work_path,
model_devi_skip,
e_trust_lo,
e_trust_hi,
v_trust_lo,
v_trust_hi,
f_trust_lo,
f_trust_hi,
fp_task_min,
Expand All @@ -1352,63 +1511,41 @@ def _make_fp_vasp_inner (modd_path,

fp_tasks = []
cluster_cutoff = jdata['cluster_cutoff'] if jdata.get('use_clusters', False) else None
model_devi_adapt_trust_lo = jdata.get('model_devi_adapt_trust_lo', False)
# skip save *.out if detailed_report_make_fp is False, default is True
detailed_report_make_fp = jdata.get("detailed_report_make_fp", True)
# skip bad box criteria
skip_bad_box = jdata.get('fp_skip_bad_box')
# skip discrete structure in cluster
fp_cluster_vacuum = jdata.get('fp_cluster_vacuum',None)
for ss in system_index :
fp_candidate = []
if detailed_report_make_fp:
fp_rest_accurate = []
fp_rest_failed = []
modd_system_glob = os.path.join(modd_path, 'task.' + ss + '.*')
modd_system_task = glob.glob(modd_system_glob)
modd_system_task.sort()
cc = 0
counter = Counter()
counter['candidate'] = 0
counter['failed'] = 0
counter['accurate'] = 0
for tt in modd_system_task :
with warnings.catch_warnings():
warnings.simplefilter("ignore")
all_conf = np.loadtxt(os.path.join(tt, 'model_devi.out'))
for ii in range(all_conf.shape[0]) :
if all_conf[ii][0] < model_devi_skip :
continue
cc = int(all_conf[ii][0])
if cluster_cutoff is None:
if (all_conf[ii][1] < e_trust_hi and all_conf[ii][1] >= e_trust_lo) or \
(all_conf[ii][4] < f_trust_hi and all_conf[ii][4] >= f_trust_lo) :
fp_candidate.append([tt, cc])
counter['candidate'] += 1
elif (all_conf[ii][1] >= e_trust_hi ) or (all_conf[ii][4] >= f_trust_hi ):
if detailed_report_make_fp:
fp_rest_failed.append([tt, cc])
counter['failed'] += 1
elif (all_conf[ii][1] < e_trust_lo and all_conf[ii][4] < f_trust_lo ):
if detailed_report_make_fp:
fp_rest_accurate.append([tt, cc])
counter['accurate'] += 1
else :
raise RuntimeError('md traj %s frame %d with f devi %f does not belong to either accurate, candidiate and failed, it should not happen' % (tt, ii, all_conf[ii][4]))
else:
idx_candidate = np.where(np.logical_and(all_conf[ii][7:] < f_trust_hi, all_conf[ii][7:] >= f_trust_lo))[0]
for jj in idx_candidate:
fp_candidate.append([tt, cc, jj])
counter['candidate'] += len(idx_candidate)
idx_rest_accurate = np.where(all_conf[ii][7:] < f_trust_lo)[0]
if detailed_report_make_fp:
for jj in idx_rest_accurate:
fp_rest_accurate.append([tt, cc, jj])
counter['accurate'] += len(idx_rest_accurate)
idx_rest_failed = np.where(all_conf[ii][7:] >= f_trust_hi)[0]
if detailed_report_make_fp:
for jj in idx_rest_failed:
fp_rest_failed.append([tt, cc, jj])
counter['failed'] += len(idx_rest_failed)

# assumed e -> v
if not model_devi_adapt_trust_lo:
fp_rest_accurate, fp_candidate, fp_rest_failed, counter \
= _select_by_model_devi_standard(
modd_system_task,
f_trust_lo, f_trust_hi,
v_trust_lo, v_trust_hi,
cluster_cutoff,
model_devi_skip,
detailed_report_make_fp = detailed_report_make_fp)
else:
numb_candi_f = jdata.get('model_devi_numb_candi_f', 10)
numb_candi_v = jdata.get('model_devi_numb_candi_v', 0)
perc_candi_f = jdata.get('model_devi_perc_candi_f', 0.)
perc_candi_v = jdata.get('model_devi_perc_candi_v', 0.)
fp_rest_accurate, fp_candidate, fp_rest_failed, counter, f_trust_lo_ad, v_trust_lo_ad \
= _select_by_model_devi_adaptive_trust_low(
modd_system_task,
f_trust_hi, numb_candi_f, perc_candi_f,
v_trust_hi, numb_candi_v, perc_candi_v,
model_devi_skip = model_devi_skip)
dlog.info("system {0:s} {1:9s} : f_trust_lo {2:6.3f} v_trust_lo {3:6.3f}".format(ss, 'adapted', f_trust_lo_ad, v_trust_lo_ad))

# print a report
fp_sum = sum(counter.values())
for cc_key, cc_value in counter.items():
Expand Down Expand Up @@ -1768,8 +1905,8 @@ def _make_fp_vasp_configs(iter_index,
jdata):
fp_task_max = jdata['fp_task_max']
model_devi_skip = jdata['model_devi_skip']
e_trust_lo = 1e+10
e_trust_hi = 1e+10
v_trust_lo = jdata.get('model_devi_v_trust_lo', 1e10)
v_trust_hi = jdata.get('model_devi_v_trust_hi', 1e10)
f_trust_lo = jdata['model_devi_f_trust_lo']
f_trust_hi = jdata['model_devi_f_trust_hi']
type_map = jdata['type_map']
Expand All @@ -1787,7 +1924,7 @@ def _make_fp_vasp_configs(iter_index,
# make configs
fp_tasks = _make_fp_vasp_inner(modd_path, work_path,
model_devi_skip,
e_trust_lo, e_trust_hi,
v_trust_lo, v_trust_hi,
f_trust_lo, f_trust_hi,
task_min, fp_task_max,
[],
Expand Down

0 comments on commit 56e6541

Please sign in to comment.