Skip to content

Commit

Permalink
add distance conf filter (#250)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced a `filters` argument for configuration customization in
existing functions.
- Implemented new filtering classes for validating atomic configurations
based on distance and geometric criteria, enhancing configuration
selection options.

- **Tests**
- Added unit tests for the new filtering classes to ensure robust
functionality and validation of atomic configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Aug 21, 2024
1 parent 716674b commit 44d3fd1
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 29 deletions.
42 changes: 42 additions & 0 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from dpgen2.exploration.report import (
conv_styles,
)
from dpgen2.exploration.selector import (
conf_filter_styles,
)
from dpgen2.fp import (
fp_styles,
)
Expand Down Expand Up @@ -174,6 +177,25 @@ def variant_conf():
)


def variant_filter():
doc = "the type of the configuration filter."
var_list = []
for kk in conf_filter_styles.keys():
var_list.append(
Argument(
kk,
dict,
conf_filter_styles[kk].args(),
doc="Configuration filter of type %s" % kk,
)
)
return Variant(
"type",
var_list,
doc=doc,
)


def lmp_args():
doc_config = "Configuration of lmp exploration"
doc_max_numb_iter = "Maximum number of iterations per stage"
Expand All @@ -189,6 +211,7 @@ def lmp_args():
"Then each stage is defined by a list of exploration task groups. "
"Each task group is described in :ref:`the task group definition<task_group_sec>` "
)
doc_filters = "A list of configuration filters"

return [
Argument(
Expand Down Expand Up @@ -227,6 +250,15 @@ def lmp_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument(
"filters",
list,
[],
[variant_filter()],
optional=True,
default=[],
doc=doc_filters,
),
]


Expand Down Expand Up @@ -272,6 +304,7 @@ def caly_args():
"Then each stage is defined by a list of exploration task groups. "
"Each task group is described in :ref:`the task group definition<task_group_sec>` "
)
doc_filters = "A list of configuration filters"

return [
Argument(
Expand Down Expand Up @@ -310,6 +343,15 @@ def caly_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument(
"filters",
list,
[],
[variant_filter()],
optional=True,
default=[],
doc=doc_filters,
),
]


Expand Down
20 changes: 20 additions & 0 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import os
import pickle
import re
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -70,7 +73,9 @@
ExplorationScheduler,
)
from dpgen2.exploration.selector import (
ConfFilters,
ConfSelectorFrames,
conf_filter_styles,
)
from dpgen2.exploration.task import (
CustomizedLmpTemplateTaskGroup,
Expand Down Expand Up @@ -272,13 +277,25 @@ def make_naive_exploration_scheduler(
)


def get_conf_filters(config):
conf_filters = None
if len(config) > 0:
conf_filters = ConfFilters()
for c in config:
c = deepcopy(c)
conf_filter = conf_filter_styles[c.pop("type")](**c)
conf_filters.add(conf_filter)
return conf_filters


def make_calypso_naive_exploration_scheduler(config):
model_devi_jobs = config["explore"]["stages"]
fp_task_max = config["fp"]["task_max"]
max_numb_iter = config["explore"]["max_numb_iter"]
fatal_at_max = config["explore"]["fatal_at_max"]
convergence = config["explore"]["convergence"]
output_nopbc = config["explore"]["output_nopbc"]
conf_filters = get_conf_filters(config["explore"]["filters"])
scheduler = ExplorationScheduler()
# report
conv_style = convergence.pop("type")
Expand All @@ -289,6 +306,7 @@ def make_calypso_naive_exploration_scheduler(config):
render,
report,
fp_task_max,
conf_filters,
)

for job_ in model_devi_jobs:
Expand Down Expand Up @@ -329,6 +347,7 @@ def make_lmp_naive_exploration_scheduler(config):
fatal_at_max = config["explore"]["fatal_at_max"]
convergence = config["explore"]["convergence"]
output_nopbc = config["explore"]["output_nopbc"]
conf_filters = get_conf_filters(config["explore"]["filters"])
scheduler = ExplorationScheduler()
# report
conv_style = convergence.pop("type")
Expand All @@ -339,6 +358,7 @@ def make_lmp_naive_exploration_scheduler(config):
render,
report,
fp_task_max,
conf_filters,
)

sys_configs_lmp = []
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/render/traj_render_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_confs(
type_map: Optional[List[str]] = None,
conf_filters: Optional["ConfFilters"] = None,
) -> dpdata.MultiSystems:
del conf_filters # by far does not support conf filters
ntraj = len(trajs)
traj_fmt = "lammps/dump"
ms = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -74,4 +73,11 @@ def get_confs(
ss.nopbc = self.nopbc
ss = ss.sub_system(id_selected[ii])
ms.append(ss)
if conf_filters is not None:
ms2 = dpdata.MultiSystems(type_map=type_map)
for s in ms:
s2 = conf_filters.check(s)
if len(s2) > 0:
ms2.append(s2)
ms = ms2
return ms
11 changes: 11 additions & 0 deletions dpgen2/exploration/selector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@
from .conf_selector_frame import (
ConfSelectorFrames,
)
from .distance_conf_filter import (
BoxLengthFilter,
BoxSkewnessConfFilter,
DistanceConfFilter,
)

conf_filter_styles = {
"distance": DistanceConfFilter,
"box_skewness": BoxSkewnessConfFilter,
"box_length": BoxLengthFilter,
}
27 changes: 4 additions & 23 deletions dpgen2/exploration/selector/conf_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,14 @@ class ConfFilter(ABC):
@abstractmethod
def check(
self,
coords: np.ndarray,
cell: np.ndarray,
atom_types: np.ndarray,
nopbc: bool,
frame: dpdata.System,
) -> bool:
"""Check if the configuration is valid.
Parameters
----------
coords : numpy.array
The coordinates, numpy array of shape natoms x 3
cell : numpy.array
The cell tensor. numpy array of shape 3 x 3
atom_types : numpy.array
The atom types. numpy array of shape natoms
nopbc : bool
If no periodic boundary condition.
frame : dpdata.System
A dpdata.System containing a single frame
Returns
-------
Expand Down Expand Up @@ -62,16 +53,6 @@ def check(
natoms = sum(conf["atom_numbs"]) # type: ignore
selected_idx = np.arange(conf.get_nframes())
for ff in self._filters:
fsel = np.where(
[
ff.check(
conf["coords"][ii],
conf["cells"][ii],
conf["atom_types"],
conf.nopbc,
)
for ii in range(conf.get_nframes())
]
)[0]
fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0]
selected_idx = np.intersect1d(selected_idx, fsel)
return conf.sub_system(selected_idx)
Loading

0 comments on commit 44d3fd1

Please sign in to comment.