-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1122 from ethho/dev-tests-plat-146-aggr
PLAT-146: Migrate test_aggr_regressions.py
- Loading branch information
Showing
5 changed files
with
247 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import datajoint as dj | ||
import itertools | ||
import inspect | ||
|
||
|
||
class R(dj.Lookup): | ||
definition = """ | ||
r : char(1) | ||
""" | ||
contents = zip("ABCDFGHIJKLMNOPQRST") | ||
|
||
|
||
class Q(dj.Lookup): | ||
definition = """ | ||
-> R | ||
""" | ||
contents = zip("ABCDFGH") | ||
|
||
|
||
class S(dj.Lookup): | ||
definition = """ | ||
-> R | ||
s : int | ||
""" | ||
contents = itertools.product("ABCDF", range(10)) | ||
|
||
|
||
class A(dj.Lookup): | ||
definition = """ | ||
id: int | ||
""" | ||
contents = zip(range(10)) | ||
|
||
|
||
class B(dj.Lookup): | ||
definition = """ | ||
-> A | ||
id2: int | ||
""" | ||
contents = zip(range(5), range(5, 10)) | ||
|
||
|
||
class X(dj.Lookup): | ||
definition = """ | ||
id: int | ||
""" | ||
contents = zip(range(10)) | ||
|
||
|
||
LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)} | ||
__all__ = list(LOCALS_AGGR_REGRESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import uuid | ||
import inspect | ||
import datajoint as dj | ||
from . import PREFIX, CONN_INFO | ||
|
||
top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000") | ||
|
||
|
||
class Basic(dj.Manual): | ||
definition = """ | ||
item : uuid | ||
--- | ||
number : int | ||
""" | ||
|
||
|
||
class Topic(dj.Manual): | ||
definition = """ | ||
# A topic for items | ||
topic_id : uuid # internal identification of a topic, reflects topic name | ||
--- | ||
topic : varchar(8000) # full topic name used to generate the topic id | ||
""" | ||
|
||
def add(self, topic): | ||
"""add a new topic with a its UUID""" | ||
self.insert1( | ||
dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic) | ||
) | ||
|
||
|
||
class Item(dj.Computed): | ||
definition = """ | ||
item_id : uuid # internal identification of | ||
--- | ||
-> Topic | ||
word : varchar(8000) | ||
""" | ||
|
||
key_source = Topic # test key source that is not instantiated | ||
|
||
def make(self, key): | ||
for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"): | ||
self.insert1( | ||
dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word)) | ||
) | ||
|
||
|
||
LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)} | ||
__all__ = list(LOCALS_UUID) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections. | ||
""" | ||
|
||
import pytest | ||
import datajoint as dj | ||
from . import PREFIX | ||
import uuid | ||
from .schema_uuid import Topic, Item, top_level_namespace_id | ||
from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def schema_aggr_reg(connection_test): | ||
context = LOCALS_AGGR_REGRESS | ||
schema = dj.Schema( | ||
PREFIX + "_aggr_regress", | ||
context=context, | ||
connection=connection_test, | ||
) | ||
schema(R) | ||
schema(Q) | ||
schema(S) | ||
yield schema | ||
schema.drop() | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def schema_aggr_reg_with_abx(connection_test): | ||
context = LOCALS_AGGR_REGRESS | ||
schema = dj.Schema( | ||
PREFIX + "_aggr_regress_with_abx", | ||
context=context, | ||
connection=connection_test, | ||
) | ||
schema(R) | ||
schema(Q) | ||
schema(S) | ||
schema(A) | ||
schema(B) | ||
schema(X) | ||
yield schema | ||
schema.drop() | ||
|
||
|
||
def test_issue386(schema_aggr_reg): | ||
""" | ||
--------------- ISSUE 386 ------------------- | ||
Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor | ||
Q & (R.aggr(S, n='count(*)') & 'n=2') | ||
Error: Unknown column 'n' in HAVING | ||
""" | ||
result = R.aggr(S, n="count(*)") & "n=10" | ||
result = Q & result | ||
result.fetch() | ||
|
||
|
||
def test_issue449(schema_aggr_reg): | ||
""" | ||
---------------- ISSUE 449 ------------------ | ||
Issue 449 arises from incorrect group by attributes after joining with a dj.U() | ||
""" | ||
result = dj.U("n") * R.aggr(S, n="max(s)") | ||
result.fetch() | ||
|
||
|
||
def test_issue484(schema_aggr_reg): | ||
""" | ||
---------------- ISSUE 484 ----------------- | ||
Issue 484 | ||
""" | ||
q = dj.U().aggr(S, n="max(s)") | ||
n = q.fetch("n") | ||
n = q.fetch1("n") | ||
q = dj.U().aggr(S, n="avg(s)") | ||
result = dj.U().aggr(q, m="max(n)") | ||
result.fetch() | ||
|
||
|
||
def test_union_join(schema_aggr_reg_with_abx): | ||
""" | ||
This test fails if it runs after TestIssue558. | ||
https://github.com/datajoint/datajoint-python/issues/930 | ||
""" | ||
A.insert(zip([100, 200, 300, 400, 500, 600])) | ||
B.insert([(100, 11), (200, 22), (300, 33), (400, 44)]) | ||
q1 = B & "id < 300" | ||
q2 = B & "id > 300" | ||
|
||
expected_data = [ | ||
{"id": 0, "id2": 5}, | ||
{"id": 1, "id2": 6}, | ||
{"id": 2, "id2": 7}, | ||
{"id": 3, "id2": 8}, | ||
{"id": 4, "id2": 9}, | ||
{"id": 100, "id2": 11}, | ||
{"id": 200, "id2": 22}, | ||
{"id": 400, "id2": 44}, | ||
] | ||
|
||
assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data | ||
|
||
|
||
class TestIssue558: | ||
""" | ||
--------------- ISSUE 558 ------------------ | ||
Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed | ||
by a projection into a single SELECT statement, which in several unusual cases produces unexpected results. | ||
""" | ||
|
||
def test_issue558_part1(self, schema_aggr_reg_with_abx): | ||
q = (A - B).proj(id2="3") | ||
assert len(A - B) == len(q) | ||
|
||
def test_issue558_part2(self, schema_aggr_reg_with_abx): | ||
d = dict(id=3, id2=5) | ||
assert len(X & d) == len((X & d).proj(id2="3")) | ||
|
||
|
||
def test_left_join_len(schema_uuid): | ||
Topic().add("jeff") | ||
Item.populate() | ||
Topic().add("jeff2") | ||
Topic().add("jeff3") | ||
q = Topic.join( | ||
Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True | ||
) | ||
qf = q.fetch() | ||
assert len(q) == len(qf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters