Skip to content

Commit

Permalink
Add bulk of the new archive import code
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Sep 28, 2021
1 parent a536201 commit e8f337d
Show file tree
Hide file tree
Showing 36 changed files with 1,257 additions and 307 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ repos:
aiida/repository/.*py|
aiida/tools/graph/graph_traversers.py|
aiida/tools/groups/paths.py|
aiida/tools/archive/.*py|
aiida/tools/importexport/archive/.*py|
aiida/tools/importexport/dbexport/__init__.py|
aiida/tools/importexport/dbimport/backends/.*.py|
Expand Down
5 changes: 5 additions & 0 deletions aiida/backends/sqlalchemy/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
)


class DbGroupNode(Base):
"""Class to store group to nodes relation using SQLA backend."""
__table__ = table_groups_nodes


class DbGroup(Base):
"""Class to store groups using SQLA backend."""

Expand Down
6 changes: 4 additions & 2 deletions aiida/backends/sqlalchemy/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ class DbLink(Base):
Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), index=True
)

input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id')
output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id')
# TODO check these overlaps are correct:
# https://docs.sqlalchemy.org/en/14/errors.html#relationship-x-will-copy-column-q-to-column-p-which-conflicts-with-relationship-s-y
input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id', overlaps='inputs_q,outputs_q')
output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id', overlaps='inputs_q,outputs_q')

label = Column(String(255), index=True, nullable=False)
type = Column(String(255), index=True)
Expand Down
88 changes: 49 additions & 39 deletions aiida/cmdline/commands/cmd_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def inspect(archive, version, meta_data):
if version:
echo.echo(current_version)
elif meta_data:
echo.echo_dictionary(archive_format.read_metadata(archive), sort_keys=False)
with archive_format.get_reader_cls()(archive) as archive_reader:
metadata = archive_reader.get_metadata()
echo.echo_dictionary(metadata, sort_keys=False)
else:
metadata = archive_format.read_metadata(archive)
with archive_format.get_reader_cls()(archive) as archive_reader:
metadata = archive_reader.get_metadata()
echo.echo(
tabulate.tabulate([[name, metadata[key]] for key, name in [
['aiida_version', 'Version aiida'],
Expand All @@ -83,9 +86,7 @@ def inspect(archive, version, meta_data):
echo.echo(
tabulate.tabulate(metadata['entity_counts'].items(), headers=['Entity type', 'Count'], tablefmt='rst')
)
# TODO repo files count
# TODO file size
# TODO conversion info
# TODO archive inspect: repo files count, file size, conversion info


@verdi_archive.command('create')
Expand Down Expand Up @@ -133,10 +134,13 @@ def create(
"""
# pylint: disable=too-many-branches
from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter
from aiida.tools.archive.abstract import get_format
from aiida.tools.archive.create import export
from aiida.tools.importexport.common.exceptions import ArchiveExportError
from aiida.tools.archive.exceptions import ArchiveExportError

archive_format = get_format()

# TODO export all
# TODO allow export all
entities = []

if codes:
Expand Down Expand Up @@ -189,7 +193,7 @@ def create(
set_progress_reporter(None)

try:
export(entities, filename=output_file, **kwargs)
export(entities, filename=output_file, archive_format=archive_format, **kwargs)
except ArchiveExportError as exception:
echo.echo_critical(f'failed to write the archive file. Exception: {exception}')
else:
Expand Down Expand Up @@ -230,7 +234,7 @@ def migrate(input_file, output_file, force, in_place, version):
)

if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member
set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level == logging.INFO))
set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO))
else:
set_progress_reporter(None)

Expand All @@ -255,11 +259,11 @@ def migrate(input_file, output_file, force, in_place, version):
class ExtrasImportCode(Enum):
"""Exit codes for the verdi command line."""
# pylint: disable=invalid-name
keep_existing = 'kcl'
update_existing = 'kcu'
mirror = 'ncu'
none = 'knl'
ask = 'kca'
keep_existing = ('k', 'c', 'l')
update_existing = ('k', 'c', 'u')
mirror = ('n', 'c', 'u')
none = ('k', 'n', 'l')
ask = ('k', 'c', 'a')


@verdi_archive.command('import')
Expand All @@ -281,14 +285,14 @@ class ExtrasImportCode(Enum):
'-e',
'--extras-mode-existing',
type=click.Choice(EXTRAS_MODE_EXISTING),
default='keep_existing',
default='none', # TODO changed this to none (fastest, since it does not require any action)
help='Specify which extras from the export archive should be imported for nodes that are already contained in the '
'database: '
'ask: import all extras and prompt what to do for existing extras. '
'none: do not import any extras.'
'keep_existing: import all extras and keep original value of existing extras. '
'update_existing: import all extras and overwrite value of existing extras. '
'mirror: import all extras and remove any existing extras that are not present in the archive. '
'none: do not import any extras.'
'ask: import all extras and prompt what to do for existing extras. '
)
@click.option(
'-n',
Expand All @@ -313,11 +317,16 @@ class ExtrasImportCode(Enum):
show_default=True,
help='Force migration of archive file archives, if needed.'
)
@click.option(
'-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.'
)
@click.option('--test-run', is_flag=True, help='Determine entities to import, but do not actually import them.')
@options.NON_INTERACTIVE()
@decorators.with_dbenv()
@click.pass_context
def import_archive(
ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive
ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, batch_size,
test_run, non_interactive
):
"""Import data from an AiiDA archive file.
Expand All @@ -327,7 +336,7 @@ def import_archive(
from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter

if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member
set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level == logging.DEBUG))
set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO))
else:
set_progress_reporter(None)

Expand All @@ -339,10 +348,12 @@ def import_archive(

# Shared import key-word arguments
import_kwargs = {
'group': group,
'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value,
'extras_mode_new': extras_mode_new,
'comment_mode': comment_mode,
'import_new_extras': extras_mode_new == 'import',
'merge_extras': ExtrasImportCode[extras_mode_existing].value,
# 'comment_mode': comment_mode,
'batch_size': batch_size,
# 'group': group,
'test_run': test_run,
}

for archive, web_based in all_archives:
Expand All @@ -357,12 +368,14 @@ def _echo_exception(msg: str, exception, warn_only: bool = False):
:param warn_only: If True only print a warning, otherwise calls sys.exit with a non-zero exit status
"""
from aiida.tools.importexport import IMPORT_LOGGER
from aiida.tools.archive.imports import IMPORT_LOGGER
message = f'{msg}: {exception.__class__.__name__}: {str(exception)}'
if warn_only:
# TODO note to the user about running with debug to see traceback
echo.echo_warning(message)
else:
IMPORT_LOGGER.debug('%s', traceback.format_exc())
# TODO change back to debug?
IMPORT_LOGGER.info('%s', traceback.format_exc())
echo.echo_critical(message)


Expand All @@ -372,7 +385,7 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]:
:returns: list of (archive path, whether it is web based)
"""
from aiida.tools.importexport.common.utils import get_valid_import_links
from aiida.tools.archive.common import get_valid_import_links

final_archives = []

Expand Down Expand Up @@ -410,13 +423,11 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr
"""
from aiida.common.folders import SandboxFolder
from aiida.tools.importexport import (
EXPORT_VERSION,
IncompatibleArchiveVersionError,
detect_archive_type,
import_data,
)
from aiida.tools.importexport.archive.migrators import get_migrator
from aiida.tools.archive.abstract import get_format
from aiida.tools.archive.exceptions import IncompatibleArchiveVersionError
from aiida.tools.archive.imports import import_archive as _import_archive

archive_format = get_format()

with SandboxFolder() as temp_folder:

Expand All @@ -435,22 +446,21 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr

echo.echo_report(f'starting import: {archive}')
try:
import_data(archive_path, **import_kwargs)
_import_archive(archive_path, archive_format, **import_kwargs)
except IncompatibleArchiveVersionError as exception:
if try_migration:

echo.echo_report(f'incompatible version detected for {archive}, trying migration')
try:
migrator = get_migrator(detect_archive_type(archive_path))(archive_path)
archive_path = migrator.migrate(
EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath
)
new_path = temp_folder.get_abs_path('migrated_archive.aiida')
archive_format.migrate(archive_path, new_path, archive_format.get_latest_version(), compression=0)
archive_path = new_path
except Exception as exception:
_echo_exception(f'an exception occurred while migrating the archive {archive}', exception)

echo.echo_report('proceeding with import of migrated archive')
try:
import_data(archive_path, **import_kwargs)
_import_archive(archive_path, archive_format, **import_kwargs)
except Exception as exception:
_echo_exception(
f'an exception occurred while trying to import the migrated archive {archive}', exception
Expand Down
20 changes: 14 additions & 6 deletions aiida/cmdline/commands/cmd_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def detect_invalid_nodes():

@verdi_database.command('summary')
def database_summary():
"""Summarise the entities in the database."""
"""Summarise the entities in the database (`-v info` shows more detail)."""
from aiida.cmdline import is_verbose
from aiida.orm import Comment, Computer, Group, Log, Node, QueryBuilder, User
data = {}
Expand All @@ -207,28 +207,32 @@ def database_summary():
query_user = QueryBuilder().append(User, project=['email'])
data['Users'] = {'count': query_user.count()}
if is_verbose():
data['Users']['emails'] = query_user.distinct().all(flat=True)
data['Users']['emails'] = sorted({email for email, in query_user.iterall() if email is not None})

# Computer
query_comp = QueryBuilder().append(Computer, project=['label'])
data['Computers'] = {'count': query_comp.count()}
if is_verbose():
data['Computers']['labels'] = query_comp.distinct().all(flat=True)
data['Computers']['labels'] = sorted({comp for comp, in query_comp.iterall() if comp is not None})

# Node
count = QueryBuilder().append(Node).count()
data['Nodes'] = {'count': count}
if is_verbose():
node_types = QueryBuilder().append(Node, project=['node_type']).distinct().all(flat=True)
node_types = sorted({
typ for typ, in QueryBuilder().append(Node, project=['node_type']).iterall() if typ is not None
})
data['Nodes']['node_types'] = node_types
process_types = QueryBuilder().append(Node, project=['process_type']).distinct().all(flat=True)
process_types = sorted({
typ for typ, in QueryBuilder().append(Node, project=['process_type']).iterall() if typ is not None
})
data['Nodes']['process_types'] = [p for p in process_types if p]

# Group
query_group = QueryBuilder().append(Group, project=['type_string'])
data['Groups'] = {'count': query_group.count()}
if is_verbose():
data['Groups']['type_strings'] = query_group.distinct().all(flat=True)
data['Groups']['type_strings'] = sorted({typ for typ, in query_group.iterall() if typ is not None})

# Comment
count = QueryBuilder().append(Comment).count()
Expand All @@ -238,4 +242,8 @@ def database_summary():
count = QueryBuilder().append(Log).count()
data['Logs'] = {'count': count}

# Links
count = QueryBuilder().append(entity_type='link').count()
data['Links'] = {'count': count}

echo.echo_dictionary(data, sort_keys=False, fmt='yaml')
1 change: 1 addition & 0 deletions aiida/cmdline/utils/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
('aiida.orm', 'Group', 'Group'),
('aiida.orm', 'QueryBuilder', 'QueryBuilder'),
('aiida.orm', 'User', 'User'),
('aiida.orm', 'AuthInfo', 'AuthInfo'),
('aiida.orm', 'load_code', 'load_code'),
('aiida.orm', 'load_computer', 'load_computer'),
('aiida.orm', 'load_group', 'load_group'),
Expand Down
1 change: 1 addition & 0 deletions aiida/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'Entity',
'EntityAttributesMixin',
'EntityExtrasMixin',
'EntityTypes',
'Float',
'FolderData',
'Group',
Expand Down
4 changes: 4 additions & 0 deletions aiida/orm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _(backend_entity):
Note that we do not register on `collections.abc.Sequence` because that will also match strings.
"""
if hasattr(backend_entity, '_asdict'):
# it is a NamedTuple, so return as is
return backend_entity

converted = []

# Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only
Expand Down
16 changes: 15 additions & 1 deletion aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Module for all common top level AiiDA entity classes and methods"""
import abc
import copy
from enum import Enum
import typing

from plumpy.base.utils import call_with_super_check, super_check
Expand All @@ -18,13 +19,26 @@
from aiida.common.lang import classproperty, type_check
from aiida.manage.manager import get_manager

__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin')
__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin', 'EntityTypes')

EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name

_NO_DEFAULT = tuple()


class EntityTypes(Enum):
"""Enum for referring to ORM entities in a backend-agnostic manner."""
AUTHINFO = 'authinfo'
COMMENT = 'comment'
COMPUTER = 'computer'
GROUP = 'group'
LOG = 'log'
NODE = 'node'
USER = 'user'
LINK = 'link'
GROUP_NODE = 'group_node'


class Collection(typing.Generic[EntityType]):
"""Container class that represents the collection of objects of a particular type."""

Expand Down
15 changes: 14 additions & 1 deletion aiida/orm/implementation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
###########################################################################
"""Generic backend related objects"""
import abc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, List

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import (
BackendAuthInfoCollection,
BackendCommentCollection,
Expand Down Expand Up @@ -90,3 +91,15 @@ def get_session(self) -> 'Session':
:return: an instance of :class:`sqlalchemy.orm.session.Session`
"""

@abc.abstractmethod
def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], transaction: Any) -> List[int]:
"""Insert a list of entities into the database, directly into a backend transaction.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing all fields of the backend model,
except the `id` field (a.k.a primary key), which will be generated dynamically
:param transaction: the returned object of the ``self.transaction`` context
:returns: The list of generated primary keys for the entities
"""
Loading

0 comments on commit e8f337d

Please sign in to comment.