Skip to content

Commit

Permalink
More tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Aug 29, 2024
1 parent cbea334 commit 6b62d37
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 52 deletions.
8 changes: 8 additions & 0 deletions dlt/common/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time:
return result
else:
raise ValueError(f"{value} is not a valid ISO time string.")
elif isinstance(value, timedelta):
# Assume timedelta is seconds passed since midnight. Some drivers (mysqlclient) return time in this format
return pendulum.time(
value.seconds // 3600,
(value.seconds // 60) % 60,
value.seconds % 60,
value.microseconds,
)
raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.")


Expand Down
89 changes: 69 additions & 20 deletions dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Optional, Iterator, Any, Sequence, ContextManager, AnyStr, Union, Tuple, List, Dict
from typing import (
Optional,
Iterator,
Any,
Sequence,
ContextManager,
AnyStr,
Union,
Tuple,
List,
Dict,
)
from contextlib import contextmanager
from functools import wraps
import inspect
Expand All @@ -11,11 +22,12 @@
from dlt.destinations.exceptions import (
DatabaseUndefinedRelation,
DatabaseTerminalException,
DatabaseTransientException,
LoadClientNotConnected,
DatabaseException,
)
from dlt.destinations.typing import DBTransaction
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.typing import DBTransaction, DBApiCursor
from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl
from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials
from dlt.common.typing import TFun

Expand All @@ -25,10 +37,12 @@ def __init__(self, sqla_transaction: sa.engine.Transaction) -> None:
self.sqla_transaction = sqla_transaction

def commit_transaction(self) -> None:
self.sqla_transaction.commit()
if self.sqla_transaction.is_active:
self.sqla_transaction.commit()

def rollback_transaction(self) -> None:
self.sqla_transaction.rollback()
if self.sqla_transaction.is_active:
self.sqla_transaction.rollback()


def raise_database_error(f: TFun) -> TFun:
Expand All @@ -51,16 +65,46 @@ def _wrap(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any:
return _wrap # type: ignore[return-value]


class SqlaDbApiCursor(DBApiCursorImpl):
def __init__(self, curr: sa.engine.CursorResult) -> None:
# Sqlalchemy CursorResult is *mostly* compatible with DB-API cursor
self.native_cursor = curr # type: ignore[assignment]
curr.columns

self.fetchall = curr.fetchall # type: ignore[method-assign]
self.fetchone = curr.fetchone # type: ignore[method-assign]
self.fetchmany = curr.fetchmany # type: ignore[method-assign]

def _get_columns(self) -> List[str]:
return list(self.native_cursor.keys()) # type: ignore[attr-defined]

# @property
# def description(self) -> Any:
# # Get the underlying driver's cursor description, this is mostly used in tests
# return self.native_cursor.cursor.description # type: ignore[attr-defined]

def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("execute not implemented")


class DbApiProps:
# Only needed for some tests
paramstyle = "named"


class SqlalchemyClient(SqlClientBase[Connection]):
external_engine: bool = False
dialect: sa.engine.interfaces.Dialect
dialect_name: str
dbapi = DbApiProps # type: ignore[assignment]

def __init__(
self,
dataset_name: str,
staging_dataset_name: str,
credentials: SqlalchemyCredentials,
capabilities: DestinationCapabilitiesContext,
engine_args: Optional[Dict[str, Any]] = None
engine_args: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities)
self.credentials = credentials
Expand All @@ -76,14 +120,8 @@ def __init__(
self._current_connection: Optional[Connection] = None
self._current_transaction: Optional[SqlaTransactionWrapper] = None
self.metadata = sa.MetaData()

@property
def dialect(self) -> sa.engine.interfaces.Dialect:
return self.engine.dialect

@property
def dialect_name(self) -> str:
return self.dialect.name # type: ignore[attr-defined]
self.dialect = self.engine.dialect
self.dialect_name = self.dialect.name # type: ignore[attr-defined]

def open_connection(self) -> Connection:
if self._current_connection is None:
Expand Down Expand Up @@ -191,7 +229,7 @@ def drop_dataset(self) -> None:
return self._sqlite_drop_dataset(self.dataset_name)
try:
self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True))
except DatabaseTerminalException as e:
except DatabaseTransientException as e:
if isinstance(e.__cause__, sa.exc.ProgrammingError):
# May not support CASCADE
self.execute_sql(sa.schema.DropSchema(self.dataset_name))
Expand Down Expand Up @@ -220,15 +258,15 @@ def execute_sql(
@contextmanager
def execute_query(
self, query: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any
) -> Iterator[sa.engine.CursorResult]:
) -> Iterator[DBApiCursor]:
if isinstance(query, str):
if args:
# Sqlalchemy text supports :named paramstyle for all dialects
query, kwargs = self._to_named_paramstyle(query, args) # type: ignore[assignment]
args = ()
query = sa.text(query)
with self._transaction():
yield self._current_connection.execute(query, *args, **kwargs)
yield SqlaDbApiCursor(self._current_connection.execute(query, *args, **kwargs)) # type: ignore[abstract]

def get_existing_table(self, table_name: str) -> Optional[sa.Table]:
"""Get a table object from metadata if it exists"""
Expand Down Expand Up @@ -316,16 +354,27 @@ def compare_storage_table(self, table_name: str) -> Tuple[sa.Table, List[sa.Colu
return existing, missing_columns, reflected is not None

@staticmethod
def _make_database_exception(e: Exception) -> DatabaseException:
def _make_database_exception(e: Exception) -> Exception:
if isinstance(e, sa.exc.NoSuchTableError):
return DatabaseUndefinedRelation(e)
msg = str(e).lower()
if isinstance(e, (sa.exc.ProgrammingError, sa.exc.OperationalError)):
msg = str(e)
if "exist" in msg: # TODO: Hack
return DatabaseUndefinedRelation(e)
elif "no such" in msg: # sqlite # TODO: Hack
return DatabaseUndefinedRelation(e)
return DatabaseTerminalException(e)
elif "unknown table" in msg:
return DatabaseUndefinedRelation(e)
elif "unknown database" in msg:
return DatabaseUndefinedRelation(e)
elif isinstance(e, (sa.exc.OperationalError, sa.exc.IntegrityError)):
raise DatabaseTerminalException(e)
return DatabaseTransientException(e)
elif isinstance(e, sa.exc.SQLAlchemyError):
return DatabaseTransientException(e)
else:
return e
# return DatabaseTerminalException(e)

def _ensure_native_conn(self) -> None:
if not self.native_connection:
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient

def _raw_capabilities(self) -> DestinationCapabilitiesContext:
# https://www.sqlalchemyql.org/docs/current/limits.html
caps = DestinationCapabilitiesContext()
caps = DestinationCapabilitiesContext.generic_capabilities()
caps.preferred_loader_file_format = "typed-jsonl"
caps.supported_loader_file_formats = ["typed-jsonl", "parquet"]
caps.preferred_staging_file_format = None
Expand Down
18 changes: 14 additions & 4 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ def _create_double_type(self) -> sa.types.TypeEngine:
from sqlalchemy.dialects.mysql import DOUBLE
return sa.Float(precision=53) # Otherwise use float

def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine:
precision, scale = column.get("precision"), column.get("scale")
if precision is None and scale is None:
precision, scale = self.capabilities.decimal_precision
return sa.Numeric(precision, scale)

def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.types.TypeEngine:
sc_t = column["data_type"]
precision = column.get("precision")
# TODO: Precision and scale for supported types
if sc_t == "text":
length = precision
if length is None and (column.get("unique") or column.get("primary_key")):
if length is None and column.get("unique"):
length = 128
if length is None:
return sa.Text() # type: ignore[no-any-return]
Expand All @@ -111,7 +117,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty
elif sc_t == "complex":
return sa.JSON(none_as_null=True)
elif sc_t == "decimal":
return sa.Numeric(precision, column.get("scale", 0))
return self._to_db_decimal_type(column)
elif sc_t == "wei":
wei_precision, wei_scale = self.capabilities.wei_precision
return sa.Numeric(precision=wei_precision, scale=wei_scale)
Expand All @@ -134,6 +140,7 @@ def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnType:
precision, scale = db_type.precision, db_type.scale
if (precision, scale) == self.capabilities.wei_precision:
return dict(data_type="wei")

return dict(data_type="decimal", precision=precision, scale=scale)

def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType:
Expand Down Expand Up @@ -249,7 +256,11 @@ def __init__(
capabilities: DestinationCapabilitiesContext,
) -> None:
self.sql_client = SqlalchemyClient(
config.normalize_dataset_name(schema), None, config.credentials, capabilities, engine_args=config.engine_args
config.normalize_dataset_name(schema),
None,
config.credentials,
capabilities,
engine_args=config.engine_args,
)

self.schema = schema
Expand Down Expand Up @@ -283,7 +294,6 @@ def _to_column_object(
schema_column["name"],
self.type_mapper.to_db_type(schema_column, table_format),
nullable=schema_column.get("nullable", True),
primary_key=schema_column.get("primary_key", False),
unique=schema_column.get("unique", False),
)

Expand Down
9 changes: 2 additions & 7 deletions tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import dlt
from dlt.common import pendulum
from dlt.common.time import reduce_pendulum_datetime_precision
from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_time
from dlt.common.utils import uniq_id

from tests.load.utils import destinations_configs, DestinationTestConfiguration
Expand Down Expand Up @@ -135,12 +135,7 @@ def some_data():
row[i] = round(row[i], 4)
if isinstance(row[i], timedelta) and isinstance(first_record[i], dt_time):
# Some drivers (mysqlclient) return TIME columns as timedelta as seconds since midnight
row[i] = dt_time(
hour=row[i].seconds // 3600,
minute=(row[i].seconds // 60) % 60,
second=row[i].seconds % 60,
microsecond=row[i].microseconds,
)
row[i] = ensure_pendulum_time(row[i])

expected = sorted([list(r.values()) for r in records])

Expand Down
43 changes: 29 additions & 14 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,18 +586,23 @@ def test_load_with_all_types(
client.schema._bump_version()
client.update_stored_schema()

should_load_to_staging = client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined]
if should_load_to_staging:
with client.with_staging_dataset(): # type: ignore[attr-defined]
# create staging for merge dataset
client.initialize_storage()
client.update_stored_schema()
if isinstance(client, WithStagingDataset):
should_load_to_staging = client.should_load_data_to_staging_dataset(
client.schema.tables[table_name]
)
if should_load_to_staging:
with client.with_staging_dataset():
# create staging for merge dataset
client.initialize_storage()
client.update_stored_schema()

with client.sql_client.with_alternative_dataset_name(
client.sql_client.staging_dataset_name
if should_load_to_staging
else client.sql_client.dataset_name
):
with client.sql_client.with_alternative_dataset_name(
client.sql_client.staging_dataset_name
if should_load_to_staging
else client.sql_client.dataset_name
):
canonical_name = client.sql_client.make_qualified_table_name(table_name)
else:
canonical_name = client.sql_client.make_qualified_table_name(table_name)
# write row
print(data_row)
Expand Down Expand Up @@ -656,6 +661,8 @@ def test_write_dispositions(
client.update_stored_schema()

if write_disposition == "merge":
if not client.capabilities.supported_merge_strategies:
pytest.skip("destination does not support merge")
# add root key
client.schema.tables[table_name]["columns"]["col1"]["root_key"] = True
# create staging for merge dataset
Expand All @@ -675,9 +682,11 @@ def test_write_dispositions(
with io.BytesIO() as f:
write_dataset(client, f, [data_row], column_schemas)
query = f.getvalue()
if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined]
if isinstance(
client, WithStagingDataset
) and client.should_load_data_to_staging_dataset(client.schema.tables[table_name]):
# load to staging dataset on merge
with client.with_staging_dataset(): # type: ignore[attr-defined]
with client.with_staging_dataset():
expect_load_file(client, file_storage, query, t)
else:
# load directly on other
Expand Down Expand Up @@ -824,6 +833,8 @@ def test_get_stored_state(

# get state
stored_state = client.get_stored_state("pipeline")
# Ensure timezone aware datetime for comparing
stored_state.created_at = pendulum.instance(stored_state.created_at)
assert doc == stored_state.as_doc()


Expand Down Expand Up @@ -919,7 +930,11 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None:
"mandatory_column", "text", nullable=False
)
client.schema._bump_version()
if destination_config.destination == "clickhouse":
if destination_config.destination == "clickhouse" or (
# mysql allows adding not-null columns (they have an implicit default)
destination_config.destination == "sqlalchemy"
and client.sql_client.dialect_name == "mysql"
):
client.update_stored_schema()
else:
with pytest.raises(DatabaseException) as py_ex:
Expand Down
12 changes: 7 additions & 5 deletions tests/load/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dlt.destinations.sql_client import DBApiCursor, SqlClientBase
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.typing import TNativeConn
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.time import ensure_pendulum_datetime, to_py_datetime

from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage
from tests.load.utils import (
Expand Down Expand Up @@ -62,7 +62,9 @@ def naming(request) -> str:
@pytest.mark.parametrize(
"client",
destinations_configs(
default_sql_configs=True, exclude=["mssql", "synapse", "dremio", "clickhouse"]
# Only databases that support search path or equivalent
default_sql_configs=True,
exclude=["mssql", "synapse", "dremio", "clickhouse", "sqlalchemy"],
),
indirect=True,
ids=lambda x: x.name,
Expand Down Expand Up @@ -212,10 +214,10 @@ def test_execute_sql(client: SqlJobClientBase) -> None:
assert rows[0][0] == "event"
# print(rows[0][1])
# print(type(rows[0][1]))
# convert to pendulum to make sure it is supported by dbapi
# ensure datetime obj to make sure it is supported by dbapi
rows = client.sql_client.execute_sql(
f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s",
ensure_pendulum_datetime(rows[0][1]),
to_py_datetime(ensure_pendulum_datetime(rows[0][1])),
)
assert len(rows) == 1
# use rows in subsequent test
Expand Down Expand Up @@ -620,7 +622,7 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None:

@pytest.mark.parametrize(
"client",
destinations_configs(default_sql_configs=True, exclude=["databricks"]),
destinations_configs(default_sql_configs=True, exclude=["databricks", "sqlalchemy"]),
indirect=True,
ids=lambda x: x.name,
)
Expand Down
Loading

0 comments on commit 6b62d37

Please sign in to comment.