Skip to content

Commit

Permalink
feat: implement test framework and move to sync only calls
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Feb 6, 2025
1 parent 693c8d6 commit b59c7c0
Show file tree
Hide file tree
Showing 17 changed files with 183 additions and 201 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:

# Ruff replaces black, flake8, autoflake and isort
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.8.0" # make sure this is always consistent with hatch configs
rev: "v0.9.3" # make sure this is always consistent with hatch configs
hooks:
- id: ruff
args: [--config, ./pyproject.toml]
Expand Down
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@
"source.convertImportFormat"
],
"sqltools.disableReleaseNotifications": true,
"sqltools.disableNodeDetectNotifications": true,
"cloudcode.duetAI.enable": true,
"sqltools.disableNodeDetectNotifications": true,
"cloudcode.compute.sshInternalIp": true,
"python.testing.pytestArgs": [
"tests"
Expand Down Expand Up @@ -187,5 +186,6 @@
"python.analysis.experimentalserver": true,
"python.analysis.diagnosticSeverityOverrides": {
"reportUnknownMemberType": "none"
}
},
"geminicodeassist.enable": true
}
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dependencies = [
"sqlalchemy>=2.0.25",
"typing-extensions>=4.0.0",
"msgspec",
"greenlet; sys_platform == \"darwin\"",
"greenlet",
]

[project.urls]
Expand All @@ -75,7 +75,7 @@ Source = "https://github.com/GoogleCloudPlatform/database-assessment"
mssql = ["aioodbc"]
mysql = ["asyncmy>=0.2.9"]
oracle = ["oracledb"]
postgres = ["asyncpg>=0.29.0"]
postgres = ["psycopg[pool,binary]"]
server = ["litestar[structlog,jinja]>=2.7.0", "litestar-granian>=0.2.3"]


Expand Down
2 changes: 1 addition & 1 deletion src/dma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
click.rich_click.ERRORS_EPILOGUE = """
For additional support, refer to the documentation at https://googlecloudplatform.github.io/database-assessment/
"""
click.rich_click.MAX_WIDTH = 80
click.rich_click.WIDTH = 80
click.rich_click.SHOW_METAVARS_COLUMN = True
click.rich_click.APPEND_METAVARS_HELP = True
click.rich_click.STYLE_OPTION = "bold cyan"
Expand Down
59 changes: 26 additions & 33 deletions src/dma/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import asyncio
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -147,26 +146,23 @@ def collect_data(
password = prompt.Prompt.ask("Please enter a password", password=True)
input_confirmed = True if no_prompt else prompt.Confirm.ask("Are you ready to start the assessment?")
if input_confirmed:
loop = asyncio.get_event_loop()
loop.run_until_complete(
_collect_data(
console=console,
src_info=SourceInfo(
db_type=db_type.upper(), # type: ignore[arg-type]
username=username,
password=password,
hostname=hostname,
port=port,
),
database=database,
collection_identifier=collection_identifier,
)
_collect_data(
console=console,
src_info=SourceInfo(
db_type=db_type.upper(), # type: ignore[arg-type]
username=username,
password=password,
hostname=hostname,
port=port,
),
database=database,
collection_identifier=collection_identifier,
)
else:
console.rule("Skipping execution until input is confirmed", align="left")


async def _collect_data(
def _collect_data(
console: Console,
src_info: SourceInfo,
database: str,
Expand All @@ -185,7 +181,7 @@ async def _collect_data(
console=console,
collection_identifier=collection_identifier,
)
await collection_extractor.execute()
collection_extractor.execute()
collection_extractor.dump_database(working_path)


Expand Down Expand Up @@ -291,26 +287,23 @@ def readiness_assessment(
password = prompt.Prompt.ask("Please enter a password", password=True)
input_confirmed = True if no_prompt else prompt.Confirm.ask("Are you ready to start the assessment?")
if input_confirmed:
loop = asyncio.get_event_loop()
loop.run_until_complete(
_readiness_check(
console=console,
src_info=SourceInfo(
db_type=db_type.upper(), # type: ignore[arg-type]
username=username,
password=password,
hostname=hostname,
port=port,
),
database=database,
collection_identifier=collection_identifier,
)
_readiness_check(
console=console,
src_info=SourceInfo(
db_type=db_type.upper(), # type: ignore[arg-type]
username=username,
password=password,
hostname=hostname,
port=port,
),
database=database,
collection_identifier=collection_identifier,
)
else:
console.rule("Skipping execution until input is confirmed", align="left")


async def _readiness_check(
def _readiness_check(
console: Console,
src_info: SourceInfo,
database: str,
Expand All @@ -328,7 +321,7 @@ async def _readiness_check(
collection_identifier=collection_identifier,
working_path=working_path,
)
await workflow.execute()
workflow.execute()
console.print(Padding("", 1, expand=True))
console.rule("Processing collected data.", align="left")
workflow.print_summary()
Expand Down
14 changes: 7 additions & 7 deletions src/dma/collector/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@
from dma.lib.exceptions import ApplicationError

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Generator
from collections.abc import Generator, Iterator
from pathlib import Path

import duckdb
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session


async def provide_collection_query_manager(
db_session: AsyncSession,
def provide_collection_query_manager(
db_session: Session,
execution_id: str | None = None,
source_id: str | None = None,
manual_id: str | None = None,
) -> AsyncIterator[CollectionQueryManager]:
) -> Iterator[CollectionQueryManager]:
"""Provide collection query manager.
Uses SQLAlchemy Connection management to establish and retrieve a valid database session.
The driver dialect is detected from the session and the underlying raw DBAPI connection is fetched and passed to the Query Manager.
"""
dialect = db_session.bind.dialect if db_session.bind is not None else db_session.get_bind().dialect
db_connection = await db_session.connection()
db_connection = db_session.connection()

raw_connection = await db_connection.get_raw_connection()
raw_connection = db_connection.engine.raw_connection()
if not raw_connection.driver_connection:
msg = "Unable to fetch raw connection from session."
raise ApplicationError(msg)
Expand Down
34 changes: 17 additions & 17 deletions src/dma/collector/query_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def __init__(
self.manual_id = manual_id
super().__init__(connection, queries)

async def execute_ddl_scripts(self, *args: Any, **kwargs: Any) -> None:
def execute_ddl_scripts(self, *args: Any, **kwargs: Any) -> None:
"""Execute pre-processing queries."""
console.print(Padding("CANONICAL DATA MODEL", 1, style="bold", expand=True), width=80)
with console.status("[bold green]Creating tables...[/]") as status:
for script in self.available_queries("ddl"):
status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]")
await self.execute(script)
self.execute(script)
status.console.print(rf" [green]:heavy_check_mark:[/] Created [bold magenta]`{script}`[/]")
if not self.available_queries("ddl"):
console.print(" [dim grey]:heavy_check_mark: No DDL scripts to load[/]")
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_extended_collection_queries(self) -> set[str]:
raise ApplicationError(msg)
return set(self.available_queries("extended_collection"))

def get_per_db_collection_queries(self) -> set[str]: # noqa: PLR6301
def get_per_db_collection_queries(self) -> set[str]:
"""Get the collection queries that need to be executed for each DB in the instance"""
msg = "Implement this execution method."
raise NotImplementedError(msg)
Expand All @@ -102,7 +102,7 @@ def get_db_version(self) -> str:
raise ApplicationError(msg)
return self.db_version

async def set_identifiers(
def set_identifiers(
self,
execution_id: str | None = None,
source_id: str | None = None,
Expand All @@ -119,7 +119,7 @@ async def set_identifiers(
if db_version is not None:
self.db_version = db_version
if self.execution_id is None or self.source_id is None or self.db_version is None:
init_results = await self.execute_init_queries()
init_results = self.execute_init_queries()
self.source_id = (
source_id if source_id is not None else cast("str | None", init_results.get("init_get_source_id", None))
)
Expand All @@ -139,7 +139,7 @@ async def set_identifiers(
if self.expected_collection_queries is None:
self.expected_collection_queries = self.get_collection_queries()

async def execute_init_queries(
def execute_init_queries(
self,
*args: Any,
**kwargs: Any,
Expand All @@ -150,7 +150,7 @@ async def execute_init_queries(
results: dict[str, Any] = {}
for script in self.available_queries("init"):
status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]")
script_result = await self.select_one_value(script)
script_result = self.select_one_value(script)
results[script] = script_result
status.console.print(rf" [green]:heavy_check_mark:[/] Gathered [bold magenta]`{script}`[/]")
if not self.available_queries("init"):
Expand All @@ -159,7 +159,7 @@ async def execute_init_queries(
)
return results

async def execute_collection_queries(
def execute_collection_queries(
self,
execution_id: str | None = None,
source_id: str | None = None,
Expand All @@ -168,13 +168,13 @@ async def execute_collection_queries(
**kwargs: Any,
) -> dict[str, Any]:
"""Execute pre-processing queries."""
await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
console.print(Padding("COLLECTION QUERIES", 1, style="bold", expand=True), width=80)
with console.status("[bold green]Executing queries...[/]") as status:
results: dict[str, Any] = {}
for script in self.get_collection_queries():
status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]")
script_result = await self.select(
script_result = self.select(
script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id
)
results[script] = script_result
Expand All @@ -183,7 +183,7 @@ async def execute_collection_queries(
status.console.print(" [dim grey]:heavy_check_mark: No collection queries for this database type[/]")
return results

async def execute_extended_collection_queries(
def execute_extended_collection_queries(
self,
execution_id: str | None = None,
source_id: str | None = None,
Expand All @@ -195,13 +195,13 @@ async def execute_extended_collection_queries(
Returns: None
"""
await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
console.print(Padding("EXTENDED COLLECTION QUERIES", 1, style="bold", expand=True), width=80)
with console.status("[bold green]Executing queries...[/]") as status:
results: dict[str, Any] = {}
for script in self.get_extended_collection_queries():
status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]")
script_result = await self.select(
script_result = self.select(
script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id
)
results[script] = script_result
Expand All @@ -210,7 +210,7 @@ async def execute_extended_collection_queries(
console.print(" [dim grey]:heavy_check_mark: No extended collection queries for this database type[/]")
return results

async def execute_per_db_collection_queries(
def execute_per_db_collection_queries(
self,
execution_id: str | None = None,
source_id: str | None = None,
Expand All @@ -219,14 +219,14 @@ async def execute_per_db_collection_queries(
**kwargs: Any,
) -> dict[str, Any]:
"""Execute per DB pre-processing queries."""
await self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
self.set_identifiers(execution_id=execution_id, source_id=source_id, manual_id=manual_id)
console.print(Padding("PER DB QUERIES", 1, style="bold", expand=True), width=80)
with console.status("[bold green]Executing queries...[/]") as status:
results: dict[str, Any] = {}
for script in self.get_per_db_collection_queries():
status.update(rf" [yellow]*[/] Executing [bold magenta]`{script}`[/]")
try:
script_result = await self.select(
script_result = self.select(
script, PKEY=self.execution_id, DMA_SOURCE_ID=self.source_id, DMA_MANUAL_ID=self.manual_id
)
results[script] = script_result
Expand All @@ -248,7 +248,7 @@ def __init__(
source_id: str | None = None,
manual_id: str | None = None,
queries: Queries = aiosql.from_path(
sql_path=f"{_root_path}/collector/sql/sources/postgres", driver_adapter="asyncpg"
sql_path=f"{_root_path}/collector/sql/sources/postgres", driver_adapter="psycopg"
),
) -> None:
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions src/dma/collector/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(
self.db_type = db_type
self.canonical_query_manager = canonical_query_manager

async def execute(self) -> None:
def execute(self) -> None:
"""Execute Workflow"""
await self.canonical_query_manager.execute_ddl_scripts()
self.canonical_query_manager.execute_ddl_scripts()

def import_to_table(self, data: dict[str, list[dict]]) -> None:
"""Load a dictionary of result sets into duckdb.
Expand Down
Loading

0 comments on commit b59c7c0

Please sign in to comment.