Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping tests: restructure, refactor, decouple session and engine, improve #12765

Merged
merged 15 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
141 changes: 141 additions & 0 deletions test/unit/data/model/mapping/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from uuid import uuid4

import pytest
from sqlalchemy import (
delete,
select,
UniqueConstraint,
)


class AbstractBaseTest(ABC):
@pytest.fixture
def cls_(self):
"""
Return class under test.
Assumptions: if the class under test is Foo, then the class grouping
the tests should be a subclass of BaseTest, named TestFoo.
"""
prefix = len("Test")
class_name = self.__class__.__name__[prefix:]
return getattr(self.get_model(), class_name)

@abstractmethod
def get_model(self):
pass


def dbcleanup_wrapper(session, obj, where_clause=None):
with dbcleanup(session, obj, where_clause):
yield obj


@contextmanager
def dbcleanup(session, obj, where_clause=None):
"""
Use the session to store obj in database; delete from database on exit, bypassing the session.

If obj does not have an id field, a SQLAlchemy WHERE clause should be provided to construct
a custom select statement.
"""
return_id = where_clause is None

try:
obj_id = persist(session, obj, return_id)
yield obj_id
finally:
table = obj.__table__
if where_clause is None:
where_clause = _get_default_where_clause(type(obj), obj_id)
stmt = delete(table).where(where_clause)
session.execute(stmt)


def persist(session, obj, return_id=True):
"""
Use the session to store obj in database, then remove obj from session,
so that on a subsequent load from the database we get a clean instance.
"""
session.add(obj)
session.flush()
obj_id = obj.id if return_id else None # save this before obj is expunged
session.expunge(obj)
return obj_id


def delete_from_database(session, objects):
"""
Delete each object in objects from database.
May be called at the end of a test if use of a context manager is impractical.
(Assume all objects have the id field as their primary key.)
"""
# Ensure we have a list of objects (check for list explicitly: a model can be iterable)
if not isinstance(objects, list):
objects = [objects]

for obj in objects:
table = obj.__table__
stmt = delete(table).where(table.c.id == obj.id)
session.execute(stmt)


def get_stored_obj(session, cls, obj_id=None, where_clause=None, unique=False):
# Either obj_id or where_clause must be provided, but not both
assert bool(obj_id) ^ (where_clause is not None)
if where_clause is None:
where_clause = _get_default_where_clause(cls, obj_id)
stmt = select(cls).where(where_clause)
result = session.execute(stmt)
# unique() is required if result contains joint eager loads against collections
# https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/2253
if unique:
result = result.unique()
return result.scalar_one()


def has_unique_constraint(table, fields):
for constraint in table.constraints:
if isinstance(constraint, UniqueConstraint):
col_names = {c.name for c in constraint.columns}
if set(fields) == col_names:
return True


def has_index(table, fields):
for index in table.indexes:
col_names = {c.name for c in index.columns}
if set(fields) == col_names:
return True


def collection_consists_of_objects(collection, *objects):
"""
Returns True iff list(collection) == list(objects), where object equality is determined
by primary key equality: object1.id == object2.id.
"""
if len(collection) != len(objects): # False if lengths are different
return False
if not collection: # True if both are empty
return True

# Sort, then compare each member by its 'id' attribute, which must be its primary key.
collection.sort(key=lambda item: item.id)
objects = list(objects) # type: ignore
objects.sort(key=lambda item: item.id) # type: ignore

for item1, item2 in zip(collection, objects):
if item1.id is None or item2.id is None or item1.id != item2.id:
return False
return True


def get_unique_value():
"""Generate unique values to accommodate unique constraints."""
return uuid4().hex


def _get_default_where_clause(cls, obj_id):
where_clause = cls.__table__.c.id == obj_id
return where_clause
27 changes: 27 additions & 0 deletions test/unit/data/model/mapping/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker


@pytest.fixture(scope='module')
def engine():
db_uri = 'sqlite:///:memory:'
return create_engine(db_uri)


@pytest.fixture
def session(init_model, engine):
"""
init_model is a fixture that must be defined in the test module using the
session fixture (or in any other discoverable location). Ideally, it will
have module scope and will initialize the models in the database. It must
use the same engine as this fixture. For example:

@pytest.fixture(scope='module')
def init_model(engine):
model.mapper_registry.metadata.create_all(engine)
"""
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
yield Session()
Session.remove()
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from sqlalchemy.orm import registry, Session

from galaxy.model import _HasTable
from . test_model_mapping import (
are_same_entity_collections,
from .common import (
collection_consists_of_objects,
dbcleanup,
dbcleanup_wrapper,
delete_from_database,
Expand Down Expand Up @@ -162,24 +162,39 @@ def test_has_unique_constraint(session):
assert not has_unique_constraint(Foo.__table__, ('field1',))


def test_are_same_entity_collections(session):
def test_collection_consists_of_objects(session):
# create objects
foo1 = Foo()
foo2 = Foo()
foo3 = Foo()
# store objects
persist(session, foo1)
persist(session, foo2)
persist(session, foo3)

# retrieve objects from storage
stored_foo1 = _get_stored_instance_by_id(session, Foo, foo1.id)
stored_foo2 = _get_stored_instance_by_id(session, Foo, foo2.id)
stored_foo3 = _get_stored_instance_by_id(session, Foo, foo3.id)

expected = [foo1, foo2]

assert are_same_entity_collections([stored_foo1, stored_foo2], expected)
assert are_same_entity_collections([stored_foo2, stored_foo1], expected)
assert not are_same_entity_collections([stored_foo1, stored_foo3], expected)
assert not are_same_entity_collections([stored_foo1, stored_foo1, stored_foo2], expected)
# verify retrieved objects are not the same python objects as those we stored
assert stored_foo1 is not foo1
assert stored_foo2 is not foo2
assert stored_foo3 is not foo3

# trivial case
assert collection_consists_of_objects([stored_foo1, stored_foo2], foo1, foo2)
# empty collection and no objects
assert collection_consists_of_objects([])
# ordering in collection does not matter
assert collection_consists_of_objects([stored_foo2, stored_foo1], foo1, foo2)
# contains wrong object
assert not collection_consists_of_objects([stored_foo1, stored_foo3], foo1, foo2)
# contains wrong number of objects
assert not collection_consists_of_objects([stored_foo1, stored_foo1, stored_foo2], foo1, foo2)
# if an object's primary key is not set, it cannot be equal to another object
foo1.id, stored_foo1.id = None, None
assert not collection_consists_of_objects([stored_foo1], foo1)


# Test utilities
Expand Down
Loading