Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize field conversion to database format #1840

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ Changelog

.. rst-class:: emphasize-children

0.23
0.24
====

0.23.1
0.24.0 (unreleased)
------
Fixed
^^^^^
- Rename pypika to pypika_tortoise for fixing package name conflict (#1829)
- Concurrent connection pool initialization (#1825)
Changed
^^^^^^^
- Optimize field conversion to database format to speed up `create` and `bulk_create` (#1840)

0.23.0
------
Expand Down
8 changes: 8 additions & 0 deletions tests/fields/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ async def test_empty(self):


class TestDatetimeFields(TestEmpty):
async def asyncSetUp(self):
await super().asyncSetUp()
timezone._reset_timezone_cache()

async def asyncTearDown(self):
await super().asyncTearDown()
timezone._reset_timezone_cache()

def test_both_auto_bad(self):
with self.assertRaisesRegex(
ConfigurationError, "You can choose only 'auto_now' or 'auto_now_add'"
Expand Down
2 changes: 2 additions & 0 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tortoise.filters import get_m2m_filters
from tortoise.log import logger
from tortoise.models import Model, ModelMeta
from tortoise.timezone import _reset_timezone_cache
from tortoise.utils import generate_schema_for_client


Expand Down Expand Up @@ -614,6 +615,7 @@ async def _drop_databases(cls) -> None:
def _init_timezone(cls, use_tz: bool, timezone: str) -> None:
os.environ["USE_TZ"] = str(use_tz)
os.environ["TIMEZONE"] = timezone
_reset_timezone_cache()


def run_async(coro: Coroutine) -> None:
Expand Down
35 changes: 14 additions & 21 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import datetime
import decimal
from copy import copy
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -24,7 +23,6 @@

from tortoise.exceptions import OperationalError
from tortoise.expressions import Expression, ResolveContext
from tortoise.fields.base import Field
from tortoise.fields.relational import (
BackwardFKRelation,
BackwardOneToOneRelation,
Expand All @@ -42,12 +40,11 @@

EXECUTOR_CACHE: Dict[
Tuple[str, Optional[str], str],
Tuple[list, str, list, str, Dict[str, Callable], str, Dict[str, str]],
Tuple[list, str, list, str, str, Dict[str, str]],
] = {}


class BaseExecutor:
TO_DB_OVERRIDE: Dict[Type[Field], Callable] = {}
FILTER_FUNC_OVERRIDE: Dict[Callable, Callable] = {}
EXPLAIN_PREFIX: str = "EXPLAIN"
DB_NATIVE = {bytes, str, int, float, decimal.Decimal, datetime.datetime, datetime.date}
Expand Down Expand Up @@ -81,16 +78,6 @@ def __init__(
self._prepare_insert_statement(columns_all, has_generated=False)
)

self.column_map: Dict[str, Callable[[Any, Any], Any]] = {}
for column in self.regular_columns_all:
field_object = self.model._meta.fields_map[column]
if field_object.__class__ in self.TO_DB_OVERRIDE:
self.column_map[column] = partial(
self.TO_DB_OVERRIDE[field_object.__class__], field_object
)
else:
self.column_map[column] = field_object.to_db_value

table = self.model._meta.basetable
basequery = cast(QueryBuilder, self.model._meta.basequery)
self.delete_query = str(
Expand All @@ -103,7 +90,6 @@ def __init__(
self.insert_query,
self.regular_columns_all,
self.insert_query_all,
self.column_map,
self.delete_query,
self.update_cache,
)
Expand All @@ -114,7 +100,6 @@ def __init__(
self.insert_query,
self.regular_columns_all,
self.insert_query_all,
self.column_map,
self.delete_query,
self.update_cache,
) = EXECUTOR_CACHE[key]
Expand Down Expand Up @@ -194,15 +179,19 @@ def parameter(self, pos: int) -> Parameter:
async def execute_insert(self, instance: "Model") -> None:
if not instance._custom_generated_pk:
values = [
self.column_map[field_name](getattr(instance, field_name), instance)
self.model._meta.fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self.regular_columns
]
insert_result = await self.db.execute_insert(self.insert_query, values)
await self._process_insert_result(instance, insert_result)

else:
values = [
self.column_map[field_name](getattr(instance, field_name), instance)
self.model._meta.fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self.regular_columns_all
]
await self.db.execute_insert(self.insert_query_all, values)
Expand All @@ -219,14 +208,18 @@ async def execute_bulk_insert(
if instance._custom_generated_pk:
values_lists_all.append(
[
self.column_map[field_name](getattr(instance, field_name), instance)
self.model._meta.fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self.regular_columns_all
]
)
else:
values_lists.append(
[
self.column_map[field_name](getattr(instance, field_name), instance)
self.model._meta.fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self.regular_columns
]
)
Expand Down Expand Up @@ -292,7 +285,7 @@ async def execute_update(
if isinstance(instance_field, Expression):
expressions[field] = instance_field
else:
value = self.column_map[field](instance_field, instance)
value = self.model._meta.fields_map[field].to_db_value(instance_field, instance)
values.append(value)
values.append(self.model._meta.pk.to_db_value(instance.pk, instance))
return (
Expand Down
17 changes: 1 addition & 16 deletions tortoise/backends/mssql/executor.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,9 @@
from typing import Any, Optional, Type, Union
from typing import Any

from tortoise import Model, fields
from tortoise.backends.odbc.executor import ODBCExecutor
from tortoise.exceptions import UnSupportedError
from tortoise.fields import BooleanField


def to_db_bool(
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model]
) -> Optional[int]:
self.validate(value)
if value is None:
return None
return int(bool(value))


class MSSQLExecutor(ODBCExecutor):
TO_DB_OVERRIDE = {
fields.BooleanField: to_db_bool,
}

async def execute_explain(self, sql: str) -> Any:
raise UnSupportedError("MSSQL does not support explain")
85 changes: 5 additions & 80 deletions tortoise/backends/sqlite/executor.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,24 @@
import datetime
import sqlite3
from decimal import Decimal
from typing import Optional, Type, Union

import pytz

from tortoise import Model, fields, timezone
from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.sqlite.regex import (
insensitive_posix_sqlite_regexp,
posix_sqlite_regexp,
)
from tortoise.fields import (
BigIntField,
BooleanField,
DatetimeField,
DecimalField,
IntField,
SmallIntField,
TimeField,
)
from tortoise.fields import BigIntField, IntField, SmallIntField
from tortoise.filters import insensitive_posix_regex, posix_regex


def to_db_bool(
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model]
) -> Optional[int]:
self.validate(value)
if value is None:
return None
return int(bool(value))


def to_db_decimal(
self: DecimalField,
value: Optional[Union[str, float, int, Decimal]],
instance: Union[Type[Model], Model],
) -> Optional[str]:
self.validate(value)
if value is None:
return None
return str(Decimal(value).quantize(self.quant).normalize())


def to_db_datetime(
self: DatetimeField, value: Optional[datetime.datetime], instance: Union[Type[Model], Model]
) -> Optional[str]:
self.validate(value)
# Only do this if it is a Model instance, not class. Test for guaranteed instance var
if hasattr(instance, "_saved_in_db") and (
self.auto_now
or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None)
):
if timezone.get_use_tz():
value = datetime.datetime.now(tz=pytz.utc)
else:
value = datetime.datetime.now(tz=timezone.get_default_timezone())
setattr(instance, self.model_field_name, value)
return value.isoformat(" ")
if isinstance(value, datetime.datetime):
return value.isoformat(" ")
return None


def to_db_time(
self: TimeField, value: Optional[datetime.time], instance: Union[Type[Model], Model]
) -> Optional[str]:
self.validate(value)
if hasattr(instance, "_saved_in_db") and (
self.auto_now
or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None)
):
if timezone.get_use_tz():
value = datetime.datetime.now(tz=pytz.utc).time()
else:
value = datetime.datetime.now(tz=timezone.get_default_timezone()).time()
setattr(instance, self.model_field_name, value)
return value.isoformat()
if isinstance(value, datetime.time):
return value.isoformat()
return None


# Converts Decimal to string for sqlite in cases where it's hard to know the
# Conversion for the cases where it's hard to know the
# related field, e.g. in raw queries, math or annotations.
sqlite3.register_adapter(Decimal, str)
sqlite3.register_adapter(datetime.date, lambda val: val.isoformat())
sqlite3.register_adapter(datetime.datetime, lambda val: val.isoformat(" "))


class SqliteExecutor(BaseExecutor):
TO_DB_OVERRIDE = {
fields.BooleanField: to_db_bool,
fields.DecimalField: to_db_decimal,
fields.DatetimeField: to_db_datetime,
fields.TimeField: to_db_time,
}
EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN"
DB_NATIVE = {bytes, str, int, float}
FILTER_FUNC_OVERRIDE = {
Expand Down
5 changes: 0 additions & 5 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,6 @@ def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any:
if value is not None and not isinstance(value, self.field_type):
value = self.field_type(value) # pylint: disable=E1102

if self.__class__ in self.model._meta.db.executor_class.TO_DB_OVERRIDE:
value = self.model._meta.db.executor_class.TO_DB_OVERRIDE[self.__class__](
self, value, instance
)

self.validate(value)
return value

Expand Down
14 changes: 6 additions & 8 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,6 @@ def _make_query(self) -> None:
self.resolve_ordering(self.model, table, self._orderings, self._annotations)

self.resolve_filters()
# Need to get executor to get correct column_map
executor = self._db.executor_class(model=self.model, db=self._db)
for key, value in self.update_kwargs.items():
field_object = self.model._meta.fields_map.get(key)
if not field_object:
Expand All @@ -1207,7 +1205,7 @@ def _make_query(self) -> None:
self.model._validate_relation_type(key, value)
fk_field: str = field_object.source_field # type: ignore
db_field = self.model._meta.fields_map[fk_field].source_field
value = executor.column_map[fk_field](
value = self.model._meta.fields_map[fk_field].to_db_value(
getattr(value, field_object.to_field_instance.model_field_name),
None,
)
Expand All @@ -1227,7 +1225,7 @@ def _make_query(self) -> None:
)
).term
else:
value = executor.column_map[key](value, None)
value = self.model._meta.fields_map[key].to_db_value(value, None)

self.query = self.query.set(db_field, value)

Expand Down Expand Up @@ -1838,7 +1836,6 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]:
)

self.resolve_filters()
executor = self._db.executor_class(model=self.model, db=self._db)
pk_attr = self.model._meta.pk_attr
source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr
pk = Field(source_pk_attr)
Expand All @@ -1848,7 +1845,7 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]:
case = Case()
pk_list = []
for obj in objects_item:
pk_value = executor.column_map[pk_attr](obj.pk, None)
pk_value = self.model._meta.fields_map[pk_attr].to_db_value(obj.pk, None)
field_obj = obj._meta.fields_map[field]
field_value = field_obj.to_db_value(getattr(obj, field), obj)
case.when(
Expand Down Expand Up @@ -1945,14 +1942,15 @@ def _make_queries(self) -> Tuple[str, str]:
return self._executor.insert_query, self._executor.insert_query_all

async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None:
fields_map = self.model._meta.fields_map
for instance_chunk in chunk(self._objects, self._batch_size):
values_lists_all = []
values_lists = []
for instance in instance_chunk:
if instance._custom_generated_pk:
values_lists_all.append(
[
self._executor.column_map[field_name](
fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self._executor.regular_columns_all
Expand All @@ -1961,7 +1959,7 @@ async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None:
else:
values_lists.append(
[
self._executor.column_map[field_name](
fields_map[field_name].to_db_value(
getattr(instance, field_name), instance
)
for field_name in self._executor.regular_columns
Expand Down
Loading
Loading