Skip to content

Commit

Permalink
Minor format and cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Sep 2, 2024
1 parent 0eba25e commit 2b7f4c6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 71 deletions.
19 changes: 8 additions & 11 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
fill_empty_source_column_values_with_placeholder,
get_canonical_vector_database_doc_id_merge_key,
create_filter_condition,
add_missing_columns_to_arrow_table,
)
from dlt.destinations.job_impl import ReferenceFollowupJobRequest
from dlt.destinations.type_mapping import TypeMapper
Expand Down Expand Up @@ -793,6 +792,10 @@ def __init__(
self.references = ReferenceFollowupJobRequest.resolve_references(file_path)

def run(self) -> None:
dlt_load_id = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined]
dlt_id = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined]
dlt_root_id = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined]

db_client: DBConnection = self._job_client.db_client
table_lineage: TTableLineage = [
TableJob(
Expand All @@ -811,28 +814,22 @@ def run(self) -> None:
file_path = job.file_path
with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f:
payload_arrow_table: pa.Table = pq.read_table(f)
target_table_schema: pa.Schema = pq.read_schema(f)

payload_arrow_table = add_missing_columns_to_arrow_table(
payload_arrow_table, target_table_schema
)

if target_is_root_table:
canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key(
job.table_schema
)
# TODO: Guard against edge cases. For example, if `doc_id` field has escape characters in it.
filter_condition = create_filter_condition(
canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field]
)
merge_key = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined]
merge_key = dlt_load_id

else:
filter_condition = create_filter_condition(
self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined]
payload_arrow_table[self._schema.data_item_normalizer.C_DLT_ROOT_ID], # type: ignore[attr-defined]
dlt_root_id,
payload_arrow_table[dlt_root_id],
)
merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined]
merge_key = dlt_id

write_records(
payload_arrow_table,
Expand Down
60 changes: 0 additions & 60 deletions dlt/destinations/impl/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@

import pyarrow as pa

from dlt import Schema
from dlt.common import logger
from dlt.common.destination.exceptions import DestinationTerminalException
from dlt.common.pendulum import __utcnow
from dlt.common.schema import TTableSchema
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider
from dlt.destinations.impl.lancedb.schema import TArrowDataType

EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB"
PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = {
Expand All @@ -28,23 +25,6 @@ def set_non_standard_providers_environment_variables(
os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or ""


def get_default_arrow_value(field_type: TArrowDataType) -> object:
if pa.types.is_integer(field_type):
return 0
elif pa.types.is_floating(field_type):
return 0.0
elif pa.types.is_string(field_type):
return ""
elif pa.types.is_boolean(field_type):
return False
elif pa.types.is_date(field_type):
return __utcnow().today()
elif pa.types.is_timestamp(field_type):
return __utcnow()
else:
raise ValueError(f"Unsupported data type: {field_type}")


def get_canonical_vector_database_doc_id_merge_key(
load_table: TTableSchema,
) -> str:
Expand Down Expand Up @@ -100,43 +80,3 @@ def format_value(element: Union[str, int, float, pa.Scalar]) -> str:
return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element)

return f"{field_name} IN ({', '.join(map(format_value, array))})"


def add_missing_columns_to_arrow_table(
payload_arrow_table: pa.Table,
target_table_schema: pa.Schema,
) -> pa.Table:
"""Add missing columns from the target schema to the payload Arrow table.
LanceDB requires the payload to have all fields populated, even if we
don't intend to use them in our merge operation.
Unfortunately, we can't just create NULL fields; else LanceDB always truncates
the target using `when_not_matched_by_source_delete`.
This function identifies columns present in the target schema but missing from
the payload table and adds them with either default or null values.
Args:
payload_arrow_table: The input Arrow table.
target_table_schema: The schema of the target table.
Returns:
The modified Arrow table with added columns.
"""
schema_difference = pa.schema(set(target_table_schema) - set(payload_arrow_table.schema))

for field in schema_difference:
try:
default_value = get_default_arrow_value(field.type)
default_array = pa.array(
[default_value] * payload_arrow_table.num_rows, type=field.type
)
payload_arrow_table = payload_arrow_table.append_column(field, default_array)
except ValueError as e:
logger.warning(f"{e}. Using null values for field '{field.name}'.")
payload_arrow_table = payload_arrow_table.append_column(
field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type)
)

return payload_arrow_table

0 comments on commit 2b7f4c6

Please sign in to comment.