From 9fa413fa85e7d53761d2d46899157f7937ddf8f7 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 14 Nov 2024 12:57:31 +0100 Subject: [PATCH 1/8] Parametrize SELECT queries --- pyproject.toml | 2 +- tests/contrib/test_functions.py | 2 +- tests/test_case_when.py | 122 ++++++++++++++++------------ tests/test_fuzz.py | 15 ++++ tests/test_model_methods.py | 6 +- tests/test_queryset.py | 42 +++++----- tests/test_sql.py | 64 +++++++++++++++ tests/test_values.py | 2 +- tortoise/backends/base/executor.py | 9 +- tortoise/backends/mysql/executor.py | 2 +- tortoise/contrib/mysql/functions.py | 4 +- tortoise/expressions.py | 3 - tortoise/queryset.py | 44 +++++----- 13 files changed, 207 insertions(+), 110 deletions(-) create mode 100644 tests/test_sql.py diff --git a/pyproject.toml b/pyproject.toml index 734975f69..21503251f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8" -pypika-tortoise = "^0.2.2" +pypika-tortoise = { git = "https://github.com/henadzit/pypika-tortoise.git", branch = "parameterization-changes" } iso8601 = "^2.1.0" aiosqlite = ">=0.16.0, <0.21.0" pytz = "*" diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index 6f4e2c1ff..276dd360e 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -21,7 +21,7 @@ async def test_mysql_func_rand(self): @test.requireCapability(dialect="mysql") async def test_mysql_func_rand_with_seed(self): sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql() - expected_sql = "SELECT `intnum` `intnum`,RAND(0) `randnum` FROM `intfields`" + expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`" self.assertEqual(sql, expected_sql) @test.requireCapability(dialect="postgres") diff --git a/tests/test_case_when.py b/tests/test_case_when.py index b2f310dde..197a1ab29 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -11,16 +11,18 @@ async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] self.db = connections.get("models") + self.dialect = self.db.schema_generator.DIALECT async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE $3 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_multi_when(self): @@ -29,33 +31,36 @@ async def test_multi_when(self): ) sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' WHEN "intnum"<=2 THEN \'small\' ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_q_object_when(self): category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default") sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>2 AND `intnum`<8 THEN 'middle' ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>%s AND `intnum`<%s THEN %s ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">$1 AND "intnum"<$2 THEN $3 ELSE $4 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">2 AND "intnum"<8 THEN \'middle\' ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">? AND "intnum"=8 THEN `intnum_null` ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN `intnum_null` ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum_null" ELSE $2 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum_null" ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum_null" ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_AE_then(self): @@ -63,33 +68,36 @@ async def test_AE_then(self): category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default") sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum`+1 ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN `intnum`+%s ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum"+$2 ELSE $3 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum"+1 ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum"+? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_then(self): category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default") sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN COALESCE(`intnum_null`,10) ELSE 'default' END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN COALESCE(`intnum_null`,%s) ELSE %s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN COALESCE("intnum_null",$2) ELSE $3 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN COALESCE("intnum_null",10) ELSE \'default\' END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN COALESCE("intnum_null",?) ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_F_default(self): category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null")) sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE `intnum_null` END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum_null` END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum_null" END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE "intnum_null" END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum_null" END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_AE_default(self): @@ -97,22 +105,24 @@ async def test_AE_default(self): category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1) sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE `intnum`+1 END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum`+%s END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum"+$3 END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE "intnum"+1 END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum"+? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_default(self): category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10)) sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE COALESCE(`intnum_null`,10) END `category` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE COALESCE(`intnum_null`,%s) END `category` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE COALESCE("intnum_null",$3) END "category" FROM "intfields"' else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE COALESCE("intnum_null",10) END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE COALESCE("intnum_null",?) END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_case_when_in_where(self): @@ -126,11 +136,12 @@ async def test_case_when_in_where(self): .values("intnum") .sql() ) - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" + if self.dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END IN (%s,%s)" + elif self.dialect == "postgres": + expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END IN ($6,$7)' else: - expected_sql = "SELECT \"intnum\" \"intnum\" FROM \"intfields\" WHERE CASE WHEN \"intnum\">=8 THEN 'big' WHEN \"intnum\"<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" + expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END IN (?,?)' self.assertEqual(sql, expected_sql) async def test_annotation_in_when_annotation(self): @@ -142,11 +153,12 @@ async def test_annotation_in_when_annotation(self): .sql() ) - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+1 `intnum_plus_1`,CASE WHEN `intnum`+1>=10 THEN true ELSE false END `bigger_than_10` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+%s `intnum_plus_1`,CASE WHEN `intnum`+%s>=%s THEN %s ELSE %s END `bigger_than_10` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+$1 "intnum_plus_1",CASE WHEN "intnum"+$2>=$3 THEN $4 ELSE $5 END "bigger_than_10" FROM "intfields"' else: - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+1 "intnum_plus_1",CASE WHEN "intnum"+1>=10 THEN true ELSE false END "bigger_than_10" FROM "intfields"' + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+? "intnum_plus_1",CASE WHEN "intnum"+?>=? THEN ? ELSE ? END "bigger_than_10" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_annotation_in_when_annotation(self): @@ -158,11 +170,12 @@ async def test_func_annotation_in_when_annotation(self): .sql() ) - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,0) `intnum_col`,CASE WHEN COALESCE(`intnum`,0)=0 THEN true ELSE false END `is_zero` FROM `intfields`" + if self.dialect == "mysql": + expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,%s) `intnum_col`,CASE WHEN COALESCE(`intnum`,%s)=%s THEN %s ELSE %s END `is_zero` FROM `intfields`" + elif self.dialect == "postgres": + expected_sql = 'SELECT "id" "id",COALESCE("intnum",$1) "intnum_col",CASE WHEN COALESCE("intnum",$2)=$3 THEN $4 ELSE $5 END "is_zero" FROM "intfields"' else: - expected_sql = 'SELECT "id" "id",COALESCE("intnum",0) "intnum_col",CASE WHEN COALESCE("intnum",0)=0 THEN true ELSE false END "is_zero" FROM "intfields"' + expected_sql = 'SELECT "id" "id",COALESCE("intnum",?) "intnum_col",CASE WHEN COALESCE("intnum",?)=? THEN ? ELSE ? END "is_zero" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_case_when_in_group_by(self): @@ -175,13 +188,14 @@ async def test_case_when_in_group_by(self): .sql() ) - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT CASE WHEN `intnum`=0 THEN true ELSE false END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" - elif dialect == "mssql": - expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=0 THEN true ELSE false END' + if self.dialect == "mysql": + expected_sql = "SELECT CASE WHEN `intnum`=%s THEN %s ELSE %s END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" + elif self.dialect == "postgres": + expected_sql = 'SELECT CASE WHEN "intnum"=$1 THEN $2 ELSE $3 END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + elif self.dialect == "mssql": + expected_sql = 'SELECT CASE WHEN "intnum"=? THEN ? ELSE ? END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=? THEN ? ELSE ? END' else: - expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + expected_sql = 'SELECT CASE WHEN "intnum"=? THEN ? ELSE ? END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' self.assertEqual(sql, expected_sql) async def test_unknown_field_in_when_annotation(self): diff --git a/tests/test_fuzz.py b/tests/test_fuzz.py index 2cdfe9dd8..d6bd10633 100644 --- a/tests/test_fuzz.py +++ b/tests/test_fuzz.py @@ -1,6 +1,7 @@ from tests.testmodels import CharFields from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ +from tortoise.functions import Upper DODGY_STRINGS = [ "a/", @@ -9,6 +10,11 @@ "a\\x39", "a'", '"', + '""', + "'", + "''", + "\\_", + "\\\\_", "‘a", "a’", "‘a’", @@ -134,3 +140,12 @@ async def test_char_fuzz(self): ) self.assertEqual(obj1.pk, obj5.pk) self.assertEqual(char, obj5.char) + + # Filter by a function + obj6 = ( + await CharFields.annotate(upper_char=Upper("char")) + .filter(id=obj1.pk, upper_char=Upper("char")) + .first() + ) + self.assertEqual(obj1.pk, obj6.pk) + self.assertEqual(char, obj6.char) diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index 02bce5cd4..e5d354b06 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -296,14 +296,14 @@ async def test_index_access(self): async def test_index_badval(self): with self.assertRaises(ObjectDoesNotExistError) as cm: - await self.cls[100000] + await self.cls[32767] the_exception = cm.exception # For compatibility reasons this should be an instance of KeyError self.assertIsInstance(the_exception, KeyError) self.assertIs(the_exception.model, self.cls) self.assertEqual(the_exception.pk_name, "id") - self.assertEqual(the_exception.pk_val, 100000) - self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=100000") + self.assertEqual(the_exception.pk_val, 32767) + self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767") async def test_index_badtype(self): with self.assertRaises(ObjectDoesNotExistError) as cm: diff --git a/tests/test_queryset.py b/tests/test_queryset.py index cdf54f965..1e0f8fb5d 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -65,7 +65,7 @@ async def test_limit_zero(self): sql = IntFields.all().only("id").limit(0).sql() self.assertEqual( sql, - 'SELECT "id" "id" FROM "intfields" LIMIT 0', + 'SELECT "id" "id" FROM "intfields" LIMIT ?', ) async def test_offset_count(self): @@ -587,13 +587,13 @@ async def test_force_index(self): sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_again = IntFields.filter(pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql_again, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) @test.requireCapability(support_index_hint=True) @@ -601,7 +601,7 @@ async def test_force_index_available_in_more_query(self): sql_ValuesQuery = IntFields.filter(pk=1).force_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_ValuesListQuery = ( @@ -609,19 +609,19 @@ async def test_force_index_available_in_more_query(self): ) self.assertEqual( sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", + "SELECT %s FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_index_hint=True) @@ -629,13 +629,13 @@ async def test_use_index(self): sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_again = IntFields.filter(pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql_again, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) @test.requireCapability(support_index_hint=True) @@ -643,25 +643,25 @@ async def test_use_index_available_in_more_query(self): sql_ValuesQuery = IntFields.filter(pk=1).use_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_ValuesListQuery = IntFields.filter(pk=1).use_index("index_name").values_list("id").sql() self.assertEqual( sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", + "SELECT %s FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_for_update=True) @@ -675,36 +675,36 @@ async def test_select_for_update(self): if dialect == "postgres": self.assertEqual( sql1, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE', + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE', ) self.assertEqual( sql2, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE NOWAIT', + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT', ) self.assertEqual( sql3, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE SKIP LOCKED', + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED', ) self.assertEqual( sql4, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE OF "intfields"', + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"', ) elif dialect == "mysql": self.assertEqual( sql1, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE", ) self.assertEqual( sql2, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE NOWAIT", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE NOWAIT", ) self.assertEqual( sql3, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE SKIP LOCKED", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE SKIP LOCKED", ) self.assertEqual( sql4, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE OF `intfields`", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE OF `intfields`", ) async def test_select_related(self): diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 000000000..b932a2155 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,64 @@ +from tortoise import connections +from tests.testmodels import CharPkModel, IntFields +from tortoise.contrib import test +from tortoise.expressions import F +from tortoise.functions import Concat + + +class TestSQL(test.TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.db = connections.get("models") + self.dialect = self.db.schema_generator.DIALECT + + def test_filter(self): + sql = CharPkModel.all().filter(id="123").sql() + if self.dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s" + elif self.dialect == "postgres": + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=?' + + self.assertEqual(sql, expected) + + def test_filter_with_limit_offset(self): + sql = CharPkModel.all().filter(id="123").limit(10).offset(0).sql() + if self.dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s LIMIT %s OFFSET %s" + elif self.dialect == "postgres": + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' + elif self.dialect == "mssql": + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? LIMIT ? OFFSET ?' + + self.assertEqual(sql, expected) + + def test_group_by(self): + sql = IntFields.all().group_by("intnum").values("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` GROUP BY `intnum`" + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" GROUP BY "intnum"' + self.assertEqual(sql, expected) + + def test_annotate(self): + sql = CharPkModel.all().annotate(id_plus_one=Concat(F("id"), "_postfix")).sql() + if self.dialect == "mysql": + expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" + elif self.dialect == "postgres": + expected = 'SELECT "id",CONCAT("id",$1) "id_plus_one" FROM "charpkmodel"' + else: + expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' + self.assertEqual(sql, expected) + + def test_update(self): + sql = IntFields.filter(intnum=2).update(intnum=1).sql() + if self.dialect == "mysql": + expected = "UPDATE `intfields` SET `intnum`=%s WHERE `intnum`=%s" + elif self.dialect == "postgres": + expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' + else: + expected = 'UPDATE "intfields" SET "intnum"=? WHERE "intnum"=?' + self.assertEqual(sql, expected) diff --git a/tests/test_values.py b/tests/test_values.py index c74f955fa..870c544ab 100644 --- a/tests/test_values.py +++ b/tests/test_values.py @@ -212,5 +212,5 @@ class TruncMonth(Function): sql = Tournament.all().annotate(date=TruncMonth("created", "%Y-%m-%d")).values("date").sql() self.assertEqual( sql, - 'SELECT DATE_FORMAT("created",\'%Y-%m-%d\') "date" FROM "tournament"', + 'SELECT DATE_FORMAT("created",?) "date" FROM "tournament"', ) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 8b0e1bc60..82ae9ac0e 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -23,7 +23,7 @@ from pypika.queries import QueryBuilder from tortoise.exceptions import OperationalError -from tortoise.expressions import Expression, RawSQL, ResolveContext +from tortoise.expressions import Expression, ResolveContext from tortoise.fields.base import Field from tortoise.fields.relational import ( BackwardFKRelation, @@ -124,9 +124,12 @@ async def execute_explain(self, query: Query) -> Any: return (await self.db.execute_query(sql))[1] async def execute_select( - self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None + self, + sql: str, + values: Optional[list] = None, + custom_fields: Optional[list] = None, ) -> list: - _, raw_results = await self.db.execute_query(query.get_sql()) # type:ignore[union-attr] + _, raw_results = await self.db.execute_query(sql, values) instance_list = [] for row in raw_results: if self.select_related_idx: diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 343d2c5ef..cadc4ccbd 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -43,7 +43,7 @@ def get_value_sql(self, **kwargs) -> str: def escape_like(val: str) -> str: - return val.replace("\\", "\\\\\\\\").replace("%", "\\%").replace("_", "\\_") + return val.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") def mysql_contains(field: Term, value: str) -> Criterion: diff --git a/tortoise/contrib/mysql/functions.py b/tortoise/contrib/mysql/functions.py index 32c035e63..830948ab2 100644 --- a/tortoise/contrib/mysql/functions.py +++ b/tortoise/contrib/mysql/functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pypika.terms import Function, Parameter +from pypika.terms import Function class Rand(Function): @@ -12,4 +12,4 @@ class Rand(Function): def __init__(self, seed: int | None = None, alias=None) -> None: super().__init__("RAND", seed, alias=alias) - self.args = [self.wrap_constant(seed) if seed is not None else Parameter("")] + self.args = [self.wrap_constant(seed)] if seed is not None else [] diff --git a/tortoise/expressions.py b/tortoise/expressions.py index f7049e4be..50af26d2b 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -386,9 +386,6 @@ def _process_filter_kwarg( else model._meta.db.executor_class._field_to_db(field_object, value, model) ) op = param["operator"] - # this is an ugly hack - if op == operator.eq: - encoded_value = model._meta.db.query_class._builder()._wrapper_cls(encoded_value) criterion = op(table[param["source_field"]], encoded_value) return criterion, join diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 31f3e5d73..41acaaefc 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -24,7 +24,7 @@ from pypika.analytics import Count from pypika.functions import Cast from pypika.queries import QueryBuilder -from pypika.terms import Case, Field, Term, ValueWrapper +from pypika.terms import Case, Field, Term, ValueWrapper, Parameterizer from typing_extensions import Literal, Protocol from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities @@ -282,7 +282,7 @@ def _resolve_annotate(self) -> bool: def sql(self, **kwargs) -> str: """Return the actual SQL.""" - return self.as_query().get_sql(**kwargs) + return self.as_query().get_sql(parameterizer=Parameterizer(), **kwargs) def as_query(self) -> QueryBuilder: """Return the actual query.""" @@ -1091,9 +1091,9 @@ def _make_query(self) -> None: ) self.resolve_filters() if self._limit is not None: - self.query._limit = self._limit - if self._offset: - self.query._offset = self._offset + self.query._limit = self.query._wrapper_cls(self._limit) + if self._offset is not None: + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._select_for_update: @@ -1130,15 +1130,18 @@ async def __aiter__(self) -> AsyncIterator[MODEL]: yield val async def _execute(self) -> List[MODEL]: + parameterizer = Parameterizer() + sql = self.query.get_sql(parameterizer=parameterizer) instance_list = await self._db.executor_class( model=self.model, db=self._db, prefetch_map=self._prefetch_map, prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, # type:ignore[arg-type] + select_related_idx=self._select_related_idx, ).execute_select( - self.query, # type:ignore[arg-type] - custom_fields=list(self._annotations), + sql, + parameterizer.values, + custom_fields=list(self._annotations.keys()), ) if self._single: if len(instance_list) == 1: @@ -1184,7 +1187,7 @@ def _make_query(self) -> None: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering(self.model, table, self._orderings, self._annotations) self.resolve_filters() @@ -1267,7 +1270,7 @@ def __init__( def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering( model=self.model, table=self.model._meta.basetable, @@ -1314,8 +1317,8 @@ def __init__( def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) self.resolve_filters() - self.query._limit = 1 - self.query._select_other(ValueWrapper(1)) # type:ignore[arg-type] + self.query._limit = self.query._wrapper_cls(1) + self.query._select_other(ValueWrapper(1)) if self._force_indexes: self.query._force_indexes = [] @@ -1582,9 +1585,9 @@ def _make_query(self) -> None: ) self.resolve_filters() if self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: - self.query._offset = self._offset + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._group_bys: @@ -1710,9 +1713,9 @@ def _make_query(self) -> None: ] if self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: - self.query._offset = self._offset + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._group_bys: @@ -1786,9 +1789,10 @@ def _make_query(self) -> None: self.query = RawSQL(self._sql) # type:ignore[assignment] async def _execute(self) -> Any: - instance_list = await self._db.executor_class(model=self.model, db=self._db).execute_select( - self.query # type:ignore[arg-type] - ) + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + ).execute_select(self.query.get_sql()) return instance_list def __await__(self) -> Generator[Any, None, List[MODEL]]: @@ -1833,7 +1837,7 @@ def _make_query(self) -> None: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering( model=self.model, table=table, From 07114ac3184d741d887d4e04e0455afa075a1972 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 14 Nov 2024 13:50:53 +0100 Subject: [PATCH 2/8] Use new parameter interface --- poetry.lock | 20 ++++++++++++-------- tortoise/backends/base/executor.py | 10 +++++----- tortoise/backends/base_postgres/executor.py | 4 ---- tortoise/backends/mysql/executor.py | 5 +---- tortoise/backends/odbc/executor.py | 5 ----- tortoise/backends/psycopg/executor.py | 5 ----- tortoise/backends/sqlite/executor.py | 4 ---- tortoise/queryset.py | 6 +++--- 8 files changed, 21 insertions(+), 38 deletions(-) diff --git a/poetry.lock b/poetry.lock index d3849fa45..b920cd0b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiofiles" @@ -286,7 +286,7 @@ name = "asyncmy" version = "0.2.10rc1" description = "A fast asyncio MySQL driver" optional = true -python-versions = "^3.7" +python-versions = "<4.0,>=3.7" files = [ {file = "asyncmy-0.2.10rc1.tar.gz", hash = "sha256:ba97b7f9b9719b6cb15169f0bffbf20be63767ff5052a24c3663a1d558bced5a"}, ] @@ -2658,11 +2658,15 @@ name = "pypika-tortoise" version = "0.2.2" description = "Forked from pypika and streamline just for tortoise-orm" optional = false -python-versions = ">=3.8,<4.0" -files = [ - {file = "pypika_tortoise-0.2.2-py3-none-any.whl", hash = "sha256:e93190aedd95acb08b69636bc2328cc053b2c9971307b6d44405bc6d9f9b71a5"}, - {file = "pypika_tortoise-0.2.2.tar.gz", hash = "sha256:f0fbc9e0c3ddc33118a5be69907428863849df60788e125edef1f46a6261d63b"}, -] +python-versions = "^3.7" +files = [] +develop = false + +[package.source] +type = "git" +url = "https://github.com/henadzit/pypika-tortoise.git" +reference = "parameterization-changes" +resolved_reference = "36a49ec465eedc7c17c6b3621360c3d700614146" [[package]] name = "pytest" @@ -3855,4 +3859,4 @@ psycopg = ["psycopg"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "fe730e9093d0549d2152dfdc8775c428e72a9a2ea48b2b789f16bfa172ebf66e" +content-hash = "ccee25fb7393e24ddf0389c7787bbbc941c46e0afb090caf4c0ffbe3a81d744a" diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 82ae9ac0e..8e4f4717d 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -197,7 +197,7 @@ async def _process_insert_result(self, instance: "Model", results: Any) -> None: raise NotImplementedError() # pragma: nocoverage def parameter(self, pos: int) -> Parameter: - raise NotImplementedError() # pragma: nocoverage + return Parameter(idx=pos + 1) async def execute_insert(self, instance: "Model") -> None: if not instance._custom_generated_pk: @@ -259,14 +259,14 @@ def get_update_sql( expressions = expressions or {} table = self.model._meta.basetable query = self.db.query_class.update(table) - count = 0 + parameter_idx = 0 for field in update_fields or self.model._meta.fields_db_projection.keys(): db_column = self.model._meta.fields_db_projection[field] field_object = self.model._meta.fields_map[field] if not field_object.pk: if field not in expressions.keys(): - query = query.set(db_column, self.parameter(count)) - count += 1 + query = query.set(db_column, self.parameter(parameter_idx)) + parameter_idx += 1 else: value = ( expressions[field] @@ -282,7 +282,7 @@ def get_update_sql( ) query = query.set(db_column, value) - query = query.where(table[self.model._meta.db_pk_column] == self.parameter(count)) + query = query.where(table[self.model._meta.db_pk_column] == self.parameter(parameter_idx)) sql = query.get_sql() if not expressions: diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index 2703b04aa..0b85e8eb7 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -1,7 +1,6 @@ import uuid from typing import Optional, Sequence, cast -from pypika import Parameter from pypika.dialects import PostgreSQLQueryBuilder from pypika.terms import Term @@ -38,9 +37,6 @@ class BasePostgresExecutor(BaseExecutor): posix_regex: postgres_posix_regex, } - def parameter(self, pos: int) -> Parameter: - return Parameter("$%d" % (pos + 1,)) - def _prepare_insert_statement( self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False ) -> PostgreSQLQueryBuilder: diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index cadc4ccbd..741c1b38f 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,4 +1,4 @@ -from pypika import Parameter, functions +from pypika import functions from pypika.enums import SqlTypes from pypika.terms import BasicCriterion, Criterion from pypika.utils import format_quotes @@ -117,9 +117,6 @@ class MySQLExecutor(BaseExecutor): } EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON" - def parameter(self, pos: int) -> Parameter: - return Parameter("%s") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/backends/odbc/executor.py b/tortoise/backends/odbc/executor.py index 620fbb239..9f54f5300 100644 --- a/tortoise/backends/odbc/executor.py +++ b/tortoise/backends/odbc/executor.py @@ -1,14 +1,9 @@ -from pypika import Parameter - from tortoise import Model from tortoise.backends.base.executor import BaseExecutor from tortoise.fields import BigIntField, IntField, SmallIntField class ODBCExecutor(BaseExecutor): - def parameter(self, pos: int) -> Parameter: - return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/backends/psycopg/executor.py b/tortoise/backends/psycopg/executor.py index e53492494..435e2fca7 100644 --- a/tortoise/backends/psycopg/executor.py +++ b/tortoise/backends/psycopg/executor.py @@ -2,8 +2,6 @@ from typing import Optional -from pypika import Parameter - from tortoise import Model from tortoise.backends.base_postgres.executor import BasePostgresExecutor @@ -23,6 +21,3 @@ async def _process_insert_result( for key, val in zip(generated_fields, results): setattr(instance, db_projection[key], val) - - def parameter(self, pos: int) -> Parameter: - return Parameter("%s") diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index 86236d2b8..2971f35b4 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -3,7 +3,6 @@ from typing import Optional, Type, Union import pytz -from pypika import Parameter from tortoise import Model, fields, timezone from tortoise.backends.base.executor import BaseExecutor @@ -87,9 +86,6 @@ class SqliteExecutor(BaseExecutor): EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN" DB_NATIVE = {bytes, str, int, float} - def parameter(self, pos: int) -> Parameter: - return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 41acaaefc..e6c5c0827 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1193,7 +1193,7 @@ def _make_query(self) -> None: self.resolve_filters() # Need to get executor to get correct column_map executor = self._db.executor_class(model=self.model, db=self._db) - count = 0 + parameter_idx = 0 for key, value in self.update_kwargs.items(): field_object = self.model._meta.fields_map.get(key) if not field_object: @@ -1227,9 +1227,9 @@ def _make_query(self) -> None: if isinstance(value, Term): self.query = self.query.set(db_field, value) else: - self.query = self.query.set(db_field, executor.parameter(count)) + self.query = self.query.set(db_field, executor.parameter(parameter_idx)) self.values.append(value) - count += 1 + parameter_idx += 1 def __await__(self) -> Generator[Any, None, int]: if self._db is None: From 505d29afe572bbd110682da0963bb80ca84f4f2e Mon Sep 17 00:00:00 2001 From: henadzit Date: Fri, 15 Nov 2024 11:53:39 +0100 Subject: [PATCH 3/8] Make sure _execute() uses the same query as returned from sql() --- tests/test_sql.py | 3 +- tortoise/backends/base/executor.py | 6 +- tortoise/expressions.py | 7 +- tortoise/queryset.py | 169 ++++++++++++++++------------- 4 files changed, 105 insertions(+), 80 deletions(-) diff --git a/tests/test_sql.py b/tests/test_sql.py index b932a2155..76e7fb11f 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,5 +1,5 @@ -from tortoise import connections from tests.testmodels import CharPkModel, IntFields +from tortoise import connections from tortoise.contrib import test from tortoise.expressions import F from tortoise.functions import Concat @@ -53,6 +53,7 @@ def test_annotate(self): expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' self.assertEqual(sql, expected) + @test.skip("Update queries are not parameterized yet") def test_update(self): sql = IntFields.filter(intnum=2).update(intnum=1).sql() if self.dialect == "mysql": diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 8e4f4717d..dc44c6e3a 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -19,7 +19,7 @@ cast, ) -from pypika import JoinType, Parameter, Query, Table +from pypika import JoinType, Parameter, Table from pypika.queries import QueryBuilder from tortoise.exceptions import OperationalError @@ -119,8 +119,8 @@ def __init__( self.update_cache, ) = EXECUTOR_CACHE[key] - async def execute_explain(self, query: Query) -> Any: - sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql())) # type:ignore[attr-defined] + async def execute_explain(self, sql: str) -> Any: + sql = " ".join((self.EXPLAIN_PREFIX, sql)) return (await self.db.execute_query(sql))[1] async def execute_select( diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 50af26d2b..d01a92a2c 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -215,10 +215,11 @@ def __init__(self, query: "AwaitableQuery") -> None: self.query = query def get_sql(self, **kwargs: Any) -> str: - return self.query.as_query().get_sql(**kwargs) + return self.query._make_query(**kwargs)[0] - def as_(self, alias: str) -> "Selectable": # type:ignore[override] - return self.query.as_query().as_(alias) + def as_(self, alias: str) -> "Selectable": + self.query._make_query() + return self.query.query.as_(alias) class RawSQL(Term): diff --git a/tortoise/queryset.py b/tortoise/queryset.py index e6c5c0827..24baeb484 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -128,6 +128,10 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db + def _choose_db_if_not_chosen(self, for_write: bool = False) -> None: + if self._db is None: + self._db = self._choose_db(for_write) # type: ignore + def resolve_filters(self) -> None: """Builds the common filters for a QuerySet.""" has_aggregate = self._resolve_annotate() @@ -280,21 +284,23 @@ def _resolve_annotate(self) -> bool: return any(info.term.is_aggregate for info in annotation_info.values()) - def sql(self, **kwargs) -> str: + def sql(self) -> str: """Return the actual SQL.""" - return self.as_query().get_sql(parameterizer=Parameterizer(), **kwargs) - - def as_query(self) -> QueryBuilder: - """Return the actual query.""" if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self.query - def _make_query(self) -> None: + sql, _ = self._make_query() + return sql + + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + """Build the query + + :param pypika_kwargs: Required for Subquery making + :return: Tuple[str, List[Any]]: The query string and the parameters + """ raise NotImplementedError() # pragma: nocoverage - async def _execute(self) -> Any: + async def _execute(self, sql: str, values: List[Any]) -> Any: raise NotImplementedError() # pragma: nocoverage @@ -1000,10 +1006,8 @@ async def explain(self) -> Any: """ if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return await self._db.executor_class(model=self.model, db=self._db).execute_explain( - self.query # type:ignore[arg-type] - ) + sql, _ = self._make_query() + return await self._db.executor_class(model=self.model, db=self._db).execute_explain(sql) def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": """ @@ -1055,7 +1059,7 @@ def _join_table_with_select_related( ) return self.query - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: # clean tmp records first self._select_related_idx = [] self._joined_tables = [] @@ -1119,19 +1123,23 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + parameterizer = Parameterizer() + return ( + self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) + def __await__(self) -> Generator[Any, None, List[MODEL]]: if self._db is None: self._db = self._choose_db(self._select_for_update) # type: ignore - self._make_query() - return self._execute().__await__() + sql, values = self._make_query() + return self._execute(sql, values).__await__() async def __aiter__(self) -> AsyncIterator[MODEL]: for val in await self: yield val - async def _execute(self) -> List[MODEL]: - parameterizer = Parameterizer() - sql = self.query.get_sql(parameterizer=parameterizer) + async def _execute(self, sql: str, values: List[Any]) -> List[MODEL]: instance_list = await self._db.executor_class( model=self.model, db=self._db, @@ -1140,7 +1148,7 @@ async def _execute(self) -> List[MODEL]: select_related_idx=self._select_related_idx, ).execute_select( sql, - parameterizer.values, + values, custom_fields=list(self._annotations.keys()), ) if self._single: @@ -1183,7 +1191,7 @@ def __init__( self._orderings = orderings self.values: List[Any] = [] - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: @@ -1230,15 +1238,15 @@ def _make_query(self) -> None: self.query = self.query.set(db_field, executor.parameter(parameter_idx)) self.values.append(value) parameter_idx += 1 + return self.query.get_sql(), self.values def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: - return (await self._db.execute_query(str(self.query), self.values))[0] + async def _execute(self, sql, values) -> int: + return (await self._db.execute_query(sql, values))[0] class DeleteQuery(AwaitableQuery): @@ -1267,7 +1275,7 @@ def __init__( self._limit = limit self._orderings = orderings - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) if self.capabilities.support_update_limit_order_by and self._limit: self.query._limit = self.query._wrapper_cls(self._limit) @@ -1279,15 +1287,15 @@ def _make_query(self) -> None: ) self.resolve_filters() self.query._delete_from = True + return self.query.get_sql(), [] def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: - return (await self._db.execute_query(str(self.query)))[0] + async def _execute(self, sql: str, values: List[Any]) -> int: + return (await self._db.execute_query(sql, values))[0] class ExistsQuery(AwaitableQuery): @@ -1314,7 +1322,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() self.query._limit = self.query._wrapper_cls(1) @@ -1327,13 +1335,15 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self.query.get_sql(), [] + def __await__(self) -> Generator[Any, None, bool]: if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> bool: + async def _execute(self, sql: str, values: List[Any]) -> bool: result, _ = await self._db.execute_query(str(self.query)) return bool(result) @@ -1368,7 +1378,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() count_term = Count("*") @@ -1385,14 +1395,15 @@ def _make_query(self) -> None: if self._use_indexes: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self.query.get_sql(**pypika_kwargs), [] def __await__(self) -> Generator[Any, None, int]: if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: + async def _execute(self, sql: str, values: List[Any]) -> int: _, result = await self._db.execute_query(str(self.query)) if not result: return 0 @@ -1570,7 +1581,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1600,6 +1611,8 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self.query.get_sql(**pypika_kwargs), [] + @overload def __await__( self: "ValuesListQuery[Literal[False]]", @@ -1613,15 +1626,15 @@ def __await__( def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]: if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() # pylint: disable=E1101 + sql, values = self._make_query() + return self._execute(sql, values).__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self) -> Union[List[Any], Tuple]: - _, result = await self._db.execute_query(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> Union[List[Any], Tuple]: + _, result = await self._db.execute_query(sql, values) columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -1692,7 +1705,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1728,6 +1741,8 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self.query.get_sql(**pypika_kwargs), [] + @overload def __await__( self: "ValuesQuery[Literal[False]]", @@ -1743,15 +1758,15 @@ def __await__( ) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]: if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() # pylint: disable=E1101 + sql, values = self._make_query() + return self._execute(sql, values).__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[Dict[str, Any]]: for val in await self: yield val - async def _execute(self) -> Union[List[dict], Dict]: - result = await self._db.execute_query_dict(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> Union[List[dict], Dict]: + result = await self._db.execute_query_dict(sql, values) columns = [ val for val in [ @@ -1785,21 +1800,22 @@ def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: self._sql = sql self._db = db - def _make_query(self) -> None: - self.query = RawSQL(self._sql) # type:ignore[assignment] + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + self.query = RawSQL(self._sql) + return self.query.get_sql(**pypika_kwargs), [] - async def _execute(self) -> Any: + async def _execute(self, sql: str, values: List[Any]) -> Any: instance_list = await self._db.executor_class( model=self.model, db=self._db, - ).execute_select(self.query.get_sql()) + ).execute_select(sql, values) return instance_list def __await__(self) -> Generator[Any, None, List[MODEL]]: if self._db is None: self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + sql, values = self._make_query() + return self._execute(sql, values).__await__() class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): @@ -1833,7 +1849,7 @@ def __init__( self._batch_size = batch_size self._queries: List[QueryBuilder] = [] - def _make_query(self) -> None: + def _make_queries(self) -> List[Tuple[str, List[Any]]]: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: @@ -1875,16 +1891,23 @@ def _make_query(self) -> None: query = query.set(field, case) query = query.where(pk.isin(pk_list)) self._queries.append(query) + return [(query.get_sql(), []) for query in self._queries] - async def _execute(self) -> int: + async def _execute_many(self, queries_with_params: List[Tuple[str, List[Any]]]) -> int: count = 0 - for query in self._queries: - count += (await self._db.execute_query(str(query)))[0] + for sql, values in queries_with_params: + count += (await self._db.execute_query(sql, values))[0] return count - def sql(self, **kwargs) -> str: - self.as_query() - return ";".join([str(query) for query in self._queries]) + def __await__(self) -> Generator[Any, Any, int]: + self._choose_db_if_not_chosen(True) + queries = self._make_queries() + return self._execute_many(queries).__await__() + + def sql(self) -> str: + self._choose_db_if_not_chosen() + queries = self._make_queries() + return ";".join([str(sql) for sql, _ in queries]) class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): @@ -1918,7 +1941,7 @@ def __init__( self._update_fields = update_fields self._on_conflict = on_conflict - def _make_query(self) -> None: + def _make_queries(self) -> None: self._executor = self._db.executor_class(model=self.model, db=self._db) if self._ignore_conflicts or self._update_fields: _, columns = self._executor._prepare_insert_columns() @@ -1948,7 +1971,7 @@ def _make_query(self) -> None: self._insert_query_all = self._executor.insert_query_all # type:ignore[assignment] self._insert_query = self._executor.insert_query # type:ignore[assignment] - async def _execute(self) -> None: + async def _execute_many(self) -> None: for instance_chunk in chunk(self._objects, self._batch_size): values_lists_all = [] values_lists = [] @@ -1977,13 +2000,13 @@ async def _execute(self) -> None: await self._db.execute_many(str(self._insert_query), values_lists) def __await__(self) -> Generator[Any, None, None]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + self._make_queries() + return self._execute_many().__await__() - def sql(self, **kwargs) -> str: - self.as_query() + def sql(self) -> str: + self._choose_db_if_not_chosen() + self._make_queries() if self._insert_query and self._insert_query_all: return ";".join([str(self._insert_query), str(self._insert_query_all)]) return str(self._insert_query or self._insert_query_all) From 03796ae57568f42e9a6bb7307205059dce612c69 Mon Sep 17 00:00:00 2001 From: henadzit Date: Fri, 15 Nov 2024 19:35:36 +0100 Subject: [PATCH 4/8] Parametrize .values, .values_list, .exists and .count queries --- tests/test_filters.py | 8 ++++++ tests/test_queryset.py | 4 +-- tests/test_sql.py | 40 ++++++++++++++++++++++++++++++ tortoise/backends/base/executor.py | 23 ++++------------- tortoise/expressions.py | 2 +- tortoise/fields/base.py | 6 +++++ tortoise/filters.py | 4 +-- tortoise/queryset.py | 33 ++++++++++++++++-------- 8 files changed, 87 insertions(+), 33 deletions(-) diff --git a/tests/test_filters.py b/tests/test_filters.py index 70be46c78..dd80a84e8 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -268,6 +268,14 @@ async def test_between_and(self): [Decimal("1.2345")], ) + async def test_in(self): + self.assertEqual( + await DecimalFields.filter( + decimal__in=[Decimal("1.2345"), Decimal("1000")] + ).values_list("decimal", flat=True), + [Decimal("1.2345")], + ) + class TestCharFkFieldFilters(test.TestCase): async def asyncSetUp(self): diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 1e0f8fb5d..a286efedf 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -615,7 +615,7 @@ async def test_force_index_available_in_more_query(self): sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", + "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() @@ -655,7 +655,7 @@ async def test_use_index_available_in_more_query(self): sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", + "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() diff --git a/tests/test_sql.py b/tests/test_sql.py index 76e7fb11f..26f9d9a04 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -53,6 +53,46 @@ def test_annotate(self): expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' self.assertEqual(sql, expected) + def test_values(self): + sql = IntFields.filter(intnum=1).values("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + + def test_values_list(self): + sql = IntFields.filter(intnum=1).values_list("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `0` FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + + def test_exists(self): + sql = IntFields.filter(intnum=1).exists().sql() + if self.dialect == "mysql": + expected = "SELECT %s FROM `intfields` WHERE `intnum`=%s LIMIT %s" + elif self.dialect == "postgres": + expected = 'SELECT $1 FROM "intfields" WHERE "intnum"=$2 LIMIT $3' + else: + expected = 'SELECT ? FROM "intfields" WHERE "intnum"=? LIMIT ?' + self.assertEqual(sql, expected) + + def test_count(self): + sql = IntFields.all().filter(intnum=1).count().sql() + if self.dialect == "mysql": + expected = "SELECT COUNT(*) FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + @test.skip("Update queries are not parameterized yet") def test_update(self): sql = IntFields.filter(intnum=2).update(intnum=1).sql() diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index dc44c6e3a..1c2a5845a 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -170,14 +170,6 @@ def _prepare_insert_columns( result_columns = [self.model._meta.fields_db_projection[c] for c in regular_columns] return regular_columns, result_columns - @classmethod - def _field_to_db( - cls, field_object: Field, attr: Any, instance: "Union[Type[Model], Model]" - ) -> Any: - if field_object.__class__ in cls.TO_DB_OVERRIDE: - return cls.TO_DB_OVERRIDE[field_object.__class__](field_object, attr, instance) - return field_object.to_db_value(attr, instance) - def _prepare_insert_statement( self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False ) -> QueryBuilder: @@ -330,10 +322,8 @@ async def _prefetch_reverse_relation( if relation_field not in related_objects_for_fetch: related_objects_for_fetch[relation_field] = [] related_objects_for_fetch[relation_field].append( - self._field_to_db( - instance._meta.fields_map[related_field_name], - getattr(instance, related_field_name), - instance, + instance._meta.fields_map[related_field_name].to_db_value( + getattr(instance, related_field_name), instance ) ) @@ -375,10 +365,8 @@ async def _prefetch_reverse_o2o_relation( if relation_field not in related_objects_for_fetch: related_objects_for_fetch[relation_field] = [] related_objects_for_fetch[relation_field].append( - self._field_to_db( - instance._meta.fields_map[related_field_name], - getattr(instance, related_field_name), - instance, + instance._meta.fields_map[related_field_name].to_db_value( + getattr(instance, related_field_name), instance ) ) @@ -410,8 +398,7 @@ async def _prefetch_m2m_relation( ) -> "Iterable[Model]": to_attr, related_query = related_query instance_id_set: set = { - self._field_to_db(instance._meta.pk, instance.pk, instance) - for instance in instance_list + instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list } field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore diff --git a/tortoise/expressions.py b/tortoise/expressions.py index d01a92a2c..5e352e235 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -384,7 +384,7 @@ def _process_filter_kwarg( encoded_value = ( param["value_encoder"](value, model, field_object) if param.get("value_encoder") - else model._meta.db.executor_class._field_to_db(field_object, value, model) + else field_object.to_db_value(value, model) ) op = param["operator"] criterion = op(table[param["source_field"]], encoded_value) diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 966e0a266..3bfcb2408 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -258,6 +258,12 @@ def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any: """ if value is not None and not isinstance(value, self.field_type): value = self.field_type(value) # pylint: disable=E1102 + + if self.__class__ in self.model._meta.db.executor_class.TO_DB_OVERRIDE: + value = self.model._meta.db.executor_class.TO_DB_OVERRIDE[self.__class__]( + self, value, instance + ) + self.validate(value) return value diff --git a/tortoise/filters.py b/tortoise/filters.py index d5d9ea719..f21e674bb 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -89,14 +89,14 @@ def is_in(field: Term, value: Any) -> Criterion: if value: return field.isin(value) # SQL has no False, so we return 1=0 - return BasicCriterion(Equality.eq, ValueWrapper(1), ValueWrapper(0)) + return BasicCriterion(Equality.eq, ValueWrapper("1"), ValueWrapper("0")) def not_in(field: Term, value: Any) -> Criterion: if value: return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 - return BasicCriterion(Equality.eq, ValueWrapper(1), ValueWrapper(1)) + return BasicCriterion(Equality.eq, ValueWrapper("1"), ValueWrapper("1")) def between_and(field: Term, value: Tuple[Any, Any]) -> Criterion: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 24baeb484..e608ace64 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -24,7 +24,7 @@ from pypika.analytics import Count from pypika.functions import Cast from pypika.queries import QueryBuilder -from pypika.terms import Case, Field, Term, ValueWrapper, Parameterizer +from pypika.terms import Case, Field, Parameterizer, Star, Term, ValueWrapper from typing_extensions import Literal, Protocol from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities @@ -1123,7 +1123,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = Parameterizer() + parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) return ( self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), parameterizer.values, @@ -1335,7 +1335,8 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self.query.get_sql(), [] + parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + return self.query.get_sql(parameterizer=parameterizer), parameterizer.values def __await__(self) -> Generator[Any, None, bool]: if self._db is None: @@ -1344,7 +1345,7 @@ def __await__(self) -> Generator[Any, None, bool]: return self._execute(sql, values).__await__() async def _execute(self, sql: str, values: List[Any]) -> bool: - result, _ = await self._db.execute_query(str(self.query)) + result, _ = await self._db.execute_query(sql, values) return bool(result) @@ -1381,7 +1382,7 @@ def __init__( def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() - count_term = Count("*") + count_term = Count(Star()) if self.query._groupbys: count_term = count_term.over() @@ -1395,7 +1396,11 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: if self._use_indexes: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self.query.get_sql(**pypika_kwargs), [] + parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + return ( + self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) def __await__(self) -> Generator[Any, None, int]: if self._db is None: @@ -1404,7 +1409,7 @@ def __await__(self) -> Generator[Any, None, int]: return self._execute(sql, values).__await__() async def _execute(self, sql: str, values: List[Any]) -> int: - _, result = await self._db.execute_query(str(self.query)) + _, result = await self._db.execute_query(sql, values) if not result: return 0 count = list(dict(result[0]).values())[0] - self._offset @@ -1611,7 +1616,11 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self.query.get_sql(**pypika_kwargs), [] + parameterizer = Parameterizer() + return ( + self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) @overload def __await__( @@ -1741,7 +1750,11 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self.query.get_sql(**pypika_kwargs), [] + parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + return ( + self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) @overload def __await__( @@ -1907,7 +1920,7 @@ def __await__(self) -> Generator[Any, Any, int]: def sql(self) -> str: self._choose_db_if_not_chosen() queries = self._make_queries() - return ";".join([str(sql) for sql, _ in queries]) + return ";".join([sql for sql, _ in queries]) class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): From 3bde481f8100acc0a17b484d83be52f0e232cd7d Mon Sep 17 00:00:00 2001 From: henadzit Date: Sun, 17 Nov 2024 18:01:05 +0100 Subject: [PATCH 5/8] Configure Parameterizer for psycopg --- tests/test_case_when.py | 67 +++++++++++++++++++++------ tests/test_queryset.py | 51 +++++++++++++------- tests/test_sql.py | 42 +++++++++++++---- tortoise/backends/base/executor.py | 5 ++ tortoise/backends/psycopg/executor.py | 9 ++++ tortoise/expressions.py | 2 + tortoise/queryset.py | 12 ++--- 7 files changed, 145 insertions(+), 43 deletions(-) diff --git a/tests/test_case_when.py b/tests/test_case_when.py index 197a1ab29..386d2e7a5 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,5 +1,6 @@ from tests.testmodels import IntFields from tortoise import connections +from tortoise.backends.psycopg.client import PsycopgClient from tortoise.contrib import test from tortoise.exceptions import FieldError from tortoise.expressions import Case, F, Q, When @@ -12,6 +13,7 @@ async def asyncSetUp(self): self.intfields = [await IntFields.create(intnum=val) for val in range(10)] self.db = connections.get("models") self.dialect = self.db.schema_generator.DIALECT + self.is_psycopg = isinstance(self.db, PsycopgClient) async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") @@ -20,7 +22,10 @@ async def test_single_when(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE $3 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE $3 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -34,7 +39,10 @@ async def test_multi_when(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s WHEN "intnum"<=%s THEN %s ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -46,7 +54,10 @@ async def test_q_object_when(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>%s AND `intnum`<%s THEN %s ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">$1 AND "intnum"<$2 THEN $3 ELSE $4 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">%s AND "intnum"<%s THEN %s ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">$1 AND "intnum"<$2 THEN $3 ELSE $4 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">? AND "intnum"=%s THEN `intnum_null` ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum_null" ELSE $2 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN "intnum_null" ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum_null" ELSE $2 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum_null" ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -71,7 +85,10 @@ async def test_AE_then(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN `intnum`+%s ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum"+$2 ELSE $3 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN "intnum"+%s ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum"+$2 ELSE $3 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum"+? ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -83,7 +100,10 @@ async def test_func_then(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN COALESCE(`intnum_null`,%s) ELSE %s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN COALESCE("intnum_null",$2) ELSE $3 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN COALESCE("intnum_null",%s) ELSE %s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN COALESCE("intnum_null",$2) ELSE $3 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN COALESCE("intnum_null",?) ELSE ? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -95,7 +115,10 @@ async def test_F_default(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum_null` END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum_null" END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE "intnum_null" END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum_null" END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum_null" END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -108,7 +131,10 @@ async def test_AE_default(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum`+%s END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum"+$3 END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE "intnum"+%s END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum"+$3 END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum"+? END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -120,7 +146,10 @@ async def test_func_default(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE COALESCE(`intnum_null`,%s) END `category` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE COALESCE("intnum_null",$3) END "category" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE COALESCE("intnum_null",%s) END "category" FROM "intfields"' + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE COALESCE("intnum_null",$3) END "category" FROM "intfields"' else: expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE COALESCE("intnum_null",?) END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -139,7 +168,10 @@ async def test_case_when_in_where(self): if self.dialect == "mysql": expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END IN (%s,%s)" elif self.dialect == "postgres": - expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END IN ($6,$7)' + if self.is_psycopg: + expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=%s THEN %s WHEN "intnum"<=%s THEN %s ELSE %s END IN (%s,%s)' + else: + expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END IN ($6,$7)' else: expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END IN (?,?)' self.assertEqual(sql, expected_sql) @@ -156,7 +188,10 @@ async def test_annotation_in_when_annotation(self): if self.dialect == "mysql": expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+%s `intnum_plus_1`,CASE WHEN `intnum`+%s>=%s THEN %s ELSE %s END `bigger_than_10` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+$1 "intnum_plus_1",CASE WHEN "intnum"+$2>=$3 THEN $4 ELSE $5 END "bigger_than_10" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+%s "intnum_plus_1",CASE WHEN "intnum"+%s>=%s THEN %s ELSE %s END "bigger_than_10" FROM "intfields"' + else: + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+$1 "intnum_plus_1",CASE WHEN "intnum"+$2>=$3 THEN $4 ELSE $5 END "bigger_than_10" FROM "intfields"' else: expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+? "intnum_plus_1",CASE WHEN "intnum"+?>=? THEN ? ELSE ? END "bigger_than_10" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -173,7 +208,10 @@ async def test_func_annotation_in_when_annotation(self): if self.dialect == "mysql": expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,%s) `intnum_col`,CASE WHEN COALESCE(`intnum`,%s)=%s THEN %s ELSE %s END `is_zero` FROM `intfields`" elif self.dialect == "postgres": - expected_sql = 'SELECT "id" "id",COALESCE("intnum",$1) "intnum_col",CASE WHEN COALESCE("intnum",$2)=$3 THEN $4 ELSE $5 END "is_zero" FROM "intfields"' + if self.is_psycopg: + expected_sql = 'SELECT "id" "id",COALESCE("intnum",%s) "intnum_col",CASE WHEN COALESCE("intnum",%s)=%s THEN %s ELSE %s END "is_zero" FROM "intfields"' + else: + expected_sql = 'SELECT "id" "id",COALESCE("intnum",$1) "intnum_col",CASE WHEN COALESCE("intnum",$2)=$3 THEN $4 ELSE $5 END "is_zero" FROM "intfields"' else: expected_sql = 'SELECT "id" "id",COALESCE("intnum",?) "intnum_col",CASE WHEN COALESCE("intnum",?)=? THEN ? ELSE ? END "is_zero" FROM "intfields"' self.assertEqual(sql, expected_sql) @@ -191,7 +229,10 @@ async def test_case_when_in_group_by(self): if self.dialect == "mysql": expected_sql = "SELECT CASE WHEN `intnum`=%s THEN %s ELSE %s END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" elif self.dialect == "postgres": - expected_sql = 'SELECT CASE WHEN "intnum"=$1 THEN $2 ELSE $3 END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + if self.is_psycopg: + expected_sql = 'SELECT CASE WHEN "intnum"=%s THEN %s ELSE %s END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + else: + expected_sql = 'SELECT CASE WHEN "intnum"=$1 THEN $2 ELSE $3 END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' elif self.dialect == "mssql": expected_sql = 'SELECT CASE WHEN "intnum"=? THEN ? ELSE ? END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=? THEN ? ELSE ? END' else: diff --git a/tests/test_queryset.py b/tests/test_queryset.py index a286efedf..9d1234fe4 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -12,6 +12,7 @@ Tree, ) from tortoise import connections +from tortoise.backends.psycopg.client import PsycopgClient from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import ( @@ -673,22 +674,40 @@ async def test_select_for_update(self): dialect = self.db.schema_generator.DIALECT if dialect == "postgres": - self.assertEqual( - sql1, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE', - ) - self.assertEqual( - sql2, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT', - ) - self.assertEqual( - sql3, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED', - ) - self.assertEqual( - sql4, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"', - ) + if isinstance(self.db, PsycopgClient): + self.assertEqual( + sql1, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE', + ) + self.assertEqual( + sql2, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE NOWAIT', + ) + self.assertEqual( + sql3, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE SKIP LOCKED', + ) + self.assertEqual( + sql4, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE OF "intfields"', + ) + else: + self.assertEqual( + sql1, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE', + ) + self.assertEqual( + sql2, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT', + ) + self.assertEqual( + sql3, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED', + ) + self.assertEqual( + sql4, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"', + ) elif dialect == "mysql": self.assertEqual( sql1, diff --git a/tests/test_sql.py b/tests/test_sql.py index 26f9d9a04..ebabee16e 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,5 +1,6 @@ from tests.testmodels import CharPkModel, IntFields from tortoise import connections +from tortoise.backends.psycopg.client import PsycopgClient from tortoise.contrib import test from tortoise.expressions import F from tortoise.functions import Concat @@ -10,13 +11,17 @@ async def asyncSetUp(self): await super().asyncSetUp() self.db = connections.get("models") self.dialect = self.db.schema_generator.DIALECT + self.is_psycopg = isinstance(self.db, PsycopgClient) def test_filter(self): sql = CharPkModel.all().filter(id="123").sql() if self.dialect == "mysql": expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s" elif self.dialect == "postgres": - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' + if self.is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' else: expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=?' @@ -27,7 +32,10 @@ def test_filter_with_limit_offset(self): if self.dialect == "mysql": expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s LIMIT %s OFFSET %s" elif self.dialect == "postgres": - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' + if self.is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s LIMIT %s OFFSET %s' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' elif self.dialect == "mssql": expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY' else: @@ -48,7 +56,10 @@ def test_annotate(self): if self.dialect == "mysql": expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" elif self.dialect == "postgres": - expected = 'SELECT "id",CONCAT("id",$1) "id_plus_one" FROM "charpkmodel"' + if self.is_psycopg: + expected = 'SELECT "id",CONCAT("id",%s) "id_plus_one" FROM "charpkmodel"' + else: + expected = 'SELECT "id",CONCAT("id",$1) "id_plus_one" FROM "charpkmodel"' else: expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' self.assertEqual(sql, expected) @@ -58,7 +69,10 @@ def test_values(self): if self.dialect == "mysql": expected = "SELECT `intnum` `intnum` FROM `intfields` WHERE `intnum`=%s" elif self.dialect == "postgres": - expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' + if self.is_psycopg: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' else: expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=?' self.assertEqual(sql, expected) @@ -68,7 +82,10 @@ def test_values_list(self): if self.dialect == "mysql": expected = "SELECT `intnum` `0` FROM `intfields` WHERE `intnum`=%s" elif self.dialect == "postgres": - expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' + if self.is_psycopg: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' else: expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=?' self.assertEqual(sql, expected) @@ -78,7 +95,10 @@ def test_exists(self): if self.dialect == "mysql": expected = "SELECT %s FROM `intfields` WHERE `intnum`=%s LIMIT %s" elif self.dialect == "postgres": - expected = 'SELECT $1 FROM "intfields" WHERE "intnum"=$2 LIMIT $3' + if self.is_psycopg: + expected = 'SELECT %s FROM "intfields" WHERE "intnum"=%s LIMIT %s' + else: + expected = 'SELECT $1 FROM "intfields" WHERE "intnum"=$2 LIMIT $3' else: expected = 'SELECT ? FROM "intfields" WHERE "intnum"=? LIMIT ?' self.assertEqual(sql, expected) @@ -88,7 +108,10 @@ def test_count(self): if self.dialect == "mysql": expected = "SELECT COUNT(*) FROM `intfields` WHERE `intnum`=%s" elif self.dialect == "postgres": - expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' + if self.is_psycopg: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' else: expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=?' self.assertEqual(sql, expected) @@ -99,7 +122,10 @@ def test_update(self): if self.dialect == "mysql": expected = "UPDATE `intfields` SET `intnum`=%s WHERE `intnum`=%s" elif self.dialect == "postgres": - expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' + if self.is_psycopg: + expected = 'UPDATE "intfields" SET "intnum"=%s WHERE "intnum"=%s' + else: + expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' else: expected = 'UPDATE "intfields" SET "intnum"=? WHERE "intnum"=?' self.assertEqual(sql, expected) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 1c2a5845a..9de684cc9 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -20,6 +20,7 @@ ) from pypika import JoinType, Parameter, Table +from pypika.terms import Parameterizer from pypika.queries import QueryBuilder from tortoise.exceptions import OperationalError @@ -191,6 +192,10 @@ async def _process_insert_result(self, instance: "Model", results: Any) -> None: def parameter(self, pos: int) -> Parameter: return Parameter(idx=pos + 1) + @classmethod + def parameterizer(cls) -> Parameterizer: + return Parameterizer() + async def execute_insert(self, instance: "Model") -> None: if not instance._custom_generated_pk: values = [ diff --git a/tortoise/backends/psycopg/executor.py b/tortoise/backends/psycopg/executor.py index 435e2fca7..5ea001f6d 100644 --- a/tortoise/backends/psycopg/executor.py +++ b/tortoise/backends/psycopg/executor.py @@ -2,6 +2,8 @@ from typing import Optional +from pypika import Parameter, Parameterizer + from tortoise import Model from tortoise.backends.base_postgres.executor import BasePostgresExecutor @@ -21,3 +23,10 @@ async def _process_insert_result( for key, val in zip(generated_fields, results): setattr(instance, db_projection[key], val) + + def parameter(self, pos: int) -> Parameter: + return Parameter("%s") + + @classmethod + def parameterizer(cls) -> Parameterizer: + return Parameterizer(placeholder_factory=lambda _: "%s") diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 5e352e235..51f427356 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -215,9 +215,11 @@ def __init__(self, query: "AwaitableQuery") -> None: self.query = query def get_sql(self, **kwargs: Any) -> str: + self.query._choose_db_if_not_chosen() return self.query._make_query(**kwargs)[0] def as_(self, alias: str) -> "Selectable": + self.query._choose_db_if_not_chosen() self.query._make_query() return self.query.query.as_(alias) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index e608ace64..092ca050e 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -24,7 +24,7 @@ from pypika.analytics import Count from pypika.functions import Cast from pypika.queries import QueryBuilder -from pypika.terms import Case, Field, Parameterizer, Star, Term, ValueWrapper +from pypika.terms import Case, Field, Star, Term, ValueWrapper from typing_extensions import Literal, Protocol from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities @@ -1123,7 +1123,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) return ( self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), parameterizer.values, @@ -1335,7 +1335,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) return self.query.get_sql(parameterizer=parameterizer), parameterizer.values def __await__(self) -> Generator[Any, None, bool]: @@ -1396,7 +1396,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: if self._use_indexes: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) return ( self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), parameterizer.values, @@ -1616,7 +1616,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = Parameterizer() + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) return ( self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), parameterizer.values, @@ -1750,7 +1750,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", Parameterizer()) + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) return ( self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), parameterizer.values, From 46d47ac7fd3f7e71eb179654f17089ff56100451 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 18 Nov 2024 13:34:56 +0100 Subject: [PATCH 6/8] Fix Postgres issues --- poetry.lock | 2 +- tests/test_queryset.py | 4 ++-- tests/test_sql.py | 18 ++++++++++++------ tortoise/filters.py | 12 ++++++++++-- tortoise/functions.py | 14 +++++++++++++- tortoise/queryset.py | 2 +- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/poetry.lock b/poetry.lock index b920cd0b2..21e3ee8b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2666,7 +2666,7 @@ develop = false type = "git" url = "https://github.com/henadzit/pypika-tortoise.git" reference = "parameterization-changes" -resolved_reference = "36a49ec465eedc7c17c6b3621360c3d700614146" +resolved_reference = "e4db01e6ab73045c0cf45ed71eeadba30395905e" [[package]] name = "pytest" diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 9d1234fe4..bc04d98a6 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -622,7 +622,7 @@ async def test_force_index_available_in_more_query(self): sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT %s FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", + "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_index_hint=True) @@ -662,7 +662,7 @@ async def test_use_index_available_in_more_query(self): sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT %s FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", + "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_for_update=True) diff --git a/tests/test_sql.py b/tests/test_sql.py index ebabee16e..9aacd9241 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -57,9 +57,13 @@ def test_annotate(self): expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" elif self.dialect == "postgres": if self.is_psycopg: - expected = 'SELECT "id",CONCAT("id",%s) "id_plus_one" FROM "charpkmodel"' + expected = ( + 'SELECT "id",CONCAT("id"::text,%s::text) "id_plus_one" FROM "charpkmodel"' + ) else: - expected = 'SELECT "id",CONCAT("id",$1) "id_plus_one" FROM "charpkmodel"' + expected = ( + 'SELECT "id",CONCAT("id"::text,$1::text) "id_plus_one" FROM "charpkmodel"' + ) else: expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' self.assertEqual(sql, expected) @@ -93,14 +97,16 @@ def test_values_list(self): def test_exists(self): sql = IntFields.filter(intnum=1).exists().sql() if self.dialect == "mysql": - expected = "SELECT %s FROM `intfields` WHERE `intnum`=%s LIMIT %s" + expected = "SELECT 1 FROM `intfields` WHERE `intnum`=%s LIMIT %s" elif self.dialect == "postgres": if self.is_psycopg: - expected = 'SELECT %s FROM "intfields" WHERE "intnum"=%s LIMIT %s' + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=%s LIMIT %s' else: - expected = 'SELECT $1 FROM "intfields" WHERE "intnum"=$2 LIMIT $3' + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=$1 LIMIT $2' + elif self.dialect == "mssql": + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT ? ROWS ONLY' else: - expected = 'SELECT ? FROM "intfields" WHERE "intnum"=? LIMIT ?' + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? LIMIT ?' self.assertEqual(sql, expected) def test_count(self): diff --git a/tortoise/filters.py b/tortoise/filters.py index f21e674bb..ec710aee2 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -89,14 +89,22 @@ def is_in(field: Term, value: Any) -> Criterion: if value: return field.isin(value) # SQL has no False, so we return 1=0 - return BasicCriterion(Equality.eq, ValueWrapper("1"), ValueWrapper("0")) + return BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(0, allow_parametrize=False), + ) def not_in(field: Term, value: Any) -> Criterion: if value: return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 - return BasicCriterion(Equality.eq, ValueWrapper("1"), ValueWrapper("1")) + return BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(1, allow_parametrize=False), + ) def between_and(field: Term, value: Tuple[Any, Any]) -> Criterion: diff --git a/tortoise/functions.py b/tortoise/functions.py index 8cea15661..cf69790de 100644 --- a/tortoise/functions.py +++ b/tortoise/functions.py @@ -57,6 +57,18 @@ class Upper(Function): database_func = functions.Upper +class _Concat(functions.Concat): # type: ignore + @staticmethod + def get_arg_sql(arg, **kwargs): + sql = arg.get_sql(with_alias=False, **kwargs) if hasattr(arg, "get_sql") else str(arg) + # explicitly convert to text for postgres to avoid errors like + # "could not determine data type of parameter $1" + dialect = kwargs.get("dialect", None) + if dialect and dialect.value == "postgresql": + return f"{sql}::text" + return sql + + class Concat(Function): """ Concate field or constant text. @@ -65,7 +77,7 @@ class Concat(Function): :samp:`Concat("{FIELD_NAME}", {ANOTHER_FIELD_NAMES or CONSTANT_TEXT}, *args)` """ - database_func = functions.Concat + database_func = _Concat ############################################################################## diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 092ca050e..2d9c8175d 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1326,7 +1326,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() self.query._limit = self.query._wrapper_cls(1) - self.query._select_other(ValueWrapper(1)) + self.query._select_other(ValueWrapper(1, allow_parametrize=False)) if self._force_indexes: self.query._force_indexes = [] From 0d2fe634f0c684a1852e331cc585698f45749ad4 Mon Sep 17 00:00:00 2001 From: henadzit Date: Tue, 19 Nov 2024 23:51:53 +0100 Subject: [PATCH 7/8] Add params_inline arg to QuerySet.sql() --- poetry.lock | 4 +- tests/test_case_when.py | 254 +++++++++++++--------------- tortoise/backends/base/executor.py | 2 +- tortoise/backends/mssql/executor.py | 4 +- tortoise/expressions.py | 2 +- tortoise/functions.py | 2 +- tortoise/indexes.py | 4 +- tortoise/queryset.py | 78 ++++----- 8 files changed, 165 insertions(+), 185 deletions(-) diff --git a/poetry.lock b/poetry.lock index 21e3ee8b8..753d3a20f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2658,7 +2658,7 @@ name = "pypika-tortoise" version = "0.2.2" description = "Forked from pypika and streamline just for tortoise-orm" optional = false -python-versions = "^3.7" +python-versions = "^3.8" files = [] develop = false @@ -2666,7 +2666,7 @@ develop = false type = "git" url = "https://github.com/henadzit/pypika-tortoise.git" reference = "parameterization-changes" -resolved_reference = "e4db01e6ab73045c0cf45ed71eeadba30395905e" +resolved_reference = "31eea5a7d1299d33ce1776e97aab44879c54de35" [[package]] name = "pytest" diff --git a/tests/test_case_when.py b/tests/test_case_when.py index 386d2e7a5..f696b9930 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,6 +1,5 @@ from tests.testmodels import IntFields from tortoise import connections -from tortoise.backends.psycopg.client import PsycopgClient from tortoise.contrib import test from tortoise.exceptions import FieldError from tortoise.expressions import Case, F, Q, When @@ -12,146 +11,153 @@ async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] self.db = connections.get("models") - self.dialect = self.db.schema_generator.DIALECT - self.is_psycopg = isinstance(self.db, PsycopgClient) async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE $3 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE ? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_multi_when(self): category = Case( When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default" ) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s WHEN "intnum"<=%s THEN %s ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' WHEN "intnum"<=2 THEN \'small\' ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_q_object_when(self): category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>%s AND `intnum`<%s THEN %s ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">%s AND "intnum"<%s THEN %s ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">$1 AND "intnum"<$2 THEN $3 ELSE $4 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>2 AND `intnum`<8 THEN 'middle' ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">? AND "intnum"2 AND "intnum"<8 THEN \'middle\' ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_F_then(self): category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN `intnum_null` ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN "intnum_null" ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum_null" ELSE $2 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum_null` ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum_null" ELSE ? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum_null" ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_AE_then(self): # AE: ArithmeticExpression category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN `intnum`+%s ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN "intnum"+%s ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN "intnum"+$2 ELSE $3 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum`+1 ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN "intnum"+? ELSE ? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum"+1 ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_then(self): category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN COALESCE(`intnum_null`,%s) ELSE %s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN COALESCE("intnum_null",%s) ELSE %s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN COALESCE("intnum_null",$2) ELSE $3 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN COALESCE(`intnum_null`,10) ELSE 'default' END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN COALESCE("intnum_null",?) ELSE ? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN COALESCE("intnum_null",10) ELSE \'default\' END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_F_default(self): category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null")) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum_null` END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE "intnum_null" END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum_null" END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE `intnum_null` END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum_null" END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE "intnum_null" END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_AE_default(self): # AE: ArithmeticExpression category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE `intnum`+%s END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE "intnum"+%s END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE "intnum"+$3 END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE `intnum`+1 END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE "intnum"+? END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE "intnum"+1 END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_default(self): category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10)) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() - - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=%s THEN %s ELSE COALESCE(`intnum_null`,%s) END `category` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=%s THEN %s ELSE COALESCE("intnum_null",%s) END "category" FROM "intfields"' - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=$1 THEN $2 ELSE COALESCE("intnum_null",$3) END "category" FROM "intfields"' + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE COALESCE(`intnum_null`,10) END `category` FROM `intfields`" else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=? THEN ? ELSE COALESCE("intnum_null",?) END "category" FROM "intfields"' + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE COALESCE("intnum_null",10) END "category" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_case_when_in_where(self): @@ -163,17 +169,13 @@ async def test_case_when_in_where(self): .annotate(category=category) .filter(category__in=["big", "small"]) .values("intnum") - .sql() + .sql(params_inline=True) ) - if self.dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=%s THEN %s WHEN `intnum`<=%s THEN %s ELSE %s END IN (%s,%s)" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=%s THEN %s WHEN "intnum"<=%s THEN %s ELSE %s END IN (%s,%s)' - else: - expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=$1 THEN $2 WHEN "intnum"<=$3 THEN $4 ELSE $5 END IN ($6,$7)' + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" else: - expected_sql = 'SELECT "intnum" "intnum" FROM "intfields" WHERE CASE WHEN "intnum">=? THEN ? WHEN "intnum"<=? THEN ? ELSE ? END IN (?,?)' + expected_sql = "SELECT \"intnum\" \"intnum\" FROM \"intfields\" WHERE CASE WHEN \"intnum\">=8 THEN 'big' WHEN \"intnum\"<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" self.assertEqual(sql, expected_sql) async def test_annotation_in_when_annotation(self): @@ -182,18 +184,14 @@ async def test_annotation_in_when_annotation(self): .annotate(intnum_plus_1=F("intnum") + 1) .annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False)) .values("id", "intnum", "intnum_plus_1", "bigger_than_10") - .sql() + .sql(params_inline=True) ) - if self.dialect == "mysql": - expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+%s `intnum_plus_1`,CASE WHEN `intnum`+%s>=%s THEN %s ELSE %s END `bigger_than_10` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+%s "intnum_plus_1",CASE WHEN "intnum"+%s>=%s THEN %s ELSE %s END "bigger_than_10" FROM "intfields"' - else: - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+$1 "intnum_plus_1",CASE WHEN "intnum"+$2>=$3 THEN $4 ELSE $5 END "bigger_than_10" FROM "intfields"' + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+1 `intnum_plus_1`,CASE WHEN `intnum`+1>=10 THEN true ELSE false END `bigger_than_10` FROM `intfields`" else: - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+? "intnum_plus_1",CASE WHEN "intnum"+?>=? THEN ? ELSE ? END "bigger_than_10" FROM "intfields"' + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+1 "intnum_plus_1",CASE WHEN "intnum"+1>=10 THEN true ELSE false END "bigger_than_10" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_func_annotation_in_when_annotation(self): @@ -202,18 +200,14 @@ async def test_func_annotation_in_when_annotation(self): .annotate(intnum_col=Coalesce("intnum", 0)) .annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False)) .values("id", "intnum_col", "is_zero") - .sql() + .sql(params_inline=True) ) - if self.dialect == "mysql": - expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,%s) `intnum_col`,CASE WHEN COALESCE(`intnum`,%s)=%s THEN %s ELSE %s END `is_zero` FROM `intfields`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT "id" "id",COALESCE("intnum",%s) "intnum_col",CASE WHEN COALESCE("intnum",%s)=%s THEN %s ELSE %s END "is_zero" FROM "intfields"' - else: - expected_sql = 'SELECT "id" "id",COALESCE("intnum",$1) "intnum_col",CASE WHEN COALESCE("intnum",$2)=$3 THEN $4 ELSE $5 END "is_zero" FROM "intfields"' + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,0) `intnum_col`,CASE WHEN COALESCE(`intnum`,0)=0 THEN true ELSE false END `is_zero` FROM `intfields`" else: - expected_sql = 'SELECT "id" "id",COALESCE("intnum",?) "intnum_col",CASE WHEN COALESCE("intnum",?)=? THEN ? ELSE ? END "is_zero" FROM "intfields"' + expected_sql = 'SELECT "id" "id",COALESCE("intnum",0) "intnum_col",CASE WHEN COALESCE("intnum",0)=0 THEN true ELSE false END "is_zero" FROM "intfields"' self.assertEqual(sql, expected_sql) async def test_case_when_in_group_by(self): @@ -223,24 +217,20 @@ async def test_case_when_in_group_by(self): .annotate(count=Count("id")) .group_by("is_zero") .values("is_zero", "count") - .sql() + .sql(params_inline=True) ) - if self.dialect == "mysql": - expected_sql = "SELECT CASE WHEN `intnum`=%s THEN %s ELSE %s END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected_sql = 'SELECT CASE WHEN "intnum"=%s THEN %s ELSE %s END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' - else: - expected_sql = 'SELECT CASE WHEN "intnum"=$1 THEN $2 ELSE $3 END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' - elif self.dialect == "mssql": - expected_sql = 'SELECT CASE WHEN "intnum"=? THEN ? ELSE ? END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=? THEN ? ELSE ? END' + dialect = self.db.schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT CASE WHEN `intnum`=0 THEN true ELSE false END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" + elif dialect == "mssql": + expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=0 THEN true ELSE false END' else: - expected_sql = 'SELECT CASE WHEN "intnum"=? THEN ? ELSE ? END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' self.assertEqual(sql, expected_sql) async def test_unknown_field_in_when_annotation(self): with self.assertRaisesRegex(FieldError, "Unknown filter param 'unknown'.+"): IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate( is_zero=Case(When(Q(unknown=0), then="1"), default="2") - ).sql() + ).sql(params_inline=True) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 9de684cc9..7e2902322 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -20,8 +20,8 @@ ) from pypika import JoinType, Parameter, Table -from pypika.terms import Parameterizer from pypika.queries import QueryBuilder +from pypika.terms import Parameterizer from tortoise.exceptions import OperationalError from tortoise.expressions import Expression, ResolveContext diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index 3b18ff9f1..9d41171a6 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -1,7 +1,5 @@ from typing import Any, Optional, Type, Union -from pypika import Query - from tortoise import Model, fields from tortoise.backends.odbc.executor import ODBCExecutor from tortoise.exceptions import UnSupportedError @@ -22,5 +20,5 @@ class MSSQLExecutor(ODBCExecutor): fields.BooleanField: to_db_bool, } - async def execute_explain(self, query: Query) -> Any: + async def execute_explain(self, sql: str) -> Any: raise UnSupportedError("MSSQL does not support explain") diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 51f427356..10e544471 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -218,7 +218,7 @@ def get_sql(self, **kwargs: Any) -> str: self.query._choose_db_if_not_chosen() return self.query._make_query(**kwargs)[0] - def as_(self, alias: str) -> "Selectable": + def as_(self, alias: str) -> "Selectable": # type: ignore self.query._choose_db_if_not_chosen() self.query._make_query() return self.query.query.as_(alias) diff --git a/tortoise/functions.py b/tortoise/functions.py index cf69790de..225fc3c65 100644 --- a/tortoise/functions.py +++ b/tortoise/functions.py @@ -57,7 +57,7 @@ class Upper(Function): database_func = functions.Upper -class _Concat(functions.Concat): # type: ignore +class _Concat(functions.Concat): @staticmethod def get_arg_sql(arg, **kwargs): sql = arg.get_sql(with_alias=False, **kwargs) if hasattr(arg, "get_sql") else str(arg) diff --git a/tortoise/indexes.py b/tortoise/indexes.py index ffab65da2..c9b8d9e02 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -38,7 +38,9 @@ def __init__( self.expressions = expressions self.extra = "" - def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str: + def get_sql( + self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + ) -> str: if self.fields: fields = ", ".join(schema_generator.quote(f) for f in self.fields) else: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 2d9c8175d..e029b41d9 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -284,12 +284,18 @@ def _resolve_annotate(self) -> bool: return any(info.term.is_aggregate for info in annotation_info.values()) - def sql(self) -> str: - """Return the actual SQL.""" - if self._db is None: - self._db = self._choose_db() # type: ignore + def sql(self, params_inline=False) -> str: + """ + Returns the SQL query that will be executed. By default, it will return the query with + placeholders, but if you set `params_inline=True`, it will inline the parameters. + + :param params_inline: Whether to inline the parameters + """ + self._choose_db_if_not_chosen() sql, _ = self._make_query() + if params_inline: + sql = self.query.get_sql() return sql def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: @@ -300,6 +306,13 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: """ raise NotImplementedError() # pragma: nocoverage + def _parametrize_query(self, query: QueryBuilder, **pypika_kwargs) -> Tuple[str, List[Any]]: + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) + return ( + query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) + async def _execute(self, sql: str, values: List[Any]) -> Any: raise NotImplementedError() # pragma: nocoverage @@ -1004,8 +1017,7 @@ async def explain(self) -> Any: and query optimization. **The output format may (and will) vary greatly depending on the database backend.** """ - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, _ = self._make_query() return await self._db.executor_class(model=self.model, db=self._db).execute_explain(sql) @@ -1123,11 +1135,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return ( - self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), - parameterizer.values, - ) + return self._parametrize_query(self.query, **pypika_kwargs) def __await__(self) -> Generator[Any, None, List[MODEL]]: if self._db is None: @@ -1145,7 +1153,7 @@ async def _execute(self, sql: str, values: List[Any]) -> List[MODEL]: db=self._db, prefetch_map=self._prefetch_map, prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, + select_related_idx=self._select_related_idx, # type: ignore ).execute_select( sql, values, @@ -1326,7 +1334,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() self.query._limit = self.query._wrapper_cls(1) - self.query._select_other(ValueWrapper(1, allow_parametrize=False)) + self.query._select_other(ValueWrapper(1, allow_parametrize=False)) # type:ignore[arg-type] if self._force_indexes: self.query._force_indexes = [] @@ -1335,12 +1343,10 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return self.query.get_sql(parameterizer=parameterizer), parameterizer.values + return self._parametrize_query(self.query, **pypika_kwargs) def __await__(self) -> Generator[Any, None, bool]: - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, values = self._make_query() return self._execute(sql, values).__await__() @@ -1396,15 +1402,11 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: if self._use_indexes: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return ( - self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), - parameterizer.values, - ) + + return self._parametrize_query(self.query, **pypika_kwargs) def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, values = self._make_query() return self._execute(sql, values).__await__() @@ -1616,11 +1618,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return ( - self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), - parameterizer.values, - ) + return self._parametrize_query(self.query, **pypika_kwargs) @overload def __await__( @@ -1633,8 +1631,7 @@ def __await__( ) -> Generator[Any, None, Tuple[Any, ...]]: ... def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]: - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, values = self._make_query() return self._execute(sql, values).__await__() # pylint: disable=E1101 @@ -1750,11 +1747,7 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return ( - self.query.get_sql(parameterizer=parameterizer, **pypika_kwargs), - parameterizer.values, - ) + return self._parametrize_query(self.query, **pypika_kwargs) @overload def __await__( @@ -1769,8 +1762,7 @@ def __await__( def __await__( self, ) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]: - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, values = self._make_query() return self._execute(sql, values).__await__() # pylint: disable=E1101 @@ -1814,8 +1806,7 @@ def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: self._db = db def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: - self.query = RawSQL(self._sql) - return self.query.get_sql(**pypika_kwargs), [] + return RawSQL(self._sql).get_sql(**pypika_kwargs), [] async def _execute(self, sql: str, values: List[Any]) -> Any: instance_list = await self._db.executor_class( @@ -1825,8 +1816,7 @@ async def _execute(self, sql: str, values: List[Any]) -> Any: return instance_list def __await__(self) -> Generator[Any, None, List[MODEL]]: - if self._db is None: - self._db = self._choose_db() # type: ignore + self._choose_db_if_not_chosen() sql, values = self._make_query() return self._execute(sql, values).__await__() @@ -1917,7 +1907,7 @@ def __await__(self) -> Generator[Any, Any, int]: queries = self._make_queries() return self._execute_many(queries).__await__() - def sql(self) -> str: + def sql(self, params_inline=False) -> str: self._choose_db_if_not_chosen() queries = self._make_queries() return ";".join([sql for sql, _ in queries]) @@ -2017,7 +2007,7 @@ def __await__(self) -> Generator[Any, None, None]: self._make_queries() return self._execute_many().__await__() - def sql(self) -> str: + def sql(self, params_inline=False) -> str: self._choose_db_if_not_chosen() self._make_queries() if self._insert_query and self._insert_query_all: From dfe7553cc0a8da86a469ca3f54599d565f741714 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 21 Nov 2024 18:27:16 +0100 Subject: [PATCH 8/8] Use pypika-tortoise 0.3.0 --- poetry.lock | 18 +++++++----------- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 753d3a20f..ccf6e9fd2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2655,18 +2655,14 @@ files = [ [[package]] name = "pypika-tortoise" -version = "0.2.2" +version = "0.3.0" description = "Forked from pypika and streamline just for tortoise-orm" optional = false -python-versions = "^3.8" -files = [] -develop = false - -[package.source] -type = "git" -url = "https://github.com/henadzit/pypika-tortoise.git" -reference = "parameterization-changes" -resolved_reference = "31eea5a7d1299d33ce1776e97aab44879c54de35" +python-versions = "<4.0,>=3.8" +files = [ + {file = "pypika_tortoise-0.3.0-py3-none-any.whl", hash = "sha256:c374a09591cdb24828d1c28bd0dfcfa2916094f4d3561a65c965b2549aa7c52f"}, + {file = "pypika_tortoise-0.3.0.tar.gz", hash = "sha256:9bfb796e15ff8b395355ff42d9c4a4146fd716d3cbf9679391ac3a1c06d0e56a"}, +] [[package]] name = "pytest" @@ -3859,4 +3855,4 @@ psycopg = ["psycopg"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ccee25fb7393e24ddf0389c7787bbbc941c46e0afb090caf4c0ffbe3a81d744a" +content-hash = "e39d83526d00453748662852417b58c5a4f3d6e326671493125b79cad305f801" diff --git a/pyproject.toml b/pyproject.toml index 21503251f..d271fd205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8" -pypika-tortoise = { git = "https://github.com/henadzit/pypika-tortoise.git", branch = "parameterization-changes" } +pypika-tortoise = "^0.3.0" iso8601 = "^2.1.0" aiosqlite = ">=0.16.0, <0.21.0" pytz = "*"