diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 68d5395d5..0c6835714 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,8 @@ Bugfixes: - Fixed missing table/column comment generation for ``ForeignKeyField`` and ``ManyToManyField`` - Fixed comment generation to escape properly for ``SQLite`` - Fixed comment generation for ``PostgreSQL`` to not duplicate comments +- Fixed generation of schema for fields that defined custom ``source_field`` values defined +- Fixed working with Models that have fields with custom ``source_field`` values defined Docs/examples: ^^^^^^^^^^^^^^ diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 63a95a3d8..747b8d8ae 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -27,7 +27,7 @@ logger = logging.getLogger("tortoise") -if sys.version_info < (3, 6): +if sys.version_info < (3, 6): # pragma: nocoverage warnings.warn("Tortoise-ORM is soon going to require Python 3.6", DeprecationWarning) @@ -99,7 +99,6 @@ def split_reference(reference: str) -> Tuple[str, str]: related_model = get_related_model(related_app_name, related_model_name) key_field = "{}_id".format(field) - fk_object.source_field = key_field key_fk_object = deepcopy(related_model._meta.pk) key_fk_object.pk = False key_fk_object.index = fk_object.index @@ -107,6 +106,12 @@ def split_reference(reference: str) -> Tuple[str, str]: key_fk_object.null = fk_object.null key_fk_object.generated = fk_object.generated key_fk_object.reference = fk_object + if fk_object.source_field: + key_fk_object.source_field = fk_object.source_field + fk_object.source_field = key_field + else: + fk_object.source_field = key_field + key_fk_object.source_field = key_field model._meta.add_field(key_field, key_fk_object) fk_object.type = related_model diff --git a/tortoise/backends/asyncpg/executor.py b/tortoise/backends/asyncpg/executor.py index 7e30b2c1f..66ff69fb7 100644 --- a/tortoise/backends/asyncpg/executor.py +++ b/tortoise/backends/asyncpg/executor.py @@ -24,5 +24,6 @@ def _prepare_insert_statement(self, columns: List[str]) -> str: async def _process_insert_result(self, instance: Model, results: Optional[asyncpg.Record]): if results: generated_fields = self.model._meta.generated_db_fields + db_projection = instance._meta.fields_db_projection_reverse for key, val in zip(generated_fields, results): - setattr(instance, key, val) + setattr(instance, db_projection[key], val) diff --git a/tortoise/backends/asyncpg/schema_generator.py b/tortoise/backends/asyncpg/schema_generator.py index 21e574a4e..113c9fcec 100644 --- a/tortoise/backends/asyncpg/schema_generator.py +++ b/tortoise/backends/asyncpg/schema_generator.py @@ -42,6 +42,8 @@ def _column_comment_generator(self, table: str, column: str, comment: str) -> st return "" def _post_table_hook(self) -> str: - val = "\n" + "\n".join(self.comments_array) + val = "\n".join(self.comments_array) self.comments_array = [] - return val + if val: + return "\n" + val + return "" diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 0e0bdc92a..84407399e 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -214,7 +214,7 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_ ) for e in raw_results } - related_object_list = [related_query.model(_from_db=True, **e) for e in raw_results] + related_object_list = [related_query.model._init_from_db(**e) for e in raw_results] await self.__class__( model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map ).fetch_for_list(related_object_list) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 7891d880f..b9a966852 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -11,13 +11,13 @@ class BaseSchemaGenerator: FIELD_TEMPLATE = '"{name}" {type} {nullable} {unique}{primary}{comment}' INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});' UNIQUE_CONSTRAINT_CREATE_TEMPLATE = "UNIQUE ({fields})" - FK_TEMPLATE = ' REFERENCES "{table}" (id) ON DELETE {on_delete}{comment}' + FK_TEMPLATE = ' REFERENCES "{table}" ({field}) ON DELETE {on_delete}{comment}' M2M_TABLE_TEMPLATE = ( 'CREATE TABLE {exists}"{table_name}" (\n' - ' "{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" (id)' - " ON DELETE CASCADE,\n" - ' "{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" (id)' - " ON DELETE CASCADE\n" + ' "{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}"' + " ({backward_field}) ON DELETE CASCADE,\n" + ' "{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}"' + " ({forward_field}) ON DELETE CASCADE\n" "){comment};" ) @@ -145,7 +145,7 @@ def _get_table_sql(self, model, safe=True) -> dict: else "" ) if isinstance(field_object, (fields.IntField, fields.BigIntField)) and field_object.pk: - fields_to_create.append(self._get_primary_key_create_string(field_name, comment)) + fields_to_create.append(self._get_primary_key_create_string(db_field, comment)) continue nullable = "NOT NULL" if not field_object.null else "" unique = "UNIQUE" if field_object.unique else "" @@ -170,6 +170,7 @@ def _get_table_sql(self, model, safe=True) -> dict: ) field_creation_string += self.FK_TEMPLATE.format( table=field_object.reference.type._meta.table, + field=field_object.reference.type._meta.db_pk_field, on_delete=field_object.reference.on_delete, comment=comment, ) @@ -177,7 +178,7 @@ def _get_table_sql(self, model, safe=True) -> dict: fields_to_create.append(field_creation_string) if field_object.index: - fields_with_index.append(field_name) + fields_with_index.append(db_field) if model._meta.unique_together is not None: unique_together_sqls = [] @@ -244,6 +245,8 @@ def _get_table_sql(self, model, safe=True) -> dict: table_name=field_object.through, backward_table=model._meta.table, forward_table=field_object.type._meta.table, + backward_field=model._meta.db_pk_field, + forward_field=field_object.type._meta.db_pk_field, backward_key=field_object.backward_key, backward_type=self._get_field_type(model._meta.pk), forward_key=field_object.forward_key, @@ -251,7 +254,7 @@ def _get_table_sql(self, model, safe=True) -> dict: comment=self._table_comment_generator( table=field_object.through, comment=field_object.description ) - if model._meta.table_description + if field_object.description else "", ) m2m_create_string += self._post_table_hook() diff --git a/tortoise/backends/mysql/schema_generator.py b/tortoise/backends/mysql/schema_generator.py index c74123eab..2ba5023fe 100644 --- a/tortoise/backends/mysql/schema_generator.py +++ b/tortoise/backends/mysql/schema_generator.py @@ -6,13 +6,13 @@ class MySQLSchemaGenerator(BaseSchemaGenerator): TABLE_CREATE_TEMPLATE = "CREATE TABLE {exists}`{table_name}` ({fields}){comment};" INDEX_CREATE_TEMPLATE = "CREATE INDEX `{index_name}` ON `{table_name}` ({fields});" FIELD_TEMPLATE = "`{name}` {type} {nullable} {unique}{comment}" - FK_TEMPLATE = " REFERENCES `{table}` (`id`) ON DELETE {on_delete}{comment}" + FK_TEMPLATE = " REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}{comment}" M2M_TABLE_TEMPLATE = ( "CREATE TABLE `{table_name}` (\n" - " `{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`id`)" - " ON DELETE CASCADE,\n" - " `{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`id`)" - " ON DELETE CASCADE\n" + " `{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}`" + " (`{backward_field}`) ON DELETE CASCADE,\n" + " `{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}`" + " (`{forward_field}`) ON DELETE CASCADE\n" "){comment};" ) FIELD_TYPE_MAP = { diff --git a/tortoise/filters.py b/tortoise/filters.py index 9a9a445e4..8ae21b909 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -158,60 +158,97 @@ def get_filters_for_field( return get_m2m_filters(field_name, field) if isinstance(field, fields.BackwardFKRelation): return get_backward_fk_filters(field_name, field) + actual_field_name = field_name + if field_name == "pk" and field: + actual_field_name = field.model_field_name return { - field_name: {"field": source_field, "operator": operator.eq}, - "{}__not".format(field_name): {"field": source_field, "operator": not_equal}, + field_name: { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.eq, + }, + "{}__not".format(field_name): { + "field": actual_field_name, + "source_field": source_field, + "operator": not_equal, + }, "{}__in".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": is_in, "value_encoder": list_encoder, }, "{}__not_in".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": not_in, "value_encoder": list_encoder, }, "{}__isnull".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": is_null, "value_encoder": bool_encoder, }, "{}__not_isnull".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": not_null, "value_encoder": bool_encoder, }, - "{}__gte".format(field_name): {"field": source_field, "operator": operator.ge}, - "{}__lte".format(field_name): {"field": source_field, "operator": operator.le}, - "{}__gt".format(field_name): {"field": source_field, "operator": operator.gt}, - "{}__lt".format(field_name): {"field": source_field, "operator": operator.lt}, + "{}__gte".format(field_name): { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.ge, + }, + "{}__lte".format(field_name): { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.le, + }, + "{}__gt".format(field_name): { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.gt, + }, + "{}__lt".format(field_name): { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.lt, + }, "{}__contains".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": contains, "value_encoder": string_encoder, }, "{}__startswith".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": starts_with, "value_encoder": string_encoder, }, "{}__endswith".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": ends_with, "value_encoder": string_encoder, }, "{}__icontains".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": insensitive_contains, "value_encoder": string_encoder, }, "{}__istartswith".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": insensitive_starts_with, "value_encoder": string_encoder, }, "{}__iendswith".format(field_name): { - "field": source_field, + "field": actual_field_name, + "source_field": source_field, "operator": insensitive_ends_with, "value_encoder": string_encoder, }, diff --git a/tortoise/models.py b/tortoise/models.py index 341ff16c3..1287bc607 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -283,10 +283,10 @@ class Model(metaclass=ModelMeta): # I don' like this here, but it makes autocompletion and static analysis much happier _meta = MetaInfo(None) - def __init__(self, *args, _from_db: bool = False, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: # self._meta is a very common attribute lookup, lets cache it. meta = self._meta - self._saved_in_db = _from_db or (meta.pk_attr in kwargs and meta.pk.generated) + self._saved_in_db = meta.pk_attr in kwargs and meta.pk.generated self._init_lazy_fkm2m() # Assign values and do type conversions @@ -311,8 +311,9 @@ def _init_from_db(cls, **kwargs) -> MODEL_TYPE: meta = self._meta for key, value in kwargs.items(): - if key in meta.fields: - setattr(self, key, meta.fields_map[key].to_python_value(value)) + model_field = meta.fields_db_projection_reverse.get(key) + if model_field: + setattr(self, model_field, meta.fields_map[model_field].to_python_value(value)) return self diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 8b3a3479c..b6480166a 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -34,7 +34,7 @@ def _process_filter_kwarg(model, key, value) -> Tuple[Criterion, Optional[Tuple[ if param.get("value_encoder") else model._meta.db.executor_class._field_to_db(field_object, value, model) ) - criterion = param["operator"](getattr(table, param["field"]), encoded_value) + criterion = param["operator"](getattr(table, param["source_field"]), encoded_value) return criterion, join diff --git a/tortoise/tests/models_schema_create.py b/tortoise/tests/models_schema_create.py index a45bc4be4..688eb80e7 100644 --- a/tortoise/tests/models_schema_create.py +++ b/tortoise/tests/models_schema_create.py @@ -6,7 +6,7 @@ class Tournament(Model): - id = fields.IntField(pk=True) + tid = fields.IntField(pk=True) name = fields.TextField(description="Tournament name", index=True) created = fields.DatetimeField(auto_now_add=True, description="Created */'`/* datetime") @@ -23,7 +23,7 @@ class Event(Model): participants = fields.ManyToManyField( "models.Team", related_name="events", - through="event_team", + through="teamevents", description="How participants relate", ) modified = fields.DatetimeField(auto_now=True) @@ -36,6 +36,26 @@ class Meta: class Team(Model): name = fields.CharField(max_length=50, pk=True, description="The TEAM name (and PK)") + manager = fields.ForeignKeyField("models.Team", related_name="team_members", null=True) + talks_to = fields.ManyToManyField("models.Team", related_name="gets_talked_to") class Meta: table_description = "The TEAMS!" + + +class SourceFields(Model): + id = fields.IntField(pk=True, source_field="sometable_id") + chars = fields.CharField(max_length=255, source_field="some_chars_table", index=True) + fk = fields.ForeignKeyField( + "models.SourceFields", related_name="team_members", null=True, source_field="fk_sometable" + ) + rel_to = fields.ManyToManyField( + "models.SourceFields", + related_name="rel_from", + through="sometable_self", + forward_key="sts_forward", + backward_key="backward_sts", + ) + + class Meta: + table = "sometable" diff --git a/tortoise/tests/test_generate_schema.py b/tortoise/tests/test_generate_schema.py index abf920dee..58d5b2baa 100644 --- a/tortoise/tests/test_generate_schema.py +++ b/tortoise/tests/test_generate_schema.py @@ -148,11 +148,19 @@ async def test_schema(self): sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) self.assertEqual( sql.strip(), - """CREATE TABLE "team" ( - "name" VARCHAR(50) NOT NULL PRIMARY KEY /* The TEAM name (and PK) */ + """ +CREATE TABLE "sometable" ( + "sometable_id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "some_chars_table" VARCHAR(255) NOT NULL, + "fk_sometable" INT REFERENCES "sometable" (sometable_id) ON DELETE CASCADE +); +CREATE INDEX "sometable_some_ch_115115_idx" ON "sometable" (some_chars_table); +CREATE TABLE "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY /* The TEAM name (and PK) */, + "manager_id" VARCHAR(50) /* The TEAM name (and PK) */ REFERENCES "team" (name) ON DELETE CASCADE ) /* The TEAMS! */; CREATE TABLE "tournament" ( - "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "tid" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "name" TEXT NOT NULL /* Tournament name */, "created" TIMESTAMP NOT NULL /* Created *\\/'`\\/* datetime */ ) /* What Tournaments *\\/'`\\/* we have */; @@ -163,12 +171,21 @@ async def test_schema(self): "modified" TIMESTAMP NOT NULL, "prize" VARCHAR(40), "token" VARCHAR(100) NOT NULL UNIQUE /* Unique token */, - "tournament_id" INT NOT NULL REFERENCES "tournament" (id) ON DELETE CASCADE /* FK to tournament */ + "tournament_id" INT NOT NULL REFERENCES "tournament" (tid) ON DELETE CASCADE /* FK to tournament */ ) /* This table contains a list of all the events */; -CREATE TABLE "event_team" ( +CREATE TABLE "sometable_self" ( + "backward_sts" INT NOT NULL REFERENCES "sometable" (sometable_id) ON DELETE CASCADE, + "sts_forward" INT NOT NULL REFERENCES "sometable" (sometable_id) ON DELETE CASCADE +); +CREATE TABLE "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE +); +CREATE TABLE "teamevents" ( "event_id" INT NOT NULL REFERENCES "event" (id) ON DELETE CASCADE, - "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (id) ON DELETE CASCADE -) /* How participants relate */;""", # noqa + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE +) /* How participants relate */; +""".strip(), # noqa ) @@ -234,11 +251,19 @@ async def test_schema(self): sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) self.assertEqual( sql.strip(), - """CREATE TABLE `team` ( - `name` VARCHAR(50) NOT NULL COMMENT 'The TEAM name (and PK)' + """ +CREATE TABLE `sometable` ( + `sometable_id` INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, + `some_chars_table` VARCHAR(255) NOT NULL, + `fk_sometable` INT REFERENCES `sometable` (`sometable_id`) ON DELETE CASCADE +); +CREATE INDEX `sometable_some_ch_115115_idx` ON `sometable` (some_chars_table); +CREATE TABLE `team` ( + `name` VARCHAR(50) NOT NULL COMMENT 'The TEAM name (and PK)', + `manager_id` VARCHAR(50) COMMENT 'The TEAM name (and PK)' REFERENCES `team` (`name`) ON DELETE CASCADE ) COMMENT='The TEAMS!'; CREATE TABLE `tournament` ( - `id` INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, + `tid` INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, `name` TEXT NOT NULL COMMENT 'Tournament name', `created` DATETIME(6) NOT NULL COMMENT 'Created */\\'`/* datetime' ) COMMENT='What Tournaments */\\'`/* we have'; @@ -249,12 +274,21 @@ async def test_schema(self): `modified` DATETIME(6) NOT NULL, `prize` DECIMAL(10,2), `token` VARCHAR(100) NOT NULL UNIQUE COMMENT 'Unique token', - `tournament_id` INT NOT NULL REFERENCES `tournament` (`id`) ON DELETE CASCADE COMMENT 'FK to tournament' + `tournament_id` INT NOT NULL REFERENCES `tournament` (`tid`) ON DELETE CASCADE COMMENT 'FK to tournament' ) COMMENT='This table contains a list of all the events'; -CREATE TABLE `event_team` ( +CREATE TABLE `sometable_self` ( + `backward_sts` INT NOT NULL REFERENCES `sometable` (`sometable_id`) ON DELETE CASCADE, + `sts_forward` INT NOT NULL REFERENCES `sometable` (`sometable_id`) ON DELETE CASCADE +); +CREATE TABLE `team_team` ( + `team_rel_id` VARCHAR(50) NOT NULL REFERENCES `team` (`name`) ON DELETE CASCADE, + `team_id` VARCHAR(50) NOT NULL REFERENCES `team` (`name`) ON DELETE CASCADE +); +CREATE TABLE `teamevents` ( `event_id` INT NOT NULL REFERENCES `event` (`id`) ON DELETE CASCADE, - `team_id` VARCHAR(50) NOT NULL REFERENCES `team` (`id`) ON DELETE CASCADE -) COMMENT='How participants relate';""", # noqa + `team_id` VARCHAR(50) NOT NULL REFERENCES `team` (`name`) ON DELETE CASCADE +) COMMENT='How participants relate'; +""".strip(), # noqa ) @@ -307,13 +341,22 @@ async def test_schema(self): sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) self.assertEqual( sql.strip(), - """CREATE TABLE "team" ( - "name" VARCHAR(50) NOT NULL PRIMARY KEY + """ +CREATE TABLE "sometable" ( + "sometable_id" SERIAL NOT NULL PRIMARY KEY, + "some_chars_table" VARCHAR(255) NOT NULL, + "fk_sometable" INT REFERENCES "sometable" (sometable_id) ON DELETE CASCADE +); +CREATE INDEX "sometable_some_ch_115115_idx" ON "sometable" (some_chars_table); +CREATE TABLE "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY, + "manager_id" VARCHAR(50) REFERENCES "team" (name) ON DELETE CASCADE ); COMMENT ON COLUMN team.name IS 'The TEAM name (and PK)'; +COMMENT ON COLUMN team.manager_id IS 'The TEAM name (and PK)'; COMMENT ON TABLE team IS 'The TEAMS!'; CREATE TABLE "tournament" ( - "id" SERIAL NOT NULL PRIMARY KEY, + "tid" SERIAL NOT NULL PRIMARY KEY, "name" TEXT NOT NULL, "created" TIMESTAMP NOT NULL ); @@ -327,15 +370,24 @@ async def test_schema(self): "modified" TIMESTAMP NOT NULL, "prize" DECIMAL(10,2), "token" VARCHAR(100) NOT NULL UNIQUE, - "tournament_id" INT NOT NULL REFERENCES "tournament" (id) ON DELETE CASCADE + "tournament_id" INT NOT NULL REFERENCES "tournament" (tid) ON DELETE CASCADE ); COMMENT ON COLUMN event.id IS 'Event ID'; COMMENT ON COLUMN event.token IS 'Unique token'; COMMENT ON COLUMN event.tournament_id IS 'FK to tournament'; COMMENT ON TABLE event IS 'This table contains a list of all the events'; -CREATE TABLE "event_team" ( +CREATE TABLE "sometable_self" ( + "backward_sts" INT NOT NULL REFERENCES "sometable" (sometable_id) ON DELETE CASCADE, + "sts_forward" INT NOT NULL REFERENCES "sometable" (sometable_id) ON DELETE CASCADE +); +CREATE TABLE "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE +); +CREATE TABLE "teamevents" ( "event_id" INT NOT NULL REFERENCES "event" (id) ON DELETE CASCADE, - "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (id) ON DELETE CASCADE + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" (name) ON DELETE CASCADE ); -COMMENT ON TABLE event_team IS 'How participants relate';""", +COMMENT ON TABLE teamevents IS 'How participants relate'; +""".strip(), ) diff --git a/tortoise/tests/test_source_field.py b/tortoise/tests/test_source_field.py new file mode 100644 index 000000000..1447acace --- /dev/null +++ b/tortoise/tests/test_source_field.py @@ -0,0 +1,166 @@ +""" +This module does a series of use tests on a non-source_field model, + and then the EXACT same ones on a source_field'ed model. + +This is to test that behaviour doesn't change when one defined source_field parameters. +""" +from tortoise.contrib import test +from tortoise.tests.testmodels import SourceFields, StraightFields + + +class StraightFieldTests(test.TestCase): + def setUp(self) -> None: + self.model = StraightFields + + async def test_get_all(self): + obj1 = await self.model.create(chars="aaa") + self.assertIsNotNone(obj1.id, str(dir(obj1))) + obj2 = await self.model.create(chars="bbb") + + objs = await self.model.all() + self.assertEqual(objs, [obj1, obj2]) + + async def test_get_by_pk(self): + obj = await self.model.create(chars="aaa") + obj1 = await self.model.get(id=obj.id) + + self.assertEqual(obj, obj1) + + async def test_get_by_chars(self): + obj = await self.model.create(chars="aaa") + obj1 = await self.model.get(chars="aaa") + + self.assertEqual(obj, obj1) + + async def test_get_fk_forward_fetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + + obj2a = await self.model.get(id=obj2.id) + await obj2a.fetch_related("fk") + self.assertEqual(obj2, obj2a) + self.assertEqual(obj1, obj2a.fk) + + async def test_get_fk_forward_prefetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + + obj2a = await self.model.get(id=obj2.id).prefetch_related("fk") + self.assertEqual(obj2, obj2a) + self.assertEqual(obj1, obj2a.fk) + + async def test_get_fk_reverse_await(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + obj3 = await self.model.create(chars="ccc", fk=obj1) + + obj1a = await self.model.get(id=obj1.id) + self.assertEqual(await obj1a.fkrev, [obj2, obj3]) + + async def test_get_fk_reverse_filter(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + obj3 = await self.model.create(chars="ccc", fk=obj1) + + objs = await self.model.filter(fk=obj1) + self.assertEqual(objs, [obj2, obj3]) + + async def test_get_fk_reverse_async_for(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + obj3 = await self.model.create(chars="ccc", fk=obj1) + + obj1a = await self.model.get(id=obj1.id) + objs = [] + async for obj in obj1a.fkrev: + objs.append(obj) + self.assertEqual(objs, [obj2, obj3]) + + async def test_get_fk_reverse_fetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + obj3 = await self.model.create(chars="ccc", fk=obj1) + + obj1a = await self.model.get(id=obj1.id) + await obj1a.fetch_related("fkrev") + self.assertEqual(list(obj1a.fkrev), [obj2, obj3]) + + async def test_get_fk_reverse_prefetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb", fk=obj1) + obj3 = await self.model.create(chars="ccc", fk=obj1) + + obj1a = await self.model.get(id=obj1.id).prefetch_related("fkrev") + self.assertEqual(list(obj1a.fkrev), [obj2, obj3]) + + async def test_get_m2m_forward_await(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + obj2a = await self.model.get(id=obj2.id) + self.assertEqual(await obj2a.rel_from, [obj1]) + + obj1a = await self.model.get(id=obj1.id) + self.assertEqual(await obj1a.rel_to, [obj2]) + + async def test_get_m2m_reverse_await(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj2.rel_from.add(obj1) + + obj2a = await self.model.get(id=obj2.id) + self.assertEqual(await obj2a.rel_from, [obj1]) + + obj1a = await self.model.get(id=obj1.id) + self.assertEqual(await obj1a.rel_to, [obj2]) + + async def test_get_m2m_filter(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + rel_froms = await self.model.filter(rel_from=obj1) + self.assertEqual(rel_froms, [obj2]) + + rel_tos = await self.model.filter(rel_to=obj2) + self.assertEqual(rel_tos, [obj1]) + + async def test_get_m2m_forward_fetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + obj2a = await self.model.get(id=obj2.id) + await obj2a.fetch_related("rel_from") + self.assertEqual(list(obj2a.rel_from), [obj1]) + + async def test_get_m2m_reverse_fetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + obj1a = await self.model.get(id=obj1.id) + await obj1a.fetch_related("rel_to") + self.assertEqual(list(obj1a.rel_to), [obj2]) + + async def test_get_m2m_forward_prefetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + obj2a = await self.model.get(id=obj2.id).prefetch_related("rel_from") + self.assertEqual(list(obj2a.rel_from), [obj1]) + + async def test_get_m2m_reverse_prefetch_related(self): + obj1 = await self.model.create(chars="aaa") + obj2 = await self.model.create(chars="bbb") + await obj1.rel_to.add(obj2) + + obj1a = await self.model.get(id=obj1.id).prefetch_related("rel_to") + self.assertEqual(list(obj1a.rel_to), [obj2]) + + +class SourceFieldTests(StraightFieldTests): + def setUp(self) -> None: + self.model = SourceFields # type: ignore diff --git a/tortoise/tests/testmodels.py b/tortoise/tests/testmodels.py index 84a26b20b..8a4165707 100644 --- a/tortoise/tests/testmodels.py +++ b/tortoise/tests/testmodels.py @@ -341,3 +341,28 @@ async def full_hierarchy__fetch_related(self, level=0): for member in self.team_members: text.append(await member.full_hierarchy__fetch_related(level + 1)) return "\n".join(text) + + +class StraightFields(Model): + id = fields.IntField(pk=True) + chars = fields.CharField(max_length=255, index=True) + fk = fields.ForeignKeyField("models.StraightFields", related_name="fkrev", null=True) + rel_to = fields.ManyToManyField("models.StraightFields", related_name="rel_from") + + +class SourceFields(Model): + id = fields.IntField(pk=True, source_field="sometable_id") + chars = fields.CharField(max_length=255, source_field="some_chars_table", index=True) + fk = fields.ForeignKeyField( + "models.SourceFields", related_name="fkrev", null=True, source_field="fk_sometable" + ) + rel_to = fields.ManyToManyField( + "models.SourceFields", + related_name="rel_from", + through="sometable_self", + forward_key="sts_forward", + backward_key="backward_sts", + ) + + class Meta: + table = "sometable"