From 6b62d37af89a6978e8bfdfdc8a0678d07c607d26 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 29 Aug 2024 18:27:47 -0400 Subject: [PATCH] More tests passing --- dlt/common/time.py | 8 ++ .../impl/sqlalchemy/db_api_client.py | 89 ++++++++++++++----- dlt/destinations/impl/sqlalchemy/factory.py | 2 +- .../impl/sqlalchemy/sqlalchemy_job_client.py | 18 +++- tests/load/pipeline/test_arrow_loading.py | 9 +- tests/load/test_job_client.py | 43 ++++++--- tests/load/test_sql_client.py | 12 +-- tests/load/utils.py | 6 +- 8 files changed, 135 insertions(+), 52 deletions(-) diff --git a/dlt/common/time.py b/dlt/common/time.py index 8532f566b8..26de0b5645 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -143,6 +143,14 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time: return result else: raise ValueError(f"{value} is not a valid ISO time string.") + elif isinstance(value, timedelta): + # Assume timedelta is seconds passed since midnight. Some drivers (mysqlclient) return time in this format + return pendulum.time( + value.seconds // 3600, + (value.seconds // 60) % 60, + value.seconds % 60, + value.microseconds, + ) raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.") diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index a79cc90d7f..0969fcb628 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -1,4 +1,15 @@ -from typing import Optional, Iterator, Any, Sequence, ContextManager, AnyStr, Union, Tuple, List, Dict +from typing import ( + Optional, + Iterator, + Any, + Sequence, + ContextManager, + AnyStr, + Union, + Tuple, + List, + Dict, +) from contextlib import contextmanager from functools import wraps import inspect @@ -11,11 +22,12 @@ from dlt.destinations.exceptions import ( DatabaseUndefinedRelation, DatabaseTerminalException, + DatabaseTransientException, LoadClientNotConnected, DatabaseException, ) -from dlt.destinations.typing import DBTransaction -from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.typing import DBTransaction, DBApiCursor +from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials from dlt.common.typing import TFun @@ -25,10 +37,12 @@ def __init__(self, sqla_transaction: sa.engine.Transaction) -> None: self.sqla_transaction = sqla_transaction def commit_transaction(self) -> None: - self.sqla_transaction.commit() + if self.sqla_transaction.is_active: + self.sqla_transaction.commit() def rollback_transaction(self) -> None: - self.sqla_transaction.rollback() + if self.sqla_transaction.is_active: + self.sqla_transaction.rollback() def raise_database_error(f: TFun) -> TFun: @@ -51,8 +65,38 @@ def _wrap(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore[return-value] +class SqlaDbApiCursor(DBApiCursorImpl): + def __init__(self, curr: sa.engine.CursorResult) -> None: + # Sqlalchemy CursorResult is *mostly* compatible with DB-API cursor + self.native_cursor = curr # type: ignore[assignment] + curr.columns + + self.fetchall = curr.fetchall # type: ignore[method-assign] + self.fetchone = curr.fetchone # type: ignore[method-assign] + self.fetchmany = curr.fetchmany # type: ignore[method-assign] + + def _get_columns(self) -> List[str]: + return list(self.native_cursor.keys()) # type: ignore[attr-defined] + + # @property + # def description(self) -> Any: + # # Get the underlying driver's cursor description, this is mostly used in tests + # return self.native_cursor.cursor.description # type: ignore[attr-defined] + + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("execute not implemented") + + +class DbApiProps: + # Only needed for some tests + paramstyle = "named" + + class SqlalchemyClient(SqlClientBase[Connection]): external_engine: bool = False + dialect: sa.engine.interfaces.Dialect + dialect_name: str + dbapi = DbApiProps # type: ignore[assignment] def __init__( self, @@ -60,7 +104,7 @@ def __init__( staging_dataset_name: str, credentials: SqlalchemyCredentials, capabilities: DestinationCapabilitiesContext, - engine_args: Optional[Dict[str, Any]] = None + engine_args: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self.credentials = credentials @@ -76,14 +120,8 @@ def __init__( self._current_connection: Optional[Connection] = None self._current_transaction: Optional[SqlaTransactionWrapper] = None self.metadata = sa.MetaData() - - @property - def dialect(self) -> sa.engine.interfaces.Dialect: - return self.engine.dialect - - @property - def dialect_name(self) -> str: - return self.dialect.name # type: ignore[attr-defined] + self.dialect = self.engine.dialect + self.dialect_name = self.dialect.name # type: ignore[attr-defined] def open_connection(self) -> Connection: if self._current_connection is None: @@ -191,7 +229,7 @@ def drop_dataset(self) -> None: return self._sqlite_drop_dataset(self.dataset_name) try: self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) - except DatabaseTerminalException as e: + except DatabaseTransientException as e: if isinstance(e.__cause__, sa.exc.ProgrammingError): # May not support CASCADE self.execute_sql(sa.schema.DropSchema(self.dataset_name)) @@ -220,7 +258,7 @@ def execute_sql( @contextmanager def execute_query( self, query: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any - ) -> Iterator[sa.engine.CursorResult]: + ) -> Iterator[DBApiCursor]: if isinstance(query, str): if args: # Sqlalchemy text supports :named paramstyle for all dialects @@ -228,7 +266,7 @@ def execute_query( args = () query = sa.text(query) with self._transaction(): - yield self._current_connection.execute(query, *args, **kwargs) + yield SqlaDbApiCursor(self._current_connection.execute(query, *args, **kwargs)) # type: ignore[abstract] def get_existing_table(self, table_name: str) -> Optional[sa.Table]: """Get a table object from metadata if it exists""" @@ -316,16 +354,27 @@ def compare_storage_table(self, table_name: str) -> Tuple[sa.Table, List[sa.Colu return existing, missing_columns, reflected is not None @staticmethod - def _make_database_exception(e: Exception) -> DatabaseException: + def _make_database_exception(e: Exception) -> Exception: if isinstance(e, sa.exc.NoSuchTableError): return DatabaseUndefinedRelation(e) + msg = str(e).lower() if isinstance(e, (sa.exc.ProgrammingError, sa.exc.OperationalError)): - msg = str(e) if "exist" in msg: # TODO: Hack return DatabaseUndefinedRelation(e) elif "no such" in msg: # sqlite # TODO: Hack return DatabaseUndefinedRelation(e) - return DatabaseTerminalException(e) + elif "unknown table" in msg: + return DatabaseUndefinedRelation(e) + elif "unknown database" in msg: + return DatabaseUndefinedRelation(e) + elif isinstance(e, (sa.exc.OperationalError, sa.exc.IntegrityError)): + raise DatabaseTerminalException(e) + return DatabaseTransientException(e) + elif isinstance(e, sa.exc.SQLAlchemyError): + return DatabaseTransientException(e) + else: + return e + # return DatabaseTerminalException(e) def _ensure_native_conn(self) -> None: if not self.native_connection: diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index a0fa542573..c6b78baf3b 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -20,7 +20,7 @@ class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient def _raw_capabilities(self) -> DestinationCapabilitiesContext: # https://www.sqlalchemyql.org/docs/current/limits.html - caps = DestinationCapabilitiesContext() + caps = DestinationCapabilitiesContext.generic_capabilities() caps.preferred_loader_file_format = "typed-jsonl" caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] caps.preferred_staging_file_format = None diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 6e2992faf7..7dfbaa9015 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -87,13 +87,19 @@ def _create_double_type(self) -> sa.types.TypeEngine: from sqlalchemy.dialects.mysql import DOUBLE return sa.Float(precision=53) # Otherwise use float + def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine: + precision, scale = column.get("precision"), column.get("scale") + if precision is None and scale is None: + precision, scale = self.capabilities.decimal_precision + return sa.Numeric(precision, scale) + def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.types.TypeEngine: sc_t = column["data_type"] precision = column.get("precision") # TODO: Precision and scale for supported types if sc_t == "text": length = precision - if length is None and (column.get("unique") or column.get("primary_key")): + if length is None and column.get("unique"): length = 128 if length is None: return sa.Text() # type: ignore[no-any-return] @@ -111,7 +117,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty elif sc_t == "complex": return sa.JSON(none_as_null=True) elif sc_t == "decimal": - return sa.Numeric(precision, column.get("scale", 0)) + return self._to_db_decimal_type(column) elif sc_t == "wei": wei_precision, wei_scale = self.capabilities.wei_precision return sa.Numeric(precision=wei_precision, scale=wei_scale) @@ -134,6 +140,7 @@ def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnType: precision, scale = db_type.precision, db_type.scale if (precision, scale) == self.capabilities.wei_precision: return dict(data_type="wei") + return dict(data_type="decimal", precision=precision, scale=scale) def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType: @@ -249,7 +256,11 @@ def __init__( capabilities: DestinationCapabilitiesContext, ) -> None: self.sql_client = SqlalchemyClient( - config.normalize_dataset_name(schema), None, config.credentials, capabilities, engine_args=config.engine_args + config.normalize_dataset_name(schema), + None, + config.credentials, + capabilities, + engine_args=config.engine_args, ) self.schema = schema @@ -283,7 +294,6 @@ def _to_column_object( schema_column["name"], self.type_mapper.to_db_type(schema_column, table_format), nullable=schema_column.get("nullable", True), - primary_key=schema_column.get("primary_key", False), unique=schema_column.get("unique", False), ) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 8998f77673..1a42db7f4a 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -9,7 +9,7 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision +from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_time from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -135,12 +135,7 @@ def some_data(): row[i] = round(row[i], 4) if isinstance(row[i], timedelta) and isinstance(first_record[i], dt_time): # Some drivers (mysqlclient) return TIME columns as timedelta as seconds since midnight - row[i] = dt_time( - hour=row[i].seconds // 3600, - minute=(row[i].seconds // 60) % 60, - second=row[i].seconds % 60, - microsecond=row[i].microseconds, - ) + row[i] = ensure_pendulum_time(row[i]) expected = sorted([list(r.values()) for r in records]) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 84bcf7cab0..c09b92717c 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -586,18 +586,23 @@ def test_load_with_all_types( client.schema._bump_version() client.update_stored_schema() - should_load_to_staging = client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined] - if should_load_to_staging: - with client.with_staging_dataset(): # type: ignore[attr-defined] - # create staging for merge dataset - client.initialize_storage() - client.update_stored_schema() + if isinstance(client, WithStagingDataset): + should_load_to_staging = client.should_load_data_to_staging_dataset( + client.schema.tables[table_name] + ) + if should_load_to_staging: + with client.with_staging_dataset(): + # create staging for merge dataset + client.initialize_storage() + client.update_stored_schema() - with client.sql_client.with_alternative_dataset_name( - client.sql_client.staging_dataset_name - if should_load_to_staging - else client.sql_client.dataset_name - ): + with client.sql_client.with_alternative_dataset_name( + client.sql_client.staging_dataset_name + if should_load_to_staging + else client.sql_client.dataset_name + ): + canonical_name = client.sql_client.make_qualified_table_name(table_name) + else: canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row print(data_row) @@ -656,6 +661,8 @@ def test_write_dispositions( client.update_stored_schema() if write_disposition == "merge": + if not client.capabilities.supported_merge_strategies: + pytest.skip("destination does not support merge") # add root key client.schema.tables[table_name]["columns"]["col1"]["root_key"] = True # create staging for merge dataset @@ -675,9 +682,11 @@ def test_write_dispositions( with io.BytesIO() as f: write_dataset(client, f, [data_row], column_schemas) query = f.getvalue() - if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] + if isinstance( + client, WithStagingDataset + ) and client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # load to staging dataset on merge - with client.with_staging_dataset(): # type: ignore[attr-defined] + with client.with_staging_dataset(): expect_load_file(client, file_storage, query, t) else: # load directly on other @@ -824,6 +833,8 @@ def test_get_stored_state( # get state stored_state = client.get_stored_state("pipeline") + # Ensure timezone aware datetime for comparing + stored_state.created_at = pendulum.instance(stored_state.created_at) assert doc == stored_state.as_doc() @@ -919,7 +930,11 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: "mandatory_column", "text", nullable=False ) client.schema._bump_version() - if destination_config.destination == "clickhouse": + if destination_config.destination == "clickhouse" or ( + # mysql allows adding not-null columns (they have an implicit default) + destination_config.destination == "sqlalchemy" + and client.sql_client.dialect_name == "mysql" + ): client.update_stored_schema() else: with pytest.raises(DatabaseException) as py_ex: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index e167f0ceda..fa154c65dc 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -20,7 +20,7 @@ from dlt.destinations.sql_client import DBApiCursor, SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.typing import TNativeConn -from dlt.common.time import ensure_pendulum_datetime +from dlt.common.time import ensure_pendulum_datetime, to_py_datetime from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage from tests.load.utils import ( @@ -62,7 +62,9 @@ def naming(request) -> str: @pytest.mark.parametrize( "client", destinations_configs( - default_sql_configs=True, exclude=["mssql", "synapse", "dremio", "clickhouse"] + # Only databases that support search path or equivalent + default_sql_configs=True, + exclude=["mssql", "synapse", "dremio", "clickhouse", "sqlalchemy"], ), indirect=True, ids=lambda x: x.name, @@ -212,10 +214,10 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert rows[0][0] == "event" # print(rows[0][1]) # print(type(rows[0][1])) - # convert to pendulum to make sure it is supported by dbapi + # ensure datetime obj to make sure it is supported by dbapi rows = client.sql_client.execute_sql( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - ensure_pendulum_datetime(rows[0][1]), + to_py_datetime(ensure_pendulum_datetime(rows[0][1])), ) assert len(rows) == 1 # use rows in subsequent test @@ -620,7 +622,7 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, exclude=["databricks"]), + destinations_configs(default_sql_configs=True, exclude=["databricks", "sqlalchemy"]), indirect=True, ids=lambda x: x.name, ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 3c8cfd2914..a20049ae4e 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -268,7 +268,11 @@ def destinations_configs( DestinationTestConfiguration(destination="duckdb", file_format="parquet"), DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), ] - destination_configs += [DestinationTestConfiguration(destination="sqlalchemy")] + destination_configs += [ + DestinationTestConfiguration( + destination="sqlalchemy", supports_merge=False, supports_dbt=False + ) + ] # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. destination_configs += [ DestinationTestConfiguration(