From b59c7c0462e8f8c051ecb55038be9de2bf71a094 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 29 Jan 2025 18:48:54 +0000 Subject: [PATCH] feat: implement test framework and move to sync only calls --- .pre-commit-config.yaml | 2 +- .vscode/settings.json | 6 +- pyproject.toml | 4 +- src/dma/__init__.py | 2 +- src/dma/cli/main.py | 59 +++++++-------- src/dma/collector/dependencies.py | 14 ++-- src/dma/collector/query_managers.py | 34 ++++----- src/dma/collector/workflows/base.py | 4 +- .../workflows/collection_extractor/base.py | 50 ++++++------- .../workflows/readiness_check/base.py | 12 +-- src/dma/lib/db/base.py | 15 ++-- src/dma/lib/db/query_manager.py | 37 +++++----- tests/integration/oracle/conftest.py | 45 +++++------- .../oracle/test_base_connectivity.py | 8 +- tests/integration/postgres/conftest.py | 73 +++++++++---------- .../postgres/test_base_connectivity.py | 11 +-- tests/integration/postgres/test_cli.py | 8 +- 17 files changed, 183 insertions(+), 201 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d46b594d..fc0fc3b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: # Ruff replaces black, flake8, autoflake and isort - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.8.0" # make sure this is always consistent with hatch configs + rev: "v0.9.3" # make sure this is always consistent with hatch configs hooks: - id: ruff args: [--config, ./pyproject.toml] diff --git a/.vscode/settings.json b/.vscode/settings.json index 8fd09fb2..f359ab17 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -150,8 +150,7 @@ "source.convertImportFormat" ], "sqltools.disableReleaseNotifications": true, - "sqltools.disableNodeDetectNotifications": true, - "cloudcode.duetAI.enable": true, +"sqltools.disableNodeDetectNotifications": true, "cloudcode.compute.sshInternalIp": true, "python.testing.pytestArgs": [ "tests" @@ -187,5 +186,6 @@ "python.analysis.experimentalserver": true, "python.analysis.diagnosticSeverityOverrides": { "reportUnknownMemberType": "none" -} +}, +"geminicodeassist.enable": true } diff --git a/pyproject.toml b/pyproject.toml index 093c624a..db4dd5f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ "sqlalchemy>=2.0.25", "typing-extensions>=4.0.0", "msgspec", - "greenlet; sys_platform == \"darwin\"", + "greenlet", ] [project.urls] @@ -75,7 +75,7 @@ Source = "https://github.com/GoogleCloudPlatform/database-assessment" mssql = ["aioodbc"] mysql = ["asyncmy>=0.2.9"] oracle = ["oracledb"] -postgres = ["asyncpg>=0.29.0"] +postgres = ["psycopg[pool,binary]"] server = ["litestar[structlog,jinja]>=2.7.0", "litestar-granian>=0.2.3"] diff --git a/src/dma/__init__.py b/src/dma/__init__.py index 05bfc7d2..3ca76154 100644 --- a/src/dma/__init__.py +++ b/src/dma/__init__.py @@ -30,7 +30,7 @@ click.rich_click.ERRORS_EPILOGUE = """ For additional support, refer to the documentation at https://googlecloudplatform.github.io/database-assessment/ """ -click.rich_click.MAX_WIDTH = 80 +click.rich_click.WIDTH = 80 click.rich_click.SHOW_METAVARS_COLUMN = True click.rich_click.APPEND_METAVARS_HELP = True click.rich_click.STYLE_OPTION = "bold cyan" diff --git a/src/dma/cli/main.py b/src/dma/cli/main.py index 20f38d16..8cc0027c 100644 --- a/src/dma/cli/main.py +++ b/src/dma/cli/main.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import asyncio from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -147,26 +146,23 @@ def collect_data( password = prompt.Prompt.ask("Please enter a password", password=True) input_confirmed = True if no_prompt else prompt.Confirm.ask("Are you ready to start the assessment?") if input_confirmed: - loop = asyncio.get_event_loop() - loop.run_until_complete( - _collect_data( - console=console, - src_info=SourceInfo( - db_type=db_type.upper(), # type: ignore[arg-type] - username=username, - password=password, - hostname=hostname, - port=port, - ), - database=database, - collection_identifier=collection_identifier, - ) + _collect_data( + console=console, + src_info=SourceInfo( + db_type=db_type.upper(), # type: ignore[arg-type] + username=username, + password=password, + hostname=hostname, + port=port, + ), + database=database, + collection_identifier=collection_identifier, ) else: console.rule("Skipping execution until input is confirmed", align="left") -async def _collect_data( +def _collect_data( console: Console, src_info: SourceInfo, database: str, @@ -185,7 +181,7 @@ async def _collect_data( console=console, collection_identifier=collection_identifier, ) - await collection_extractor.execute() + collection_extractor.execute() collection_extractor.dump_database(working_path) @@ -291,26 +287,23 @@ def readiness_assessment( password = prompt.Prompt.ask("Please enter a password", password=True) input_confirmed = True if no_prompt else prompt.Confirm.ask("Are you ready to start the assessment?") if input_confirmed: - loop = asyncio.get_event_loop() - loop.run_until_complete( - _readiness_check( - console=console, - src_info=SourceInfo( - db_type=db_type.upper(), # type: ignore[arg-type] - username=username, - password=password, - hostname=hostname, - port=port, - ), - database=database, - collection_identifier=collection_identifier, - ) + _readiness_check( + console=console, + src_info=SourceInfo( + db_type=db_type.upper(), # type: ignore[arg-type] + username=username, + password=password, + hostname=hostname, + port=port, + ), + database=database, + collection_identifier=collection_identifier, ) else: console.rule("Skipping execution until input is confirmed", align="left") -async def _readiness_check( +def _readiness_check( console: Console, src_info: SourceInfo, database: str, @@ -328,7 +321,7 @@ async def _readiness_check( collection_identifier=collection_identifier, working_path=working_path, ) - await workflow.execute() + workflow.execute() console.print(Padding("", 1, expand=True)) console.rule("Processing collected data.", align="left") workflow.print_summary() diff --git a/src/dma/collector/dependencies.py b/src/dma/collector/dependencies.py index b0e31c53..fb9e8f2d 100644 --- a/src/dma/collector/dependencies.py +++ b/src/dma/collector/dependencies.py @@ -27,19 +27,19 @@ from dma.lib.exceptions import ApplicationError if TYPE_CHECKING: - from collections.abc import AsyncIterator, Generator + from collections.abc import Generator, Iterator from pathlib import Path import duckdb - from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session -async def provide_collection_query_manager( - db_session: AsyncSession, +def provide_collection_query_manager( + db_session: Session, execution_id: str | None = None, source_id: str | None = None, manual_id: str | None = None, -) -> AsyncIterator[CollectionQueryManager]: +) -> Iterator[CollectionQueryManager]: """Provide collection query manager. Uses SQLAlchemy Connection management to establish and retrieve a valid database session. @@ -47,9 +47,9 @@ async def provide_collection_query_manager( The driver dialect is detected from the session and the underlying raw DBAPI connection is fetched and passed to the Query Manager. """ dialect = db_session.bind.dialect if db_session.bind is not None else db_session.get_bind().dialect - db_connection = await db_session.connection() + db_connection = db_session.connection() - raw_connection = await db_connection.get_raw_connection() + raw_connection = db_connection.engine.raw_connection() if not raw_connection.driver_connection: msg = "Unable to fetch raw connection from session." raise ApplicationError(msg) diff --git a/src/dma/collector/query_managers.py b/src/dma/collector/query_managers.py index 44bd9cf6..dac1102b 100644 --- a/src/dma/collector/query_managers.py +++ b/src/dma/collector/query_managers.py @@ -47,13 +47,13 @@ def __init__( self.manual_id = manual_id super().__init__(connection, queries) - async def execute_ddl_scripts(self, *args: Any, **kwargs: Any) -> None: + def execute_ddl_scripts(self, *args: Any, **kwargs: Any) -> None: """Execute pre-processing queries.""" console.print(Padding("CANONICAL DATA MODEL", 1, style="bold", expand=True), width=80) with console.status("[bold green]Creating tables...[/]") as status: for script in self.available_queries("ddl"): status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]") - await self.execute(script) + self.execute(script) status.console.print(rf" [green]:heavy_check_mark:[/] Created [bold magenta]`{script}`[/]") if not self.available_queries("ddl"): console.print(" [dim grey]:heavy_check_mark: No DDL scripts to load[/]") @@ -91,7 +91,7 @@ def get_extended_collection_queries(self) -> set[str]: raise ApplicationError(msg) return set(self.available_queries("extended_collection")) - def get_per_db_collection_queries(self) -> set[str]: # noqa: PLR6301 + def get_per_db_collection_queries(self) -> set[str]: """Get the collection queries that need to be executed for each DB in the instance""" msg = "Implement this execution method." raise NotImplementedError(msg) @@ -102,7 +102,7 @@ def get_db_version(self) -> str: raise ApplicationError(msg) return self.db_version - async def set_identifiers( + def set_identifiers( self, execution_id: str | None = None, source_id: str | None = None, @@ -119,7 +119,7 @@ async def set_identifiers( if db_version is not None: self.db_version = db_version if self.execution_id is None or self.source_id is None or self.db_version is None: - init_results = await self.execute_init_queries() + init_results = self.execute_init_queries() self.source_id = ( source_id if source_id is not None else cast("str | None", init_results.get("init_get_source_id", None)) ) @@ -139,7 +139,7 @@ async def set_identifiers( if self.expected_collection_queries is None: self.expected_collection_queries = self.get_collection_queries() - async def execute_init_queries( + def execute_init_queries( self, *args: Any, **kwargs: Any, @@ -150,7 +150,7 @@ async def execute_init_queries( results: dict[str, Any] = {} for script in self.available_queries("init"): status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]") - script_result = await self.select_one_value(script) + script_result = self.select_one_value(script) results[script] = script_result status.console.print(rf" [green]:heavy_check_mark:[/] Gathered [bold magenta]`{script}`[/]") if not self.available_queries("init"): @@ -159,7 +159,7 @@ async def execute_init_queries( ) return results - async def execute_collection_queries( + def execute_collection_queries( self, execution_id: str | None = None, source_id: str | None = None, @@ -168,13 +168,13 @@ async def execute_collection_queries( **kwargs: Any, ) -> dict[str, Any]: """Execute pre-processing queries.""" - await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) + self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) console.print(Padding("COLLECTION QUERIES", 1, style="bold", expand=True), width=80) with console.status("[bold green]Executing queries...[/]") as status: results: dict[str, Any] = {} for script in self.get_collection_queries(): status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]") - script_result = await self.select( + script_result = self.select( script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id ) results[script] = script_result @@ -183,7 +183,7 @@ async def execute_collection_queries( status.console.print(" [dim grey]:heavy_check_mark: No collection queries for this database type[/]") return results - async def execute_extended_collection_queries( + def execute_extended_collection_queries( self, execution_id: str | None = None, source_id: str | None = None, @@ -195,13 +195,13 @@ async def execute_extended_collection_queries( Returns: None """ - await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) + self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) console.print(Padding("EXTENDED COLLECTION QUERIES", 1, style="bold", expand=True), width=80) with console.status("[bold green]Executing queries...[/]") as status: results: dict[str, Any] = {} for script in self.get_extended_collection_queries(): status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]") - script_result = await self.select( + script_result = self.select( script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id ) results[script] = script_result @@ -210,7 +210,7 @@ async def execute_extended_collection_queries( console.print(" [dim grey]:heavy_check_mark: No extended collection queries for this database type[/]") return results - async def execute_per_db_collection_queries( + def execute_per_db_collection_queries( self, execution_id: str | None = None, source_id: str | None = None, @@ -219,14 +219,14 @@ async def execute_per_db_collection_queries( **kwargs: Any, ) -> dict[str, Any]: """Execute per DB pre-processing queries.""" - await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) + self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id) console.print(Padding("PER DB QUERIES", 1, style="bold", expand=True), width=80) with console.status("[bold green]Executing queries...[/]") as status: results: dict[str, Any] = {} for script in self.get_per_db_collection_queries(): status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]") try: - script_result = await self.select( + script_result = self.select( script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id ) results[script] = script_result @@ -248,7 +248,7 @@ def __init__( source_id: str | None = None, manual_id: str | None = None, queries: Queries = aiosql.from_path( - sql_path=f"{_root_path}/collector/sql/sources/postgres", driver_adapter="asyncpg" + sql_path=f"{_root_path}/collector/sql/sources/postgres", driver_adapter="psycopg" ), ) -> None: super().__init__( diff --git a/src/dma/collector/workflows/base.py b/src/dma/collector/workflows/base.py index 1d8dd936..ec0df6c9 100644 --- a/src/dma/collector/workflows/base.py +++ b/src/dma/collector/workflows/base.py @@ -42,9 +42,9 @@ def __init__( self.db_type = db_type self.canonical_query_manager = canonical_query_manager - async def execute(self) -> None: + def execute(self) -> None: """Execute Workflow""" - await self.canonical_query_manager.execute_ddl_scripts() + self.canonical_query_manager.execute_ddl_scripts() def import_to_table(self, data: dict[str, list[dict]]) -> None: """Load a dictionary of result sets into duckdb. diff --git a/src/dma/collector/workflows/collection_extractor/base.py b/src/dma/collector/workflows/collection_extractor/base.py index 88c61550..fffa88f0 100644 --- a/src/dma/collector/workflows/collection_extractor/base.py +++ b/src/dma/collector/workflows/collection_extractor/base.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from rich.table import Table -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from dma.__about__ import __version__ as current_version from dma.collector.dependencies import provide_collection_query_manager @@ -47,43 +47,43 @@ def __init__( self.collection_identifier = collection_identifier super().__init__(local_db, canonical_query_manager, src_info.db_type, console) - async def execute(self) -> None: - await super().execute() + def execute(self) -> None: + super().execute() execution_id = ( f"{self.src_info.db_type}_{current_version!s}_{datetime.now(tz=timezone.utc).strftime('%y%m%d%H%M%S')}" ) - await self.collect_data(execution_id) - await self.collect_db_specific_data(execution_id) + self.collect_data(execution_id) + self.collect_db_specific_data(execution_id) - async def collect_data(self, execution_id: str) -> None: - async_engine = get_engine(self.src_info, self.database) - async with AsyncSession(async_engine) as db_session: - collection_manager = await anext( # noqa: F821 # pyright: ignore[reportUndefinedVariable] + def collect_data(self, execution_id: str) -> None: + sync_engine = get_engine(self.src_info, self.database) + with Session(sync_engine) as db_session: + collection_manager = next( provide_collection_query_manager( db_session=db_session, execution_id=execution_id, manual_id=self.collection_identifier ) ) - await self.extract_collection(collection_manager) - await self.extract_extended_collection(collection_manager) - await self.process_collection() + self.extract_collection(collection_manager) + self.extract_extended_collection(collection_manager) + self.process_collection() self.db_version = collection_manager.get_db_version() - await async_engine.dispose() + sync_engine.dispose() - async def collect_db_specific_data(self, execution_id: str) -> None: - dbs = await self.get_all_dbs() + def collect_db_specific_data(self, execution_id: str) -> None: + dbs = self.get_all_dbs() for db in dbs: async_engine = get_engine(src_info=self.src_info, database=db) - async with AsyncSession(async_engine) as db_session: - collection_manager = await anext( # noqa: F821 # pyright: ignore[reportUndefinedVariable] + with Session(async_engine) as db_session: + collection_manager = next( provide_collection_query_manager( db_session=db_session, execution_id=execution_id, manual_id=self.collection_identifier ) ) - db_collection = await collection_manager.execute_per_db_collection_queries() + db_collection = collection_manager.execute_per_db_collection_queries() self.import_to_table(db_collection) - await async_engine.dispose() + async_engine.dispose() - async def get_all_dbs(self) -> set[str]: + def get_all_dbs(self) -> set[str]: result = self.local_db.sql(""" select database_name from extended_collection_postgres_all_databases """).fetchall() @@ -95,15 +95,15 @@ def get_db_version(self) -> str: raise ApplicationError(msg) return self.db_version - async def extract_collection(self, collection_query_manager: CollectionQueryManager) -> None: - collection = await collection_query_manager.execute_collection_queries() + def extract_collection(self, collection_query_manager: CollectionQueryManager) -> None: + collection = collection_query_manager.execute_collection_queries() self.import_to_table(collection) - async def extract_extended_collection(self, collection_query_manager: CollectionQueryManager) -> None: - extended_collection = await collection_query_manager.execute_extended_collection_queries() + def extract_extended_collection(self, collection_query_manager: CollectionQueryManager) -> None: + extended_collection = collection_query_manager.execute_extended_collection_queries() self.import_to_table(extended_collection) - async def process_collection(self) -> None: + def process_collection(self) -> None: """Process Collections""" def print_summary(self) -> None: diff --git a/src/dma/collector/workflows/readiness_check/base.py b/src/dma/collector/workflows/readiness_check/base.py index 61d188a5..2d40b010 100644 --- a/src/dma/collector/workflows/readiness_check/base.py +++ b/src/dma/collector/workflows/readiness_check/base.py @@ -66,11 +66,11 @@ def __init__( self.collection_identifier = collection_identifier self.working_path = working_path - async def execute(self) -> None: - await self.execute_data_collection() + def execute(self) -> None: + self.execute_data_collection() self.execute_readiness_check() - async def execute_data_collection(self) -> None: + def execute_data_collection(self) -> None: canonical_query_manager = next( provide_canonical_queries(local_db=self.local_db, working_path=self.working_path) ) @@ -83,7 +83,7 @@ async def execute_data_collection(self) -> None: console=self.console, collection_identifier=self.collection_identifier, ) - await self.collection_extractor.execute() + self.collection_extractor.execute() self.db_version = self.collection_extractor.get_db_version() def execute_readiness_check(self) -> None: @@ -134,7 +134,7 @@ def __init__(self, console: Console, readiness_check: ReadinessCheck) -> None: self.local_db = readiness_check.local_db self.db_version = readiness_check.db_version - def execute(self) -> None: # noqa: PLR6301 + def execute(self) -> None: """Execute checks""" msg = "Implement this execution method." raise NotImplementedError(msg) @@ -145,7 +145,7 @@ def get_all_dbs(self) -> set[str]: """).fetchall() return {row[0] for row in result} - def print_summary(self) -> None: # noqa: PLR6301 + def print_summary(self) -> None: """Summarizes results""" msg = "Implement this execution method." raise NotImplementedError(msg) diff --git a/src/dma/lib/db/base.py b/src/dma/lib/db/base.py index a43ace3c..86278d36 100644 --- a/src/dma/lib/db/base.py +++ b/src/dma/lib/db/base.py @@ -16,8 +16,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from sqlalchemy import URL -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy import URL, Engine, create_engine if TYPE_CHECKING: from dma.types import ( @@ -37,11 +36,11 @@ class SourceInfo: def get_engine( src_info: SourceInfo, database: str, -) -> AsyncEngine: +) -> Engine: if src_info.db_type == "POSTGRES": - return create_async_engine( + return create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=src_info.username, password=src_info.password, host=src_info.hostname, @@ -51,7 +50,7 @@ def get_engine( ), ) if src_info.db_type == "MYSQL": - return create_async_engine( + return create_engine( URL( drivername="mysql+asyncmy", username=src_info.username, @@ -63,7 +62,7 @@ def get_engine( ), ) if src_info.db_type == "MSSQL": - return create_async_engine( + return create_engine( URL( drivername="mssql+aioodbc", username=src_info.username, @@ -83,7 +82,7 @@ def get_engine( ), ) if src_info.db_type == "ORACLE": - return create_async_engine( + return create_engine( "oracle+oracledb://:@", thick_mode=False, connect_args={ diff --git a/src/dma/lib/db/query_manager.py b/src/dma/lib/db/query_manager.py index 2ff81b7d..830161f3 100644 --- a/src/dma/lib/db/query_manager.py +++ b/src/dma/lib/db/query_manager.py @@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Any, TypeVar from dma.lib.exceptions import ApplicationError -from dma.utils import maybe_async faulthandler.enable() if TYPE_CHECKING: - from collections.abc import AsyncIterator + from collections.abc import Iterator from aiosql.queries import Queries @@ -50,37 +49,37 @@ def available_queries(self, prefix: str | None = None) -> list[str]: ) @classmethod - @contextlib.asynccontextmanager - async def from_connection( + @contextlib.contextmanager + def from_connection( cls: type[QueryManagerT], queries: Queries, connection: Any, - ) -> AsyncIterator[QueryManagerT]: + ) -> Iterator[QueryManagerT]: """Context manager that returns instance of query manager object.""" yield cls(connection=connection, queries=queries) - async def select(self, method: str, **binds: Any) -> list[dict[str, Any]]: - data = await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def select(self, method: str, **binds: Any) -> list[dict[str, Any]]: + data = self.fn(method)(conn=self.connection, **binds) return [dict(row) for row in data] - async def select_one(self, method: str, **binds: Any) -> dict[str, Any]: - data = await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def select_one(self, method: str, **binds: Any) -> dict[str, Any]: + data = self.fn(method)(conn=self.connection, **binds) return dict(data) - async def select_one_value(self, method: str, **binds: Any) -> Any: - return await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def select_one_value(self, method: str, **binds: Any) -> Any: + return self.fn(method)(conn=self.connection, **binds) - async def insert_update_delete(self, method: str, **binds: Any) -> None: - return await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def insert_update_delete(self, method: str, **binds: Any) -> None: + return self.fn(method)(conn=self.connection, **binds) - async def insert_update_delete_many(self, method: str, **binds: Any) -> Any | None: - return await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def insert_update_delete_many(self, method: str, **binds: Any) -> Any | None: + return self.fn(method)(conn=self.connection, **binds) - async def insert_returning(self, method: str, **binds: Any) -> Any | None: - return await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def insert_returning(self, method: str, **binds: Any) -> Any | None: + return self.fn(method)(conn=self.connection, **binds) - async def execute(self, method: str, **binds: Any) -> Any: - return await maybe_async(self.fn(method)(conn=self.connection, **binds)) + def execute(self, method: str, **binds: Any) -> Any: + return self.fn(method)(conn=self.connection, **binds) def fn(self, method: str) -> Any: try: diff --git a/tests/integration/oracle/conftest.py b/tests/integration/oracle/conftest.py index 0efbb78a..d7e4b703 100644 --- a/tests/integration/oracle/conftest.py +++ b/tests/integration/oracle/conftest.py @@ -16,12 +16,14 @@ from __future__ import annotations import platform -from typing import cast +from typing import TYPE_CHECKING, cast import pytest from pytest import FixtureRequest -from sqlalchemy import NullPool -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy import Engine, NullPool, create_engine + +if TYPE_CHECKING: + from collections.abc import Generator pytestmark = [ pytest.mark.anyio, @@ -32,21 +34,16 @@ @pytest.fixture(scope="session") -async def oracle18c_async_engine( +def oracle18c_sync_engine( oracle_docker_ip: str, oracle_user: str, oracle_password: str, oracle18c_port: int, oracle18c_service_name: str, oracle18c_service: None, -) -> AsyncEngine: - """Oracle 18c instance for end-to-end testing. - - - Returns: - Async SQLAlchemy engine instance. - """ - return create_async_engine( +) -> Generator[Engine, None, None]: + """Oracle 18c instance for end-to-end testing.""" + yield create_engine( "oracle+oracledb://:@", thick_mode=False, connect_args={ @@ -61,22 +58,16 @@ async def oracle18c_async_engine( @pytest.fixture(scope="session") -async def oracle23ai_async_engine( +def oracle23ai_engine( oracle_docker_ip: str, oracle_user: str, oracle_password: str, oracle23ai_port: int, oracle23ai_service_name: str, oracle23ai_service: None, -) -> AsyncEngine: - """Oracle 23c instance for end-to-end testing. - - - - Returns: - Async SQLAlchemy engine instance. - """ - return create_async_engine( +) -> Generator[Engine, None, None]: + """Oracle 23c instance for end-to-end testing.""" + yield create_engine( "oracle+oracledb://:@", thick_mode=False, connect_args={ @@ -92,17 +83,17 @@ async def oracle23ai_async_engine( @pytest.fixture( scope="session", - name="async_engine", + name="sync_engine", params=[ pytest.param( - "oracle18c_async_engine", + "oracle18c_sync_engine", marks=[pytest.mark.oracle], ), pytest.param( - "oracle23ai_async_engine", + "oracle23ai_sync_engine", marks=[pytest.mark.oracle], ), ], ) -def async_engine(request: FixtureRequest) -> AsyncEngine: - return cast("AsyncEngine", request.getfixturevalue(request.param)) +def sync_engine(request: FixtureRequest) -> Generator[Engine, None, None]: + yield cast("Engine", request.getfixturevalue(request.param)) diff --git a/tests/integration/oracle/test_base_connectivity.py b/tests/integration/oracle/test_base_connectivity.py index 4c3f0f07..abfcefab 100644 --- a/tests/integration/oracle/test_base_connectivity.py +++ b/tests/integration/oracle/test_base_connectivity.py @@ -21,7 +21,7 @@ from sqlalchemy import text if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncEngine + from sqlalchemy import Engine pytestmark = [ @@ -31,8 +31,8 @@ ] -async def test_engine_connectivity(async_engine: AsyncEngine) -> None: - async with async_engine.begin() as conn: - await conn.execute( +def test_engine_connectivity(sync_engine: Engine) -> None: + with sync_engine.begin() as conn: + conn.execute( text("select 1 from dual"), ) diff --git a/tests/integration/postgres/conftest.py b/tests/integration/postgres/conftest.py index 49dfd816..2a9abd61 100644 --- a/tests/integration/postgres/conftest.py +++ b/tests/integration/postgres/conftest.py @@ -21,8 +21,7 @@ import pytest from click.testing import CliRunner from pytest import FixtureRequest -from sqlalchemy import URL, NullPool, text -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy import URL, Engine, NullPool, create_engine, text if TYPE_CHECKING: from collections.abc import Generator @@ -40,18 +39,18 @@ def runner() -> CliRunner: @pytest.fixture(scope="session") -async def postgres17_async_engine( +def postgres17_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres17_port: int, postgres17_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -64,18 +63,18 @@ async def postgres17_async_engine( @pytest.fixture(scope="session") -async def postgres16_async_engine( +def postgres16_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres16_port, postgres16_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -88,18 +87,18 @@ async def postgres16_async_engine( @pytest.fixture(scope="session") -async def postgres15_async_engine( +def postgres15_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres15_port, postgres15_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -112,18 +111,18 @@ async def postgres15_async_engine( @pytest.fixture(scope="session") -async def postgres14_async_engine( +def postgres14_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres14_port, postgres14_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -136,18 +135,18 @@ async def postgres14_async_engine( @pytest.fixture(scope="session") -async def postgres13_async_engine( +def postgres13_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres13_port, postgres13_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -160,18 +159,18 @@ async def postgres13_async_engine( @pytest.fixture(scope="session") -async def postgres12_async_engine( +def postgres12_sync_engine( postgres_docker_ip: str, postgres_user: str, postgres_password: str, postgres_database: str, postgres12_port, postgres12_service: None, -) -> AsyncEngine: +) -> Generator[Engine, None, None]: """Postgresql instance for end-to-end testing.""" - return create_async_engine( + yield create_engine( URL( - drivername="postgresql+asyncpg", + drivername="postgresql+psycopg", username=postgres_user, password=postgres_password, host=postgres_docker_ip, @@ -187,48 +186,48 @@ async def postgres12_async_engine( scope="session", params=[ pytest.param( - "postgres12_async_engine", + "postgres12_sync_engine", marks=[ pytest.mark.postgres, ], ), pytest.param( - "postgres13_async_engine", + "postgres13_sync_engine", marks=[ pytest.mark.postgres, ], ), pytest.param( - "postgres14_async_engine", + "postgres14_sync_engine", marks=[ pytest.mark.postgres, ], ), pytest.param( - "postgres15_async_engine", + "postgres15_sync_engine", marks=[ pytest.mark.postgres, ], ), pytest.param( - "postgres16_async_engine", + "postgres16_sync_engine", marks=[ pytest.mark.postgres, ], ), pytest.param( - "postgres17_async_engine", + "postgres17_sync_engine", marks=[ pytest.mark.postgres, ], ), ], ) -def async_engine(request: FixtureRequest) -> Generator[AsyncEngine, None, None]: - yield cast("AsyncEngine", request.getfixturevalue(request.param)) +def sync_engine(request: FixtureRequest) -> Generator[Engine, None, None]: + yield cast("Engine", request.getfixturevalue(request.param)) @pytest.fixture(scope="session") -async def _seed_postgres_database(async_engine: AsyncEngine) -> None: - async with async_engine.begin() as conn: - await conn.execute(text(dedent("""create extension if not exists pg_stat_statements;"""))) +def _seed_postgres_database(sync_engine: Engine) -> None: + with sync_engine.begin() as conn: + conn.execute(text(dedent("""create extension if not exists pg_stat_statements;"""))) diff --git a/tests/integration/postgres/test_base_connectivity.py b/tests/integration/postgres/test_base_connectivity.py index c12139d4..426af535 100644 --- a/tests/integration/postgres/test_base_connectivity.py +++ b/tests/integration/postgres/test_base_connectivity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the Oracle.""" +"""Unit tests for the Postgres Connectivity.""" from __future__ import annotations @@ -21,7 +21,7 @@ from sqlalchemy import text if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncEngine + from sqlalchemy import Engine pytestmark = [ @@ -31,8 +31,9 @@ ] -async def test_engine_connectivity(async_engine: AsyncEngine) -> None: - async with async_engine.begin() as conn: - await conn.execute( +def test_engine_connectivity(sync_engine: Engine) -> None: + with sync_engine.begin() as conn: + result = conn.execute( text("select 1"), ) + assert result.scalar() == 1 diff --git a/tests/integration/postgres/test_cli.py b/tests/integration/postgres/test_cli.py index ade7b1e1..78a6c599 100644 --- a/tests/integration/postgres/test_cli.py +++ b/tests/integration/postgres/test_cli.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from click.testing import CliRunner - from sqlalchemy.ext.asyncio import AsyncEngine + from sqlalchemy import Engine pytestmark = [ pytest.mark.anyio, @@ -33,12 +33,12 @@ ] -async def test_cli_postgres( - async_engine: AsyncEngine, +def test_cli_postgres( + sync_engine: Engine, _seed_postgres_database: None, runner: CliRunner, ) -> None: - url = urlparse(str(async_engine.url)) + url = urlparse(str(sync_engine.url.render_as_string(hide_password=False))) result = runner.invoke( app, [