Skip to content

Commit

Permalink
Merge branch 'remove-lancedb-doc-id-hints' into 1587-lancedb-support-…
Browse files Browse the repository at this point in the history
…efficient-update-strategy-for-chunked-documents

# Conflicts:
#	dlt/destinations/impl/lancedb/lancedb_adapter.py
#	tests/load/lancedb/test_merge.py
  • Loading branch information
Pipboyguy committed Sep 2, 2024
2 parents 470315e + 2b7f4c6 commit 105b388
Show file tree
Hide file tree
Showing 125 changed files with 4,919 additions and 1,338 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_destinations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ env:
# Test redshift and filesystem with all buckets
# postgres runs again here so we can test on mac/windows
ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]"
# note that all buckets are enabled for testing

jobs:
get_docs_changes:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ lint:
poetry run mypy --config-file mypy.ini dlt tests
poetry run flake8 --max-line-length=200 dlt
poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases
poetry run black dlt docs tests --diff --extend-exclude=".*syntax_error.py"
poetry run black dlt docs tests --check --diff --color --extend-exclude=".*syntax_error.py"
# poetry run isort ./ --diff
# $(MAKE) lint-security

Expand Down
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def to_object_store_rs_credentials(self) -> Dict[str, str]:
creds = self.to_adlfs_credentials()
if creds["sas_token"] is None:
creds.pop("sas_token")
if creds["account_key"] is None:
creds.pop("account_key")
return creds

def create_sas_token(self) -> None:
Expand Down
2 changes: 0 additions & 2 deletions dlt/common/data_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dlt.common.data_writers.writers import (
DataWriter,
DataWriterMetrics,
TDataItemFormat,
FileWriterSpec,
create_import_spec,
Expand All @@ -22,7 +21,6 @@
"resolve_best_writer_spec",
"get_best_writer_spec",
"is_native_writer",
"DataWriterMetrics",
"TDataItemFormat",
"BufferedDataWriter",
"new_file_id",
Expand Down
3 changes: 2 additions & 1 deletion dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import contextlib
from typing import ClassVar, Iterator, List, IO, Any, Optional, Type, Generic

from dlt.common.metrics import DataWriterMetrics
from dlt.common.typing import TDataItem, TDataItems
from dlt.common.data_writers.exceptions import (
BufferedDataWriterClosed,
DestinationCapabilitiesRequired,
FileImportNotFound,
InvalidFileNameTemplateException,
)
from dlt.common.data_writers.writers import TWriter, DataWriter, DataWriterMetrics, FileWriterSpec
from dlt.common.data_writers.writers import TWriter, DataWriter, FileWriterSpec
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.configuration import with_config, known_sections, configspec
from dlt.common.configuration.specs import BaseConfiguration
Expand Down
20 changes: 1 addition & 19 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TLoaderFileFormat,
ALL_SUPPORTED_FILE_FORMATS,
)
from dlt.common.metrics import DataWriterMetrics
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.typing import StrAny

Expand All @@ -59,25 +60,6 @@ class FileWriterSpec(NamedTuple):
supports_compression: bool = False


class DataWriterMetrics(NamedTuple):
file_path: str
items_count: int
file_size: int
created: float
last_modified: float

def __add__(self, other: Tuple[object, ...], /) -> Tuple[object, ...]:
if isinstance(other, DataWriterMetrics):
return DataWriterMetrics(
"", # path is not known
self.items_count + other.items_count,
self.file_size + other.file_size,
min(self.created, other.created),
max(self.last_modified, other.last_modified),
)
return NotImplemented


EMPTY_DATA_WRITER_METRICS = DataWriterMetrics("", 0, 0, 2**32, 0.0)


Expand Down
1 change: 1 addition & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
# use naming convention in the schema
naming_convention: TNamingConventionReferenceArg = None
alter_add_multi_column: bool = True
supports_create_table_if_not_exists: bool = True
supports_truncate_command: bool = True
schema_supports_numeric_precision: bool = True
timestamp_precision: int = 6
Expand Down
50 changes: 42 additions & 8 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from copy import deepcopy
import inspect

from dlt.common import logger
from dlt.common import logger, pendulum
from dlt.common.configuration.specs.base_configuration import extract_inner_hint
from dlt.common.destination.utils import verify_schema_capabilities
from dlt.common.exceptions import TerminalValueError
from dlt.common.metrics import LoadJobMetrics
from dlt.common.normalizers.naming import NamingConvention
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.utils import (
Expand Down Expand Up @@ -268,6 +269,8 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura

staging_config: Optional[DestinationClientStagingConfiguration] = None
"""configuration of the staging, if present, injected at runtime"""
truncate_tables_on_staging_destination_before_load: bool = True
"""If dlt should truncate the tables on staging destination before loading data."""


TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"]
Expand All @@ -284,6 +287,8 @@ def __init__(self, file_path: str) -> None:
# NOTE: we only accept a full filepath in the constructor
assert self._file_name != self._file_path
self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name)
self._started_at: pendulum.DateTime = None
self._finished_at: pendulum.DateTime = None

def job_id(self) -> str:
"""The job id that is derived from the file name and does not changes during job lifecycle"""
Expand All @@ -306,6 +311,18 @@ def exception(self) -> str:
"""The exception associated with failed or retry states"""
pass

def metrics(self) -> Optional[LoadJobMetrics]:
"""Returns job execution metrics"""
return LoadJobMetrics(
self._parsed_file_name.job_id(),
self._file_path,
self._parsed_file_name.table_name,
self._started_at,
self._finished_at,
self.state(),
None,
)


class RunnableLoadJob(LoadJob, ABC):
"""Represents a runnable job that loads a single file
Expand Down Expand Up @@ -361,16 +378,22 @@ def run_managed(
# filepath is now moved to running
try:
self._state = "running"
self._started_at = pendulum.now()
self._job_client.prepare_load_job_execution(self)
self.run()
self._state = "completed"
except (DestinationTerminalException, TerminalValueError) as e:
self._state = "failed"
self._exception = e
logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}")
except (DestinationTransientException, Exception) as e:
self._state = "retry"
self._exception = e
logger.exception(
f"Transient exception in job {self.job_id()} in file {self._file_path}"
)
finally:
self._finished_at = pendulum.now()
# sanity check
assert self._state in ("completed", "retry", "failed")

Expand All @@ -391,7 +414,7 @@ def exception(self) -> str:
return str(self._exception)


class FollowupJob:
class FollowupJobRequest:
"""Base class for follow up jobs that should be created"""

@abstractmethod
Expand All @@ -403,8 +426,8 @@ def new_file_path(self) -> str:
class HasFollowupJobs:
"""Adds a trait that allows to create single or table chain followup jobs"""

def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]:
"""Return list of new jobs. `final_state` is state to which this job transits"""
def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]:
"""Return list of jobs requests for jobs that should be created. `final_state` is state to which this job transits"""
return []


Expand Down Expand Up @@ -479,7 +502,7 @@ def create_table_chain_completed_followup_jobs(
self,
table_chain: Sequence[TTableSchema],
completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None,
) -> List[FollowupJob]:
) -> List[FollowupJobRequest]:
"""Creates a list of followup jobs that should be executed after a table chain is completed"""
return []

Expand Down Expand Up @@ -557,17 +580,28 @@ def with_staging_dataset(self) -> ContextManager["JobClientBase"]:
return self # type: ignore


class SupportsStagingDestination:
class SupportsStagingDestination(ABC):
"""Adds capability to support a staging destination for the load"""

def should_load_data_to_staging_dataset_on_staging_destination(
self, table: TTableSchema
) -> bool:
"""If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination
instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset
on staging destination to copy data to iceberg tables stored on regular dataset on staging destination.
The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their
own storage) will copy data.
"""
return False

@abstractmethod
def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool:
# the default is to truncate the tables on the staging destination...
return True
"""If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which
can be changed with a config flag.
For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it.
For Athena we truncate those tables only on "replace" write disposition.
"""
pass


# TODO: type Destination properly
Expand Down
80 changes: 69 additions & 11 deletions dlt/common/libs/deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from dlt.common import logger
from dlt.common.libs.pyarrow import pyarrow as pa
from dlt.common.libs.pyarrow import cast_arrow_schema_types
from dlt.common.schema.typing import TWriteDisposition
from dlt.common.schema.typing import TWriteDisposition, TTableSchema
from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop
from dlt.common.exceptions import MissingDependencyException
from dlt.common.storages import FilesystemConfiguration
from dlt.common.utils import assert_min_pkg_version
from dlt.destinations.impl.filesystem.filesystem import FilesystemClient

try:
import deltalake
from deltalake import write_deltalake, DeltaTable
from deltalake.writer import try_get_deltatable
except ModuleNotFoundError:
Expand Down Expand Up @@ -74,7 +76,7 @@ def write_delta_table(
partition_by: Optional[Union[List[str], str]] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> None:
"""Writes in-memory Arrow table to on-disk Delta table.
"""Writes in-memory Arrow data to on-disk Delta table.
Thin wrapper around `deltalake.write_deltalake`.
"""
Expand All @@ -93,31 +95,73 @@ def write_delta_table(
)


def get_delta_tables(pipeline: Pipeline, *tables: str) -> Dict[str, DeltaTable]:
"""Returns Delta tables in `pipeline.default_schema` as `deltalake.DeltaTable` objects.
def merge_delta_table(
table: DeltaTable,
data: Union[pa.Table, pa.RecordBatchReader],
schema: TTableSchema,
) -> None:
"""Merges in-memory Arrow data into on-disk Delta table."""

strategy = schema["x-merge-strategy"] # type: ignore[typeddict-item]
if strategy == "upsert":
# `DeltaTable.merge` does not support automatic schema evolution
# https://github.com/delta-io/delta-rs/issues/2282
_evolve_delta_table_schema(table, data.schema)

if "parent" in schema:
unique_column = get_first_column_name_with_prop(schema, "unique")
predicate = f"target.{unique_column} = source.{unique_column}"
else:
primary_keys = get_columns_names_with_prop(schema, "primary_key")
predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys])

qry = (
table.merge(
source=ensure_delta_compatible_arrow_data(data),
predicate=predicate,
source_alias="source",
target_alias="target",
)
.when_matched_update_all()
.when_not_matched_insert_all()
)

qry.execute()
else:
ValueError(f'Merge strategy "{strategy}" not supported.')


def get_delta_tables(
pipeline: Pipeline, *tables: str, schema_name: str = None
) -> Dict[str, DeltaTable]:
"""Returns Delta tables in `pipeline.default_schema (default)` as `deltalake.DeltaTable` objects.
Returned object is a dictionary with table names as keys and `DeltaTable` objects as values.
Optionally filters dictionary by table names specified as `*tables*`.
Raises ValueError if table name specified as `*tables` is not found.
Raises ValueError if table name specified as `*tables` is not found. You may try to switch to other
schemas via `schema_name` argument.
"""
from dlt.common.schema.utils import get_table_format

with pipeline.destination_client() as client:
with pipeline.destination_client(schema_name=schema_name) as client:
assert isinstance(
client, FilesystemClient
), "The `get_delta_tables` function requires a `filesystem` destination."

schema_delta_tables = [
t["name"]
for t in pipeline.default_schema.tables.values()
if get_table_format(pipeline.default_schema.tables, t["name"]) == "delta"
for t in client.schema.tables.values()
if get_table_format(client.schema.tables, t["name"]) == "delta"
]
if len(tables) > 0:
invalid_tables = set(tables) - set(schema_delta_tables)
if len(invalid_tables) > 0:
available_schemas = ""
if len(pipeline.schema_names) > 1:
available_schemas = f" Available schemas are {pipeline.schema_names}"
raise ValueError(
"Schema does not contain Delta tables with these names: "
f"{', '.join(invalid_tables)}."
f"Schema {client.schema.name} does not contain Delta tables with these names: "
f"{', '.join(invalid_tables)}.{available_schemas}"
)
schema_delta_tables = [t for t in schema_delta_tables if t in tables]
table_dirs = client.get_table_dirs(schema_delta_tables, remote=True)
Expand All @@ -132,7 +176,8 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str
"""Returns dict that can be passed as `storage_options` in `deltalake` library."""
creds = {}
extra_options = {}
if config.protocol in ("az", "gs", "s3"):
# TODO: create a mixin with to_object_store_rs_credentials for a proper discovery
if hasattr(config.credentials, "to_object_store_rs_credentials"):
creds = config.credentials.to_object_store_rs_credentials()
if config.deltalake_storage_options is not None:
extra_options = config.deltalake_storage_options
Expand All @@ -145,3 +190,16 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str
+ ". dlt will use the values in `deltalake_storage_options`."
)
return {**creds, **extra_options}


def _evolve_delta_table_schema(delta_table: DeltaTable, arrow_schema: pa.Schema) -> None:
"""Evolves `delta_table` schema if different from `arrow_schema`.
Adds column(s) to `delta_table` present in `arrow_schema` but not in `delta_table`.
"""
new_fields = [
deltalake.Field.from_pyarrow(field)
for field in ensure_delta_compatible_arrow_schema(arrow_schema)
if field not in delta_table.to_pyarrow_dataset().schema
]
delta_table.alter.add_columns(new_fields)
9 changes: 8 additions & 1 deletion dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def get_py_arrow_datatype(
elif column_type == "bool":
return pyarrow.bool_()
elif column_type == "timestamp":
return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz)
# sets timezone to None when timezone hint is false
timezone = tz if column.get("timezone", True) else None
precision = column.get("precision") or caps.timestamp_precision
return get_py_arrow_timestamp(precision, timezone)
elif column_type == "bigint":
return get_pyarrow_int(column.get("precision"))
elif column_type == "binary":
Expand Down Expand Up @@ -139,6 +142,10 @@ def get_column_type_from_py_arrow(dtype: pyarrow.DataType) -> TColumnType:
precision = 6
else:
precision = 9

if dtype.tz is None:
return dict(data_type="timestamp", precision=precision, timezone=False)

return dict(data_type="timestamp", precision=precision)
elif pyarrow.types.is_date(dtype):
return dict(data_type="date")
Expand Down
Loading

0 comments on commit 105b388

Please sign in to comment.