diff --git a/aiida_pseudo/cli/install.py b/aiida_pseudo/cli/install.py index e6f1300..1991c7a 100644 --- a/aiida_pseudo/cli/install.py +++ b/aiida_pseudo/cli/install.py @@ -84,7 +84,6 @@ def cmd_install_sssp(version, functional, protocol, traceback): from aiida.orm import Group, QueryBuilder from aiida_pseudo import __version__ - from aiida_pseudo.common import units from aiida_pseudo.groups.family import SsspConfiguration, SsspFamily from .utils import attempt, create_family_from_archive @@ -136,13 +135,10 @@ def cmd_install_sssp(version, functional, protocol, traceback): echo.echo_critical(msg) # Cutoffs are in Rydberg but need to be stored in the family in electronvolt. - cutoffs[element] = { - 'cutoff_wfc': values['cutoff_wfc'] * units.RY_TO_EV, - 'cutoff_rho': values['cutoff_rho'] * units.RY_TO_EV, - } + cutoffs[element] = {'cutoff_wfc': values['cutoff_wfc'], 'cutoff_rho': values['cutoff_rho']} family.description = description - family.set_cutoffs({'normal': cutoffs}) + family.set_cutoffs({'normal': cutoffs}, unit='Ry') echo.echo_success(f'installed `{label}` containing {family.count()} pseudo potentials') @@ -264,6 +260,6 @@ def cmd_install_pseudo_dojo(version, functional, relativistic, protocol, pseudo_ echo.echo_warning(msg) family.description = description - family.set_cutoffs(cutoffs, default_stringency=default_stringency) + family.set_cutoffs(cutoffs, default_stringency=default_stringency, unit='Eh') echo.echo_success(f'installed `{label}` containing {family.count()} pseudo potentials') diff --git a/aiida_pseudo/common/units.py b/aiida_pseudo/common/units.py index 9687214..2d3cb0a 100644 --- a/aiida_pseudo/common/units.py +++ b/aiida_pseudo/common/units.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module with constants for unit conversions.""" +from pint import UnitRegistry -RY_TO_EV = 13.6056917253 # Taken from `qe_tools.constants` v2.0 -HA_TO_EV = RY_TO_EV * 2.0 +# This unit registry singleton should be used to construct new quantities with a unit and to convert them to other units +U = UnitRegistry() diff --git a/aiida_pseudo/groups/family/pseudo_dojo.py b/aiida_pseudo/groups/family/pseudo_dojo.py index 4e9b9fd..30af539 100644 --- a/aiida_pseudo/groups/family/pseudo_dojo.py +++ b/aiida_pseudo/groups/family/pseudo_dojo.py @@ -9,7 +9,6 @@ from aiida.common.exceptions import ParsingError -from aiida_pseudo.common import units from aiida_pseudo.data.pseudo import UpfData, PsmlData, Psp8Data, JthXmlData from ..mixins import RecommendedCutoffMixin from .pseudo import PseudoPotentialFamily @@ -232,9 +231,7 @@ def get_cutoffs_from_djrepo(cls, djrepo, pseudo_type): except KeyError as exception: raise ParsingError(f'stringency `{stringency}` is not defined in the djrepo `hints`') from exception - ecutwfc = ecutwfc * units.HA_TO_EV - ecutrho = ecutwfc * dual - cutoffs[stringency] = {'cutoff_wfc': ecutwfc, 'cutoff_rho': ecutrho} + cutoffs[stringency] = {'cutoff_wfc': ecutwfc, 'cutoff_rho': ecutwfc * dual} return cutoffs diff --git a/aiida_pseudo/groups/mixins/cutoffs.py b/aiida_pseudo/groups/mixins/cutoffs.py index 849107b..a9022fb 100644 --- a/aiida_pseudo/groups/mixins/cutoffs.py +++ b/aiida_pseudo/groups/mixins/cutoffs.py @@ -5,6 +5,8 @@ from aiida.common.lang import type_check from aiida.plugins import DataFactory +from aiida_pseudo.common.units import U + StructureData = DataFactory('structure') # pylint: disable=invalid-name __all__ = ('RecommendedCutoffMixin',) @@ -17,11 +19,14 @@ class RecommendedCutoffMixin: functions and the charge density. The units have to be in electronvolt. """ + DEFAULT_UNIT = 'eV' + _key_cutoffs = '_cutoffs' + _key_cutoffs_unit = '_cutoffs_unit' _key_default_stringency = '_default_stringency' - @classmethod - def validate_cutoffs(cls, elements: set, cutoffs: dict) -> None: + @staticmethod + def validate_cutoffs(elements: set, cutoffs: dict) -> None: """Validate a cutoff dictionary for a given set of elements. :param elements: set of elements for which to validate the cutoffs dictionary. @@ -57,6 +62,22 @@ def validate_cutoffs(cls, elements: set, cutoffs: dict) -> None: f'invalid cutoff values for stringency `{stringency}` and element {element}: {values}' ) + @staticmethod + def validate_cutoffs_unit(unit: str) -> None: + """Validate the cutoffs unit. + + The unit should be a name that is recognized by the ``pint`` library to be a unit of energy. + + :raises ValueError: if an invalid unit is specified. + """ + type_check(unit, str) + + if unit not in U: + raise ValueError(f'`{unit}` is not a valid unit.') + + if not U.Quantity(1, unit).check('[energy]'): + raise ValueError(f'`{unit}` is not a valid energy unit.') + def validate_stringency(self, stringency: str) -> None: """Validate a cutoff stringency. @@ -94,7 +115,7 @@ def get_cutoff_stringencies(self) -> tuple: """ return tuple(self._get_cutoffs().keys()) - def set_cutoffs(self, cutoffs: dict, default_stringency: str = None) -> None: + def set_cutoffs(self, cutoffs: dict, default_stringency: str = None, unit: str = None) -> None: """Set the recommended cutoffs for the pseudos in this family. .. note:: units of the cutoffs should be in electronvolt. @@ -107,9 +128,13 @@ def set_cutoffs(self, cutoffs: dict, default_stringency: str = None) -> None: :param default_stringency: the default stringency to be used when ``get_recommended_cutoffs`` is called. If is possible to not specify this if and only if the cutoffs only contain a single stringency set. That one will then automatically be set as default. + :param unit: string definition of a unit of energy as recognized by the ``UnitRegistry`` of the ``pint`` lib. :raises ValueError: if the cutoffs have an invalid format or the default stringency is invalid. """ + unit = unit or self.DEFAULT_UNIT + self.validate_cutoffs(set(self.elements), cutoffs) + self.validate_cutoffs_unit(unit) if default_stringency is None and len(cutoffs) != 1: raise ValueError('have to explicitly specify a default stringency when specifying multiple cutoff sets.') @@ -117,6 +142,7 @@ def set_cutoffs(self, cutoffs: dict, default_stringency: str = None) -> None: default_stringency = default_stringency or list(cutoffs.keys())[0] self.set_extra(self._key_cutoffs, cutoffs) + self.set_extra(self._key_cutoffs_unit, unit) self.set_extra(self._key_default_stringency, default_stringency) def get_cutoffs(self, stringency=None) -> Union[dict, None]: @@ -133,7 +159,7 @@ def get_cutoffs(self, stringency=None) -> Union[dict, None]: except KeyError as exception: raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception - def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=None): + def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=None, unit=None): """Return tuple of recommended wavefunction and density cutoffs for the given elements or ``StructureData``. .. note:: at least one and only one of arguments ``elements`` or ``structure`` should be passed. @@ -141,8 +167,10 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N :param elements: single or tuple of elements. :param structure: a ``StructureData`` node. :param stringency: optional stringency if different from the default. + :param unit: string definition of a unit of energy as recognized by the ``UnitRegistry`` of the ``pint`` lib. :return: tuple of recommended wavefunction and density cutoff. :raises ValueError: if the requested stringency is not defined for this family. + :raises ValueError: if optional unit specified is invalid. """ if (elements is None and structure is None) or (elements is not None and structure is not None): raise ValueError('at least one and only one of `elements` or `structure` should be defined') @@ -150,6 +178,9 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N type_check(elements, (tuple, str), allow_none=True) type_check(structure, StructureData, allow_none=True) + if unit is not None: + self.validate_cutoffs_unit(unit) + if structure is not None: symbols = structure.get_symbols_set() elif isinstance(elements, tuple): @@ -162,8 +193,21 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N cutoffs = self.get_cutoffs(stringency=stringency) for element in symbols: - values = cutoffs[element] + + if unit is not None: + current_unit = self.get_cutoffs_unit() + values = {k: U.Quantity(v, current_unit).to(unit).to_tuple()[0] for k, v in cutoffs[element].items()} + else: + values = cutoffs[element] + cutoffs_wfc.append(values['cutoff_wfc']) cutoffs_rho.append(values['cutoff_rho']) return (max(cutoffs_wfc), max(cutoffs_rho)) + + def get_cutoffs_unit(self) -> str: + """Return the cutoffs unit. + + :return: the string representation of the unit of the cutoffs. + """ + return self.get_extra(self._key_cutoffs_unit, self.DEFAULT_UNIT) diff --git a/setup.json b/setup.json index d71ab66..8fc9e6d 100644 --- a/setup.json +++ b/setup.json @@ -39,6 +39,7 @@ "aiida-core~=1.4", "click~=7.0", "click-completion~=0.5", + "pint~=0.16.1", "requests~=2.20", "sqlalchemy<1.4" ], diff --git a/tests/conftest.py b/tests/conftest.py index f6663f8..afda81f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -117,7 +117,7 @@ def _get_pseudo_potential_data(element='Ar', entry_point=None) -> PseudoPotentia @pytest.fixture def get_pseudo_family(tmpdir, filepath_pseudos): - """Return a factory for a `PseudoPotentialFamily` instance.""" + """Return a factory for a ``PseudoPotentialFamily`` instance.""" def _get_pseudo_family( label='family', @@ -125,6 +125,7 @@ def _get_pseudo_family( pseudo_type=PseudoPotentialData, elements=None, cutoffs=None, + unit=None, default_stringency=None ) -> PseudoPotentialFamily: """Return an instance of `PseudoPotentialFamily` or subclass containing the given elements. @@ -132,6 +133,7 @@ def _get_pseudo_family( :param elements: optional list of elements to include instead of all the available ones :params cutoffs: optional dictionary of cutoffs to specify. Needs to respect the format expected by the method `aiida_pseudo.groups.mixins.cutoffs.RecommendedCutoffMixin.set_cutoffs`. + :param unit: string definition of a unit of energy as recognized by the ``UnitRegistry`` of the ``pint`` lib. :param default_stringency: string with the default stringency name, if not specified, the first one specified in the ``cutoffs`` argument will be used if specified. :return: the pseudo family @@ -155,7 +157,7 @@ def _get_pseudo_family( if cutoffs is not None and isinstance(family, CutoffsFamily): default_stringency = default_stringency or list(cutoffs.keys())[0] - family.set_cutoffs(cutoffs, default_stringency) + family.set_cutoffs(cutoffs, default_stringency, unit) return family diff --git a/tests/groups/mixins/test_cutoffs.py b/tests/groups/mixins/test_cutoffs.py index a7921d0..e4f69aa 100644 --- a/tests/groups/mixins/test_cutoffs.py +++ b/tests/groups/mixins/test_cutoffs.py @@ -28,7 +28,7 @@ def _get_cutoffs(family, stringencies=('default',)): @pytest.mark.usefixtures('clear_db') def test_get_cutoffs_private(get_pseudo_family, get_cutoffs): - """Test the ``RecommendedCutoffMixin._get_cutoffs`` method.""" + """Test the ``CutoffsFamily._get_cutoffs`` method.""" family = get_pseudo_family(cls=CutoffsFamily) assert family._get_cutoffs() == {} # pylint: disable=protected-access @@ -36,9 +36,22 @@ def test_get_cutoffs_private(get_pseudo_family, get_cutoffs): assert family._get_cutoffs() == get_cutoffs(family) # pylint: disable=protected-access +@pytest.mark.usefixtures('clear_db') +def test_validate_cutoffs_unit(): + """Test the ``CutoffsFamily.validate_cutoffs_unit`` method.""" + with pytest.raises(TypeError): + CutoffsFamily.validate_cutoffs_unit(10) + + with pytest.raises(ValueError, match=r'`invalid` is not a valid unit.'): + CutoffsFamily.validate_cutoffs_unit('invalid') + + with pytest.raises(ValueError, match=r'`watt` is not a valid energy unit.'): + CutoffsFamily.validate_cutoffs_unit('watt') + + @pytest.mark.usefixtures('clear_db') def test_validate_stringency(get_pseudo_family, get_cutoffs): - """Test the ``RecommendedCutoffMixin.validate_stringency`` method.""" + """Test the ``CutoffsFamily.validate_stringency`` method.""" family = get_pseudo_family(cls=CutoffsFamily) with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'): @@ -56,7 +69,7 @@ def test_validate_stringency(get_pseudo_family, get_cutoffs): @pytest.mark.usefixtures('clear_db') def test_get_default_stringency(get_pseudo_family, get_cutoffs): - """Test the ``RecommendedCutoffMixin.get_default_stringency`` method.""" + """Test the ``CutoffsFamily.get_default_stringency`` method.""" family = get_pseudo_family(cls=CutoffsFamily) with pytest.raises(ValueError, match='no default stringency has been defined.'): @@ -71,7 +84,7 @@ def test_get_default_stringency(get_pseudo_family, get_cutoffs): @pytest.mark.usefixtures('clear_db') def test_get_cutoff_stringencies(get_pseudo_family, get_cutoffs): - """Test the ``RecommendedCutoffMixin.get_cutoff_stringencies`` method.""" + """Test the ``CutoffsFamily.get_cutoff_stringencies`` method.""" family = get_pseudo_family(cls=CutoffsFamily) assert family.get_cutoff_stringencies() == () @@ -84,7 +97,7 @@ def test_get_cutoff_stringencies(get_pseudo_family, get_cutoffs): @pytest.mark.usefixtures('clear_db') def test_set_cutoffs(get_pseudo_family): - """Test the `RecommendedCutoffMixin.set_cutoffs` method.""" + """Test the ``CutoffsFamily.set_cutoffs`` method.""" elements = ['Ar', 'He'] family = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=CutoffsFamily, elements=elements) cutoffs = {'default': {element: {'cutoff_wfc': 1.0, 'cutoff_rho': 2.0} for element in elements}} @@ -118,9 +131,21 @@ def test_set_cutoffs(get_pseudo_family): family.set_cutoffs(cutoffs_invalid, 'default') +@pytest.mark.usefixtures('clear_db') +def test_set_cutoffs_unit_default(get_pseudo_family): + """Test the ``CutoffsFamily.set_cutoffs`` sets a default unit if not specified.""" + elements = ['Ar'] + family = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=CutoffsFamily, elements=elements) + values = {element: {'cutoff_wfc': 1.0, 'cutoff_rho': 2.0} for element in elements} + cutoffs = {'default': values} + + family.set_cutoffs(cutoffs) + assert family.get_cutoffs_unit() == CutoffsFamily.DEFAULT_UNIT + + @pytest.mark.usefixtures('clear_db') def test_set_cutoffs_auto_default(get_pseudo_family): - """Test the `RecommendedCutoffMixin.set_cutoffs` method when not specifying explicit default. + """Test the ``CutoffsFamily.set_cutoffs`` method when not specifying explicit default. If the cutoffs specified only contain a single set, the `default_stringency` is determined automatically. """ @@ -139,7 +164,7 @@ def test_set_cutoffs_auto_default(get_pseudo_family): @pytest.mark.usefixtures('clear_db') def test_get_cutoffs(get_pseudo_family): - """Test the `RecommendedCutoffMixin.get_cutoffs` method.""" + """Test the ``CutoffsFamily.get_cutoffs`` method.""" elements = ['Ar', 'He'] family = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=CutoffsFamily, elements=elements) cutoffs = {'default': {element: {'cutoff_wfc': 1.0, 'cutoff_rho': 2.0} for element in elements}} @@ -157,7 +182,7 @@ def test_get_cutoffs(get_pseudo_family): @pytest.mark.usefixtures('clear_db') def test_get_recommended_cutoffs(get_pseudo_family, generate_structure): - """Test the `RecommendedCutoffMixin.get_recommended_cutoffs` method.""" + """Test the ``CutoffsFamily.get_recommended_cutoffs`` method.""" elements = ['Ar', 'He'] cutoffs = { 'default': { @@ -171,8 +196,13 @@ def test_get_recommended_cutoffs(get_pseudo_family, generate_structure): }, } } - family = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=CutoffsFamily, elements=elements) - family.set_cutoffs(cutoffs, 'default') + family = get_pseudo_family( + label='SSSP/1.0/PBE/efficiency', + cls=CutoffsFamily, + elements=elements, + cutoffs=cutoffs, + default_stringency='default' + ) structure = generate_structure(elements=elements) with pytest.raises(ValueError): @@ -201,3 +231,51 @@ def test_get_recommended_cutoffs(get_pseudo_family, generate_structure): expected = cutoffs['default']['He'] structure = generate_structure(elements=['He1', 'He2']) assert family.get_recommended_cutoffs(structure=structure) == (expected['cutoff_wfc'], expected['cutoff_rho']) + + +@pytest.mark.usefixtures('clear_db') +def test_get_recommended_cutoffs_unit(get_pseudo_family): + """Test the ``CutoffsFamily.get_recommended_cutoffs`` method with the ``unit`` argument.""" + elements = ['Ar', 'He'] + unit = 'Eh' + cutoffs = { + 'default': { + 'Ar': { + 'cutoff_wfc': 1.0, + 'cutoff_rho': 2.0 + }, + 'He': { + 'cutoff_wfc': 3.0, + 'cutoff_rho': 8.0 + }, + } + } + family = get_pseudo_family( + label='SSSP/1.0/PBE/efficiency', + cls=CutoffsFamily, + elements=elements, + cutoffs=cutoffs, + default_stringency='default', + unit=unit + ) + + cutoffs_ar = cutoffs['default']['Ar'] + + expected = (cutoffs_ar['cutoff_wfc'], cutoffs_ar['cutoff_rho']) + assert family.get_recommended_cutoffs(elements='Ar') == expected + + expected = (cutoffs_ar['cutoff_wfc'] * 2, cutoffs_ar['cutoff_rho'] * 2) + assert family.get_recommended_cutoffs(elements='Ar', unit='Ry') == expected + + +@pytest.mark.usefixtures('clear_db') +def test_get_cutoffs_unit(get_pseudo_family, get_cutoffs): + """Test the ``CutoffsFamily.get_cutoffs_unit`` method.""" + family = get_pseudo_family(cls=CutoffsFamily) + assert family.get_cutoffs_unit() == 'eV' + + family.set_cutoffs(get_cutoffs(family), unit='Ry') + assert family.get_cutoffs_unit() == 'Ry' + + family.set_cutoffs(get_cutoffs(family), unit='Eh') + assert family.get_cutoffs_unit() == 'Eh'