diff --git a/.changes/unreleased/Under the Hood-20240426-142511.yaml b/.changes/unreleased/Under the Hood-20240426-142511.yaml new file mode 100644 index 00000000..33a067d9 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240426-142511.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Support record/replay mode. +time: 2024-04-26T14:25:11.251251-04:00 +custom: + Author: peterallenwebb + Issue: "407" diff --git a/dbt/__init__.py b/dbt/__init__.py deleted file mode 100644 index 782ff40f..00000000 --- a/dbt/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# N.B. -# This will add to the package’s __path__ all subdirectories of directories on sys.path named after the package which effectively combines both modules into a single namespace (dbt.adapters) - -from pkgutil import extend_path - -__path__ = extend_path(__path__, __name__) diff --git a/dbt/adapters/__init__.py b/dbt/adapters/__init__.py deleted file mode 100644 index 1713e032..00000000 --- a/dbt/adapters/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -This adds all subdirectories of directories on `sys.path` to this package’s `__path__` . -It effectively combines all adapters into a single namespace (dbt.adapter). -""" - -from pkgutil import extend_path - -__path__ = extend_path(__path__, __name__) diff --git a/dbt/adapters/record.py b/dbt/adapters/record.py new file mode 100644 index 00000000..3fdc027b --- /dev/null +++ b/dbt/adapters/record.py @@ -0,0 +1,67 @@ +import dataclasses +from io import StringIO +import json +import re +from typing import Any, Optional, Mapping + +from agate import Table + +from dbt_common.events.contextvars import get_node_info +from dbt_common.record import Record, Recorder + +from dbt.adapters.contracts.connection import AdapterResponse + + +@dataclasses.dataclass +class QueryRecordParams: + sql: str + auto_begin: bool = False + fetch: bool = False + limit: Optional[int] = None + node_unique_id: Optional[str] = None + + def __post_init__(self): + if self.node_unique_id is None: + node_info = get_node_info() + self.node_unique_id = node_info["unique_id"] if node_info else "" + + @staticmethod + def _clean_up_sql(sql: str) -> str: + sql = re.sub(r"--.*?\n", "", sql) # Remove single-line comments (--) + sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) # Remove multi-line comments (/* */) + return sql.replace(" ", "").replace("\n", "") + + def _matches(self, other: "QueryRecordParams") -> bool: + return self.node_unique_id == other.node_unique_id and self._clean_up_sql( + self.sql + ) == self._clean_up_sql(other.sql) + + +@dataclasses.dataclass +class QueryRecordResult: + adapter_response: Optional["AdapterResponse"] + table: Optional[Table] + + def _to_dict(self) -> Any: + buf = StringIO() + self.table.to_json(buf) # type: ignore + + return { + "adapter_response": self.adapter_response.to_dict(), # type: ignore + "table": buf.getvalue(), + } + + @classmethod + def _from_dict(cls, dct: Mapping) -> "QueryRecordResult": + return QueryRecordResult( + adapter_response=AdapterResponse.from_dict(dct["adapter_response"]), + table=Table.from_object(json.loads(dct["table"])), + ) + + +class QueryRecord(Record): + params_cls = QueryRecordParams + result_cls = QueryRecordResult + + +Recorder.register_record_type(QueryRecord) diff --git a/dbt/adapters/sql/connections.py b/dbt/adapters/sql/connections.py index 78cd3c9b..9adaafce 100644 --- a/dbt/adapters/sql/connections.py +++ b/dbt/adapters/sql/connections.py @@ -5,6 +5,7 @@ from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event from dbt_common.exceptions import DbtInternalError, NotImplementedError +from dbt_common.record import record_function from dbt_common.utils import cast_to_str from dbt.adapters.base import BaseConnectionManager @@ -19,6 +20,7 @@ SQLQuery, SQLQueryStatus, ) +from dbt.adapters.record import QueryRecord if TYPE_CHECKING: import agate @@ -143,6 +145,7 @@ def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Tab return table_from_data_flat(data, column_names) + @record_function(QueryRecord, method=True, tuple_result=True) def execute( self, sql: str,