Skip to content

Commit

Permalink
Merge pull request #1122 from ethho/dev-tests-plat-146-aggr
Browse files Browse the repository at this point in the history
PLAT-146: Migrate test_aggr_regressions.py
  • Loading branch information
A-Baji authored Dec 11, 2023
2 parents 90209a6 + 110d642 commit 9c0e407
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 1 deletion.
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
schema_advanced,
schema_adapted,
schema_external,
schema_uuid as schema_uuid_module,
)


Expand Down Expand Up @@ -307,6 +308,20 @@ def schema_ext(connection_test, stores_config, enable_filepath_feature):
schema.drop()


@pytest.fixture
def schema_uuid(connection_test):
schema = dj.Schema(
PREFIX + "_test1",
context=schema_uuid_module.LOCALS_UUID,
connection=connection_test,
)
schema(schema_uuid_module.Basic)
schema(schema_uuid_module.Topic)
schema(schema_uuid_module.Item)
yield schema
schema.drop()


@pytest.fixture(scope="session")
def http_client():
# Initialize httpClient with relevant timeout.
Expand Down
51 changes: 51 additions & 0 deletions tests/schema_aggr_regress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datajoint as dj
import itertools
import inspect


class R(dj.Lookup):
definition = """
r : char(1)
"""
contents = zip("ABCDFGHIJKLMNOPQRST")


class Q(dj.Lookup):
definition = """
-> R
"""
contents = zip("ABCDFGH")


class S(dj.Lookup):
definition = """
-> R
s : int
"""
contents = itertools.product("ABCDF", range(10))


class A(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


class B(dj.Lookup):
definition = """
-> A
id2: int
"""
contents = zip(range(5), range(5, 10))


class X(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_AGGR_REGRESS)
50 changes: 50 additions & 0 deletions tests/schema_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import uuid
import inspect
import datajoint as dj
from . import PREFIX, CONN_INFO

top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000")


class Basic(dj.Manual):
definition = """
item : uuid
---
number : int
"""


class Topic(dj.Manual):
definition = """
# A topic for items
topic_id : uuid # internal identification of a topic, reflects topic name
---
topic : varchar(8000) # full topic name used to generate the topic id
"""

def add(self, topic):
"""add a new topic with a its UUID"""
self.insert1(
dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic)
)


class Item(dj.Computed):
definition = """
item_id : uuid # internal identification of
---
-> Topic
word : varchar(8000)
"""

key_source = Topic # test key source that is not instantiated

def make(self, key):
for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"):
self.insert1(
dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word))
)


LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_UUID)
130 changes: 130 additions & 0 deletions tests/test_aggr_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections.
"""

import pytest
import datajoint as dj
from . import PREFIX
import uuid
from .schema_uuid import Topic, Item, top_level_namespace_id
from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS


@pytest.fixture(scope="function")
def schema_aggr_reg(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
yield schema
schema.drop()


@pytest.fixture(scope="function")
def schema_aggr_reg_with_abx(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress_with_abx",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
schema(A)
schema(B)
schema(X)
yield schema
schema.drop()


def test_issue386(schema_aggr_reg):
"""
--------------- ISSUE 386 -------------------
Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor
Q & (R.aggr(S, n='count(*)') & 'n=2')
Error: Unknown column 'n' in HAVING
"""
result = R.aggr(S, n="count(*)") & "n=10"
result = Q & result
result.fetch()


def test_issue449(schema_aggr_reg):
"""
---------------- ISSUE 449 ------------------
Issue 449 arises from incorrect group by attributes after joining with a dj.U()
"""
result = dj.U("n") * R.aggr(S, n="max(s)")
result.fetch()


def test_issue484(schema_aggr_reg):
"""
---------------- ISSUE 484 -----------------
Issue 484
"""
q = dj.U().aggr(S, n="max(s)")
n = q.fetch("n")
n = q.fetch1("n")
q = dj.U().aggr(S, n="avg(s)")
result = dj.U().aggr(q, m="max(n)")
result.fetch()


def test_union_join(schema_aggr_reg_with_abx):
"""
This test fails if it runs after TestIssue558.
https://github.com/datajoint/datajoint-python/issues/930
"""
A.insert(zip([100, 200, 300, 400, 500, 600]))
B.insert([(100, 11), (200, 22), (300, 33), (400, 44)])
q1 = B & "id < 300"
q2 = B & "id > 300"

expected_data = [
{"id": 0, "id2": 5},
{"id": 1, "id2": 6},
{"id": 2, "id2": 7},
{"id": 3, "id2": 8},
{"id": 4, "id2": 9},
{"id": 100, "id2": 11},
{"id": 200, "id2": 22},
{"id": 400, "id2": 44},
]

assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data


class TestIssue558:
"""
--------------- ISSUE 558 ------------------
Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed
by a projection into a single SELECT statement, which in several unusual cases produces unexpected results.
"""

def test_issue558_part1(self, schema_aggr_reg_with_abx):
q = (A - B).proj(id2="3")
assert len(A - B) == len(q)

def test_issue558_part2(self, schema_aggr_reg_with_abx):
d = dict(id=3, id2=5)
assert len(X & d) == len((X & d).proj(id2="3"))


def test_left_join_len(schema_uuid):
Topic().add("jeff")
Item.populate()
Topic().add("jeff2")
Topic().add("jeff3")
q = Topic.join(
Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True
)
qf = q.fetch()
assert len(q) == len(qf)
2 changes: 1 addition & 1 deletion tests/test_erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_make_image(schema_simp):

def test_part_table_parsing(schema_simp):
# https://github.com/datajoint/datajoint-python/issues/882
erd = dj.Di(schema_simp)
erd = dj.Di(schema_simp, context=LOCALS_SIMPLE)
graph = erd._make_graph()
assert "OutfitLaunch" in graph.nodes()
assert "OutfitLaunch.OutfitPiece" in graph.nodes()

0 comments on commit 9c0e407

Please sign in to comment.