From 7af50c7324a7d39b42255b07fb36be1cc89b35b3 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 28 Oct 2024 17:18:07 +0100 Subject: [PATCH] Make fields of querysets private --- tests/test_queryset.py | 1 + tortoise/queryset.py | 120 ++++++++++++++++++++--------------------- 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 5559322f4..49da96e85 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -759,6 +759,7 @@ async def test_order_by(self): tournaments = await base_query.order_by("-name") self.assertEqual(tournaments, [b, a]) + @test.requireCapability(dialect=NotEQ("mssql")) async def test_values_with_annotations(self): await Tournament.create(name="Championship") await Tournament.create(name="Super Bowl") diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 2919ca17f..287fd5474 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -88,9 +88,9 @@ def values( class AwaitableQuery(Generic[MODEL]): __slots__ = ( - "_joined_tables", "query", "model", + "_joined_tables", "_db", "capabilities", "_annotations", @@ -1452,8 +1452,8 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): "_orderings", "_single", "_raise_does_not_exist", - "fields_for_select_list", - "flat", + "_fields_for_select_list", + "_flat", "_group_bys", "_force_indexes", "_use_indexes", @@ -1492,8 +1492,8 @@ def __init__( self._q_objects = q_objects self._single = single self._raise_does_not_exist = raise_does_not_exist - self.fields_for_select_list = fields_for_select_list - self.flat = flat + self._fields_for_select_list = fields_for_select_list + self._flat = flat self._db = db self._group_bys = group_bys self._force_indexes = force_indexes @@ -1555,7 +1555,7 @@ async def _execute(self) -> Union[List[Any], Tuple]: (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() ] - if self.flat: + if self._flat: func = columns[0][1] flatmap = lambda entry: func(entry["0"]) # noqa lst_values = list(map(flatmap, result)) @@ -1576,7 +1576,7 @@ async def _execute(self) -> Union[List[Any], Tuple]: class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( - "fields_for_select", + "_fields_for_select", "_limit", "_offset", "_distinct", @@ -1607,7 +1607,7 @@ def __init__( use_indexes: Set[str], ) -> None: super().__init__(model, annotations) - self.fields_for_select = fields_for_select + self._fields_for_select = fields_for_select self._limit = limit self._offset = offset self._distinct = distinct @@ -1625,7 +1625,7 @@ def _make_query(self) -> None: self._joined_tables = [] self.query = copy(self.model._meta.basequery) - for return_as, field in self.fields_for_select.items(): + for return_as, field in self._fields_for_select.items(): self.add_field_to_select_query(field, return_as) self.resolve_ordering( @@ -1638,7 +1638,7 @@ def _make_query(self) -> None: # remove annotations that are not in fields_for_select self.query._selects = [ - select for select in self.query._selects if select.alias in self.fields_for_select + select for select in self.query._selects if select.alias in self._fields_for_select ] if self._limit: @@ -1685,7 +1685,7 @@ async def _execute(self) -> Union[List[dict], Dict]: val for val in [ (alias, self.resolve_to_python_value(self.model, field_name)) - for alias, field_name in self.fields_for_select.items() + for alias, field_name in self._fields_for_select.items() ] if not isinstance(val[1], types.LambdaType) ] @@ -1732,7 +1732,7 @@ def __await__(self) -> Generator[Any, None, List[MODEL]]: class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): - __slots__ = ("objects", "fields", "batch_size", "queries") + __slots__ = ("fields", "_objects", "_batch_size", "_queries") def __init__( self, @@ -1757,10 +1757,10 @@ def __init__( limit=limit, orderings=orderings, ) - self.objects = objects self.fields = fields - self.batch_size = batch_size - self.queries: List[QueryBuilder] = [] + self._objects = objects + self._batch_size = batch_size + self._queries: List[QueryBuilder] = [] def _make_query(self) -> None: table = self.model._meta.basetable @@ -1779,7 +1779,7 @@ def _make_query(self) -> None: pk_attr = self.model._meta.pk_attr source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr pk = Field(source_pk_attr) - for objects_item in chunk(self.objects, self.batch_size): + for objects_item in chunk(self._objects, self._batch_size): query = copy(self.query) for field in self.fields: case = Case() @@ -1803,30 +1803,30 @@ def _make_query(self) -> None: pk_list.append(value) query = query.set(field, case) query = query.where(pk.isin(pk_list)) - self.queries.append(query) + self._queries.append(query) async def _execute(self) -> int: count = 0 - for query in self.queries: + for query in self._queries: count += (await self._db.execute_query(str(query)))[0] return count def sql(self, **kwargs) -> str: self.as_query() - return ";".join([str(query) for query in self.queries]) + return ";".join([str(query) for query in self._queries]) class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): __slots__ = ( - "objects", - "ignore_conflicts", - "batch_size", + "_objects", + "_ignore_conflicts", + "_batch_size", "_db", - "executor", - "insert_query", - "insert_query_all", - "update_fields", - "on_conflict", + "_executor", + "_insert_query", + "_insert_query_all", + "_update_fields", + "_on_conflict", ) def __init__( @@ -1840,70 +1840,70 @@ def __init__( on_conflict: Optional[Iterable[str]] = None, ): super().__init__(model) - self.objects = objects - self.ignore_conflicts = ignore_conflicts - self.batch_size = batch_size + self._objects = objects + self._ignore_conflicts = ignore_conflicts + self._batch_size = batch_size self._db = db - self.update_fields = update_fields - self.on_conflict = on_conflict + self._update_fields = update_fields + self._on_conflict = on_conflict def _make_query(self) -> None: - self.executor = self._db.executor_class(model=self.model, db=self._db) - if self.ignore_conflicts or self.update_fields: - regular_columns, columns = self.executor._prepare_insert_columns() - self.insert_query = self.executor._prepare_insert_statement( - columns, ignore_conflicts=self.ignore_conflicts + self._executor = self._db.executor_class(model=self.model, db=self._db) + if self._ignore_conflicts or self._update_fields: + regular_columns, columns = self._executor._prepare_insert_columns() + self._insert_query = self._executor._prepare_insert_statement( + columns, ignore_conflicts=self._ignore_conflicts ) - self.insert_query_all = self.insert_query + self._insert_query_all = self._insert_query if self.model._meta.generated_db_fields: - regular_columns_all, columns_all = self.executor._prepare_insert_columns( + regular_columns_all, columns_all = self._executor._prepare_insert_columns( include_generated=True ) - self.insert_query_all = self.executor._prepare_insert_statement( + self._insert_query_all = self._executor._prepare_insert_statement( columns_all, has_generated=False, - ignore_conflicts=self.ignore_conflicts, + ignore_conflicts=self._ignore_conflicts, ) - if self.update_fields: + if self._update_fields: alias = f"new_{self.model._meta.db_table}" - self.insert_query_all = self.insert_query_all.as_(alias).on_conflict( - *self.on_conflict + self._insert_query_all = self._insert_query_all.as_(alias).on_conflict( + *self._on_conflict ) - self.insert_query = self.insert_query.as_(alias).on_conflict(*self.on_conflict) - for update_field in self.update_fields: - self.insert_query_all = self.insert_query_all.do_update(update_field) - self.insert_query = self.insert_query.do_update(update_field) + self._insert_query = self._insert_query.as_(alias).on_conflict(*self._on_conflict) + for update_field in self._update_fields: + self._insert_query_all = self._insert_query_all.do_update(update_field) + self._insert_query = self._insert_query.do_update(update_field) else: - self.insert_query_all = self.executor.insert_query_all - self.insert_query = self.executor.insert_query + self._insert_query_all = self._executor.insert_query_all + self._insert_query = self._executor.insert_query async def _execute(self) -> None: - for instance_chunk in chunk(self.objects, self.batch_size): + for instance_chunk in chunk(self._objects, self._batch_size): values_lists_all = [] values_lists = [] for instance in instance_chunk: if instance._custom_generated_pk: values_lists_all.append( [ - self.executor.column_map[field_name]( + self._executor.column_map[field_name]( getattr(instance, field_name), instance ) - for field_name in self.executor.regular_columns_all + for field_name in self._executor.regular_columns_all ] ) else: values_lists.append( [ - self.executor.column_map[field_name]( + self._executor.column_map[field_name]( getattr(instance, field_name), instance ) - for field_name in self.executor.regular_columns + for field_name in self._executor.regular_columns ] ) if values_lists_all: - await self._db.execute_many(str(self.insert_query_all), values_lists_all) + await self._db.execute_many(str(self._insert_query_all), values_lists_all) if values_lists: - await self._db.execute_many(str(self.insert_query), values_lists) + await self._db.execute_many(str(self._insert_query), values_lists) def __await__(self) -> Generator[Any, None, None]: if self._db is None: @@ -1913,6 +1913,6 @@ def __await__(self) -> Generator[Any, None, None]: def sql(self, **kwargs) -> str: self.as_query() - if self.insert_query and self.insert_query_all: - return ";".join([str(self.insert_query), str(self.insert_query_all)]) - return str(self.insert_query or self.insert_query_all) + if self._insert_query and self._insert_query_all: + return ";".join([str(self._insert_query), str(self._insert_query_all)]) + return str(self._insert_query or self._insert_query_all)