diff --git a/tests/__init__.py b/tests/__init__.py index 8b825a042..de57f6eab 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,60 +10,3 @@ user=os.getenv("DJ_USER"), password=os.getenv("DJ_PASS"), ) - - -@pytest.fixture -def connection_root(): - """Root user database connection.""" - dj.config["safemode"] = False - connection = dj.Connection( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), - ) - yield connection - dj.config["safemode"] = True - connection.close() - - -@pytest.fixture -def connection_test(connection_root): - """Test user database connection.""" - database = f"{PREFIX}%%" - credentials = dict( - host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint" - ) - permission = "ALL PRIVILEGES" - - # Create MySQL users - if version.parse( - connection_root.query("select @@version;").fetchone()[0] - ) >= version.parse("8.0.0"): - # create user if necessary on mysql8 - connection_root.query( - f""" - CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; - """ - ) - connection_root.query( - f""" - GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%'; - """ - ) - else: - # grant permissions. For MySQL 5.7 this also automatically creates user - # if not exists - connection_root.query( - f""" - GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; - """ - ) - - connection = dj.Connection(**credentials) - yield connection - connection_root.query(f"""DROP USER `{credentials["user"]}`""") - connection.close() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..e13a13632 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,154 @@ +import datajoint as dj +from packaging import version +import os +import pytest +from . import PREFIX, schema, schema_simple, schema_advanced + +namespace = locals() + + +@pytest.fixture(scope="session") +def connection_root(): + """Root user database connection.""" + dj.config["safemode"] = False + connection = dj.Connection( + host=os.getenv("DJ_HOST"), + user=os.getenv("DJ_USER"), + password=os.getenv("DJ_PASS"), + ) + yield connection + dj.config["safemode"] = True + connection.close() + + +@pytest.fixture(scope="session") +def connection_test(connection_root): + """Test user database connection.""" + database = f"{PREFIX}%%" + credentials = dict( + host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint" + ) + permission = "ALL PRIVILEGES" + + # Create MySQL users + if version.parse( + connection_root.query("select @@version;").fetchone()[0] + ) >= version.parse("8.0.0"): + # create user if necessary on mysql8 + connection_root.query( + f""" + CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%' + IDENTIFIED BY '{credentials["password"]}'; + """ + ) + connection_root.query( + f""" + GRANT {permission} ON `{database}`.* + TO '{credentials["user"]}'@'%%'; + """ + ) + else: + # grant permissions. For MySQL 5.7 this also automatically creates user + # if not exists + connection_root.query( + f""" + GRANT {permission} ON `{database}`.* + TO '{credentials["user"]}'@'%%' + IDENTIFIED BY '{credentials["password"]}'; + """ + ) + + connection = dj.Connection(**credentials) + yield connection + connection_root.query(f"""DROP USER `{credentials["user"]}`""") + connection.close() + + +@pytest.fixture(scope="module") +def schema_any(connection_test): + schema_any = dj.Schema( + PREFIX + "_test1", schema.__dict__, connection=connection_test + ) + schema_any(schema.TTest) + schema_any(schema.TTest2) + schema_any(schema.TTest3) + schema_any(schema.NullableNumbers) + schema_any(schema.TTestExtra) + schema_any(schema.TTestNoExtra) + schema_any(schema.Auto) + schema_any(schema.User) + schema_any(schema.Subject) + schema_any(schema.Language) + schema_any(schema.Experiment) + schema_any(schema.Trial) + schema_any(schema.Ephys) + schema_any(schema.Image) + schema_any(schema.UberTrash) + schema_any(schema.UnterTrash) + schema_any(schema.SimpleSource) + schema_any(schema.SigIntTable) + schema_any(schema.SigTermTable) + schema_any(schema.DjExceptionName) + schema_any(schema.ErrorClass) + schema_any(schema.DecimalPrimaryKey) + schema_any(schema.IndexRich) + schema_any(schema.ThingA) + schema_any(schema.ThingB) + schema_any(schema.ThingC) + schema_any(schema.Parent) + schema_any(schema.Child) + schema_any(schema.ComplexParent) + schema_any(schema.ComplexChild) + schema_any(schema.SubjectA) + schema_any(schema.SessionA) + schema_any(schema.SessionStatusA) + schema_any(schema.SessionDateA) + schema_any(schema.Stimulus) + schema_any(schema.Longblob) + yield schema_any + schema_any.drop() + + +@pytest.fixture(scope="module") +def schema_simp(connection_test): + schema = dj.Schema( + PREFIX + "_relational", schema_simple.__dict__, connection=connection_test + ) + schema(schema_simple.IJ) + schema(schema_simple.JI) + schema(schema_simple.A) + schema(schema_simple.B) + schema(schema_simple.L) + schema(schema_simple.D) + schema(schema_simple.E) + schema(schema_simple.F) + schema(schema_simple.F) + schema(schema_simple.DataA) + schema(schema_simple.DataB) + schema(schema_simple.Website) + schema(schema_simple.Profile) + schema(schema_simple.Website) + schema(schema_simple.TTestUpdate) + schema(schema_simple.ArgmaxTest) + schema(schema_simple.ReservedWord) + schema(schema_simple.OutfitLaunch) + yield schema + schema.drop() + + +@pytest.fixture(scope="module") +def schema_adv(connection_test): + schema = dj.Schema( + PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test + ) + schema(schema_advanced.Person) + schema(schema_advanced.Parent) + schema(schema_advanced.Subject) + schema(schema_advanced.Prep) + schema(schema_advanced.Slice) + schema(schema_advanced.Cell) + schema(schema_advanced.InputCell) + schema(schema_advanced.LocalSynapse) + schema(schema_advanced.GlobalSynapse) + yield schema + schema.drop() diff --git a/tests/schema.py b/tests/schema.py new file mode 100644 index 000000000..864c5efe4 --- /dev/null +++ b/tests/schema.py @@ -0,0 +1,452 @@ +""" +Sample schema with realistic tables for testing +""" + +import random +import numpy as np +import datajoint as dj +import inspect + +LOCALS_ANY = locals() + + +class TTest(dj.Lookup): + """ + doc string + """ + + definition = """ + key : int # key + --- + value : int # value + """ + contents = [(k, 2 * k) for k in range(10)] + + +class TTest2(dj.Manual): + definition = """ + key : int # key + --- + value : int # value + """ + + +class TTest3(dj.Manual): + definition = """ + key : int + --- + value : varchar(300) + """ + + +class NullableNumbers(dj.Manual): + definition = """ + key : int + --- + fvalue = null : float + dvalue = null : double + ivalue = null : int + """ + + +class TTestExtra(dj.Manual): + """ + clone of Test but with an extra field + """ + + definition = TTest.definition + "\nextra : int # extra int\n" + + +class TTestNoExtra(dj.Manual): + """ + clone of Test but with no extra fields + """ + + definition = TTest.definition + + +class Auto(dj.Lookup): + definition = """ + id :int auto_increment + --- + name :varchar(12) + """ + + def fill(self): + if not self: + self.insert([dict(name="Godel"), dict(name="Escher"), dict(name="Bach")]) + + +class User(dj.Lookup): + definition = """ # lab members + username: varchar(12) + """ + contents = [ + ["Jake"], + ["Cathryn"], + ["Shan"], + ["Fabian"], + ["Edgar"], + ["George"], + ["Dimitri"], + ] + + +class Subject(dj.Lookup): + definition = """ # Basic information about animal subjects used in experiments + subject_id :int # unique subject id + --- + real_id :varchar(40) # real-world name. Omit if the same as subject_id + species = "mouse" :enum('mouse', 'monkey', 'human') + date_of_birth :date + subject_notes :varchar(4000) + unique index (real_id, species) + """ + + contents = [ + [1551, "1551", "mouse", "2015-04-01", "genetically engineered super mouse"], + [10, "Curious George", "monkey", "2008-06-30", ""], + [1552, "1552", "mouse", "2015-06-15", ""], + [1553, "1553", "mouse", "2016-07-01", ""], + ] + + +class Language(dj.Lookup): + definition = """ + # languages spoken by some of the developers + # additional comments are ignored + name : varchar(40) # name of the developer + language : varchar(40) # language + """ + contents = [ + ("Fabian", "English"), + ("Edgar", "English"), + ("Dimitri", "English"), + ("Dimitri", "Ukrainian"), + ("Fabian", "German"), + ("Edgar", "Japanese"), + ] + + +class Experiment(dj.Imported): + definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject + --- + experiment_date :date # date when experiment was started + -> [nullable] User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + fake_experiments_per_subject = 5 + + def make(self, key): + """ + populate with random data + """ + from datetime import date, timedelta + + users = [None, None] + list(User().fetch()["username"]) + random.seed("Amazing Seed") + self.insert( + dict( + key, + experiment_id=experiment_id, + experiment_date=( + date.today() - timedelta(random.expovariate(1 / 30)) + ).isoformat(), + username=random.choice(users), + ) + for experiment_id in range(self.fake_experiments_per_subject) + ) + + +class Trial(dj.Imported): + definition = """ # a trial within an experiment + -> Experiment.proj(animal='subject_id') + trial_id :smallint # trial number + --- + start_time :double # (s) + """ + + class Condition(dj.Part): + definition = """ # trial conditions + -> Trial + cond_idx : smallint # condition number + ---- + orientation : float # degrees + """ + + def make(self, key): + """populate with random data (pretend reading from raw files)""" + random.seed("Amazing Seed") + trial = self.Condition() + for trial_id in range(10): + key["trial_id"] = trial_id + self.insert1(dict(key, start_time=random.random() * 1e9)) + trial.insert( + dict(key, cond_idx=cond_idx, orientation=random.random() * 360) + for cond_idx in range(30) + ) + + +class Ephys(dj.Imported): + definition = """ # some kind of electrophysiological recording + -> Trial + ---- + sampling_frequency :double # (Hz) + duration :decimal(7,3) # (s) + """ + + class Channel(dj.Part): + definition = """ # subtable containing individual channels + -> master + channel :tinyint unsigned # channel number within Ephys + ---- + voltage : longblob + current = null : longblob # optional current to test null handling + """ + + def _make_tuples(self, key): + """ + populate with random data + """ + random.seed(str(key)) + row = dict( + key, sampling_frequency=6000, duration=np.minimum(2, random.expovariate(1)) + ) + self.insert1(row) + number_samples = int(row["duration"] * row["sampling_frequency"] + 0.5) + sub = self.Channel() + sub.insert( + dict( + key, + channel=channel, + voltage=np.float32(np.random.randn(number_samples)), + ) + for channel in range(2) + ) + + +class Image(dj.Manual): + definition = """ + # table for testing blob inserts + id : int # image identifier + --- + img : longblob # image + """ + + +class UberTrash(dj.Lookup): + definition = """ + id : int + --- + """ + contents = [(1,)] + + +class UnterTrash(dj.Lookup): + definition = """ + -> UberTrash + my_id : int + --- + """ + contents = [(1, 1), (1, 2)] + + +class SimpleSource(dj.Lookup): + definition = """ + id : int # id + """ + contents = ((x,) for x in range(10)) + + +class SigIntTable(dj.Computed): + definition = """ + -> SimpleSource + """ + + def _make_tuples(self, key): + raise KeyboardInterrupt + + +class SigTermTable(dj.Computed): + definition = """ + -> SimpleSource + """ + + def make(self, key): + raise SystemExit("SIGTERM received") + + +class DjExceptionName(dj.Lookup): + definition = """ + dj_exception_name: char(64) + """ + + @property + def contents(self): + return [ + [member_name] + for member_name, member_type in inspect.getmembers(dj.errors) + if inspect.isclass(member_type) and issubclass(member_type, Exception) + ] + + +class ErrorClass(dj.Computed): + definition = """ + -> DjExceptionName + """ + + def make(self, key): + exception_name = key["dj_exception_name"] + raise getattr(dj.errors, exception_name) + + +class DecimalPrimaryKey(dj.Lookup): + definition = """ + id : decimal(4,3) + """ + contents = zip((0.1, 0.25, 3.99)) + + +class IndexRich(dj.Manual): + definition = """ + -> Subject + --- + -> [unique, nullable] User.proj(first="username") + first_date : date + value : int + index (first_date, value) + """ + + +# Schema for issue 656 +class ThingA(dj.Manual): + definition = """ + a: int + """ + + +class ThingB(dj.Manual): + definition = """ + b1: int + b2: int + --- + b3: int + """ + + +class ThingC(dj.Manual): + definition = """ + -> ThingA + --- + -> [unique, nullable] ThingB + """ + + +class Parent(dj.Lookup): + definition = """ + parent_id: int + --- + name: varchar(30) + """ + contents = [(1, "Joe")] + + +class Child(dj.Lookup): + definition = """ + -> Parent + child_id: int + --- + name: varchar(30) + """ + contents = [(1, 12, "Dan")] + + +# Related to issue #886 (8), #883 (5) +class ComplexParent(dj.Lookup): + definition = "\n".join(["parent_id_{}: int".format(i + 1) for i in range(8)]) + contents = [tuple(i for i in range(8))] + + +class ComplexChild(dj.Lookup): + definition = "\n".join( + ["-> ComplexParent"] + ["child_id_{}: int".format(i + 1) for i in range(1)] + ) + contents = [tuple(i for i in range(9))] + + +class SubjectA(dj.Lookup): + definition = """ + subject_id: varchar(32) + --- + dob : date + sex : enum('M', 'F', 'U') + """ + contents = [ + ("mouse1", "2020-09-01", "M"), + ("mouse2", "2020-03-19", "F"), + ("mouse3", "2020-08-23", "F"), + ] + + +class SessionA(dj.Lookup): + definition = """ + -> SubjectA + session_start_time: datetime + --- + session_dir='' : varchar(32) + """ + contents = [ + ("mouse1", "2020-12-01 12:32:34", ""), + ("mouse1", "2020-12-02 12:32:34", ""), + ("mouse1", "2020-12-03 12:32:34", ""), + ("mouse1", "2020-12-04 12:32:34", ""), + ] + + +class SessionStatusA(dj.Lookup): + definition = """ + -> SessionA + --- + status: enum('in_training', 'trained_1a', 'trained_1b', 'ready4ephys') + """ + contents = [ + ("mouse1", "2020-12-01 12:32:34", "in_training"), + ("mouse1", "2020-12-02 12:32:34", "trained_1a"), + ("mouse1", "2020-12-03 12:32:34", "trained_1b"), + ("mouse1", "2020-12-04 12:32:34", "ready4ephys"), + ] + + +class SessionDateA(dj.Lookup): + definition = """ + -> SubjectA + session_date: date + """ + contents = [ + ("mouse1", "2020-12-01"), + ("mouse1", "2020-12-02"), + ("mouse1", "2020-12-03"), + ("mouse1", "2020-12-04"), + ] + + +class Stimulus(dj.Lookup): + definition = """ + id: int + --- + contrast: int + brightness: int + """ + + +class Longblob(dj.Manual): + definition = """ + id: int + --- + data: longblob + """ diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py new file mode 100644 index 000000000..104e4d1e4 --- /dev/null +++ b/tests/schema_advanced.py @@ -0,0 +1,137 @@ +import datajoint as dj + +LOCALS_ADVANCED = locals() + + +class Person(dj.Manual): + definition = """ + person_id : int + ---- + full_name : varchar(60) + sex : enum('M','F') + """ + + def fill(self): + """ + fill fake names from www.fakenamegenerator.com + """ + self.insert( + ( + (0, "May K. Hall", "F"), + (1, "Jeffrey E. Gillen", "M"), + (2, "Hanna R. Walters", "F"), + (3, "Russel S. James", "M"), + (4, "Robbin J. Fletcher", "F"), + (5, "Wade J. Sullivan", "M"), + (6, "Dorothy J. Chen", "F"), + (7, "Michael L. Kowalewski", "M"), + (8, "Kimberly J. Stringer", "F"), + (9, "Mark G. Hair", "M"), + (10, "Mary R. Thompson", "F"), + (11, "Graham C. Gilpin", "M"), + (12, "Nelda T. Ruggeri", "F"), + (13, "Bryan M. Cummings", "M"), + (14, "Sara C. Le", "F"), + (15, "Myron S. Jaramillo", "M"), + ) + ) + + +class Parent(dj.Manual): + definition = """ + -> Person + parent_sex : enum('M','F') + --- + -> Person.proj(parent='person_id') + """ + + def fill(self): + def make_parent(pid, parent): + return dict( + person_id=pid, + parent=parent, + parent_sex=(Person & {"person_id": parent}).fetch1("sex"), + ) + + self.insert( + make_parent(*r) + for r in ( + (0, 2), + (0, 3), + (1, 4), + (1, 5), + (2, 4), + (2, 5), + (3, 4), + (3, 7), + (4, 7), + (4, 8), + (5, 9), + (5, 10), + (6, 9), + (6, 10), + (7, 11), + (7, 12), + (8, 11), + (8, 14), + (9, 11), + (9, 12), + (10, 13), + (10, 14), + (11, 14), + (11, 15), + (12, 14), + (12, 15), + ) + ) + + +class Subject(dj.Manual): + definition = """ + subject : int + --- + -> [unique, nullable] Person + """ + + +class Prep(dj.Manual): + definition = """ + prep : int + """ + + +class Slice(dj.Manual): + definition = """ + -> Prep + slice : int + """ + + +class Cell(dj.Manual): + definition = """ + -> Slice + cell : int + """ + + +class InputCell(dj.Manual): + definition = """ # a synapse within the slice + -> Cell + -> Cell.proj(input="cell") + """ + + +class LocalSynapse(dj.Manual): + definition = """ # a synapse within the slice + -> Cell.proj(presynaptic='cell') + -> Cell.proj(postsynaptic='cell') + """ + + +class GlobalSynapse(dj.Manual): + # Mix old-style and new-style projected foreign keys + definition = """ + # a synapse within the slice + -> Cell.proj(pre_slice="slice", pre_cell="cell") + -> Cell.proj(post_slice="slice", post_cell="cell") + """ diff --git a/tests/schema_simple.py b/tests/schema_simple.py new file mode 100644 index 000000000..bb5c21ff5 --- /dev/null +++ b/tests/schema_simple.py @@ -0,0 +1,262 @@ +""" +A simple, abstract schema to test relational algebra +""" +import random +import datajoint as dj +import itertools +import hashlib +import uuid +import faker +import numpy as np +from datetime import date, timedelta + +LOCALS_SIMPLE = locals() + + +class IJ(dj.Lookup): + definition = """ # tests restrictions + i : int + j : int + """ + contents = list(dict(i=i, j=j + 2) for i in range(3) for j in range(3)) + + +class JI(dj.Lookup): + definition = """ # tests restrictions by relations when attributes are reordered + j : int + i : int + """ + contents = list(dict(i=i + 1, j=j) for i in range(3) for j in range(3)) + + +class A(dj.Lookup): + definition = """ + id_a :int + --- + cond_in_a :tinyint + """ + contents = [(i, i % 4 > i % 3) for i in range(10)] + + +class B(dj.Computed): + definition = """ + -> A + id_b :int + --- + mu :float # mean value + sigma :float # standard deviation + n :smallint # number samples + """ + + class C(dj.Part): + definition = """ + -> B + id_c :int + --- + value :float # normally distributed variables according to parameters in B + """ + + def make(self, key): + random.seed(str(key)) + sub = B.C() + for i in range(4): + key["id_b"] = i + mu = random.normalvariate(0, 10) + sigma = random.lognormvariate(0, 4) + n = random.randint(0, 10) + self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) + sub.insert( + dict(key, id_c=j, value=random.normalvariate(mu, sigma)) + for j in range(n) + ) + + +class L(dj.Lookup): + definition = """ + id_l: int + --- + cond_in_l :tinyint + """ + contents = [(i, i % 3 >= i % 5) for i in range(30)] + + +class D(dj.Computed): + definition = """ + -> A + id_d :int + --- + -> L + """ + + def _make_tuples(self, key): + # make reference to a random tuple from L + random.seed(str(key)) + lookup = list(L().fetch("KEY")) + self.insert(dict(key, id_d=i, **random.choice(lookup)) for i in range(4)) + + +class E(dj.Computed): + definition = """ + -> B + -> D + --- + -> L + """ + + class F(dj.Part): + definition = """ + -> E + id_f :int + --- + -> B.C + """ + + def make(self, key): + random.seed(str(key)) + self.insert1(dict(key, **random.choice(list(L().fetch("KEY"))))) + sub = E.F() + references = list((B.C() & key).fetch("KEY")) + random.shuffle(references) + sub.insert( + dict(key, id_f=i, **ref) + for i, ref in enumerate(references) + if random.getrandbits(1) + ) + + +class F(dj.Manual): + definition = """ + id: int + ---- + date=null: date + """ + + +class DataA(dj.Lookup): + definition = """ + idx : int + --- + a : int + """ + contents = list(zip(range(5), range(5))) + + +class DataB(dj.Lookup): + definition = """ + idx : int + --- + a : int + """ + contents = list(zip(range(5), range(5, 10))) + + +class Website(dj.Lookup): + definition = """ + url_hash : uuid + --- + url : varchar(1000) + """ + + def insert1_url(self, url): + hashed = hashlib.sha1() + hashed.update(url.encode()) + url_hash = uuid.UUID(bytes=hashed.digest()[:16]) + self.insert1(dict(url=url, url_hash=url_hash), skip_duplicates=True) + return url_hash + + +class Profile(dj.Manual): + definition = """ + ssn : char(11) + --- + name : varchar(70) + residence : varchar(255) + blood_group : enum('A+', 'A-', 'AB+', 'AB-', 'B+', 'B-', 'O+', 'O-') + username : varchar(120) + birthdate : date + job : varchar(120) + sex : enum('M', 'F') + """ + + class Website(dj.Part): + definition = """ + -> master + -> Website + """ + + def populate_random(self, n=10): + fake = faker.Faker() + faker.Faker.seed(0) # make test deterministic + for _ in range(n): + profile = fake.profile() + with self.connection.transaction: + self.insert1(profile, ignore_extra_fields=True) + for url in profile["website"]: + self.Website().insert1( + dict(ssn=profile["ssn"], url_hash=Website().insert1_url(url)) + ) + + +class TTestUpdate(dj.Lookup): + definition = """ + primary_key : int + --- + string_attr : varchar(255) + num_attr=null : float + blob_attr : longblob + """ + + contents = [ + (0, "my_string", 0.0, np.random.randn(10, 2)), + (1, "my_other_string", 1.0, np.random.randn(20, 1)), + ] + + +class ArgmaxTest(dj.Lookup): + definition = """ + primary_key : int + --- + secondary_key : char(2) + val : float + """ + + n = 10 + + @property + def contents(self): + n = self.n + yield from zip( + range(n**2), + itertools.chain(*itertools.repeat(tuple(map(chr, range(100, 100 + n))), n)), + np.random.rand(n**2), + ) + + +class ReservedWord(dj.Manual): + definition = """ + # Test of SQL reserved words + key : int + --- + in : varchar(25) + from : varchar(25) + int : int + select : varchar(25) + """ + + +class OutfitLaunch(dj.Lookup): + definition = """ + # Monthly released designer outfits + release_id: int + --- + day: date + """ + contents = [(0, date.today() - timedelta(days=15))] + + class OutfitPiece(dj.Part, dj.Lookup): + definition = """ + # Outfit piece associated with outfit + -> OutfitLaunch + piece: varchar(20) + """ + contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")] diff --git a/tests_old/test_blob.py b/tests/test_blob.py similarity index 67% rename from tests_old/test_blob.py rename to tests/test_blob.py index 3765edc57..23de7be76 100644 --- a/tests_old/test_blob.py +++ b/tests/test_blob.py @@ -2,20 +2,12 @@ import timeit import numpy as np import uuid -from . import schema from decimal import Decimal from datetime import datetime from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal -from nose.tools import ( - assert_equal, - assert_true, - assert_false, - assert_list_equal, - assert_set_equal, - assert_tuple_equal, - assert_dict_equal, -) +from pytest import approx +from .schema import * def test_pack(): @@ -24,19 +16,19 @@ def test_pack(): -3.7e-2, np.float64(3e31), -np.inf, - np.int8(-3), - np.uint8(-1), + np.array(-3).astype(np.uint8), + np.array(-1).astype(np.uint8), np.int16(-33), - np.uint16(-33), + np.array(-33).astype(np.uint16), np.int32(-3), - np.uint32(-1), + np.array(-1).astype(np.uint32), np.int64(373), - np.uint64(-3), + np.array(-3).astype(np.uint64), ): - assert_equal(x, unpack(pack(x)), "Scalars don't match!") + assert x == approx(unpack(pack(x)), rel=1e-6), "Scalars don't match!" x = np.nan - assert_true(np.isnan(unpack(pack(x))), "nan scalar did not match!") + assert np.isnan(unpack(pack(x))), "nan scalar did not match!" x = np.random.randn(8, 10) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -45,7 +37,7 @@ def test_pack(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") x = 7j - assert_equal(x, unpack(pack(x)), "Complex scalar does not match") + assert x == unpack(pack(x)), "Complex scalar does not match" x = np.float32(np.random.randn(3, 4, 5)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -54,41 +46,37 @@ def test_pack(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") x = None - assert_true(unpack(pack(x)) is None, "None did not match") + assert unpack(pack(x)) is None, "None did not match" x = -255 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, int) and not isinstance(y, np.ndarray), - "Scalar int did not match", - ) + assert ( + x == y and isinstance(y, int) and not isinstance(y, np.ndarray) + ), "Scalar int did not match" x = -25523987234234287910987234987098245697129798713407812347 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, int) and not isinstance(y, np.ndarray), - "Unbounded int did not match", - ) + assert ( + x == y and isinstance(y, int) and not isinstance(y, np.ndarray) + ), "Unbounded int did not match" x = 7.0 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, float) and not isinstance(y, np.ndarray), - "Scalar float did not match", - ) + assert ( + x == y and isinstance(y, float) and not isinstance(y, np.ndarray) + ), "Scalar float did not match" x = 7j y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, complex) and not isinstance(y, np.ndarray), - "Complex scalar did not match", - ) + assert ( + x == y and isinstance(y, complex) and not isinstance(y, np.ndarray) + ), "Complex scalar did not match" x = True - assert_true(unpack(pack(x)) is True, "Scalar bool did not match") + assert unpack(pack(x)) is True, "Scalar bool did not match" x = [None] - assert_list_equal(x, unpack(pack(x))) + assert [None] == unpack(pack(x)) x = { "name": "Anonymous", @@ -98,22 +86,22 @@ def test_pack(): (11, 12): None, } y = unpack(pack(x)) - assert_dict_equal(x, y, "Dict do not match!") - assert_false( - isinstance(["range"][0], np.ndarray), "Scalar int was coerced into array." - ) + assert x == y, "Dict do not match!" + assert not isinstance( + ["range"][0], np.ndarray + ), "Scalar int was coerced into array." x = uuid.uuid4() - assert_equal(x, unpack(pack(x)), "UUID did not match") + assert x == unpack(pack(x)), "UUID did not match" x = Decimal("-112122121.000003000") - assert_equal(x, unpack(pack(x)), "Decimal did not pack/unpack correctly") + assert x == unpack(pack(x)), "Decimal did not pack/unpack correctly" x = [1, datetime.now(), {1: "one", "two": 2}, (1, 2)] - assert_list_equal(x, unpack(pack(x)), "List did not pack/unpack correctly") + assert x == unpack(pack(x)), "List did not pack/unpack correctly" x = (1, datetime.now(), {1: "one", "two": 2}, (uuid.uuid4(), 2)) - assert_tuple_equal(x, unpack(pack(x)), "Tuple did not pack/unpack correctly") + assert x == unpack(pack(x)), "Tuple did not pack/unpack correctly" x = ( 1, @@ -121,36 +109,34 @@ def test_pack(): {"yes!": [1, 2, np.array((3, 4))]}, ) y = unpack(pack(x)) - assert_dict_equal(x[1], y[1]) + assert x[1] == y[1] assert_array_equal(x[2]["yes!"][2], y[2]["yes!"][2]) x = {"elephant"} - assert_set_equal(x, unpack(pack(x)), "Set did not pack/unpack correctly") + assert x == unpack(pack(x)), "Set did not pack/unpack correctly" x = tuple(range(10)) - assert_tuple_equal( - x, unpack(pack(range(10))), "Iterator did not pack/unpack correctly" - ) + assert x == unpack(pack(range(10))), "Iterator did not pack/unpack correctly" x = Decimal("1.24") - assert_true(x == unpack(pack(x)), "Decimal object did not pack/unpack correctly") + assert x == approx(unpack(pack(x))), "Decimal object did not pack/unpack correctly" x = datetime.now() - assert_true(x == unpack(pack(x)), "Datetime object did not pack/unpack correctly") + assert x == unpack(pack(x)), "Datetime object did not pack/unpack correctly" x = np.bool_(True) - assert_true(x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly") + assert x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly" x = "test" - assert_true(x == unpack(pack(x)), "String object did not pack/unpack correctly") + assert x == unpack(pack(x)), "String object did not pack/unpack correctly" x = np.array(["yes"]) - assert_true( - x == unpack(pack(x)), "Numpy string array object did not pack/unpack correctly" - ) + assert x == unpack( + pack(x) + ), "Numpy string array object did not pack/unpack correctly" x = np.datetime64("1998").astype("datetime64[us]") - assert_true(x == unpack(pack(x))) + assert x == unpack(pack(x)) def test_recarrays(): @@ -183,18 +169,16 @@ def test_complex(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") -def test_insert_longblob(): +def test_insert_longblob(schema_any): insert_dj_blob = {"id": 1, "data": [1, 2, 3]} - schema.Longblob.insert1(insert_dj_blob) - assert (schema.Longblob & "id=1").fetch1() == insert_dj_blob - (schema.Longblob & "id=1").delete() + Longblob.insert1(insert_dj_blob) + assert (Longblob & "id=1").fetch1() == insert_dj_blob + (Longblob & "id=1").delete() query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])} - schema.Longblob.insert1(query_mym_blob) - assert (schema.Longblob & "id=1").fetch1()["data"].all() == query_mym_blob[ - "data" - ].all() - (schema.Longblob & "id=1").delete() + Longblob.insert1(query_mym_blob) + assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all() + (Longblob & "id=1").delete() query_32_blob = ( "INSERT INTO djtest_test1.longblob (id, data) VALUES (1, " @@ -207,7 +191,7 @@ def test_insert_longblob(): ) dj.conn().query(query_32_blob).fetchall() dj.blob.use_32bit_dims = True - assert (schema.Longblob & "id=1").fetch1() == { + assert (Longblob & "id=1").fetch1() == { "id": 1, "data": np.rec.array( [ @@ -223,7 +207,7 @@ def test_insert_longblob(): dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")], ), } - (schema.Longblob & "id=1").delete() + (Longblob & "id=1").delete() dj.blob.use_32bit_dims = False diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py new file mode 100644 index 000000000..06154b1fc --- /dev/null +++ b/tests/test_blob_matlab.py @@ -0,0 +1,183 @@ +import numpy as np +import pytest +import datajoint as dj +from datajoint.blob import pack, unpack +from numpy.testing import assert_array_equal + +from . import PREFIX + + +class Blob(dj.Manual): + definition = """ # diverse types of blobs + id : int + ----- + comment : varchar(255) + blob : longblob + """ + + +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_test1", locals(), connection=connection_test) + schema(Blob) + yield schema + schema.drop() + + +@pytest.fixture(scope="module") +def insert_blobs_func(schema): + def insert_blobs(): + """ + This function inserts blobs resulting from the following datajoint-matlab code: + + self.insert({ + 1 'simple string' 'character string' + 2 '1D vector' 1:15:180 + 3 'string array' {'string1' 'string2'} + 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + 5 '3D double array' reshape(1:24, [2,3,4]) + 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) + 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) + }) + + and then dumped using the command + mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql + """ + + schema.connection.query( + """ + INSERT INTO {table_name} VALUES + (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), + (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), + (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), + (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), + (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), + (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), + (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 + ); + """.format( + table_name=Blob.full_table_name + ) + ) + + yield insert_blobs + + +@pytest.fixture(scope="class") +def setup_class(schema, insert_blobs_func): + assert not dj.config["safemode"], "safemode must be disabled" + Blob().delete() + insert_blobs_func() + + +class TestFetch: + @staticmethod + def test_complex_matlab_blobs(setup_class): + """ + test correct de-serialization of various blob types + """ + blobs = Blob().fetch("blob", order_by="KEY") + + blob = blobs[0] # 'simple string' 'character string' + assert blob[0] == "character string" + + blob = blobs[1] # '1D vector' 1:15:180 + assert_array_equal(blob, np.r_[1:180:15][None, :]) + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[2] # 'string array' {'string1' 'string2'} + assert isinstance(blob, dj.MatCell) + assert_array_equal(blob, np.array([["string1", "string2"]])) + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[ + 3 + ] # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") + assert_array_equal(blob.a[0, 0], np.array([[1.0]])) + assert_array_equal(blob.a[0, 1], np.array([[2.0]])) + assert isinstance(blob.b[0, 1], dj.MatStruct) + assert tuple(blob.b[0, 1].C[0, 0].shape) == (5, 5) + b = unpack(pack(blob)) + assert_array_equal(b[0, 0].b[0, 0].c, blob[0, 0].b[0, 0].c) + assert_array_equal(b[0, 1].b[0, 0].C, blob[0, 1].b[0, 0].C) + + blob = blobs[4] # '3D double array' reshape(1:24, [2,3,4]) + assert_array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "float64" + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[5] # reshape(uint8(1:24), [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" + assert_array_equal(blob, unpack(pack(blob))) + + blob = blobs[6] # fftn(reshape(1:24, [2,3,4])) + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" + assert_array_equal(blob, unpack(pack(blob))) + + @staticmethod + def test_complex_matlab_squeeze(setup_class): + """ + test correct de-serialization of various blob types + """ + blob = (Blob & "id=1").fetch1( + "blob", squeeze=True + ) # 'simple string' 'character string' + assert blob == "character string" + + blob = (Blob & "id=2").fetch1( + "blob", squeeze=True + ) # '1D vector' 1:15:180 + assert_array_equal(blob, np.r_[1:180:15]) + + blob = (Blob & "id=3").fetch1( + "blob", squeeze=True + ) # 'string array' {'string1' 'string2'} + assert isinstance(blob, dj.MatCell) + assert_array_equal(blob, np.array(["string1", "string2"])) + + blob = (Blob & "id=4").fetch1( + "blob", squeeze=True + ) # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") + assert_array_equal( + blob.a, + np.array( + [ + 1.0, + 2, + ] + ), + ) + assert isinstance(blob[1].b, dj.MatStruct) + assert tuple(blob[1].b.C.item().shape) == (5, 5) + + blob = (Blob & "id=5").fetch1( + "blob", squeeze=True + ) # '3D double array' reshape(1:24, [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "float64" + + blob = (Blob & "id=6").fetch1( + "blob", squeeze=True + ) # reshape(uint8(1:24), [2,3,4]) + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" + + blob = (Blob & "id=7").fetch1( + "blob", squeeze=True + ) # fftn(reshape(1:24, [2,3,4])) + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" + + def test_iter(self, setup_class): + """ + test iterator over the entity set + """ + from_iter = {d["id"]: d for d in Blob()} + assert len(from_iter) == len(Blob()) + assert from_iter[1]["blob"] == "character string" diff --git a/tests/test_connection.py b/tests/test_connection.py index 1916da951..795d3761e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,8 +5,7 @@ import datajoint as dj from datajoint import DataJointError import numpy as np -from . import CONN_INFO_ROOT, connection_root, connection_test - +from . import CONN_INFO_ROOT from . import PREFIX import pytest diff --git a/tests_old/test_dependencies.py b/tests/test_dependencies.py similarity index 62% rename from tests_old/test_dependencies.py rename to tests/test_dependencies.py index c359b602a..312e5f8ad 100644 --- a/tests_old/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,60 +1,56 @@ -from nose.tools import assert_true, raises, assert_list_equal -from .schema import * +import datajoint as dj +from datajoint import errors +from pytest import raises from datajoint.dependencies import unite_master_parts +from .schema import * def test_unite_master_parts(): - assert_list_equal( - unite_master_parts( - [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`b`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`b__q`", - "`s`.`d`", - "`s`.`a__r`", - ] - ), + assert unite_master_parts( [ "`s`.`a`", "`s`.`a__q`", - "`s`.`a__r`", "`s`.`b`", - "`s`.`b__q`", "`s`.`c`", "`s`.`c__q`", + "`s`.`b__q`", "`s`.`d`", - ], - ) - assert_list_equal( - unite_master_parts( - [ - "`lab`.`#equipment`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method__field_detect_params`", - ] - ), + "`s`.`a__r`", + ] + ) == [ + "`s`.`a`", + "`s`.`a__q`", + "`s`.`a__r`", + "`s`.`b`", + "`s`.`b__q`", + "`s`.`c`", + "`s`.`c__q`", + "`s`.`d`", + ] + assert unite_master_parts( [ "`lab`.`#equipment`", - "`lab`.`#equipment__config`", "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`cells`.`cell_analysis_method__field_detect_params`", "`cells`.`cell_analysis_method_task_type`", "`cells`.`cell_analysis_method_users`", "`cells`.`favorite_selection`", - ], - ) - - -def test_nullable_dependency(): + "`cells`.`cell_analysis_method__cell_selection_params`", + "`lab`.`#equipment__config`", + "`cells`.`cell_analysis_method__field_detect_params`", + ] + ) == [ + "`lab`.`#equipment`", + "`lab`.`#equipment__config`", + "`cells`.`cell_analysis_method`", + "`cells`.`cell_analysis_method__cell_selection_params`", + "`cells`.`cell_analysis_method__field_detect_params`", + "`cells`.`cell_analysis_method_task_type`", + "`cells`.`cell_analysis_method_users`", + "`cells`.`favorite_selection`", + ] + + +def test_nullable_dependency(schema_any): """test nullable unique foreign key""" # Thing C has a nullable dependency on B whose primary key is composite a = ThingA() @@ -80,11 +76,10 @@ def test_nullable_dependency(): c.insert1(dict(a=3, b1=1, b2=1)) c.insert1(dict(a=4, b1=1, b2=2)) - assert_true(len(c) == len(c.fetch()) == 5) + assert len(c) == len(c.fetch()) == 5 -@raises(dj.errors.DuplicateError) -def test_unique_dependency(): +def test_unique_dependency(schema_any): """test nullable unique foreign key""" # Thing C has a nullable dependency on B whose primary key is composite @@ -104,4 +99,5 @@ def test_unique_dependency(): c.insert1(dict(a=0, b1=1, b2=1)) # duplicate foreign key attributes = not ok - c.insert1(dict(a=1, b1=1, b2=1)) + with raises(errors.DuplicateError): + c.insert1(dict(a=1, b1=1, b2=1)) diff --git a/tests/test_erd.py b/tests/test_erd.py new file mode 100644 index 000000000..f1274ec1b --- /dev/null +++ b/tests/test_erd.py @@ -0,0 +1,64 @@ +import datajoint as dj +from .schema_simple import LOCALS_SIMPLE, A, B, D, E, L, OutfitLaunch +from .schema_advanced import * + + +def test_decorator(schema_simp): + assert issubclass(A, dj.Lookup) + assert not issubclass(A, dj.Part) + assert B.database == schema_simp.database + assert issubclass(B.C, dj.Part) + assert B.C.database == schema_simp.database + assert B.C.master is B and E.F.master is E + + +def test_dependencies(schema_simp): + deps = schema_simp.connection.dependencies + deps.load() + assert all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) + assert set(A().children()) == set([B.full_table_name, D.full_table_name]) + assert set(D().parents(primary=True)) == set([A.full_table_name]) + assert set(D().parents(primary=False)) == set([L.full_table_name]) + assert set(deps.descendants(L.full_table_name)).issubset( + cls.full_table_name for cls in (L, D, E, E.F) + ) + + +def test_erd(schema_simp): + assert dj.diagram.diagram_active, "Failed to import networkx and pydot" + erd = dj.ERD(schema_simp, context=LOCALS_SIMPLE) + graph = erd._make_graph() + assert set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) + + +def test_erd_algebra(schema_simp): + erd0 = dj.ERD(B) + erd1 = erd0 + 3 + erd2 = dj.Di(E) - 3 + erd3 = erd1 * erd2 + erd4 = (erd0 + E).add_parts() - B - E + assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B]) + assert erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F)) + assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) + assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E)) + assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) + + +def test_repr_svg(schema_adv): + erd = dj.ERD(schema_adv, context=locals()) + svg = erd._repr_svg_() + assert svg.startswith("") + + +def test_make_image(schema_simp): + erd = dj.ERD(schema_simp, context=locals()) + img = erd.make_image() + assert img.ndim == 3 and img.shape[2] in (3, 4) + + +def test_part_table_parsing(schema_simp): + # https://github.com/datajoint/datajoint-python/issues/882 + erd = dj.Di(schema_simp) + graph = erd._make_graph() + assert "OutfitLaunch" in graph.nodes() + assert "OutfitLaunch.OutfitPiece" in graph.nodes() diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py new file mode 100644 index 000000000..18daa952a --- /dev/null +++ b/tests/test_foreign_keys.py @@ -0,0 +1,47 @@ +from datajoint.declare import declare +from .schema_advanced import * + + +def test_aliased_fk(schema_adv): + person = Person() + parent = Parent() + person.delete() + assert not person + assert not parent + person.fill() + parent.fill() + assert person + assert parent + link = person.proj(parent_name="full_name", parent="person_id") + parents = person * parent * link + parents &= dict(full_name="May K. Hall") + assert set(parents.fetch("parent_name")) == {"Hanna R. Walters", "Russel S. James"} + delete_count = person.delete() + assert delete_count == 16 + + +def test_describe(schema_adv): + """real_definition should match original definition""" + for rel in (LocalSynapse, GlobalSynapse): + describe = rel.describe() + s1 = declare(rel.full_table_name, rel.definition, schema_adv.context)[0].split( + "\n" + ) + s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") + for c1, c2 in zip(s1, s2): + assert c1 == c2 + + +def test_delete(schema_adv): + person = Person() + parent = Parent() + person.delete() + assert not person + assert not parent + person.fill() + parent.fill() + assert parent + original_len = len(parent) + to_delete = len(parent & "11 in (person_id, parent)") + (person & "person_id=11").delete() + assert to_delete and len(parent) == original_len - to_delete diff --git a/tests_old/test_groupby.py b/tests/test_groupby.py similarity index 93% rename from tests_old/test_groupby.py rename to tests/test_groupby.py index 3d3be530e..109972760 100644 --- a/tests_old/test_groupby.py +++ b/tests/test_groupby.py @@ -1,7 +1,7 @@ from .schema_simple import A, D -def test_aggr_with_proj(): +def test_aggr_with_proj(schema_simp): # issue #944 - only breaks with MariaDB # MariaDB implements the SQL:1992 standard that prohibits fields in the select statement that are # not also in the GROUP BY statement. diff --git a/tests/test_hash.py b/tests/test_hash.py new file mode 100644 index 000000000..a88c45316 --- /dev/null +++ b/tests/test_hash.py @@ -0,0 +1,6 @@ +from datajoint import hash + + +def test_hash(): + assert hash.uuid_from_buffer(b"abc").hex == "900150983cd24fb0d6963f7d28e17f72" + assert hash.uuid_from_buffer(b"").hex == "d41d8cd98f00b204e9800998ecf8427e" diff --git a/tests_old/test_json.py b/tests/test_json.py similarity index 98% rename from tests_old/test_json.py rename to tests/test_json.py index b9b13e4ee..760475a1a 100644 --- a/tests_old/test_json.py +++ b/tests/test_json.py @@ -2,12 +2,10 @@ from datajoint.declare import declare import datajoint as dj import numpy as np -from distutils.version import LooseVersion +from packaging.version import Version from . import PREFIX -if LooseVersion(dj.conn().query("select @@version;").fetchone()[0]) >= LooseVersion( - "8.0.0" -): +if Version(dj.conn().query("select @@version;").fetchone()[0]) >= Version("8.0.0"): schema = dj.Schema(PREFIX + "_json") Team = None diff --git a/tests/test_log.py b/tests/test_log.py new file mode 100644 index 000000000..4b6e64613 --- /dev/null +++ b/tests/test_log.py @@ -0,0 +1,5 @@ +def test_log(schema_any): + ts, events = (schema_any.log & 'event like "Declared%%"').fetch( + "timestamp", "event" + ) + assert len(ts) >= 2 diff --git a/tests/test_nan.py b/tests/test_nan.py new file mode 100644 index 000000000..299c0d9f8 --- /dev/null +++ b/tests/test_nan.py @@ -0,0 +1,47 @@ +import numpy as np +import datajoint as dj +from . import PREFIX +import pytest + + +class NanTest(dj.Manual): + definition = """ + id :int + --- + value=null :double + """ + + +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_nantest", connection=connection_test) + schema(NanTest) + yield schema + schema.drop() + + +@pytest.fixture(scope="class") +def setup_class(request, schema): + rel = NanTest() + with dj.config(safemode=False): + rel.delete() + a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) + rel.insert(((i, value) for i, value in enumerate(a))) + request.cls.rel = rel + request.cls.a = a + + +class TestNaNInsert: + def test_insert_nan(self, setup_class): + """Test fetching of null values""" + b = self.rel.fetch("value", order_by="id") + assert (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" + assert np.allclose( + self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] + ), "incorrect storage of floats" + + def test_nulls_do_not_affect_primary_keys(self, setup_class): + """Test against a case that previously caused a bug when skipping existing entries.""" + self.rel.insert( + ((i, value) for i, value in enumerate(self.a)), skip_duplicates=True + ) diff --git a/tests_old/test_plugin.py b/tests/test_plugin.py similarity index 100% rename from tests_old/test_plugin.py rename to tests/test_plugin.py diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py new file mode 100644 index 000000000..d225bccbb --- /dev/null +++ b/tests/test_relation_u.py @@ -0,0 +1,88 @@ +import pytest +import datajoint as dj +from pytest import raises +from .schema import * +from .schema_simple import * + + +@pytest.fixture(scope="class") +def setup_class(request, schema_any): + request.cls.user = User() + request.cls.language = Language() + request.cls.subject = Subject() + request.cls.experiment = Experiment() + request.cls.trial = Trial() + request.cls.ephys = Ephys() + request.cls.channel = Ephys.Channel() + request.cls.img = Image() + request.cls.trash = UberTrash() + + +class TestU: + """ + Test tables: insert, delete + """ + + def test_restriction(self, setup_class): + language_set = {s[1] for s in self.language.contents} + rel = dj.U("language") & self.language + assert list(rel.heading.names) == ["language"] + assert len(rel) == len(language_set) + assert set(rel.fetch("language")) == language_set + # Test for issue #342 + rel = self.trial * dj.U("start_time") + assert list(rel.primary_key) == self.trial.primary_key + ["start_time"] + assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) + assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] + + def test_invalid_restriction(self, setup_class): + with raises(dj.DataJointError): + result = dj.U("color") & dict(color="red") + + def test_ineffective_restriction(self, setup_class): + rel = self.language & dj.U("language") + assert rel.make_sql() == self.language.make_sql() + + def test_join(self, setup_class): + rel = self.experiment * dj.U("experiment_date") + assert self.experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] + + rel = dj.U("experiment_date") * self.experiment + assert self.experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] + + def test_invalid_join(self, setup_class): + with raises(dj.DataJointError): + rel = dj.U("language") * dict(language="English") + + def test_repr_without_attrs(self, setup_class): + """test dj.U() display""" + query = dj.U().aggr(Language, n="count(*)") + repr(query) + + def test_aggregations(self, setup_class): + lang = Language() + # test total aggregation on expression object + n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") + assert n1 == len(lang.fetch()) + # test total aggregation on expression class + n2 = dj.U().aggr(Language, n="count(*)").fetch1("n") + assert n1 == n2 + rel = dj.U("language").aggr(Language, number_of_speakers="count(*)") + assert len(rel) == len(set(l[1] for l in Language.contents)) + assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 + + def test_argmax(self, setup_class): + rel = TTest() + # get the tuples corresponding to the maximum value + mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" + assert mx.fetch("value")[0] == max(rel.fetch("value")) + + def test_aggr(self, setup_class, schema_simp): + rel = ArgmaxTest() + amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") + amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") + assert ( + len(amax1) == len(amax2) == rel.n + ), "Aggregated argmax with join and restriction does not yield the same length." diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py new file mode 100644 index 000000000..c8b7d5a24 --- /dev/null +++ b/tests/test_schema_keywords.py @@ -0,0 +1,50 @@ +from . import PREFIX +import datajoint as dj +import pytest + + +class A(dj.Manual): + definition = """ + a_id: int # a id + """ + + +class B(dj.Manual): + source = None + definition = """ + -> self.source + b_id: int # b id + """ + + class H(dj.Part): + definition = """ + -> master + name: varchar(128) # name + """ + + class C(dj.Part): + definition = """ + -> master + -> master.H + """ + + +class D(B): + source = A + + +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_keywords", connection=connection_test) + schema(A) + schema(D) + yield schema + schema.drop() + + +def test_inherited_part_table(schema): + assert "a_id" in D().heading.attributes + assert "b_id" in D().heading.attributes + assert "a_id" in D.C().heading.attributes + assert "b_id" in D.C().heading.attributes + assert "name" in D.C().heading.attributes diff --git a/tests_old/test_settings.py b/tests/test_settings.py similarity index 69% rename from tests_old/test_settings.py rename to tests/test_settings.py index 63c3dad36..b937d5ad3 100644 --- a/tests_old/test_settings.py +++ b/tests/test_settings.py @@ -1,8 +1,8 @@ import pprint import random import string -from datajoint import settings -from nose.tools import assert_true, assert_equal, raises +import pytest +from datajoint import DataJointError, settings import datajoint as dj import os @@ -14,7 +14,7 @@ def test_load_save(): dj.config.save("tmp.json") conf = settings.Config() conf.load("tmp.json") - assert_true(conf == dj.config, "Two config files do not match.") + assert conf == dj.config os.remove("tmp.json") @@ -25,7 +25,7 @@ def test_singleton(): conf.load("tmp.json") conf["dummy.val"] = 2 - assert_true(conf == dj.config, "Config does not behave like a singleton.") + assert conf == dj.config os.remove("tmp.json") @@ -34,36 +34,36 @@ def test_singleton2(): conf = settings.Config() conf["dummy.val"] = 2 _ = settings.Config() # a new instance should not delete dummy.val - assert_true(conf["dummy.val"] == 2, "Config does not behave like a singleton.") + assert conf["dummy.val"] == 2 -@raises(dj.DataJointError) def test_validator(): """Testing validator""" - dj.config["database.port"] = "harbor" + with pytest.raises(DataJointError): + dj.config["database.port"] = "harbor" def test_del(): """Testing del""" dj.config["peter"] = 2 - assert_true("peter" in dj.config) + assert "peter" in dj.config del dj.config["peter"] - assert_true("peter" not in dj.config) + assert "peter" not in dj.config def test_len(): """Testing len""" - assert_equal(len(dj.config), len(dj.config._conf)) + len(dj.config) == len(dj.config._conf) def test_str(): """Testing str""" - assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) + str(dj.config) == pprint.pformat(dj.config._conf, indent=4) def test_repr(): """Testing repr""" - assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) + repr(dj.config) == pprint.pformat(dj.config._conf, indent=4) def test_save(): @@ -76,7 +76,7 @@ def test_save(): os.rename(settings.LOCALCONFIG, tmpfile) moved = True dj.config.save_local() - assert_true(os.path.isfile(settings.LOCALCONFIG)) + assert os.path.isfile(settings.LOCALCONFIG) if moved: os.rename(tmpfile, settings.LOCALCONFIG) @@ -101,5 +101,5 @@ def test_contextmanager(): """Testing context manager""" dj.config["arbitrary.stuff"] = 7 with dj.config(arbitrary__stuff=10): - assert_true(dj.config["arbitrary.stuff"] == 10) - assert_true(dj.config["arbitrary.stuff"] == 7) + assert dj.config["arbitrary.stuff"] == 10 + assert dj.config["arbitrary.stuff"] == 7 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..936badb1c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,33 @@ +""" +Collection of test cases to test core module. +""" +from datajoint import DataJointError +from datajoint.utils import from_camel_case, to_camel_case +import pytest + + +def setup(): + pass + + +def teardown(): + pass + + +def test_from_camel_case(): + assert from_camel_case("AllGroups") == "all_groups" + with pytest.raises(DataJointError): + from_camel_case("repNames") + with pytest.raises(DataJointError): + from_camel_case("10_all") + with pytest.raises(DataJointError): + from_camel_case("hello world") + with pytest.raises(DataJointError): + from_camel_case("#baisc_names") + + +def test_to_camel_case(): + assert to_camel_case("all_groups") == "AllGroups" + assert to_camel_case("hello") == "Hello" + assert to_camel_case("this_is_a_sample_case") == "ThisIsASampleCase" + assert to_camel_case("This_is_Mixed") == "ThisIsMixed" diff --git a/tests/test_virtual_module.py b/tests/test_virtual_module.py new file mode 100644 index 000000000..bd8a0c754 --- /dev/null +++ b/tests/test_virtual_module.py @@ -0,0 +1,7 @@ +import datajoint as dj +from datajoint.user_tables import UserTable + + +def test_virtual_module(schema_any, connection_test): + module = dj.VirtualModule("module", schema_any.database, connection=connection_test) + assert issubclass(module.Experiment, UserTable) diff --git a/tests_old/test_blob_matlab.py b/tests_old/test_blob_matlab.py deleted file mode 100644 index 6104c9291..000000000 --- a/tests_old/test_blob_matlab.py +++ /dev/null @@ -1,175 +0,0 @@ -import numpy as np -import datajoint as dj -from datajoint.blob import pack, unpack - -from nose.tools import assert_equal, assert_true, assert_tuple_equal, assert_false -from numpy.testing import assert_array_equal - -from . import PREFIX, CONN_INFO - -schema = dj.Schema(PREFIX + "_test1", locals(), connection=dj.conn(**CONN_INFO)) - - -@schema -class Blob(dj.Manual): - definition = """ # diverse types of blobs - id : int - ----- - comment : varchar(255) - blob : longblob - """ - - -def insert_blobs(): - """ - This function inserts blobs resulting from the following datajoint-matlab code: - - self.insert({ - 1 'simple string' 'character string' - 2 '1D vector' 1:15:180 - 3 'string array' {'string1' 'string2'} - 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - 5 '3D double array' reshape(1:24, [2,3,4]) - 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) - 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) - }) - - and then dumped using the command - mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql - """ - - schema.connection.query( - """ - INSERT INTO {table_name} VALUES - (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), - (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), - (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), - (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), - (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), - (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), - (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 - ); - """.format( - table_name=Blob.full_table_name - ) - ) - - -class TestFetch: - @classmethod - def setup_class(cls): - assert_false(dj.config["safemode"], "safemode must be disabled") - Blob().delete() - insert_blobs() - - @staticmethod - def test_complex_matlab_blobs(): - """ - test correct de-serialization of various blob types - """ - blobs = Blob().fetch("blob", order_by="KEY") - - blob = blobs[0] # 'simple string' 'character string' - assert_equal(blob[0], "character string") - - blob = blobs[1] # '1D vector' 1:15:180 - assert_array_equal(blob, np.r_[1:180:15][None, :]) - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[2] # 'string array' {'string1' 'string2'} - assert_true(isinstance(blob, dj.MatCell)) - assert_array_equal(blob, np.array([["string1", "string2"]])) - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[ - 3 - ] # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert_true(isinstance(blob, dj.MatStruct)) - assert_tuple_equal(blob.dtype.names, ("a", "b")) - assert_array_equal(blob.a[0, 0], np.array([[1.0]])) - assert_array_equal(blob.a[0, 1], np.array([[2.0]])) - assert_true(isinstance(blob.b[0, 1], dj.MatStruct)) - assert_tuple_equal(blob.b[0, 1].C[0, 0].shape, (5, 5)) - b = unpack(pack(blob)) - assert_array_equal(b[0, 0].b[0, 0].c, blob[0, 0].b[0, 0].c) - assert_array_equal(b[0, 1].b[0, 0].C, blob[0, 1].b[0, 0].C) - - blob = blobs[4] # '3D double array' reshape(1:24, [2,3,4]) - assert_array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert_true(blob.dtype == "float64") - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[5] # reshape(uint8(1:24), [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "uint8") - assert_array_equal(blob, unpack(pack(blob))) - - blob = blobs[6] # fftn(reshape(1:24, [2,3,4])) - assert_tuple_equal(blob.shape, (2, 3, 4)) - assert_true(blob.dtype == "complex128") - assert_array_equal(blob, unpack(pack(blob))) - - @staticmethod - def test_complex_matlab_squeeze(): - """ - test correct de-serialization of various blob types - """ - blob = (Blob & "id=1").fetch1( - "blob", squeeze=True - ) # 'simple string' 'character string' - assert_equal(blob, "character string") - - blob = (Blob & "id=2").fetch1( - "blob", squeeze=True - ) # '1D vector' 1:15:180 - assert_array_equal(blob, np.r_[1:180:15]) - - blob = (Blob & "id=3").fetch1( - "blob", squeeze=True - ) # 'string array' {'string1' 'string2'} - assert_true(isinstance(blob, dj.MatCell)) - assert_array_equal(blob, np.array(["string1", "string2"])) - - blob = (Blob & "id=4").fetch1( - "blob", squeeze=True - ) # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert_true(isinstance(blob, dj.MatStruct)) - assert_tuple_equal(blob.dtype.names, ("a", "b")) - assert_array_equal( - blob.a, - np.array( - [ - 1.0, - 2, - ] - ), - ) - assert_true(isinstance(blob[1].b, dj.MatStruct)) - assert_tuple_equal(blob[1].b.C.item().shape, (5, 5)) - - blob = (Blob & "id=5").fetch1( - "blob", squeeze=True - ) # '3D double array' reshape(1:24, [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "float64") - - blob = (Blob & "id=6").fetch1( - "blob", squeeze=True - ) # reshape(uint8(1:24), [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "uint8") - - blob = (Blob & "id=7").fetch1( - "blob", squeeze=True - ) # fftn(reshape(1:24, [2,3,4])) - assert_tuple_equal(blob.shape, (2, 3, 4)) - assert_true(blob.dtype == "complex128") - - @staticmethod - def test_iter(): - """ - test iterator over the entity set - """ - from_iter = {d["id"]: d for d in Blob()} - assert_equal(len(from_iter), len(Blob())) - assert_equal(from_iter[1]["blob"], "character string") diff --git a/tests_old/test_erd.py b/tests_old/test_erd.py deleted file mode 100644 index 1a6293431..000000000 --- a/tests_old/test_erd.py +++ /dev/null @@ -1,87 +0,0 @@ -from nose.tools import assert_false, assert_true -import datajoint as dj -from .schema_simple import A, B, D, E, L, schema, OutfitLaunch -from . import schema_advanced - -namespace = locals() - - -class TestERD: - @staticmethod - def setup(): - """ - class-level test setup. Executes before each test method. - """ - - @staticmethod - def test_decorator(): - assert_true(issubclass(A, dj.Lookup)) - assert_false(issubclass(A, dj.Part)) - assert_true(B.database == schema.database) - assert_true(issubclass(B.C, dj.Part)) - assert_true(B.C.database == schema.database) - assert_true(B.C.master is B and E.F.master is E) - - @staticmethod - def test_dependencies(): - deps = schema.connection.dependencies - deps.load() - assert_true( - all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) - ) - assert_true(set(A().children()) == set([B.full_table_name, D.full_table_name])) - assert_true(set(D().parents(primary=True)) == set([A.full_table_name])) - assert_true(set(D().parents(primary=False)) == set([L.full_table_name])) - assert_true( - set(deps.descendants(L.full_table_name)).issubset( - cls.full_table_name for cls in (L, D, E, E.F) - ) - ) - - @staticmethod - def test_erd(): - assert_true(dj.diagram.diagram_active, "Failed to import networkx and pydot") - erd = dj.ERD(schema, context=namespace) - graph = erd._make_graph() - assert_true( - set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) - ) - - @staticmethod - def test_erd_algebra(): - erd0 = dj.ERD(B) - erd1 = erd0 + 3 - erd2 = dj.Di(E) - 3 - erd3 = erd1 * erd2 - erd4 = (erd0 + E).add_parts() - B - E - assert_true(erd0.nodes_to_show == set(cls.full_table_name for cls in [B])) - assert_true( - erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F)) - ) - assert_true( - erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) - ) - assert_true(erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E))) - assert_true( - erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) - ) - - @staticmethod - def test_repr_svg(): - erd = dj.ERD(schema_advanced, context=namespace) - svg = erd._repr_svg_() - assert_true(svg.startswith("")) - - @staticmethod - def test_make_image(): - erd = dj.ERD(schema, context=namespace) - img = erd.make_image() - assert_true(img.ndim == 3 and img.shape[2] in (3, 4)) - - @staticmethod - def test_part_table_parsing(): - # https://github.com/datajoint/datajoint-python/issues/882 - erd = dj.Di(schema) - graph = erd._make_graph() - assert "OutfitLaunch" in graph.nodes() - assert "OutfitLaunch.OutfitPiece" in graph.nodes() diff --git a/tests_old/test_foreign_keys.py b/tests_old/test_foreign_keys.py deleted file mode 100644 index d082960e4..000000000 --- a/tests_old/test_foreign_keys.py +++ /dev/null @@ -1,51 +0,0 @@ -from nose.tools import assert_equal, assert_false, assert_true -from datajoint.declare import declare - -from . import schema_advanced - - -def test_aliased_fk(): - person = schema_advanced.Person() - parent = schema_advanced.Parent() - person.delete() - assert_false(person) - assert_false(parent) - person.fill() - parent.fill() - assert_true(person) - assert_true(parent) - link = person.proj(parent_name="full_name", parent="person_id") - parents = person * parent * link - parents &= dict(full_name="May K. Hall") - assert_equal( - set(parents.fetch("parent_name")), {"Hanna R. Walters", "Russel S. James"} - ) - delete_count = person.delete() - assert delete_count == 16 - - -def test_describe(): - """real_definition should match original definition""" - for rel in (schema_advanced.LocalSynapse, schema_advanced.GlobalSynapse): - describe = rel.describe() - s1 = declare( - rel.full_table_name, rel.definition, schema_advanced.schema.context - )[0].split("\n") - s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") - for c1, c2 in zip(s1, s2): - assert_equal(c1, c2) - - -def test_delete(): - person = schema_advanced.Person() - parent = schema_advanced.Parent() - person.delete() - assert_false(person) - assert_false(parent) - person.fill() - parent.fill() - assert_true(parent) - original_len = len(parent) - to_delete = len(parent & "11 in (person_id, parent)") - (person & "person_id=11").delete() - assert_true(to_delete and len(parent) == original_len - to_delete) diff --git a/tests_old/test_hash.py b/tests_old/test_hash.py deleted file mode 100644 index dc88290eb..000000000 --- a/tests_old/test_hash.py +++ /dev/null @@ -1,7 +0,0 @@ -from nose.tools import assert_equal -from datajoint import hash - - -def test_hash(): - assert_equal(hash.uuid_from_buffer(b"abc").hex, "900150983cd24fb0d6963f7d28e17f72") - assert_equal(hash.uuid_from_buffer(b"").hex, "d41d8cd98f00b204e9800998ecf8427e") diff --git a/tests_old/test_log.py b/tests_old/test_log.py deleted file mode 100644 index 86a48bc37..000000000 --- a/tests_old/test_log.py +++ /dev/null @@ -1,9 +0,0 @@ -from nose.tools import assert_true -from . import schema - - -def test_log(): - ts, events = (schema.schema.log & 'event like "Declared%%"').fetch( - "timestamp", "event" - ) - assert_true(len(ts) >= 2) diff --git a/tests_old/test_nan.py b/tests_old/test_nan.py deleted file mode 100644 index b06848fdf..000000000 --- a/tests_old/test_nan.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -from nose.tools import assert_true -import datajoint as dj -from . import PREFIX, CONN_INFO - -schema = dj.Schema(PREFIX + "_nantest", locals(), connection=dj.conn(**CONN_INFO)) - - -@schema -class NanTest(dj.Manual): - definition = """ - id :int - --- - value=null :double - """ - - -class TestNaNInsert: - @classmethod - def setup_class(cls): - cls.rel = NanTest() - with dj.config(safemode=False): - cls.rel.delete() - a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) - cls.rel.insert(((i, value) for i, value in enumerate(a))) - cls.a = a - - def test_insert_nan(self): - """Test fetching of null values""" - b = self.rel.fetch("value", order_by="id") - assert_true( - (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" - ) - assert_true( - np.allclose( - self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] - ), - "incorrect storage of floats", - ) - - def test_nulls_do_not_affect_primary_keys(self): - """Test against a case that previously caused a bug when skipping existing entries.""" - self.rel.insert( - ((i, value) for i, value in enumerate(self.a)), skip_duplicates=True - ) diff --git a/tests_old/test_relation_u.py b/tests_old/test_relation_u.py deleted file mode 100644 index ff30711b3..000000000 --- a/tests_old/test_relation_u.py +++ /dev/null @@ -1,88 +0,0 @@ -from nose.tools import assert_equal, assert_true, raises, assert_list_equal -from . import schema, schema_simple -import datajoint as dj - - -class TestU: - """ - Test tables: insert, delete - """ - - @classmethod - def setup_class(cls): - cls.user = schema.User() - cls.language = schema.Language() - cls.subject = schema.Subject() - cls.experiment = schema.Experiment() - cls.trial = schema.Trial() - cls.ephys = schema.Ephys() - cls.channel = schema.Ephys.Channel() - cls.img = schema.Image() - cls.trash = schema.UberTrash() - - def test_restriction(self): - language_set = {s[1] for s in self.language.contents} - rel = dj.U("language") & self.language - assert_list_equal(rel.heading.names, ["language"]) - assert_true(len(rel) == len(language_set)) - assert_true(set(rel.fetch("language")) == language_set) - # Test for issue #342 - rel = self.trial * dj.U("start_time") - assert_list_equal(rel.primary_key, self.trial.primary_key + ["start_time"]) - assert_list_equal(rel.primary_key, (rel & "trial_id>3").primary_key) - assert_list_equal((dj.U("start_time") & self.trial).primary_key, ["start_time"]) - - @staticmethod - @raises(dj.DataJointError) - def test_invalid_restriction(): - result = dj.U("color") & dict(color="red") - - def test_ineffective_restriction(self): - rel = self.language & dj.U("language") - assert_true(rel.make_sql() == self.language.make_sql()) - - def test_join(self): - rel = self.experiment * dj.U("experiment_date") - assert_equal(self.experiment.primary_key, ["subject_id", "experiment_id"]) - assert_equal(rel.primary_key, self.experiment.primary_key + ["experiment_date"]) - - rel = dj.U("experiment_date") * self.experiment - assert_equal(self.experiment.primary_key, ["subject_id", "experiment_id"]) - assert_equal(rel.primary_key, self.experiment.primary_key + ["experiment_date"]) - - @staticmethod - @raises(dj.DataJointError) - def test_invalid_join(): - rel = dj.U("language") * dict(language="English") - - def test_repr_without_attrs(self): - """test dj.U() display""" - query = dj.U().aggr(schema.Language, n="count(*)") - repr(query) - - def test_aggregations(self): - lang = schema.Language() - # test total aggregation on expression object - n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") - assert_equal(n1, len(lang.fetch())) - # test total aggregation on expression class - n2 = dj.U().aggr(schema.Language, n="count(*)").fetch1("n") - assert_equal(n1, n2) - rel = dj.U("language").aggr(schema.Language, number_of_speakers="count(*)") - assert_equal(len(rel), len(set(l[1] for l in schema.Language.contents))) - assert_equal((rel & 'language="English"').fetch1("number_of_speakers"), 3) - - def test_argmax(self): - rel = schema.TTest() - # get the tuples corresponding to maximum value - mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" - assert_equal(mx.fetch("value")[0], max(rel.fetch("value"))) - - def test_aggr(self): - rel = schema_simple.ArgmaxTest() - amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") - amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") - assert_true( - len(amax1) == len(amax2) == rel.n, - "Aggregated argmax with join and restriction does not yield same length.", - ) diff --git a/tests_old/test_schema_keywords.py b/tests_old/test_schema_keywords.py deleted file mode 100644 index 49f380f57..000000000 --- a/tests_old/test_schema_keywords.py +++ /dev/null @@ -1,46 +0,0 @@ -from . import PREFIX, CONN_INFO -import datajoint as dj -from nose.tools import assert_true - - -schema = dj.Schema(PREFIX + "_keywords", connection=dj.conn(**CONN_INFO)) - - -@schema -class A(dj.Manual): - definition = """ - a_id: int # a id - """ - - -class B(dj.Manual): - source = None - definition = """ - -> self.source - b_id: int # b id - """ - - class H(dj.Part): - definition = """ - -> master - name: varchar(128) # name - """ - - class C(dj.Part): - definition = """ - -> master - -> master.H - """ - - -@schema -class D(B): - source = A - - -def test_inherited_part_table(): - assert_true("a_id" in D().heading.attributes) - assert_true("b_id" in D().heading.attributes) - assert_true("a_id" in D.C().heading.attributes) - assert_true("b_id" in D.C().heading.attributes) - assert_true("name" in D.C().heading.attributes) diff --git a/tests_old/test_utils.py b/tests_old/test_utils.py deleted file mode 100644 index b5ed96af3..000000000 --- a/tests_old/test_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Collection of test cases to test core module. -""" -from nose.tools import assert_true, assert_raises, assert_equal -from datajoint import DataJointError -from datajoint.utils import from_camel_case, to_camel_case - - -def setup(): - pass - - -def teardown(): - pass - - -def test_from_camel_case(): - assert_equal(from_camel_case("AllGroups"), "all_groups") - with assert_raises(DataJointError): - from_camel_case("repNames") - with assert_raises(DataJointError): - from_camel_case("10_all") - with assert_raises(DataJointError): - from_camel_case("hello world") - with assert_raises(DataJointError): - from_camel_case("#baisc_names") - - -def test_to_camel_case(): - assert_equal(to_camel_case("all_groups"), "AllGroups") - assert_equal(to_camel_case("hello"), "Hello") - assert_equal(to_camel_case("this_is_a_sample_case"), "ThisIsASampleCase") - assert_equal(to_camel_case("This_is_Mixed"), "ThisIsMixed") diff --git a/tests_old/test_virtual_module.py b/tests_old/test_virtual_module.py deleted file mode 100644 index 58180916f..000000000 --- a/tests_old/test_virtual_module.py +++ /dev/null @@ -1,12 +0,0 @@ -from nose.tools import assert_true -import datajoint as dj -from datajoint.user_tables import UserTable -from . import schema -from . import CONN_INFO - - -def test_virtual_module(): - module = dj.VirtualModule( - "module", schema.schema.database, connection=dj.conn(**CONN_INFO) - ) - assert_true(issubclass(module.Experiment, UserTable))