Skip to content

Commit

Permalink
Make fields of querysets private
Browse files Browse the repository at this point in the history
  • Loading branch information
henadzit committed Oct 28, 2024
1 parent fb4b234 commit 7af50c7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 60 deletions.
1 change: 1 addition & 0 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ async def test_order_by(self):
tournaments = await base_query.order_by("-name")
self.assertEqual(tournaments, [b, a])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_values_with_annotations(self):
await Tournament.create(name="Championship")
await Tournament.create(name="Super Bowl")
Expand Down
120 changes: 60 additions & 60 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def values(

class AwaitableQuery(Generic[MODEL]):
__slots__ = (
"_joined_tables",
"query",
"model",
"_joined_tables",
"_db",
"capabilities",
"_annotations",
Expand Down Expand Up @@ -1452,8 +1452,8 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]):
"_orderings",
"_single",
"_raise_does_not_exist",
"fields_for_select_list",
"flat",
"_fields_for_select_list",
"_flat",
"_group_bys",
"_force_indexes",
"_use_indexes",
Expand Down Expand Up @@ -1492,8 +1492,8 @@ def __init__(
self._q_objects = q_objects
self._single = single
self._raise_does_not_exist = raise_does_not_exist
self.fields_for_select_list = fields_for_select_list
self.flat = flat
self._fields_for_select_list = fields_for_select_list
self._flat = flat
self._db = db
self._group_bys = group_bys
self._force_indexes = force_indexes
Expand Down Expand Up @@ -1555,7 +1555,7 @@ async def _execute(self) -> Union[List[Any], Tuple]:
(key, self.resolve_to_python_value(self.model, name))
for key, name in self.fields.items()
]
if self.flat:
if self._flat:
func = columns[0][1]
flatmap = lambda entry: func(entry["0"]) # noqa
lst_values = list(map(flatmap, result))
Expand All @@ -1576,7 +1576,7 @@ async def _execute(self) -> Union[List[Any], Tuple]:

class ValuesQuery(FieldSelectQuery, Generic[SINGLE]):
__slots__ = (
"fields_for_select",
"_fields_for_select",
"_limit",
"_offset",
"_distinct",
Expand Down Expand Up @@ -1607,7 +1607,7 @@ def __init__(
use_indexes: Set[str],
) -> None:
super().__init__(model, annotations)
self.fields_for_select = fields_for_select
self._fields_for_select = fields_for_select
self._limit = limit
self._offset = offset
self._distinct = distinct
Expand All @@ -1625,7 +1625,7 @@ def _make_query(self) -> None:
self._joined_tables = []

self.query = copy(self.model._meta.basequery)
for return_as, field in self.fields_for_select.items():
for return_as, field in self._fields_for_select.items():
self.add_field_to_select_query(field, return_as)

self.resolve_ordering(
Expand All @@ -1638,7 +1638,7 @@ def _make_query(self) -> None:

# remove annotations that are not in fields_for_select
self.query._selects = [
select for select in self.query._selects if select.alias in self.fields_for_select
select for select in self.query._selects if select.alias in self._fields_for_select
]

if self._limit:
Expand Down Expand Up @@ -1685,7 +1685,7 @@ async def _execute(self) -> Union[List[dict], Dict]:
val
for val in [
(alias, self.resolve_to_python_value(self.model, field_name))
for alias, field_name in self.fields_for_select.items()
for alias, field_name in self._fields_for_select.items()
]
if not isinstance(val[1], types.LambdaType)
]
Expand Down Expand Up @@ -1732,7 +1732,7 @@ def __await__(self) -> Generator[Any, None, List[MODEL]]:


class BulkUpdateQuery(UpdateQuery, Generic[MODEL]):
__slots__ = ("objects", "fields", "batch_size", "queries")
__slots__ = ("fields", "_objects", "_batch_size", "_queries")

def __init__(
self,
Expand All @@ -1757,10 +1757,10 @@ def __init__(
limit=limit,
orderings=orderings,
)
self.objects = objects
self.fields = fields
self.batch_size = batch_size
self.queries: List[QueryBuilder] = []
self._objects = objects
self._batch_size = batch_size
self._queries: List[QueryBuilder] = []

def _make_query(self) -> None:
table = self.model._meta.basetable
Expand All @@ -1779,7 +1779,7 @@ def _make_query(self) -> None:
pk_attr = self.model._meta.pk_attr
source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr
pk = Field(source_pk_attr)
for objects_item in chunk(self.objects, self.batch_size):
for objects_item in chunk(self._objects, self._batch_size):
query = copy(self.query)
for field in self.fields:
case = Case()
Expand All @@ -1803,30 +1803,30 @@ def _make_query(self) -> None:
pk_list.append(value)
query = query.set(field, case)
query = query.where(pk.isin(pk_list))
self.queries.append(query)
self._queries.append(query)

async def _execute(self) -> int:
count = 0
for query in self.queries:
for query in self._queries:
count += (await self._db.execute_query(str(query)))[0]
return count

def sql(self, **kwargs) -> str:
self.as_query()
return ";".join([str(query) for query in self.queries])
return ";".join([str(query) for query in self._queries])


class BulkCreateQuery(AwaitableQuery, Generic[MODEL]):
__slots__ = (
"objects",
"ignore_conflicts",
"batch_size",
"_objects",
"_ignore_conflicts",
"_batch_size",
"_db",
"executor",
"insert_query",
"insert_query_all",
"update_fields",
"on_conflict",
"_executor",
"_insert_query",
"_insert_query_all",
"_update_fields",
"_on_conflict",
)

def __init__(
Expand All @@ -1840,70 +1840,70 @@ def __init__(
on_conflict: Optional[Iterable[str]] = None,
):
super().__init__(model)
self.objects = objects
self.ignore_conflicts = ignore_conflicts
self.batch_size = batch_size
self._objects = objects
self._ignore_conflicts = ignore_conflicts
self._batch_size = batch_size
self._db = db
self.update_fields = update_fields
self.on_conflict = on_conflict
self._update_fields = update_fields
self._on_conflict = on_conflict

def _make_query(self) -> None:
self.executor = self._db.executor_class(model=self.model, db=self._db)
if self.ignore_conflicts or self.update_fields:
regular_columns, columns = self.executor._prepare_insert_columns()
self.insert_query = self.executor._prepare_insert_statement(
columns, ignore_conflicts=self.ignore_conflicts
self._executor = self._db.executor_class(model=self.model, db=self._db)
if self._ignore_conflicts or self._update_fields:
regular_columns, columns = self._executor._prepare_insert_columns()
self._insert_query = self._executor._prepare_insert_statement(
columns, ignore_conflicts=self._ignore_conflicts
)
self.insert_query_all = self.insert_query
self._insert_query_all = self._insert_query
if self.model._meta.generated_db_fields:
regular_columns_all, columns_all = self.executor._prepare_insert_columns(
regular_columns_all, columns_all = self._executor._prepare_insert_columns(
include_generated=True
)
self.insert_query_all = self.executor._prepare_insert_statement(
self._insert_query_all = self._executor._prepare_insert_statement(
columns_all,
has_generated=False,
ignore_conflicts=self.ignore_conflicts,
ignore_conflicts=self._ignore_conflicts,
)
if self.update_fields:
if self._update_fields:
alias = f"new_{self.model._meta.db_table}"
self.insert_query_all = self.insert_query_all.as_(alias).on_conflict(
*self.on_conflict
self._insert_query_all = self._insert_query_all.as_(alias).on_conflict(
*self._on_conflict
)
self.insert_query = self.insert_query.as_(alias).on_conflict(*self.on_conflict)
for update_field in self.update_fields:
self.insert_query_all = self.insert_query_all.do_update(update_field)
self.insert_query = self.insert_query.do_update(update_field)
self._insert_query = self._insert_query.as_(alias).on_conflict(*self._on_conflict)
for update_field in self._update_fields:
self._insert_query_all = self._insert_query_all.do_update(update_field)
self._insert_query = self._insert_query.do_update(update_field)
else:
self.insert_query_all = self.executor.insert_query_all
self.insert_query = self.executor.insert_query
self._insert_query_all = self._executor.insert_query_all
self._insert_query = self._executor.insert_query

async def _execute(self) -> None:
for instance_chunk in chunk(self.objects, self.batch_size):
for instance_chunk in chunk(self._objects, self._batch_size):
values_lists_all = []
values_lists = []
for instance in instance_chunk:
if instance._custom_generated_pk:
values_lists_all.append(
[
self.executor.column_map[field_name](
self._executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self.executor.regular_columns_all
for field_name in self._executor.regular_columns_all
]
)
else:
values_lists.append(
[
self.executor.column_map[field_name](
self._executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self.executor.regular_columns
for field_name in self._executor.regular_columns
]
)
if values_lists_all:
await self._db.execute_many(str(self.insert_query_all), values_lists_all)
await self._db.execute_many(str(self._insert_query_all), values_lists_all)
if values_lists:
await self._db.execute_many(str(self.insert_query), values_lists)
await self._db.execute_many(str(self._insert_query), values_lists)

def __await__(self) -> Generator[Any, None, None]:
if self._db is None:
Expand All @@ -1913,6 +1913,6 @@ def __await__(self) -> Generator[Any, None, None]:

def sql(self, **kwargs) -> str:
self.as_query()
if self.insert_query and self.insert_query_all:
return ";".join([str(self.insert_query), str(self.insert_query_all)])
return str(self.insert_query or self.insert_query_all)
if self._insert_query and self._insert_query_all:
return ";".join([str(self._insert_query), str(self._insert_query_all)])
return str(self._insert_query or self._insert_query_all)

0 comments on commit 7af50c7

Please sign in to comment.