From 2b7f4c6d2a72399123a2c20da05d136b76b9f177 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 3 Sep 2024 01:13:52 +0200 Subject: [PATCH] Minor format and cleanup Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 19 +++--- dlt/destinations/impl/lancedb/utils.py | 60 ------------------- 2 files changed, 8 insertions(+), 71 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 7892a8c2c7..11249d0f97 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,6 @@ fill_empty_source_column_values_with_placeholder, get_canonical_vector_database_doc_id_merge_key, create_filter_condition, - add_missing_columns_to_arrow_table, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -793,6 +792,10 @@ def __init__( self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: + dlt_load_id = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] + db_client: DBConnection = self._job_client.db_client table_lineage: TTableLineage = [ TableJob( @@ -811,28 +814,22 @@ def run(self) -> None: file_path = job.file_path with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: payload_arrow_table: pa.Table = pq.read_table(f) - target_table_schema: pa.Schema = pq.read_schema(f) - - payload_arrow_table = add_missing_columns_to_arrow_table( - payload_arrow_table, target_table_schema - ) if target_is_root_table: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( job.table_schema ) - # TODO: Guard against edge cases. For example, if `doc_id` field has escape characters in it. filter_condition = create_filter_condition( canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] ) - merge_key = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + merge_key = dlt_load_id else: filter_condition = create_filter_condition( - self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] - payload_arrow_table[self._schema.data_item_normalizer.C_DLT_ROOT_ID], # type: ignore[attr-defined] + dlt_root_id, + payload_arrow_table[dlt_root_id], ) - merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + merge_key = dlt_id write_records( payload_arrow_table, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 525f1cec7a..f07f2754d2 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -3,14 +3,11 @@ import pyarrow as pa -from dlt import Schema from dlt.common import logger from dlt.common.destination.exceptions import DestinationTerminalException -from dlt.common.pendulum import __utcnow from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -28,23 +25,6 @@ def set_non_standard_providers_environment_variables( os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" -def get_default_arrow_value(field_type: TArrowDataType) -> object: - if pa.types.is_integer(field_type): - return 0 - elif pa.types.is_floating(field_type): - return 0.0 - elif pa.types.is_string(field_type): - return "" - elif pa.types.is_boolean(field_type): - return False - elif pa.types.is_date(field_type): - return __utcnow().today() - elif pa.types.is_timestamp(field_type): - return __utcnow() - else: - raise ValueError(f"Unsupported data type: {field_type}") - - def get_canonical_vector_database_doc_id_merge_key( load_table: TTableSchema, ) -> str: @@ -100,43 +80,3 @@ def format_value(element: Union[str, int, float, pa.Scalar]) -> str: return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) return f"{field_name} IN ({', '.join(map(format_value, array))})" - - -def add_missing_columns_to_arrow_table( - payload_arrow_table: pa.Table, - target_table_schema: pa.Schema, -) -> pa.Table: - """Add missing columns from the target schema to the payload Arrow table. - - LanceDB requires the payload to have all fields populated, even if we - don't intend to use them in our merge operation. - - Unfortunately, we can't just create NULL fields; else LanceDB always truncates - the target using `when_not_matched_by_source_delete`. - This function identifies columns present in the target schema but missing from - the payload table and adds them with either default or null values. - - Args: - payload_arrow_table: The input Arrow table. - target_table_schema: The schema of the target table. - - Returns: - The modified Arrow table with added columns. - - """ - schema_difference = pa.schema(set(target_table_schema) - set(payload_arrow_table.schema)) - - for field in schema_difference: - try: - default_value = get_default_arrow_value(field.type) - default_array = pa.array( - [default_value] * payload_arrow_table.num_rows, type=field.type - ) - payload_arrow_table = payload_arrow_table.append_column(field, default_array) - except ValueError as e: - logger.warning(f"{e}. Using null values for field '{field.name}'.") - payload_arrow_table = payload_arrow_table.append_column( - field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) - ) - - return payload_arrow_table