From 817eb4b9f4810d033b3624c56830bdd3d9aa3380 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 14 Oct 2024 10:49:12 +0200 Subject: [PATCH 01/17] reformat postgres regex enums and support case insentive regexp # Conflicts: # tortoise/contrib/postgres/regex.py --- tortoise/backends/base_postgres/executor.py | 7 ++++++- tortoise/contrib/postgres/regex.py | 12 +++++++++--- tortoise/contrib/sqlite/regex.py | 0 tortoise/filters.py | 15 ++++++++++++++- 4 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 tortoise/contrib/sqlite/regex.py diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index fc6345dd2..e7253e833 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -11,9 +11,13 @@ postgres_json_contains, postgres_json_filter, ) -from tortoise.contrib.postgres.regex import postgres_posix_regex +from tortoise.contrib.postgres.regex import ( + postgres_insensitive_posix_regex, + postgres_posix_regex, +) from tortoise.contrib.postgres.search import SearchCriterion from tortoise.filters import ( + insensitive_posix_regex, json_contained_by, json_contains, json_filter, @@ -35,6 +39,7 @@ class BasePostgresExecutor(BaseExecutor): json_contained_by: postgres_json_contained_by, json_filter: postgres_json_filter, posix_regex: postgres_posix_regex, + insensitive_posix_regex: postgres_insensitive_posix_regex, } def _prepare_insert_statement( diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index fcb951d45..e273dad1b 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -5,9 +5,15 @@ class PostgresRegexMatching(enum.Enum): - posix_regex = "~" + POSIX_REGEX = " ~ " + IPOSIX_REGEX = " *~ " -def postgres_posix_regex(field: Term, value: str) -> BasicCriterion: +def postgres_posix_regex(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(PostgresRegexMatching.posix_regex, field, term) + return BasicCriterion(PostgresRegexMatching.POSIX_REGEX, field, term) + + +def postgres_insensitive_posix_regex(field: Term, value: str): + term = cast(Term, field.wrap_constant(value)) + return BasicCriterion(PostgresRegexMatching.IPOSIX_REGEX, field, term) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py new file mode 100644 index 000000000..e69de29bb diff --git a/tortoise/filters.py b/tortoise/filters.py index 828c1a2dd..98b1b0e80 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -144,7 +144,14 @@ def search(field: Term, value: str) -> Any: def posix_regex(field: Term, value: str) -> Any: # Will be overridden in each executor raise NotImplementedError( - "The postgres_posix_regex filter operator is not supported by your database backend" + "The posix_regex filter operator is not supported by your database backend" + ) + + +def insensitive_posix_regex(field: Term, value: str): + # Will be overridden in each executor + raise NotImplementedError( + "The insensitive_posix_regex filter operator is not supported by your database backend" ) @@ -510,6 +517,12 @@ def get_filters_for_field( "operator": posix_regex, "value_encoder": string_encoder, }, + f"{field_name}__iposix_regex": { + "field": actual_field_name, + "source_field": source_field, + "operator": insensitive_posix_regex, + "value_encoder": string_encoder, + }, f"{field_name}__year": { "field": actual_field_name, "source_field": source_field, From 8d80eaf9a4b0a43aac1e91e07cb198e736098fd6 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 14 Oct 2024 10:49:27 +0200 Subject: [PATCH 02/17] add regexp supprot for sqlite --- tortoise/backends/sqlite/executor.py | 9 +++++++++ tortoise/contrib/sqlite/regex.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index dba3dee03..a46d05960 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -7,6 +7,10 @@ from tortoise import Model, fields, timezone from tortoise.backends.base.executor import BaseExecutor +from tortoise.contrib.sqlite.regex import ( + insensitive_posix_sqlite_regexp, + posix_sqlite_regexp, +) from tortoise.fields import ( BigIntField, BooleanField, @@ -16,6 +20,7 @@ SmallIntField, TimeField, ) +from tortoise.filters import insensitive_posix_regex, posix_regex def to_db_bool( @@ -91,6 +96,10 @@ class SqliteExecutor(BaseExecutor): } EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN" DB_NATIVE = {bytes, str, int, float} + FILTER_FUNC_OVERRIDE = { + posix_regex: posix_sqlite_regexp, + insensitive_posix_regex: insensitive_posix_sqlite_regexp, + } async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index e69de29bb..ff0063a5e 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -0,0 +1,17 @@ +import enum + +from pypika.terms import BasicCriterion, Term + + +class SQLiteRegexMatching(enum.Enum): + POSIX_REGEX = " REGEXP " + + +def posix_sqlite_regexp(field: Term, value: str): + return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, field, field.wrap_constant(value)) + + +def insensitive_posix_sqlite_regexp(field: Term, value: str): + return BasicCriterion( + SQLiteRegexMatching.POSIX_REGEX, field, field.wrap_constant(f"(?i) {value}") + ) From eff21da6833f9e31ceec09cc73b9f6704b8d2432 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 14 Oct 2024 11:05:54 +0200 Subject: [PATCH 03/17] update docs --- docs/query.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/query.rst b/docs/query.rst index 632e711cd..75cc7975b 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -271,8 +271,9 @@ The ``filter`` option allows you to filter the JSON object by its keys and value obj5 = await JSONModel.filter(data__filter={"owner__name__isnull": True}).first() obj6 = await JSONModel.filter(data__filter={"owner__last__not_isnull": False}).first() -In PostgreSQL and MySQL, you can use ``postgres_posix_regex`` to make comparisons using POSIX regular expressions: -In PostgreSQL, this is done with the ``~`` operator, while in MySQL the ``REGEXP`` operator is used. +In PostgreSQL and MySQL and SQLite, you can use ``posix_regex`` to make comparisons using POSIX regular expressions: +On PostgreSQL, this uses the ``~`` operator, on MySQL and SQLite it uses the ``REGEXP`` operator. +PostgreSQL and SQLite also support ``iposix_regex``, which makes case insensive comparisons. .. code-block:: python3 @@ -281,6 +282,7 @@ In PostgreSQL, this is done with the ``~`` operator, while in MySQL the ``REGEXP await DemoModel.create(demo_text="Hello World") obj = await DemoModel.filter(demo_text__posix_regex="^Hello World$").first() + obj = await DemoModel.filter(demo_text__iposix_regex="^hello world$").first() In PostgreSQL, ``filter`` supports additional lookup types: From 0fcd613f7f5db1e99146d2289e3d83c0453f6a4a Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 13:45:05 +0100 Subject: [PATCH 04/17] fx wrong order of case insenstivie regexp operator --- tortoise/contrib/postgres/regex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index e273dad1b..097ea5bc8 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -6,7 +6,7 @@ class PostgresRegexMatching(enum.Enum): POSIX_REGEX = " ~ " - IPOSIX_REGEX = " *~ " + IPOSIX_REGEX = " ~* " def postgres_posix_regex(field: Term, value: str): From 55573785b311fa7ae5abc4712df2a2a24f20faaf Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 13:53:12 +0100 Subject: [PATCH 05/17] add tests --- tests/test_posix_regex_filter.py | 52 ++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 106464c19..1951be8c1 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -4,9 +4,32 @@ class TestPosixRegexFilter(test.TestCase): - @test.requireCapability(dialect="mysql") @test.requireCapability(dialect="postgres") - async def test_regex_filter(self): + async def test_regex_filter_postgres(self): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + self.assertEqual( + set( + await testmodels.Author.filter( + name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" + ).values_list("name", flat=True) + ), + {author.name}, + ) + + @test.requireCapability(dialect="mysql") + async def test_regex_filter_mysql(self): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + self.assertEqual( + set( + await testmodels.Author.filter( + name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" + ).values_list("name", flat=True) + ), + {author.name}, + ) + + @test.requireCapability(dialect="sqlite") + async def test_regex_filter_sqlite(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") self.assertEqual( set( @@ -16,3 +39,28 @@ async def test_regex_filter(self): ), {author.name}, ) + +class TestCaseInsensitivePosixRegexFilter(test.TestCase): + @test.requireCapability(dialect="postgres") + async def test_case_insensitive_regex_filter_postgres(self): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + self.assertEqual( + set( + await testmodels.Author.filter( + name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" + ).values_list("name", flat=True) + ), + {author.name}, + ) + + @test.requireCapability(dialect="sqlite") + async def test_case_insensitive_regex_filter_sqlite(self): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + self.assertEqual( + set( + await testmodels.Author.filter( + name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" + ).values_list("name", flat=True) + ), + {author.name}, + ) \ No newline at end of file From e0adcf3f2250229a7c3c8a57d49a1d989df8795a Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 14:32:26 +0100 Subject: [PATCH 06/17] fix sqlite iregexp # Conflicts: # tortoise/backends/sqlite/client.py --- tests/test_posix_regex_filter.py | 1 + tortoise/backends/sqlite/client.py | 2 ++ tortoise/contrib/sqlite/regex.py | 26 +++++++++++++++++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 1951be8c1..7b2535f0b 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -56,6 +56,7 @@ async def test_case_insensitive_regex_filter_postgres(self): @test.requireCapability(dialect="sqlite") async def test_case_insensitive_regex_filter_sqlite(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + print(testmodels.Author.filter(name__iposix_regex="^johann [a-zA-Z]+ Von goethe$").sql(params_inline=True)) self.assertEqual( set( await testmodels.Author.filter( diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index bb97abc14..4156f1ef0 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -29,6 +29,7 @@ ) from tortoise.backends.sqlite.executor import SqliteExecutor from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator +from tortoise.contrib.sqlite.regex import install_regexp_function from tortoise.connection import connections from tortoise.exceptions import ( IntegrityError, @@ -84,6 +85,7 @@ async def create_connection(self, with_db: bool) -> None: for pragma, val in self.pragmas.items(): cursor = await self._connection.execute(f"PRAGMA {pragma}={val}") await cursor.close() + await install_regexp_function(self._connection) self.log.debug( "Created connection %s with params: filename=%s %s", self._connection, diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index ff0063a5e..fbb96d6be 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -1,17 +1,33 @@ import enum +import re +from typing import cast -from pypika.terms import BasicCriterion, Term +import aiosqlite +from pypika.terms import BasicCriterion, Term, Function class SQLiteRegexMatching(enum.Enum): POSIX_REGEX = " REGEXP " + IPOSIX_REGEX = " MATCH " def posix_sqlite_regexp(field: Term, value: str): - return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, field, field.wrap_constant(value)) + term = cast(Term, field.wrap_constant(value)) + return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, field, term) + # return Function("regexp", field, term) def insensitive_posix_sqlite_regexp(field: Term, value: str): - return BasicCriterion( - SQLiteRegexMatching.POSIX_REGEX, field, field.wrap_constant(f"(?i) {value}") - ) + term = cast(Term, field.wrap_constant(value)) + return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, field, term) + # return Function("iregexp", field, term) + +async def install_regexp_function(connection: aiosqlite.Connection): + def regexp(expr, item): + return re.search(expr, item) is not None + + def iregexp(expr, item): + return re.search(expr, item, re.IGNORECASE) is not None + + await connection.create_function("regexp", 2, regexp) + await connection.create_function("match", 2, iregexp) \ No newline at end of file From 6ca34db5074c2db7c91ad8e3693854da3405f9ab Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 14:35:06 +0100 Subject: [PATCH 07/17] remove unused import --- tortoise/contrib/sqlite/regex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index fbb96d6be..99c36ac1d 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -3,7 +3,7 @@ from typing import cast import aiosqlite -from pypika.terms import BasicCriterion, Term, Function +from pypika.terms import BasicCriterion, Term class SQLiteRegexMatching(enum.Enum): From c1cf54e0141bd301330f089d09797757ac817feb Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 14:43:23 +0100 Subject: [PATCH 08/17] add cast to varchar # Conflicts: # tortoise/contrib/postgres/regex.py --- tests/test_posix_regex_filter.py | 1 - tortoise/contrib/postgres/regex.py | 2 ++ tortoise/contrib/sqlite/regex.py | 9 ++++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 7b2535f0b..1951be8c1 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -56,7 +56,6 @@ async def test_case_insensitive_regex_filter_postgres(self): @test.requireCapability(dialect="sqlite") async def test_case_insensitive_regex_filter_sqlite(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - print(testmodels.Author.filter(name__iposix_regex="^johann [a-zA-Z]+ Von goethe$").sql(params_inline=True)) self.assertEqual( set( await testmodels.Author.filter( diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index 097ea5bc8..0c06f0b3c 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -2,6 +2,8 @@ from typing import cast from pypika_tortoise.terms import BasicCriterion, Term +from pypika_tortoise.functions import Cast +from pypika_tortoise.enums import SqlTypes class PostgresRegexMatching(enum.Enum): diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 99c36ac1d..05e71ffc1 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -4,7 +4,8 @@ import aiosqlite from pypika.terms import BasicCriterion, Term - +from pypika.functions import Cast +from pypika.enums import SqlTypes class SQLiteRegexMatching(enum.Enum): POSIX_REGEX = " REGEXP " @@ -13,14 +14,12 @@ class SQLiteRegexMatching(enum.Enum): def posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, field, term) - # return Function("regexp", field, term) + return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, Cast(field, SqlTypes.VARCHAR), term) def insensitive_posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, field, term) - # return Function("iregexp", field, term) + return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, Cast(field, SqlTypes.VARCHAR), term) async def install_regexp_function(connection: aiosqlite.Connection): def regexp(expr, item): From 6205b329aa3062d9ae973ab8ce93b591dbfda64c Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 14:55:22 +0100 Subject: [PATCH 09/17] fix tests for mysql --- tortoise/backends/mysql/executor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 45be08e7d..503c67d80 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,7 +1,10 @@ +import enum + from pypika_tortoise import functions from pypika_tortoise.enums import SqlTypes from pypika_tortoise.terms import BasicCriterion, Criterion from pypika_tortoise.utils import format_quotes +from pypika_tortoise.functions import Cast from tortoise import Model from tortoise.backends.base.executor import BaseExecutor @@ -31,6 +34,9 @@ ) +class MySQLRegexpComparators(enum.Enum): + REGEXP = " REGEXP " + class StrWrapper(ValueWrapper): """ Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL @@ -97,7 +103,7 @@ def mysql_search(field: Term, value: str) -> SearchCriterion: def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: - return BasicCriterion(" REGEXP ", field, StrWrapper(value)) # type:ignore[arg-type] + return BasicCriterion(MySQLRegexpComparators.REGEXP, Cast(field, SqlTypes.VARCHAR), StrWrapper(value)) # type:ignore[arg-type] class MySQLExecutor(BaseExecutor): From 2635329d26bce765edab4fce7fc1f13c70424718 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 15:12:59 +0100 Subject: [PATCH 10/17] add null field test --- tests/test_posix_regex_filter.py | 26 ++++++++++++++++++++++++++ tortoise/contrib/postgres/regex.py | 2 ++ tortoise/contrib/sqlite/regex.py | 2 ++ 3 files changed, 30 insertions(+) diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 1951be8c1..e3632da84 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -40,6 +40,32 @@ async def test_regex_filter_sqlite(self): {author.name}, ) + @test.requireCapability(dialect="postgres") + async def test_regex_filter_works_with_null_field_postgres(self): + t = await testmodels.Tournament.create(name="Test") + print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) + self.assertEqual( + set( + await testmodels.Tournament.filter( + desc__posix_regex="^test$" + ).values_list("name", flat=True) + ), + set(), + ) + + @test.requireCapability(dialect="sqlite") + async def test_regex_filter_works_with_null_field_sqlite(self): + t = await testmodels.Tournament.create(name="Test") + print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) + self.assertEqual( + set( + await testmodels.Tournament.filter( + desc__posix_regex="^test$" + ).values_list("name", flat=True) + ), + set(), + ) + class TestCaseInsensitivePosixRegexFilter(test.TestCase): @test.requireCapability(dialect="postgres") async def test_case_insensitive_regex_filter_postgres(self): diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index 0c06f0b3c..17e8890f2 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -5,6 +5,8 @@ from pypika_tortoise.functions import Cast from pypika_tortoise.enums import SqlTypes +from tortoise.functions import Coalesce + class PostgresRegexMatching(enum.Enum): POSIX_REGEX = " ~ " diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 05e71ffc1..13977bb22 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -23,6 +23,8 @@ def insensitive_posix_sqlite_regexp(field: Term, value: str): async def install_regexp_function(connection: aiosqlite.Connection): def regexp(expr, item): + if not expr or not item: + return False return re.search(expr, item) is not None def iregexp(expr, item): From a557ff3d26b1f7c69c20a1ce779af9cca622480a Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 15:14:01 +0100 Subject: [PATCH 11/17] remove type-ignore statement from mysql regex impl --- tortoise/backends/mysql/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 503c67d80..0e84bc8de 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -103,7 +103,7 @@ def mysql_search(field: Term, value: str) -> SearchCriterion: def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: - return BasicCriterion(MySQLRegexpComparators.REGEXP, Cast(field, SqlTypes.VARCHAR), StrWrapper(value)) # type:ignore[arg-type] + return BasicCriterion(MySQLRegexpComparators.REGEXP, Cast(field, SqlTypes.VARCHAR), StrWrapper(value)) class MySQLExecutor(BaseExecutor): From 4cfda6cf77df287e9056b84c6220312a23db8199 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 15:22:32 +0100 Subject: [PATCH 12/17] add coalesce --- tortoise/backends/mysql/executor.py | 4 ++-- tortoise/contrib/postgres/regex.py | 8 +++----- tortoise/contrib/sqlite/regex.py | 7 ++++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 0e84bc8de..b823bcb2e 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -4,7 +4,7 @@ from pypika_tortoise.enums import SqlTypes from pypika_tortoise.terms import BasicCriterion, Criterion from pypika_tortoise.utils import format_quotes -from pypika_tortoise.functions import Cast +from pypika_tortoise.functions import Cast, Coalesce from tortoise import Model from tortoise.backends.base.executor import BaseExecutor @@ -103,7 +103,7 @@ def mysql_search(field: Term, value: str) -> SearchCriterion: def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: - return BasicCriterion(MySQLRegexpComparators.REGEXP, Cast(field, SqlTypes.VARCHAR), StrWrapper(value)) + return BasicCriterion(MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.VARCHAR)), StrWrapper(value)) class MySQLExecutor(BaseExecutor): diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index 17e8890f2..c41fe07d3 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -2,11 +2,9 @@ from typing import cast from pypika_tortoise.terms import BasicCriterion, Term -from pypika_tortoise.functions import Cast +from pypika_tortoise.functions import Cast, Coalesce from pypika_tortoise.enums import SqlTypes -from tortoise.functions import Coalesce - class PostgresRegexMatching(enum.Enum): POSIX_REGEX = " ~ " @@ -15,9 +13,9 @@ class PostgresRegexMatching(enum.Enum): def postgres_posix_regex(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(PostgresRegexMatching.POSIX_REGEX, field, term) + return BasicCriterion(PostgresRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) def postgres_insensitive_posix_regex(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(PostgresRegexMatching.IPOSIX_REGEX, field, term) + return BasicCriterion(PostgresRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 13977bb22..11e1aa725 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -4,9 +4,10 @@ import aiosqlite from pypika.terms import BasicCriterion, Term -from pypika.functions import Cast +from pypika.functions import Cast, Coalesce from pypika.enums import SqlTypes + class SQLiteRegexMatching(enum.Enum): POSIX_REGEX = " REGEXP " IPOSIX_REGEX = " MATCH " @@ -14,12 +15,12 @@ class SQLiteRegexMatching(enum.Enum): def posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, Cast(field, SqlTypes.VARCHAR), term) + return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) def insensitive_posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, Cast(field, SqlTypes.VARCHAR), term) + return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) async def install_regexp_function(connection: aiosqlite.Connection): def regexp(expr, item): From bbb083dcb7d3e140adfd2c0cd18f748dd9f644db Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 15:26:51 +0100 Subject: [PATCH 13/17] make style and make lint fixes --- tests/contrib/test_pydantic.py | 1 - tests/test_posix_regex_filter.py | 19 ++++++++++--------- tortoise/backends/mysql/executor.py | 7 +++++-- tortoise/contrib/postgres/regex.py | 12 ++++++++---- tortoise/contrib/sqlite/regex.py | 15 ++++++++++----- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index a4fcd4466..1d03be6eb 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -7,7 +7,6 @@ Address, CamelCaseAliasPerson, Employee, - EnumFields, Event, IntFields, JSONFields, diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index e3632da84..5d233b438 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -42,30 +42,31 @@ async def test_regex_filter_sqlite(self): @test.requireCapability(dialect="postgres") async def test_regex_filter_works_with_null_field_postgres(self): - t = await testmodels.Tournament.create(name="Test") + await testmodels.Tournament.create(name="Test") print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) self.assertEqual( set( - await testmodels.Tournament.filter( - desc__posix_regex="^test$" - ).values_list("name", flat=True) + await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( + "name", flat=True + ) ), set(), ) @test.requireCapability(dialect="sqlite") async def test_regex_filter_works_with_null_field_sqlite(self): - t = await testmodels.Tournament.create(name="Test") + await testmodels.Tournament.create(name="Test") print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) self.assertEqual( set( - await testmodels.Tournament.filter( - desc__posix_regex="^test$" - ).values_list("name", flat=True) + await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( + "name", flat=True + ) ), set(), ) + class TestCaseInsensitivePosixRegexFilter(test.TestCase): @test.requireCapability(dialect="postgres") async def test_case_insensitive_regex_filter_postgres(self): @@ -89,4 +90,4 @@ async def test_case_insensitive_regex_filter_sqlite(self): ).values_list("name", flat=True) ), {author.name}, - ) \ No newline at end of file + ) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index b823bcb2e..56b2b4192 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -2,9 +2,9 @@ from pypika_tortoise import functions from pypika_tortoise.enums import SqlTypes +from pypika_tortoise.functions import Cast, Coalesce from pypika_tortoise.terms import BasicCriterion, Criterion from pypika_tortoise.utils import format_quotes -from pypika_tortoise.functions import Cast, Coalesce from tortoise import Model from tortoise.backends.base.executor import BaseExecutor @@ -37,6 +37,7 @@ class MySQLRegexpComparators(enum.Enum): REGEXP = " REGEXP " + class StrWrapper(ValueWrapper): """ Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL @@ -103,7 +104,9 @@ def mysql_search(field: Term, value: str) -> SearchCriterion: def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: - return BasicCriterion(MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.VARCHAR)), StrWrapper(value)) + return BasicCriterion( + MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.VARCHAR)), StrWrapper(value) + ) class MySQLExecutor(BaseExecutor): diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index c41fe07d3..18fd9870f 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -1,9 +1,9 @@ import enum from typing import cast -from pypika_tortoise.terms import BasicCriterion, Term -from pypika_tortoise.functions import Cast, Coalesce from pypika_tortoise.enums import SqlTypes +from pypika_tortoise.functions import Cast, Coalesce +from pypika_tortoise.terms import BasicCriterion, Term class PostgresRegexMatching(enum.Enum): @@ -13,9 +13,13 @@ class PostgresRegexMatching(enum.Enum): def postgres_posix_regex(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(PostgresRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) + return BasicCriterion( + PostgresRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term + ) def postgres_insensitive_posix_regex(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(PostgresRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) + return BasicCriterion( + PostgresRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term + ) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 11e1aa725..97dc1aef4 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -3,9 +3,9 @@ from typing import cast import aiosqlite -from pypika.terms import BasicCriterion, Term -from pypika.functions import Cast, Coalesce from pypika.enums import SqlTypes +from pypika.functions import Cast, Coalesce +from pypika.terms import BasicCriterion, Term class SQLiteRegexMatching(enum.Enum): @@ -15,12 +15,17 @@ class SQLiteRegexMatching(enum.Enum): def posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) + return BasicCriterion( + SQLiteRegexMatching.POSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term + ) def insensitive_posix_sqlite_regexp(field: Term, value: str): term = cast(Term, field.wrap_constant(value)) - return BasicCriterion(SQLiteRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term) + return BasicCriterion( + SQLiteRegexMatching.IPOSIX_REGEX, Coalesce(Cast(field, SqlTypes.VARCHAR), ""), term + ) + async def install_regexp_function(connection: aiosqlite.Connection): def regexp(expr, item): @@ -32,4 +37,4 @@ def iregexp(expr, item): return re.search(expr, item, re.IGNORECASE) is not None await connection.create_function("regexp", 2, regexp) - await connection.create_function("match", 2, iregexp) \ No newline at end of file + await connection.create_function("match", 2, iregexp) From 982d571bf837c2358c2869692a18d057145cbf05 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 13 Dec 2024 15:44:28 +0100 Subject: [PATCH 14/17] fix mysql tests --- tortoise/backends/mysql/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 56b2b4192..c35f39507 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -105,7 +105,7 @@ def mysql_search(field: Term, value: str) -> SearchCriterion: def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: return BasicCriterion( - MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.VARCHAR)), StrWrapper(value) + MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.CHAR)), StrWrapper(value) ) From 7f8fac5d9057fc4c6556c17bc1b8ca0181106c16 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 16 Dec 2024 14:20:40 +0100 Subject: [PATCH 15/17] make sqlite regexp function installation option, add capability for it # Conflicts: # tortoise/backends/sqlite/client.py --- Makefile | 3 ++ tests/test_connection.py | 11 +++--- tests/test_posix_regex_filter.py | 39 ++++++---------------- tortoise/backends/base/client.py | 3 ++ tortoise/backends/base/config_generator.py | 5 ++- tortoise/backends/base_postgres/client.py | 2 +- tortoise/backends/mysql/client.py | 2 +- tortoise/backends/sqlite/__init__.py | 8 ++++- tortoise/backends/sqlite/client.py | 21 ++++++++++-- tortoise/connection.py | 14 +++++--- tortoise/contrib/sqlite/regex.py | 2 +- 11 files changed, 64 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index 5ed9b7b64..8bb915271 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,9 @@ test: deps test_sqlite: $(py_warn) TORTOISE_TEST_DB=sqlite://:memory: pytest --cov-report= $(pytest_opts) +test_sqlite_regexp: + $(py_warn) TORTOISE_TEST_DB=sqlite://:memory:?install_regexp_functions=True pytest --cov-report= $(pytest_opts) + test_postgres_asyncpg: python -V | grep PyPy || $(py_warn) TORTOISE_TEST_DB="asyncpg://postgres:$(TORTOISE_POSTGRES_PASS)@127.0.0.1:5432/test_\{\}" pytest $(pytest_opts) --cov-append --cov-report= diff --git a/tests/test_connection.py b/tests/test_connection.py index 7f6841330..22faf844c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -71,18 +71,21 @@ def test_clear_storage(self, mocked_get_storage: Mock): @patch("tortoise.connection.importlib.import_module") def test_discover_client_class_proper_impl(self, mocked_import_module: Mock): mocked_import_module.return_value = Mock(client_class="some_class") - client_class = self.conn_handler._discover_client_class("blah") + del mocked_import_module.return_value.get_client_class + client_class = self.conn_handler._discover_client_class({"engine": "blah"}) + mocked_import_module.assert_called_once_with("blah") self.assertEqual(client_class, "some_class") @patch("tortoise.connection.importlib.import_module") def test_discover_client_class_improper_impl(self, mocked_import_module: Mock): del mocked_import_module.return_value.client_class + del mocked_import_module.return_value.get_client_class engine = "some_engine" with self.assertRaises( ConfigurationError, msg=f'Backend for engine "{engine}" does not implement db client' ): - _ = self.conn_handler._discover_client_class(engine) + _ = self.conn_handler._discover_client_class({"engine": engine}) @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) def test_get_db_info_present(self, mocked_db_config: Mock): @@ -156,7 +159,7 @@ def test_create_connection_db_info_str( mocked_get_db_info.assert_called_once_with(alias) mocked_expand_db_url.assert_called_once_with("some_db_url") - mocked_discover_client_class.assert_called_once_with("some_engine") + mocked_discover_client_class.assert_called_once_with({"engine": "some_engine", "credentials": {"cred_key": "some_val"}}) expected_client_class.assert_called_once_with(**expected_db_params) self.assertEqual(ret_val, "some_connection") @@ -182,7 +185,7 @@ def test_create_connection_db_info_not_str( mocked_get_db_info.assert_called_once_with(alias) mocked_expand_db_url.assert_not_called() - mocked_discover_client_class.assert_called_once_with("some_engine") + mocked_discover_client_class.assert_called_once_with({"engine": "some_engine", "credentials": {"cred_key": "some_val"}}) expected_client_class.assert_called_once_with(**expected_db_params) self.assertEqual(ret_val, "some_connection") diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 5d233b438..0d09b01ce 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -1,35 +1,16 @@ from tests import testmodels from tortoise.contrib import test +class RegexTestCase(test.TestCase): + async def asyncSetUp(self) -> None: + await super().asyncSetUp() -class TestPosixRegexFilter(test.TestCase): - @test.requireCapability(dialect="postgres") - async def test_regex_filter_postgres(self): - author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - self.assertEqual( - set( - await testmodels.Author.filter( - name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" - ).values_list("name", flat=True) - ), - {author.name}, - ) - @test.requireCapability(dialect="mysql") - async def test_regex_filter_mysql(self): - author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - self.assertEqual( - set( - await testmodels.Author.filter( - name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" - ).values_list("name", flat=True) - ), - {author.name}, - ) +class TestPosixRegexFilter(test.TestCase): - @test.requireCapability(dialect="sqlite") - async def test_regex_filter_sqlite(self): + @test.requireCapability(support_for_posix_regex_queries=True) + async def test_regex_filter(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") self.assertEqual( set( @@ -40,7 +21,7 @@ async def test_regex_filter_sqlite(self): {author.name}, ) - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) async def test_regex_filter_works_with_null_field_postgres(self): await testmodels.Tournament.create(name="Test") print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) @@ -53,7 +34,7 @@ async def test_regex_filter_works_with_null_field_postgres(self): set(), ) - @test.requireCapability(dialect="sqlite") + @test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) async def test_regex_filter_works_with_null_field_sqlite(self): await testmodels.Tournament.create(name="Test") print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) @@ -68,7 +49,7 @@ async def test_regex_filter_works_with_null_field_sqlite(self): class TestCaseInsensitivePosixRegexFilter(test.TestCase): - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) async def test_case_insensitive_regex_filter_postgres(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") self.assertEqual( @@ -80,7 +61,7 @@ async def test_case_insensitive_regex_filter_postgres(self): {author.name}, ) - @test.requireCapability(dialect="sqlite") + @test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) async def test_case_insensitive_regex_filter_sqlite(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") self.assertEqual( diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index f938583f3..2e8a11f5d 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -46,6 +46,7 @@ class Capabilities: :param support_for_update: Indicates that this DB supports SELECT ... FOR UPDATE SQL statement. :param support_index_hint: Support force index or use index. :param support_update_limit_order_by: support update/delete with limit and order by. + :param: support_for_posix_regex_queries: indicated if the db supports posix regex queries """ def __init__( @@ -63,6 +64,7 @@ def __init__( support_index_hint: bool = False, # support update/delete with limit and order by support_update_limit_order_by: bool = True, + support_for_posix_regex_queries: bool = False, ) -> None: super().__setattr__("_mutable", True) @@ -74,6 +76,7 @@ def __init__( self.support_for_update = support_for_update self.support_index_hint = support_index_hint self.support_update_limit_order_by = support_update_limit_order_by + self.support_for_posix_regex_queries = support_for_posix_regex_queries super().__setattr__("_mutable", False) def __setattr__(self, attr: str, value: Any) -> None: diff --git a/tortoise/backends/base/config_generator.py b/tortoise/backends/base/config_generator.py index cd29cc051..514f50175 100644 --- a/tortoise/backends/base/config_generator.py +++ b/tortoise/backends/base/config_generator.py @@ -62,7 +62,10 @@ "skip_first_char": False, "vmap": {"path": "file_path"}, "defaults": {"journal_mode": "WAL", "journal_size_limit": 16384}, - "cast": {"journal_size_limit": int}, + "cast": { + "journal_size_limit": int, + "install_regexp_functions": bool, + }, }, "mysql": { "engine": "tortoise.backends.mysql", diff --git a/tortoise/backends/base_postgres/client.py b/tortoise/backends/base_postgres/client.py index 06e6d3288..365acb78a 100644 --- a/tortoise/backends/base_postgres/client.py +++ b/tortoise/backends/base_postgres/client.py @@ -53,7 +53,7 @@ class BasePostgresClient(BaseDBAsyncClient, abc.ABC): query_class: Type[PostgreSQLQuery] = PostgreSQLQuery executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor schema_generator: Type[BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator - capabilities = Capabilities("postgres", support_update_limit_order_by=False) + capabilities = Capabilities("postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True) connection_class: "Optional[Union[AsyncConnection, Connection]]" = None loop: Optional[AbstractEventLoop] = None _pool: Optional[Any] = None diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index a2d469ede..c2c6c34be 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -74,7 +74,7 @@ class MySQLClient(BaseDBAsyncClient): executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities( - "mysql", requires_limit=True, inline_comment=True, support_index_hint=True + "mysql", requires_limit=True, inline_comment=True, support_index_hint=True, support_for_posix_regex_queries=True ) def __init__( diff --git a/tortoise/backends/sqlite/__init__.py b/tortoise/backends/sqlite/__init__.py index 9209b053d..5c51ddf6f 100644 --- a/tortoise/backends/sqlite/__init__.py +++ b/tortoise/backends/sqlite/__init__.py @@ -1,3 +1,9 @@ -from .client import SqliteClient +from .client import SqliteClient, SqliteClientWithRegexpSupport client_class = SqliteClient + +def get_client_class(db_info: dict): + if db_info.get("credentials", {}).get("install_regexp_functions"): + return SqliteClientWithRegexpSupport + else: + return SqliteClient \ No newline at end of file diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 4156f1ef0..24411a00a 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -12,7 +12,7 @@ Sequence, Tuple, TypeVar, - cast, + cast, override, ) import aiosqlite @@ -29,7 +29,7 @@ ) from tortoise.backends.sqlite.executor import SqliteExecutor from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator -from tortoise.contrib.sqlite.regex import install_regexp_function +from tortoise.contrib.sqlite.regex import install_regexp_functions as install_regexp_functions_to_db from tortoise.connection import connections from tortoise.exceptions import ( IntegrityError, @@ -85,7 +85,6 @@ async def create_connection(self, with_db: bool) -> None: for pragma, val in self.pragmas.items(): cursor = await self._connection.execute(f"PRAGMA {pragma}={val}") await cursor.close() - await install_regexp_function(self._connection) self.log.debug( "Created connection %s with params: filename=%s %s", self._connection, @@ -214,6 +213,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class SqliteTransactionWrapper(SqliteClient, TransactionalDBClient): def __init__(self, connection: SqliteClient) -> None: + self.capabilities = connection.capabilities self.connection_name = connection.connection_name self._connection: aiosqlite.Connection = cast(aiosqlite.Connection, connection._connection) self._lock = asyncio.Lock() @@ -274,3 +274,18 @@ async def release_savepoint(self) -> None: def _gen_savepoint_name(_c=count()) -> str: return f"tortoise_savepoint_{next(_c)}" + + +class SqliteClientWithRegexpSupport(SqliteClient): + capabilities = Capabilities( + "sqlite", daemon=False, requires_limit=True, inline_comment=True, support_for_update=False, support_for_posix_regex_queries=True + ) + + async def create_connection(self, with_db: bool) -> None: + await super().create_connection(with_db) + await install_regexp_functions_to_db(self._connection) + + +# class TransactionWrapperWithRegexpSupport(TransactionWrapper): +# def __init__(self, connection: SqliteClientWithRegexpSupport) -> None: +# super().__init__(connection) \ No newline at end of file diff --git a/tortoise/connection.py b/tortoise/connection.py index 2ac6ddf6e..aee908468 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -65,13 +65,17 @@ def _copy_storage(self) -> Dict[str, "BaseDBAsyncClient"]: def _clear_storage(self) -> None: self._get_storage().clear() - def _discover_client_class(self, engine: str) -> Type["BaseDBAsyncClient"]: + def _discover_client_class(self, db_info: dict) -> Type["BaseDBAsyncClient"]: # Let exception bubble up for transparency - engine_module = importlib.import_module(engine) + engine_str = db_info.get("engine", "") + engine_module = importlib.import_module(engine_str) try: - client_class = engine_module.client_class + if hasattr(engine_module, "get_client_class"): + client_class = engine_module.get_client_class(db_info) + else: + client_class = engine_module.client_class except AttributeError: - raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') + raise ConfigurationError(f'Backend for engine "{engine_str}" does not implement db client') return client_class def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: @@ -93,7 +97,7 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": db_info = self._get_db_info(conn_alias) if isinstance(db_info, str): db_info = expand_db_url(db_info) - client_class = self._discover_client_class(db_info.get("engine", "")) + client_class = self._discover_client_class(db_info) db_params = db_info["credentials"].copy() db_params.update({"connection_name": conn_alias}) connection: "BaseDBAsyncClient" = client_class(**db_params) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 97dc1aef4..123971ae0 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -27,7 +27,7 @@ def insensitive_posix_sqlite_regexp(field: Term, value: str): ) -async def install_regexp_function(connection: aiosqlite.Connection): +async def install_regexp_functions(connection: aiosqlite.Connection): def regexp(expr, item): if not expr or not item: return False From 4df395cd41dec6bd08e9d22dd45601b35bfea0cd Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 16 Dec 2024 14:38:28 +0100 Subject: [PATCH 16/17] make style & make lint --- tests/contrib/test_pydantic.py | 1 + tests/test_connection.py | 8 ++++++-- tests/test_posix_regex_filter.py | 2 +- tortoise/backends/base_postgres/client.py | 4 +++- tortoise/backends/mysql/client.py | 6 +++++- tortoise/backends/sqlite/__init__.py | 3 ++- tortoise/backends/sqlite/client.py | 21 ++++++++++++--------- tortoise/connection.py | 4 +++- 8 files changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 1d03be6eb..a4fcd4466 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -7,6 +7,7 @@ Address, CamelCaseAliasPerson, Employee, + EnumFields, Event, IntFields, JSONFields, diff --git a/tests/test_connection.py b/tests/test_connection.py index 22faf844c..75e44a151 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -159,7 +159,9 @@ def test_create_connection_db_info_str( mocked_get_db_info.assert_called_once_with(alias) mocked_expand_db_url.assert_called_once_with("some_db_url") - mocked_discover_client_class.assert_called_once_with({"engine": "some_engine", "credentials": {"cred_key": "some_val"}}) + mocked_discover_client_class.assert_called_once_with( + {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} + ) expected_client_class.assert_called_once_with(**expected_db_params) self.assertEqual(ret_val, "some_connection") @@ -185,7 +187,9 @@ def test_create_connection_db_info_not_str( mocked_get_db_info.assert_called_once_with(alias) mocked_expand_db_url.assert_not_called() - mocked_discover_client_class.assert_called_once_with({"engine": "some_engine", "credentials": {"cred_key": "some_val"}}) + mocked_discover_client_class.assert_called_once_with( + {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} + ) expected_client_class.assert_called_once_with(**expected_db_params) self.assertEqual(ret_val, "some_connection") diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 0d09b01ce..422f100fb 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -1,12 +1,12 @@ from tests import testmodels from tortoise.contrib import test + class RegexTestCase(test.TestCase): async def asyncSetUp(self) -> None: await super().asyncSetUp() - class TestPosixRegexFilter(test.TestCase): @test.requireCapability(support_for_posix_regex_queries=True) diff --git a/tortoise/backends/base_postgres/client.py b/tortoise/backends/base_postgres/client.py index 365acb78a..ddca7fbe5 100644 --- a/tortoise/backends/base_postgres/client.py +++ b/tortoise/backends/base_postgres/client.py @@ -53,7 +53,9 @@ class BasePostgresClient(BaseDBAsyncClient, abc.ABC): query_class: Type[PostgreSQLQuery] = PostgreSQLQuery executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor schema_generator: Type[BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator - capabilities = Capabilities("postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True) + capabilities = Capabilities( + "postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True + ) connection_class: "Optional[Union[AsyncConnection, Connection]]" = None loop: Optional[AbstractEventLoop] = None _pool: Optional[Any] = None diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index c2c6c34be..983a1f969 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -74,7 +74,11 @@ class MySQLClient(BaseDBAsyncClient): executor_class = MySQLExecutor schema_generator = MySQLSchemaGenerator capabilities = Capabilities( - "mysql", requires_limit=True, inline_comment=True, support_index_hint=True, support_for_posix_regex_queries=True + "mysql", + requires_limit=True, + inline_comment=True, + support_index_hint=True, + support_for_posix_regex_queries=True, ) def __init__( diff --git a/tortoise/backends/sqlite/__init__.py b/tortoise/backends/sqlite/__init__.py index 5c51ddf6f..a32652fa0 100644 --- a/tortoise/backends/sqlite/__init__.py +++ b/tortoise/backends/sqlite/__init__.py @@ -2,8 +2,9 @@ client_class = SqliteClient + def get_client_class(db_info: dict): if db_info.get("credentials", {}).get("install_regexp_functions"): return SqliteClientWithRegexpSupport else: - return SqliteClient \ No newline at end of file + return SqliteClient diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 24411a00a..433b8be7b 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -12,7 +12,7 @@ Sequence, Tuple, TypeVar, - cast, override, + cast, ) import aiosqlite @@ -29,8 +29,10 @@ ) from tortoise.backends.sqlite.executor import SqliteExecutor from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator -from tortoise.contrib.sqlite.regex import install_regexp_functions as install_regexp_functions_to_db from tortoise.connection import connections +from tortoise.contrib.sqlite.regex import ( + install_regexp_functions as install_regexp_functions_to_db, +) from tortoise.exceptions import ( IntegrityError, OperationalError, @@ -278,14 +280,15 @@ def _gen_savepoint_name(_c=count()) -> str: class SqliteClientWithRegexpSupport(SqliteClient): capabilities = Capabilities( - "sqlite", daemon=False, requires_limit=True, inline_comment=True, support_for_update=False, support_for_posix_regex_queries=True + "sqlite", + daemon=False, + requires_limit=True, + inline_comment=True, + support_for_update=False, + support_for_posix_regex_queries=True, ) async def create_connection(self, with_db: bool) -> None: await super().create_connection(with_db) - await install_regexp_functions_to_db(self._connection) - - -# class TransactionWrapperWithRegexpSupport(TransactionWrapper): -# def __init__(self, connection: SqliteClientWithRegexpSupport) -> None: -# super().__init__(connection) \ No newline at end of file + if self._connection: + await install_regexp_functions_to_db(self._connection) diff --git a/tortoise/connection.py b/tortoise/connection.py index aee908468..4ac511c88 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -75,7 +75,9 @@ def _discover_client_class(self, db_info: dict) -> Type["BaseDBAsyncClient"]: else: client_class = engine_module.client_class except AttributeError: - raise ConfigurationError(f'Backend for engine "{engine_str}" does not implement db client') + raise ConfigurationError( + f'Backend for engine "{engine_str}" does not implement db client' + ) return client_class def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: From aa477532d94c77f7be01f5337864a70e1d756d04 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Mon, 30 Dec 2024 09:20:21 +0100 Subject: [PATCH 17/17] fix import issue after rebase --- tortoise/contrib/sqlite/regex.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index 123971ae0..e70e87a5f 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -3,9 +3,9 @@ from typing import cast import aiosqlite -from pypika.enums import SqlTypes -from pypika.functions import Cast, Coalesce -from pypika.terms import BasicCriterion, Term +from pypika_tortoise.enums import SqlTypes +from pypika_tortoise.functions import Cast, Coalesce +from pypika_tortoise.terms import BasicCriterion, Term class SQLiteRegexMatching(enum.Enum):