From 3d25536796b51ac78c387e266fdc327d7fed7b53 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 27 Oct 2023 11:23:58 -0700 Subject: [PATCH] add class BaseMaker(Maker, ABC) in new atomate2/common/jobs/base.py --- src/atomate2/common/jobs/base.py | 44 +++++++++++++++++++++++++ src/atomate2/forcefields/flows/relax.py | 11 ++++--- src/atomate2/vasp/flows/amset.py | 9 +++-- 3 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 src/atomate2/common/jobs/base.py diff --git a/src/atomate2/common/jobs/base.py b/src/atomate2/common/jobs/base.py new file mode 100644 index 0000000000..97077986bd --- /dev/null +++ b/src/atomate2/common/jobs/base.py @@ -0,0 +1,44 @@ +"""BaseMaker enforces a specific 'make' method signature for all atomate2 makers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from jobflow import Flow, Maker, Response + +if TYPE_CHECKING: + from pymatgen.core import Structure + + +class BaseMaker(Maker, ABC): + """ + Abstract base class for atomate2 Makers. + + This class is designed to enforce a consistent signature for the 'make' method. + All subclasses must implement this method with identical signature so they are + easily exchangeable in Flows. + """ + + @abstractmethod + def make( + self, + structure: Structure, + *args: Any, + **kwargs: Any, + ) -> Response | Flow: + """ + Abstract method for making a job or task. Must be implemented by subclasses. + + Parameters + ---------- + structure : Structure + The structure for the task or job. + prev_dir : str | Path | None, optional + The previous directory path, if applicable. + + Returns + ------- + Response + A jobflow.Response object containing the outcome of the task or job. + """ diff --git a/src/atomate2/forcefields/flows/relax.py b/src/atomate2/forcefields/flows/relax.py index 27ca123927..65f9dadf6e 100644 --- a/src/atomate2/forcefields/flows/relax.py +++ b/src/atomate2/forcefields/flows/relax.py @@ -5,8 +5,9 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -from jobflow import Flow, Maker +from jobflow import Flow +from atomate2.common.jobs.base import BaseMaker from atomate2.forcefields.jobs import CHGNetRelaxMaker, M3GNetRelaxMaker from atomate2.vasp.jobs.core import RelaxMaker @@ -17,7 +18,7 @@ @dataclass -class CHGNetVaspRelaxMaker(Maker): +class CHGNetVaspRelaxMaker(BaseMaker): """ Maker to (pre)relax a structure using CHGNet and then run VASP. @@ -35,7 +36,7 @@ class CHGNetVaspRelaxMaker(Maker): chgnet_maker: CHGNetRelaxMaker = field(default_factory=CHGNetRelaxMaker) vasp_maker: BaseVaspMaker = field(default_factory=RelaxMaker) - def make(self, structure: Structure) -> Flow: + def make(self, structure: Structure, *args, **kwargs) -> Flow: """ Create a flow with a CHGNet (pre)relaxation followed by a VASP relaxation. @@ -58,7 +59,7 @@ def make(self, structure: Structure) -> Flow: @dataclass -class M3GNetVaspRelaxMaker(Maker): +class M3GNetVaspRelaxMaker(BaseMaker): """ Maker to (pre)relax a structure using M3GNet and then run VASP. @@ -76,7 +77,7 @@ class M3GNetVaspRelaxMaker(Maker): m3gnet_maker: M3GNetRelaxMaker = field(default_factory=M3GNetRelaxMaker) vasp_maker: BaseVaspMaker = field(default_factory=RelaxMaker) - def make(self, structure: Structure) -> Flow: + def make(self, structure: Structure, *args, **kwargs) -> Flow: """ Create a flow with a M3GNet (pre)relaxation followed by a VASP relaxation. diff --git a/src/atomate2/vasp/flows/amset.py b/src/atomate2/vasp/flows/amset.py index cfc4ef1966..3ef7881edb 100644 --- a/src/atomate2/vasp/flows/amset.py +++ b/src/atomate2/vasp/flows/amset.py @@ -11,6 +11,7 @@ from atomate2 import SETTINGS from atomate2.amset.jobs import AmsetMaker +from atomate2.common.jobs.base import BaseMaker from atomate2.vasp.flows.core import DoubleRelaxMaker from atomate2.vasp.flows.elastic import ElasticMaker from atomate2.vasp.jobs.amset import ( @@ -59,7 +60,7 @@ @dataclass -class DeformationPotentialMaker(Maker): +class DeformationPotentialMaker(BaseMaker): """ Maker to generate acoustic deformation potentials for amset. @@ -86,8 +87,10 @@ class DeformationPotentialMaker(Maker): def make( self, structure: Structure, + *args, prev_dir: str | Path | None = None, ibands: tuple[list[int], list[int]] = None, + **kwargs, ) -> Flow: """ Make flow to calculate acoustic deformation potentials. @@ -323,7 +326,7 @@ def make( @dataclass -class HSEVaspAmsetMaker(Maker): +class HSEVaspAmsetMaker(BaseMaker): """ Maker to calculate transport properties using AMSET with HSE06 VASP inputs. @@ -385,7 +388,9 @@ class HSEVaspAmsetMaker(Maker): def make( self, structure: Structure, + *args, prev_dir: str | Path | None = None, + **kwargs, ) -> Flow: """ Make flow to calculate electronic transport properties using AMSET and VASP.