Skip to content

Commit

Permalink
Add custom PathOrUrl parameter type.
Browse files Browse the repository at this point in the history
Instead of using the `PathOrUrl` parameter type from `aiida-core` we
add our own implementation that uses the `requests` library instead of
`urllib`, but most importantly, it attempts to retrieve the URL in the
`attempt` context manager that will catch any exceptions and properly
display the error and exit the command, just like other important parts
of the command that can fail.
  • Loading branch information
sphuber committed May 3, 2021
1 parent 2fb8330 commit ee025ea
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 66 deletions.
111 changes: 57 additions & 54 deletions aiida_pseudo/cli/install.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# -*- coding: utf-8 -*-
"""Command to install a pseudo potential family."""
import json
import os
import pathlib
import shutil
import tempfile
import urllib.request

import click
import requests

from aiida.cmdline.utils import decorators, echo
from aiida.cmdline.params import options as options_core
from aiida.cmdline.params import types

from aiida_pseudo.groups.family import PseudoDojoConfiguration, SsspConfiguration
from .params import options
from .params import options, types
from .root import cmd_root


Expand All @@ -24,63 +22,72 @@ def cmd_install():


@cmd_install.command('family')
@click.argument('archive_or_folder', type=types.PathOrUrl(exists=True, file_okay=True))
@click.argument('archive', type=types.PathOrUrl(exists=True, file_okay=True))
@click.argument('label', type=click.STRING)
@options_core.DESCRIPTION(help='Description for the family.')
@options.ARCHIVE_FORMAT()
@options.PSEUDO_TYPE()
@options.TRACEBACK()
@decorators.with_dbenv()
def cmd_install_family(archive_or_folder, label, description, archive_format, pseudo_type, traceback): # pylint: disable=too-many-arguments
"""Install a standard pseudo potential family from a FOLDER or an ARCHIVE (on the local file system or from a URL).
def cmd_install_family(archive, label, description, archive_format, pseudo_type, traceback): # pylint: disable=too-many-arguments
"""Install a standard pseudopotential family from an ARCHIVE.
The command will attempt first to recognize the passed ARCHIVE_FOLDER as a folder in the local system. If not,
`archive_or_folder` is assumed to be an archive and the command will attempt to infer the archive format from the
filename extension of the ARCHIVE. If this fails, the archive format can be specified explicitly with the archive
format option, which will also display which formats are supported.
The ARCHIVE can be a (compressed) archive of a directory containing the pseudopotentials on the local file system or
provided by an HTTP URL. Alternatively, it can be a normal directory on the local file system. The (unarchived)
directory should only contain the pseudopotential files and they cannot be in any subdirectory. In addition,
depending on the chosen pseudopotential type (see the option `-P/--pseudo-type`) there can be additional
requirements on the pseudopotential file and filename format.
By default, the command will create a base `PseudoPotentialFamily`, but the type can be changed with the pseudos
type option. If the base type is used, the pseudo potential files in the archive *have* to have filenames that
strictly follow the format `ELEMENT.EXTENSION`, because otherwise the element cannot be determined automatically.
If the ARCHIVE corresponds to a (compressed) archive, the command will attempt to infer the archive format from the
filename extension of the ARCHIVE. If this fails, the archive format can be specified explicitly with the archive
format option `-F/--archive-format`, which will also display which formats are supported. These format suffixes
follow the naming of the `shutil.unpack_archive` standard library method.
Once the ARCHIVE is downloaded, uncompressed and unarchived into a directory on the local file system, the command
will create a `PseudoPotentialFamily` instance where the type of the pseudopotential data nodes that are stored
within it is set through the `-P/--pseudo-type` option. If the default `pseudo` (which corresponds to the data
plugin `PseudoPotentialData`) the pseudopotential files in the archive *have* to have filenames that strictly follow
the format `ELEMENT.EXTENSION`, or the creation of the family will fail. This is because for the default
pseudopotential type, the format of the file is unknown and the family requires the element to be known, which in
this case can then only be parsed from the filename.
"""
from .utils import attempt, create_family_from_archive
from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily
from .utils import attempt, create_family_from_archive

# `archive_or_folder` is a simple string, containing the name of the folder / file / url location.

if pathlib.Path(archive_or_folder).is_dir():
try:
family = PseudoPotentialFamily.create_from_folder(archive_or_folder, label, pseudo_type=pseudo_type)
except ValueError as exception:
raise OSError(f'failed to parse pseudos from `{archive_or_folder}`: {exception}') from exception
elif pathlib.Path(archive_or_folder).is_file():
archive = archive_or_folder
if isinstance(archive, pathlib.Path) and archive.is_dir():
with attempt(f'creating a pseudopotential family from directory `{archive}`...', include_traceback=traceback):
family = PseudoPotentialFamily.create_from_folder(archive, label, pseudo_type=pseudo_type)
elif isinstance(archive, pathlib.Path) and archive.is_file():
with attempt('unpacking archive and parsing pseudos... ', include_traceback=traceback):
family = create_family_from_archive(
PseudoPotentialFamily, label, pathlib.Path(archive), fmt=archive_format, pseudo_type=pseudo_type
PseudoPotentialFamily, label, archive, fmt=archive_format, pseudo_type=pseudo_type
)
else:
# The file of the url must be copied to a local temporary file. Maybe better ways to do it?
# The `create_family_from_archive` does currently not accept filelike objects because the underlying
# `shutil.unpack_archive` does not. Likewise, `unpack_archive` will attempt to deduce the archive format
# from the filename extension, so it is important we maintain the original filename.
# Of course if this fails, users can specify the archive format explicitly wiht the corresponding option.
with urllib.request.urlopen(archive_or_folder) as archive:
suffix = os.path.basename(archive.url)
with tempfile.NamedTemporaryFile(mode='w+b', suffix=suffix) as handle:
shutil.copyfileobj(archive, handle)
handle.flush()
with attempt('unpacking archive and parsing pseudos... ', include_traceback=traceback):
family = create_family_from_archive(
PseudoPotentialFamily,
label,
pathlib.Path(handle.name),
fmt=archive_format,
pseudo_type=pseudo_type
)
# At this point, we can assume that it is not a valid filepath on disk, but rather a URL and the ``archive``
# variable will contain the result objects from the ``requests`` library. The validation of the URL will already
# have been done by the ``PathOrUrl`` parameter type, so the URL is reachabel. The content of the URL must be
# copied to a local temporary file because `create_family_from_archive` does currently not accept filelike
# objects, because in turn the underlying `shutil.unpack_archive` does not. In addition, `unpack_archive` will
# attempt to deduce the archive format from the filename extension, so it is important we maintain the original
# filename. Of course if this fails, users can specify the archive format explicitly with the corresponding
# option. We get the filename by converting the URL to a ``Path`` object and taking the filename, using that as
# a suffix for the temporary file that is generated on disk to copy the content to.
suffix = pathlib.Path(archive.url).name
with tempfile.NamedTemporaryFile(mode='w+b', suffix=suffix) as handle:
handle.write(archive.content)
handle.flush()

with attempt('unpacking archive and parsing pseudos... ', include_traceback=traceback):
family = create_family_from_archive(
PseudoPotentialFamily,
label,
pathlib.Path(handle.name),
fmt=archive_format,
pseudo_type=pseudo_type
)

family.description = description
echo.echo_success(f'installed `{label}` containing {family.count()} pseudo potentials')
echo.echo_success(f'installed `{label}` containing {family.count()} pseudopotentials')


def download_sssp(
Expand All @@ -96,23 +103,21 @@ def download_sssp(
:param filepath_metadata: absolute filepath to write the metadata file to.
:param traceback: boolean, if true, print the traceback when an exception occurs.
"""
import requests

from aiida_pseudo.groups.family import SsspFamily
from .utils import attempt

url_sssp_base = 'https://legacy-archive.materialscloud.org/file/2018.0001/v4/'
url_archive = f"{url_sssp_base}/{SsspFamily.format_configuration_filename(configuration, 'tar.gz')}"
url_metadata = f"{url_sssp_base}/{SsspFamily.format_configuration_filename(configuration, 'json')}"

with attempt('downloading selected pseudo potentials archive... ', include_traceback=traceback):
with attempt('downloading selected pseudopotentials archive... ', include_traceback=traceback):
response = requests.get(url_archive)
response.raise_for_status()
with open(filepath_archive, 'wb') as handle:
handle.write(response.content)
handle.flush()

with attempt('downloading selected pseudo potentials metadata... ', include_traceback=traceback):
with attempt('downloading selected pseudopotentials metadata... ', include_traceback=traceback):
response = requests.get(url_metadata)
response.raise_for_status()
with open(filepath_metadata, 'wb') as handle:
Expand All @@ -133,23 +138,21 @@ def download_pseudo_dojo(
:param filepath_metadata: absolute filepath to write the metadata archive to.
:param traceback: boolean, if true, print the traceback when an exception occurs.
"""
import requests

from aiida_pseudo.groups.family import PseudoDojoFamily
from .utils import attempt

label = PseudoDojoFamily.format_configuration_label(configuration)
url_archive = PseudoDojoFamily.get_url_archive(label)
url_metadata = PseudoDojoFamily.get_url_metadata(label)

with attempt('downloading selected pseudo potentials archive... ', include_traceback=traceback):
with attempt('downloading selected pseudopotentials archive... ', include_traceback=traceback):
response = requests.get(url_archive)
response.raise_for_status()
with open(filepath_archive, 'wb') as handle:
handle.write(response.content)
handle.flush()

with attempt('downloading selected pseudo potentials metadata archive... ', include_traceback=traceback):
with attempt('downloading selected pseudopotentials metadata archive... ', include_traceback=traceback):
response = requests.get(url_metadata)
response.raise_for_status()
with open(filepath_metadata, 'wb') as handle:
Expand Down Expand Up @@ -230,7 +233,7 @@ def cmd_install_sssp(version, functional, protocol, download_only, traceback):
family.description = description
family.set_cutoffs(cutoffs, 'normal', unit='Ry')

echo.echo_success(f'installed `{label}` containing {family.count()} pseudo potentials')
echo.echo_success(f'installed `{label}` containing {family.count()} pseudopotentials')


@cmd_install.command('pseudo-dojo')
Expand Down Expand Up @@ -353,4 +356,4 @@ def cmd_install_pseudo_dojo(
family.set_cutoffs(cutoff_values, stringency, unit='Eh')
family.set_default_stringency(default_stringency)

echo.echo_success(f'installed `{label}` containing {family.count()} pseudo potentials')
echo.echo_success(f'installed `{label}` containing {family.count()} pseudopotentials')
24 changes: 12 additions & 12 deletions aiida_pseudo/cli/params/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@
help='Select the pseudopotential file format of the installed configuration.'
)

PSEUDO_TYPE = OverridableOption(
'-P',
'--pseudo-type',
type=PseudoPotentialTypeParam(),
default='pseudo',
show_default=True,
help=(
'Select the pseudopotential type to be used for the family. Should be the entry point name of a '
'subclass of `PseudoPotentialData`, for example, `pseudo.upf`.'
)
)

STRINGENCY = OverridableOption(
'-s', '--stringency', type=click.STRING, required=False, help='Stringency level for the recommended cutoffs.'
)
Expand Down Expand Up @@ -91,15 +103,3 @@
'pseudopotential family.'
)
)

PSEUDO_TYPE = OverridableOption(
'-P',
'--pseudo-type',
type=PseudoPotentialTypeParam(),
default='pseudo',
show_default=True,
help=(
'Select the pseudopotential type to be used for the family. Should be the entry point name of a '
'subclass of `PseudoPotentialData`.'
)
)
29 changes: 29 additions & 0 deletions aiida_pseudo/cli/params/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-self-use
"""Custom parameter types for command line interface commands."""
import pathlib
import typing

import click
import requests

from aiida.cmdline.params.types import GroupParamType
from ..utils import attempt

__all__ = ('PseudoPotentialFamilyTypeParam', 'PseudoPotentialFamilyParam', 'PseudoPotentialTypeParam')

Expand Down Expand Up @@ -88,3 +93,27 @@ def complete(self, _, incomplete):
from aiida.plugins.entry_point import get_entry_point_names
entry_points = get_entry_point_names('aiida.groups')
return [(ep, '') for ep in entry_points if (ep.startswith('pseudo.family') and ep.startswith(incomplete))]


class PathOrUrl(click.Path):
"""Extension of ``click``'s ``Path``-type that also supports URLs."""

name = 'PathOrUrl'

def convert(self, value, param, ctx) -> typing.Union[pathlib.Path, bytes]:
"""Convert the string value to the desired value.
If the ``value`` corresponds to a valid path on the local filesystem, return it as a ``pathlib.Path`` instance.
Otherwise, treat it as a URL and try to fetch the content. If successful, the raw retrieved bytes will be
returned.
:param value: the filepath on the local filesystem or a URL.
"""
try:
# Call the method of the super class, which will raise if it ``value`` is not a valid path.
return pathlib.Path(super().convert(value, param, ctx))
except click.exceptions.BadParameter:
with attempt(f'attempting to download data from `{value}`...'):
response = requests.get(value)
response.raise_for_status()
return response

0 comments on commit ee025ea

Please sign in to comment.