Skip to content

Commit

Permalink
Generic primary key implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
abondar committed Apr 3, 2019
1 parent d4b32c0 commit 99685a4
Show file tree
Hide file tree
Showing 19 changed files with 640 additions and 160 deletions.
59 changes: 28 additions & 31 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def _init_relations(cls) -> None:
reference = field_object.model_name
related_app_name, related_model_name = reference.split(".")
related_model = cls.apps[related_app_name][related_model_name]

key_field = "{}_id".format(field)
field_object.source_field = key_field
key_field_object = deepcopy(related_model._meta.pk)
key_field_object.pk = False
key_field_object.index = field_object.index
key_field_object.default = field_object.default
key_field_object.null = field_object.null
key_field_object.generated = field_object.generated
key_field_object.reference = field_object
model._meta.add_field(key_field, key_field_object)

field_object.type = related_model
backward_relation_name = field_object.related_name
if not backward_relation_name:
Expand All @@ -59,35 +71,27 @@ def _init_relations(cls) -> None:
)
)
fk_relation = fields.BackwardFKRelation(model, "{}_id".format(field))
setattr(related_model, backward_relation_name, fk_relation)
related_model._meta.filters.update(
get_backward_fk_filters(backward_relation_name, fk_relation)
)

related_model._meta.backward_fk_fields.add(backward_relation_name)
related_model._meta.fetch_fields.add(backward_relation_name)
related_model._meta.fields_map[backward_relation_name] = fk_relation
related_model._meta.fields.add(backward_relation_name)
related_model._meta.add_field(backward_relation_name, fk_relation)

for field in model._meta.m2m_fields:
field_mobject = cast(fields.ManyToManyField, model._meta.fields_map[field])
if field_mobject._generated:
field_object = cast(fields.ManyToManyField, model._meta.fields_map[field])
if field_object._generated:
continue

backward_key = field_mobject.backward_key
backward_key = field_object.backward_key
if not backward_key:
backward_key = "{}_id".format(model._meta.table)
field_mobject.backward_key = backward_key
field_object.backward_key = backward_key

reference = field_mobject.model_name
reference = field_object.model_name
related_app_name, related_model_name = reference.split(".")
related_model = cls.apps[related_app_name][related_model_name]

field_mobject.type = related_model
field_object.type = related_model

backward_relation_name = field_mobject.related_name
backward_relation_name = field_object.related_name
if not backward_relation_name:
backward_relation_name = field_mobject.related_name = "{}_through".format(
backward_relation_name = field_object.related_name = "{}_through".format(
model._meta.table
)
if backward_relation_name in related_model._meta.fields:
Expand All @@ -97,35 +101,28 @@ def _init_relations(cls) -> None:
)
)

if not field_mobject.through:
if not field_object.through:
related_model_table_name = (
related_model._meta.table
if related_model._meta.table
else related_model.__name__.lower()
)

field_mobject.through = "{}_{}".format(
field_object.through = "{}_{}".format(
model._meta.table, related_model_table_name
)

m2m_relation = fields.ManyToManyField(
"{}.{}".format(app_name, model_name),
field_mobject.through,
forward_key=field_mobject.backward_key,
backward_key=field_mobject.forward_key,
field_object.through,
forward_key=field_object.backward_key,
backward_key=field_object.forward_key,
related_name=field,
type=model,
)
m2m_relation._generated = True
setattr(related_model, backward_relation_name, m2m_relation)
model._meta.filters.update(get_m2m_filters(field, field_mobject))
related_model._meta.filters.update(
get_m2m_filters(backward_relation_name, m2m_relation)
)
related_model._meta.m2m_fields.add(backward_relation_name)
related_model._meta.fetch_fields.add(backward_relation_name)
related_model._meta.fields_map[backward_relation_name] = m2m_relation
related_model._meta.fields.add(backward_relation_name)
model._meta.filters.update(get_m2m_filters(field, field_object))
related_model._meta.add_field(backward_relation_name, m2m_relation)

@classmethod
def _discover_client_class(cls, engine: str) -> BaseDBAsyncClient:
Expand Down
9 changes: 7 additions & 2 deletions tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,20 @@ def acquire_connection(self) -> ConnectionWrapper:
return ConnectionWrapper(self._connection, self._lock)

def _in_transaction(self) -> "TransactionWrapper":
return self._transaction_class(self.connection_name, self._connection, self._lock)
return self._transaction_class(
connection_name=self.connection_name,
connection=self._connection,
lock=self._lock,
fetch_inserted=self.fetch_inserted,
)

@translate_exceptions
async def execute_insert(self, query: str, values: list) -> int:
async with self.acquire_connection() as connection:
self.log.debug("%s: %s", query, values)
# TODO: Cache prepared statement
stmt = await connection.prepare(query)
return await stmt.fetchval(*values)
return await stmt.fetch(*values)

@translate_exceptions
async def execute_query(self, query: str) -> List[dict]:
Expand Down
15 changes: 11 additions & 4 deletions tortoise/backends/asyncpg/executor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from typing import List
from typing import List, Any

from pypika import Parameter, Table

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor


class AsyncpgExecutor(BaseExecutor):

EXPLAIN_PREFIX = "EXPLAIN (FORMAT JSON, VERBOSE)"

def _prepare_insert_statement(self, columns: List[str]) -> str:
return str(
query = (
self.db.query_class.into(Table(self.model._meta.table))
.columns(*columns)
.insert(*[Parameter("$%d" % (i + 1,)) for i in range(len(columns))])
.returning("id")
)
generated_fields = self.model._meta.generated_db_fields
if generated_fields and self.db.fetch_inserted:
query = query.returning(*generated_fields)
return str(query)

async def _process_insert_result(self, instance: Model, results: Any):
if self.model._meta.generated_db_fields and self.db.fetch_inserted:
instance.set_field_values(dict(results))
20 changes: 20 additions & 0 deletions tortoise/backends/asyncpg/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional, Union
from uuid import UUID

from tortoise.fields import JSONField as GenericJSONField, UUIDField as GenericUUIDField


class JSONField(GenericJSONField):
def to_db_value(self, value: Optional[Union[dict, list]], instance):
return value

def to_python_value(self, value: Optional[Union[str, dict, list]]):
return value


class UUIDField(GenericUUIDField):
def to_db_value(self, value: Optional[UUID], instance):
return value

def to_python_value(self, value: Optional[UUID]):
return value
2 changes: 1 addition & 1 deletion tortoise/backends/asyncpg/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class AsyncpgSchemaGenerator(BaseSchemaGenerator):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.FIELD_TYPE_MAP.update({fields.JSONField: "JSONB"})
self.FIELD_TYPE_MAP.update({fields.JSONField: "JSONB", fields.UUIDField: "UUID"})

def _get_primary_key_create_string(self, field_name: str) -> str:
return '"{}" SERIAL NOT NULL PRIMARY KEY'.format(field_name)
3 changes: 2 additions & 1 deletion tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ class BaseDBAsyncClient:
schema_generator = BaseSchemaGenerator
capabilities = Capabilities("")

def __init__(self, connection_name: str, **kwargs) -> None:
def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs) -> None:
self.log = logging.getLogger("db_client")
self.connection_name = connection_name
self.fetch_inserted = fetch_inserted

async def create_connection(self, with_db: bool) -> None:
raise NotImplementedError() # pragma: nocoverage
Expand Down
47 changes: 28 additions & 19 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def execute_select(self, query, custom_fields: Optional[list] = None) -> l
raw_results = await self.db.execute_query(query.get_sql())
instance_list = []
for row in raw_results:
instance = self.model(**row)
instance = self.model(_from_db=True, **row)
if custom_fields:
for field in custom_fields:
setattr(instance, field, row[field])
Expand Down Expand Up @@ -65,6 +65,9 @@ def _prepare_insert_statement(self, columns: List[str]) -> str:
# go to descendant executors
raise NotImplementedError() # pragma: nocoverage

async def _process_insert_result(self, instance: "Model", results: Any):
raise NotImplementedError() # pragma: nocoverage

async def execute_insert(self, instance):
key = "{}:{}".format(self.db.connection_name, self.model._meta.table)
if key not in INSERT_CACHE:
Expand All @@ -75,7 +78,8 @@ async def execute_insert(self, instance):
regular_columns, columns, query = INSERT_CACHE[key]

values = self._prepare_insert_values(instance=instance, regular_columns=regular_columns)
instance.id = await self.db.execute_insert(query, values)
insert_result = await self.db.execute_insert(query, values)
await self._process_insert_result(instance, insert_result)
return instance

async def execute_update(self, instance):
Expand All @@ -87,22 +91,22 @@ async def execute_update(self, instance):
query = query.set(
db_field, self._field_to_db(field_object, getattr(instance, field), instance)
)
query = query.where(table.id == instance.id)
query = query.where(getattr(table, self.model._meta.db_pk_field) == instance.pk)
await self.db.execute_query(query.get_sql())
return instance

async def execute_delete(self, instance):
table = Table(self.model._meta.table)
query = self.model._meta.basequery.where(table.id == instance.id).delete()
query = self.model._meta.basequery.where(
getattr(table, self.model._meta.db_pk_field) == instance.pk
).delete()
await self.db.execute_query(query.get_sql())
return instance

async def _prefetch_reverse_relation(
self, instance_list: list, field: str, related_query
) -> list:
instance_id_set = set() # type: Set[int]
for instance in instance_list:
instance_id_set.add(instance.id)
instance_id_set = {instance.pk for instance in instance_list} # type: Set[Any]
backward_relation_manager = getattr(self.model, field)
relation_field = backward_relation_manager.relation_field

Expand All @@ -119,13 +123,11 @@ async def _prefetch_reverse_relation(
related_object_map[object_id] = [entry]
for instance in instance_list:
relation_container = getattr(instance, field)
relation_container._set_result_for_query(related_object_map.get(instance.id, []))
relation_container._set_result_for_query(related_object_map.get(instance.pk, []))
return instance_list

async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_query) -> list:
instance_id_set = set() # type: Set[int]
for instance in instance_list:
instance_id_set.add(instance.id)
instance_id_set = {instance.pk for instance in instance_list} # type: Set[Any]

field_object = self.model._meta.fields_map[field]

Expand All @@ -141,12 +143,13 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_
)

related_query_table = Table(related_query.model._meta.table)
related_pk_field = related_query.model._meta.db_pk_field
query = (
related_query.query.join(subquery)
.on(subquery._forward_relation_key == related_query_table.id)
.on(subquery._forward_relation_key == getattr(related_query_table, related_pk_field))
.select(
subquery._backward_relation_key.as_("_backward_relation_key"),
*[getattr(related_query_table, field).as_(field) for field in related_query.fields]
*[getattr(related_query_table, field).as_(field) for field in related_query.fields],
)
)

Expand All @@ -173,12 +176,18 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_
query = query.having(having_criterion)

raw_results = await self.db.execute_query(query.get_sql())
relations = {(e["_backward_relation_key"], e["id"]) for e in raw_results}
related_object_list = [related_query.model(**e) for e in raw_results]
relations = {
(
self.model._meta.pk.to_python_value(e["_backward_relation_key"]),
field_object.type._meta.pk.to_python_value(e[related_pk_field]),
)
for e in raw_results
}
related_object_list = [related_query.model(_from_db=True, **e) for e in raw_results]
await self.__class__(
model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map
).fetch_for_list(related_object_list)
related_object_map = {e.id: e for e in related_object_list}
related_object_map = {e.pk: e for e in related_object_list}
relation_map = {} # type: Dict[str, list]

for object_id, related_object_id in relations:
Expand All @@ -188,7 +197,7 @@ async def _prefetch_m2m_relation(self, instance_list: list, field: str, related_

for instance in instance_list:
relation_container = getattr(instance, field)
relation_container._set_result_for_query(relation_map.get(instance.id, []))
relation_container._set_result_for_query(relation_map.get(instance.pk, []))
return instance_list

async def _prefetch_direct_relation(
Expand All @@ -200,8 +209,8 @@ async def _prefetch_direct_relation(
if getattr(instance, relation_key_field):
related_objects_for_fetch.add(getattr(instance, relation_key_field))
if related_objects_for_fetch:
related_object_list = await related_query.filter(id__in=list(related_objects_for_fetch))
related_object_map = {obj.id: obj for obj in related_object_list}
related_object_list = await related_query.filter(pk__in=list(related_objects_for_fetch))
related_object_map = {obj.pk: obj for obj in related_object_list}
for instance in instance_list:
setattr(
instance, field, related_object_map.get(getattr(instance, relation_key_field))
Expand Down
16 changes: 14 additions & 2 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,24 @@ class BaseSchemaGenerator:
fields.DateField: "DATE",
fields.FloatField: "DOUBLE PRECISION",
fields.JSONField: "TEXT",
fields.UUIDField: "CHAR(36)",
}

def __init__(self, client) -> None:
self.client = client

def _create_string(self, db_field: str, field_type: str, nullable: str, unique: str) -> str:
def _create_string(
self, db_field: str, field_type: str, nullable: str, unique: str, is_pk: bool
) -> str:
# children can override this function to customize thier sql queries

field_creation_string = self.FIELD_TEMPLATE.format(
name=db_field, type=field_type, nullable=nullable, unique=unique
).strip()

if is_pk:
field_creation_string += " PRIMARY KEY"

return field_creation_string

def _get_primary_key_create_string(self, field_name: str) -> str:
Expand Down Expand Up @@ -105,7 +111,13 @@ def _get_table_sql(self, model, safe=True) -> dict:
elif isinstance(field_object, fields.CharField):
field_type = field_type.format(field_object.max_length)

field_creation_string = self._create_string(db_field, field_type, nullable, unique)
field_creation_string = self._create_string(
db_field=db_field,
field_type=field_type,
nullable=nullable,
unique=unique,
is_pk=field_object.pk,
)

if hasattr(field_object, "reference") and field_object.reference:
field_creation_string += self.FK_TEMPLATE.format(
Expand Down
Loading

0 comments on commit 99685a4

Please sign in to comment.