From 99685a400520a288e6e44ed51a879b1d6fa0e4a0 Mon Sep 17 00:00:00 2001 From: Andrey Bondar Date: Wed, 3 Apr 2019 18:12:13 +0300 Subject: [PATCH] Generic primary key implementation --- tortoise/__init__.py | 59 ++--- tortoise/backends/asyncpg/client.py | 9 +- tortoise/backends/asyncpg/executor.py | 15 +- tortoise/backends/asyncpg/fields.py | 20 ++ tortoise/backends/asyncpg/schema_generator.py | 2 +- tortoise/backends/base/client.py | 3 +- tortoise/backends/base/executor.py | 47 ++-- tortoise/backends/base/schema_generator.py | 16 +- tortoise/backends/mysql/client.py | 10 +- tortoise/backends/mysql/executor.py | 30 ++- tortoise/backends/sqlite/client.py | 17 +- tortoise/backends/sqlite/executor.py | 31 ++- tortoise/fields.py | 107 ++++++-- tortoise/filters.py | 22 +- tortoise/models.py | 247 ++++++++++++++---- tortoise/query_utils.py | 44 +++- tortoise/tests/fields/test_uuid.py | 28 ++ tortoise/tests/test_primary_key.py | 70 +++++ tortoise/tests/testmodels.py | 23 ++ 19 files changed, 640 insertions(+), 160 deletions(-) create mode 100644 tortoise/backends/asyncpg/fields.py create mode 100644 tortoise/tests/fields/test_uuid.py create mode 100644 tortoise/tests/test_primary_key.py diff --git a/tortoise/__init__.py b/tortoise/__init__.py index faae15938..e98988b86 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -48,6 +48,18 @@ def _init_relations(cls) -> None: reference = field_object.model_name related_app_name, related_model_name = reference.split(".") related_model = cls.apps[related_app_name][related_model_name] + + key_field = "{}_id".format(field) + field_object.source_field = key_field + key_field_object = deepcopy(related_model._meta.pk) + key_field_object.pk = False + key_field_object.index = field_object.index + key_field_object.default = field_object.default + key_field_object.null = field_object.null + key_field_object.generated = field_object.generated + key_field_object.reference = field_object + model._meta.add_field(key_field, key_field_object) + field_object.type = related_model backward_relation_name = field_object.related_name if not backward_relation_name: @@ -59,35 +71,27 @@ def _init_relations(cls) -> None: ) ) fk_relation = fields.BackwardFKRelation(model, "{}_id".format(field)) - setattr(related_model, backward_relation_name, fk_relation) - related_model._meta.filters.update( - get_backward_fk_filters(backward_relation_name, fk_relation) - ) - - related_model._meta.backward_fk_fields.add(backward_relation_name) - related_model._meta.fetch_fields.add(backward_relation_name) - related_model._meta.fields_map[backward_relation_name] = fk_relation - related_model._meta.fields.add(backward_relation_name) + related_model._meta.add_field(backward_relation_name, fk_relation) for field in model._meta.m2m_fields: - field_mobject = cast(fields.ManyToManyField, model._meta.fields_map[field]) - if field_mobject._generated: + field_object = cast(fields.ManyToManyField, model._meta.fields_map[field]) + if field_object._generated: continue - backward_key = field_mobject.backward_key + backward_key = field_object.backward_key if not backward_key: backward_key = "{}_id".format(model._meta.table) - field_mobject.backward_key = backward_key + field_object.backward_key = backward_key - reference = field_mobject.model_name + reference = field_object.model_name related_app_name, related_model_name = reference.split(".") related_model = cls.apps[related_app_name][related_model_name] - field_mobject.type = related_model + field_object.type = related_model - backward_relation_name = field_mobject.related_name + backward_relation_name = field_object.related_name if not backward_relation_name: - backward_relation_name = field_mobject.related_name = "{}_through".format( + backward_relation_name = field_object.related_name = "{}_through".format( model._meta.table ) if backward_relation_name in related_model._meta.fields: @@ -97,35 +101,28 @@ def _init_relations(cls) -> None: ) ) - if not field_mobject.through: + if not field_object.through: related_model_table_name = ( related_model._meta.table if related_model._meta.table else related_model.__name__.lower() ) - field_mobject.through = "{}_{}".format( + field_object.through = "{}_{}".format( model._meta.table, related_model_table_name ) m2m_relation = fields.ManyToManyField( "{}.{}".format(app_name, model_name), - field_mobject.through, - forward_key=field_mobject.backward_key, - backward_key=field_mobject.forward_key, + field_object.through, + forward_key=field_object.backward_key, + backward_key=field_object.forward_key, related_name=field, type=model, ) m2m_relation._generated = True - setattr(related_model, backward_relation_name, m2m_relation) - model._meta.filters.update(get_m2m_filters(field, field_mobject)) - related_model._meta.filters.update( - get_m2m_filters(backward_relation_name, m2m_relation) - ) - related_model._meta.m2m_fields.add(backward_relation_name) - related_model._meta.fetch_fields.add(backward_relation_name) - related_model._meta.fields_map[backward_relation_name] = m2m_relation - related_model._meta.fields.add(backward_relation_name) + model._meta.filters.update(get_m2m_filters(field, field_object)) + related_model._meta.add_field(backward_relation_name, m2m_relation) @classmethod def _discover_client_class(cls, engine: str) -> BaseDBAsyncClient: diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 47c138e50..7b41ddd90 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -116,7 +116,12 @@ def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> "TransactionWrapper": - return self._transaction_class(self.connection_name, self._connection, self._lock) + return self._transaction_class( + connection_name=self.connection_name, + connection=self._connection, + lock=self._lock, + fetch_inserted=self.fetch_inserted, + ) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: @@ -124,7 +129,7 @@ async def execute_insert(self, query: str, values: list) -> int: self.log.debug("%s: %s", query, values) # TODO: Cache prepared statement stmt = await connection.prepare(query) - return await stmt.fetchval(*values) + return await stmt.fetch(*values) @translate_exceptions async def execute_query(self, query: str) -> List[dict]: diff --git a/tortoise/backends/asyncpg/executor.py b/tortoise/backends/asyncpg/executor.py index fb12587a7..7623f7c6e 100644 --- a/tortoise/backends/asyncpg/executor.py +++ b/tortoise/backends/asyncpg/executor.py @@ -1,18 +1,25 @@ -from typing import List +from typing import List, Any from pypika import Parameter, Table +from tortoise import Model from tortoise.backends.base.executor import BaseExecutor class AsyncpgExecutor(BaseExecutor): - EXPLAIN_PREFIX = "EXPLAIN (FORMAT JSON, VERBOSE)" def _prepare_insert_statement(self, columns: List[str]) -> str: - return str( + query = ( self.db.query_class.into(Table(self.model._meta.table)) .columns(*columns) .insert(*[Parameter("$%d" % (i + 1,)) for i in range(len(columns))]) - .returning("id") ) + generated_fields = self.model._meta.generated_db_fields + if generated_fields and self.db.fetch_inserted: + query = query.returning(*generated_fields) + return str(query) + + async def _process_insert_result(self, instance: Model, results: Any): + if self.model._meta.generated_db_fields and self.db.fetch_inserted: + instance.set_field_values(dict(results)) diff --git a/tortoise/backends/asyncpg/fields.py b/tortoise/backends/asyncpg/fields.py new file mode 100644 index 000000000..d77d4710f --- /dev/null +++ b/tortoise/backends/asyncpg/fields.py @@ -0,0 +1,20 @@ +from typing import Optional, Union +from uuid import UUID + +from tortoise.fields import JSONField as GenericJSONField, UUIDField as GenericUUIDField + + +class JSONField(GenericJSONField): + def to_db_value(self, value: Optional[Union[dict, list]], instance): + return value + + def to_python_value(self, value: Optional[Union[str, dict, list]]): + return value + + +class UUIDField(GenericUUIDField): + def to_db_value(self, value: Optional[UUID], instance): + return value + + def to_python_value(self, value: Optional[UUID]): + return value diff --git a/tortoise/backends/asyncpg/schema_generator.py b/tortoise/backends/asyncpg/schema_generator.py index e94ce416f..080dc8c01 100644 --- a/tortoise/backends/asyncpg/schema_generator.py +++ b/tortoise/backends/asyncpg/schema_generator.py @@ -5,7 +5,7 @@ class AsyncpgSchemaGenerator(BaseSchemaGenerator): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.FIELD_TYPE_MAP.update({fields.JSONField: "JSONB"}) + self.FIELD_TYPE_MAP.update({fields.JSONField: "JSONB", fields.UUIDField: "UUID"}) def _get_primary_key_create_string(self, field_name: str) -> str: return '"{}" SERIAL NOT NULL PRIMARY KEY'.format(field_name) diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 5acbe0b58..1db0bdf4c 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -58,9 +58,10 @@ class BaseDBAsyncClient: schema_generator = BaseSchemaGenerator capabilities = Capabilities("") - def __init__(self, connection_name: str, **kwargs) -> None: + def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs) -> None: self.log = logging.getLogger("db_client") self.connection_name = connection_name + self.fetch_inserted = fetch_inserted async def create_connection(self, with_db: bool) -> None: raise NotImplementedError() # pragma: nocoverage diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 3d1a2f173..95f7c073a 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -28,7 +28,7 @@ async def execute_select(self, query, custom_fields: Optional[list] = None) -> l raw_results = await self.db.execute_query(query.get_sql()) instance_list = [] for row in raw_results: - instance = self.model(**row) + instance = self.model(_from_db=True, **row) if custom_fields: for field in custom_fields: setattr(instance, field, row[field]) @@ -65,6 +65,9 @@ def _prepare_insert_statement(self, columns: List[str]) -> str: # go to descendant executors raise NotImplementedError() # pragma: nocoverage + async def _process_insert_result(self, instance: "Model", results: Any): + raise NotImplementedError() # pragma: nocoverage + async def execute_insert(self, instance): key = "{}:{}".format(self.db.connection_name, self.model._meta.table) if key not in INSERT_CACHE: @@ -75,7 +78,8 @@ async def execute_insert(self, instance): regular_columns, columns, query = INSERT_CACHE[key] values = self._prepare_insert_values(instance=instance, regular_columns=regular_columns) - instance.id = await self.db.execute_insert(query, values) + insert_result = await self.db.execute_insert(query, values) + await self._process_insert_result(instance, insert_result) return instance async def execute_update(self, instance): @@ -87,22 +91,22 @@ async def execute_update(self, instance): query = query.set( db_field, self._field_to_db(field_object, getattr(instance, field), instance) ) - query = query.where(table.id == instance.id) + query = query.where(getattr(table, self.model._meta.db_pk_field) == instance.pk) await self.db.execute_query(query.get_sql()) return instance async def execute_delete(self, instance): table = Table(self.model._meta.table) - query = self.model._meta.basequery.where(table.id == instance.id).delete() + query = self.model._meta.basequery.where( + getattr(table, self.model._meta.db_pk_field) == instance.pk + ).delete() await self.db.execute_query(query.get_sql()) return instance async def _prefetch_reverse_relation( self, instance_list: list, field: str, related_query ) -> list: - instance_id_set = set() # type: Set[int] - for instance in instance_list: - instance_id_set.add(instance.id) + instance_id_set = {instance.pk for instance in instance_list} # type: Set[Any] backward_relation_manager = getattr(self.model, field) relation_field = backward_relation_manager.relation_field @@ -119,13 +123,11 @@ async def _prefetch_reverse_relation( related_object_map[object_id] = [entry] for instance in instance_list: relation_container = getattr(instance, field) - relation_container._set_result_for_query(related_object_map.get(instance.id, [])) + relation_container._set_result_for_query(related_object_map.get(instance.pk, [])) return instance_list async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_query) -> list: - instance_id_set = set() # type: Set[int] - for instance in instance_list: - instance_id_set.add(instance.id) + instance_id_set = {instance.pk for instance in instance_list} # type: Set[Any] field_object = self.model._meta.fields_map[field] @@ -141,12 +143,13 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_ ) related_query_table = Table(related_query.model._meta.table) + related_pk_field = related_query.model._meta.db_pk_field query = ( related_query.query.join(subquery) - .on(subquery._forward_relation_key == related_query_table.id) + .on(subquery._forward_relation_key == getattr(related_query_table, related_pk_field)) .select( subquery._backward_relation_key.as_("_backward_relation_key"), - *[getattr(related_query_table, field).as_(field) for field in related_query.fields] + *[getattr(related_query_table, field).as_(field) for field in related_query.fields], ) ) @@ -173,12 +176,18 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_ query = query.having(having_criterion) raw_results = await self.db.execute_query(query.get_sql()) - relations = {(e["_backward_relation_key"], e["id"]) for e in raw_results} - related_object_list = [related_query.model(**e) for e in raw_results] + relations = { + ( + self.model._meta.pk.to_python_value(e["_backward_relation_key"]), + field_object.type._meta.pk.to_python_value(e[related_pk_field]), + ) + for e in raw_results + } + related_object_list = [related_query.model(_from_db=True, **e) for e in raw_results] await self.__class__( model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map ).fetch_for_list(related_object_list) - related_object_map = {e.id: e for e in related_object_list} + related_object_map = {e.pk: e for e in related_object_list} relation_map = {} # type: Dict[str, list] for object_id, related_object_id in relations: @@ -188,7 +197,7 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_ for instance in instance_list: relation_container = getattr(instance, field) - relation_container._set_result_for_query(relation_map.get(instance.id, [])) + relation_container._set_result_for_query(relation_map.get(instance.pk, [])) return instance_list async def _prefetch_direct_relation( @@ -200,8 +209,8 @@ async def _prefetch_direct_relation( if getattr(instance, relation_key_field): related_objects_for_fetch.add(getattr(instance, relation_key_field)) if related_objects_for_fetch: - related_object_list = await related_query.filter(id__in=list(related_objects_for_fetch)) - related_object_map = {obj.id: obj for obj in related_object_list} + related_object_list = await related_query.filter(pk__in=list(related_objects_for_fetch)) + related_object_map = {obj.pk: obj for obj in related_object_list} for instance in instance_list: setattr( instance, field, related_object_map.get(getattr(instance, relation_key_field)) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 2ac693aed..52ceff8ae 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -30,18 +30,24 @@ class BaseSchemaGenerator: fields.DateField: "DATE", fields.FloatField: "DOUBLE PRECISION", fields.JSONField: "TEXT", + fields.UUIDField: "CHAR(36)", } def __init__(self, client) -> None: self.client = client - def _create_string(self, db_field: str, field_type: str, nullable: str, unique: str) -> str: + def _create_string( + self, db_field: str, field_type: str, nullable: str, unique: str, is_pk: bool + ) -> str: # children can override this function to customize thier sql queries field_creation_string = self.FIELD_TEMPLATE.format( name=db_field, type=field_type, nullable=nullable, unique=unique ).strip() + if is_pk: + field_creation_string += " PRIMARY KEY" + return field_creation_string def _get_primary_key_create_string(self, field_name: str) -> str: @@ -105,7 +111,13 @@ def _get_table_sql(self, model, safe=True) -> dict: elif isinstance(field_object, fields.CharField): field_type = field_type.format(field_object.max_length) - field_creation_string = self._create_string(db_field, field_type, nullable, unique) + field_creation_string = self._create_string( + db_field=db_field, + field_type=field_type, + nullable=nullable, + unique=unique, + is_pk=field_object.pk, + ) if hasattr(field_object, "reference") and field_object.reference: field_creation_string += self.FK_TEMPLATE.format( diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index 4f5bcbe38..6f715646d 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -125,7 +125,12 @@ def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self): - return self._transaction_class(self.connection_name, self._connection, self._lock) + return self._transaction_class( + connection_name=self.connection_name, + connection=self._connection, + lock=self._lock, + fetch_inserted=self.fetch_inserted, + ) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: @@ -153,7 +158,7 @@ async def execute_script(self, query: str) -> None: class TransactionWrapper(MySQLClient, BaseTransactionWrapper): - def __init__(self, connection_name, connection, lock): + def __init__(self, connection_name, connection, lock, fetch_inserted): self.connection_name = connection_name self._connection = connection self._lock = lock @@ -161,6 +166,7 @@ def __init__(self, connection_name, connection, lock): self._transaction_class = self.__class__ self._finalized = False self._old_context_value = None + self.fetch_inserted = fetch_inserted async def start(self): await self._connection.begin() diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index b0b241425..50aee076a 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,9 +1,11 @@ -from typing import List +from typing import List, Any from pypika import MySQLQuery, Parameter, Table, functions from pypika.enums import SqlTypes +from tortoise import Model from tortoise.backends.base.executor import BaseExecutor +from tortoise.fields import IntField, BigIntField from tortoise.filters import ( contains, ends_with, @@ -61,3 +63,29 @@ def _prepare_insert_statement(self, columns: List[str]) -> str: .columns(*columns) .insert(*[Parameter("%s") for _ in range(len(columns))]) ) + + async def _process_insert_result(self, instance: Model, results: Any): + generated_fields = self.model._meta.generated_db_fields + if not generated_fields: + return + + pk_fetched = False + pk_field_object = self.model._meta.pk + if isinstance(pk_field_object, (IntField, BigIntField)) and pk_field_object.generated: + instance.pk = results + pk_fetched = True + + if self.db.fetch_inserted: + other_generated_fields = set(generated_fields) + if pk_fetched: + other_generated_fields.remove(self.model._meta.db_pk_field) + if not other_generated_fields: + return + table = Table(self.model._meta.table) + query = str( + MySQLQuery.from_(table) + .select(*generated_fields) + .where(getattr(table, self.model._meta.db_pk_field) == instance.pk) + ) + fetch_results = await self.db.execute_query(query) + instance.set_field_values(dict(fetch_results)) diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 75dec47eb..a01898bab 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -78,7 +78,12 @@ def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._connection, self._lock) def _in_transaction(self) -> "TransactionWrapper": - return self._transaction_class(self.connection_name, self._connection, self._lock) + return self._transaction_class( + connection_name=self.connection_name, + connection=self._connection, + lock=self._lock, + fetch_inserted=self.fetch_inserted, + ) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: @@ -90,7 +95,10 @@ async def execute_insert(self, query: str, values: list) -> int: async def execute_query(self, query: str) -> List[dict]: async with self.acquire_connection() as connection: self.log.debug(query) - return [dict(row) for row in await connection.execute_fetchall(query)] + print(query) + res = [dict(row) for row in await connection.execute_fetchall(query)] + print(res) + return res @translate_exceptions async def execute_script(self, query: str) -> None: @@ -100,7 +108,9 @@ async def execute_script(self, query: str) -> None: class TransactionWrapper(SqliteClient, BaseTransactionWrapper): - def __init__(self, connection_name: str, connection: aiosqlite.Connection, lock) -> None: + def __init__( + self, connection_name: str, connection: aiosqlite.Connection, lock, fetch_inserted + ) -> None: self.connection_name = connection_name self._connection = connection # type: aiosqlite.Connection self._lock = lock @@ -108,6 +118,7 @@ def __init__(self, connection_name: str, connection: aiosqlite.Connection, lock) self._transaction_class = self.__class__ self._old_context_value = None self._finalized = False + self.fetch_inserted = fetch_inserted async def start(self) -> None: try: diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index ea8d92d74..d047993ee 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -1,10 +1,11 @@ from decimal import Decimal -from typing import List +from typing import List, Any from pypika import Parameter, Table -from tortoise import fields +from tortoise import fields, Model from tortoise.backends.base.executor import BaseExecutor +from tortoise.fields import IntField, BigIntField def to_db_bool(self, value, instance): @@ -33,3 +34,29 @@ def _prepare_insert_statement(self, columns: List[str]) -> str: .columns(*columns) .insert(*[Parameter("?") for _ in range(len(columns))]) ) + + async def _process_insert_result(self, instance: Model, results: Any): + generated_fields = self.model._meta.generated_db_fields + if not generated_fields: + return + + pk_fetched = False + pk_field_object = self.model._meta.pk + if isinstance(pk_field_object, (IntField, BigIntField)) and pk_field_object.generated: + instance.pk = results + pk_fetched = True + + if self.db.fetch_inserted: + other_generated_fields = set(generated_fields) + if pk_fetched: + other_generated_fields.remove(self.model._meta.db_pk_field) + if not other_generated_fields: + return + table = Table(self.model._meta.table) + query = str( + self.db.query_class.from_(table) + .select(*generated_fields) + .where(getattr(table, self.model._meta.db_pk_field) == instance.pk) + ) + fetch_results = await self.db.execute_query(query) + instance.set_field_values(dict(fetch_results)) diff --git a/tortoise/fields.py b/tortoise/fields.py index 6b4aeec0d..28e5f128a 100644 --- a/tortoise/fields.py +++ b/tortoise/fields.py @@ -1,8 +1,10 @@ import datetime import functools import json +import uuid from decimal import Decimal from typing import Any, Optional, Union +from uuid import UUID import ciso8601 from pypika import Table @@ -35,7 +37,10 @@ class Field: "unique", "index", "model_field_name", + "model", + "reference", ) + has_db_field = True def __init__( self, @@ -47,6 +52,8 @@ def __init__( default: Any = None, unique: bool = False, index: bool = False, + reference: str = None, + model: "Model" = None, **kwargs ) -> None: self.type = type @@ -58,6 +65,8 @@ def __init__( self.unique = unique self.index = index self.model_field_name = "" # Type: str + self.model = model + self.reference = reference def to_db_value(self, value: Any, instance) -> Any: if value is None or type(value) == self.type: # pylint: disable=C0123 @@ -82,13 +91,10 @@ class IntField(Field): True if field is Primary Key. """ - __slots__ = ("reference",) - def __init__(self, pk: bool = False, **kwargs) -> None: - kwargs["generated"] = bool(kwargs.get("generated")) | pk - super().__init__(int, **kwargs) - self.reference = kwargs.get("reference") - self.pk = pk + if pk: + kwargs["generated"] = bool(kwargs.get("generated", True)) + super().__init__(int, pk=pk, **kwargs) class BigIntField(Field): @@ -99,13 +105,10 @@ class BigIntField(Field): True if field is Primary Key. """ - __slots__ = ("reference",) - def __init__(self, pk: bool = False, **kwargs) -> None: - kwargs["generated"] = bool(kwargs.get("generated")) | pk - super().__init__(int, **kwargs) - self.reference = kwargs.get("reference") - self.pk = pk + if pk: + kwargs["generated"] = bool(kwargs.get("generated", True)) + super().__init__(int, pk=pk, **kwargs) class SmallIntField(Field): @@ -286,7 +289,7 @@ class JSONField(Field): __slots__ = ("encoder", "decoder") def __init__(self, encoder=JSON_DUMPS, decoder=JSON_LOADS, **kwargs) -> None: - super().__init__((dict, list), **kwargs) + super().__init__(type=(dict, list), **kwargs) self.encoder = encoder self.decoder = decoder @@ -303,6 +306,28 @@ def to_python_value( return self.decoder(value) +class UUIDField(Field): + """ + UUID Field + + This field can store uuid value. Postgresql will store value with + native data type, while others will store it as string + """ + + def __init__(self, *args, **kwargs): + super().__init__(type=UUID, *args, **kwargs) + + def to_db_value(self, value: Any, instance): + if value is None: + return None + return str(value) + + def to_python_value(self, value: Any): + if value is None or isinstance(value, self.type): + return value + return uuid.UUID(value) + + class ForeignKeyField(Field): """ ForeignKey relation field. @@ -334,6 +359,7 @@ class ForeignKeyField(Field): """ __slots__ = ("model_name", "related_name", "on_delete") + has_db_field = False def __init__( self, model_name: str, related_name: Optional[str] = None, on_delete=CASCADE, **kwargs @@ -384,6 +410,7 @@ class ManyToManyField(Field): "through", "_generated", ) + has_db_field = False def __init__( self, @@ -407,6 +434,7 @@ def __init__( class BackwardFKRelation(Field): __slots__ = ("type", "relation_field") + has_db_field = False def __init__(self, type, relation_field: str) -> None: # pylint: disable=W0622 super().__init__(type=type) @@ -437,11 +465,11 @@ def __init__(self, model, relation_field: str, instance) -> None: @property def _query(self): - if not self.instance.id: + if not self.instance.pk: raise OperationalError( "This objects hasn't been instanced, call .save() before" " calling related queries" ) - return self.model.filter(**{self.relation_field: self.instance.id}) + return self.model.filter(**{self.relation_field: self.instance.pk}) def __contains__(self, item) -> bool: if not self._fetched: @@ -544,15 +572,20 @@ async def add(self, *instances, using_db=None) -> None: """ if not instances: return - if self.instance.id is None: + if self.instance.pk is None: raise OperationalError( "You should first call .save() on {model}".format(model=self.instance) ) db = using_db if using_db else self.model._meta.db + pk_formatting_func = type(self.instance)._meta.pk.to_db_value + related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value through_table = Table(self.field.through) select_query = ( db.query_class.from_(through_table) - .where(getattr(through_table, self.field.backward_key) == self.instance.id) + .where( + getattr(through_table, self.field.backward_key) + == pk_formatting_func(self.instance.pk, self.instance) + ) .select(self.field.backward_key, self.field.forward_key) ) query = db.query_class.into(through_table).columns( @@ -561,10 +594,12 @@ async def add(self, *instances, using_db=None) -> None: ) if len(instances) == 1: - criterion = getattr(through_table, self.field.forward_key) == instances[0].id + criterion = getattr( + through_table, self.field.forward_key + ) == related_pk_formatting_func(instances[0].pk, instances[0]) else: criterion = getattr(through_table, self.field.forward_key).isin( - [i.id for i in instances] + [related_pk_formatting_func(i.pk, i) for i in instances] ) select_query = select_query.where(criterion) @@ -577,13 +612,16 @@ async def add(self, *instances, using_db=None) -> None: insert_is_required = False for instance_to_add in instances: - if instance_to_add.id is None: + if instance_to_add.pk is None: raise OperationalError( "You should first call .save() on {model}".format(model=instance_to_add) ) - if (self.instance.id, instance_to_add.id) in already_existing_relations: + if (self.instance.pk, instance_to_add.pk) in already_existing_relations: continue - query = query.insert(instance_to_add.id, self.instance.id) + query = query.insert( + related_pk_formatting_func(instance_to_add.pk, instance_to_add), + pk_formatting_func(self.instance.pk, self.instance), + ) insert_is_required = True if insert_is_required: await db.execute_query(str(query)) @@ -594,9 +632,13 @@ async def clear(self, using_db=None) -> None: """ db = using_db if using_db else self.model._meta.db through_table = Table(self.field.through) + pk_formatting_func = type(self.instance)._meta.pk.to_db_value query = ( db.query_class.from_(through_table) - .where(getattr(through_table, self.field.backward_key) == self.instance.id) + .where( + getattr(through_table, self.field.backward_key) + == pk_formatting_func(self.instance.pk, self.instance) + ) .delete() ) await db.execute_query(str(query)) @@ -609,14 +651,25 @@ async def remove(self, *instances, using_db=None) -> None: if not instances: raise OperationalError("remove() called on no instances") through_table = Table(self.field.through) + pk_formatting_func = type(self.instance)._meta.pk.to_db_value + related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value if len(instances) == 1: - condition = (getattr(through_table, self.field.forward_key) == instances[0].id) & ( - getattr(through_table, self.field.backward_key) == self.instance.id + condition = ( + getattr(through_table, self.field.forward_key) + == related_pk_formatting_func(instances[0].pk, instances[0]) + ) & ( + getattr(through_table, self.field.backward_key) + == pk_formatting_func(self.instance.pk, self.instance) ) else: - condition = (getattr(through_table, self.field.backward_key) == self.instance.id) & ( - getattr(through_table, self.field.forward_key).isin([i.id for i in instances]) + condition = ( + getattr(through_table, self.field.backward_key) + == pk_formatting_func(self.instance.pk, self.instance) + ) & ( + getattr(through_table, self.field.forward_key).isin( + [related_pk_formatting_func(i.pk, i) for i in instances] + ) ) query = db.query_class.from_(through_table).where(condition).delete() await db.execute_query(str(query)) diff --git a/tortoise/filters.py b/tortoise/filters.py index 7e61e654e..7d79e38d6 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -1,4 +1,5 @@ import operator +from functools import partial from typing import Dict, Iterable, Optional # noqa from pypika import Table, functions @@ -13,6 +14,13 @@ def list_encoder(values, instance, field: Field): return [field.to_db_value(element, instance) for element in values] +def related_list_encoder(values, instance, field: Field): + return [ + field.to_db_value(element.pk if hasattr(element, "pk") else element, instance) + for element in values + ] + + def bool_encoder(value, *args): return bool(value) @@ -78,63 +86,69 @@ def insensitive_ends_with(field, value): def get_m2m_filters(field_name: str, field: fields.ManyToManyField) -> Dict[str, dict]: + target_table_pk = field.type._meta.pk return { field_name: { "field": field.forward_key, "backward_key": field.backward_key, "operator": operator.eq, "table": Table(field.through), + "value_encoder": target_table_pk.to_db_value, }, "{}__not".format(field_name): { "field": field.forward_key, "backward_key": field.backward_key, "operator": not_equal, "table": Table(field.through), + "value_encoder": target_table_pk.to_db_value, }, "{}__in".format(field_name): { "field": field.forward_key, "backward_key": field.backward_key, "operator": is_in, "table": Table(field.through), - "value_encoder": list_encoder, + "value_encoder": partial(related_list_encoder, field=target_table_pk), }, "{}__not_in".format(field_name): { "field": field.forward_key, "backward_key": field.backward_key, "operator": not_in, "table": Table(field.through), - "value_encoder": list_encoder, + "value_encoder": partial(related_list_encoder, field=target_table_pk), }, } def get_backward_fk_filters(field_name: str, field: fields.BackwardFKRelation) -> Dict[str, dict]: + target_table_pk = field.type._meta.pk return { field_name: { "field": "id", "backward_key": field.relation_field, "operator": operator.eq, "table": Table(field.type._meta.table), + "value_encoder": target_table_pk.to_db_value, }, "{}__not".format(field_name): { "field": "id", "backward_key": field.relation_field, "operator": not_equal, "table": Table(field.type._meta.table), + "value_encoder": target_table_pk.to_db_value, }, "{}__in".format(field_name): { "field": "id", "backward_key": field.relation_field, "operator": is_in, "table": Table(field.type._meta.table), - "value_encoder": list_encoder, + "value_encoder": partial(related_list_encoder, field=target_table_pk), }, "{}__not_in".format(field_name): { "field": "id", "backward_key": field.relation_field, "operator": not_in, "table": Table(field.type._meta.table), - "value_encoder": list_encoder, + "value_encoder": partial(related_list_encoder, field=target_table_pk), }, } diff --git a/tortoise/models.py b/tortoise/models.py index 1cc6de24e..09435592b 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1,12 +1,17 @@ -from copy import copy -from typing import Dict, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union +from copy import copy, deepcopy +from typing import Dict, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union, Any from pypika import Query from tortoise import fields from tortoise.backends.base.client import BaseDBAsyncClient # noqa from tortoise.exceptions import ConfigurationError, OperationalError -from tortoise.fields import ManyToManyField, ManyToManyRelationManager, RelationQueryContainer +from tortoise.fields import ( + ManyToManyField, + ManyToManyRelationManager, + RelationQueryContainer, + Field, +) from tortoise.filters import get_filters_for_field from tortoise.queryset import QuerySet from tortoise.transactions import current_transaction_map @@ -31,15 +36,15 @@ class MetaInfo: "abstract", "table", "app", - "fields", - "db_fields", + "_fields", + "_db_fields", "m2m_fields", "fk_fields", "backward_fk_fields", - "fetch_fields", + "_fetch_fields", "fields_db_projection", "_inited", - "fields_db_projection_reverse", + "_fields_db_projection_reverse", "filters", "fields_map", "default_connection", @@ -47,6 +52,9 @@ class MetaInfo: "basequery_all_fields", "_filters", "unique_together", + "pk_attr", + "_generated_db_fields", + "_model", ) def __init__(self, meta) -> None: @@ -54,14 +62,14 @@ def __init__(self, meta) -> None: self.table = getattr(meta, "table", "") # type: str self.app = getattr(meta, "app", None) # type: Optional[str] self.unique_together = get_unique_together(meta) # type: Optional[Union[Tuple, List]] - self.fields = set() # type: Set[str] - self.db_fields = set() # type: Set[str] + self._fields = None # type: Optional[Set[str]] + self._db_fields = None # type: Optional[Set[str]] self.m2m_fields = set() # type: Set[str] self.fk_fields = set() # type: Set[str] self.backward_fk_fields = set() # type: Set[str] - self.fetch_fields = set() # type: Set[str] + self._fetch_fields = None # type: Optional[Set[str]] self.fields_db_projection = {} # type: Dict[str,str] - self.fields_db_projection_reverse = {} # type: Dict[str,str] + self._fields_db_projection_reverse = None # type: Optional[Dict[str,str]] self._filters = {} # type: Dict[str, Dict[str, dict]] self.filters = {} # type: Dict[str, dict] self.fields_map = {} # type: Dict[str, fields.Field] @@ -69,6 +77,89 @@ def __init__(self, meta) -> None: self.default_connection = None # type: Optional[str] self.basequery = Query() # type: Query self.basequery_all_fields = Query() # type: Query + self.pk_attr = getattr(meta, "pk_attr", "") # type: str + self._generated_db_fields = None + self._model = None + + def add_field(self, name: str, value: Field): + if name in self.fields_map: + raise ConfigurationError("Field {} already present in meta".format(name)) + setattr(self._model, name, value) + value.model = self._model + self.fields_map[name] = value + self._fields = None + + if value.has_db_field: + self.fields_db_projection[name] = value.source_field or name + self._fields_db_projection_reverse = None + + if isinstance(value, fields.ForeignKeyField): + self.fk_fields.add(name) + self._fetch_fields = None + elif isinstance(value, fields.ManyToManyField): + self.m2m_fields.add(name) + self._fetch_fields = None + elif isinstance(value, fields.BackwardFKRelation): + self.backward_fk_fields.add(name) + self._fetch_fields = None + + field_filters = get_filters_for_field( + field_name=name, field=value, source_field=value.source_field or name + ) + self._filters.update(field_filters) + self.generate_filters() + + @property + def fields_db_projection_reverse(self) -> Dict[str, str]: + if self._fields_db_projection_reverse is None: + self._fields_db_projection_reverse = { + value: key for key, value in self.fields_db_projection.items() + } + return self._fields_db_projection_reverse + + @property + def fields(self) -> Set[str]: + if self._fields is None: + self._fields = set(self.fields_map.keys()) + return self._fields + + @property + def db_fields(self) -> Set[str]: + if self._db_fields is None: + self._db_fields = set(self.fields_db_projection.values()) + return self._db_fields + + @property + def fetch_fields(self): + if self._fetch_fields is None: + self._fetch_fields = self.m2m_fields | self.backward_fk_fields | self.fk_fields + return self._fetch_fields + + @property + def pk(self): + return self.fields_map[self.pk_attr] + + @property + def db_pk_field(self) -> str: + field_object = self.fields_map[self.pk_attr] + return field_object.source_field or self.pk_attr + + @property + def is_pk_generated(self) -> bool: + field_object = self.fields_map[self.pk_attr] + return field_object.generated + + @property + def generated_db_fields(self) -> Tuple[str]: + """Return list of names of db fields that are generated on db side""" + if self._generated_db_fields is None: + generated_fields = [] + for field in self.fields_map.values(): + if not field.generated: + continue + generated_fields.append(field.source_field or field.model_field_name) + self._generated_db_fields = tuple(generated_fields) + return self._generated_db_fields @property def db(self) -> BaseDBAsyncClient: @@ -102,32 +193,51 @@ def __new__(mcs, name: str, bases, attrs: dict, *args, **kwargs): fk_fields = set() # type: Set[str] m2m_fields = set() # type: Set[str] - if "id" not in attrs: - attrs["id"] = fields.IntField(pk=True) + custom_pk_present = False + for key, value in attrs.items(): + if isinstance(value, fields.Field): + if value.pk: + if custom_pk_present: + raise ConfigurationError( + "Can't create model {} with two primary keys, " + "only single pk are supported".format(name) + ) + elif value.generated and not isinstance( + value, (fields.IntField, fields.BigIntField) + ): + raise ConfigurationError( + "Generated primary key allowed only for IntField and BigIntField" + ) + custom_pk_present = True + pk_attr = key + + if not custom_pk_present: + if "id" not in attrs: + attrs["id"] = fields.IntField(pk=True) + pk_attr = "id" + + if not isinstance(attrs["id"], fields.Field) or not attrs["id"].pk: + raise ConfigurationError( + "Can't create model {} without explicit primary key " + "if field 'id' already present".format(name) + ) + + meta_class = attrs.get("Meta", type("Meta", (), {})) for key, value in attrs.items(): if isinstance(value, fields.Field): + if getattr(meta_class, "abstract", None): + value = deepcopy(value) + fields_map[key] = value value.model_field_name = key + if isinstance(value, fields.ForeignKeyField): - key_field = "{}_id".format(key) - value.source_field = key_field - fields_db_projection[key_field] = key_field - fields_map[key_field] = fields.IntField( - reference=value, null=value.null, default=value.default - ) - filters.update( - get_filters_for_field( - field_name=key_field, - field=fields_map[key_field], - source_field=key_field, - ) - ) fk_fields.add(key) elif isinstance(value, fields.ManyToManyField): m2m_fields.add(key) else: - fields_db_projection[key] = value.source_field if value.source_field else key + fields_db_projection[key] = value.source_field or key filters.update( get_filters_for_field( field_name=key, @@ -135,39 +245,45 @@ def __new__(mcs, name: str, bases, attrs: dict, *args, **kwargs): source_field=fields_db_projection[key], ) ) + if value.pk: + filters.update( + get_filters_for_field( + field_name="pk", + field=fields_map[key], + source_field=fields_db_projection[key], + ) + ) - attrs["_meta"] = meta = MetaInfo(attrs.get("Meta")) + attrs["_meta"] = meta = MetaInfo(meta_class) meta.fields_map = fields_map meta.fields_db_projection = fields_db_projection - meta.fields_db_projection_reverse = { - value: key for key, value in fields_db_projection.items() - } - meta.fields = set(fields_map.keys()) - meta.db_fields = set(fields_db_projection.values()) meta._filters = filters meta.fk_fields = fk_fields meta.backward_fk_fields = set() meta.m2m_fields = m2m_fields - meta.fetch_fields = fk_fields | m2m_fields meta.default_connection = None + meta.pk_attr = pk_attr meta._inited = False if not fields_map: meta.abstract = True new_class = super().__new__(mcs, name, bases, attrs) + for field in meta.fields_map.values(): + field.model = new_class + meta._model = new_class return new_class class Model(metaclass=ModelMeta): # I don' like this here, but it makes autocompletion and static analysis much happier _meta = MetaInfo(None) - id = None # type: Optional[Hashable] - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, _from_db: bool = False, **kwargs) -> None: # self._meta is a very common attribute lookup, lets cache it. meta = self._meta + self._saved_in_db = _from_db or (meta.pk_attr in kwargs and meta.is_pk_generated) # Create lazy fk/m2m objects for key in meta.backward_fk_fields: @@ -193,15 +309,32 @@ def __init__(self, *args, **kwargs) -> None: # Assign values and do type conversions passed_fields = set(kwargs.keys()) passed_fields.update(meta.fetch_fields) + passed_fields |= self.set_field_values(kwargs) - for key, value in kwargs.items(): + # Assign defaults for missing fields + for key in meta.fields.difference(passed_fields): + field_object = meta.fields_map[key] + if callable(field_object.default): + setattr(self, key, field_object.default()) + else: + setattr(self, key, field_object.default) + + def set_field_values(self, values_map: Dict[str, Any]) -> List[str]: + """ + Sets values for fields honoring type transformations and + return list of fields that were set additionally + """ + meta = self._meta + passed_fields = set() + + for key, value in values_map.items(): if key in meta.fk_fields: - if hasattr(value, "id") and not value.id: + if hasattr(value, "pk") and not value.pk: raise OperationalError( "You should first call .save() on {} before referring to it".format(value) ) relation_field = "{}_id".format(key) - setattr(self, relation_field, value.id) + setattr(self, relation_field, value.pk) passed_fields.add(relation_field) elif key in meta.fields: field_object = meta.fields_map[key] @@ -209,7 +342,10 @@ def __init__(self, *args, **kwargs) -> None: raise ValueError("{} is non nullable field, but null was passed".format(key)) setattr(self, key, field_object.to_python_value(value)) elif key in meta.db_fields: - setattr(self, meta.fields_db_projection_reverse[key], value) + field_object = meta.fields_map[meta.fields_db_projection_reverse[key]] + if value is None and not field_object.null: + raise ValueError("{} is non nullable field, but null was passed".format(key)) + setattr(self, key, field_object.to_python_value(value)) elif key in meta.backward_fk_fields: raise ConfigurationError( "You can't set backward relations through init, change related model instead" @@ -219,31 +355,34 @@ def __init__(self, *args, **kwargs) -> None: "You can't set m2m relations through init, use m2m_manager instead" ) - # Assign defaults for missing fields - for key in meta.fields.difference(passed_fields): - field_object = meta.fields_map[key] - if callable(field_object.default): - setattr(self, key, field_object.default()) - else: - setattr(self, key, field_object.default) + return passed_fields + + def _get_pk_val(self): + return getattr(self, self._meta.pk_attr) + + def _set_pk_val(self, value): + setattr(self, self._meta.pk_attr, value) + + pk = property(_get_pk_val, _set_pk_val) async def _insert_instance(self, using_db=None) -> None: db = using_db if using_db else self._meta.db await db.executor_class(model=self.__class__, db=db).execute_insert(self) + self._saved_in_db = True async def _update_instance(self, using_db=None) -> None: db = using_db if using_db else self._meta.db await db.executor_class(model=self.__class__, db=db).execute_update(self) async def save(self, *args, **kwargs) -> None: - if not self.id: + if not self._saved_in_db: await self._insert_instance(*args, **kwargs) else: await self._update_instance(*args, **kwargs) async def delete(self, using_db=None) -> None: db = using_db if using_db else self._meta.db - if not self.id: + if not self._saved_in_db: raise OperationalError("Can't delete unpersisted record") await db.executor_class(model=self.__class__, db=db).execute_delete(self) @@ -255,18 +394,18 @@ def __str__(self) -> str: return "<{}>".format(self.__class__.__name__) def __repr__(self) -> str: - if self.id: - return "<{}: {}>".format(self.__class__.__name__, self.id) + if self.pk: + return "<{}: {}>".format(self.__class__.__name__, self.pk) return "<{}>".format(self.__class__.__name__) def __hash__(self) -> int: - if not self.id: + if not self.pk: raise TypeError("Model instances without id are unhashable") - return hash(self.id) + return hash(self.pk) def __eq__(self, other) -> bool: # pylint: disable=C0123 - if type(self) == type(other) and self.id == other.id: + if type(self) == type(other) and self.pk == other.pk: return True return False diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 3ce99dfb6..f54e6f689 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -17,8 +17,14 @@ def _process_filter_kwarg(model, key, value) -> Tuple[Criterion, Optional[Tuple[ else: param = model._meta.get_filter(key) + pk_db_field = model._meta.db_pk_field if param.get("table"): - join = (param["table"], table.id == getattr(param["table"], param["backward_key"])) + join = ( + param["table"], + getattr(table, pk_db_field) == getattr(param["table"], param["backward_key"]), + ) + if param.get("value_encoder"): + value = param["value_encoder"](value, model) criterion = param["operator"](getattr(param["table"], param["field"]), value) else: field_object = model._meta.fields_map[param["field"]] @@ -35,24 +41,42 @@ def _get_joins_for_related_field( table, related_field, related_field_name ) -> List[Tuple[Table, Criterion]]: required_joins = [] + + table_pk = related_field.model._meta.db_pk_field + related_table_pk = related_field.type._meta.db_pk_field + if isinstance(related_field, fields.ManyToManyField): related_table = Table(related_field.type._meta.table) through_table = Table(related_field.through) required_joins.append( - (through_table, table.id == getattr(through_table, related_field.backward_key)) + ( + through_table, + getattr(table, table_pk) == getattr(through_table, related_field.backward_key), + ) ) required_joins.append( - (related_table, getattr(through_table, related_field.forward_key) == related_table.id) + ( + related_table, + getattr(through_table, related_field.forward_key) + == getattr(related_table, related_table_pk), + ) ) elif isinstance(related_field, fields.BackwardFKRelation): related_table = Table(related_field.type._meta.table) required_joins.append( - (related_table, table.id == getattr(related_table, related_field.relation_field)) + ( + related_table, + getattr(table, table_pk) == getattr(related_table, related_field.relation_field), + ) ) else: related_table = Table(related_field.type._meta.table) required_joins.append( - (related_table, related_table.id == getattr(table, "{}_id".format(related_field_name))) + ( + related_table, + getattr(related_table, related_table_pk) + == getattr(table, "{}_id".format(related_field_name)), + ) ) return required_joins @@ -213,11 +237,17 @@ def _resolve_regular_kwarg(self, model, key, value) -> QueryModifier: def _get_actual_filter_params(self, model, key, value) -> Tuple[str, Any]: if key in model._meta.fk_fields: field_object = model._meta.fields_map[key] - if hasattr(value, "id"): - filter_value = value.id + if hasattr(value, "pk"): + filter_value = value.pk else: filter_value = value filter_key = field_object.source_field + elif key in model._meta.m2m_fields: + filter_key = key + if hasattr(value, "pk"): + filter_value = value.pk + else: + filter_value = value elif ( key.split("__")[0] in model._meta.fetch_fields or key in self._custom_filters diff --git a/tortoise/tests/fields/test_uuid.py b/tortoise/tests/fields/test_uuid.py new file mode 100644 index 000000000..7503011bf --- /dev/null +++ b/tortoise/tests/fields/test_uuid.py @@ -0,0 +1,28 @@ +import uuid + +from tortoise.contrib import test +from tortoise.exceptions import IntegrityError +from tortoise.tests import testmodels + + +class TestUUIDFields(test.TestCase): + async def test_empty(self): + with self.assertRaises(IntegrityError): + await testmodels.UUIDFields.create() + + async def test_create(self): + data = uuid.uuid4() + obj0 = await testmodels.UUIDFields.create(data=data) + obj = await testmodels.UUIDFields.get(id=obj0.id) + self.assertEqual(obj.data, data) + self.assertEqual(obj.data_null, None) + await obj.save() + obj2 = await testmodels.UUIDFields.get(id=obj.id) + self.assertEqual(obj, obj2) + + async def test_create_not_null(self): + data = uuid.uuid4() + obj0 = await testmodels.UUIDFields.create(data=data, data_null=data) + obj = await testmodels.UUIDFields.get(id=obj0.id) + self.assertEqual(obj.data, data) + self.assertEqual(obj.data_null, data) diff --git a/tortoise/tests/test_primary_key.py b/tortoise/tests/test_primary_key.py new file mode 100644 index 000000000..9aa109564 --- /dev/null +++ b/tortoise/tests/test_primary_key.py @@ -0,0 +1,70 @@ +import uuid + +from tortoise.contrib import test +from tortoise.tests.testmodels import ( + ImplicitPkModel, + UUIDPkModel, + UUIDFkRelatedModel, + UUIDM2MRelatedModel, +) + + +class TestQueryset(test.TestCase): + async def test_implicit_pk(self): + instance = await ImplicitPkModel.create(value="test") + self.assertTrue(instance.id) + self.assertEqual(instance.pk, instance.id) + + async def test_uuid_pk(self): + value = uuid.uuid4() + await UUIDPkModel.create(id=value) + + instance2 = await UUIDPkModel.get(id=value) + self.assertEqual(instance2.id, value) + self.assertEqual(instance2.pk, value) + + async def test_uuid_pk_fk(self): + value = uuid.uuid4() + instance = await UUIDPkModel.create(id=value) + instance2 = await UUIDPkModel.create(id=uuid.uuid4()) + await UUIDFkRelatedModel.create(model=instance2) + + related_instance = await UUIDFkRelatedModel.create(model=instance) + self.assertEqual(related_instance.model_id, value) + + related_instance = await UUIDFkRelatedModel.filter(model=instance).first() + self.assertEqual(related_instance.model_id, value) + + related_instance = await UUIDFkRelatedModel.filter(model_id=value).first() + self.assertEqual(related_instance.model_id, value) + + await instance.fetch_related("children") + self.assertEqual(instance.children[0], related_instance) + + async def test_uuid_m2m(self): + value = uuid.uuid4() + instance = await UUIDPkModel.create(id=value) + instance2 = await UUIDPkModel.create(id=uuid.uuid4()) + + related_instance = await UUIDM2MRelatedModel.create() + related_instance2 = await UUIDM2MRelatedModel.create() + + await instance.peers.add(related_instance) + await related_instance2.models.add(instance, instance2) + + await related_instance.fetch_related("models") + print(list(related_instance.models)) + self.assertEqual(len(related_instance.models), 1) + self.assertEqual(related_instance.models[0], instance) + + await related_instance2.fetch_related("models") + self.assertEqual(len(related_instance2.models), 2) + self.assertEqual(set(m.pk for m in related_instance2.models), {instance.pk, instance2.pk}) + + related_instance_list = await UUIDM2MRelatedModel.filter(models=instance2) + self.assertEqual(len(related_instance_list), 1) + self.assertEqual(related_instance_list[0], related_instance2) + + related_instance_list = await UUIDM2MRelatedModel.filter(models__in=[instance2]) + self.assertEqual(len(related_instance_list), 1) + self.assertEqual(related_instance_list[0], related_instance2) diff --git a/tortoise/tests/testmodels.py b/tortoise/tests/testmodels.py index 30112e5da..7074e8e01 100644 --- a/tortoise/tests/testmodels.py +++ b/tortoise/tests/testmodels.py @@ -157,6 +157,12 @@ class JSONFields(Model): data_null = fields.JSONField(null=True) +class UUIDFields(Model): + id = fields.IntField(pk=True) + data = fields.UUIDField() + data_null = fields.UUIDField(null=True) + + class MinRelation(Model): id = fields.IntField(pk=True) tournament = fields.ForeignKeyField("models.Tournament") @@ -220,3 +226,20 @@ class ContactTypeEnum(IntEnum): class Contact(Model): id = fields.IntField(pk=True) type = fields.IntField(default=ContactTypeEnum.other) + + +class ImplicitPkModel(Model): + value = fields.TextField() + + +class UUIDPkModel(Model): + id = fields.UUIDField(pk=True) + + +class UUIDFkRelatedModel(Model): + model = fields.ForeignKeyField("models.UUIDPkModel", related_name="children") + + +class UUIDM2MRelatedModel(Model): + value = fields.TextField(default="test") + models = fields.ManyToManyField("models.UUIDPkModel", related_name="peers")