diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index a034ac7eb0..7fae69ff9e 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -29,6 +29,7 @@ env: # Test redshift and filesystem with all buckets # postgres runs again here so we can test on mac/windows ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]" + # note that all buckets are enabled for testing jobs: get_docs_changes: diff --git a/Makefile b/Makefile index 15fb895a9f..f47047a3fe 100644 --- a/Makefile +++ b/Makefile @@ -52,7 +52,7 @@ lint: poetry run mypy --config-file mypy.ini dlt tests poetry run flake8 --max-line-length=200 dlt poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases - poetry run black dlt docs tests --diff --extend-exclude=".*syntax_error.py" + poetry run black dlt docs tests --check --diff --color --extend-exclude=".*syntax_error.py" # poetry run isort ./ --diff # $(MAKE) lint-security diff --git a/dlt/common/configuration/specs/azure_credentials.py b/dlt/common/configuration/specs/azure_credentials.py index 7fa34fa00f..6794b581ce 100644 --- a/dlt/common/configuration/specs/azure_credentials.py +++ b/dlt/common/configuration/specs/azure_credentials.py @@ -32,6 +32,8 @@ def to_object_store_rs_credentials(self) -> Dict[str, str]: creds = self.to_adlfs_credentials() if creds["sas_token"] is None: creds.pop("sas_token") + if creds["account_key"] is None: + creds.pop("account_key") return creds def create_sas_token(self) -> None: diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 945e74a37b..9966590c06 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -1,6 +1,5 @@ from dlt.common.data_writers.writers import ( DataWriter, - DataWriterMetrics, TDataItemFormat, FileWriterSpec, create_import_spec, @@ -22,7 +21,6 @@ "resolve_best_writer_spec", "get_best_writer_spec", "is_native_writer", - "DataWriterMetrics", "TDataItemFormat", "BufferedDataWriter", "new_file_id", diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 8077007edb..945fca6580 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -3,6 +3,7 @@ import contextlib from typing import ClassVar, Iterator, List, IO, Any, Optional, Type, Generic +from dlt.common.metrics import DataWriterMetrics from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers.exceptions import ( BufferedDataWriterClosed, @@ -10,7 +11,7 @@ FileImportNotFound, InvalidFileNameTemplateException, ) -from dlt.common.data_writers.writers import TWriter, DataWriter, DataWriterMetrics, FileWriterSpec +from dlt.common.data_writers.writers import TWriter, DataWriter, FileWriterSpec from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration import with_config, known_sections, configspec from dlt.common.configuration.specs import BaseConfiguration diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index d324792a83..abd3343ea1 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -34,6 +34,7 @@ TLoaderFileFormat, ALL_SUPPORTED_FILE_FORMATS, ) +from dlt.common.metrics import DataWriterMetrics from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.typing import StrAny @@ -59,25 +60,6 @@ class FileWriterSpec(NamedTuple): supports_compression: bool = False -class DataWriterMetrics(NamedTuple): - file_path: str - items_count: int - file_size: int - created: float - last_modified: float - - def __add__(self, other: Tuple[object, ...], /) -> Tuple[object, ...]: - if isinstance(other, DataWriterMetrics): - return DataWriterMetrics( - "", # path is not known - self.items_count + other.items_count, - self.file_size + other.file_size, - min(self.created, other.created), - max(self.last_modified, other.last_modified), - ) - return NotImplemented - - EMPTY_DATA_WRITER_METRICS = DataWriterMetrics("", 0, 0, 2**32, 0.0) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index be71cb50e9..52e7d74833 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -76,6 +76,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): # use naming convention in the schema naming_convention: TNamingConventionReferenceArg = None alter_add_multi_column: bool = True + supports_create_table_if_not_exists: bool = True supports_truncate_command: bool = True schema_supports_numeric_precision: bool = True timestamp_precision: int = 6 diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 3af7dcff13..e7bba266df 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -24,10 +24,11 @@ from copy import deepcopy import inspect -from dlt.common import logger +from dlt.common import logger, pendulum from dlt.common.configuration.specs.base_configuration import extract_inner_hint from dlt.common.destination.utils import verify_schema_capabilities from dlt.common.exceptions import TerminalValueError +from dlt.common.metrics import LoadJobMetrics from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.utils import ( @@ -268,6 +269,8 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura staging_config: Optional[DestinationClientStagingConfiguration] = None """configuration of the staging, if present, injected at runtime""" + truncate_tables_on_staging_destination_before_load: bool = True + """If dlt should truncate the tables on staging destination before loading data.""" TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] @@ -284,6 +287,8 @@ def __init__(self, file_path: str) -> None: # NOTE: we only accept a full filepath in the constructor assert self._file_name != self._file_path self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name) + self._started_at: pendulum.DateTime = None + self._finished_at: pendulum.DateTime = None def job_id(self) -> str: """The job id that is derived from the file name and does not changes during job lifecycle""" @@ -306,6 +311,18 @@ def exception(self) -> str: """The exception associated with failed or retry states""" pass + def metrics(self) -> Optional[LoadJobMetrics]: + """Returns job execution metrics""" + return LoadJobMetrics( + self._parsed_file_name.job_id(), + self._file_path, + self._parsed_file_name.table_name, + self._started_at, + self._finished_at, + self.state(), + None, + ) + class RunnableLoadJob(LoadJob, ABC): """Represents a runnable job that loads a single file @@ -361,16 +378,22 @@ def run_managed( # filepath is now moved to running try: self._state = "running" + self._started_at = pendulum.now() self._job_client.prepare_load_job_execution(self) self.run() self._state = "completed" except (DestinationTerminalException, TerminalValueError) as e: self._state = "failed" self._exception = e + logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") except (DestinationTransientException, Exception) as e: self._state = "retry" self._exception = e + logger.exception( + f"Transient exception in job {self.job_id()} in file {self._file_path}" + ) finally: + self._finished_at = pendulum.now() # sanity check assert self._state in ("completed", "retry", "failed") @@ -391,7 +414,7 @@ def exception(self) -> str: return str(self._exception) -class FollowupJob: +class FollowupJobRequest: """Base class for follow up jobs that should be created""" @abstractmethod @@ -403,8 +426,8 @@ def new_file_path(self) -> str: class HasFollowupJobs: """Adds a trait that allows to create single or table chain followup jobs""" - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: - """Return list of new jobs. `final_state` is state to which this job transits""" + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]: + """Return list of jobs requests for jobs that should be created. `final_state` is state to which this job transits""" return [] @@ -479,7 +502,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] @@ -557,17 +580,28 @@ def with_staging_dataset(self) -> ContextManager["JobClientBase"]: return self # type: ignore -class SupportsStagingDestination: +class SupportsStagingDestination(ABC): """Adds capability to support a staging destination for the load""" def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: + """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination + instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset + on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. + The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their + own storage) will copy data. + """ return False + @abstractmethod def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: - # the default is to truncate the tables on the staging destination... - return True + """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which + can be changed with a config flag. + For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. + For Athena we truncate those tables only on "replace" write disposition. + """ + pass # TODO: type Destination properly diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index d98795d07c..38b23ea27a 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -5,13 +5,15 @@ from dlt.common import logger from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.pyarrow import cast_arrow_schema_types -from dlt.common.schema.typing import TWriteDisposition +from dlt.common.schema.typing import TWriteDisposition, TTableSchema +from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop from dlt.common.exceptions import MissingDependencyException from dlt.common.storages import FilesystemConfiguration from dlt.common.utils import assert_min_pkg_version from dlt.destinations.impl.filesystem.filesystem import FilesystemClient try: + import deltalake from deltalake import write_deltalake, DeltaTable from deltalake.writer import try_get_deltatable except ModuleNotFoundError: @@ -74,7 +76,7 @@ def write_delta_table( partition_by: Optional[Union[List[str], str]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> None: - """Writes in-memory Arrow table to on-disk Delta table. + """Writes in-memory Arrow data to on-disk Delta table. Thin wrapper around `deltalake.write_deltalake`. """ @@ -93,31 +95,73 @@ def write_delta_table( ) -def get_delta_tables(pipeline: Pipeline, *tables: str) -> Dict[str, DeltaTable]: - """Returns Delta tables in `pipeline.default_schema` as `deltalake.DeltaTable` objects. +def merge_delta_table( + table: DeltaTable, + data: Union[pa.Table, pa.RecordBatchReader], + schema: TTableSchema, +) -> None: + """Merges in-memory Arrow data into on-disk Delta table.""" + + strategy = schema["x-merge-strategy"] # type: ignore[typeddict-item] + if strategy == "upsert": + # `DeltaTable.merge` does not support automatic schema evolution + # https://github.com/delta-io/delta-rs/issues/2282 + _evolve_delta_table_schema(table, data.schema) + + if "parent" in schema: + unique_column = get_first_column_name_with_prop(schema, "unique") + predicate = f"target.{unique_column} = source.{unique_column}" + else: + primary_keys = get_columns_names_with_prop(schema, "primary_key") + predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) + + qry = ( + table.merge( + source=ensure_delta_compatible_arrow_data(data), + predicate=predicate, + source_alias="source", + target_alias="target", + ) + .when_matched_update_all() + .when_not_matched_insert_all() + ) + + qry.execute() + else: + ValueError(f'Merge strategy "{strategy}" not supported.') + + +def get_delta_tables( + pipeline: Pipeline, *tables: str, schema_name: str = None +) -> Dict[str, DeltaTable]: + """Returns Delta tables in `pipeline.default_schema (default)` as `deltalake.DeltaTable` objects. Returned object is a dictionary with table names as keys and `DeltaTable` objects as values. Optionally filters dictionary by table names specified as `*tables*`. - Raises ValueError if table name specified as `*tables` is not found. + Raises ValueError if table name specified as `*tables` is not found. You may try to switch to other + schemas via `schema_name` argument. """ from dlt.common.schema.utils import get_table_format - with pipeline.destination_client() as client: + with pipeline.destination_client(schema_name=schema_name) as client: assert isinstance( client, FilesystemClient ), "The `get_delta_tables` function requires a `filesystem` destination." schema_delta_tables = [ t["name"] - for t in pipeline.default_schema.tables.values() - if get_table_format(pipeline.default_schema.tables, t["name"]) == "delta" + for t in client.schema.tables.values() + if get_table_format(client.schema.tables, t["name"]) == "delta" ] if len(tables) > 0: invalid_tables = set(tables) - set(schema_delta_tables) if len(invalid_tables) > 0: + available_schemas = "" + if len(pipeline.schema_names) > 1: + available_schemas = f" Available schemas are {pipeline.schema_names}" raise ValueError( - "Schema does not contain Delta tables with these names: " - f"{', '.join(invalid_tables)}." + f"Schema {client.schema.name} does not contain Delta tables with these names: " + f"{', '.join(invalid_tables)}.{available_schemas}" ) schema_delta_tables = [t for t in schema_delta_tables if t in tables] table_dirs = client.get_table_dirs(schema_delta_tables, remote=True) @@ -132,7 +176,8 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str """Returns dict that can be passed as `storage_options` in `deltalake` library.""" creds = {} extra_options = {} - if config.protocol in ("az", "gs", "s3"): + # TODO: create a mixin with to_object_store_rs_credentials for a proper discovery + if hasattr(config.credentials, "to_object_store_rs_credentials"): creds = config.credentials.to_object_store_rs_credentials() if config.deltalake_storage_options is not None: extra_options = config.deltalake_storage_options @@ -145,3 +190,16 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str + ". dlt will use the values in `deltalake_storage_options`." ) return {**creds, **extra_options} + + +def _evolve_delta_table_schema(delta_table: DeltaTable, arrow_schema: pa.Schema) -> None: + """Evolves `delta_table` schema if different from `arrow_schema`. + + Adds column(s) to `delta_table` present in `arrow_schema` but not in `delta_table`. + """ + new_fields = [ + deltalake.Field.from_pyarrow(field) + for field in ensure_delta_compatible_arrow_schema(arrow_schema) + if field not in delta_table.to_pyarrow_dataset().schema + ] + delta_table.alter.add_columns(new_fields) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 9d3e97421c..e9dcfaf095 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -54,7 +54,10 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) + # sets timezone to None when timezone hint is false + timezone = tz if column.get("timezone", True) else None + precision = column.get("precision") or caps.timestamp_precision + return get_py_arrow_timestamp(precision, timezone) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -139,6 +142,10 @@ def get_column_type_from_py_arrow(dtype: pyarrow.DataType) -> TColumnType: precision = 6 else: precision = 9 + + if dtype.tz is None: + return dict(data_type="timestamp", precision=precision, timezone=False) + return dict(data_type="timestamp", precision=precision) elif pyarrow.types.is_date(dtype): return dict(data_type="date") diff --git a/dlt/common/metrics.py b/dlt/common/metrics.py new file mode 100644 index 0000000000..d6acf19d0d --- /dev/null +++ b/dlt/common/metrics.py @@ -0,0 +1,71 @@ +import datetime # noqa: I251 +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypedDict # noqa: 251 + + +class DataWriterMetrics(NamedTuple): + file_path: str + items_count: int + file_size: int + created: float + last_modified: float + + def __add__(self, other: Tuple[object, ...], /) -> Tuple[object, ...]: + if isinstance(other, DataWriterMetrics): + return DataWriterMetrics( + self.file_path if self.file_path == other.file_path else "", + # self.table_name if self.table_name == other.table_name else "", + self.items_count + other.items_count, + self.file_size + other.file_size, + min(self.created, other.created), + max(self.last_modified, other.last_modified), + ) + return NotImplemented + + +class StepMetrics(TypedDict): + """Metrics for particular package processed in particular pipeline step""" + + started_at: datetime.datetime + """Start of package processing""" + finished_at: datetime.datetime + """End of package processing""" + + +class ExtractDataInfo(TypedDict): + name: str + data_type: str + + +class ExtractMetrics(StepMetrics): + schema_name: str + job_metrics: Dict[str, DataWriterMetrics] + """Metrics collected per job id during writing of job file""" + table_metrics: Dict[str, DataWriterMetrics] + """Job metrics aggregated by table""" + resource_metrics: Dict[str, DataWriterMetrics] + """Job metrics aggregated by resource""" + dag: List[Tuple[str, str]] + """A resource dag where elements of the list are graph edges""" + hints: Dict[str, Dict[str, Any]] + """Hints passed to the resources""" + + +class NormalizeMetrics(StepMetrics): + job_metrics: Dict[str, DataWriterMetrics] + """Metrics collected per job id during writing of job file""" + table_metrics: Dict[str, DataWriterMetrics] + """Job metrics aggregated by table""" + + +class LoadJobMetrics(NamedTuple): + job_id: str + file_path: str + table_name: str + started_at: datetime.datetime + finished_at: datetime.datetime + state: Optional[str] + remote_url: Optional[str] + + +class LoadMetrics(StepMetrics): + job_metrics: Dict[str, LoadJobMetrics] diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index a13bab15f4..725f6a8355 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -54,9 +54,9 @@ class SupportsDataItemNormalizer(Protocol): """A class with a name DataItemNormalizer deriving from normalizers.json.DataItemNormalizer""" -def wrap_in_dict(item: Any) -> DictStrAny: +def wrap_in_dict(label: str, item: Any) -> DictStrAny: """Wraps `item` that is not a dictionary into dictionary that can be json normalized""" - return {"value": item} + return {label: item} __all__ = [ diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 8e296445eb..33184640f0 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -184,11 +184,10 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> # and all child tables must be lists return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) - @staticmethod - def _link_row(row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny: + def _link_row(self, row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny: assert parent_row_id - row["_dlt_parent_id"] = parent_row_id - row["_dlt_list_idx"] = list_idx + row[self.c_dlt_parent_id] = parent_row_id + row[self.c_dlt_list_idx] = list_idx return row @@ -227,7 +226,7 @@ def _add_row_id( if row_id_type == "row_hash": row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) # link to parent table - DataItemNormalizer._link_row(flattened_row, parent_row_id, pos) + self._link_row(flattened_row, parent_row_id, pos) flattened_row[self.c_dlt_id] = row_id return row_id @@ -260,7 +259,6 @@ def _normalize_list( parent_row_id: Optional[str] = None, _r_lvl: int = 0, ) -> TNormalizedRowIterator: - v: DictStrAny = None table = self.schema.naming.shorten_fragments(*parent_path, *ident_path) for idx, v in enumerate(seq): @@ -283,9 +281,9 @@ def _normalize_list( else: # list of simple types child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) - wrap_v = wrap_in_dict(v) + wrap_v = wrap_in_dict(self.c_value, v) wrap_v[self.c_dlt_id] = child_row_hash - e = DataItemNormalizer._link_row(wrap_v, parent_row_id, idx) + e = self._link_row(wrap_v, parent_row_id, idx) DataItemNormalizer._extend_row(extend, e) yield (table, self.schema.naming.shorten_fragments(*parent_path)), e @@ -389,7 +387,7 @@ def normalize_data_item( ) -> TNormalizedRowIterator: # wrap items that are not dictionaries in dictionary, otherwise they cannot be processed by the JSON normalizer if not isinstance(item, dict): - item = wrap_in_dict(item) + item = wrap_in_dict(self.c_value, item) # we will extend event with all the fields necessary to load it as root row row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 1e1416eb53..8a07ddbd33 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -16,7 +16,6 @@ Optional, Protocol, Sequence, - TYPE_CHECKING, Tuple, TypeVar, TypedDict, @@ -36,6 +35,14 @@ from dlt.common.destination import TDestinationReferenceArg, TDestination from dlt.common.destination.exceptions import DestinationHasFailedJobs from dlt.common.exceptions import PipelineStateNotAvailable, SourceSectionNotAvailable +from dlt.common.metrics import ( + DataWriterMetrics, + ExtractDataInfo, + ExtractMetrics, + LoadMetrics, + NormalizeMetrics, + StepMetrics, +) from dlt.common.schema import Schema from dlt.common.schema.typing import ( TColumnNames, @@ -44,11 +51,12 @@ TSchemaContract, ) from dlt.common.source import get_current_pipe_name +from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.storages.load_storage import LoadPackageInfo from dlt.common.time import ensure_pendulum_datetime, precise_time from dlt.common.typing import DictStrAny, REPattern, StrAny, SupportsHumanize from dlt.common.jsonpath import delete_matches, TAnyJsonPath -from dlt.common.data_writers.writers import DataWriterMetrics, TLoaderFileFormat +from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.utils import RowCounts, merge_row_counts from dlt.common.versioned_state import TVersionedState @@ -68,15 +76,6 @@ class _StepInfo(NamedTuple): finished_at: datetime.datetime -class StepMetrics(TypedDict): - """Metrics for particular package processed in particular pipeline step""" - - started_at: datetime.datetime - """Start of package processing""" - finished_at: datetime.datetime - """End of package processing""" - - TStepMetricsCo = TypeVar("TStepMetricsCo", bound=StepMetrics, covariant=True) @@ -154,17 +153,20 @@ def _load_packages_asstr(load_packages: List[LoadPackageInfo], verbosity: int) - return msg @staticmethod - def job_metrics_asdict( + def writer_metrics_asdict( job_metrics: Dict[str, DataWriterMetrics], key_name: str = "job_id", extend: StrAny = None ) -> List[DictStrAny]: - jobs = [] - for job_id, metrics in job_metrics.items(): + entities = [] + for entity_id, metrics in job_metrics.items(): d = metrics._asdict() if extend: d.update(extend) - d[key_name] = job_id - jobs.append(d) - return jobs + d[key_name] = entity_id + # add job-level info if known + if metrics.file_path: + d["table_name"] = ParsedLoadJobFileName.parse(metrics.file_path).table_name + entities.append(d) + return entities def _astuple(self) -> _StepInfo: return _StepInfo( @@ -177,25 +179,6 @@ def _astuple(self) -> _StepInfo: ) -class ExtractDataInfo(TypedDict): - name: str - data_type: str - - -class ExtractMetrics(StepMetrics): - schema_name: str - job_metrics: Dict[str, DataWriterMetrics] - """Metrics collected per job id during writing of job file""" - table_metrics: Dict[str, DataWriterMetrics] - """Job metrics aggregated by table""" - resource_metrics: Dict[str, DataWriterMetrics] - """Job metrics aggregated by resource""" - dag: List[Tuple[str, str]] - """A resource dag where elements of the list are graph edges""" - hints: Dict[str, Dict[str, Any]] - """Hints passed to the resources""" - - class _ExtractInfo(NamedTuple): """NamedTuple cannot be part of the derivation chain so we must re-declare all fields to use it as mixin later""" @@ -228,16 +211,8 @@ def asdict(self) -> DictStrAny: for load_id, metrics_list in self.metrics.items(): for idx, metrics in enumerate(metrics_list): extend = {"load_id": load_id, "extract_idx": idx} - load_metrics["job_metrics"].extend( - self.job_metrics_asdict(metrics["job_metrics"], extend=extend) - ) - load_metrics["table_metrics"].extend( - self.job_metrics_asdict( - metrics["table_metrics"], key_name="table_name", extend=extend - ) - ) load_metrics["resource_metrics"].extend( - self.job_metrics_asdict( + self.writer_metrics_asdict( metrics["resource_metrics"], key_name="resource_name", extend=extend ) ) @@ -253,6 +228,15 @@ def asdict(self) -> DictStrAny: for name, hints in metrics["hints"].items() ] ) + load_metrics["job_metrics"].extend( + self.writer_metrics_asdict(metrics["job_metrics"], extend=extend) + ) + load_metrics["table_metrics"].extend( + self.writer_metrics_asdict( + metrics["table_metrics"], key_name="table_name", extend=extend + ) + ) + d.update(load_metrics) return d @@ -260,13 +244,6 @@ def asstr(self, verbosity: int = 0) -> str: return self._load_packages_asstr(self.load_packages, verbosity) -class NormalizeMetrics(StepMetrics): - job_metrics: Dict[str, DataWriterMetrics] - """Metrics collected per job id during writing of job file""" - table_metrics: Dict[str, DataWriterMetrics] - """Job metrics aggregated by table""" - - class _NormalizeInfo(NamedTuple): pipeline: "SupportsPipeline" metrics: Dict[str, List[NormalizeMetrics]] @@ -305,10 +282,10 @@ def asdict(self) -> DictStrAny: for idx, metrics in enumerate(metrics_list): extend = {"load_id": load_id, "extract_idx": idx} load_metrics["job_metrics"].extend( - self.job_metrics_asdict(metrics["job_metrics"], extend=extend) + self.writer_metrics_asdict(metrics["job_metrics"], extend=extend) ) load_metrics["table_metrics"].extend( - self.job_metrics_asdict( + self.writer_metrics_asdict( metrics["table_metrics"], key_name="table_name", extend=extend ) ) @@ -326,10 +303,6 @@ def asstr(self, verbosity: int = 0) -> str: return msg -class LoadMetrics(StepMetrics): - pass - - class _LoadInfo(NamedTuple): pipeline: "SupportsPipeline" metrics: Dict[str, List[LoadMetrics]] @@ -354,7 +327,19 @@ class LoadInfo(StepInfo[LoadMetrics], _LoadInfo): # type: ignore[misc] def asdict(self) -> DictStrAny: """A dictionary representation of LoadInfo that can be loaded with `dlt`""" - return super().asdict() + d = super().asdict() + # transform metrics + d.pop("metrics") + load_metrics: Dict[str, List[Any]] = {"job_metrics": []} + for load_id, metrics_list in self.metrics.items(): + # one set of metrics per package id + assert len(metrics_list) == 1 + metrics = metrics_list[0] + for job_metrics in metrics["job_metrics"].values(): + load_metrics["job_metrics"].append({"load_id": load_id, **job_metrics._asdict()}) + + d.update(load_metrics) + return d def asstr(self, verbosity: int = 0) -> str: msg = f"Pipeline {self.pipeline.pipeline_name} load step completed in " diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a4dd51d4b..a81e9046a9 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -94,6 +94,7 @@ class TColumnType(TypedDict, total=False): data_type: Optional[TDataType] precision: Optional[int] scale: Optional[int] + timezone: Optional[bool] class TColumnSchemaBase(TColumnType, total=False): @@ -187,6 +188,7 @@ class TMergeDispositionDict(TWriteDispositionDict, total=False): strategy: Optional[TLoaderMergeStrategy] validity_column_names: Optional[List[str]] active_record_timestamp: Optional[TAnyDateTime] + boundary_timestamp: Optional[TAnyDateTime] row_version_column_name: Optional[str] diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index 7bb3c0cf97..50876a01cd 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -8,7 +8,7 @@ LoadJobInfo, LoadPackageInfo, PackageStorage, - TJobState, + TPackageJobState, create_load_id, ) from .data_item_storage import DataItemStorage @@ -40,7 +40,7 @@ "LoadJobInfo", "LoadPackageInfo", "PackageStorage", - "TJobState", + "TPackageJobState", "create_load_id", "fsspec_from_config", "fsspec_filesystem", diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index b2bdb3a7b6..04780528c4 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -1,7 +1,7 @@ import os import pathlib from typing import Any, Literal, Optional, Type, get_args, ClassVar, Dict, Union -from urllib.parse import urlparse, unquote +from urllib.parse import urlparse, unquote, urlunparse from dlt.common.configuration import configspec, resolve_type from dlt.common.configuration.exceptions import ConfigurationValueError @@ -52,6 +52,53 @@ class LoadStorageConfiguration(BaseConfiguration): ] +def _make_az_url(scheme: str, fs_path: str, bucket_url: str) -> str: + parsed_bucket_url = urlparse(bucket_url) + if parsed_bucket_url.username: + # az://@.dfs.core.windows.net/ + # fs_path always starts with container + split_path = fs_path.split("/", maxsplit=1) + if len(split_path) == 1: + split_path.append("") + container, path = split_path + netloc = f"{container}@{parsed_bucket_url.hostname}" + return urlunparse(parsed_bucket_url._replace(path=path, scheme=scheme, netloc=netloc)) + return f"{scheme}://{fs_path}" + + +def _make_file_url(scheme: str, fs_path: str, bucket_url: str) -> str: + """Creates a normalized file:// url from a local path + + netloc is never set. UNC paths are represented as file://host/path + """ + p_ = pathlib.Path(fs_path) + p_ = p_.expanduser().resolve() + return p_.as_uri() + + +MAKE_URI_DISPATCH = {"az": _make_az_url, "file": _make_file_url} + +MAKE_URI_DISPATCH["adl"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["abfs"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["azure"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["abfss"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["local"] = MAKE_URI_DISPATCH["file"] + + +def make_fsspec_url(scheme: str, fs_path: str, bucket_url: str) -> str: + """Creates url from `fs_path` and `scheme` using bucket_url as an `url` template + + Args: + scheme (str): scheme of the resulting url + fs_path (str): kind of absolute path that fsspec uses to locate resources for particular filesystem. + bucket_url (str): an url template. the structure of url will be preserved if possible + """ + _maker = MAKE_URI_DISPATCH.get(scheme) + if _maker: + return _maker(scheme, fs_path, bucket_url) + return f"{scheme}://{fs_path}" + + @configspec class FilesystemConfiguration(BaseConfiguration): """A configuration defining filesystem location and access credentials. @@ -59,7 +106,7 @@ class FilesystemConfiguration(BaseConfiguration): When configuration is resolved, `bucket_url` is used to extract a protocol and request corresponding credentials class. * s3 * gs, gcs - * az, abfs, adl + * az, abfs, adl, abfss, azure * file, memory * gdrive """ @@ -72,6 +119,8 @@ class FilesystemConfiguration(BaseConfiguration): "az": AnyAzureCredentials, "abfs": AnyAzureCredentials, "adl": AnyAzureCredentials, + "abfss": AnyAzureCredentials, + "azure": AnyAzureCredentials, } bucket_url: str = None @@ -93,17 +142,21 @@ def protocol(self) -> str: else: return urlparse(self.bucket_url).scheme + @property + def is_local_filesystem(self) -> bool: + return self.protocol == "file" + def on_resolved(self) -> None: - uri = urlparse(self.bucket_url) - if not uri.path and not uri.netloc: + url = urlparse(self.bucket_url) + if not url.path and not url.netloc: raise ConfigurationValueError( "File path and netloc are missing. Field bucket_url of" - " FilesystemClientConfiguration must contain valid uri with a path or host:password" + " FilesystemClientConfiguration must contain valid url with a path or host:password" " component." ) # this is just a path in a local file system if self.is_local_path(self.bucket_url): - self.bucket_url = self.make_file_uri(self.bucket_url) + self.bucket_url = self.make_file_url(self.bucket_url) @resolve_type("credentials") def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: @@ -122,44 +175,50 @@ def fingerprint(self) -> str: if self.is_local_path(self.bucket_url): return digest128("") - uri = urlparse(self.bucket_url) - return digest128(self.bucket_url.replace(uri.path, "")) + url = urlparse(self.bucket_url) + return digest128(self.bucket_url.replace(url.path, "")) + + def make_url(self, fs_path: str) -> str: + """Makes a full url (with scheme) form fs_path which is kind-of absolute path used by fsspec to identify resources. + This method will use `bucket_url` to infer the original form of the url. + """ + return make_fsspec_url(self.protocol, fs_path, self.bucket_url) def __str__(self) -> str: """Return displayable destination location""" - uri = urlparse(self.bucket_url) + url = urlparse(self.bucket_url) # do not show passwords - if uri.password: - new_netloc = f"{uri.username}:****@{uri.hostname}" - if uri.port: - new_netloc += f":{uri.port}" - return uri._replace(netloc=new_netloc).geturl() + if url.password: + new_netloc = f"{url.username}:****@{url.hostname}" + if url.port: + new_netloc += f":{url.port}" + return url._replace(netloc=new_netloc).geturl() return self.bucket_url @staticmethod - def is_local_path(uri: str) -> bool: - """Checks if `uri` is a local path, without a schema""" - uri_parsed = urlparse(uri) + def is_local_path(url: str) -> bool: + """Checks if `url` is a local path, without a schema""" + url_parsed = urlparse(url) # this prevents windows absolute paths to be recognized as schemas - return not uri_parsed.scheme or os.path.isabs(uri) + return not url_parsed.scheme or os.path.isabs(url) @staticmethod - def make_local_path(file_uri: str) -> str: + def make_local_path(file_url: str) -> str: """Gets a valid local filesystem path from file:// scheme. Supports POSIX/Windows/UNC paths Returns: str: local filesystem path """ - uri = urlparse(file_uri) - if uri.scheme != "file": - raise ValueError(f"Must be file scheme but is {uri.scheme}") - if not uri.path and not uri.netloc: + url = urlparse(file_url) + if url.scheme != "file": + raise ValueError(f"Must be file scheme but is {url.scheme}") + if not url.path and not url.netloc: raise ConfigurationValueError("File path and netloc are missing.") - local_path = unquote(uri.path) - if uri.netloc: + local_path = unquote(url.path) + if url.netloc: # or UNC file://localhost/path - local_path = "//" + unquote(uri.netloc) + local_path + local_path = "//" + unquote(url.netloc) + local_path else: # if we are on windows, strip the POSIX root from path which is always absolute if os.path.sep != local_path[0]: @@ -172,11 +231,9 @@ def make_local_path(file_uri: str) -> str: return str(pathlib.Path(local_path)) @staticmethod - def make_file_uri(local_path: str) -> str: - """Creates a normalized file:// uri from a local path + def make_file_url(local_path: str) -> str: + """Creates a normalized file:// url from a local path netloc is never set. UNC paths are represented as file://host/path """ - p_ = pathlib.Path(local_path) - p_ = p_.expanduser().resolve() - return p_.as_uri() + return make_fsspec_url("file", local_path, None) diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 29a9da8acf..0f70c04bc5 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -1,14 +1,13 @@ -from pathlib import Path -from typing import Dict, Any, List, Sequence +from typing import Dict, Any, List from abc import ABC, abstractmethod from dlt.common import logger +from dlt.common.metrics import DataWriterMetrics from dlt.common.schema import TTableSchemaColumns -from dlt.common.typing import StrAny, TDataItems +from dlt.common.typing import TDataItems from dlt.common.data_writers import ( BufferedDataWriter, DataWriter, - DataWriterMetrics, FileWriterSpec, ) diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index 7d14b8f7f7..f26cc060a3 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -3,7 +3,6 @@ import re import stat import errno -import tempfile import shutil import pathvalidate from typing import IO, Any, Optional, List, cast @@ -29,10 +28,8 @@ def save(self, relative_path: str, data: Any) -> str: @staticmethod def save_atomic(storage_path: str, relative_path: str, data: Any, file_type: str = "t") -> str: mode = "w" + file_type - with tempfile.NamedTemporaryFile( - dir=storage_path, mode=mode, delete=False, encoding=encoding_for_mode(mode) - ) as f: - tmp_path = f.name + tmp_path = os.path.join(storage_path, uniq_id(8)) + with open(tmp_path, mode=mode, encoding=encoding_for_mode(mode)) as f: f.write(data) try: dest_path = os.path.join(storage_path, relative_path) @@ -116,11 +113,11 @@ def open_file(self, relative_path: str, mode: str = "r") -> IO[Any]: return FileStorage.open_zipsafe_ro(self.make_full_path(relative_path), mode) return open(self.make_full_path(relative_path), mode, encoding=encoding_for_mode(mode)) - def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: - mode = mode + file_type or self.file_type - return tempfile.NamedTemporaryFile( - dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode) - ) + # def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: + # mode = mode + file_type or self.file_type + # return tempfile.NamedTemporaryFile( + # dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode) + # ) def has_file(self, relative_path: str) -> bool: return os.path.isfile(self.make_full_path(relative_path)) diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index be9ae2bbb1..7da5ebabef 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -21,7 +21,7 @@ ) from urllib.parse import urlparse -from fsspec import AbstractFileSystem, register_implementation +from fsspec import AbstractFileSystem, register_implementation, get_filesystem_class from fsspec.core import url_to_fs from dlt import version @@ -32,7 +32,11 @@ AzureCredentials, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.storages.configuration import FileSystemCredentials, FilesystemConfiguration +from dlt.common.storages.configuration import ( + FileSystemCredentials, + FilesystemConfiguration, + make_fsspec_url, +) from dlt.common.time import ensure_pendulum_datetime from dlt.common.typing import DictStrAny @@ -65,18 +69,20 @@ class FileItem(TypedDict, total=False): MTIME_DISPATCH["gs"] = MTIME_DISPATCH["gcs"] MTIME_DISPATCH["s3a"] = MTIME_DISPATCH["s3"] MTIME_DISPATCH["abfs"] = MTIME_DISPATCH["az"] +MTIME_DISPATCH["abfss"] = MTIME_DISPATCH["az"] # Map of protocol to a filesystem type CREDENTIALS_DISPATCH: Dict[str, Callable[[FilesystemConfiguration], DictStrAny]] = { "s3": lambda config: cast(AwsCredentials, config.credentials).to_s3fs_credentials(), - "adl": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), "az": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "gcs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gdrive": lambda config: {"credentials": cast(GcpCredentials, config.credentials)}, - "abfs": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "azure": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), } +CREDENTIALS_DISPATCH["adl"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["abfs"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["azure"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["abfss"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["gcs"] = CREDENTIALS_DISPATCH["gs"] def fsspec_filesystem( @@ -90,7 +96,7 @@ def fsspec_filesystem( Please supply credentials instance corresponding to the protocol. The `protocol` is just the code name of the filesystem i.e.: * s3 - * az, abfs + * az, abfs, abfss, adl, azure * gcs, gs also see filesystem_from_config @@ -136,7 +142,7 @@ def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSys Authenticates following filesystems: * s3 - * az, abfs + * az, abfs, abfss, adl, azure * gcs, gs All other filesystems are not authenticated @@ -146,8 +152,14 @@ def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSys fs_kwargs = prepare_fsspec_args(config) try: + # first get the class to check the protocol + fs_cls = get_filesystem_class(config.protocol) + if fs_cls.protocol == "abfs": + # if storage account is present in bucket_url and in credentials, az fsspec will fail + if urlparse(config.bucket_url).username: + fs_kwargs.pop("account_name") return url_to_fs(config.bucket_url, **fs_kwargs) # type: ignore - except ModuleNotFoundError as e: + except ImportError as e: raise MissingDependencyException( "filesystem", [f"{version.DLT_PKG_NAME}[{config.protocol}]"] ) from e @@ -291,10 +303,8 @@ def glob_files( """ is_local_fs = "file" in fs_client.protocol if is_local_fs and FilesystemConfiguration.is_local_path(bucket_url): - bucket_url = FilesystemConfiguration.make_file_uri(bucket_url) - bucket_url_parsed = urlparse(bucket_url) - else: - bucket_url_parsed = urlparse(bucket_url) + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + bucket_url_parsed = urlparse(bucket_url) if is_local_fs: root_dir = FilesystemConfiguration.make_local_path(bucket_url) @@ -302,7 +312,8 @@ def glob_files( files = glob.glob(str(pathlib.Path(root_dir).joinpath(file_glob)), recursive=True) glob_result = {file: fs_client.info(file) for file in files} else: - root_dir = bucket_url_parsed._replace(scheme="", query="").geturl().lstrip("/") + # convert to fs_path + root_dir = fs_client._strip_protocol(bucket_url) filter_url = posixpath.join(root_dir, file_glob) glob_result = fs_client.glob(filter_url, detail=True) if isinstance(glob_result, list): @@ -314,20 +325,23 @@ def glob_files( for file, md in glob_result.items(): if md["type"] != "file": continue + scheme = bucket_url_parsed.scheme + # relative paths are always POSIX if is_local_fs: - rel_path = pathlib.Path(file).relative_to(root_dir).as_posix() - file_url = FilesystemConfiguration.make_file_uri(file) + # use OS pathlib for local paths + loc_path = pathlib.Path(file) + file_name = loc_path.name + rel_path = loc_path.relative_to(root_dir).as_posix() + file_url = FilesystemConfiguration.make_file_url(file) else: - rel_path = posixpath.relpath(file.lstrip("/"), root_dir) - file_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, rel_path) - ).geturl() + file_name = posixpath.basename(file) + rel_path = posixpath.relpath(file, root_dir) + file_url = make_fsspec_url(scheme, file, bucket_url) - scheme = bucket_url_parsed.scheme mime_type, encoding = guess_mime_type(rel_path) yield FileItem( - file_name=posixpath.basename(rel_path), + file_name=file_name, relative_path=rel_path, file_url=file_url, mime_type=mime_type, diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index b0ed93f734..d569fbe662 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -143,8 +143,8 @@ def create_load_id() -> str: # folders to manage load jobs in a single load package -TJobState = Literal["new_jobs", "failed_jobs", "started_jobs", "completed_jobs"] -WORKING_FOLDERS: Set[TJobState] = set(get_args(TJobState)) +TPackageJobState = Literal["new_jobs", "failed_jobs", "started_jobs", "completed_jobs"] +WORKING_FOLDERS: Set[TPackageJobState] = set(get_args(TPackageJobState)) TLoadPackageStatus = Literal["new", "extracted", "normalized", "loaded", "aborted"] @@ -191,7 +191,7 @@ def __str__(self) -> str: class LoadJobInfo(NamedTuple): - state: TJobState + state: TPackageJobState file_path: str file_size: int created_at: datetime.datetime @@ -204,6 +204,7 @@ def asdict(self) -> DictStrAny: # flatten del d["job_file_info"] d.update(self.job_file_info._asdict()) + d["job_id"] = self.job_file_info.job_id() return d def asstr(self, verbosity: int = 0) -> str: @@ -241,7 +242,7 @@ class _LoadPackageInfo(NamedTuple): schema: Schema schema_update: TSchemaTables completed_at: datetime.datetime - jobs: Dict[TJobState, List[LoadJobInfo]] + jobs: Dict[TPackageJobState, List[LoadJobInfo]] class LoadPackageInfo(SupportsHumanize, _LoadPackageInfo): @@ -298,10 +299,10 @@ def __str__(self) -> str: class PackageStorage: - NEW_JOBS_FOLDER: ClassVar[TJobState] = "new_jobs" - FAILED_JOBS_FOLDER: ClassVar[TJobState] = "failed_jobs" - STARTED_JOBS_FOLDER: ClassVar[TJobState] = "started_jobs" - COMPLETED_JOBS_FOLDER: ClassVar[TJobState] = "completed_jobs" + NEW_JOBS_FOLDER: ClassVar[TPackageJobState] = "new_jobs" + FAILED_JOBS_FOLDER: ClassVar[TPackageJobState] = "failed_jobs" + STARTED_JOBS_FOLDER: ClassVar[TPackageJobState] = "started_jobs" + COMPLETED_JOBS_FOLDER: ClassVar[TPackageJobState] = "completed_jobs" SCHEMA_FILE_NAME: ClassVar[str] = "schema.json" SCHEMA_UPDATES_FILE_NAME = ( # updates to the tables in schema created by normalizer @@ -330,11 +331,11 @@ def get_package_path(self, load_id: str) -> str: """Gets path of the package relative to storage root""" return load_id - def get_job_state_folder_path(self, load_id: str, state: TJobState) -> str: + def get_job_state_folder_path(self, load_id: str, state: TPackageJobState) -> str: """Gets path to the jobs in `state` in package `load_id`, relative to the storage root""" return os.path.join(self.get_package_path(load_id), state) - def get_job_file_path(self, load_id: str, state: TJobState, file_name: str) -> str: + def get_job_file_path(self, load_id: str, state: TPackageJobState, file_name: str) -> str: """Get path to job with `file_name` in `state` in package `load_id`, relative to the storage root""" return os.path.join(self.get_job_state_folder_path(load_id, state), file_name) @@ -369,12 +370,12 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: def list_job_with_states_for_table( self, load_id: str, table_name: str - ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + ) -> Sequence[Tuple[TPackageJobState, ParsedLoadJobFileName]]: return self.filter_jobs_for_table(self.list_all_jobs_with_states(load_id), table_name) def list_all_jobs_with_states( self, load_id: str - ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + ) -> Sequence[Tuple[TPackageJobState, ParsedLoadJobFileName]]: info = self.get_load_package_jobs(load_id) state_jobs = [] for state, jobs in info.items(): @@ -413,7 +414,7 @@ def is_package_completed(self, load_id: str) -> bool: # def import_job( - self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" + self, load_id: str, job_file_path: str, job_state: TPackageJobState = "new_jobs" ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" self.storage.atomic_import( @@ -568,12 +569,14 @@ def get_load_package_state_path(self, load_id: str) -> str: # Get package info # - def get_load_package_jobs(self, load_id: str) -> Dict[TJobState, List[ParsedLoadJobFileName]]: + def get_load_package_jobs( + self, load_id: str + ) -> Dict[TPackageJobState, List[ParsedLoadJobFileName]]: """Gets all jobs in a package and returns them as lists assigned to a particular state.""" package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) - all_jobs: Dict[TJobState, List[ParsedLoadJobFileName]] = {} + all_jobs: Dict[TPackageJobState, List[ParsedLoadJobFileName]] = {} for state in WORKING_FOLDERS: jobs: List[ParsedLoadJobFileName] = [] with contextlib.suppress(FileNotFoundError): @@ -616,7 +619,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: schema = Schema.from_dict(self._load_schema(load_id)) # read jobs with all statuses - all_job_infos: Dict[TJobState, List[LoadJobInfo]] = {} + all_job_infos: Dict[TPackageJobState, List[LoadJobInfo]] = {} for state, jobs in package_jobs.items(): all_job_infos[state] = [ self._read_job_file_info(load_id, state, job, package_created_at) for job in jobs @@ -643,7 +646,7 @@ def get_job_failed_message(self, load_id: str, job: ParsedLoadJobFileName) -> st return failed_message def job_to_job_info( - self, load_id: str, state: TJobState, job: ParsedLoadJobFileName + self, load_id: str, state: TPackageJobState, job: ParsedLoadJobFileName ) -> LoadJobInfo: """Creates partial job info by converting job object. size, mtime and failed message will not be populated""" full_path = os.path.join( @@ -660,7 +663,11 @@ def job_to_job_info( ) def _read_job_file_info( - self, load_id: str, state: TJobState, job: ParsedLoadJobFileName, now: DateTime = None + self, + load_id: str, + state: TPackageJobState, + job: ParsedLoadJobFileName, + now: DateTime = None, ) -> LoadJobInfo: """Creates job info by reading additional props from storage""" failed_message = None @@ -687,8 +694,8 @@ def _read_job_file_info( def _move_job( self, load_id: str, - source_folder: TJobState, - dest_folder: TJobState, + source_folder: TPackageJobState, + dest_folder: TPackageJobState, file_name: str, new_file_name: str = None, ) -> str: @@ -736,8 +743,8 @@ def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: @staticmethod def filter_jobs_for_table( - all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], table_name: str - ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + all_jobs: Iterable[Tuple[TPackageJobState, ParsedLoadJobFileName]], table_name: str + ) -> Sequence[Tuple[TPackageJobState, ParsedLoadJobFileName]]: return [job for job in all_jobs if job[1].table_name == table_name] diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 00e95fbad9..8ac1d74e9a 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -17,7 +17,7 @@ LoadPackageInfo, PackageStorage, ParsedLoadJobFileName, - TJobState, + TPackageJobState, TLoadPackageState, TJobFileFormat, ) @@ -141,16 +141,16 @@ def commit_schema_update(self, load_id: str, applied_update: TSchemaTables) -> N """Marks schema update as processed and stores the update that was applied at the destination""" load_path = self.get_normalized_package_path(load_id) schema_update_file = join(load_path, PackageStorage.SCHEMA_UPDATES_FILE_NAME) - processed_schema_update_file = join( + applied_schema_update_file = join( load_path, PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME ) # delete initial schema update self.storage.delete(schema_update_file) # save applied update - self.storage.save(processed_schema_update_file, json.dumps(applied_update)) + self.storage.save(applied_schema_update_file, json.dumps(applied_update)) def import_new_job( - self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" + self, load_id: str, job_file_path: str, job_state: TPackageJobState = "new_jobs" ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" # TODO: use normalize storage and add file type checks diff --git a/dlt/destinations/fs_client.py b/dlt/destinations/fs_client.py index 3233446594..14e77b6b4e 100644 --- a/dlt/destinations/fs_client.py +++ b/dlt/destinations/fs_client.py @@ -3,9 +3,12 @@ from abc import ABC, abstractmethod from fsspec import AbstractFileSystem +from dlt.common.schema import Schema + class FSClientBase(ABC): fs_client: AbstractFileSystem + schema: Schema @property @abstractmethod diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 371c1bae22..c4a9bab212 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -34,7 +34,6 @@ from dlt.common import logger from dlt.common.exceptions import TerminalValueError -from dlt.common.storages.fsspec_filesystem import fsspec_from_config from dlt.common.utils import uniq_id, without_none from dlt.common.schema import TColumnSchema, Schema, TTableSchema from dlt.common.schema.typing import ( @@ -46,7 +45,7 @@ from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob -from dlt.common.destination.reference import FollowupJob, SupportsStagingDestination +from dlt.common.destination.reference import FollowupJobRequest, SupportsStagingDestination from dlt.common.data_writers.escape import escape_hive_identifier from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob @@ -105,9 +104,9 @@ class AthenaTypeMapper(TypeMapper): def __init__(self, capabilities: DestinationCapabilitiesContext): super().__init__(capabilities) - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") + table_format = table.get("table_format") if precision is None: return "bigint" if precision <= 8: @@ -404,9 +403,9 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: return ( - f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table)}" ) def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str: @@ -430,9 +429,9 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables table = self.prepare_load_table(table_name, self.in_staging_mode) - table_format = table.get("table_format") + is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + columns = ", ".join([self._get_column_def_sql(c, table) for c in new_columns]) # create unique tag for iceberg table so it is never recreated in the same folder # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh @@ -452,7 +451,7 @@ def _get_table_update_sql( partition_clause = self._iceberg_partition_clause( cast(Optional[Dict[str, str]], table.get(PARTITION_HINT)) ) - sql.append(f"""CREATE TABLE {qualified_table_name} + sql.append(f"""{self._make_create_table(qualified_table_name, table)} ({columns}) {partition_clause} LOCATION '{location.rstrip('/')}' @@ -490,7 +489,7 @@ def create_load_job( def _create_append_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyFollowupJob.from_table_chain( @@ -501,7 +500,7 @@ def _create_append_followup_jobs( def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyFollowupJob.from_table_chain( @@ -510,7 +509,9 @@ def _create_replace_followup_jobs( ] return super()._create_replace_followup_jobs(table_chain) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [AthenaMergeJob.from_table_chain(table_chain, self.sql_client)] def _is_iceberg_table(self, table: TTableSchema) -> bool: diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index c6bf2e7654..9bc555bd0d 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -16,7 +16,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( HasFollowupJobs, - FollowupJob, + FollowupJobRequest, TLoadJobState, RunnableLoadJob, SupportsStagingDestination, @@ -51,7 +51,7 @@ from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.utils import parse_db_data_type_str_with_precision @@ -90,9 +90,9 @@ class BigQueryTypeMapper(TypeMapper): "TIME": "time", } - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: + def to_db_decimal_type(self, column: TColumnSchema) -> str: # Use BigQuery's BIGNUMERIC for large precision decimals - precision, scale = self.decimal_precision(precision, scale) + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) if precision > 38 or scale > 9: return "BIGNUMERIC(%i,%i)" % (precision, scale) return "NUMERIC(%i,%i)" % (precision, scale) @@ -234,7 +234,9 @@ def __init__( self.sql_client: BigQuerySqlClient = sql_client # type: ignore self.type_mapper = BigQueryTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] def create_load_job( @@ -415,10 +417,10 @@ def _get_info_schema_columns_query( return query, folded_table_names - def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, column: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( - f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" @@ -430,11 +432,11 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load # append to table for merge loads (append to stage) and regular appends. table_name = table["name"] - # determine whether we load from local or uri + # determine whether we load from local or url bucket_path = None ext: str = os.path.splitext(file_path)[1][1:] - if ReferenceFollowupJob.is_reference_job(file_path): - bucket_path = ReferenceFollowupJob.resolve_reference(file_path) + if ReferenceFollowupJobRequest.is_reference_job(file_path): + bucket_path = ReferenceFollowupJobRequest.resolve_reference(file_path) ext = os.path.splitext(bucket_path)[1][1:] # Select a correct source format @@ -501,6 +503,9 @@ def _should_autodetect_schema(self, table_name: str) -> bool: self.schema._schema_tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True ) or (self.config.autodetect_schema and table_name not in self.schema.dlt_table_names()) + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + def _streaming_load( items: List[Dict[Any, Any]], table: Dict[str, Any], job_client: BigQueryClient diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 5bd34e0e0d..038735a84b 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -20,7 +20,7 @@ TLoadJobState, HasFollowupJobs, RunnableLoadJob, - FollowupJob, + FollowupJobRequest, LoadJob, ) from dlt.common.schema import Schema, TColumnSchema @@ -52,7 +52,7 @@ SqlJobClientBase, SqlJobClientWithStaging, ) -from dlt.destinations.job_impl import ReferenceFollowupJob, FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_impl import ReferenceFollowupJobRequest, FinalizedLoadJobWithFollowupJobs from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -141,8 +141,8 @@ def run(self) -> None: bucket_path = None file_name = self._file_name - if ReferenceFollowupJob.is_reference_job(self._file_path): - bucket_path = ReferenceFollowupJob.resolve_reference(self._file_path) + if ReferenceFollowupJobRequest.is_reference_job(self._file_path): + bucket_path = ReferenceFollowupJobRequest.resolve_reference(self._file_path) file_name = FileStorage.get_file_name_from_file_path(bucket_path) bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme @@ -288,10 +288,12 @@ def __init__( self.active_hints = deepcopy(HINT_TO_CLICKHOUSE_ATTR) self.type_mapper = ClickHouseTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. hints_ = " ".join( @@ -305,9 +307,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non # Alter table statements only accept `Nullable` modifiers. # JSON type isn't nullable in ClickHouse. type_with_nullability_modifier = ( - f"Nullable({self.type_mapper.to_db_type(c)})" + f"Nullable({self.type_mapper.to_db_type(c,table)})" if c.get("nullable", True) - else self.type_mapper.to_db_type(c) + else self.type_mapper.to_db_type(c, table) ) return ( @@ -370,3 +372,6 @@ def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(ch_t, precision, scale) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 3bd2d12a5a..789dbedae9 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -43,6 +43,10 @@ def to_connector_params(self) -> Dict[str, Any]: class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="databricks", init=False, repr=False, compare=False) # type: ignore[misc] credentials: DatabricksCredentials = None + staging_credentials_name: Optional[str] = None + "If set, credentials with given name will be used in copy command" + is_staging_external_location: bool = False + """If true, the temporary credentials are not propagated to the COPY command""" def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 0a203c21b6..0c19984b4c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,25 +1,22 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast +from typing import Optional, Sequence, List, cast from urllib.parse import urlparse, urlunparse from dlt import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( HasFollowupJobs, - FollowupJob, - TLoadJobState, + FollowupJobRequest, RunnableLoadJob, - CredentialsConfiguration, SupportsStagingDestination, LoadJob, ) from dlt.common.configuration.specs import ( AwsCredentialsWithoutDefaults, - AzureCredentials, AzureCredentialsWithoutDefaults, ) from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat from dlt.common.schema.utils import table_schema_has_type from dlt.common.storages import FilesystemConfiguration, fsspec_from_config @@ -31,10 +28,13 @@ from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper +AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "abfss", "abfs"] + + class DatabricksTypeMapper(TypeMapper): sct_to_unbound_dbt = { "complex": "STRING", # Databricks supports complex types like ARRAY @@ -68,9 +68,8 @@ class DatabricksTypeMapper(TypeMapper): "wei": "DECIMAL(%i,%i)", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" if precision <= 8: @@ -121,8 +120,8 @@ def run(self) -> None: staging_credentials = self._staging_config.credentials # extract and prepare some vars bucket_path = orig_bucket_path = ( - ReferenceFollowupJob.resolve_reference(self._file_path) - if ReferenceFollowupJob.is_reference_job(self._file_path) + ReferenceFollowupJobRequest.resolve_reference(self._file_path) + if ReferenceFollowupJobRequest.is_reference_job(self._file_path) else "" ) file_name = ( @@ -137,41 +136,51 @@ def run(self) -> None: if bucket_path: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - # referencing an staged files via a bucket URL requires explicit AWS credentials - if bucket_scheme == "s3" and isinstance( - staging_credentials, AwsCredentialsWithoutDefaults - ): - s3_creds = staging_credentials.to_session_credentials() - credentials_clause = f"""WITH(CREDENTIAL( - AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', - AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', - - AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' - )) - """ - from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme in ["az", "abfs"] and isinstance( - staging_credentials, AzureCredentialsWithoutDefaults - ): - # Explicit azure credentials are needed to load from bucket without a named stage - credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" - # Converts an az:/// to abfss://@.dfs.core.windows.net/ - # as required by snowflake - _path = bucket_url.path - bucket_path = urlunparse( - bucket_url._replace( - scheme="abfss", - netloc=f"{bucket_url.netloc}@{staging_credentials.azure_storage_account_name}.dfs.core.windows.net", - path=_path, - ) - ) - from_clause = f"FROM '{bucket_path}'" - else: + + if bucket_scheme not in AZURE_BLOB_STORAGE_PROTOCOLS + ["s3"]: raise LoadJobTerminalException( self._file_path, f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" " azure buckets are supported", ) + + if self._job_client.config.is_staging_external_location: + # just skip the credentials clause for external location + # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location + pass + elif self._job_client.config.staging_credentials_name: + # add named credentials + credentials_clause = ( + f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" + ) + else: + # referencing an staged files via a bucket URL requires explicit AWS credentials + if bucket_scheme == "s3": + assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + s3_creds = staging_credentials.to_session_credentials() + credentials_clause = f"""WITH(CREDENTIAL( + AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', + AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', + + AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' + )) + """ + elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + # Explicit azure credentials are needed to load from bucket without a named stage + credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, staging_credentials.azure_storage_account_name + ) + + if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, staging_credentials.azure_storage_account_name + ) + + # always add FROM clause + from_clause = f"FROM '{bucket_path}'" else: raise LoadJobTerminalException( self._file_path, @@ -231,6 +240,34 @@ def run(self) -> None: """ self._sql_client.execute_sql(statement) + @staticmethod + def ensure_databricks_abfss_url( + bucket_path: str, azure_storage_account_name: str = None + ) -> str: + bucket_url = urlparse(bucket_path) + # Converts an az:/// to abfss://@.dfs.core.windows.net/ + if bucket_url.username: + # has the right form, ensure abfss schema + return urlunparse(bucket_url._replace(scheme="abfss")) + + if not azure_storage_account_name: + raise TerminalValueError( + f"Could not convert azure blob storage url {bucket_path} into form required by" + " Databricks" + " (abfss://@.dfs.core.windows.net/)" + " because storage account name is not known. Please use Databricks abfss://" + " canonical url as bucket_url in staging credentials" + ) + # as required by databricks + _path = bucket_url.path + return urlunparse( + bucket_url._replace( + scheme="abfss", + netloc=f"{bucket_url.netloc}@{azure_storage_account_name}.dfs.core.windows.net", + path=_path, + ) + ) + class DatabricksMergeJob(SqlMergeFollowupJob): @classmethod @@ -279,14 +316,18 @@ def create_load_job( ) return job - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return [ + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) + ] def _get_table_update_sql( self, @@ -311,10 +352,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: @@ -323,3 +364,6 @@ def _get_storage_table_query_columns(self) -> List[str]: "full_data_type" ) return fields + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 409d3bc4be..6108b69da9 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -54,6 +54,8 @@ def client_class(self) -> t.Type["DatabricksClient"]: def __init__( self, credentials: t.Union[DatabricksCredentials, t.Dict[str, t.Any], str] = None, + is_staging_external_location: t.Optional[bool] = False, + staging_credentials_name: t.Optional[str] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -65,10 +67,14 @@ def __init__( Args: credentials: Credentials to connect to the databricks database. Can be an instance of `DatabricksCredentials` or a connection string in the format `databricks://user:password@host:port/database` + is_staging_external_location: If true, the temporary credentials are not propagated to the COPY command + staging_credentials_name: If set, credentials with given name will be used in copy command **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, + is_staging_external_location=is_staging_external_location, + staging_credentials_name=staging_credentials_name, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 3611665f6c..91dc64f113 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -7,7 +7,7 @@ TLoadJobState, RunnableLoadJob, SupportsStagingDestination, - FollowupJob, + FollowupJobRequest, LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns @@ -19,7 +19,7 @@ from dlt.destinations.impl.dremio.sql_client import DremioSqlClient from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.sql_client import SqlClientBase @@ -101,8 +101,8 @@ def run(self) -> None: # extract and prepare some vars bucket_path = ( - ReferenceFollowupJob.resolve_reference(self._file_path) - if ReferenceFollowupJob.is_reference_job(self._file_path) + ReferenceFollowupJobRequest.resolve_reference(self._file_path) + if ReferenceFollowupJobRequest.is_reference_job(self._file_path) else "" ) @@ -195,16 +195,25 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: - return ["ADD COLUMNS (" + ", ".join(self._get_column_def_sql(c) for c in new_columns) + ")"] + return [ + "ADD COLUMNS (" + + ", ".join(self._get_column_def_sql(c, table) for c in new_columns) + + ")" + ] + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3d5905ff40..d5065f5bdd 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -62,9 +62,8 @@ class DuckDbTypeMapper(TypeMapper): "TIMESTAMP_NS": "timestamp", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" # Precision is number of bits @@ -83,19 +82,39 @@ def to_db_integer_type( ) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone and precision is not None: + raise TerminalValueError( + f"DuckDB does not support both timezone and precision for column '{column_name}' in" + f" table '{table_name}'. To resolve this issue, either set timezone to False or" + " None, or use the default precision." + ) + + if timezone: + return "TIMESTAMP WITH TIME ZONE" + elif timezone is not None: # condition for when timezone is False given that none is falsy + return "TIMESTAMP" + if precision is None or precision == 6: - return super().to_db_datetime_type(precision, table_format) - if precision == 0: + return None + elif precision == 0: return "TIMESTAMP_S" - if precision == 3: + elif precision == 3: return "TIMESTAMP_MS" - if precision == 9: + elif precision == 9: return "TIMESTAMP_NS" + raise TerminalValueError( - f"timestamp with {precision} decimals after seconds cannot be mapped into duckdb" - " TIMESTAMP type" + f"DuckDB does not support precision '{precision}' for '{column_name}' in table" + f" '{table_name}'" ) def from_db_type( @@ -162,7 +181,7 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -170,7 +189,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _from_db_type( diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index 7bc1d9e943..a066479294 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -25,7 +25,7 @@ class DummyClientConfiguration(DestinationClientConfiguration): retry_prob: float = 0.0 """probability of job retry""" completed_prob: float = 0.0 - """probablibitly of successful job completion""" + """probability of successful job completion""" exception_prob: float = 0.0 """probability of exception transient exception when running job""" timeout: float = 10.0 @@ -34,6 +34,8 @@ class DummyClientConfiguration(DestinationClientConfiguration): """raise terminal exception in job init""" fail_transiently_in_init: bool = False """raise transient exception in job init""" + truncate_tables_on_staging_destination_before_load: bool = True + """truncate tables on staging destination""" # new jobs workflows create_followup_jobs: bool = False diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 7d406c969f..fc87faaf5a 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -14,6 +14,7 @@ ) import os import time +from dlt.common.metrics import LoadJobMetrics from dlt.common.pendulum import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.storages import FileStorage @@ -25,7 +26,7 @@ ) from dlt.common.destination.reference import ( HasFollowupJobs, - FollowupJob, + FollowupJobRequest, SupportsStagingDestination, TLoadJobState, RunnableLoadJob, @@ -37,10 +38,9 @@ from dlt.destinations.exceptions import ( LoadJobNotExistsException, - LoadJobInvalidStateTransitionException, ) from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest class LoadDummyBaseJob(RunnableLoadJob): @@ -78,18 +78,25 @@ def run(self) -> None: c_r = random.random() if self.config.retry_prob >= c_r: # this will make the job go to a retry state - raise DestinationTransientException("a random retry occured") + raise DestinationTransientException("a random retry occurred") # fail prob c_r = random.random() if self.config.fail_prob >= c_r: # this will make the the job go to a failed state - raise DestinationTerminalException("a random fail occured") + raise DestinationTerminalException("a random fail occurred") time.sleep(0.1) + def metrics(self) -> Optional[LoadJobMetrics]: + m = super().metrics() + # add remote url if there's followup job + if self.config.create_followup_jobs: + m = m._replace(remote_url=self._file_name) + return m -class DummyFollowupJob(ReferenceFollowupJob): + +class DummyFollowupJobRequest(ReferenceFollowupJobRequest): def __init__( self, original_file_name: str, remote_paths: List[str], config: DummyClientConfiguration ) -> None: @@ -100,9 +107,9 @@ def __init__( class LoadDummyJob(LoadDummyBaseJob, HasFollowupJobs): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]: if self.config.create_followup_jobs and final_state == "completed": - new_job = DummyFollowupJob( + new_job = DummyFollowupJobRequest( original_file_name=self.file_name(), remote_paths=[self._file_name], config=self.config, @@ -113,8 +120,8 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: JOBS: Dict[str, LoadDummyBaseJob] = {} -CREATED_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} -CREATED_TABLE_CHAIN_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} +CREATED_FOLLOWUP_JOBS: Dict[str, FollowupJobRequest] = {} +CREATED_TABLE_CHAIN_FOLLOWUP_JOBS: Dict[str, FollowupJobRequest] = {} RETRIED_JOBS: Dict[str, LoadDummyBaseJob] = {} @@ -173,7 +180,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: """Creates a list of followup jobs that should be executed after a table chain is completed""" # if sql job follow up is configure we schedule a merge job that will always fail @@ -184,7 +191,7 @@ def create_table_chain_completed_followup_jobs( if self.config.create_followup_table_chain_reference_jobs: table_job_paths = [job.file_path for job in completed_table_chain_jobs] file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - job = ReferenceFollowupJob(file_name, table_job_paths) + job = ReferenceFollowupJobRequest(file_name, table_job_paths) CREATED_TABLE_CHAIN_FOLLOWUP_JOBS[job.job_id()] = job return [job] return [] @@ -195,6 +202,9 @@ def complete_load(self, load_id: str) -> None: def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: return super().should_load_data_to_staging_dataset(table) + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + @contextmanager def with_staging_dataset(self) -> Iterator[JobClientBase]: try: @@ -212,7 +222,7 @@ def __exit__( pass def _create_job(self, job_id: str) -> LoadDummyBaseJob: - if ReferenceFollowupJob.is_reference_job(job_id): + if ReferenceFollowupJobRequest.is_reference_job(job_id): return LoadDummyBaseJob(job_id, config=self.config) else: return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index f2466f25a2..ac5ffb9ef3 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -3,16 +3,17 @@ import base64 from types import TracebackType -from typing import ClassVar, List, Type, Iterable, Iterator, Optional, Tuple, Sequence, cast +from typing import Dict, List, Type, Iterable, Iterator, Optional, Tuple, Sequence, cast from fsspec import AbstractFileSystem from contextlib import contextmanager import dlt from dlt.common import logger, time, json, pendulum +from dlt.common.metrics import LoadJobMetrics from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema -from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, fsspec_from_config from dlt.common.storages.load_package import ( LoadJobInfo, @@ -21,7 +22,7 @@ ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, + FollowupJobRequest, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -34,7 +35,7 @@ ) from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.destinations.job_impl import ( - ReferenceFollowupJob, + ReferenceFollowupJobRequest, FinalizedLoadJob, FinalizedLoadJobWithFollowupJobs, ) @@ -55,129 +56,136 @@ def __init__( self._job_client: FilesystemClient = None def run(self) -> None: - # pick local filesystem pathlib or posix for buckets - self.is_local_filesystem = self._job_client.config.protocol == "file" - self.pathlib = os.path if self.is_local_filesystem else posixpath - - self.destination_file_name = path_utils.create_path( - self._job_client.config.layout, - self._file_name, - self._job_client.schema.name, - self._load_id, - current_datetime=self._job_client.config.current_datetime, - load_package_timestamp=dlt.current.load_package()["state"]["created_at"], - extra_placeholders=self._job_client.config.extra_placeholders, - ) + self.__is_local_filesystem = self._job_client.config.is_local_filesystem # We would like to avoid failing for local filesystem where # deeply nested directory will not exist before writing a file. # It `auto_mkdir` is disabled by default in fsspec so we made some # trade offs between different options and decided on this. # remote_path = f"{client.config.protocol}://{posixpath.join(dataset_path, destination_file_name)}" remote_path = self.make_remote_path() - if self.is_local_filesystem: - self._job_client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True) + if self.__is_local_filesystem: + # use os.path for local file name + self._job_client.fs_client.makedirs(os.path.dirname(remote_path), exist_ok=True) self._job_client.fs_client.put_file(self._file_path, remote_path) def make_remote_path(self) -> str: """Returns path on the remote filesystem to which copy the file, without scheme. For local filesystem a native path is used""" + destination_file_name = path_utils.create_path( + self._job_client.config.layout, + self._file_name, + self._job_client.schema.name, + self._load_id, + current_datetime=self._job_client.config.current_datetime, + load_package_timestamp=dlt.current.load_package()["state"]["created_at"], + extra_placeholders=self._job_client.config.extra_placeholders, + ) + # pick local filesystem pathlib or posix for buckets + pathlib = os.path if self.__is_local_filesystem else posixpath # path.join does not normalize separators and available # normalization functions are very invasive and may string the trailing separator - return self.pathlib.join( # type: ignore[no-any-return] + return pathlib.join( # type: ignore[no-any-return] self._job_client.dataset_path, - path_utils.normalize_path_sep(self.pathlib, self.destination_file_name), + path_utils.normalize_path_sep(pathlib, destination_file_name), ) + def make_remote_url(self) -> str: + """Returns path on a remote filesystem as a full url including scheme.""" + return self._job_client.make_remote_url(self.make_remote_path()) + + def metrics(self) -> Optional[LoadJobMetrics]: + m = super().metrics() + return m._replace(remote_url=self.make_remote_url()) + class DeltaLoadFilesystemJob(FilesystemLoadJob): def __init__(self, file_path: str) -> None: - super().__init__( - file_path=file_path, - ) + super().__init__(file_path=file_path) - def run(self) -> None: + # create Arrow dataset from Parquet files from dlt.common.libs.pyarrow import pyarrow as pa - from dlt.common.libs.deltalake import ( - DeltaTable, - write_delta_table, - ensure_delta_compatible_arrow_schema, - _deltalake_storage_options, - try_get_deltatable, - ) - # create Arrow dataset from Parquet files - file_paths = ReferenceFollowupJob.resolve_references(self._file_path) - arrow_ds = pa.dataset.dataset(file_paths) + self.file_paths = ReferenceFollowupJobRequest.resolve_references(self._file_path) + self.arrow_ds = pa.dataset.dataset(self.file_paths) - # create Delta table object - dt_path = self._job_client.make_remote_uri( - self._job_client.get_table_dir(self.load_table_name) - ) - storage_options = _deltalake_storage_options(self._job_client.config) - dt = try_get_deltatable(dt_path, storage_options=storage_options) + def make_remote_path(self) -> str: + # remote path is table dir - delta will create its file structure inside it + return self._job_client.get_table_dir(self.load_table_name) - # get partition columns - part_cols = get_columns_names_with_prop(self._load_table, "partition") + def run(self) -> None: + logger.info(f"Will copy file(s) {self.file_paths} to delta table {self.make_remote_url()}") + + from dlt.common.libs.deltalake import write_delta_table, merge_delta_table # explicitly check if there is data # (https://github.com/delta-io/delta-rs/issues/2686) - if arrow_ds.head(1).num_rows == 0: - if dt is None: - # create new empty Delta table with schema from Arrow table - DeltaTable.create( - table_uri=dt_path, - schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema), - mode="overwrite", - partition_by=part_cols, - storage_options=storage_options, - ) + if self.arrow_ds.head(1).num_rows == 0: + self._create_or_evolve_delta_table() return - arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader - - if self._load_table["write_disposition"] == "merge" and dt is not None: - assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] - - if self._load_table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] - if "parent" in self._load_table: - unique_column = get_first_column_name_with_prop(self._load_table, "unique") - predicate = f"target.{unique_column} = source.{unique_column}" - else: - primary_keys = get_columns_names_with_prop(self._load_table, "primary_key") - predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) - - qry = ( - dt.merge( - source=arrow_rbr, - predicate=predicate, - source_alias="source", - target_alias="target", - ) - .when_matched_update_all() - .when_not_matched_insert_all() + with self.arrow_ds.scanner().to_reader() as arrow_rbr: # RecordBatchReader + if self._load_table["write_disposition"] == "merge" and self._delta_table is not None: + assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] + merge_delta_table( + table=self._delta_table, + data=arrow_rbr, + schema=self._load_table, + ) + else: + write_delta_table( + table_or_uri=( + self.make_remote_url() if self._delta_table is None else self._delta_table + ), + data=arrow_rbr, + write_disposition=self._load_table["write_disposition"], + partition_by=self._partition_columns, + storage_options=self._storage_options, ) - qry.execute() + @property + def _storage_options(self) -> Dict[str, str]: + from dlt.common.libs.deltalake import _deltalake_storage_options + + return _deltalake_storage_options(self._job_client.config) - else: - write_delta_table( - table_or_uri=dt_path if dt is None else dt, - data=arrow_rbr, - write_disposition=self._load_table["write_disposition"], - partition_by=part_cols, - storage_options=storage_options, + @property + def _delta_table(self) -> Optional["DeltaTable"]: # type: ignore[name-defined] # noqa: F821 + from dlt.common.libs.deltalake import try_get_deltatable + + return try_get_deltatable(self.make_remote_url(), storage_options=self._storage_options) + + @property + def _partition_columns(self) -> List[str]: + return get_columns_names_with_prop(self._load_table, "partition") + + def _create_or_evolve_delta_table(self) -> None: + from dlt.common.libs.deltalake import ( + DeltaTable, + ensure_delta_compatible_arrow_schema, + _evolve_delta_table_schema, + ) + + if self._delta_table is None: + DeltaTable.create( + table_uri=self.make_remote_url(), + schema=ensure_delta_compatible_arrow_schema(self.arrow_ds.schema), + mode="overwrite", + partition_by=self._partition_columns, + storage_options=self._storage_options, ) + else: + _evolve_delta_table_schema(self._delta_table, self.arrow_ds.schema) class FilesystemLoadJobWithFollowup(HasFollowupJobs, FilesystemLoadJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]: jobs = super().create_followup_jobs(final_state) if self._load_table.get("table_format") == "delta": # delta table jobs only require table chain followup jobs pass elif final_state == "completed": - ref_job = ReferenceFollowupJob( + ref_job = ReferenceFollowupJobRequest( original_file_name=self.file_name(), - remote_paths=[self._job_client.make_remote_uri(self.make_remote_path())], + remote_paths=[self._job_client.make_remote_url(self.make_remote_path())], ) jobs.append(ref_job) return jobs @@ -200,7 +208,7 @@ def __init__( ) -> None: super().__init__(schema, config, capabilities) self.fs_client, fs_path = fsspec_from_config(config) - self.is_local_filesystem = config.protocol == "file" + self.is_local_filesystem = config.is_local_filesystem self.bucket_path = ( config.make_local_path(config.bucket_url) if self.is_local_filesystem else fs_path ) @@ -288,6 +296,7 @@ def update_stored_schema( only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> TSchemaTables: + applied_update = super().update_stored_schema(only_tables, expected_update) # create destination dirs for all tables table_names = only_tables or self.schema.tables.keys() dirs_to_create = self.get_table_dirs(table_names) @@ -301,14 +310,16 @@ def update_stored_schema( if not self.config.as_staging: self._store_current_schema() - return expected_update + # we assume that expected_update == applied_update so table schemas in dest were not + # externally changed + return applied_update def get_table_dir(self, table_name: str, remote: bool = False) -> str: # dlt tables do not respect layout (for now) table_prefix = self.get_table_prefix(table_name) table_dir: str = self.pathlib.dirname(table_prefix) if remote: - table_dir = self.make_remote_uri(table_dir) + table_dir = self.make_remote_url(table_dir) return table_dir def get_table_prefix(self, table_name: str) -> str: @@ -342,7 +353,7 @@ def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[ # we fallback to our own glob implementation that is tested to return consistent results for # filesystems we support. we were not able to use `find` or `walk` because they were selecting # files wrongly (on azure walk on path1/path2/ would also select files from path1/path2_v2/ but returning wrong dirs) - for details in glob_files(self.fs_client, self.make_remote_uri(table_dir), "**"): + for details in glob_files(self.fs_client, self.make_remote_url(table_dir), "**"): file = details["file_name"] filepath = self.pathlib.join(table_dir, details["relative_path"]) # skip INIT files @@ -369,7 +380,7 @@ def create_load_job( import dlt.common.libs.deltalake # assert dependencies are installed # a reference job for a delta table indicates a table chain followup job - if ReferenceFollowupJob.is_reference_job(file_path): + if ReferenceFollowupJobRequest.is_reference_job(file_path): return DeltaLoadFilesystemJob(file_path) # otherwise just continue return FinalizedLoadJobWithFollowupJobs(file_path) @@ -377,12 +388,12 @@ def create_load_job( cls = FilesystemLoadJobWithFollowup if self.config.as_staging else FilesystemLoadJob return cls(file_path) - def make_remote_uri(self, remote_path: str) -> str: + def make_remote_url(self, remote_path: str) -> str: """Returns uri to the remote filesystem to which copy the file""" if self.is_local_filesystem: - return self.config.make_file_uri(remote_path) + return self.config.make_file_url(remote_path) else: - return f"{self.config.protocol}://{remote_path}" + return self.config.make_url(remote_path) def __enter__(self) -> "FilesystemClient": return self @@ -578,7 +589,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: assert completed_table_chain_jobs is not None jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -591,5 +602,5 @@ def create_table_chain_completed_followup_jobs( if job.job_file_info.table_name == table["name"] ] file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) + jobs.append(ReferenceFollowupJobRequest(file_name, table_job_paths)) return jobs diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 5aa4ba714f..8f6a192bb0 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -91,10 +91,8 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): but it is configurable in rare cases. Make sure it corresponds with the associated embedding model's dimensionality.""" - vector_field_name: str = "vector__" + vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" - id_field_name: str = "id__" - """Name of the special field to manage deduplication.""" sentinel_table_name: str = "dltSentinelTable" """Name of the sentinel table that encapsulates datasets. Since LanceDB has no concept of schemas, this table serves as a proxy to group related dlt tables together.""" diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index d9b92e02b9..d99f0fa6ee 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -32,6 +32,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.recommended_file_size = 128_000_000 + caps.supported_merge_strategies = ["upsert"] + return caps @property diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 0daba7a651..8f4fbb091d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,18 +1,20 @@ -from typing import Any +from typing import Any, Dict from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource +from dlt.extract.items import TTableHintTemplate VECTORIZE_HINT = "x-lancedb-embed" -DOCUMENT_ID_HINT = "x-lancedb-doc-id" +NO_REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" def lancedb_adapter( data: Any, embed: TColumnNames = None, - document_id: TColumnNames = None, + merge_key: TColumnNames = None, + no_remove_orphans: bool = False, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -22,8 +24,10 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. - document_id (TColumnNames, optional): Specify columns which represenet the document - and which will be appended to primary/merge keys. + merge_key (TColumnNames, optional): Specify columns to merge on. + It can be a single column name as a string, or a list of column names. + no_remove_orphans (bool): Specify whether to remove orphaned records in child + tables with no parent records after merges to maintain referential integrity. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -38,6 +42,7 @@ def lancedb_adapter( """ resource = get_resource_for_adapter(data) + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} column_hints: TTableSchemaColumns = {} if embed: @@ -54,23 +59,28 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } - if document_id: - if isinstance(document_id, str): - document_id = [document_id] - if not isinstance(document_id, list): + if merge_key: + if isinstance(merge_key, str): + merge_key = [merge_key] + if not isinstance(merge_key, list): raise ValueError( - "'document_id' must be a list of column names or a single column name as a string." + "'merge_key' must be a list of column names or a single column name as a string." ) - for column_name in document_id: + for column_name in merge_key: column_hints[column_name] = { "name": column_name, - DOCUMENT_ID_HINT: True, # type: ignore[misc] + "merge_key": True, } - if not column_hints: - raise ValueError("At least one of 'embed' or 'document_id' must be specified.") + additional_table_hints[NO_REMOVE_ORPHANS_HINT] = no_remove_orphans + + if column_hints or additional_table_hints: + resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints) else: - resource.apply_hints(columns=column_hints) + raise ValueError( + "You must must provide at least either the 'embed' or 'merge_key' or 'remove_orphans'" + " argument if using the adapter." + ) return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ecdd22ca56..11249d0f97 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -17,9 +17,9 @@ import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa -import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection +from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore from numpy import ndarray @@ -38,23 +38,20 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - FollowupJob, LoadJob, HasFollowupJobs, + FollowupJobRequest, ) -from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( TColumnType, - TTableFormat, TTableSchemaColumns, TWriteDisposition, TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName -from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, ) @@ -63,7 +60,7 @@ ) from dlt.destinations.impl.lancedb.lancedb_adapter import ( VECTORIZE_HINT, - DOCUMENT_ID_HINT, + NO_REMOVE_ORPHANS_HINT, ) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, @@ -72,13 +69,17 @@ NULL_SCHEMA, TArrowField, arrow_datatype_to_fusion_datatype, + TTableLineage, + TableJob, ) from dlt.destinations.impl.lancedb.utils import ( - get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, - generate_arrow_uuid_column, + EMPTY_STRING_PLACEHOLDER, + fill_empty_source_column_values_with_placeholder, + get_canonical_vector_database_doc_id_merge_key, + create_filter_condition, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -113,21 +114,27 @@ class LanceDBTypeMapper(TypeMapper): pa.date32(): "date", } - def to_db_decimal_type( - self, precision: Optional[int], scale: Optional[int] - ) -> pa.Decimal128Type: - precision, scale = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) return pa.decimal128(precision, scale) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.Time64Type: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) @@ -157,14 +164,16 @@ def from_db_type( return super().from_db_type(cast(str, db_type), precision, scale) -def write_to_db( - records: Union[pa.Table, List[DictStrAny]], +def write_records( + records: DATA, /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - id_field_name: Optional[str] = None, + merge_key: Optional[str] = None, + remove_orphans: Optional[bool] = False, + filter_condition: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -172,8 +181,11 @@ def write_to_db( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - id_field_name: The name of the ID field for update/merge operations. + merge_key: Keys for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). + filter_condition (str): If None, then all such rows will be deleted. + Otherwise, the condition will be used as an SQL filter to limit what rows are deleted. Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -194,11 +206,14 @@ def write_to_db( elif write_disposition == "replace": tbl.add(records, mode="overwrite") elif write_disposition == "merge": - if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + if remove_orphans: + tbl.merge_insert(merge_key).when_not_matched_by_source_delete( + filter_condition + ).execute(records) + else: + tbl.merge_insert( + merge_key + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -214,6 +229,8 @@ class LanceDBClient(JobClientBase, WithStateSync): """LanceDB destination handler.""" model_func: TextEmbeddingFunction + """The embedder callback used for each chunk.""" + dataset_name: str def __init__( self, @@ -231,6 +248,7 @@ def __init__( self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name + self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider @@ -241,27 +259,13 @@ def __init__( embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) - # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": - from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) - else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name - self.id_field_name = self.config.id_field_name - - @property - def dataset_name(self) -> str: - return self.config.normalize_dataset_name(self.schema) @property def sentinel_table(self) -> str: @@ -442,7 +446,7 @@ def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Fi try: # Use DataFusion SQL syntax to alter fields without loading data into client memory. - # Currently, the most efficient way to modify column values is in LanceDB. + # Now, the most efficient way to modify column values is in LanceDB. new_fields = { field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" for field in field_schemas @@ -484,13 +488,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: self.schema.get_table(table_name=table_name), VECTORIZE_HINT ) vector_field_name = self.vector_field_name - id_field_name = self.id_field_name embedding_model_func = self.model_func embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None - id_field_name = None embedding_model_func = None embedding_model_dimensions = None @@ -502,7 +504,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, - id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -532,7 +533,7 @@ def update_schema_in_storage(self) -> None: "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_version_table_name, @@ -553,7 +554,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # normalize property names p_load_id = self.schema.naming.normalize_identifier("load_id") - p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier( + self.schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + ) p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_status = self.schema.naming.normalize_identifier("status") p_version = self.schema.naming.normalize_identifier("version") @@ -681,7 +684,7 @@ def complete_load(self, load_id: str) -> None: write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_loads_table_name, @@ -691,7 +694,7 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if ReferenceFollowupJob.is_reference_job(file_path): + if ReferenceFollowupJobRequest.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path) else: return LanceDBLoadJob(file_path, table) @@ -700,12 +703,15 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - if table_chain[0].get("write_disposition") == "merge": - # TODO: Use staging to write deletion records. For now we use only one job. + # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. + first_table_in_chain = table_chain[0] + if first_table_in_chain.get( + "write_disposition" + ) == "merge" and not first_table_in_chain.get(NO_REMOVE_ORPHANS_HINT): all_job_paths_ordered = [ job.file_path for table in table_chain @@ -715,7 +721,7 @@ def create_table_chain_completed_followup_jobs( root_table_file_name = FileStorage.get_file_name_from_file_path( all_job_paths_ordered[0] ) - jobs.append(ReferenceFollowupJob(root_table_file_name, all_job_paths_ordered)) + jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) return jobs def table_exists(self, table_name: str) -> bool: @@ -737,10 +743,6 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( - self._load_table - ) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) @@ -748,24 +750,36 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column( - arrow_table, - unique_identifiers=unique_identifiers, - table_name=fq_table_name, - id_field_name=id_field_name, + # Replace empty strings with placeholder string if OpenAI is used. + # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. + if (self._job_client.config.embedding_model_provider == "openai") and ( + source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + ): + arrow_table = fill_empty_source_column_values_with_placeholder( + arrow_table, source_columns, EMPTY_STRING_PLACEHOLDER + ) + + # We need upsert merge's deterministic _dlt_id to perform orphan removal. + # Hence, we require at least a primary key on the root table if the merge disposition is chosen. + if ( + (self._load_table not in self._schema.dlt_table_names()) + and not self._load_table.get("parent") # Is root table. + and (write_disposition == "merge") + and (not get_columns_names_with_prop(self._load_table, "primary_key")) + ): + raise DestinationTerminalException( + "LanceDB's write disposition requires at least one explicit primary key." ) - write_to_db( + write_records( arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=id_field_name, + merge_key=self._schema.data_item_normalizer.C_DLT_ID, # type: ignore[attr-defined] ) -# TODO: Implement staging for this step with insert deletes. class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] @@ -775,93 +789,54 @@ def __init__( ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None - self.references = ReferenceFollowupJob.resolve_references(file_path) + self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: + dlt_load_id = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] + db_client: DBConnection = self._job_client.db_client - id_field_name: str = self._job_client.config.id_field_name - - # We don't all insert jobs for each table using this method. - table_lineage: List[TTableSchema] = [] - for file_path_ in self.references: - table = self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name) - if table["name"] not in [table_["name"] for table_ in table_lineage]: - table_lineage.append(table) - - for table in table_lineage: - fq_table_name: str = self._job_client.make_qualified_table_name(table["name"]) - try: - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - table["parent"] - ) - except KeyError: - fq_parent_table_name = None # The table is a root table. - - try: - child_table = db_client.open_table(fq_table_name) - child_table.checkout_latest() - if fq_parent_table_name: - parent_table = db_client.open_table(fq_parent_table_name) - parent_table.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e - - try: - if fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set( - pc.unique( - parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] - ).to_pylist() - ) - child_ids = set( - pc.unique( - child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ - "_dlt_parent_id" - ] - ).to_pylist() - ) + table_lineage: TTableLineage = [ + TableJob( + table_schema=self._schema.get_table( + ParsedLoadJobFileName.parse(file_path_).table_name + ), + table_name=ParsedLoadJobFileName.parse(file_path_).table_name, + file_path=file_path_, + ) + for file_path_ in self.references + ] - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + for job in table_lineage: + target_is_root_table = "parent" not in job.table_schema + fq_table_name = self._job_client.make_qualified_table_name(job.table_name) + file_path = job.file_path + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) - else: - # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop(table, DOCUMENT_ID_HINT) - if document_id_field and get_columns_names_with_prop(table, "primary_key"): - raise SystemConfigurationException( - "You CANNOT specify a primary key AND a document ID hint for the same" - " resource when using merge disposition." - ) + if target_is_root_table: + canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( + job.table_schema + ) + filter_condition = create_filter_condition( + canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] + ) + merge_key = dlt_load_id - # If document ID is defined, we use this as the sole grouping key to identify stale chunks, - # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or id_field_name - grouping_key = ( - grouping_key if isinstance(grouping_key, list) else [grouping_key] - ) - child_table_arrow: pa.Table = child_table.to_lance().to_table( - columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] - ) + else: + filter_condition = create_filter_condition( + dlt_root_id, + payload_arrow_table[dlt_root_id], + ) + merge_key = dlt_id - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) - joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) - orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") - - except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=merge_key, + remove_orphans=True, + filter_condition=filter_condition, + ) diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py deleted file mode 100644 index d90adb62bd..0000000000 --- a/dlt/destinations/impl/lancedb/models.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Union, List - -import numpy as np -from lancedb.embeddings import OpenAIEmbeddings # type: ignore -from lancedb.embeddings.registry import register # type: ignore -from lancedb.embeddings.utils import TEXT # type: ignore - - -@register("openai_patched") -class PatchedOpenAIEmbeddings(OpenAIEmbeddings): - EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" - - def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] - """ - Replace empty strings with a placeholder value. - """ - - sanitized_texts = super().sanitize_input(texts) - return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] - - def generate_embeddings( - self, - texts: Union[List[str], np.ndarray], # type: ignore[type-arg] - ) -> List[np.array]: # type: ignore[valid-type] - """ - Generate embeddings, treating the placeholder as an empty result. - """ - embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] - - for i, text in enumerate(texts): - if text == self.EMPTY_STRING_PLACEHOLDER: - embeddings[i] = np.zeros(self.ndims()) - - return embeddings diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index db624aeb12..2fa3251ede 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,9 +1,10 @@ """Utilities for creating arrow schemas from table schemas.""" - +from collections import namedtuple from typing import ( List, cast, Optional, + Tuple, ) import pyarrow as pa @@ -11,7 +12,7 @@ from typing_extensions import TypeAlias from dlt.common.json import json -from dlt.common.schema import Schema, TColumnSchema +from dlt.common.schema import Schema, TColumnSchema, TTableSchema from dlt.common.typing import DictStrAny from dlt.destinations.type_mapping import TypeMapper @@ -21,6 +22,8 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" +TableJob = namedtuple("TableJob", ["table_schema", "table_name", "file_path"]) +TTableLineage: TypeAlias = List[TableJob] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: @@ -41,7 +44,6 @@ def make_arrow_table_schema( table_name: str, schema: Schema, type_mapper: TypeMapper, - id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, embedding_model_func: Optional[TextEmbeddingFunction] = None, @@ -50,9 +52,6 @@ def make_arrow_table_schema( """Creates a PyArrow schema from a dlt schema.""" arrow_schema: List[TArrowField] = [] - if id_field_name: - arrow_schema.append(pa.field(id_field_name, pa.string())) - if embedding_fields: # User's provided dimension config, if provided, takes precedence. vec_size = embedding_model_dimensions or embedding_model_func.ndims() diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 37303686df..f07f2754d2 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,15 +1,15 @@ import os -import uuid -from typing import Sequence, Union, Dict, List +from typing import Union, Dict, List import pyarrow as pa -import pyarrow.compute as pc +from dlt.common import logger +from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider - +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { "cohere": "COHERE_API_KEY", "gemini-text": "GOOGLE_API_KEY", @@ -18,59 +18,65 @@ } -# TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) -def generate_arrow_uuid_column( - table: pa.Table, unique_identifiers: Sequence[str], id_field_name: str, table_name: str -) -> pa.Table: - """Generates deterministic UUID - used for deduplication, returning a new arrow - table with added UUID column. - - Args: - table (pa.Table): PyArrow table to generate UUIDs for. - unique_identifiers (Sequence[str]): A list of unique identifier column names. - id_field_name (str): Name of the new UUID column. - table_name (str): Name of the table. - - Returns: - pa.Table: New PyArrow table with the new UUID column. - """ - - unique_identifiers_columns = [] - for col in unique_identifiers: - column = pc.fill_null(pc.cast(table[col], pa.string()), "") - unique_identifiers_columns.append(column.to_pylist()) +def set_non_standard_providers_environment_variables( + embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] +) -> None: + if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: + os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" - uuids = pa.array( - [ - str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) - for x in ["".join(x) for x in zip(*unique_identifiers_columns)] - ] - ) - return table.append_column(id_field_name, uuids) +def get_canonical_vector_database_doc_id_merge_key( + load_table: TTableSchema, +) -> str: + if merge_key := get_columns_names_with_prop(load_table, "merge_key"): + if len(merge_key) > 1: + raise DestinationTerminalException( + "You cannot specify multiple merge keys with LanceDB orphan remove enabled:" + f" {merge_key}" + ) + else: + return merge_key[0] + elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): + # No merge key defined, warn and assume the first element of the primary key is `doc_id`. + logger.warning( + "Merge strategy selected without defined merge key - using the first element of the" + f" primary key ({primary_key}) as merge key." + ) + return primary_key[0] + else: + raise DestinationTerminalException( + "You must specify at least a primary key in order to perform orphan removal." + ) -def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: - """Returns a list of merge keys for a table used for either merging or deduplication. +def fill_empty_source_column_values_with_placeholder( + table: pa.Table, source_columns: List[str], placeholder: str +) -> pa.Table: + """ + Replaces empty strings and null values in the specified source columns of an Arrow table with a placeholder string. Args: - table_schema (TTableSchema): a dlt table schema. + table (pa.Table): The input Arrow table. + source_columns (List[str]): A list of column names to replace empty strings and null values in. + placeholder (str): The placeholder string to use for replacement. Returns: - Sequence[str]: A list of unique column identifiers. + pa.Table: The modified Arrow table with empty strings and null values replaced in the specified columns. """ - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") - merge_keys = [] - if table_schema.get("write_disposition") == "merge": - merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - else: - return get_columns_names_with_prop(table_schema, "unique") - - -def set_non_standard_providers_environment_variables( - embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] -) -> None: - if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: - os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" + for col_name in source_columns: + column = table[col_name] + filled_column = pa.compute.fill_null(column, fill_value=placeholder) + new_column = pa.compute.replace_substring_regex( + filled_column, pattern=r"^$", replacement=placeholder + ) + table = table.set_column(table.column_names.index(col_name), col_name, new_column) + return table + + +def create_filter_condition(field_name: str, array: pa.Array) -> str: + def format_value(element: Union[str, int, float, pa.Scalar]) -> str: + if isinstance(element, pa.Scalar): + element = element.as_py() + return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) + + return f"{field_name} IN ({', '.join(map(format_value, array))})" diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index 85c94c21b7..f1a8bb136a 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -37,6 +37,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 2**30 - 1 caps.is_max_text_data_type_length_in_bytes = False caps.supports_ddl_transactions = True + caps.supports_create_table_if_not_exists = False # IF NOT EXISTS not supported caps.max_rows_per_insert = 1000 caps.timestamp_precision = 7 caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index a67423a873..a7e796b2d8 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Sequence, List, Any from dlt.common.exceptions import TerminalValueError -from dlt.common.destination.reference import FollowupJob +from dlt.common.destination.reference import FollowupJobRequest from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat @@ -59,9 +59,8 @@ class MsSqlTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 8: @@ -160,24 +159,24 @@ def __init__( self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} self.type_mapper = MsSqlTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause - return [ - "ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) - ] + return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns db_type = "nvarchar(%i)" % (c.get("precision") or 900) else: - db_type = self.type_mapper.to_db_type(c) + db_type = self.type_mapper.to_db_type(c, table) hints_str = " ".join( self.active_hints.get(h, "") @@ -189,7 +188,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: if self.config.replace_strategy == "staging-optimized": return [MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 5ae5f27a6e..5777e46c90 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -9,7 +9,7 @@ from dlt.common.destination.reference import ( HasFollowupJobs, RunnableLoadJob, - FollowupJob, + FollowupJobRequest, LoadJob, TLoadJobState, ) @@ -66,9 +66,8 @@ class PostgresTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" # Precision is number of bits @@ -82,6 +81,39 @@ def to_db_integer_type( f"bigint with {precision} bits precision cannot be mapped into postgres integer type" ) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "timestamp" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 6: + timestamp += f" ({precision})" + else: + raise TerminalValueError( + f"Postgres does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + # append timezone part + if timezone is None or timezone: # timezone True and None + timestamp += " with time zone" + else: # timezone is explicitly False + timestamp += " without time zone" + + return timestamp + def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None ) -> TColumnType: @@ -233,7 +265,7 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -241,12 +273,12 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: if self.config.replace_strategy == "staging-optimized": return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 81abd57803..9bba60af07 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -14,7 +14,7 @@ from dlt.common.destination.reference import ( - FollowupJob, + FollowupJobRequest, CredentialsConfiguration, SupportsStagingDestination, LoadJob, @@ -33,7 +33,7 @@ from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -82,9 +82,8 @@ class RedshiftTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 16: @@ -238,10 +237,12 @@ def __init__( self.config: RedshiftClientConfiguration = config self.type_mapper = RedshiftTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() @@ -249,7 +250,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def create_load_job( @@ -258,7 +259,7 @@ def create_load_job( """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert ReferenceFollowupJob.is_reference_job( + assert ReferenceFollowupJobRequest.is_reference_job( file_path ), "Redshift must use staging to load files" job = RedshiftCopyFileLoadJob( @@ -272,3 +273,6 @@ def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(pq_t, precision, scale) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 904b524791..247b3233d0 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -18,7 +18,7 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat - +from dlt.common.exceptions import TerminalValueError from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat @@ -29,7 +29,7 @@ from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -77,6 +77,36 @@ def from_db_type( return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "TIMESTAMP_TZ" + + if timezone is not None and not timezone: # explicitaly handles timezone False + timestamp = "TIMESTAMP_NTZ" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 9: + timestamp += f"({precision})" + else: + raise TerminalValueError( + f"Snowflake does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + return timestamp + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -98,11 +128,11 @@ def run(self) -> None: self._sql_client = self._job_client.sql_client # resolve reference - is_local_file = not ReferenceFollowupJob.is_reference_job(self._file_path) + is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) file_url = ( self._file_path if is_local_file - else ReferenceFollowupJob.resolve_reference(self._file_path) + else ReferenceFollowupJobRequest.resolve_reference(self._file_path) ) # take file name file_name = FileStorage.get_file_name_from_file_path(file_url) @@ -289,12 +319,11 @@ def create_load_job( return job def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause return [ - "ADD COLUMN\n" - + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] def _get_table_update_sql( @@ -320,8 +349,11 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index bb117e48d2..d5a0281bec 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -63,6 +63,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_transactions = True caps.supports_ddl_transactions = False + caps.supports_create_table_if_not_exists = ( + False # IF NOT EXISTS on CREATE TABLE not supported + ) + # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." # if number of records exceeds a certain number. Which exact number that is seems not deterministic: # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index d1b38f73bd..750a4895f0 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import SupportsStagingDestination, FollowupJob, LoadJob +from dlt.common.destination.reference import SupportsStagingDestination, FollowupJobRequest, LoadJob from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint from dlt.common.schema.utils import ( @@ -19,7 +19,7 @@ AzureServicePrincipalCredentialsWithoutDefaults, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import ( SqlJobClientBase, @@ -131,7 +131,7 @@ def _get_columstore_valid_column(self, c: TColumnSchema) -> TColumnSchema: def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: return SqlJobClientBase._create_replace_followup_jobs(self, table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -163,7 +163,7 @@ def create_load_job( ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert ReferenceFollowupJob.is_reference_job( + assert ReferenceFollowupJobRequest.is_reference_job( file_path ), "Synapse must use staging to load files" job = SynapseCopyFileLoadJob( @@ -173,6 +173,9 @@ def create_load_job( ) return job + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 7fdd979c5d..3026baf753 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -42,7 +42,7 @@ WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, - FollowupJob, + FollowupJobRequest, WithStagingDataset, RunnableLoadJob, LoadJob, @@ -53,7 +53,7 @@ from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.job_impl import ( - ReferenceFollowupJob, + ReferenceFollowupJobRequest, ) from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlStagingCopyFollowupJob from dlt.destinations.typing import TNativeConn @@ -118,7 +118,7 @@ def __init__( super().__init__(file_path) self._job_client: "SqlJobClientBase" = None self._staging_credentials = staging_credentials - self._bucket_path = ReferenceFollowupJob.resolve_reference(file_path) + self._bucket_path = ReferenceFollowupJobRequest.resolve_reference(file_path) class SqlJobClientBase(JobClientBase, WithStateSync): @@ -216,16 +216,18 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: def _create_append_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: return [] - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJobRequest]: return [SqlMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[FollowupJob]: - jobs: List[FollowupJob] = [] + ) -> List[FollowupJobRequest]: + jobs: List[FollowupJobRequest] = [] if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: jobs.append( SqlStagingCopyFollowupJob.from_table_chain( @@ -238,7 +240,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: + ) -> List[FollowupJobRequest]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -515,28 +517,36 @@ def _build_schema_update_sql( return sql_updates, schema_update def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] + return [f"ADD COLUMN {self._get_column_def_sql(c, table)}" for c in new_columns] + + def _make_create_table(self, qualified_name: str, table: TTableSchema) -> str: + not_exists_clause = " " + if ( + table["name"] in self.schema.dlt_table_names() + and self.capabilities.supports_create_table_if_not_exists + ): + not_exists_clause = " IF NOT EXISTS " + return f"CREATE TABLE{not_exists_clause}{qualified_name}" def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: # build sql - canonical_name = self.sql_client.make_qualified_table_name(table_name) + qualified_name = self.sql_client.make_qualified_table_name(table_name) table = self.prepare_load_table(table_name) - table_format = table.get("table_format") sql_result: List[str] = [] if not generate_alter: # build CREATE - sql = f"CREATE TABLE {canonical_name} (\n" - sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + sql = self._make_create_table(qualified_name, table) + " (\n" + sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) sql += ")" sql_result.append(sql) else: - sql_base = f"ALTER TABLE {canonical_name}\n" - add_column_statements = self._make_add_column_sql(new_columns, table_format) + sql_base = f"ALTER TABLE {qualified_name}\n" + add_column_statements = self._make_add_column_sql(new_columns, table) if self.capabilities.alter_add_multi_column: column_sql = ",\n" sql_result.append(sql_base + column_sql.join(add_column_statements)) @@ -559,19 +569,19 @@ def _get_table_update_sql( if hint == "not_null": logger.warning( f"Column(s) {hint_columns} with NOT NULL are being added to existing" - f" table {canonical_name}. If there's data in the table the operation" + f" table {qualified_name}. If there's data in the table the operation" " will fail." ) else: logger.warning( f"Column(s) {hint_columns} with hint {hint} are being added to existing" - f" table {canonical_name}. Several hint types may not be added to" + f" table {qualified_name}. Several hint types may not be added to" " existing tables." ) return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: pass @staticmethod diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 41c939f482..1f54913064 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import os import tempfile # noqa: 251 -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Optional from dlt.common.json import json from dlt.common.destination.reference import ( @@ -9,9 +9,10 @@ TLoadJobState, RunnableLoadJob, JobClientBase, - FollowupJob, + FollowupJobRequest, LoadJob, ) +from dlt.common.metrics import LoadJobMetrics from dlt.common.storages.load_package import commit_load_package_state from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage @@ -56,7 +57,7 @@ class FinalizedLoadJobWithFollowupJobs(FinalizedLoadJob, HasFollowupJobs): pass -class FollowupJobImpl(FollowupJob): +class FollowupJobRequestImpl(FollowupJobRequest): """ Class to create a new loadjob, not stateful and not runnable """ @@ -79,7 +80,7 @@ def job_id(self) -> str: return self._parsed_file_name.job_id() -class ReferenceFollowupJob(FollowupJobImpl): +class ReferenceFollowupJobRequest(FollowupJobRequestImpl): def __init__(self, original_file_name: str, remote_paths: List[str]) -> None: file_name = os.path.splitext(original_file_name)[0] + "." + "reference" self._remote_paths = remote_paths @@ -98,7 +99,7 @@ def resolve_references(file_path: str) -> List[str]: @staticmethod def resolve_reference(file_path: str) -> str: - refs = ReferenceFollowupJob.resolve_references(file_path) + refs = ReferenceFollowupJobRequest.resolve_references(file_path) assert len(refs) == 1 return refs[0] diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 5ffbc4510f..7f74d70208 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional, Callable, Union import yaml -from dlt.common.logger import pretty_format_exception +from dlt.common.time import ensure_pendulum_datetime from dlt.common.schema.typing import ( TTableSchema, @@ -21,7 +21,7 @@ from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException -from dlt.destinations.job_impl import FollowupJobImpl +from dlt.destinations.job_impl import FollowupJobRequestImpl from dlt.destinations.sql_client import SqlClientBase from dlt.common.destination.exceptions import DestinationTransientException @@ -45,7 +45,7 @@ def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSc ) -class SqlFollowupJob(FollowupJobImpl): +class SqlFollowupJob(FollowupJobRequestImpl): """Sql base job for jobs that rely on the whole tablechain""" @classmethod @@ -54,7 +54,7 @@ def from_table_chain( table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, - ) -> FollowupJobImpl: + ) -> FollowupJobRequestImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -720,10 +720,18 @@ def gen_scd2_sql( format_datetime_literal = ( DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal ) - boundary_ts = format_datetime_literal( - current_load_package()["state"]["created_at"], + + boundary_ts = ensure_pendulum_datetime( + root_table.get( # type: ignore[arg-type] + "x-boundary-timestamp", + current_load_package()["state"]["created_at"], + ) + ) + boundary_literal = format_datetime_literal( + boundary_ts, caps.timestamp_precision, ) + active_record_timestamp = get_active_record_timestamp(root_table) if active_record_timestamp is None: active_record_literal = "NULL" @@ -736,7 +744,7 @@ def gen_scd2_sql( # retire updated and deleted records sql.append(f""" - {cls.gen_update_table_prefix(root_table_name)} {to} = {boundary_ts} + {cls.gen_update_table_prefix(root_table_name)} {to} = {boundary_literal} WHERE {is_active_clause} AND {hash_} NOT IN (SELECT {hash_} FROM {staging_root_table_name}); """) @@ -746,7 +754,7 @@ def gen_scd2_sql( col_str = ", ".join([c for c in columns if c not in (from_, to)]) sql.append(f""" INSERT INTO {root_table_name} ({col_str}, {from_}, {to}) - SELECT {col_str}, {boundary_ts} AS {from_}, {active_record_literal} AS {to} + SELECT {col_str}, {boundary_literal} AS {from_}, {active_record_literal} AS {to} FROM {staging_root_table_name} AS s WHERE {hash_} NOT IN (SELECT {hash_} FROM {root_table_name} WHERE {is_active_clause}); """) diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcd938b33c..5ac43e4f1f 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,13 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat +from dlt.common import logger +from dlt.common.schema.typing import ( + TColumnSchema, + TDataType, + TColumnType, + TTableFormat, + TTableSchema, +) from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.utils import without_none @@ -20,39 +27,54 @@ class TypeMapper: def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: # Override in subclass if db supports other timestamp types (e.g. with different time resolutions) + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is not None or precision is not None: + message = ( + "Column flags for timezone or precision are not yet supported in this" + " destination. One or both of these flags were used in column" + f" '{column.get('name')}'." + ) + # TODO: refactor lancedb and wevavite to make table object required + if table: + message += f" in table '{table.get('name')}'." + + logger.warning(message) + return None - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other time types (e.g. with different time resolutions) return None - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: - precision_tup = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> str: + precision_tup = self.decimal_precision(column.get("precision"), column.get("scale")) if not precision_tup or "decimal" not in self.sct_to_dbt: return self.sct_to_unbound_dbt["decimal"] return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1]) - def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: - precision, scale = column.get("precision"), column.get("scale") + # TODO: refactor lancedb and wevavite to make table object required + def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: sc_t = column["data_type"] if sc_t == "bigint": - db_t = self.to_db_integer_type(precision, table_format) + db_t = self.to_db_integer_type(column, table) elif sc_t == "timestamp": - db_t = self.to_db_datetime_type(precision, table_format) + db_t = self.to_db_datetime_type(column, table) elif sc_t == "time": - db_t = self.to_db_time_type(precision, table_format) + db_t = self.to_db_time_type(column, table) elif sc_t == "decimal": - db_t = self.to_db_decimal_type(precision, scale) + db_t = self.to_db_decimal_type(column) else: db_t = None if db_t: @@ -61,14 +83,16 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) - bounded_template = self.sct_to_dbt.get(sc_t) if not bounded_template: return self.sct_to_unbound_dbt[sc_t] - precision_tuple = self.precision_tuple_or_default(sc_t, precision, scale) + precision_tuple = self.precision_tuple_or_default(sc_t, column) if not precision_tuple: return self.sct_to_unbound_dbt[sc_t] return self.sct_to_dbt[sc_t] % precision_tuple def precision_tuple_or_default( - self, data_type: TDataType, precision: Optional[int], scale: Optional[int] + self, data_type: TDataType, column: TColumnSchema ) -> Optional[Tuple[int, ...]]: + precision = column.get("precision") + scale = column.get("scale") if data_type in ("timestamp", "time"): if precision is None: return None # Use default which is usually the max diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 4a1de2517d..8a91dd7477 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -4,9 +4,9 @@ from dlt.common.configuration import known_sections, resolve_configuration, with_config from dlt.common import logger from dlt.common.configuration.specs import BaseConfiguration, configspec -from dlt.common.data_writers import DataWriterMetrics from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import MissingDependencyException +from dlt.common.metrics import DataWriterMetrics from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.typing import TDataItems, TDataItem, TLoaderFileFormat from dlt.common.schema import Schema, utils diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 123a8455e1..67a6b3e83a 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -26,6 +26,7 @@ new_table, ) from dlt.common.typing import TDataItem +from dlt.common.time import ensure_pendulum_datetime from dlt.common.utils import clone_dict_nested from dlt.common.normalizers.json.relational import DataItemNormalizer from dlt.common.validation import validate_dict_ignoring_xkeys @@ -444,6 +445,8 @@ def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: mddict: TMergeDispositionDict = deepcopy(dict_["write_disposition"]) if mddict is not None: dict_["x-merge-strategy"] = mddict.get("strategy", DEFAULT_MERGE_STRATEGY) + if "boundary_timestamp" in mddict: + dict_["x-boundary-timestamp"] = mddict["boundary_timestamp"] # add columns for `scd2` merge strategy if dict_.get("x-merge-strategy") == "scd2": if mddict.get("validity_column_names") is None: @@ -512,3 +515,14 @@ def validate_write_disposition_hint(wd: TTableHintTemplate[TWriteDispositionConf f'`{wd["strategy"]}` is not a valid merge strategy. ' f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" ) + + for ts in ("active_record_timestamp", "boundary_timestamp"): + if ts == "active_record_timestamp" and wd.get("active_record_timestamp") is None: + continue # None is allowed for active_record_timestamp + if ts in wd: + try: + ensure_pendulum_datetime(wd[ts]) # type: ignore[literal-required] + except Exception: + raise ValueError( + f'could not parse `{ts}` value "{wd[ts]}"' # type: ignore[literal-required] + ) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index c1117370b5..343a737c07 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -35,7 +35,12 @@ IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) -from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import ( + IncrementalColumnState, + TCursorValue, + LastValueFunc, + OnCursorValueMissing, +) from dlt.extract.pipe import Pipe from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform from dlt.extract.incremental.transform import ( @@ -81,7 +86,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa >>> info = p.run(r, destination="duckdb") Args: - cursor_path: The name or a JSON path to an cursor field. Uses the same names of fields as in your JSON document, before they are normalized to store in the database. + cursor_path: The name or a JSON path to a cursor field. Uses the same names of fields as in your JSON document, before they are normalized to store in the database. initial_value: Optional value used for `last_value` when no state is available, e.g. on the first run of the pipeline. If not provided `last_value` will be `None` on the first run. last_value_func: Callable used to determine which cursor value to save in state. It is called with a list of the stored state value and all cursor vals from currently processing items. Default is `max` primary_key: Optional primary key used to deduplicate data. If not provided, a primary key defined by the resource will be used. Pass a tuple to define a compound key. Pass empty tuple to disable unique checks @@ -95,6 +100,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa specified range of data. Currently Airflow scheduler is detected: "data_interval_start" and "data_interval_end" are taken from the context and passed Incremental class. The values passed explicitly to Incremental will be ignored. Note that if logical "end date" is present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded + on_cursor_value_missing: Specify what happens when the cursor_path does not exist in a record or a record has `None` at the cursor_path: raise, include, exclude """ # this is config/dataclass so declare members @@ -104,6 +110,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa end_value: Optional[Any] = None row_order: Optional[TSortOrder] = None allow_external_schedulers: bool = False + on_cursor_value_missing: OnCursorValueMissing = "raise" # incremental acting as empty EMPTY: ClassVar["Incremental[Any]"] = None @@ -118,6 +125,7 @@ def __init__( end_value: Optional[TCursorValue] = None, row_order: Optional[TSortOrder] = None, allow_external_schedulers: bool = False, + on_cursor_value_missing: OnCursorValueMissing = "raise", ) -> None: # make sure that path is valid if cursor_path: @@ -133,6 +141,11 @@ def __init__( self._primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key self.row_order = row_order self.allow_external_schedulers = allow_external_schedulers + if on_cursor_value_missing not in ["raise", "include", "exclude"]: + raise ValueError( + f"Unexpected argument for on_cursor_value_missing. Got {on_cursor_value_missing}" + ) + self.on_cursor_value_missing = on_cursor_value_missing self._cached_state: IncrementalColumnState = None """State dictionary cached on first access""" @@ -171,6 +184,7 @@ def _make_transforms(self) -> None: self.last_value_func, self._primary_key, set(self._cached_state["unique_hashes"]), + self.on_cursor_value_missing, ) @classmethod diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py index e318a028dc..973d3b6585 100644 --- a/dlt/extract/incremental/exceptions.py +++ b/dlt/extract/incremental/exceptions.py @@ -1,14 +1,55 @@ +from typing import Any + from dlt.extract.exceptions import PipeException from dlt.common.typing import TDataItem class IncrementalCursorPathMissing(PipeException): - def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str = None) -> None: + def __init__( + self, pipe_name: str, json_path: str, item: TDataItem = None, msg: str = None + ) -> None: self.json_path = json_path self.item = item msg = ( msg - or f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." + or f"Cursor element with JSON path `{json_path}` was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document because they can be different from the names you see in database." + ) + super().__init__(pipe_name, msg) + + +class IncrementalCursorPathHasValueNone(PipeException): + def __init__( + self, pipe_name: str, json_path: str, item: TDataItem = None, msg: str = None + ) -> None: + self.json_path = json_path + self.item = item + msg = ( + msg + or f"Cursor element with JSON path `{json_path}` has the value `None` in extracted data item. All data items must contain a value != None. Construct the incremental with on_cursor_value_none='include' if you want to include such rows" + ) + super().__init__(pipe_name, msg) + + +class IncrementalCursorInvalidCoercion(PipeException): + def __init__( + self, + pipe_name: str, + cursor_path: str, + cursor_value: TDataItem, + cursor_value_type: str, + item: TDataItem, + item_type: Any, + details: str, + ) -> None: + self.cursor_path = cursor_path + self.cursor_value = cursor_value + self.cursor_value_type = cursor_value_type + self.item = item + msg = ( + f"Could not coerce {cursor_value_type} with value {cursor_value} and type" + f" {type(cursor_value)} to actual data item {item} at path {cursor_path} with type" + f" {item_type}: {details}. You need to use different data type for" + f" {cursor_value_type} or cast your data ie. by using `add_map` on this resource." ) super().__init__(pipe_name, msg) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 947e21f7b8..eb448d4266 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -1,5 +1,5 @@ -from datetime import datetime, date # noqa: I251 -from typing import Any, Optional, Set, Tuple, List +from datetime import datetime # noqa: I251 +from typing import Any, Optional, Set, Tuple, List, Type from dlt.common.exceptions import MissingDependencyException from dlt.common.utils import digest128 @@ -8,10 +8,12 @@ from dlt.common.typing import TDataItem from dlt.common.jsonpath import find_values, JSONPathFields, compile_path from dlt.extract.incremental.exceptions import ( + IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, + IncrementalCursorPathHasValueNone, ) -from dlt.extract.incremental.typing import TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import TCursorValue, LastValueFunc, OnCursorValueMissing from dlt.extract.utils import resolve_column_value from dlt.extract.items import TTableHintTemplate from dlt.common.schema.typing import TColumnNames @@ -54,6 +56,7 @@ def __init__( last_value_func: LastValueFunc[TCursorValue], primary_key: Optional[TTableHintTemplate[TColumnNames]], unique_hashes: Set[str], + on_cursor_value_missing: OnCursorValueMissing = "raise", ) -> None: self.resource_name = resource_name self.cursor_path = cursor_path @@ -66,6 +69,7 @@ def __init__( self.primary_key = primary_key self.unique_hashes = unique_hashes self.start_unique_hashes = set(unique_hashes) + self.on_cursor_value_missing = on_cursor_value_missing # compile jsonpath self._compiled_cursor_path = compile_path(cursor_path) @@ -115,21 +119,39 @@ class JsonIncremental(IncrementalTransform): def find_cursor_value(self, row: TDataItem) -> Any: """Finds value in row at cursor defined by self.cursor_path. - Will use compiled JSONPath if present, otherwise it reverts to column search if row is dict + Will use compiled JSONPath if present. + Otherwise, reverts to field access if row is dict, Pydantic model, or of other class. """ - row_value: Any = None + key_exc: Type[Exception] = IncrementalCursorPathHasValueNone if self._compiled_cursor_path: - row_values = find_values(self._compiled_cursor_path, row) - if row_values: - row_value = row_values[0] + # ignores the other found values, e.g. when the path is $data.items[*].created_at + try: + row_value = find_values(self._compiled_cursor_path, row)[0] + except IndexError: + # empty list so raise a proper exception + row_value = None + key_exc = IncrementalCursorPathMissing else: try: - row_value = row[self.cursor_path] - except Exception: - pass - if row_value is None: - raise IncrementalCursorPathMissing(self.resource_name, self.cursor_path, row) - return row_value + try: + row_value = row[self.cursor_path] + except TypeError: + # supports Pydantic models and other classes + row_value = getattr(row, self.cursor_path) + except (KeyError, AttributeError): + # attr not found so raise a proper exception + row_value = None + key_exc = IncrementalCursorPathMissing + + # if we have a value - return it + if row_value is not None: + return row_value + + if self.on_cursor_value_missing == "raise": + # raise missing path or None value exception + raise key_exc(self.resource_name, self.cursor_path, row) + elif self.on_cursor_value_missing == "exclude": + return None def __call__( self, @@ -143,6 +165,12 @@ def __call__( return row, False, False row_value = self.find_cursor_value(row) + if row_value is None: + if self.on_cursor_value_missing == "exclude": + return None, False, False + else: + return row, False, False + last_value = self.last_value last_value_func = self.last_value_func @@ -158,14 +186,36 @@ def __call__( # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value - if self.end_value is not None and ( - last_value_func((row_value, self.end_value)) != self.end_value - or last_value_func((row_value,)) == self.end_value - ): - return None, False, True - + if self.end_value is not None: + try: + if ( + last_value_func((row_value, self.end_value)) != self.end_value + or last_value_func((row_value,)) == self.end_value + ): + return None, False, True + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + self.cursor_path, + self.end_value, + "end_value", + row_value, + type(row_value).__name__, + str(ex), + ) from ex check_values = (row_value,) + ((last_value,) if last_value is not None else ()) - new_value = last_value_func(check_values) + try: + new_value = last_value_func(check_values) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + self.cursor_path, + last_value, + "start_value/initial_value", + row_value, + type(row_value).__name__, + str(ex), + ) from ex # new_value is "less" or equal to last_value (the actual max) if last_value == new_value: # use func to compute row_value into last_value compatible @@ -276,6 +326,7 @@ def __call__( # TODO: Json path support. For now assume the cursor_path is a column name cursor_path = self.cursor_path + # The new max/min value try: # NOTE: datetimes are always pendulum in UTC @@ -287,21 +338,48 @@ def __call__( self.resource_name, cursor_path, tbl, - f"Column name {cursor_path} was not found in the arrow table. Not nested JSON paths" + f"Column name `{cursor_path}` was not found in the arrow table. Nested JSON paths" " are not supported for arrow tables and dataframes, the incremental cursor_path" " must be a column name.", ) from e + if tbl.schema.field(cursor_path).nullable: + tbl_without_null, tbl_with_null = self._process_null_at_cursor_path(tbl) + + tbl = tbl_without_null + # If end_value is provided, filter to include table rows that are "less" than end_value if self.end_value is not None: - end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) + try: + end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + cursor_path, + self.end_value, + "end_value", + "", + cursor_data_type, + str(ex), + ) from ex tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar)) # Is max row value higher than end value? # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py() if self.start_value is not None: - start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) + try: + start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + cursor_path, + self.start_value, + "start_value/initial_value", + "", + cursor_data_type, + str(ex), + ) from ex # Remove rows lower or equal than the last start value keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) @@ -351,12 +429,28 @@ def __call__( ) ) + # drop the temp unique index before concat and returning + if "_dlt_index" in tbl.schema.names: + tbl = pyarrow.remove_columns(tbl, ["_dlt_index"]) + + if self.on_cursor_value_missing == "include": + if isinstance(tbl, pa.RecordBatch): + assert isinstance(tbl_with_null, pa.RecordBatch) + tbl = pa.Table.from_batches([tbl, tbl_with_null]) + else: + tbl = pa.concat_tables([tbl, tbl_with_null]) + if len(tbl) == 0: return None, start_out_of_range, end_out_of_range - try: - tbl = pyarrow.remove_columns(tbl, ["_dlt_index"]) - except KeyError: - pass if is_pandas: - return tbl.to_pandas(), start_out_of_range, end_out_of_range + tbl = tbl.to_pandas() return tbl, start_out_of_range, end_out_of_range + + def _process_null_at_cursor_path(self, tbl: "pa.Table") -> Tuple["pa.Table", "pa.Table"]: + mask = pa.compute.is_valid(tbl[self.cursor_path]) + rows_without_null = tbl.filter(mask) + rows_with_null = tbl.filter(pa.compute.invert(mask)) + if self.on_cursor_value_missing == "raise": + if rows_with_null.num_rows > 0: + raise IncrementalCursorPathHasValueNone(self.resource_name, self.cursor_path) + return rows_without_null, rows_with_null diff --git a/dlt/extract/incremental/typing.py b/dlt/extract/incremental/typing.py index 9cec97d34d..a5e2612db4 100644 --- a/dlt/extract/incremental/typing.py +++ b/dlt/extract/incremental/typing.py @@ -1,8 +1,9 @@ -from typing import TypedDict, Optional, Any, List, TypeVar, Callable, Sequence +from typing import TypedDict, Optional, Any, List, Literal, TypeVar, Callable, Sequence TCursorValue = TypeVar("TCursorValue", bound=Any) LastValueFunc = Callable[[Sequence[TCursorValue]], Any] +OnCursorValueMissing = Literal["raise", "include", "exclude"] class IncrementalColumnState(TypedDict): diff --git a/dlt/extract/storage.py b/dlt/extract/storage.py index de777ad60e..395366b09e 100644 --- a/dlt/extract/storage.py +++ b/dlt/extract/storage.py @@ -1,7 +1,8 @@ import os from typing import Dict, List -from dlt.common.data_writers import TDataItemFormat, DataWriterMetrics, DataWriter, FileWriterSpec +from dlt.common.data_writers import TDataItemFormat, DataWriter, FileWriterSpec +from dlt.common.metrics import DataWriterMetrics from dlt.common.schema import Schema from dlt.common.storages import ( NormalizeStorageConfiguration, diff --git a/dlt/load/load.py b/dlt/load/load.py index 99a12d69ee..f084c9d3d9 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -5,12 +5,17 @@ import os from dlt.common import logger +from dlt.common.metrics import LoadJobMetrics from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo from dlt.common.schema.utils import get_top_level_table -from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState +from dlt.common.storages.load_storage import ( + LoadPackageInfo, + ParsedLoadJobFileName, + TPackageJobState, +) from dlt.common.storages.load_package import ( LoadPackageStateInjectableContext, load_package as current_load_package, @@ -29,7 +34,7 @@ Destination, RunnableLoadJob, LoadJob, - FollowupJob, + FollowupJobRequest, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination, @@ -84,6 +89,7 @@ def __init__( self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._loaded_packages: List[LoadPackageInfo] = [] + self._job_metrics: Dict[str, LoadJobMetrics] = {} self._run_loop_sleep_duration: float = ( 1.0 # amount of time to sleep between querying completed jobs ) @@ -308,7 +314,7 @@ def create_followup_jobs( where they will be picked up for execution """ - jobs: List[FollowupJob] = [] + jobs: List[FollowupJobRequest] = [] if isinstance(starting_job, HasFollowupJobs): # check for merge jobs only for jobs executing on the destination, the staging destination jobs must be excluded # NOTE: we may move that logic to the interface @@ -392,6 +398,11 @@ def complete_jobs( # create followup jobs self.create_followup_jobs(load_id, state, job, schema) + # preserve metrics + metrics = job.metrics() + if metrics: + self._job_metrics[job.job_id()] = metrics + # try to get exception message from job failed_message = job.exception() self.load_storage.normalized_packages.fail_job( @@ -423,7 +434,7 @@ def complete_jobs( if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: pending_exception = LoadClientJobRetry( load_id, - job.job_file_info().job_id(), + job.job_id(), r_c, self.config.raise_on_max_retries, retry_message=retry_message, @@ -431,6 +442,15 @@ def complete_jobs( elif state == "completed": # create followup jobs self.create_followup_jobs(load_id, state, job, schema) + + # preserve metrics + # TODO: metrics should be persisted. this is different vs. all other steps because load step + # may be restarted in the middle of execution + # NOTE: we could use package state but cases with 100k jobs must be tested + metrics = job.metrics() + if metrics: + self._job_metrics[job.job_id()] = metrics + # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) @@ -464,14 +484,18 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) self.load_storage.complete_load_package(load_id, aborted) # collect package info self._loaded_packages.append(self.load_storage.get_load_package_info(load_id)) - self._step_info_complete_load_id(load_id, metrics={"started_at": None, "finished_at": None}) + # TODO: job metrics must be persisted + self._step_info_complete_load_id( + load_id, + metrics={"started_at": None, "finished_at": None, "job_metrics": self._job_metrics}, + ) # delete jobs only now self.load_storage.maybe_remove_completed_jobs(load_id) logger.info( f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - def update_load_package_info(self, load_id: str) -> None: + def init_jobs_counter(self, load_id: str) -> None: # update counter we only care about the jobs that are scheduled to be loaded package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) @@ -492,7 +516,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: dropped_tables = current_load_package()["state"].get("dropped_tables", []) truncated_tables = current_load_package()["state"].get("truncated_tables", []) - self.update_load_package_info(load_id) + self.init_jobs_counter(load_id) # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: @@ -606,7 +630,8 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: ) ): # the same load id may be processed across multiple runs - if not self.current_load_id: + if self.current_load_id is None: + self._job_metrics = {} self._step_info_start_load_id(load_id) self.load_single_package(load_id, schema) diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 9750f89d4b..e3a2ebcd79 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -2,7 +2,7 @@ from itertools import groupby from dlt.common import logger -from dlt.common.storages.load_package import LoadJobInfo, PackageStorage, TJobState +from dlt.common.storages.load_package import LoadJobInfo, PackageStorage, TPackageJobState from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_child_tables, @@ -19,7 +19,7 @@ def get_completed_table_chain( schema: Schema, - all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], + all_jobs: Iterable[Tuple[TPackageJobState, ParsedLoadJobFileName]], top_merged_table: TTableSchema, being_completed_job_id: str = None, ) -> List[TTableSchema]: @@ -179,9 +179,10 @@ def _init_dataset_and_update_schema( applied_update = job_client.update_stored_schema( only_tables=update_tables, expected_update=expected_update ) - logger.info( - f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" - ) + if truncate_tables: + logger.info( + f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" + ) job_client.initialize_storage(truncate_tables=truncate_tables) return applied_update diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 5f84d57d7a..650d10c268 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -3,9 +3,9 @@ from dlt.common import logger from dlt.common.json import json -from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import ArrowToObjectAdapter from dlt.common.json import custom_pua_decode, may_have_pua +from dlt.common.metrics import DataWriterMetrics from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index e80931605c..3df060b141 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -4,10 +4,10 @@ from concurrent.futures import Future, Executor from dlt.common import logger +from dlt.common.metrics import DataWriterMetrics from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config -from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS from dlt.common.runners import TRunMetrics, Runnable, NullExecutor from dlt.common.runtime import signals diff --git a/dlt/normalize/worker.py b/dlt/normalize/worker.py index 10d0a00eb1..b8969f64a3 100644 --- a/dlt/normalize/worker.py +++ b/dlt/normalize/worker.py @@ -4,12 +4,12 @@ from dlt.common.configuration.container import Container from dlt.common.data_writers import ( DataWriter, - DataWriterMetrics, create_import_spec, resolve_best_writer_spec, get_best_writer_spec, is_native_writer, ) +from dlt.common.metrics import DataWriterMetrics from dlt.common.utils import chunks from dlt.common.schema.typing import TStoredSchema, TTableSchema from dlt.common.storages import ( diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index 29770966a6..2f857e5fd5 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -168,7 +168,7 @@ def asdict(self) -> DictStrAny: """A dictionary representation of PipelineTrace that can be loaded with `dlt`""" d = self._asdict() # run step is the same as load step - d["steps"] = [step.asdict() for step in self.steps] # if step.step != "run" + d["steps"] = [step.asdict() for step in self.steps if step.step != "run"] return d @property diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 632c93d0c7..872d4f34e8 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -123,7 +123,8 @@ def __init__( super().__init__() if total_path is None and maximum_value is None and not stop_after_empty_page: raise ValueError( - "Either `total_path` or `maximum_value` or `stop_after_empty_page` must be provided." + "Either `total_path` or `maximum_value` or `stop_after_empty_page` must be" + " provided." ) self.param_name = param_name self.current_value = initial_value @@ -164,7 +165,7 @@ def update_state(self, response: Response, data: Optional[List[Any]] = None) -> ): self._has_next_page = False - def _stop_after_this_page(self, data: Optional[List[Any]]=None) -> bool: + def _stop_after_this_page(self, data: Optional[List[Any]] = None) -> bool: return self.stop_after_empty_page and not data def _handle_missing_total(self, response_json: Dict[str, Any]) -> None: @@ -371,7 +372,8 @@ def __init__( """ if total_path is None and maximum_offset is None and not stop_after_empty_page: raise ValueError( - "Either `total_path` or `maximum_offset` or `stop_after_empty_page` must be provided." + "Either `total_path` or `maximum_offset` or `stop_after_empty_page` must be" + " provided." ) super().__init__( param_name=offset_param, diff --git a/docs/examples/parent_child_relationship/__init__.py b/docs/examples/parent_child_relationship/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/parent_child_relationship/parent_child_relationship.py b/docs/examples/parent_child_relationship/parent_child_relationship.py new file mode 100644 index 0000000000..6de00ffb28 --- /dev/null +++ b/docs/examples/parent_child_relationship/parent_child_relationship.py @@ -0,0 +1,68 @@ +""" +--- +title: Load parent table records into child table +description: Learn how to integrate custom parent keys into child records +keywords: [parent child relationship, parent key] +--- + +This example demonstrates handling data with parent-child relationships using the `dlt` library. +You learn how to integrate specific fields (e.g., primary, foreign keys) from a parent record into each child record. + +In this example, we'll explore how to: + +- Add `parent_id` into each child record using `add_parent_id` function +- Use the [`add_map` function](https://dlthub.com/docs/api_reference/extract/resource#add_map) to apply this +custom logic to every record in the dataset + +:::note important +Please note that dlt metadata, including `_dlt_id` and `_dlt_load_id`, will still be loaded into the tables. +::: +""" + +from typing import List, Dict, Any, Generator +import dlt + + +# Define a dlt resource with write disposition to 'merge' +@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"}) +def data_source() -> Generator[List[Dict[str, Any]], None, None]: + # Example data + data = [ + { + "parent_id": 1, + "parent_name": "Alice", + "children": [ + {"child_id": 1, "child_name": "Child 1"}, + {"child_id": 2, "child_name": "Child 2"}, + ], + }, + { + "parent_id": 2, + "parent_name": "Bob", + "children": [{"child_id": 3, "child_name": "Child 3"}], + }, + ] + + yield data + + +# Function to add parent_id to each child record within a parent record +def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]: + parent_id_key = "parent_id" + for child in record["children"]: + child[parent_id_key] = record[parent_id_key] + return record + + +if __name__ == "__main__": + # Create and configure the dlt pipeline + pipeline = dlt.pipeline( + pipeline_name="generic_pipeline", + destination="duckdb", + dataset_name="dataset", + ) + + # Run the pipeline + load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id") + # Output the load information after pipeline execution + print(load_info) diff --git a/docs/examples/parent_child_relationship/test_parent_child_relationship.py b/docs/examples/parent_child_relationship/test_parent_child_relationship.py new file mode 100644 index 0000000000..95d1bade97 --- /dev/null +++ b/docs/examples/parent_child_relationship/test_parent_child_relationship.py @@ -0,0 +1,76 @@ +import pytest + +from tests.utils import skipifgithubfork + + +""" +--- +title: Load parent table records into child table +description: Learn how to integrate custom parent keys into child records +keywords: [parent child relationship, parent key] +--- + +This example demonstrates handling data with parent-child relationships using +the `dlt` library. You learn how to integrate specific fields (e.g., primary, +foreign keys) from a parent record into each child record. + +In this example, we'll explore how to: + +- Add `parent_id` into each child record using `add_parent_id` function +- Use the [`add_map` function](https://dlthub.com/docs/api_reference/extract/resource#add_map) to apply this +custom logic to every record in the dataset + +:::note important +Please note that dlt metadata, including `_dlt_id` and `_dlt_load_id`, will still be loaded into the tables. +::: +""" + +from typing import List, Dict, Any, Generator +import dlt + + +# Define a dlt resource with write disposition to 'merge' +@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"}) +def data_source() -> Generator[List[Dict[str, Any]], None, None]: + # Example data + data = [ + { + "parent_id": 1, + "parent_name": "Alice", + "children": [ + {"child_id": 1, "child_name": "Child 1"}, + {"child_id": 2, "child_name": "Child 2"}, + ], + }, + { + "parent_id": 2, + "parent_name": "Bob", + "children": [{"child_id": 3, "child_name": "Child 3"}], + }, + ] + + yield data + + +# Function to add parent_id to each child record within a parent record +def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]: + parent_id_key = "parent_id" + for child in record["children"]: + child[parent_id_key] = record[parent_id_key] + return record + + +@skipifgithubfork +@pytest.mark.forked +def test_parent_child_relationship(): + # Create and configure the dlt pipeline + pipeline = dlt.pipeline( + pipeline_name="generic_pipeline", + destination="duckdb", + dataset_name="dataset", + ) + + # Run the pipeline + load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id") + # Output the load information after pipeline execution + print(load_info) diff --git a/docs/examples/postgres_to_postgres/postgres_to_postgres.py b/docs/examples/postgres_to_postgres/postgres_to_postgres.py index c6502f236a..3e88cb7ee8 100644 --- a/docs/examples/postgres_to_postgres/postgres_to_postgres.py +++ b/docs/examples/postgres_to_postgres/postgres_to_postgres.py @@ -33,7 +33,7 @@ Install `dlt` with `duckdb` as extra, also `connectorx`, Postgres adapter and progress bar tool: ```sh -pip install dlt[duckdb] connectorx pyarrow psycopg2-binary alive-progress +pip install "dlt[duckdb]" connectorx pyarrow psycopg2-binary alive-progress ``` Run the example: diff --git a/docs/technical/general_usage.md b/docs/technical/general_usage.md index 336c892c66..2df903b062 100644 --- a/docs/technical/general_usage.md +++ b/docs/technical/general_usage.md @@ -47,7 +47,7 @@ Pipeline can be explicitly created and configured via `dlt.pipeline()` that retu 4. dataset_name - name of the dataset where the data goes (see later the default names) 5. import_schema_path - default is None 6. export_schema_path - default is None -7. full_refresh - if set to True the pipeline working dir will be erased and the dataset name will get the unique suffix (current timestamp). ie the `my_data` becomes `my_data_20221107164856`. +7. dev_mode - if set to True the pipeline working dir will be erased and the dataset name will get the unique suffix (current timestamp). ie the `my_data` becomes `my_data_20221107164856`. > **Achtung** as per `secrets_and_config.md` the arguments passed to `dlt.pipeline` are configurable and if skipped will be injected by the config providers. **the values provided explicitly in the code have a full precedence over all config providers** @@ -101,7 +101,7 @@ In case **there are more schemas in the pipeline**, the data will be loaded into 1. `spotify` tables and `labels` will load into `spotify_data_1` 2. `mel` resource will load into `spotify_data_1_echonest` -The `full_refresh` option: dataset name receives a prefix with the current timestamp: ie the `my_data` becomes `my_data_20221107164856`. This allows a non destructive full refresh. Nothing is being deleted/dropped from the destination. +The `dev_mode` option: dataset name receives a prefix with the current timestamp: ie the `my_data` becomes `my_data_20221107164856`. This allows a non destructive full refresh. Nothing is being deleted/dropped from the destination. ## pipeline working directory and state Another fundamental concept is the pipeline working directory. This directory keeps the following information: @@ -117,7 +117,7 @@ The `restore_from_destination` argument to `dlt.pipeline` let's the user restore The state is being stored in the destination together with other data. So only when all pipeline stages are completed the state is available for restoration. -The pipeline cannot be restored if `full_refresh` flag is set. +The pipeline cannot be restored if `dev_mode` flag is set. The other way to trigger full refresh is to drop destination dataset. `dlt` detects that and resets the pipeline local working folder. @@ -155,8 +155,8 @@ The default json normalizer will convert json documents into tables. All the key ❗ [more here](working_with_schemas.md) -### Full refresh mode -If `full_refresh` flag is passed to `dlt.pipeline` then +### Dev mode mode +If `dev_mode` flag is passed to `dlt.pipeline` then 1. the pipeline working dir is fully wiped out (state, schemas, temp files) 2. dataset name receives a prefix with the current timestamp: ie the `my_data` becomes `my_data_20221107164856`. 3. pipeline will not be restored from the destination diff --git a/docs/website/blog/2024-01-10-dlt-mode.md b/docs/website/blog/2024-01-10-dlt-mode.md index 1d6bf8ca0e..232124df45 100644 --- a/docs/website/blog/2024-01-10-dlt-mode.md +++ b/docs/website/blog/2024-01-10-dlt-mode.md @@ -124,7 +124,7 @@ With the model we just created, called Products, a chart can be instantly create In this demo, we’ll forego the authentication issues of connecting to a data warehouse, and choose the DuckDB destination to show how the Python environment within Mode can be used to initialize a data pipeline and dump normalized data into a destination. In order to see how it works, we first install dlt[duckdb] into the Python environment. ```sh -!pip install dlt[duckdb] +!pip install "dlt[duckdb]" ``` Next, we initialize the dlt pipeline: diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index 6cd5767dcb..ddb82c95b2 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -117,6 +117,8 @@ access_token = "MY_ACCESS_TOKEN" catalog = "my_catalog" ``` +See [staging support](#staging-support) for authentication options when `dlt` copies files from buckets. + ## Write disposition All write dispositions are supported @@ -166,6 +168,11 @@ pipeline = dlt.pipeline( Refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) for details on connecting your Azure Blob Storage container with the bucket_url and credentials. +Databricks requires that you use ABFS urls in following format: +**abfss://container_name@storage_account_name.dfs.core.windows.net/path** + +`dlt` is able to adapt the other representation (ie **az://container-name/path**') still we recommend that you use the correct form. + Example to set up Databricks with Azure as a staging destination: ```py @@ -175,10 +182,34 @@ Example to set up Databricks with Azure as a staging destination: pipeline = dlt.pipeline( pipeline_name='chess_pipeline', destination='databricks', - staging=dlt.destinations.filesystem('az://your-container-name'), # add this to activate the staging location + staging=dlt.destinations.filesystem('abfss://dlt-ci-data@dltdata.dfs.core.windows.net'), # add this to activate the staging location dataset_name='player_data' ) + ``` + +### Use external locations and stored credentials +`dlt` forwards bucket credentials to `COPY INTO` SQL command by default. You may prefer to use [external locations or stored credentials instead](https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location) that are stored on the Databricks side. + +If you set up external location for your staging path, you can tell `dlt` to use it: +```toml +[destination.databricks] +is_staging_external_location=true +``` + +If you set up Databricks credential named ie. **credential_x**, you can tell `dlt` to use it: +```toml +[destination.databricks] +staging_credentials_name="credential_x" +``` + +Both options are available from code: +```py +import dlt + +bricks = dlt.destinations.databricks(staging_credentials_name="credential_x") +``` + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-databricks](https://github.com/databricks/dbt-databricks) diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index 19cef92f9d..4b8ecec4ca 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -35,6 +35,42 @@ All write dispositions are supported. ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). If you are okay with installing `pyarrow`, we suggest switching to `parquet` as the file format. Loading is faster (and also multithreaded). +### Data types +`duckdb` supports various [timestamp types](https://duckdb.org/docs/sql/data_types/timestamp.html). These can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: supported precision values are 0, 3, 6, and 9 for fractional seconds. Note that `timezone` and `precision` cannot be used together; attempting to combine them will result in an error. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE` (`TIMESTAMPTZ`). + +#### Example precision: TIMESTAMP_MS + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + +#### Example timezone: TIMESTAMP + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + ### Names normalization `dlt` uses the standard **snake_case** naming convention to keep identical table and column identifiers across all destinations. If you want to use the **duckdb** wide range of characters (i.e., emojis) for table and column names, you can switch to the **duck_case** naming convention, which accepts almost any string as an identifier: * `\n` `\r` and `"` are translated to `_` @@ -77,7 +113,8 @@ to disable tz adjustments. ::: ## Supported column hints -`duckdb` may create unique indexes for all columns with `unique` hints, but this behavior **is disabled by default** because it slows the loading down significantly. + +`duckdb` can create unique indexes for columns with `unique` hints. However, **this feature is disabled by default** as it can significantly slow down data loading. ## Destination Configuration diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index d8ec8e0490..c6dcd16862 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -144,7 +144,22 @@ lancedb_adapter( ) ``` -Bear in mind that you can't use an adapter on a [dlt source](../../general-usage/source.md), only a [dlt resource](../../general-usage/resource.md). +When using the `lancedb_adapter`, it's important to apply it directly to resources, not to the whole source. Here's an example: + +```py +products_tables = sql_database().with_resources("products", "customers") + +pipeline = dlt.pipeline( + pipeline_name="postgres_to_lancedb_pipeline", + destination="lancedb", + ) + +# apply adapter to the needed resources +lancedb_adapter(products_tables.products, embed="description") +lancedb_adapter(products_tables.customers, embed="bio") + +info = pipeline.run(products_tables) +``` ## Write disposition @@ -205,8 +220,7 @@ This is the default disposition. It will append the data to the existing data in ## Additional Destination Options - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". -- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". -- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. @@ -220,11 +234,21 @@ The LanceDB destination supports syncing of the `dlt` state. ## Current Limitations +### In-Memory Tables + Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table. This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema. For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields. Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables. +### Null string handling for OpenAI embeddings + +OpenAI embedding service doesn't accept empty string bodies. We deal with this by replacing empty strings with a placeholder that should be very semantically dissimilar to 99.9% of queries. + +If your source column (column which is embedded) has empty values, it is important to consider the impact of this. There might be a _slight_ change that semantic queries can hit these empty strings. + +We reported this issue to LanceDB: https://github.com/lancedb/lancedb/issues/1577. + diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index 1281298312..e506eb79fe 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -82,6 +82,27 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). +### Data types +`postgres` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 6. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP WITHOUT TIME ZONE`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE`. + +#### Example precision and timezone: TIMESTAMP (3) WITHOUT TIME ZONE +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="postgres") +pipeline.run(events()) +``` + ### Fast loading with arrow tables and csv You can use [arrow tables](../verified-sources/arrow-pandas.md) and [csv](../file-formats/csv.md) to quickly load tabular data. Pick the `csv` loader file format like below diff --git a/docs/website/docs/dlt-ecosystem/destinations/qdrant.md b/docs/website/docs/dlt-ecosystem/destinations/qdrant.md index 9f19007227..5fc8097440 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/qdrant.md +++ b/docs/website/docs/dlt-ecosystem/destinations/qdrant.md @@ -106,10 +106,25 @@ qdrant_adapter( ) ``` -:::tip +When using the `qdrant_adapter`, it's important to apply it directly to resources, not to the whole source. Here's an example: -A more comprehensive pipeline would load data from some API or use one of dlt's [verified sources](../verified-sources/). +```py +products_tables = sql_database().with_resources("products", "customers") + +pipeline = dlt.pipeline( + pipeline_name="postgres_to_qdrant_pipeline", + destination="qdrant", + ) +# apply adapter to the needed resources +qdrant_adapter(products_tables.products, embed="description") +qdrant_adapter(products_tables.customers, embed="bio") + +info = pipeline.run(products_tables) +``` + +:::tip +A more comprehensive pipeline would load data from some API or use one of dlt's [verified sources](../verified-sources/). ::: ## Write disposition diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 181d024a2f..f4d5a53d36 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -136,7 +136,33 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta recreated with a [clone command](https://docs.snowflake.com/en/sql-reference/sql/create-clone) from the staging tables. ## Data loading -The data is loaded using an internal Snowflake stage. We use the `PUT` command and per-table built-in stages by default. Stage files are immediately removed (if not specified otherwise). +The data is loaded using an internal Snowflake stage. We use the `PUT` command and per-table built-in stages by default. Stage files are kept by default, unless specified otherwise via the `keep_staged_files` parameter: + +```toml +[destination.snowflake] +keep_staged_files = false +``` + +### Data types +`snowflake` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 9. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP_NTZ`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP_TZ`. + +#### Example precision and timezone: TIMESTAMP_NTZ(3) +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="snowflake") +pipeline.run(events()) +``` ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default @@ -171,7 +197,7 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and Snowflake supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): * `cluster` - creates a cluster column(s). Many columns per table are supported and only when a new table is created. -### Table and column identifiers +## Table and column identifiers Snowflake supports both case sensitive and case insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case insensitive identifiers. Case sensitive (like **sql_cs_v1**) will generate case sensitive identifiers that must be quoted in SQL statements. diff --git a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md index c6597fadce..43bd85ce41 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md +++ b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md @@ -116,6 +116,22 @@ weaviate_adapter( tokenization={"title": "word", "description": "whitespace"}, ) ``` +When using the `weaviate_adapter`, it's important to apply it directly to resources, not to the whole source. Here's an example: + +```py +products_tables = sql_database().with_resources("products", "customers") + +pipeline = dlt.pipeline( + pipeline_name="postgres_to_weaviate_pipeline", + destination="weaviate", + ) + +# apply adapter to the needed resources +weaviate_adapter(products_tables.products, vectorize="description") +weaviate_adapter(products_tables.customers, vectorize="bio") + +info = pipeline.run(products_tables) +``` :::tip diff --git a/docs/website/docs/dlt-ecosystem/staging.md b/docs/website/docs/dlt-ecosystem/staging.md index 05e31a574b..789189b7dd 100644 --- a/docs/website/docs/dlt-ecosystem/staging.md +++ b/docs/website/docs/dlt-ecosystem/staging.md @@ -1,36 +1,33 @@ --- title: Staging -description: Configure an s3 or gcs bucket for staging before copying into the destination +description: Configure an S3 or GCS bucket for staging before copying into the destination keywords: [staging, destination] --- # Staging -The goal of staging is to bring the data closer to the database engine so the modification of the destination (final) dataset happens faster and without errors. `dlt`, when asked, creates two -staging areas: +The goal of staging is to bring the data closer to the database engine so that the modification of the destination (final) dataset happens faster and without errors. `dlt`, when asked, creates two staging areas: 1. A **staging dataset** used by the [merge and replace loads](../general-usage/incremental-loading.md#merge-incremental_loading) to deduplicate and merge data with the destination. -2. A **staging storage** which is typically a s3/gcp bucket where [loader files](file-formats/) are copied before they are loaded by the destination. +2. A **staging storage** which is typically an S3/GCP bucket where [loader files](file-formats/) are copied before they are loaded by the destination. ## Staging dataset -`dlt` creates a staging dataset when write disposition of any of the loaded resources requires it. It creates and migrates required tables exactly like for the -main dataset. Data in staging tables is truncated when load step begins and only for tables that will participate in it. -Such staging dataset has the same name as the dataset passed to `dlt.pipeline` but with `_staging` suffix in the name. Alternatively, you can provide your own staging dataset pattern or use a fixed name, identical for all the -configured datasets. +`dlt` creates a staging dataset when the write disposition of any of the loaded resources requires it. It creates and migrates required tables exactly like for the main dataset. Data in staging tables is truncated when the load step begins and only for tables that will participate in it. +Such a staging dataset has the same name as the dataset passed to `dlt.pipeline` but with a `_staging` suffix in the name. Alternatively, you can provide your own staging dataset pattern or use a fixed name, identical for all the configured datasets. ```toml [destination.postgres] staging_dataset_name_layout="staging_%s" ``` -Entry above switches the pattern to `staging_` prefix and for example for dataset with name **github_data** `dlt` will create **staging_github_data**. +The entry above switches the pattern to `staging_` prefix and for example, for a dataset with the name **github_data**, `dlt` will create **staging_github_data**. -To configure static staging dataset name, you can do the following (we use destination factory) +To configure a static staging dataset name, you can do the following (we use the destination factory) ```py import dlt dest_ = dlt.destinations.postgres(staging_dataset_name_layout="_dlt_staging") ``` -All pipelines using `dest_` as destination will use **staging_dataset** to store staging tables. Make sure that your pipelines are not overwriting each other's tables. +All pipelines using `dest_` as the destination will use the **staging_dataset** to store staging tables. Make sure that your pipelines are not overwriting each other's tables. -### Cleanup up staging dataset automatically -`dlt` does not truncate tables in staging dataset at the end of the load. Data that is left after contains all the extracted data and may be useful for debugging. +### Cleanup staging dataset automatically +`dlt` does not truncate tables in the staging dataset at the end of the load. Data that is left after contains all the extracted data and may be useful for debugging. If you prefer to truncate it, put the following line in `config.toml`: ```toml @@ -39,19 +36,23 @@ truncate_staging_dataset=true ``` ## Staging storage -`dlt` allows to chain destinations where the first one (`staging`) is responsible for uploading the files from local filesystem to the remote storage. It then generates followup jobs for the second destination that (typically) copy the files from remote storage into destination. +`dlt` allows chaining destinations where the first one (`staging`) is responsible for uploading the files from the local filesystem to the remote storage. It then generates follow-up jobs for the second destination that (typically) copy the files from remote storage into the destination. -Currently, only one destination the [filesystem](destinations/filesystem.md) can be used as a staging. Following destinations can copy remote files: -1. [Redshift.](destinations/redshift.md#staging-support) -2. [Bigquery.](destinations/bigquery.md#staging-support) -3. [Snowflake.](destinations/snowflake.md#staging-support) +Currently, only one destination, the [filesystem](destinations/filesystem.md), can be used as staging. The following destinations can copy remote files: + +1. [Azure Synapse](destinations/synapse#staging-support) +1. [Athena](destinations/athena#staging-support) +1. [Bigquery](destinations/bigquery.md#staging-support) +1. [Dremio](destinations/dremio#staging-support) +1. [Redshift](destinations/redshift.md#staging-support) +1. [Snowflake](destinations/snowflake.md#staging-support) ### How to use -In essence, you need to set up two destinations and then pass them to `dlt.pipeline`. Below we'll use `filesystem` staging with `parquet` files to load into `Redshift` destination. +In essence, you need to set up two destinations and then pass them to `dlt.pipeline`. Below we'll use `filesystem` staging with `parquet` files to load into the `Redshift` destination. -1. **Set up the s3 bucket and filesystem staging.** +1. **Set up the S3 bucket and filesystem staging.** - Please follow our guide in [filesystem destination documentation](destinations/filesystem.md). Test the staging as standalone destination to make sure that files go where you want them. In your `secrets.toml` you should now have a working `filesystem` configuration: + Please follow our guide in the [filesystem destination documentation](destinations/filesystem.md). Test the staging as a standalone destination to make sure that files go where you want them. In your `secrets.toml`, you should now have a working `filesystem` configuration: ```toml [destination.filesystem] bucket_url = "s3://[your_bucket_name]" # replace with your bucket name, @@ -63,15 +64,15 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel 2. **Set up the Redshift destination.** - Please follow our guide in [redshift destination documentation](destinations/redshift.md). In your `secrets.toml` you added: + Please follow our guide in the [redshift destination documentation](destinations/redshift.md). In your `secrets.toml`, you added: ```toml # keep it at the top of your toml file! before any section starts destination.redshift.credentials="redshift://loader:@localhost/dlt_data?connect_timeout=15" ``` -3. **Authorize Redshift cluster to access the staging bucket.** +3. **Authorize the Redshift cluster to access the staging bucket.** - By default `dlt` will forward the credentials configured for `filesystem` to the `Redshift` COPY command. If you are fine with this, move to the next step. + By default, `dlt` will forward the credentials configured for `filesystem` to the `Redshift` COPY command. If you are fine with this, move to the next step. 4. **Chain staging to destination and request `parquet` file format.** @@ -79,7 +80,7 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel ```py # Create a dlt pipeline that will load # chess player data to the redshift destination - # via staging on s3 + # via staging on S3 pipeline = dlt.pipeline( pipeline_name='chess_pipeline', destination='redshift', @@ -87,7 +88,7 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel dataset_name='player_data' ) ``` - `dlt` will automatically select an appropriate loader file format for the staging files. Below we explicitly specify `parquet` file format (just to demonstrate how to do it): + `dlt` will automatically select an appropriate loader file format for the staging files. Below we explicitly specify the `parquet` file format (just to demonstrate how to do it): ```py info = pipeline.run(chess(), loader_file_format="parquet") ``` @@ -96,4 +97,21 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel Run the pipeline script as usual. -> 💡 Please note that `dlt` does not delete loaded files from the staging storage after the load is complete. +:::tip +Please note that `dlt` does not delete loaded files from the staging storage after the load is complete, but it truncates previously loaded files. +::: + +### How to prevent staging files truncation + +Before `dlt` loads data to the staging storage, it truncates previously loaded files. To prevent it and keep the whole history +of loaded files, you can use the following parameter: + +```toml +[destination.redshift] +truncate_table_before_load_on_staging_destination=false +``` + +:::caution +The [Athena](destinations/athena#staging-support) destination only truncates not iceberg tables with `replace` merge_disposition. +Therefore, the parameter `truncate_table_before_load_on_staging_destination` only controls the truncation of corresponding files for these tables. +::: diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index eeb717515a..c89a63a524 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -652,6 +652,6 @@ resource. Below we show you an example on how to pseudonymize the data before it print(info) ``` -1. Remember to keep the pipeline name and destination dataset name consistent. The pipeline name is crucial for retrieving the [state](https://dlthub.com/docs/general-usage/state) from the last run, which is essential for incremental loading. Altering these names could initiate a "[full_refresh](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-full-refresh)", interfering with the metadata tracking necessary for [incremental loads](https://dlthub.com/docs/general-usage/incremental-loading). +1. Remember to keep the pipeline name and destination dataset name consistent. The pipeline name is crucial for retrieving the [state](https://dlthub.com/docs/general-usage/state) from the last run, which is essential for incremental loading. Altering these names could initiate a "[dev_mode](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-dev-mode)", interfering with the metadata tracking necessary for [incremental loads](https://dlthub.com/docs/general-usage/incremental-loading). diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md index 8c39a5090e..fdbefeddf1 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md @@ -232,6 +232,6 @@ verified source. load_info = pipeline.run(data=[source_single, source_incremental]) print(load_info) ``` - > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](../../general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“full_refresh”](../../general-usage/pipeline#do-experiments-with-full-refresh), disrupting the metadata (state) tracking for [incremental data loading](../../general-usage/incremental-loading). + > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](../../general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“dev_mode”](../../general-usage/pipeline#do-experiments-with-dev-mode), disrupting the metadata (state) tracking for [incremental data loading](../../general-usage/incremental-loading). diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/workable.md b/docs/website/docs/dlt-ecosystem/verified-sources/workable.md index 472f48a28f..9229ddca7e 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/workable.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/workable.md @@ -272,7 +272,7 @@ To create your data pipeline using single loading and destination dataset names. The pipeline name helps retrieve the [state](https://dlthub.com/docs/general-usage/state) of the last run, essential for incremental data loading. Changing these names might trigger a - [“full_refresh”](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-full-refresh), + [“dev_mode”](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-dev-mode), disrupting metadata tracking for [incremental data loading](https://dlthub.com/docs/general-usage/incremental-loading). diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 8eb1002dcf..5ff587f20e 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -348,7 +348,23 @@ You can configure the literal used to indicate an active record with `active_rec write_disposition={ "disposition": "merge", "strategy": "scd2", - "active_record_timestamp": "9999-12-31", # e.g. datetime.datetime(9999, 12, 31) is also accepted + # accepts various types of date/datetime objects + "active_record_timestamp": "9999-12-31", + } +) +def dim_customer(): + ... +``` + +#### Example: configure boundary timestamp +You can configure the "boundary timestamp" used for record validity windows with `boundary_timestamp`. The provided date(time) value is used as "valid from" for new records and as "valid to" for retired records. The timestamp at which a load package is created is used if `boundary_timestamp` is omitted. +```py +@dlt.resource( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + # accepts various types of date/datetime objects + "boundary_timestamp": "2024-08-21T12:15:00+00:00", } ) def dim_customer(): @@ -673,7 +689,7 @@ than `end_value`. :::caution In rare cases when you use Incremental with a transformer, `dlt` will not be able to automatically close -generator associated with a row that is out of range. You can still use still call `can_close()` method on +generator associated with a row that is out of range. You can still call the `can_close()` method on incremental and exit yield loop when true. ::: @@ -891,22 +907,75 @@ Consider the example below for reading incremental loading parameters from "conf ``` `id_after` incrementally stores the latest `cursor_path` value for future pipeline runs. -### Loading NULL values in the incremental cursor field +### Loading when incremental cursor path is missing or value is None/NULL -When loading incrementally with a cursor field, each row is expected to contain a value at the cursor field that is not `None`. -For example, the following source data will raise an error: +You can customize the incremental processing of dlt by setting the parameter `on_cursor_value_missing`. + +When loading incrementally with the default settings, there are two assumptions: +1. each row contains the cursor path +2. each row is expected to contain a value at the cursor path that is not `None`. + +For example, the two following source data will raise an error: ```py @dlt.resource -def some_data(updated_at=dlt.sources.incremental("updated_at")): +def some_data_without_cursor_path(updated_at=dlt.sources.incremental("updated_at")): yield [ {"id": 1, "created_at": 1, "updated_at": 1}, - {"id": 2, "created_at": 2, "updated_at": 2}, + {"id": 2, "created_at": 2}, # cursor field is missing + ] + +list(some_data_without_cursor_path()) + +@dlt.resource +def some_data_without_cursor_value(updated_at=dlt.sources.incremental("updated_at")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 3, "created_at": 4, "updated_at": None}, # value at cursor field is None + ] + +list(some_data_without_cursor_value()) +``` + + +To process a data set where some records do not include the incremental cursor path or where the values at the cursor path are `None,` there are the following four options: + +1. Configure the incremental load to raise an exception in case there is a row where the cursor path is missing or has the value `None` using `incremental(..., on_cursor_value_missing="raise")`. This is the default behavior. +2. Configure the incremental load to tolerate the missing cursor path and `None` values using `incremental(..., on_cursor_value_missing="include")`. +3. Configure the incremental load to exclude the missing cursor path and `None` values using `incremental(..., on_cursor_value_missing="exclude")`. +4. Before the incremental processing begins: Ensure that the incremental field is present and transform the values at the incremental cursor to a value different from `None`. [See docs below](#transform-records-before-incremental-processing) + +Here is an example of including rows where the incremental cursor value is missing or `None`: +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at", on_cursor_value_missing="include")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": 4, "updated_at": None}, + ] + +result = list(some_data()) +assert len(result) == 3 +assert result[1] == {"id": 2, "created_at": 2} +assert result[2] == {"id": 3, "created_at": 4, "updated_at": None} +``` + +If you do not want to import records without the cursor path or where the value at the cursor path is `None` use the following incremental configuration: + +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at", on_cursor_value_missing="exclude")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2}, {"id": 3, "created_at": 4, "updated_at": None}, ] -list(some_data()) +result = list(some_data()) +assert len(result) == 1 ``` +### Transform records before incremental processing If you want to load data that includes `None` values you can transform the records before the incremental processing. You can add steps to the pipeline that [filter, transform, or pivot your data](../general-usage/resource.md#filter-transform-and-pivot-data). @@ -1146,4 +1215,4 @@ sources: } ``` -Verify that the `last_value` is updated between pipeline runs. \ No newline at end of file +Verify that the `last_value` is updated between pipeline runs. diff --git a/docs/website/docs/general-usage/pipeline.md b/docs/website/docs/general-usage/pipeline.md index f21d6f0686..40f9419bc2 100644 --- a/docs/website/docs/general-usage/pipeline.md +++ b/docs/website/docs/general-usage/pipeline.md @@ -85,6 +85,19 @@ You can inspect stored artifacts using the command > 💡 You can attach `Pipeline` instance to an existing working folder, without creating a new > pipeline with `dlt.attach`. +### Separate working environments with `pipelines_dir`. +You can run several pipelines with the same name but with different configuration ie. to target development / staging / production environments. +Set the `pipelines_dir` argument to store all the working folders in specific place. For example: +```py +import dlt +from dlt.common.pipeline import get_dlt_pipelines_dir + +dev_pipelines_dir = os.path.join(get_dlt_pipelines_dir(), "dev") +pipeline = dlt.pipeline(destination="duckdb", dataset_name="sequence", pipelines_dir=dev_pipelines_dir) +``` +stores pipeline working folder in `~/.dlt/pipelines/dev/`. Mind that you need to pass this `~/.dlt/pipelines/dev/` +in to all cli commands to get info/trace for that pipeline. + ## Do experiments with dev mode If you [create a new pipeline script](../walkthroughs/create-a-pipeline.md) you will be diff --git a/docs/website/docs/general-usage/source.md b/docs/website/docs/general-usage/source.md index 936a3160f0..98c7a13b81 100644 --- a/docs/website/docs/general-usage/source.md +++ b/docs/website/docs/general-usage/source.md @@ -187,6 +187,26 @@ Several data sources are prone to contain semi-structured documents with very de MongoDB databases. Our practical experience is that setting the `max_nesting_level` to 2 or 3 produces the clearest and human-readable schemas. +:::tip +The `max_table_nesting` parameter at the source level doesn't automatically apply to individual +resources when accessed directly (e.g., using `source.resources["resource_1"])`. To make sure it +works, either use `source.with_resources("resource_1")` or set the parameter directly on the resource. +::: + + +You can directly configure the `max_table_nesting` parameter on the resource level as: + +```py +@dlt.resource(max_table_nesting=0) +def my_resource(): + ... +``` +or +```py +my_source = source() +my_source.my_resource.max_table_nesting = 0 +``` + ### Modify schema The schema is available via `schema` property of the source. diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index 075d351553..0ee62acec7 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -62,7 +62,7 @@ Several [text file formats](../dlt-ecosystem/file-formats/) have `gzip` compress Keep in mind load packages are buffered to disk and are left for any troubleshooting, so you can [clear disk space by setting the `delete_completed_jobs` option](../running-in-production/running.md#data-left-behind). ### Observing cpu and memory usage -Please make sure that you have the `psutils` package installed (note that Airflow installs it by default). Then you can dump the stats periodically by setting the [progress](../general-usage/pipeline.md#display-the-loading-progress) to `log` in `config.toml`: +Please make sure that you have the `psutil` package installed (note that Airflow installs it by default). Then you can dump the stats periodically by setting the [progress](../general-usage/pipeline.md#display-the-loading-progress) to `log` in `config.toml`: ```toml progress="log" ``` @@ -258,4 +258,4 @@ DLT_USE_JSON=simplejson ## Using the built in requests wrapper or RESTClient for API calls -Instead of using Python Requests directly, you can use the built-in [requests wrapper](../general-usage/http/requests) or [`RESTClient`](../general-usage/http/rest-client) for API calls. This will make your pipeline more resilient to intermittent network errors and other random glitches. \ No newline at end of file +Instead of using Python Requests directly, you can use the built-in [requests wrapper](../general-usage/http/requests) or [`RESTClient`](../general-usage/http/rest-client) for API calls. This will make your pipeline more resilient to intermittent network errors and other random glitches. diff --git a/docs/website/docs/running-in-production/running.md b/docs/website/docs/running-in-production/running.md index 3b5762612c..cc089a1393 100644 --- a/docs/website/docs/running-in-production/running.md +++ b/docs/website/docs/running-in-production/running.md @@ -271,7 +271,7 @@ load_info.raise_on_failed_jobs() ``` You may also abort the load package with `LoadClientJobFailed` (terminal exception) on a first -failed job. Such package is immediately moved to completed but its load id is not added to the +failed job. Such package is will be completed but its load id is not added to the `_dlt_loads` table. All the jobs that were running in parallel are completed before raising. The dlt state, if present, will not be visible to `dlt`. Here's example `config.toml` to enable this option: @@ -282,6 +282,20 @@ load.workers=1 load.raise_on_failed_jobs=true ``` +:::caution +Note that certain write dispositions will irreversibly modify your data +1. `replace` write disposition with the default `truncate-and-insert` [strategy](../general-usage/full-loading.md) will truncate tables before loading. +2. `merge` write disposition will merge staging dataset tables into the destination dataset. This will happen only when all data for this table (and nested tables) got loaded. + +Here's what you can do to deal with partially loaded packages: +1. Retry the load step in case of transient errors +2. Use replace strategy with staging dataset so replace happens only when data for the table (and all nested tables) was fully loaded and is atomic operation (if possible) +3. Use only "append" write disposition. When your load package fails you are able to use `_dlt_load_id` to remove all unprocessed data. +4. Use "staging append" (`merge` disposition without primary key and merge key defined). + +::: + + ### What `run` does inside Before adding retry to pipeline steps, note how `run` method actually works: diff --git a/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md b/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md index 0e342a3fea..41ba5926c4 100644 --- a/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md +++ b/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md @@ -12,7 +12,7 @@ We'll use the [GitHub API](https://docs.github.com/en/rest) to fetch the events 1. Install dlt with duckdb support: ```sh -pip install dlt[duckdb] +pip install "dlt[duckdb]" ``` 2. Create a new a new file `github_events_dispatch.py` and paste the following code: diff --git a/poetry.lock b/poetry.lock index d54a73a2ef..0ce139d08f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2102,32 +2102,33 @@ typing-extensions = ">=3.10.0" [[package]] name = "databricks-sql-connector" -version = "3.1.2" +version = "2.9.6" description = "Databricks SQL Connector for Python" optional = true -python-versions = "<4.0.0,>=3.8.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "databricks_sql_connector-3.1.2-py3-none-any.whl", hash = "sha256:5292bc25b4d8d58d301079b55086331764f067e24862c9365698b2eeddedb737"}, - {file = "databricks_sql_connector-3.1.2.tar.gz", hash = "sha256:da0df114e0824d49ccfea36c4679c95689fe359191b056ad516446a058307c37"}, + {file = "databricks_sql_connector-2.9.6-py3-none-any.whl", hash = "sha256:d830abf86e71d2eb83c6a7b7264d6c03926a8a83cec58541ddd6b83d693bde8f"}, + {file = "databricks_sql_connector-2.9.6.tar.gz", hash = "sha256:e55f5b8ede8ae6c6f31416a4cf6352f0ac019bf6875896c668c7574ceaf6e813"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.2.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<15.0.0" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" +sqlalchemy = ">=1.3.24,<2.0.0" thrift = ">=0.16.0,<0.17.0" -urllib3 = ">=1.26" - -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] +urllib3 = ">=1.0" [[package]] name = "dbt-athena-community" @@ -2377,25 +2378,24 @@ files = [ [[package]] name = "deltalake" -version = "0.17.4" +version = "0.19.1" description = "Native Delta Lake Python binding based on delta-rs with Pandas integration" optional = true python-versions = ">=3.8" files = [ - {file = "deltalake-0.17.4-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3f048bd4cdd3500fbb0d1b34046966ca4b7cefd1e9df71460b881ee8ad7f844a"}, - {file = "deltalake-0.17.4-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:b539265d8293794872e1dc3b2daad50abe05ab425e961824b3ac1155bb294604"}, - {file = "deltalake-0.17.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55e6be5f5ab8d5d34d2ea58d86e93eec2da5d2476e3c15e9520239457618bca4"}, - {file = "deltalake-0.17.4-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94dde6c2d0a07e9ce47be367d016541d3a499839350852205819353441e1a9c1"}, - {file = "deltalake-0.17.4-cp38-abi3-win_amd64.whl", hash = "sha256:f51f499d50dad88bdc18c5ed7c2319114759f3220f83aa2d32166c19accee4ce"}, - {file = "deltalake-0.17.4.tar.gz", hash = "sha256:c3c10577afc46d4b10ed16246d814a8c40b3663099066681eeba89f908373814"}, + {file = "deltalake-0.19.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ddaaaa9c85a17791c3997cf320ac11dc1725d16cf4b6f0ff1b130853e7b56cd0"}, + {file = "deltalake-0.19.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e0184d5a3f0d4f4f1fb992c3bdc8736329b78b6a4faf1a278109ec35d9945c1d"}, + {file = "deltalake-0.19.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec9d117fcf6c198f3d554be2f3a6291ca3838530650db236741ff48d4d47abb4"}, + {file = "deltalake-0.19.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:447ef721319ed15f7b5f6da507efd5fed0e6172e5ae55ac044d5b8fc9b812e47"}, + {file = "deltalake-0.19.1-cp38-abi3-win_amd64.whl", hash = "sha256:b15bc343a9f8f3de80fbedcebd5d9472b539eb0f538a71739c7fcf699089127e"}, + {file = "deltalake-0.19.1.tar.gz", hash = "sha256:5e09fabb221fb81e989c283c16278eaffb6e85706d98364abcda5c0c6ca73598"}, ] [package.dependencies] -pyarrow = ">=8" -pyarrow-hotfix = "*" +pyarrow = ">=16" [package.extras] -devel = ["mypy (>=1.8.0,<1.9.0)", "packaging (>=20)", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-timeout", "ruff (>=0.3.0,<0.4.0)", "sphinx (<=4.5)", "sphinx-rtd-theme", "toml", "wheel"] +devel = ["azure-storage-blob (==12.20.0)", "mypy (==1.10.1)", "packaging (>=20)", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-timeout", "ruff (==0.5.2)", "sphinx (<=4.5)", "sphinx-rtd-theme", "toml", "wheel"] pandas = ["pandas"] pyspark = ["delta-spark", "numpy (==1.22.2)", "pyspark"] @@ -4567,17 +4567,17 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", [[package]] name = "lancedb" -version = "0.9.0" +version = "0.13.0b1" description = "lancedb" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "lancedb-0.9.0-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:b1ca08797c72c93ae512aa1078f1891756da157d910fbae8e194fac3528fc1ac"}, - {file = "lancedb-0.9.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:15129791f03c2c04b95f914ced2c1556b43d73a24710207b9af77b6e4008bdeb"}, - {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f093d89447a2039b820d2540a0b64df3024e4549b6808ebd26b44fbe0345cc6"}, - {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:a8c1f6777e217d2277451038866d280fa5fb38bd161795e51703b043c26dd345"}, - {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:78dd5800a1148f89d33b7e98d1c8b1c42dee146f03580abc1ca83cb05273ff7f"}, - {file = "lancedb-0.9.0-cp38-abi3-win_amd64.whl", hash = "sha256:ba5bdc727d3bc131f17414f42372acde5817073feeb553793a3d20003caa1658"}, + {file = "lancedb-0.13.0b1-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:687b9a08be55e6fa9520255b1b06dcd2e6ba6c64c947410821e9a3a52b2f48ec"}, + {file = "lancedb-0.13.0b1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ac00684f7e90ffc1b386298670e2c4ddaea8c0b61b6eb1b51dbd4e74feb87a86"}, + {file = "lancedb-0.13.0b1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbe8fc15bfeec89b6b2a4a42b4b919b6d3e138cf8684af35f77f361d73fe90cd"}, + {file = "lancedb-0.13.0b1-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:231e1f00d724c468922f7951d902622d4ccb21c2db2a148b845beaebee5d35b3"}, + {file = "lancedb-0.13.0b1-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:fecdd71f137e52193bfb5843610f32fe025a60a1edf5f80530704de879706c6b"}, + {file = "lancedb-0.13.0b1-cp38-abi3-win_amd64.whl", hash = "sha256:7852d9c04a4402407af06bbbf78bf339a169f1df2bf5c70da586ca733ec40a68"}, ] [package.dependencies] @@ -4587,7 +4587,7 @@ deprecation = "*" overrides = ">=0.7" packaging = "*" pydantic = ">=1.10" -pylance = "0.13.0" +pylance = "0.16.1" ratelimiter = ">=1.0,<2.0" requests = ">=2.31.0" retry = ">=0.9.2" @@ -4598,8 +4598,8 @@ azure = ["adlfs (>=2024.2.0)"] clip = ["open-clip", "pillow", "torch"] dev = ["pre-commit", "ruff"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] -embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] -tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] [[package]] name = "lazy-object-proxy" @@ -6660,62 +6660,54 @@ files = [ [[package]] name = "pyarrow" -version = "14.0.2" +version = "17.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-14.0.2-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:ba9fe808596c5dbd08b3aeffe901e5f81095baaa28e7d5118e01354c64f22807"}, - {file = "pyarrow-14.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:22a768987a16bb46220cef490c56c671993fbee8fd0475febac0b3e16b00a10e"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dbba05e98f247f17e64303eb876f4a80fcd32f73c7e9ad975a83834d81f3fda"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a898d134d00b1eca04998e9d286e19653f9d0fcb99587310cd10270907452a6b"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:87e879323f256cb04267bb365add7208f302df942eb943c93a9dfeb8f44840b1"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:76fc257559404ea5f1306ea9a3ff0541bf996ff3f7b9209fc517b5e83811fa8e"}, - {file = "pyarrow-14.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0c4a18e00f3a32398a7f31da47fefcd7a927545b396e1f15d0c85c2f2c778cd"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02"}, - {file = "pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379"}, - {file = "pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:e354fba8490de258be7687f341bc04aba181fc8aa1f71e4584f9890d9cb2dec2"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:20e003a23a13da963f43e2b432483fdd8c38dc8882cd145f09f21792e1cf22a1"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc0de7575e841f1595ac07e5bc631084fd06ca8b03c0f2ecece733d23cd5102a"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e986dc859712acb0bd45601229021f3ffcdfc49044b64c6d071aaf4fa49e98"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f7d029f20ef56673a9730766023459ece397a05001f4e4d13805111d7c2108c0"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:209bac546942b0d8edc8debda248364f7f668e4aad4741bae58e67d40e5fcf75"}, - {file = "pyarrow-14.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:1e6987c5274fb87d66bb36816afb6f65707546b3c45c44c28e3c4133c010a881"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a01d0052d2a294a5f56cc1862933014e696aa08cc7b620e8c0cce5a5d362e976"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a51fee3a7db4d37f8cda3ea96f32530620d43b0489d169b285d774da48ca9785"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64df2bf1ef2ef14cee531e2dfe03dd924017650ffaa6f9513d7a1bb291e59c15"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c0fa3bfdb0305ffe09810f9d3e2e50a2787e3a07063001dcd7adae0cee3601a"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c65bf4fd06584f058420238bc47a316e80dda01ec0dfb3044594128a6c2db794"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:63ac901baec9369d6aae1cbe6cca11178fb018a8d45068aaf5bb54f94804a866"}, - {file = "pyarrow-14.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:75ee0efe7a87a687ae303d63037d08a48ef9ea0127064df18267252cfe2e9541"}, - {file = "pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" -[[package]] -name = "pyarrow-hotfix" -version = "0.6" -description = "" -optional = true -python-versions = ">=3.5" -files = [ - {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, - {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, -] +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] [[package]] name = "pyasn1" @@ -6993,22 +6985,22 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pylance" -version = "0.13.0" +version = "0.16.1" description = "python wrapper for Lance columnar format" optional = false python-versions = ">=3.9" files = [ - {file = "pylance-0.13.0-cp39-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2f3d6f9eec1f59f45dccb01075ba79868b8d37c8371d6210bcf6418217a0dd8b"}, - {file = "pylance-0.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f4861ab466c94b0f9a4b4e6de6e1dfa02f40e7242d8db87447bc7bb7d89606ac"}, - {file = "pylance-0.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3cb92547e145f5bfb0ea7d6f483953913b9bdd44c45bea84fc95a18da9f5853"}, - {file = "pylance-0.13.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:d1ddd7700924bc6b6b0774ea63d2aa23f9210a86cd6d6af0cdfa987df776d50d"}, - {file = "pylance-0.13.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c51d4b6e59cf4dc97c11a35b299f11e80dbdf392e2d8dc498573c26474a3c19e"}, - {file = "pylance-0.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:4018ba016f1445874960a4ba2ad5c80cb380f3116683282ee8beabd38fa8989d"}, + {file = "pylance-0.16.1-cp39-abi3-macosx_10_15_x86_64.whl", hash = "sha256:7092303ae21bc162edd98e20fc39785fa1ec6b67f04132977ac0fd63110ba16f"}, + {file = "pylance-0.16.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7c2ebdf89928c68f053ab9e369a5477da0a2ba70d47c00075dc10a37039d9e90"}, + {file = "pylance-0.16.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4525c2fd8095830b753a3efb7285f358b016836086683fe977f9f1de8e6866c"}, + {file = "pylance-0.16.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:645f0ab338bc4bd42bf3321bbb4053261979117aefd8477c2192ba624de27778"}, + {file = "pylance-0.16.1-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3a7464d60aca51e89196a79c638bcbff0bddb77158946e2ea6b5fcbc6cfc63e1"}, + {file = "pylance-0.16.1-cp39-abi3-win_amd64.whl", hash = "sha256:d12c628dfbd49efde15a5512247065341f3efb29989dd08fb5a7023f013471ee"}, ] [package.dependencies] -numpy = ">=1.22" -pyarrow = ">=12,<15.0.1" +numpy = ">=1.22,<2" +pyarrow = ">=12" [package.extras] benchmarks = ["pytest-benchmark"] @@ -8659,6 +8651,44 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tantivy" +version = "0.22.0" +description = "" +optional = true +python-versions = ">=3.8" +files = [ + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:732ec74c4dd531253af4c14756b7650527f22c7fab244e83b42d76a0a1437219"}, + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bf1da07b7e1003af4260b1ef3c3db7cb05db1578606092a6ca7a3cff2a22858a"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689ed52985e914c531eadd8dd2df1b29f0fa684687b6026206dbdc57cf9297b2"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5f2885c8e98d1efcc4836c3e9d327d6ba2bc6b5e2cd8ac9b0356af18f571070"}, + {file = "tantivy-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:4543cc72f4fec30f50fed5cd503c13d0da7cffda47648c7b72c1759103309e41"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:ec693abf38f229bc1361b0d34029a8bb9f3ee5bb956a3e745e0c4a66ea815bec"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e385839badc12b81e38bf0a4d865ee7c3a992fea9f5ce4117adae89369e7d1eb"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c097d94be1af106676c86c02b185f029484fdbd9a2b9f17cb980e840e7bdad"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c47a5cdec306ea8594cb6e7effd4b430932ebfd969f9e8f99e343adf56a79bc9"}, + {file = "tantivy-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:ba0ca878ed025d79edd9c51cda80b0105be8facbaec180fea64a17b80c74e7db"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:925682f3acb65c85c2a5a5b131401b9f30c184ea68aa73a8cc7c2ea6115e8ae3"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d75760e45a329313001354d6ca415ff12d9d812343792ae133da6bfbdc4b04a5"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd909d122b5af457d955552c304f8d5d046aee7024c703c62652ad72af89f3c7"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99266ffb204721eb2bd5b3184aa87860a6cff51b4563f808f78fa22d85a8093"}, + {file = "tantivy-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:9ed6b813a1e7769444e33979b46b470b2f4c62d983c2560ce9486fb9be1491c9"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:97eb05f8585f321dbc733b64e7e917d061dc70c572c623730b366c216540d149"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc74748b6b886475c12bf47c8814861b79f850fb8a528f37ae0392caae0f6f14"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7059c51c25148e07a20bd73efc8b51c015c220f141f3638489447b99229c8c0"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f88d05f55e2c3e581de70c5c7f46e94e5869d1c0fd48c5db33be7e56b6b88c9a"}, + {file = "tantivy-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:09bf6de2fa08aac1a7133bee3631c1123de05130fd2991ceb101f2abac51b9d2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:9de1a7497d377477dc09029c343eb9106c2c5fdb2e399f8dddd624cd9c7622a2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e81e47edd0faffb5ad20f52ae75c3a2ed680f836e72bc85c799688d3a2557502"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27333518dbc309299dafe79443ee80eede5526a489323cdb0506b95eb334f985"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c9452d05e42450be53a9a58a9cf13f9ff8d3605c73bdc38a34ce5e167a25d77"}, + {file = "tantivy-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:51e4ec0d44637562bf23912d18d12850c4b3176c0719e7b019d43b59199a643c"}, + {file = "tantivy-0.22.0.tar.gz", hash = "sha256:dce07fa2910c94934aa3d96c91087936c24e4a5802d839625d67edc6d1c95e5c"}, +] + +[package.extras] +dev = ["nox"] + [[package]] name = "tblib" version = "2.0.0" @@ -9681,7 +9711,7 @@ duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] motherduck = ["duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9696,4 +9726,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "a64fdd2845d27c9abc344809be68cba08f46641aabdc07416c37c802450fe4f3" +content-hash = "1d8fa59c9ef876d699cb5b5a2fcadb9a78c4c4d28a9fca7ca0e83147c08feaae" diff --git a/pyproject.toml b/pyproject.toml index f33bbbefcf..53ef7a5d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.5.4a0" +version = "0.5.4" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -80,7 +80,8 @@ databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true } -deltalake = { version = ">=0.17.4", optional = true } +tantivy = { version = ">= 0.22.0", optional = true } +deltalake = { version = ">=0.19.0", optional = true } [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -105,7 +106,7 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] @@ -218,7 +219,7 @@ dbt-duckdb = ">=1.2.0" pymongo = ">=4.3.3" pandas = ">2" alive-progress = ">=3.0.1" -pyarrow = ">=14.0.0" +pyarrow = ">=17.0.0" psycopg2-binary = ">=2.9" lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.35" diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index ba86edf417..292175569b 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -6,7 +6,8 @@ bucket_url_gs="gs://ci-test-bucket" bucket_url_s3="s3://dlt-ci-test-bucket" bucket_url_file="_storage" bucket_url_az="az://dlt-ci-test-bucket" +bucket_url_abfss="abfss://dlt-ci-test-bucket@dltdata.dfs.core.windows.net" bucket_url_r2="s3://dlt-ci-test-bucket" # use "/" as root path bucket_url_gdrive="gdrive://15eC3e5MNew2XAIefWNlG8VlEa0ISnnaG" -memory="memory://m" \ No newline at end of file +memory="memory:///m" \ No newline at end of file diff --git a/tests/common/cases/normalizers/sql_upper.py b/tests/common/cases/normalizers/sql_upper.py index f2175f06ad..eb88775f95 100644 --- a/tests/common/cases/normalizers/sql_upper.py +++ b/tests/common/cases/normalizers/sql_upper.py @@ -1,5 +1,3 @@ -from typing import Any, Sequence - from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention diff --git a/tests/common/data_writers/test_data_writers.py b/tests/common/data_writers/test_data_writers.py index 9b4e61a2f7..03723b7b55 100644 --- a/tests/common/data_writers/test_data_writers.py +++ b/tests/common/data_writers/test_data_writers.py @@ -5,6 +5,7 @@ from dlt.common import pendulum, json from dlt.common.data_writers.exceptions import DataWriterNotFound, SpecLookupFailed +from dlt.common.metrics import DataWriterMetrics from dlt.common.typing import AnyFun from dlt.common.data_writers.escape import ( @@ -25,7 +26,6 @@ ArrowToTypedJsonlListWriter, CsvWriter, DataWriter, - DataWriterMetrics, EMPTY_DATA_WRITER_METRICS, ImportFileWriter, InsertValuesWriter, @@ -180,12 +180,13 @@ def test_data_writer_metrics_add() -> None: metrics = DataWriterMetrics("file", 10, 100, now, now + 10) add_m: DataWriterMetrics = metrics + EMPTY_DATA_WRITER_METRICS # type: ignore[assignment] assert add_m == DataWriterMetrics("", 10, 100, now, now + 10) - assert metrics + metrics == DataWriterMetrics("", 20, 200, now, now + 10) + # will keep "file" because it is in both + assert metrics + metrics == DataWriterMetrics("file", 20, 200, now, now + 10) assert sum((metrics, metrics, metrics), EMPTY_DATA_WRITER_METRICS) == DataWriterMetrics( "", 30, 300, now, now + 10 ) # time range extends when added - add_m = metrics + DataWriterMetrics("file", 99, 120, now - 10, now + 20) # type: ignore[assignment] + add_m = metrics + DataWriterMetrics("fileX", 99, 120, now - 10, now + 20) # type: ignore[assignment] assert add_m == DataWriterMetrics("", 109, 220, now - 10, now + 20) diff --git a/tests/common/storages/test_local_filesystem.py b/tests/common/storages/test_local_filesystem.py index 14e3cc23d4..1bfe6c0b5b 100644 --- a/tests/common/storages/test_local_filesystem.py +++ b/tests/common/storages/test_local_filesystem.py @@ -45,7 +45,7 @@ ) def test_local_path_win_configuration(bucket_url: str, file_url: str) -> None: assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -66,7 +66,7 @@ def test_local_path_win_configuration(bucket_url: str, file_url: str) -> None: def test_local_user_win_path_configuration(bucket_url: str) -> None: file_url = "file:///" + pathlib.Path(bucket_url).expanduser().as_posix().lstrip("/") assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -99,7 +99,7 @@ def test_file_win_configuration() -> None: ) def test_file_posix_configuration(bucket_url: str, file_url: str) -> None: assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -117,7 +117,7 @@ def test_file_posix_configuration(bucket_url: str, file_url: str) -> None: def test_local_user_posix_path_configuration(bucket_url: str) -> None: file_url = "file:///" + pathlib.Path(bucket_url).expanduser().as_posix().lstrip("/") assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -166,7 +166,7 @@ def test_file_filesystem_configuration( assert FilesystemConfiguration.make_local_path(bucket_url) == str( pathlib.Path(local_path).resolve() ) - assert FilesystemConfiguration.make_file_uri(local_path) == norm_bucket_url + assert FilesystemConfiguration.make_file_url(local_path) == norm_bucket_url if local_path == "": with pytest.raises(ConfigurationValueError): diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index baac3b7af5..a1334ba1da 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -16,7 +16,7 @@ LoadStorageConfiguration, FilesystemConfiguration, LoadPackageInfo, - TJobState, + TPackageJobState, LoadStorage, ) from dlt.common.storages import DataItemStorage, FileStorage @@ -195,7 +195,7 @@ def assert_package_info( storage: LoadStorage, load_id: str, package_state: str, - job_state: TJobState, + job_state: TPackageJobState, jobs_count: int = 1, ) -> LoadPackageInfo: package_info = storage.get_load_package_info(load_id) diff --git a/tests/destinations/test_destination_name_and_config.py b/tests/destinations/test_destination_name_and_config.py index 11de706722..1e432a7803 100644 --- a/tests/destinations/test_destination_name_and_config.py +++ b/tests/destinations/test_destination_name_and_config.py @@ -60,7 +60,7 @@ def test_set_name_and_environment() -> None: def test_preserve_destination_instance() -> None: dummy1 = dummy(destination_name="dummy1", environment="dev/null/1") filesystem1 = filesystem( - FilesystemConfiguration.make_file_uri(TEST_STORAGE_ROOT), + FilesystemConfiguration.make_file_url(TEST_STORAGE_ROOT), destination_name="local_fs", environment="devel", ) @@ -210,7 +210,7 @@ def test_destination_config_in_name(environment: DictStrStr) -> None: with pytest.raises(ConfigFieldMissingException): p.destination_client() - environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = FilesystemConfiguration.make_file_uri( + environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = FilesystemConfiguration.make_file_url( "_storage" ) assert p._fs_client().dataset_path.endswith(p.dataset_name) diff --git a/tests/extract/data_writers/test_buffered_writer.py b/tests/extract/data_writers/test_buffered_writer.py index 5cad5a35b9..205e3f83dc 100644 --- a/tests/extract/data_writers/test_buffered_writer.py +++ b/tests/extract/data_writers/test_buffered_writer.py @@ -7,12 +7,12 @@ from dlt.common.data_writers.exceptions import BufferedDataWriterClosed from dlt.common.data_writers.writers import ( DataWriter, - DataWriterMetrics, InsertValuesWriter, JsonlWriter, ALL_WRITERS, ) from dlt.common.destination.capabilities import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.metrics import DataWriterMetrics from dlt.common.schema.utils import new_column from dlt.common.storages.file_storage import FileStorage diff --git a/tests/extract/data_writers/test_data_item_storage.py b/tests/extract/data_writers/test_data_item_storage.py index feda51c229..558eeec79e 100644 --- a/tests/extract/data_writers/test_data_item_storage.py +++ b/tests/extract/data_writers/test_data_item_storage.py @@ -3,8 +3,9 @@ import pytest from dlt.common.configuration.container import Container -from dlt.common.data_writers.writers import DataWriterMetrics, DataWriter +from dlt.common.data_writers.writers import DataWriter from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.metrics import DataWriterMetrics from dlt.common.schema.utils import new_column from dlt.common.storages.data_item_storage import DataItemStorage diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index f4082a7d86..a9867aa54b 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -30,8 +30,10 @@ from dlt.sources.helpers.transform import take_first from dlt.extract.incremental import IncrementalResourceWrapper, Incremental from dlt.extract.incremental.exceptions import ( + IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, + IncrementalCursorPathHasValueNone, ) from dlt.pipeline.exceptions import PipelineStepFailed @@ -43,6 +45,10 @@ ALL_TEST_DATA_ITEM_FORMATS, ) +from tests.pipeline.utils import assert_query_data + +import pyarrow as pa + @pytest.fixture(autouse=True) def switch_to_fifo(): @@ -166,8 +172,9 @@ def some_data(created_at=dlt.sources.incremental("created_at")): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - p.extract(some_data()) + assert values == [None] + p.extract(some_data()) assert values == [None, 5] @@ -634,6 +641,458 @@ def some_data(last_timestamp=dlt.sources.incremental("item.timestamp")): assert pip_ex.value.__context__.json_path == "item.timestamp" +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_1( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": None}, + {"id": 2, "created_at": 1}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_does_not_include_overlapping_records( + item_type: TestDataItemFormat, +) -> None: + @dlt.resource + def some_data( + invocation: int, + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include"), + ): + if invocation == 1: + yield data_to_item_format( + item_type, + [ + {"id": 1, "created_at": None}, + {"id": 2, "created_at": 1}, + {"id": 3, "created_at": 2}, + ], + ) + elif invocation == 2: + yield data_to_item_format( + item_type, + [ + {"id": 4, "created_at": 1}, + {"id": 5, "created_at": None}, + {"id": 6, "created_at": 3}, + ], + ) + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(1), destination="duckdb") + p.run(some_data(2), destination="duckdb") + + assert_query_data(p, "select id from some_data order by id", [1, 2, 3, 5, 6]) + assert_query_data( + p, "select created_at from some_data order by created_at", [1, 2, 3, None, None] + ) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 3 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_2( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": None}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_3( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": None}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_without_cursor_path( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [2]) + assert_query_data(p, "select count(created_at) from some_data", [1]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 1 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_excludes_records_and_updates_incremental_cursor( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": None}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="exclude") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_can_raise_on_none_1(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": None}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + with pytest.raises(IncrementalCursorPathHasValueNone) as py_ex: + list(some_data()) + assert py_ex.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as pip_ex: + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + + assert isinstance(pip_ex.value.__context__, IncrementalCursorPathHasValueNone) + assert pip_ex.value.__context__.json_path == "created_at" + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_can_raise_on_none_2(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + # there is no fixed, error because cursor path is missing + if item_type == "object": + with pytest.raises(IncrementalCursorPathMissing) as ex: + list(some_data()) + assert ex.value.json_path == "created_at" + # there is a fixed schema, error because value is null + else: + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as e: # type: ignore[assignment] + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + if item_type == "object": + assert isinstance(e.value.__context__, IncrementalCursorPathMissing) + else: + assert isinstance(e.value.__context__, IncrementalCursorPathHasValueNone) + assert e.value.__context__.json_path == "created_at" # type: ignore[attr-defined] + + +@pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch", "pandas"]) +def test_cursor_path_none_can_raise_on_column_missing(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + with pytest.raises(IncrementalCursorPathMissing) as py_ex: + list(some_data()) + assert py_ex.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as pip_ex: + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + assert pip_ex.value.__context__.json_path == "created_at" # type: ignore[attr-defined] + assert isinstance(pip_ex.value.__context__, IncrementalCursorPathMissing) + + +def test_cursor_path_none_nested_can_raise_on_none_1() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[0].created_at", on_cursor_value_missing="raise" + ) + ): + yield {"data": {"items": [{"created_at": None}, {"created_at": 1}]}} + + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "data.items[0].created_at" + + +def test_cursor_path_none_nested_can_raise_on_none_2() -> None: + # No pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="raise" + ) + ): + yield {"data": {"items": [{"created_at": None}, {"created_at": 1}]}} + + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "data.items[*].created_at" + + +def test_cursor_path_none_nested_can_include_on_none_1() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"created_at": None}, + {"created_at": 1}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"created_at": None}, + {"created_at": 1}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +def test_cursor_path_none_nested_can_include_on_none_2() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[0].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"created_at": None}, + {"created_at": 1}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"created_at": None}, + {"created_at": 1}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +def test_cursor_path_none_nested_includes_rows_without_cursor_path() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"id": 1}, + {"id": 2, "created_at": 2}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"id": 1}, + {"id": 2, "created_at": 2}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_set_default_value_for_incremental_cursor(item_type: TestDataItemFormat) -> None: + @dlt.resource + def some_data(created_at=dlt.sources.incremental("updated_at")): + yield data_to_item_format( + item_type, + [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 4, "updated_at": None}, + {"id": 3, "created_at": 3, "updated_at": 3}, + ], + ) + + def set_default_updated_at(record): + if record.get("updated_at") is None: + record["updated_at"] = record.get("created_at", pendulum.now().int_timestamp) + return record + + def set_default_updated_at_pandas(df): + df["updated_at"] = df["updated_at"].fillna(df["created_at"]) + return df + + def set_default_updated_at_arrow(records): + updated_at_is_null = pa.compute.is_null(records.column("updated_at")) + updated_at_filled = pa.compute.if_else( + updated_at_is_null, records.column("created_at"), records.column("updated_at") + ) + if item_type == "arrow-table": + records = records.set_column( + records.schema.get_field_index("updated_at"), + pa.field("updated_at", records.column("updated_at").type), + updated_at_filled, + ) + elif item_type == "arrow-batch": + columns = [records.column(i) for i in range(records.num_columns)] + columns[2] = updated_at_filled + records = pa.RecordBatch.from_arrays(columns, schema=records.schema) + return records + + if item_type == "object": + func = set_default_updated_at + elif item_type == "pandas": + func = set_default_updated_at_pandas + elif item_type in ["arrow-table", "arrow-batch"]: + func = set_default_updated_at_arrow + + result = list(some_data().add_map(func, insert_at=1)) + values = data_item_to_list(item_type, result) + assert data_item_length(values) == 3 + assert values[1]["updated_at"] == 4 + + # same for pipeline run + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data().add_map(func, insert_at=1)) + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "updated_at" + ] + assert s["last_value"] == 4 + + def test_json_path_cursor() -> None: @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp|modifiedAt")): @@ -1303,7 +1762,7 @@ def some_data( ) # will cause invalid comparison if item_type == "object": - with pytest.raises(InvalidStepFunctionArguments): + with pytest.raises(IncrementalCursorInvalidCoercion): list(resource) else: data = data_item_to_list(item_type, list(resource)) @@ -2065,3 +2524,21 @@ def test_source(): incremental_steps = test_source_incremental().table_name._pipe._steps assert isinstance(incremental_steps[-2], ValidateItem) assert isinstance(incremental_steps[-1], IncrementalResourceWrapper) + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_date_coercion(item_type: TestDataItemFormat) -> None: + today = datetime.today().date() + + @dlt.resource() + def updated_is_int(updated_at=dlt.sources.incremental("updated_at", initial_value=today)): + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) + + pip_1_name = "test_pydantic_columns_validator_" + uniq_id() + pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination="duckdb") + + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(updated_is_int()) + assert isinstance(pip_ex.value.__cause__, IncrementalCursorInvalidCoercion) + assert pip_ex.value.__cause__.cursor_path == "updated_at" diff --git a/tests/libs/test_deltalake.py b/tests/libs/test_deltalake.py index 3e2d7cc3f6..dc5586eb32 100644 --- a/tests/libs/test_deltalake.py +++ b/tests/libs/test_deltalake.py @@ -95,21 +95,9 @@ def arrow_data( # type: ignore[return] client = cast(FilesystemClient, client) storage_options = _deltalake_storage_options(client.config) - with pytest.raises(Exception): - # bug in `delta-rs` causes error when writing big decimal values - # https://github.com/delta-io/delta-rs/issues/2510 - # if this test fails, the bug has been fixed and we should remove this - # note from the docs: - write_delta_table( - remote_dir + "/corrupt_delta_table", - arrow_table_all_data_types("arrow-table", include_decimal_default_precision=True)[0], - write_disposition="append", - storage_options=storage_options, - ) - arrow_table = arrow_table_all_data_types( "arrow-table", - include_decimal_default_precision=False, + include_decimal_default_precision=True, include_decimal_arrow_max_precision=True, num_rows=2, )[0] diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index f6a06180c9..bb989a887c 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -3,9 +3,12 @@ pytest.importorskip("databricks") +from dlt.common.exceptions import TerminalValueError +from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob +from dlt.common.configuration import resolve_configuration +from dlt.destinations import databricks from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration -from dlt.common.configuration import resolve_configuration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -34,3 +37,48 @@ def test_databricks_credentials_to_connector_params(): assert params["extra_a"] == "a" assert params["extra_b"] == "b" assert params["_socket_timeout"] == credentials.socket_timeout + + +def test_databricks_configuration() -> None: + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.is_staging_external_location is False + assert config.staging_credentials_name is None + + os.environ["IS_STAGING_EXTERNAL_LOCATION"] = "true" + os.environ["STAGING_CREDENTIALS_NAME"] = "credential" + config = bricks.configuration(None, accept_partial=True) + assert config.is_staging_external_location is True + assert config.staging_credentials_name == "credential" + + # explicit params + bricks = databricks(is_staging_external_location=None, staging_credentials_name="credential2") + config = bricks.configuration(None, accept_partial=True) + assert config.staging_credentials_name == "credential2" + assert config.is_staging_external_location is None + + +def test_databricks_abfss_converter() -> None: + with pytest.raises(TerminalValueError): + DatabricksLoadJob.ensure_databricks_abfss_url("az://dlt-ci-test-bucket") + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket", "my_account" + ) + assert abfss_url == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net" + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket/path/to/file.parquet", "my_account" + ) + assert ( + abfss_url + == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) + assert ( + abfss_url + == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 3cad7dda2c..29ca1a2b57 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -3,8 +3,8 @@ from typing import Tuple, Union, Dict from urllib.parse import urlparse - -from fsspec import AbstractFileSystem +from fsspec import AbstractFileSystem, get_filesystem_class, register_implementation +from fsspec.core import filesystem as fs_filesystem import pytest from tenacity import retry, stop_after_attempt, wait_fixed @@ -15,6 +15,7 @@ from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import AnyAzureCredentials from dlt.common.storages import fsspec_from_config, FilesystemConfiguration +from dlt.common.storages.configuration import make_fsspec_url from dlt.common.storages.fsspec_filesystem import MTIME_DISPATCH, glob_files from dlt.common.utils import custom_environ, uniq_id from dlt.destinations import filesystem @@ -22,11 +23,12 @@ FilesystemDestinationClientConfiguration, ) from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders + +from tests.common.configuration.utils import environment from tests.common.storages.utils import TEST_SAMPLE_FILES, assert_sample_files -from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET +from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET, WITH_GDRIVE_BUCKETS from tests.utils import autouse_test_storage -from .utils import self_signed_cert -from tests.common.configuration.utils import environment +from tests.load.filesystem.utils import self_signed_cert # mark all tests as essential, do not remove @@ -53,6 +55,24 @@ def test_filesystem_configuration() -> None: } +@pytest.mark.parametrize("bucket_url", WITH_GDRIVE_BUCKETS) +def test_remote_url(bucket_url: str) -> None: + # make absolute urls out of paths + scheme = urlparse(bucket_url).scheme + if not scheme: + scheme = "file" + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + if scheme == "gdrive": + from dlt.common.storages.fsspecs.google_drive import GoogleDriveFileSystem + + register_implementation("gdrive", GoogleDriveFileSystem, "GoogleDriveFileSystem") + + fs_class = get_filesystem_class(scheme) + fs_path = fs_class._strip_protocol(bucket_url) + # reconstitute url + assert make_fsspec_url(scheme, fs_path, bucket_url) == bucket_url + + def test_filesystem_instance(with_gdrive_buckets_env: str) -> None: @retry(stop=stop_after_attempt(10), wait=wait_fixed(1), reraise=True) def check_file_exists(filedir_: str, file_url_: str): @@ -72,10 +92,8 @@ def check_file_changed(file_url_: str): bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] config = get_config() # we do not add protocol to bucket_url (we need relative path) - assert bucket_url.startswith(config.protocol) or config.protocol == "file" + assert bucket_url.startswith(config.protocol) or config.is_local_filesystem filesystem, url = fsspec_from_config(config) - if config.protocol != "file": - assert bucket_url.endswith(url) # do a few file ops now = pendulum.now() filename = f"filesystem_common_{uniq_id()}" @@ -113,7 +131,9 @@ def test_glob_overlapping_path_files(with_gdrive_buckets_env: str) -> None: # "standard_source/sample" overlaps with a real existing "standard_source/samples". walk operation on azure # will return all files from "standard_source/samples" and report the wrong "standard_source/sample" path to the user # here we test we do not have this problem with out glob - bucket_url, _, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + bucket_url, config, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + if config.protocol in ["file"]: + pytest.skip(f"{config.protocol} not supported in this test") # use glob to get data all_file_items = list(glob_files(filesystem, bucket_url)) assert len(all_file_items) == 0 @@ -272,18 +292,18 @@ def glob_test_setup( config = get_config() # enable caches config.read_only = True - if config.protocol in ["file"]: - pytest.skip(f"{config.protocol} not supported in this test") # may contain query string - bucket_url_parsed = urlparse(bucket_url) - bucket_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, glob_folder) - ).geturl() - filesystem, _ = fsspec_from_config(config) + filesystem, fs_path = fsspec_from_config(config) + bucket_url = make_fsspec_url(config.protocol, posixpath.join(fs_path, glob_folder), bucket_url) if config.protocol == "memory": - mem_path = os.path.join("m", "standard_source") + mem_path = os.path.join("/m", "standard_source") if not filesystem.isdir(mem_path): filesystem.mkdirs(mem_path) filesystem.upload(TEST_SAMPLE_FILES, mem_path, recursive=True) + if config.protocol == "file": + file_path = os.path.join("_storage", "standard_source") + if not filesystem.isdir(file_path): + filesystem.mkdirs(file_path) + filesystem.upload(TEST_SAMPLE_FILES, file_path, recursive=True) return bucket_url, config, filesystem diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_merge.py similarity index 52% rename from tests/load/lancedb/test_remove_orphaned_records.py rename to tests/load/lancedb/test_merge.py index 37bca8ffb0..f04c846df7 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_merge.py @@ -6,18 +6,17 @@ from lancedb.table import Table # type: ignore from pandas import DataFrame from pandas.testing import assert_frame_equal -from pyarrow import Table import dlt -from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrAny, DictStrStr from dlt.common.utils import uniq_id from dlt.destinations.impl.lancedb.lancedb_adapter import ( - DOCUMENT_ID_HINT, lancedb_adapter, ) from tests.load.lancedb.utils import chunk_document from tests.load.utils import ( drop_active_pipeline_data, + sequence_generator, ) from tests.pipeline.utils import ( assert_load_info, @@ -34,7 +33,7 @@ def drop_lancedb_data() -> Iterator[None]: drop_active_pipeline_data() -def test_lancedb_remove_orphaned_records() -> None: +def test_lancedb_remove_nested_orphaned_records() -> None: pipeline = dlt.pipeline( pipeline_name="test_lancedb_remove_orphaned_records", destination="lancedb", @@ -42,10 +41,11 @@ def test_lancedb_remove_orphaned_records() -> None: dev_mode=True, ) - @dlt.resource( # type: ignore[call-overload] + @dlt.resource( table_name="parent", - write_disposition="merge", - columns={"id": {DOCUMENT_ID_HINT: True}}, + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="id", + merge_key="id", ) def identity_resource( data: List[DictStrAny], @@ -76,16 +76,24 @@ def identity_resource( { "id": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], - }, # Removed one child and one grandchild + }, # Removes bar_2, baz_2 and baz_3. { "id": 2, "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], - }, # Changed child and grandchild + }, # Removes bar_3, baz_4. ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) with pipeline.destination_client() as client: + expected_parent_data = pd.DataFrame( + data=[ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + ) + expected_child_data = pd.DataFrame( data=[ {"bar": 1}, @@ -105,32 +113,29 @@ def identity_resource( ] ) + parent_table_name = client.make_qualified_table_name("parent") # type: ignore[attr-defined] child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] "parent__child__grandchild" ) + parent_tbl = client.db_client.open_table(parent_table_name) # type: ignore[attr-defined] child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] - actual_child_df = ( - child_tbl.to_pandas() - .sort_values(by="bar") - .reset_index(drop=True) - .reset_index(drop=True) - ) + actual_parent_df = parent_tbl.to_pandas().sort_values(by="id").reset_index(drop=True) + actual_child_df = child_tbl.to_pandas().sort_values(by="bar").reset_index(drop=True) actual_grandchild_df = ( - grandchild_tbl.to_pandas() - .sort_values(by="baz") - .reset_index(drop=True) - .reset_index(drop=True) + grandchild_tbl.to_pandas().sort_values(by="baz").reset_index(drop=True) ) + expected_parent_data = expected_parent_data.sort_values(by="id").reset_index(drop=True) expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) + assert_frame_equal(actual_parent_df[["id"]], expected_parent_data) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -143,17 +148,19 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: dev_mode=True, ) - @dlt.resource( # type: ignore[call-overload] + @dlt.resource( table_name="root", - write_disposition="merge", - merge_key=["chunk_hash"], - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], ) def identity_resource( data: List[DictStrAny], ) -> Generator[List[DictStrAny], None, None]: yield data + lancedb_adapter(identity_resource) + run_1 = [ {"doc_id": 1, "chunk_hash": "1a"}, {"doc_id": 2, "chunk_hash": "2a"}, @@ -197,11 +204,76 @@ def identity_resource( assert_frame_equal(actual_root_df, expected_root_table_df) +def test_lancedb_remove_orphaned_records_root_table_string_doc_id() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2a"}, + {"doc_id": "B", "chunk_hash": "2b"}, + {"doc_id": "B", "chunk_hash": "2c"}, + {"doc_id": "C", "chunk_hash": "3a"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: @dlt.resource( - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", - merge_key=["chunk"], + primary_key=["doc_id", "chunk"], + merge_key="doc_id", ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: @@ -215,7 +287,10 @@ def documents_source( ) -> Any: return documents(docs) - lancedb_adapter(documents, embed=["chunk"], document_id="doc_id") + lancedb_adapter( + documents, + embed=["chunk"], + ) pipeline = dlt.pipeline( pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", @@ -262,7 +337,89 @@ def documents_source( # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. assert len(df) == 21 - assert "vector__" in df.columns - for _, vector in enumerate(df["vector__"]): + assert "vector" in df.columns + for _, vector in enumerate(df["vector"]): assert isinstance(vector, np.ndarray) assert vector.size > 0 + + +def test_lancedb_compound_merge_key_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_compound_merge_key", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id", "chunk_hash"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource, no_remove_orphans=True) + + run_1 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash", "foo"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) + )[["doc_id", "chunk_hash", "foo"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_must_provide_at_least_primary_key_on_merge_disposition() -> None: + """We need upsert merge's deterministic _dlt_id to perform orphan removal. + Hence, we require at least the primary key required (raises exception if missing). + Specify a merge key for custom orphan identification.""" + generator_instance1 = sequence_generator() + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "upsert"}) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + pipeline = dlt.pipeline( + pipeline_name="test_must_provide_both_primary_and_merge_key_on_merge_disposition", + destination="lancedb", + dataset_name=( + f"test_must_provide_both_primary_and_merge_key_on_merge_disposition{uniq_id()}" + ), + ) + with pytest.raises(Exception): + load_info = pipeline.run( + some_data(), + ) + assert_load_info(load_info) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 4b964604e6..dcbe0eb04e 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,7 +1,11 @@ +import multiprocessing from typing import Iterator, Generator, Any, List +from typing import Mapping from typing import Union, Dict import pytest +from lancedb import DBConnection # type: ignore +from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore from lancedb.table import Table # type: ignore import dlt @@ -12,11 +16,9 @@ from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, - DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource -from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -27,7 +29,7 @@ @pytest.fixture(autouse=True) -def drop_lancedb_data() -> Iterator[None]: +def drop_lancedb_data() -> Iterator[Any]: yield drop_active_pipeline_data() @@ -54,14 +56,14 @@ def some_data() -> Generator[DictStrStr, Any, None]: lancedb_adapter( some_data, - document_id=["content"], + merge_key=["content"], ) assert some_data.columns["content"] == { # type: ignore "name": "content", "data_type": "text", "x-lancedb-embed": True, - "x-lancedb-doc-id": True, + "merge_key": True, } @@ -133,14 +135,13 @@ def some_data() -> Generator[DictStrStr, Any, None]: def test_explicit_append() -> None: - """Append should work even when the primary key is specified.""" data = [ {"doc_id": 1, "content": "1"}, {"doc_id": 2, "content": "2"}, {"doc_id": 3, "content": "3"}, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource() def some_data() -> Generator[List[DictStrAny], Any, None]: yield data @@ -157,6 +158,7 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: info = pipeline.run( some_data(), ) + assert_load_info(info) assert_table(pipeline, "some_data", items=data) @@ -278,23 +280,11 @@ def test_pipeline_merge() -> None: }, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource(primary_key=["doc_id"]) def movies_data() -> Any: yield data - @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) - def movies_data_explicit_merge_keys() -> Any: - yield data - - lancedb_adapter( - movies_data, - embed=["description"], - ) - - lancedb_adapter( - movies_data_explicit_merge_keys, - embed=["description"], - ) + lancedb_adapter(movies_data, embed=["description"], no_remove_orphans=True) pipeline = dlt.pipeline( pipeline_name="movies", @@ -303,7 +293,7 @@ def movies_data_explicit_merge_keys() -> Any: ) info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, dataset_name=f"MoviesDataset{uniq_id()}", ) assert_load_info(info) @@ -314,26 +304,11 @@ def movies_data_explicit_merge_keys() -> Any: info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, ) assert_load_info(info) assert_table(pipeline, "movies_data", items=data) - info = pipeline.run( - movies_data(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data", items=data) - - # Test with explicit merge keys. - info = pipeline.run( - movies_data_explicit_merge_keys(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) - def test_pipeline_with_schema_evolution() -> None: data = [ @@ -403,9 +378,9 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"]), + lancedb_adapter(data[:17], embed=["title", "body"], no_remove_orphans=True), table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key="id", ) assert_load_info(info) @@ -441,23 +416,23 @@ def test_merge_github_nested() -> None: def test_empty_dataset_allowed() -> None: # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. pipe = dlt.pipeline(destination="lancedb", dev_mode=True) - client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None - client = pipe.destination_client() # type: ignore[assignment] - assert client.dataset_name is None - assert client.sentinel_table == "dltSentinelTable" + client = pipe.destination_client() + assert client.dataset_name is None # type: ignore + assert client.sentinel_table == "dltSentinelTable" # type: ignore assert_table(pipe, "content", expected_items_count=3) -def test_merge_no_orphans() -> None: +def test_lancedb_remove_nested_orphaned_records_with_chunks() -> None: @dlt.resource( - write_disposition="merge", - primary_key=["doc_id"], + write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", + primary_key=["doc_id"], + merge_key=["doc_id"], ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: @@ -551,151 +526,124 @@ def documents_source( assert set(df["chunk_text"]) == expected_text -def test_merge_no_orphans_with_doc_id() -> None: - @dlt.resource( # type: ignore - write_disposition="merge", - table_name="document", - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, - ) - def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: - for doc in docs: - doc_id = doc["doc_id"] - chunks = chunk_document(doc["text"]) - embeddings = [ - { - "chunk_hash": digest128(chunk), - "chunk_text": chunk, - "embedding": mock_embed(), - } - for chunk in chunks - ] - yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} +search_data = [ + {"text": "Frodo was a happy puppy"}, + {"text": "There are several kittens playing"}, +] - @dlt.source(max_table_nesting=1) - def documents_source( - docs: List[DictStrAny], - ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: - return documents(docs) + +def test_fts_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data pipeline = dlt.pipeline( - pipeline_name="chunked_docs", + pipeline_name="test_fts_query", destination="lancedb", - dataset_name="chunked_documents", - dev_mode=True, + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), ) + assert_load_info(info) - initial_docs = [ - { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "doc_id": 1, - }, - { - "text": "Here's another document. It's a bit different from the first one.", - "doc_id": 2, - }, - ] + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client - info = pipeline.run(documents_source(initial_docs)) - assert_load_info(info) + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() - updated_docs = [ - { - "text": "This is the first document, but it has been updated with new content.", - "doc_id": 1, - }, - { - "text": "This is a completely new document that wasn't in the initial set.", - "doc_id": 3, - }, - ] + tbl.create_fts_index("text") + results = tbl.search("kittens", query_type="fts").select(["text"]).to_list() + assert results[0]["text"] == "There are several kittens playing" - info = pipeline.run(documents_source(updated_docs)) + +def test_semantic_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) assert_load_info(info) - with pipeline.destination_client() as client: - # Orphaned chunks/documents must have been discarded. - # Shouldn't contain any text from `initial_docs' where doc_id=1. - expected_text = { - "Here's ano", - "ther docum", - "ent. It's ", - "a bit diff", - "erent from", - " the first", - " one.", - "This is th", - "e first do", - "cument, bu", - "t it has b", - "een update", - "d with new", - " content.", - "This is a ", - "completely", - " new docum", - "ent that w", - "asn't in t", - "he initial", - " set.", - } + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client - embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() - tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - df = tbl.to_pandas() - assert set(df["chunk_text"]) == expected_text + results = ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) + assert results[0]["text"] == "Frodo was a happy puppy" -def test_primary_key_not_compatible_with_doc_id_hint_on_merge_disposition() -> None: - @dlt.resource( # type: ignore - write_disposition="merge", - table_name="document", - primary_key="doc_id", - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, - ) - def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: - for doc in docs: - doc_id = doc["doc_id"] - chunks = chunk_document(doc["text"]) - embeddings = [ - { - "chunk_hash": digest128(chunk), - "chunk_text": chunk, - "embedding": mock_embed(), - } - for chunk in chunks - ] - yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} +def test_semantic_query_custom_embedding_functions_registered() -> None: + """Test the LanceDB registry registered custom embedding functions defined in models, if any. + See: https://github.com/dlt-hub/dlt/issues/1765""" - @dlt.source(max_table_nesting=1) - def documents_source( - docs: List[DictStrAny], - ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: - return documents(docs) + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) pipeline = dlt.pipeline( - pipeline_name="test_mandatory_doc_id_hint_on_merge_disposition", + pipeline_name="test_fts_query", destination="lancedb", - dataset_name="test_mandatory_doc_id_hint_on_merge_disposition", - dev_mode=True, + dataset_name=f"test_pipeline_append{uniq_id()}", ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) - initial_docs = [ - { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "doc_id": 1, - }, - { - "text": "Here's another document. It's a bit different from the first one.", - "doc_id": 2, - }, - ] + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client_uri = client.db_client.uri + table_name = client.make_qualified_table_name("search_data_resource") + + # A new python process doesn't seem to correctly deserialize the custom embedding + # functions into global __REGISTRY__. + # We make sure to reset it as well to make sure no globals are propagated to the spawned process. + EmbeddingFunctionRegistry().reset() + with multiprocessing.get_context("spawn").Pool(1) as pool: + results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name)) - with pytest.raises(PipelineStepFailed): - pipeline.run(documents(initial_docs)) + assert results[0]["text"] == "Frodo was a happy puppy" + + +def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any: + import lancedb + + # Must read into __REGISTRY__ here. + db = lancedb.connect(db_client_uri) + tbl = db[table_name] + tbl.checkout_latest() + + return ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py new file mode 100644 index 0000000000..2f517aac8e --- /dev/null +++ b/tests/load/lancedb/test_utils.py @@ -0,0 +1,32 @@ +import pyarrow as pa +import pytest + +from dlt.destinations.impl.lancedb.utils import fill_empty_source_column_values_with_placeholder + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +def test_fill_empty_source_column_values_with_placeholder() -> None: + data = [ + pa.array(["", "hello", ""]), + pa.array(["hello", None, ""]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + table = pa.Table.from_arrays(data, names=["A", "B", "C", "D"]) + + source_columns = ["A", "B"] + placeholder = "placeholder" + + new_table = fill_empty_source_column_values_with_placeholder(table, source_columns, placeholder) + + expected_data = [ + pa.array(["placeholder", "hello", "placeholder"]), + pa.array(["hello", "placeholder", "placeholder"]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"]) + assert new_table.equals(expected_table) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 8dd56d22aa..8e2fddfba5 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -51,8 +51,7 @@ def assert_table( drop_keys = [ "_dlt_id", "_dlt_load_id", - dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index d6cf3ec3e8..3f3896de6c 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -55,8 +55,8 @@ def test_alter_table(client: MsSqlJobClient) -> None: # existing table has no columns sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] sqlfluff.parse(sql, dialect="tsql") - canonical_name = client.sql_client.make_qualified_table_name("event_test_table") - assert sql.count(f"ALTER TABLE {canonical_name}\nADD") == 1 + qualified_name = client.sql_client.make_qualified_table_name("event_test_table") + assert sql.count(f"ALTER TABLE {qualified_name}\nADD") == 1 assert "event_test_table" in sql assert '"col1" bigint NOT NULL' in sql assert '"col2" float NOT NULL' in sql @@ -75,3 +75,11 @@ def test_alter_table(client: MsSqlJobClient) -> None: assert '"col6_precision" decimal(6,2) NOT NULL' in sql assert '"col7_precision" varbinary(19)' in sql assert '"col11_precision" time(3) NOT NULL' in sql + + +def test_create_dlt_table(client: MsSqlJobClient) -> None: + # non existing table + sql = client._get_table_update_sql("_dlt_version", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="tsql") + qualified_name = client.sql_client.make_qualified_table_name("_dlt_version") + assert f"CREATE TABLE {qualified_name}" in sql diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py new file mode 100644 index 0000000000..5f8641f9fa --- /dev/null +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -0,0 +1,85 @@ +import pytest +import os + +from dlt.common.utils import uniq_id +from tests.load.utils import DestinationTestConfiguration, destinations_configs, AZ_BUCKET +from tests.pipeline.utils import assert_load_info + + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET), subset=("databricks",) + ), + ids=lambda x: x.name, +) +def test_databricks_external_location(destination_config: DestinationTestConfiguration) -> None: + # do not interfere with state + os.environ["RESTORE_FROM_DESTINATION"] = "False" + dataset_name = "test_databricks_external_location" + uniq_id() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + # should load abfss formatted url just fine + bricks = databricks(is_staging_external_location=False) + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert_load_info(info) + # get metrics + metrics = info.metrics[info.loads_ids[0]][0] + remote_url = list(metrics["job_metrics"].values())[0].remote_url + # abfss form was preserved + assert remote_url.startswith(abfss_bucket_url) + + # should fail on internal config error as external location is not configured + bricks = databricks(is_staging_external_location=True) + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "Invalid configuration value detected" + in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) + + # should fail on non existing stored credentials + bricks = databricks(is_staging_external_location=False, staging_credentials_name="CREDENTIAL_X") + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) + + # should fail on non existing stored credentials + # auto stage with regular az:// used + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", dataset_name=dataset_name, destination=bricks + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 759f443546..bc6cbd9848 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -12,9 +12,10 @@ from dlt.common import json from dlt.common import pendulum +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.utils import uniq_id -from dlt.common.exceptions import DependencyVersionException +from dlt.common.schema.typing import TWriteDisposition from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders @@ -299,6 +300,17 @@ def data_types(): assert len(rows) == 10 assert_all_data_types_row(rows[0], schema=column_schemas) + # make sure remote_url is in metrics + metrics = info.metrics[info.loads_ids[0]][0] + # TODO: only final copy job has remote_url. not the initial (empty) job for particular files + # we could implement an empty job for delta that generates correct remote_url + remote_url = list(metrics["job_metrics"].values())[-1].remote_url + assert remote_url.endswith("data_types") + bucket_url = destination_config.bucket_url + if FilesystemConfiguration.is_local_path(bucket_url): + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + assert remote_url.startswith(bucket_url) + # another run should append rows to the table info = pipeline.run(data_types()) assert_load_info(info) @@ -567,6 +579,104 @@ def two_part(): assert dt.metadata().partition_columns == [] +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +@pytest.mark.parametrize( + "write_disposition", + ( + "append", + "replace", + pytest.param({"disposition": "merge", "strategy": "upsert"}, id="upsert"), + ), +) +def test_delta_table_schema_evolution( + destination_config: DestinationTestConfiguration, + write_disposition: TWriteDisposition, +) -> None: + """Tests schema evolution (adding new columns) for `delta` table format.""" + from dlt.common.libs.deltalake import get_delta_tables, ensure_delta_compatible_arrow_data + from dlt.common.libs.pyarrow import pyarrow + + @dlt.resource( + write_disposition=write_disposition, + primary_key="pk", + table_format="delta", + ) + def delta_table(data): + yield data + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + # create Arrow table with one column, one row + pk_field = pyarrow.field("pk", pyarrow.int64(), nullable=False) + schema = pyarrow.schema([pk_field]) + arrow_table = pyarrow.Table.from_pydict({"pk": [1]}, schema=schema) + assert arrow_table.shape == (1, 1) + + # initial load + info = pipeline.run(delta_table(arrow_table)) + assert_load_info(info) + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] + expected = ensure_delta_compatible_arrow_data(arrow_table) + actual = dt.to_pyarrow_table() + assert actual.equals(expected) + + # create Arrow table with many columns, two rows + arrow_table = arrow_table_all_data_types( + "arrow-table", + include_decimal_default_precision=True, + include_decimal_arrow_max_precision=True, + include_not_normalized_name=False, + include_null=False, + num_rows=2, + )[0] + arrow_table = arrow_table.add_column(0, pk_field, [[1, 2]]) + + # second load — this should evolve the schema (i.e. add the new columns) + info = pipeline.run(delta_table(arrow_table)) + assert_load_info(info) + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] + actual = dt.to_pyarrow_table() + expected = ensure_delta_compatible_arrow_data(arrow_table) + if write_disposition == "append": + # just check shape and schema for `append`, because table comparison is + # more involved than with the other dispositions + assert actual.num_rows == 3 + actual.schema.equals(expected.schema) + else: + assert actual.sort_by("pk").equals(expected.sort_by("pk")) + + # create empty Arrow table with additional column + arrow_table = arrow_table.append_column( + pyarrow.field("another_new_column", pyarrow.string()), + [["foo", "foo"]], + ) + empty_arrow_table = arrow_table.schema.empty_table() + + # load 3 — this should evolve the schema without changing data + info = pipeline.run(delta_table(empty_arrow_table)) + assert_load_info(info) + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] + actual = dt.to_pyarrow_table() + expected_schema = ensure_delta_compatible_arrow_data(arrow_table).schema + assert actual.schema.equals(expected_schema) + expected_num_rows = 3 if write_disposition == "append" else 2 + assert actual.num_rows == expected_num_rows + # new column should have NULLs only + assert ( + actual.column("another_new_column").combine_chunks().to_pylist() + == [None] * expected_num_rows + ) + + @pytest.mark.parametrize( "destination_config", destinations_configs( @@ -594,7 +704,7 @@ def delta_table(data): # create empty Arrow table with schema arrow_table = arrow_table_all_data_types( "arrow-table", - include_decimal_default_precision=False, + include_decimal_default_precision=True, include_decimal_arrow_max_precision=True, include_not_normalized_name=False, include_null=False, @@ -630,22 +740,6 @@ def delta_table(data): ensure_delta_compatible_arrow_data(empty_arrow_table).schema ) - # run 3: empty Arrow table with different schema - # this should not alter the Delta table - empty_arrow_table_2 = pa.schema( - [pa.field("foo", pa.int64()), pa.field("bar", pa.string())] - ).empty_table() - - info = pipeline.run(delta_table(empty_arrow_table_2)) - assert_load_info(info) - dt = get_delta_tables(pipeline, "delta_table")["delta_table"] - assert dt.version() == 1 # still 1, no new commit was done - dt_arrow_table = dt.to_pyarrow_table() - assert dt_arrow_table.shape == (2, empty_arrow_table.num_columns) # shape did not change - assert dt_arrow_table.schema.equals( # schema did not change - ensure_delta_compatible_arrow_data(empty_arrow_table).schema - ) - # test `dlt.mark.materialize_table_schema()` users_materialize_table_schema.apply_hints(table_format="delta") info = pipeline.run(users_materialize_table_schema()) @@ -797,6 +891,67 @@ def parent_delta(): with pytest.raises(ValueError): get_delta_tables(pipeline, "non_existing_table") + # test unknown schema + with pytest.raises(FileNotFoundError): + get_delta_tables(pipeline, "non_existing_table", schema_name="aux_2") + + # load to a new schema and under new name + aux_schema = dlt.Schema("aux_2") + # NOTE: you cannot have a file with name + info = pipeline.run(parent_delta().with_name("aux_delta"), schema=aux_schema) + # also state in seprate package + assert_load_info(info, expected_load_packages=2) + delta_tables = get_delta_tables(pipeline, schema_name="aux_2") + assert "aux_delta__child" in delta_tables.keys() + get_delta_tables(pipeline, "aux_delta", schema_name="aux_2") + with pytest.raises(ValueError): + get_delta_tables(pipeline, "aux_delta") + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET,), + ), + ids=lambda x: x.name, +) +def test_parquet_to_delta_upgrade(destination_config: DestinationTestConfiguration): + # change the resource to start creating delta tables + from dlt.common.libs.deltalake import get_delta_tables + + @dlt.resource() + def foo(): + yield [{"foo": 1}, {"foo": 2}] + + pipeline = destination_config.setup_pipeline("fs_pipe") + + info = pipeline.run(foo()) + assert_load_info(info) + delta_tables = get_delta_tables(pipeline) + assert set(delta_tables.keys()) == set() + + # drop the pipeline + pipeline.deactivate() + + # redefine the resource + + @dlt.resource(table_format="delta") # type: ignore + def foo(): + yield [{"foo": 1}, {"foo": 2}] + + pipeline = destination_config.setup_pipeline("fs_pipe") + + info = pipeline.run(foo()) + assert_load_info(info) + delta_tables = get_delta_tables(pipeline) + assert set(delta_tables.keys()) == {"foo"} + + # optimize all delta tables to make sure storage is there + for table in delta_tables.values(): + table.vacuum() + TEST_LAYOUTS = ( "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 81c9292570..2792cec085 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -17,6 +17,7 @@ from dlt.common.schema.utils import new_table from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id +from dlt.common.exceptions import TerminalValueError from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations import filesystem, redshift @@ -1146,3 +1147,150 @@ def _data(): dataset_name=dataset_name, ) return p, _data + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_dest_column_invalid_timestamp_precision( + destination_config: DestinationTestConfiguration, +) -> None: + invalid_precision = 10 + + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": invalid_precision}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = destination_config.setup_pipeline(uniq_id()) + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "snowflake", "postgres"]), + ids=lambda x: x.name, +) +def test_dest_column_hint_timezone(destination_config: DestinationTestConfiguration) -> None: + destination = destination_config.destination + + input_data = [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + output_values = [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + output_map = { + "postgres": { + "tables": { + "events_timezone_off": { + "timestamp_type": "timestamp without time zone", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + "snowflake": { + "tables": { + "EVENTS_TIMEZONE_OFF": { + "timestamp_type": "TIMESTAMP_NTZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_ON": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_UNSET": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='EXPERIMENTS'" + " AND table_name = '%s' AND column_name = 'EVENT_TSTAMP'" + ), + }, + "duckdb": { + "tables": { + "events_timezone_off": { + "timestamp_type": "TIMESTAMP", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + } + + # table: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield input_data + + # table: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield input_data + + # table: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield input_data + + pipeline = destination_config.setup_pipeline( + f"{destination}_" + uniq_id(), dataset_name="experiments" + ) + + pipeline.run([events_timezone_off(), events_timezone_on(), events_timezone_unset()]) + + with pipeline.sql_client() as client: + for t in output_map[destination]["tables"].keys(): # type: ignore + # check data type + column_info = client.execute_sql(output_map[destination]["query_data_type"] % t) + assert column_info[0][0] == output_map[destination]["tables"][t]["timestamp_type"] # type: ignore + # check timestamp data + rows = client.execute_sql(f"SELECT event_tstamp FROM {t} ORDER BY event_id") + + values = [r[0].strftime("%Y-%m-%dT%H:%M:%S.%f") for r in rows] + assert values == output_map[destination]["tables"][t]["timestamp_values"] # type: ignore diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index a4001b7faa..5cadf701a2 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -42,3 +42,18 @@ def test_postgres_encoded_binary( # print(bytes(data["table"][0]["hash"])) # data in postgres equals unencoded blob assert data["table"][0]["hash"].tobytes() == blob + + +# TODO: uncomment and finalize when we implement encoding for psycopg2 +# @pytest.mark.parametrize( +# "destination_config", +# destinations_configs(default_sql_configs=True, subset=["postgres"]), +# ids=lambda x: x.name, +# ) +# def test_postgres_encoding(destination_config: DestinationTestConfiguration): +# from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient +# pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) +# client: Psycopg2SqlClient = pipeline.sql_client() +# # client.credentials.query["encoding"] = "ru" +# with client: +# print(client.native_connection.encoding) diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 8f2c0c2486..065da5ce94 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -3,6 +3,7 @@ import pytest from typing import List, Dict, Any, Optional from datetime import date, datetime, timezone # noqa: I251 +from contextlib import nullcontext as does_not_raise import dlt from dlt.common.typing import TAnyDateTime @@ -45,15 +46,16 @@ def get_load_package_created_at(pipeline: dlt.Pipeline, load_info: LoadInfo) -> return reduce_pendulum_datetime_precision(created_at, caps.timestamp_precision) +def strip_timezone(ts: TAnyDateTime) -> pendulum.DateTime: + """Converts timezone of datetime object to UTC and removes timezone awareness.""" + return ensure_pendulum_datetime(ts).astimezone(tz=timezone.utc).replace(tzinfo=None) + + def get_table( pipeline: dlt.Pipeline, table_name: str, sort_column: str = None, include_root_id: bool = True ) -> List[Dict[str, Any]]: """Returns destination table contents as list of dictionaries.""" - def strip_timezone(ts: datetime) -> datetime: - """Converts timezone of datetime object to UTC and removes timezone awareness.""" - return ensure_pendulum_datetime(ts).astimezone(tz=timezone.utc).replace(tzinfo=None) - table = [ { k: strip_timezone(v) if isinstance(v, datetime) else v @@ -69,20 +71,6 @@ def strip_timezone(ts: datetime) -> datetime: return table return sorted(table, key=lambda d: d[sort_column]) - return sorted( - [ - { - k: strip_timezone(v) if isinstance(v, datetime) else v - for k, v in r.items() - if not k.startswith("_dlt") - or k in DEFAULT_VALIDITY_COLUMN_NAMES - or (k == "_dlt_root_id" if include_root_id else False) - } - for r in load_tables_to_dicts(pipeline, table_name)[table_name] - ], - key=lambda d: d[sort_column], - ) - @pytest.mark.essential @pytest.mark.parametrize( @@ -596,6 +584,7 @@ def r(data): "9999-12-31T00:00:00", "9999-12-31T00:00:00+00:00", "9999-12-31T00:00:00+01:00", + "i_am_not_a_timestamp", ], ) def test_active_record_timestamp( @@ -604,22 +593,126 @@ def test_active_record_timestamp( ) -> None: p = destination_config.setup_pipeline("abstract", dev_mode=True) + context = does_not_raise() + if active_record_timestamp == "i_am_not_a_timestamp": + context = pytest.raises(ValueError) # type: ignore[assignment] + + with context: + + @dlt.resource( + table_name="dim_test", + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "active_record_timestamp": active_record_timestamp, + }, + ) + def r(): + yield {"foo": "bar"} + + p.run(r()) + actual_active_record_timestamp = ensure_pendulum_datetime( + load_tables_to_dicts(p, "dim_test")["dim_test"][0]["_dlt_valid_to"] + ) + assert actual_active_record_timestamp == ensure_pendulum_datetime(active_record_timestamp) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_boundary_timestamp( + destination_config: DestinationTestConfiguration, +) -> None: + p = destination_config.setup_pipeline("abstract", dev_mode=True) + + ts1 = "2024-08-21T12:15:00+00:00" + ts2 = "2024-08-22" + ts3 = date(2024, 8, 20) # earlier than ts1 and ts2 + ts4 = "i_am_not_a_timestamp" + @dlt.resource( table_name="dim_test", write_disposition={ "disposition": "merge", "strategy": "scd2", - "active_record_timestamp": active_record_timestamp, + "boundary_timestamp": ts1, }, ) - def r(): - yield {"foo": "bar"} + def r(data): + yield data - p.run(r()) - actual_active_record_timestamp = ensure_pendulum_datetime( - load_tables_to_dicts(p, "dim_test")["dim_test"][0]["_dlt_valid_to"] + # load 1 — initial load + dim_snap = [ + l1_1 := {"nk": 1, "foo": "foo"}, + l1_2 := {"nk": 2, "foo": "foo"}, + ] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 2 + from_, to = DEFAULT_VALIDITY_COLUMN_NAMES + expected = [ + {**{from_: strip_timezone(ts1), to: None}, **l1_1}, + {**{from_: strip_timezone(ts1), to: None}, **l1_2}, + ] + assert get_table(p, "dim_test", "nk") == expected + + # load 2 — different source records, different boundary timestamp + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts2, + } ) - assert actual_active_record_timestamp == ensure_pendulum_datetime(active_record_timestamp) + dim_snap = [ + l2_1 := {"nk": 1, "foo": "bar"}, # natural key 1 updated + # l1_2, # natural key 2 no longer present + l2_3 := {"nk": 3, "foo": "foo"}, # new natural key + ] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 4 + expected = [ + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_1}, # retired + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_2}, # retired + {**{from_: strip_timezone(ts2), to: None}, **l2_1}, # new + {**{from_: strip_timezone(ts2), to: None}, **l2_3}, # new + ] + assert_records_as_set(get_table(p, "dim_test"), expected) + + # load 3 — earlier boundary timestamp + # we naively apply any valid timestamp + # may lead to "valid from" > "valid to", as in this test case + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts3, + } + ) + dim_snap = [l2_1] # natural key 3 no longer present + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 4 + expected = [ + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_1}, # unchanged + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_2}, # unchanged + {**{from_: strip_timezone(ts2), to: None}, **l2_1}, # unchanged + {**{from_: strip_timezone(ts2), to: strip_timezone(ts3)}, **l2_3}, # retired + ] + assert_records_as_set(get_table(p, "dim_test"), expected) + + # invalid boundary timestamp should raise error + with pytest.raises(ValueError): + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts4, + } + ) @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 7f1427f20f..6c4f6dfec8 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -1,11 +1,12 @@ import pytest -from typing import Dict, Any, List +from typing import List import dlt, os -from dlt.common import json, sleep -from copy import deepcopy +from dlt.common import json +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.utils import uniq_id from dlt.common.schema.typing import TDataType +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from tests.load.pipeline.test_merge_disposition import github from tests.pipeline.utils import load_table_counts, assert_load_info @@ -16,6 +17,9 @@ ) from tests.cases import table_update_and_row +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + @dlt.resource( table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url") @@ -36,6 +40,13 @@ def load_modified_issues(): yield from issues +@dlt.resource(table_name="events", write_disposition="append", primary_key="timestamp") +def event_many_load_2(): + with open("tests/normalize/cases/event.event.many_load_2.json", "r", encoding="utf-8") as f: + events = json.load(f) + yield from events + + @pytest.mark.parametrize( "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name ) @@ -46,13 +57,31 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: info = pipeline.run(github(), loader_file_format=destination_config.file_format) assert_load_info(info) + # checks if remote_url is set correctly on copy jobs + metrics = info.metrics[info.loads_ids[0]][0] + for job_metrics in metrics["job_metrics"].values(): + remote_url = job_metrics.remote_url + job_ext = os.path.splitext(job_metrics.job_id)[1] + if job_ext not in (".reference", ".sql"): + assert remote_url.endswith(job_ext) + bucket_uri = destination_config.bucket_url + if FilesystemConfiguration.is_local_path(bucket_uri): + bucket_uri = FilesystemConfiguration.make_file_url(bucket_uri) + assert remote_url.startswith(bucket_uri) + package_info = pipeline.get_load_package_info(info.loads_ids[0]) assert package_info.state == "loaded" assert len(package_info.jobs["failed_jobs"]) == 0 # we have 4 parquet and 4 reference jobs plus one merge job - num_jobs = 4 + 4 + 1 if destination_config.supports_merge else 4 + 4 - assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_jobs = 4 + 4 + num_sql_jobs = 0 + if destination_config.supports_merge: + num_sql_jobs += 1 + # sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state) + if destination_config.destination == "athena" and destination_config.table_format == "iceberg": + num_sql_jobs += 1 + assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_sql_jobs assert ( len( [ @@ -87,7 +116,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: if x.job_file_info.file_format == "sql" ] ) - == 1 + == num_sql_jobs ) initial_counts = load_table_counts( @@ -167,6 +196,69 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: assert replace_counts == initial_counts +@pytest.mark.parametrize( + "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name +) +def test_truncate_staging_dataset(destination_config: DestinationTestConfiguration) -> None: + """This test checks if tables truncation on staging destination done according to the configuration. + + Test loads data to the destination three times: + * with truncation + * without truncation (after this 2 staging files should be left) + * with truncation (after this 1 staging file should be left) + """ + pipeline = destination_config.setup_pipeline( + pipeline_name="test_stage_loading", dataset_name="test_staging_load" + uniq_id() + ) + resource = event_many_load_2() + table_name: str = resource.table_name # type: ignore[assignment] + + # load the data, files stay on the stage after the load + info = pipeline.run(resource) + assert_load_info(info) + + # load the data without truncating of the staging, should see two files on staging + pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = False + info = pipeline.run(resource) + assert_load_info(info) + # check there are two staging files + _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) + with staging_client: + # except Athena + Iceberg which does not store tables in staging dataset + if ( + destination_config.destination == "athena" + and destination_config.table_format == "iceberg" + ): + table_count = 0 + # but keeps them in staging dataset on staging destination - but only the last one + with staging_client.with_staging_dataset(): # type: ignore[attr-defined] + assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + else: + table_count = 2 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] + + # load the data with truncating, so only new file is on the staging + pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = True + info = pipeline.run(resource) + assert_load_info(info) + # check that table exists in the destination + with pipeline.sql_client() as sql_client: + qual_name = sql_client.make_qualified_table_name + assert len(sql_client.execute_sql(f"SELECT * from {qual_name(table_name)}")) > 4 + # check there is only one staging file + _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) + with staging_client: + # except for Athena which does not delete staging destination tables + if destination_config.destination == "athena": + if destination_config.table_format == "iceberg": + table_count = 0 + else: + table_count = 3 + else: + table_count = 1 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] + + @pytest.mark.parametrize( "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name ) diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 86bd67db9a..28fd4eec9d 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -57,7 +57,8 @@ def test_create_table(client: PostgresClient) -> None: # non existing table sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] sqlfluff.parse(sql, dialect="postgres") - assert "event_test_table" in sql + qualified_name = client.sql_client.make_qualified_table_name("event_test_table") + assert f"CREATE TABLE {qualified_name}" in sql assert '"col1" bigint NOT NULL' in sql assert '"col2" double precision NOT NULL' in sql assert '"col3" boolean NOT NULL' in sql @@ -173,3 +174,11 @@ def test_create_table_case_sensitive(cs_client: PostgresClient) -> None: # every line starts with "Col" for line in sql.split("\n")[1:]: assert line.startswith('"Col') + + +def test_create_dlt_table(client: PostgresClient) -> None: + # non existing table + sql = client._get_table_update_sql("_dlt_version", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="postgres") + qualified_name = client.sql_client.make_qualified_table_name("_dlt_version") + assert f"CREATE TABLE IF NOT EXISTS {qualified_name}" in sql diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index b55f4ceece..72c5772668 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -8,7 +8,8 @@ from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo, TJobState +from dlt.common.storages.configuration import FilesystemConfiguration +from dlt.common.storages.load_package import LoadJobInfo, TPackageJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported from dlt.common.destination.reference import RunnableLoadJob, TDestination from dlt.common.schema.utils import ( @@ -32,6 +33,7 @@ from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain from tests.utils import ( + MockPipeline, clean_test_storage, init_test_logging, TEST_DICT_CONFIG_PROVIDER, @@ -78,10 +80,14 @@ def test_spool_job_started() -> None: load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) + assert_job_metrics(job, "completed") jobs.append(job) remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 assert len(finalized_jobs) == 2 + assert len(load._job_metrics) == 2 + for job in jobs: + assert load._job_metrics[job.job_id()] == job.metrics() def test_unsupported_writer_type() -> None: @@ -199,7 +205,9 @@ def test_spool_job_failed() -> None: load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) + assert_job_metrics(job, "failed") jobs.append(job) + assert len(jobs) == 2 # complete files remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 @@ -215,6 +223,8 @@ def test_spool_job_failed() -> None: load_id, PackageStorage.FAILED_JOBS_FOLDER, job.file_name() + ".exception" ) ) + # load should collect two jobs + assert load._job_metrics[job.job_id()] == job.metrics() started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) assert len(started_files) == 0 @@ -226,6 +236,13 @@ def test_spool_job_failed() -> None: assert package_info.state == "loaded" # all jobs failed assert len(package_info.jobs["failed_jobs"]) == 2 + # check metrics + load_info = load.get_step_info(MockPipeline("pipe", True)) # type: ignore[abstract] + metrics = load_info.metrics[load_id][0]["job_metrics"] + assert len(metrics) == 2 + for job in jobs: + assert job.job_id() in metrics + assert metrics[job.job_id()].state == "failed" def test_spool_job_failed_terminally_exception_init() -> None: @@ -244,6 +261,11 @@ def test_spool_job_failed_terminally_exception_init() -> None: assert len(package_info.jobs["started_jobs"]) == 0 # load id was never committed complete_load.assert_not_called() + # metrics can be gathered + assert len(load._job_metrics) == 2 + load_info = load.get_step_info(MockPipeline("pipe", True)) # type: ignore[abstract] + metrics = load_info.metrics[load_id][0]["job_metrics"] + assert len(metrics) == 2 def test_spool_job_failed_transiently_exception_init() -> None: @@ -264,6 +286,10 @@ def test_spool_job_failed_transiently_exception_init() -> None: # load id was never committed complete_load.assert_not_called() + # no metrics were gathered + assert len(load._job_metrics) == 0 + load_info = load.get_step_info(MockPipeline("pipe", True)) # type: ignore[abstract] + assert len(load_info.metrics) == 0 def test_spool_job_failed_exception_complete() -> None: @@ -279,6 +305,11 @@ def test_spool_job_failed_exception_complete() -> None: # both failed - we wait till the current loop is completed and then raise assert len(package_info.jobs["failed_jobs"]) == 2 assert len(package_info.jobs["started_jobs"]) == 0 + # metrics can be gathered + assert len(load._job_metrics) == 2 + load_info = load.get_step_info(MockPipeline("pipe", True)) # type: ignore[abstract] + metrics = load_info.metrics[load_id][0]["job_metrics"] + assert len(metrics) == 2 def test_spool_job_retry_new() -> None: @@ -328,6 +359,7 @@ def test_spool_job_retry_started() -> None: remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 assert len(finalized_jobs) == 0 + assert len(load._job_metrics) == 0 # clear retry flag dummy_impl.JOBS = {} files = load.load_storage.normalized_packages.list_new_jobs(load_id) @@ -407,6 +439,8 @@ def test_failing_followup_jobs() -> None: assert len(dummy_impl.JOBS) == 2 assert len(dummy_impl.RETRIED_JOBS) == 0 assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + # no metrics were collected + assert len(load._job_metrics) == 0 # now we can retry the same load, it will restart the two jobs and successfully create the followup jobs load.initial_client_config.fail_followup_job_creation = False # type: ignore @@ -436,6 +470,8 @@ def test_failing_table_chain_followup_jobs() -> None: assert len(dummy_impl.JOBS) == 2 assert len(dummy_impl.RETRIED_JOBS) == 0 assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + # no metrics were collected + assert len(load._job_metrics) == 0 # now we can retry the same load, it will restart the two jobs and successfully create the table chain followup jobs load.initial_client_config.fail_table_chain_followup_job_creation = False # type: ignore @@ -512,6 +548,23 @@ def test_completed_loop_with_delete_completed() -> None: assert_complete_job(load, should_delete_completed=True) +@pytest.mark.parametrize("to_truncate", [True, False]) +def test_truncate_table_before_load_on_stanging(to_truncate) -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + truncate_tables_on_staging_destination_before_load=to_truncate + ) + ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) + destination_client = load.get_destination_client(schema) + assert ( + destination_client.should_truncate_table_before_load_on_staging_destination( # type: ignore + schema.tables["_dlt_version"] + ) + == to_truncate + ) + + def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) @@ -662,11 +715,11 @@ def test_get_completed_table_chain_cases() -> None: # child completed, parent not event_user = schema.get_table("event_user") event_user_entities = schema.get_table("event_user__parse_data__entities") - event_user_job: Tuple[TJobState, ParsedLoadJobFileName] = ( + event_user_job: Tuple[TPackageJobState, ParsedLoadJobFileName] = ( "started_jobs", ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl"), ) - event_user_entities_job: Tuple[TJobState, ParsedLoadJobFileName] = ( + event_user_entities_job: Tuple[TPackageJobState, ParsedLoadJobFileName] = ( "completed_jobs", ParsedLoadJobFileName( "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" @@ -857,6 +910,33 @@ def test_dummy_staging_filesystem() -> None: assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 +def test_load_multiple_packages() -> None: + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) + load.config.pool_type = "none" + load_id_1, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) + sleep(0.1) + load_id_2, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) + run_metrics = load.run(None) + assert run_metrics.pending_items == 1 + # assert load._current_load_id is None + metrics_id_1 = load._job_metrics + assert len(metrics_id_1) == 2 + assert load._step_info_metrics(load_id_1)[0]["job_metrics"] == metrics_id_1 + run_metrics = load.run(None) + assert run_metrics.pending_items == 0 + metrics_id_2 = load._job_metrics + assert len(metrics_id_2) == 2 + assert load._step_info_metrics(load_id_2)[0]["job_metrics"] == metrics_id_2 + load_info = load.get_step_info(MockPipeline("pipe", True)) # type: ignore[abstract] + assert load_id_1 in load_info.metrics + assert load_id_2 in load_info.metrics + assert load_info.metrics[load_id_1][0]["job_metrics"] == metrics_id_1 + assert load_info.metrics[load_id_2][0]["job_metrics"] == metrics_id_2 + # execute empty run + load.run(None) + assert len(load_info.metrics) == 2 + + def test_terminal_exceptions() -> None: try: raise TerminalValueError("a") @@ -866,6 +946,15 @@ def test_terminal_exceptions() -> None: raise AssertionError() +def assert_job_metrics(job: RunnableLoadJob, expected_state: str) -> None: + metrics = job.metrics() + assert metrics.state == expected_state + assert metrics.started_at <= metrics.finished_at + assert metrics.job_id == job.job_id() + assert metrics.table_name == job._parsed_file_name.table_name + assert metrics.file_path == job._file_path + + def assert_complete_job( load: Load, should_delete_completed: bool = False, load_id: str = None, jobs_per_case: int = 1 ) -> None: @@ -910,6 +999,32 @@ def assert_complete_job( assert load.load_storage.loaded_packages.storage.has_folder(completed_path) # complete load on client was called complete_load.assert_called_once_with(load_id) + # assert if all jobs in final state have metrics + metrics = load.get_step_info(MockPipeline("pipe", True)).metrics[load_id][0] # type: ignore[abstract] + package_info = load.load_storage.loaded_packages.get_load_package_jobs(load_id) + for state, jobs in package_info.items(): + for job in jobs: + job_metrics = metrics["job_metrics"].get(job.job_id()) + if state in ("failed_jobs", "completed_jobs"): + assert job_metrics is not None + assert ( + metrics["job_metrics"][job.job_id()].state == "failed" + if state == "failed_jobs" + else "completed" + ) + remote_url = job_metrics.remote_url + if load.initial_client_config.create_followup_jobs: # type: ignore + assert remote_url.endswith(job.file_name()) + elif load.is_staging_destination_job(job.file_name()): + # staging destination should contain reference to remote filesystem + assert ( + FilesystemConfiguration.make_file_url(REMOTE_FILESYSTEM) + in remote_url + ) + else: + assert remote_url is None + else: + assert job_metrics is None def run_all(load: Load) -> None: @@ -941,9 +1056,9 @@ def setup_loader( staging = None if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination - client_config = client_config or DummyClientConfiguration( - loader_file_format="reference", completed_prob=1 - ) + # client_config = client_config or DummyClientConfiguration( + # loader_file_format="reference", completed_prob=1 + # ) staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( dataset_name="dummy" ) diff --git a/tests/load/utils.py b/tests/load/utils.py index d649343c63..5427904d52 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -45,6 +45,7 @@ from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration from dlt.common.schema.utils import new_table, normalize_table_identifiers from dlt.common.storages import ParsedLoadJobFileName, LoadStorage, PackageStorage +from dlt.common.storages.load_package import create_load_id from dlt.common.typing import StrAny from dlt.common.utils import uniq_id @@ -69,6 +70,7 @@ AWS_BUCKET = dlt.config.get("tests.bucket_url_s3", str) GCS_BUCKET = dlt.config.get("tests.bucket_url_gs", str) AZ_BUCKET = dlt.config.get("tests.bucket_url_az", str) +ABFS_BUCKET = dlt.config.get("tests.bucket_url_abfss", str) GDRIVE_BUCKET = dlt.config.get("tests.bucket_url_gdrive", str) FILE_BUCKET = dlt.config.get("tests.bucket_url_file", str) R2_BUCKET = dlt.config.get("tests.bucket_url_r2", str) @@ -78,6 +80,7 @@ "s3", "gs", "az", + "abfss", "gdrive", "file", "memory", @@ -85,7 +88,15 @@ ] # Filter out buckets not in all filesystem drivers -WITH_GDRIVE_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET, AZ_BUCKET, GDRIVE_BUCKET] +WITH_GDRIVE_BUCKETS = [ + GCS_BUCKET, + AWS_BUCKET, + FILE_BUCKET, + MEMORY_BUCKET, + ABFS_BUCKET, + AZ_BUCKET, + GDRIVE_BUCKET, +] WITH_GDRIVE_BUCKETS = [ bucket for bucket in WITH_GDRIVE_BUCKETS @@ -246,6 +257,27 @@ def destinations_configs( # build destination configs destination_configs: List[DestinationTestConfiguration] = [] + # default sql configs that are also default staging configs + default_sql_configs_with_staging = [ + # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + supports_merge=False, + bucket_url=AWS_BUCKET, + ), + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + bucket_url=AWS_BUCKET, + force_iceberg=True, + supports_merge=True, + supports_dbt=False, + table_format="iceberg", + extra_info="iceberg", + ), + ] + # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [ @@ -257,26 +289,10 @@ def destinations_configs( DestinationTestConfiguration(destination="duckdb", file_format="parquet"), DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), ] - # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - supports_merge=False, - bucket_url=AWS_BUCKET, - ) - ] - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - bucket_url=AWS_BUCKET, - force_iceberg=True, - supports_merge=True, - supports_dbt=False, - extra_info="iceberg", - ) - ] + + # add Athena staging configs + destination_configs += default_sql_configs_with_staging + destination_configs += [ DestinationTestConfiguration( destination="clickhouse", file_format="jsonl", supports_dbt=False @@ -321,6 +337,10 @@ def destinations_configs( DestinationTestConfiguration(destination="qdrant", extra_info="server"), ] + if (default_sql_configs or all_staging_configs) and not default_sql_configs: + # athena default configs not added yet + destination_configs += default_sql_configs_with_staging + if default_staging_configs or all_staging_configs: destination_configs += [ DestinationTestConfiguration( @@ -712,7 +732,7 @@ def expect_load_file( query = query.encode("utf-8") # type: ignore[assignment] file_storage.save(file_name, query) table = client.prepare_load_table(table_name) - load_id = uniq_id() + load_id = create_load_id() job = client.create_load_job(table, file_storage.make_full_path(file_name), load_id) if isinstance(job, RunnableLoadJob): @@ -873,7 +893,7 @@ def prepare_load_package( Create a load package with explicitely provided files job_per_case multiplies the amount of load jobs, for big packages use small files """ - load_id = uniq_id() + load_id = create_load_id() load_storage.new_packages.create_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" diff --git a/tests/pipeline/cases/contracts/trace.schema.yaml b/tests/pipeline/cases/contracts/trace.schema.yaml new file mode 100644 index 0000000000..c324818338 --- /dev/null +++ b/tests/pipeline/cases/contracts/trace.schema.yaml @@ -0,0 +1,772 @@ +version: 4 +version_hash: JE62zVwqT2T/qHTi2Qdnn2d1A/JzCzyGtDwc+qUmbTs= +engine_version: 9 +name: trace +tables: + _dlt_version: + columns: + version: + data_type: bigint + nullable: false + engine_version: + data_type: bigint + nullable: false + inserted_at: + data_type: timestamp + nullable: false + schema_name: + data_type: text + nullable: false + version_hash: + data_type: text + nullable: false + schema: + data_type: text + nullable: false + write_disposition: skip + description: Created by DLT. Tracks schema updates + _dlt_loads: + columns: + load_id: + data_type: text + nullable: false + schema_name: + data_type: text + nullable: true + status: + data_type: bigint + nullable: false + inserted_at: + data_type: timestamp + nullable: false + schema_version_hash: + data_type: text + nullable: true + write_disposition: skip + description: Created by DLT. Tracks completed loads + trace: + columns: + transaction_id: + data_type: text + nullable: true + pipeline_name: + data_type: text + nullable: true + execution_context__ci_run: + data_type: bool + nullable: true + execution_context__python: + data_type: text + nullable: true + execution_context__cpu: + data_type: bigint + nullable: true + execution_context__os__name: + data_type: text + nullable: true + execution_context__os__version: + data_type: text + nullable: true + execution_context__library__name: + data_type: text + nullable: true + execution_context__library__version: + data_type: text + nullable: true + started_at: + data_type: timestamp + nullable: true + finished_at: + data_type: timestamp + nullable: true + engine_version: + data_type: bigint + nullable: true + _dlt_load_id: + data_type: text + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + write_disposition: append + trace__execution_context__exec_info: + columns: + value: + data_type: text + nullable: true + _dlt_id: + data_type: text + nullable: false + unique: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + parent: trace + trace__steps: + columns: + span_id: + data_type: text + nullable: true + step: + data_type: text + nullable: true + started_at: + data_type: timestamp + nullable: true + finished_at: + data_type: timestamp + nullable: true + step_info__pipeline__pipeline_name: + data_type: text + nullable: true + step_info__first_run: + data_type: bool + nullable: true + step_info__started_at: + data_type: timestamp + nullable: true + step_info__finished_at: + data_type: timestamp + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + load_info__destination_type: + data_type: text + nullable: true + load_info__destination_displayable_credentials: + data_type: text + nullable: true + load_info__destination_name: + data_type: text + nullable: true + load_info__staging_type: + data_type: text + nullable: true + load_info__staging_name: + data_type: text + nullable: true + load_info__staging_displayable_credentials: + data_type: text + nullable: true + load_info__destination_fingerprint: + data_type: text + nullable: true + step_exception: + data_type: text + nullable: true + parent: trace + trace__steps__extract_info__job_metrics: + columns: + file_path: + data_type: text + nullable: true + items_count: + data_type: bigint + nullable: true + file_size: + data_type: bigint + nullable: true + created: + data_type: double + nullable: true + last_modified: + data_type: double + nullable: true + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + job_id: + data_type: text + nullable: true + table_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__extract_info__table_metrics: + columns: + file_path: + data_type: text + nullable: true + items_count: + data_type: bigint + nullable: true + file_size: + data_type: bigint + nullable: true + created: + data_type: double + nullable: true + last_modified: + data_type: double + nullable: true + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + table_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__extract_info__resource_metrics: + columns: + file_path: + data_type: text + nullable: true + items_count: + data_type: bigint + nullable: true + file_size: + data_type: bigint + nullable: true + created: + data_type: double + nullable: true + last_modified: + data_type: double + nullable: true + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + resource_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__extract_info__dag: + columns: + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + parent_name: + data_type: text + nullable: true + resource_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__extract_info__hints: + columns: + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + resource_name: + data_type: text + nullable: true + columns: + data_type: text + nullable: true + write_disposition: + data_type: text + nullable: true + schema_contract: + data_type: text + nullable: true + table_format: + data_type: text + nullable: true + file_format: + data_type: text + nullable: true + original_columns: + data_type: text + nullable: true + primary_key: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__step_info__loads_ids: + columns: + value: + data_type: text + nullable: true + _dlt_id: + data_type: text + nullable: false + unique: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + parent: trace__steps + trace__steps__step_info__load_packages: + columns: + load_id: + data_type: text + nullable: true + package_path: + data_type: text + nullable: true + state: + data_type: text + nullable: true + schema_hash: + data_type: text + nullable: true + schema_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + completed_at: + data_type: timestamp + nullable: true + parent: trace__steps + trace__steps__step_info__load_packages__jobs: + columns: + state: + data_type: text + nullable: true + file_path: + data_type: text + nullable: true + file_size: + data_type: bigint + nullable: true + created_at: + data_type: timestamp + nullable: true + elapsed: + data_type: double + nullable: true + table_name: + data_type: text + nullable: true + file_id: + data_type: text + nullable: true + retry_count: + data_type: bigint + nullable: true + file_format: + data_type: text + nullable: true + job_id: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps__step_info__load_packages + trace__steps__normalize_info__job_metrics: + columns: + file_path: + data_type: text + nullable: true + items_count: + data_type: bigint + nullable: true + file_size: + data_type: bigint + nullable: true + created: + data_type: double + nullable: true + last_modified: + data_type: double + nullable: true + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + job_id: + data_type: text + nullable: true + table_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__normalize_info__table_metrics: + columns: + file_path: + data_type: text + nullable: true + items_count: + data_type: bigint + nullable: true + file_size: + data_type: bigint + nullable: true + created: + data_type: double + nullable: true + last_modified: + data_type: double + nullable: true + load_id: + data_type: text + nullable: true + extract_idx: + data_type: bigint + nullable: true + table_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__load_info__job_metrics: + columns: + load_id: + data_type: text + nullable: true + job_id: + data_type: text + nullable: true + file_path: + data_type: text + nullable: true + table_name: + data_type: text + nullable: true + state: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + started_at: + data_type: timestamp + nullable: true + finished_at: + data_type: timestamp + nullable: true + remote_url: + data_type: text + nullable: true + parent: trace__steps + trace__steps__step_info__load_packages__tables: + columns: + write_disposition: + data_type: text + nullable: true + schema_contract: + data_type: text + nullable: true + table_format: + data_type: text + nullable: true + file_format: + data_type: text + nullable: true + name: + data_type: text + nullable: true + resource: + data_type: text + nullable: true + schema_name: + data_type: text + nullable: true + load_id: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: + data_type: text + nullable: true + x_normalizer__seen_data: + data_type: bool + nullable: true + parent: trace__steps__step_info__load_packages + trace__steps__step_info__load_packages__tables__columns: + columns: + name: + data_type: text + nullable: true + data_type: + data_type: text + nullable: true + nullable: + data_type: bool + nullable: true + primary_key: + data_type: bool + nullable: true + table_name: + data_type: text + nullable: true + schema_name: + data_type: text + nullable: true + load_id: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + unique: + data_type: bool + nullable: true + foreign_key: + data_type: bool + nullable: true + parent: trace__steps__step_info__load_packages__tables + trace__resolved_config_values: + columns: + key: + data_type: text + nullable: true + is_secret_hint: + data_type: bool + nullable: true + provider_name: + data_type: text + nullable: true + config_type_name: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace + trace__resolved_config_values__sections: + columns: + value: + data_type: text + nullable: true + _dlt_id: + data_type: text + nullable: false + unique: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + parent: trace__resolved_config_values + trace__steps__exception_traces: + columns: + message: + data_type: text + nullable: true + exception_type: + data_type: text + nullable: true + is_terminal: + data_type: bool + nullable: true + docstring: + data_type: text + nullable: true + load_id: + data_type: text + nullable: true + pipeline_name: + data_type: text + nullable: true + exception_attrs: + data_type: text + nullable: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + _dlt_id: + data_type: text + nullable: false + unique: true + parent: trace__steps + trace__steps__exception_traces__stack_trace: + columns: + value: + data_type: text + nullable: true + _dlt_id: + data_type: text + nullable: false + unique: true + _dlt_parent_id: + data_type: text + nullable: false + foreign_key: true + _dlt_list_idx: + data_type: bigint + nullable: false + parent: trace__steps__exception_traces +settings: + detections: + - iso_timestamp + default_hints: + not_null: + - _dlt_id + - _dlt_root_id + - _dlt_parent_id + - _dlt_list_idx + - _dlt_load_id + foreign_key: + - _dlt_parent_id + root_key: + - _dlt_root_id + unique: + - _dlt_id +normalizers: + names: snake_case + json: + module: dlt.common.normalizers.json.relational +previous_hashes: +- 9Ysjq/W0xpxkI/vBiYm8Qbr2nDP3JMt6KvGKUS/FCyI= +- NYeAxJ2r+T+dKFnXFhBEPzBP6SO+ORdhOfgQRo/XqBU= +- RV9jvZSD5dM+ZGjEL3HqokLvtf22K4zMNc3zWRahEw4= diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 0ab1f61d72..918f9beab9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -29,7 +29,7 @@ DestinationTerminalException, UnknownDestinationModule, ) -from dlt.common.exceptions import PipelineStateNotAvailable +from dlt.common.exceptions import PipelineStateNotAvailable, TerminalValueError from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector from dlt.common.schema.exceptions import TableIdentifiersFrozen @@ -39,7 +39,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema -from dlt.destinations import filesystem, redshift, dummy +from dlt.destinations import filesystem, redshift, dummy, duckdb from dlt.destinations.impl.filesystem.filesystem import INIT_FILE_NAME from dlt.extract.exceptions import InvalidResourceDataTypeBasic, PipeGenInvalid, SourceExhausted from dlt.extract.extract import ExtractStorage @@ -2600,6 +2600,20 @@ def ids(_id=dlt.sources.incremental("_id", initial_value=2)): assert pipeline.last_trace.last_normalize_info.row_counts["_ids"] == 2 +def test_dlt_columns_nested_table_collisions() -> None: + # we generate all identifiers in upper case to test for a bug where dlt columns for nested tables were hardcoded to + # small caps. they got normalized to upper case after the first run and then added again as small caps + # generating duplicate columns and raising collision exception as duckdb is ci destination + duck = duckdb(naming_convention="tests.common.cases.normalizers.sql_upper") + pipeline = dlt.pipeline("test_dlt_columns_child_table_collisions", destination=duck) + customers = [ + {"id": 1, "name": "dave", "orders": [1, 2, 3]}, + ] + assert_load_info(pipeline.run(customers, table_name="CUSTOMERS")) + # this one would fail without bugfix + assert_load_info(pipeline.run(customers, table_name="CUSTOMERS")) + + def test_access_pipeline_in_resource() -> None: pipeline = dlt.pipeline("test_access_pipeline_in_resource", destination="duckdb") @@ -2637,6 +2651,57 @@ def comments(user_id: str): assert pipeline.last_trace.last_normalize_info.row_counts["user_comments"] == 3 +def test_exceed_job_file_name_length() -> None: + # use very long table name both for parent and for a child + data = { + "id": 1, + "child use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child": [ + 1, + 2, + 3, + ], + "col use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child use very long table name both for parent and for a child": ( + "data" + ), + } + + table_name = ( + "parent use very long table name both for parent and for a child use very long table name" + " both for parent and for a child use very long table name both for parent and for a child" + " use very long table name both for parent and for a child use very long table name both" + " for parent and for a child use very long table name both for parent and for a child " + ) + + pipeline = dlt.pipeline( + pipeline_name="test_exceed_job_file_name_length", + destination="duckdb", + ) + # path too long + with pytest.raises(PipelineStepFailed) as os_err: + pipeline.run([data], table_name=table_name) + assert isinstance(os_err.value.__cause__, OSError) + + # fit into 255 + 1 + suffix_len = len(".b61d3af76c.0.insert-values") + pipeline = dlt.pipeline( + pipeline_name="test_exceed_job_file_name_length", + destination=duckdb( + max_identifier_length=255 - suffix_len + 1, + ), + ) + # path too long + with pytest.raises(PipelineStepFailed): + pipeline.run([data], table_name=table_name) + + pipeline = dlt.pipeline( + pipeline_name="test_exceed_job_file_name_length", + destination=duckdb( + max_identifier_length=255 - suffix_len, + ), + ) + pipeline.run([data], table_name=table_name) + + def assert_imported_file( pipeline: Pipeline, table_name: str, @@ -2664,3 +2729,18 @@ def assert_imported_file( extract_info.metrics[extract_info.loads_ids[0]][0]["table_metrics"][table_name].items_count == expected_rows ) + + +def test_duckdb_column_invalid_timestamp() -> None: + # DuckDB does not have timestamps with timezone and precision + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True, "precision": 3}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = dlt.pipeline(destination="duckdb") + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index d3e44198b4..af3a6c239e 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -22,6 +22,7 @@ class BaseModel: # type: ignore[no-redef] from dlt.common import json, pendulum from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.common.runtime.collector import ( AliveCollector, EnlightenCollector, @@ -245,7 +246,6 @@ class TestRow(BaseModel): example_string: str # yield model in resource so incremental fails when looking for "id" - # TODO: support pydantic models in incremental @dlt.resource(name="table_name", primary_key="id", write_disposition="replace") def generate_rows_incremental( @@ -599,3 +599,82 @@ def test_pick_matching_file_format(test_storage: FileStorage) -> None: files = test_storage.list_folder_files("user_data_csv/object") assert len(files) == 1 assert files[0].endswith("csv") + + +def test_filesystem_column_hint_timezone() -> None: + import pyarrow.parquet as pq + import posixpath + + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" + + # talbe: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + pipeline = dlt.pipeline(destination="filesystem") + + pipeline.run( + [events_timezone_off(), events_timezone_on(), events_timezone_unset()], + loader_file_format="parquet", + ) + + client: FilesystemClient = pipeline.destination_client() # type: ignore[assignment] + + expected_results = { + "events_timezone_off": None, + "events_timezone_on": "UTC", + "events_timezone_unset": "UTC", + } + + for t in expected_results.keys(): + events_glob = posixpath.join(client.dataset_path, f"{t}/*") + events_files = client.fs_client.glob(events_glob) + + with open(events_files[0], "rb") as f: + table = pq.read_table(f) + + # convert the timestamps to strings + timestamps = [ + ts.as_py().strftime("%Y-%m-%dT%H:%M:%S.%f") for ts in table.column("event_tstamp") + ] + assert timestamps == [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + # check if the Parquet file contains timezone information + schema = table.schema + field = schema.field("event_tstamp") + assert field.type.tz == expected_results[t] diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index 3239e01bab..d2bb035a17 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest import requests_mock +import yaml import dlt @@ -19,6 +20,8 @@ from dlt.common.typing import DictStrAny, StrStr, DictStrStr, TSecretValue from dlt.common.utils import digest128 +from dlt.destinations import dummy, filesystem + from dlt.pipeline.exceptions import PipelineStepFailed from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.trace import ( @@ -31,7 +34,8 @@ from dlt.extract.extract import describe_extract_data from dlt.extract.pipe import Pipe -from tests.utils import start_test_telemetry +from tests.pipeline.utils import PIPELINE_TEST_CASES_PATH +from tests.utils import TEST_STORAGE_ROOT, start_test_telemetry from tests.common.configuration.utils import toml_providers, environment @@ -122,7 +126,7 @@ def data(): resolved = _find_resolved_value(trace.resolved_config_values, "credentials", ["databricks"]) assert resolved.is_secret_hint is True assert resolved.value == databricks_creds - assert_trace_printable(trace) + assert_trace_serializable(trace) # activate pipeline because other was running in assert trace p.activate() @@ -153,7 +157,7 @@ def data(): assert isinstance(step.step_info, ExtractInfo) assert len(step.exception_traces) > 0 assert step.step_info.extract_data_info == [{"name": "async_exception", "data_type": "source"}] - assert_trace_printable(trace) + assert_trace_serializable(trace) extract_info = step.step_info # only new (unprocessed) package is present, all other metrics are empty, state won't be extracted @@ -174,7 +178,7 @@ def data(): step = trace.steps[2] assert step.step == "normalize" assert step.step_info is norm_info - assert_trace_printable(trace) + assert_trace_serializable(trace) assert isinstance(p.last_trace.last_normalize_info, NormalizeInfo) assert p.last_trace.last_normalize_info.row_counts == {"_dlt_pipeline_state": 1, "data": 3} @@ -216,7 +220,7 @@ def data(): assert resolved.is_secret_hint is False assert resolved.value == "1.0" assert resolved.config_type_name == "DummyClientConfiguration" - assert_trace_printable(trace) + assert_trace_serializable(trace) assert isinstance(p.last_trace.last_load_info, LoadInfo) p.activate() @@ -234,12 +238,157 @@ def data(): assert step.step == "load" assert step.step_info is load_info # same load info assert trace.steps[0].step_info is not extract_info - assert_trace_printable(trace) + assert_trace_serializable(trace) assert isinstance(p.last_trace.last_load_info, LoadInfo) assert isinstance(p.last_trace.last_normalize_info, NormalizeInfo) assert isinstance(p.last_trace.last_extract_info, ExtractInfo) +def test_trace_schema() -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" + + # mock runtime env + os.environ["CIRCLECI"] = "1" + os.environ["AWS_LAMBDA_FUNCTION_NAME"] = "lambda" + + @dlt.source(section="many_hints") + def many_hints( + api_type=dlt.config.value, + credentials: str = dlt.secrets.value, + secret_value: TSecretValue = TSecretValue("123"), # noqa: B008 + ): + # TODO: create table / column schema from typed dicts, not explicitly + @dlt.resource( + write_disposition="replace", + primary_key="id", + table_format="delta", + file_format="jsonl", + schema_contract="evolve", + columns=[ + { + "name": "multi", + "data_type": "decimal", + "nullable": True, + "cluster": True, + "description": "unknown", + "merge_key": True, + "precision": 9, + "scale": 3, + "sort": True, + "variant": True, + "partition": True, + } + ], + ) + def data(): + yield [{"id": 1, "multi": "1.2"}, {"id": 2}, {"id": 3}] + + return data() + + @dlt.source + def github(): + @dlt.resource + def get_shuffled_events(): + for _ in range(1): + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", + "r", + encoding="utf-8", + ) as f: + issues = json.load(f) + yield issues + + return get_shuffled_events() + + @dlt.source + def async_exception(max_range=1): + async def get_val(v): + await asyncio.sleep(0.1) + if v % 3 == 0: + raise ValueError(v) + return v + + @dlt.resource + def data(): + yield from [get_val(v) for v in range(1, max_range)] + + return data() + + # create pipeline with staging to get remote_url in load step job_metrics + dummy_dest = dummy(completed_prob=1.0) + pipeline = dlt.pipeline( + pipeline_name="test_trace_schema", + destination=dummy_dest, + staging=filesystem(os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem"))), + dataset_name="various", + ) + + # mock config + os.environ["API_TYPE"] = "REST" + os.environ["SOURCES__MANY_HINTS__CREDENTIALS"] = "CREDS" + + info = pipeline.run([many_hints(), github()]) + info.raise_on_failed_jobs() + + trace = pipeline.last_trace + pipeline._schema_storage.storage.save("trace.json", json.dumps(trace, pretty=True)) + + schema = dlt.Schema("trace") + trace_pipeline = dlt.pipeline( + pipeline_name="test_trace_schema_traces", destination=dummy(completed_prob=1.0) + ) + info = trace_pipeline.run([trace], table_name="trace", schema=schema) + info.raise_on_failed_jobs() + + # add exception trace + with pytest.raises(PipelineStepFailed): + pipeline.extract(async_exception(max_range=4)) + + trace_exception = pipeline.last_trace + pipeline._schema_storage.storage.save( + "trace_exception.json", json.dumps(trace_exception, pretty=True) + ) + + info = trace_pipeline.run([trace_exception], table_name="trace") + info.raise_on_failed_jobs() + inferred_trace_contract = trace_pipeline.schemas["trace"] + inferred_contract_str = inferred_trace_contract.to_pretty_yaml(remove_processing_hints=True) + + # NOTE: this saves actual inferred contract (schema) to schema storage, move it to test cases if you update + # trace shapes + # TODO: create a proper schema for dlt trace and tables/columns + pipeline._schema_storage.storage.save("trace.schema.yaml", inferred_contract_str) + # print(pipeline._schema_storage.storage.storage_path) + + # load the schema and use it as contract + with open(f"{PIPELINE_TEST_CASES_PATH}/contracts/trace.schema.yaml", encoding="utf-8") as f: + imported_schema = yaml.safe_load(f) + trace_contract = Schema.from_dict(imported_schema, remove_processing_hints=True) + # compare pretty forms of the schemas, they must be identical + # NOTE: if this fails you can comment this out and use contract run below to find first offending difference + # assert trace_contract.to_pretty_yaml() == inferred_contract_str + + # use trace contract to load data again + contract_trace_pipeline = dlt.pipeline( + pipeline_name="test_trace_schema_traces_contract", destination=dummy(completed_prob=1.0) + ) + info = contract_trace_pipeline.run( + [trace_exception, trace], + table_name="trace", + schema=trace_contract, + schema_contract="freeze", + ) + + # assert inferred_trace_contract.version_hash == trace_contract.version_hash + + # print(trace_pipeline.schemas["trace"].to_pretty_yaml()) + # print(pipeline._schema_storage.storage.storage_path) + + +# def test_trace_schema_contract() -> None: + + def test_save_load_trace() -> None: os.environ["COMPLETED_PROB"] = "1.0" info = dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") @@ -255,7 +404,7 @@ def test_save_load_trace() -> None: assert resolved.is_secret_hint is False assert resolved.value == "1.0" assert resolved.config_type_name == "DummyClientConfiguration" - assert_trace_printable(trace) + assert_trace_serializable(trace) # check row counts assert pipeline.last_trace.last_normalize_info.row_counts == { "_dlt_pipeline_state": 1, @@ -296,7 +445,7 @@ def data(): assert run_step.step == "run" assert run_step.step_exception is not None assert step.step_exception == run_step.step_exception - assert_trace_printable(trace) + assert_trace_serializable(trace) assert pipeline.last_trace.last_normalize_info is None @@ -306,7 +455,7 @@ def test_save_load_empty_trace() -> None: pipeline = dlt.pipeline() pipeline.run([], table_name="data", destination="dummy") trace = pipeline.last_trace - assert_trace_printable(trace) + assert_trace_serializable(trace) assert len(trace.steps) == 4 pipeline.activate() @@ -402,7 +551,7 @@ def test_trace_telemetry() -> None: for item in SENTRY_SENT_ITEMS: # print(item) print(item["logentry"]["message"]) - assert len(SENTRY_SENT_ITEMS) == 2 + assert len(SENTRY_SENT_ITEMS) == 4 # trace with exception @dlt.resource @@ -529,7 +678,7 @@ def _mock_sentry_before_send(event: DictStrAny, _unused_hint: Any = None) -> Dic return event -def assert_trace_printable(trace: PipelineTrace) -> None: +def assert_trace_serializable(trace: PipelineTrace) -> None: str(trace) trace.asstr(0) trace.asstr(1) diff --git a/tests/pipeline/test_platform_connection.py b/tests/pipeline/test_platform_connection.py index fa5b143ff5..aa46019382 100644 --- a/tests/pipeline/test_platform_connection.py +++ b/tests/pipeline/test_platform_connection.py @@ -65,7 +65,8 @@ def data(): # basic check of trace result assert trace_result, "no trace" assert trace_result["pipeline_name"] == "platform_test_pipeline" - assert len(trace_result["steps"]) == 4 + # just extract, normalize and load steps. run step is not serialized to trace (it was just a copy of load) + assert len(trace_result["steps"]) == 3 assert trace_result["execution_context"]["library"]["name"] == "dlt" # basic check of state result diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index dfdb9c8e40..dfb5f3f82d 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -98,6 +98,9 @@ def users_materialize_table_schema(): def assert_load_info(info: LoadInfo, expected_load_packages: int = 1) -> None: """Asserts that expected number of packages was loaded and there are no failed jobs""" + # make sure we can serialize + info.asstr(verbosity=2) + info.asdict() assert len(info.loads_ids) == expected_load_packages # all packages loaded assert all(p.completed_at is not None for p in info.load_packages) is True @@ -174,24 +177,27 @@ def _load_file(client: FSClientBase, filepath) -> List[Dict[str, Any]]: # -def _load_tables_to_dicts_fs(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, Any]]]: +def _load_tables_to_dicts_fs( + p: dlt.Pipeline, *table_names: str, schema_name: str = None +) -> Dict[str, List[Dict[str, Any]]]: """For now this will expect the standard layout in the filesystem destination, if changed the results will not be correct""" - client = p._fs_client() + client = p._fs_client(schema_name=schema_name) + assert isinstance(client, FilesystemClient) + result: Dict[str, Any] = {} delta_table_names = [ table_name for table_name in table_names - if get_table_format(p.default_schema.tables, table_name) == "delta" + if get_table_format(client.schema.tables, table_name) == "delta" ] if len(delta_table_names) > 0: from dlt.common.libs.deltalake import get_delta_tables - delta_tables = get_delta_tables(p, *table_names) + delta_tables = get_delta_tables(p, *table_names, schema_name=schema_name) for table_name in table_names: - if table_name in p.default_schema.data_table_names() and table_name in delta_table_names: - assert isinstance(client, FilesystemClient) + if table_name in client.schema.data_table_names() and table_name in delta_table_names: dt = delta_tables[table_name] result[table_name] = dt.to_pyarrow_table().to_pylist() else: @@ -241,7 +247,7 @@ def _sort_list_of_dicts(list_: List[Dict[str, Any]], sortkey: str) -> List[Dict[ return sorted(list_, key=lambda d: d[sortkey]) if _is_filesystem(p): - result = _load_tables_to_dicts_fs(p, *table_names) + result = _load_tables_to_dicts_fs(p, *table_names, schema_name=schema_name) else: result = _load_tables_to_dicts_sql(p, *table_names, schema_name=schema_name) diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 5c9f484bbc..39e3d767a0 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -347,7 +347,9 @@ def test_guarantee_termination(self): total_path=None, stop_after_empty_page=False, ) - assert e.match("`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided") + assert e.match( + "`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided" + ) with pytest.raises(ValueError) as e: OffsetPaginator( @@ -356,7 +358,9 @@ def test_guarantee_termination(self): stop_after_empty_page=False, maximum_offset=None, ) - assert e.match("`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided") + assert e.match( + "`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided" + ) @pytest.mark.usefixtures("mock_api_server") diff --git a/tests/utils.py b/tests/utils.py index 976a623c0b..1b81881470 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -189,8 +189,9 @@ def wipe_pipeline(preserve_environ) -> Iterator[None]: yield if container[PipelineContext].is_active(): # take existing pipeline - p = dlt.pipeline() - p._wipe_working_folder() + # NOTE: no more needed. test storage is wiped fully when test starts + # p = dlt.pipeline() + # p._wipe_working_folder() # deactivate context container[PipelineContext].deactivate()