Skip to content

Commit

Permalink
Implement sqlalchemy loader
Browse files Browse the repository at this point in the history
Begin implementing sqlalchemy loader

SQLA load job, factory, schema storage, POC

sqlalchemy tests attempt

Implement SqlJobClient interface

Parquet load, some tests running on mysql

update lockfile

Limit bulk insert chunk size, sqlite create/drop schema, fixes

Generate schema update

Get more tests running with mysql

More tests passing

Fix state, schema restore
  • Loading branch information
steinitzu committed Aug 31, 2024
1 parent 1723faa commit e8d0ea8
Show file tree
Hide file tree
Showing 25 changed files with 1,244 additions and 78 deletions.
6 changes: 6 additions & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None
"""The destination can override the parallelism strategy"""

max_query_parameters: Optional[int] = None
"""The maximum number of parameters that can be supplied in a single parametrized query"""

supports_native_boolean: bool = True
"""The destination supports a native boolean type, otherwise bool columns are usually stored as integers"""

def generates_case_sensitive_identifiers(self) -> bool:
"""Tells if capabilities as currently adjusted, will generate case sensitive identifiers"""
# must have case sensitive support and folding function must preserve casing
Expand Down
17 changes: 16 additions & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def from_normalized_mapping(
schema=normalized_doc[naming_convention.normalize_identifier("schema")],
)

def to_normalized_mapping(self, naming_convention: NamingConvention) -> Dict[str, Any]:
"""Convert this instance to mapping where keys are normalized according to given naming convention
Args:
naming_convention: Naming convention that should be used to normalize keys
Returns:
Dict[str, Any]: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...})
"""
return {
naming_convention.normalize_identifier(key): value
for key, value in self._asdict().items()
}


@dataclasses.dataclass
class StateInfo:
Expand Down Expand Up @@ -383,6 +397,7 @@ def run_managed(
self.run()
self._state = "completed"
except (DestinationTerminalException, TerminalValueError) as e:
logger.exception(f"Job {self.job_id()} failed terminally")
self._state = "failed"
self._exception = e
logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}")
Expand Down Expand Up @@ -443,7 +458,7 @@ def __init__(
self.capabilities = capabilities

@abstractmethod
def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
def initialize_storage(self, truncate_tables: Optional[Iterable[str]] = None) -> None:
"""Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables."""
pass

Expand Down
83 changes: 48 additions & 35 deletions dlt/common/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,7 @@
from dlt.common.utils import map_nested_in_place


class SupportsJson(Protocol):
"""Minimum adapter for different json parser implementations"""

_impl_name: str
"""Implementation name"""

def dump(
self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False
) -> None: ...

def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ...

def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ...

def typed_loads(self, s: str) -> Any: ...

def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ...

def typed_loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ...

def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ...

def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ...

def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ...

def loads(self, s: str) -> Any: ...

def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ...
TPuaDecoders = List[Callable[[Any], Any]]


def custom_encode(obj: Any) -> str:
Expand Down Expand Up @@ -104,7 +76,7 @@ def _datetime_decoder(obj: str) -> datetime:


# define decoder for each prefix
DECODERS: List[Callable[[Any], Any]] = [
DECODERS: TPuaDecoders = [
Decimal,
_datetime_decoder,
pendulum.Date.fromisoformat,
Expand All @@ -114,6 +86,11 @@ def _datetime_decoder(obj: str) -> datetime:
Wei,
pendulum.Time.fromisoformat,
]
# Alternate decoders that decode date/time/datetime to stdlib types instead of pendulum
PY_DATETIME_DECODERS = list(DECODERS)
PY_DATETIME_DECODERS[1] = datetime.fromisoformat
PY_DATETIME_DECODERS[2] = date.fromisoformat
PY_DATETIME_DECODERS[7] = time.fromisoformat
# how many decoders?
PUA_CHARACTER_MAX = len(DECODERS)

Expand Down Expand Up @@ -151,13 +128,13 @@ def custom_pua_encode(obj: Any) -> str:
raise TypeError(repr(obj) + " is not JSON serializable")


def custom_pua_decode(obj: Any) -> Any:
def custom_pua_decode(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any:
if isinstance(obj, str) and len(obj) > 1:
c = ord(obj[0]) - PUA_START
# decode only the PUA space defined in DECODERS
if c >= 0 and c <= PUA_CHARACTER_MAX:
try:
return DECODERS[c](obj[1:])
return decoders[c](obj[1:])
except Exception:
# return strings that cannot be parsed
# this may be due
Expand All @@ -167,11 +144,11 @@ def custom_pua_decode(obj: Any) -> Any:
return obj


def custom_pua_decode_nested(obj: Any) -> Any:
def custom_pua_decode_nested(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any:
if isinstance(obj, str):
return custom_pua_decode(obj)
return custom_pua_decode(obj, decoders)
elif isinstance(obj, (list, dict)):
return map_nested_in_place(custom_pua_decode, obj)
return map_nested_in_place(custom_pua_decode, obj, decoders=decoders)
return obj


Expand All @@ -190,6 +167,39 @@ def may_have_pua(line: bytes) -> bool:
return PUA_START_UTF8_MAGIC in line


class SupportsJson(Protocol):
"""Minimum adapter for different json parser implementations"""

_impl_name: str
"""Implementation name"""

def dump(
self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False
) -> None: ...

def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ...

def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ...

def typed_loads(self, s: str) -> Any: ...

def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ...

def typed_loadb(
self, s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS
) -> Any: ...

def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ...

def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ...

def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ...

def loads(self, s: str) -> Any: ...

def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ...


# pick the right impl
json: SupportsJson = None
if os.environ.get(known_env.DLT_USE_JSON) == "simplejson":
Expand All @@ -216,4 +226,7 @@ def may_have_pua(line: bytes) -> bool:
"custom_pua_remove",
"SupportsJson",
"may_have_pua",
"TPuaDecoders",
"DECODERS",
"PY_DATETIME_DECODERS",
]
12 changes: 9 additions & 3 deletions dlt/common/json/_orjson.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import IO, Any, Union
import orjson

from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode
from dlt.common.json import (
custom_pua_encode,
custom_pua_decode_nested,
custom_encode,
TPuaDecoders,
DECODERS,
)
from dlt.common.typing import AnyFun

_impl_name = "orjson"
Expand Down Expand Up @@ -38,8 +44,8 @@ def typed_loads(s: str) -> Any:
return custom_pua_decode_nested(loads(s))


def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any:
return custom_pua_decode_nested(loadb(s))
def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any:
return custom_pua_decode_nested(loadb(s), decoders)


def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str:
Expand Down
12 changes: 9 additions & 3 deletions dlt/common/json/_simplejson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import simplejson
import platform

from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode
from dlt.common.json import (
custom_pua_encode,
custom_pua_decode_nested,
custom_encode,
TPuaDecoders,
DECODERS,
)

if platform.python_implementation() == "PyPy":
# disable speedups on PyPy, it can be actually faster than Python C
Expand Down Expand Up @@ -73,8 +79,8 @@ def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> byte
return typed_dumps(obj, sort_keys, pretty).encode("utf-8")


def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any:
return custom_pua_decode_nested(loadb(s))
def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any:
return custom_pua_decode_nested(loadb(s), decoders)


def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str:
Expand Down
1 change: 1 addition & 0 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pyarrow.parquet
import pyarrow.compute
import pyarrow.dataset
from pyarrow.parquet import ParquetFile
except ModuleNotFoundError:
raise MissingDependencyException(
"dlt pyarrow helpers",
Expand Down
8 changes: 8 additions & 0 deletions dlt/common/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time:
return result
else:
raise ValueError(f"{value} is not a valid ISO time string.")
elif isinstance(value, timedelta):
# Assume timedelta is seconds passed since midnight. Some drivers (mysqlclient) return time in this format
return pendulum.time(
value.seconds // 3600,
(value.seconds // 60) % 60,
value.seconds % 60,
value.microseconds,
)
raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.")


Expand Down
14 changes: 8 additions & 6 deletions dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ def clone_dict_nested(src: TDict) -> TDict:
return update_dict_nested({}, src, copy_src_dicts=True) # type: ignore[return-value]


def map_nested_in_place(func: AnyFun, _complex: TAny) -> TAny:
"""Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place."""
def map_nested_in_place(func: AnyFun, _complex: TAny, *args: Any, **kwargs: Any) -> TAny:
"""Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place.
Additional `*args` and `**kwargs` are passed to `func`.
"""
if isinstance(_complex, tuple):
if hasattr(_complex, "_asdict"):
_complex = _complex._asdict()
Expand All @@ -293,15 +295,15 @@ def map_nested_in_place(func: AnyFun, _complex: TAny) -> TAny:
if isinstance(_complex, dict):
for k, v in _complex.items():
if isinstance(v, (dict, list, tuple)):
_complex[k] = map_nested_in_place(func, v)
_complex[k] = map_nested_in_place(func, v, *args, **kwargs)
else:
_complex[k] = func(v)
_complex[k] = func(v, *args, **kwargs)
elif isinstance(_complex, list):
for idx, _l in enumerate(_complex):
if isinstance(_l, (dict, list, tuple)):
_complex[idx] = map_nested_in_place(func, _l)
_complex[idx] = map_nested_in_place(func, _l, *args, **kwargs)
else:
_complex[idx] = func(_l)
_complex[idx] = func(_l, *args, **kwargs)
else:
raise ValueError(_complex, "Not a complex type")
return _complex
Expand Down
2 changes: 2 additions & 0 deletions dlt/destinations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dlt.destinations.impl.databricks.factory import databricks
from dlt.destinations.impl.dremio.factory import dremio
from dlt.destinations.impl.clickhouse.factory import clickhouse
from dlt.destinations.impl.sqlalchemy.factory import sqlalchemy


__all__ = [
Expand All @@ -37,4 +38,5 @@
"dremio",
"clickhouse",
"destination",
"sqlalchemy",
]
Empty file.
58 changes: 58 additions & 0 deletions dlt/destinations/impl/sqlalchemy/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING, Optional, Any, Final, Type, Dict
import dataclasses

from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
from dlt.common.destination.reference import DestinationClientDwhConfiguration

if TYPE_CHECKING:
from sqlalchemy.engine import Engine, Dialect


@configspec
class SqlalchemyCredentials(ConnectionStringCredentials):
if TYPE_CHECKING:
_engine: Optional["Engine"] = None

username: Optional[str] = None # e.g. sqlite doesn't need username

def parse_native_representation(self, native_value: Any) -> None:
from sqlalchemy.engine import Engine

if isinstance(native_value, Engine):
self.engine = native_value
super().parse_native_representation(
native_value.url.render_as_string(hide_password=False)
)
else:
super().parse_native_representation(native_value)

@property
def engine(self) -> Optional["Engine"]:
return getattr(self, "_engine", None) # type: ignore[no-any-return]

@engine.setter
def engine(self, value: "Engine") -> None:
self._engine = value

def get_dialect(self) -> Optional[Type["Dialect"]]:
if not self.drivername:
return None
# Type-ignore because of ported URL class has no get_dialect method,
# but here sqlalchemy should be available
if engine := self.engine:
return type(engine.dialect)
return self.to_url().get_dialect() # type: ignore[attr-defined,no-any-return]


@configspec
class SqlalchemyClientConfiguration(DestinationClientDwhConfiguration):
destination_type: Final[str] = dataclasses.field(default="sqlalchemy", init=False, repr=False, compare=False) # type: ignore
credentials: SqlalchemyCredentials = None
"""SQLAlchemy connection string"""

engine_args: Dict[str, Any] = dataclasses.field(default_factory=dict)
"""Additional arguments passed to `sqlalchemy.create_engine`"""

def get_dialect(self) -> Type["Dialect"]:
return self.credentials.get_dialect()
Loading

0 comments on commit e8d0ea8

Please sign in to comment.