Skip to content

Commit

Permalink
Save truncated tables in load package state
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 16, 2024
1 parent 800c3cd commit 52115c5
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 20 deletions.
2 changes: 2 additions & 0 deletions dlt/common/storages/load_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class TLoadPackageState(TVersionedState, total=False):

dropped_tables: NotRequired[List[TTableSchema]]
"""List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)"""
truncated_tables: NotRequired[List[TTableSchema]]
"""List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)"""


class TLoadPackage(TypedDict, total=False):
Expand Down
10 changes: 6 additions & 4 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def extract(
_state,
resources=_resources_to_drop,
drop_all=self.refresh == "drop_dataset",
state_only=self.refresh == "drop_data",
state_paths="*" if self.refresh == "drop_dataset" else [],
)
_state.update(new_state)
Expand All @@ -411,9 +410,12 @@ def extract(
for table in source.schema.tables.values()
if table["name"] in drop_info["tables"]
]
load_package.state["dropped_tables"] = drop_tables
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
if self.refresh == "drop_data":
load_package.state["truncated_tables"] = drop_tables
else:
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
load_package.state["dropped_tables"] = drop_tables

# reset resource states, the `extracted` list contains all the explicit resources and all their parents
for resource in source.resources.extracted.values():
Expand Down
5 changes: 3 additions & 2 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
new_jobs = self.get_new_jobs_info(load_id)

dropped_tables = current_load_package()["state"].get("dropped_tables", [])
truncated_tables = current_load_package()["state"].get("truncated_tables", [])

# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(schema) as job_client:
Expand All @@ -374,8 +375,8 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
if isinstance(job_client, WithStagingDataset)
else None
),
refresh=self.refresh,
drop_tables=dropped_tables,
truncate_tables=truncated_tables,
)

# init staging client
Expand All @@ -394,8 +395,8 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
job_client.should_truncate_table_before_load_on_staging_destination,
# should_truncate_staging,
job_client.should_load_data_to_staging_dataset_on_staging_destination,
refresh=self.refresh,
drop_tables=dropped_tables,
truncate_tables=truncated_tables,
)

self.load_storage.commit_schema_update(load_id, applied_update)
Expand Down
20 changes: 12 additions & 8 deletions dlt/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def init_client(
expected_update: TSchemaTables,
truncate_filter: Callable[[TTableSchema], bool],
load_staging_filter: Callable[[TTableSchema], bool],
refresh: Optional[TRefreshMode] = None,
drop_tables: Optional[List[TTableSchema]] = None,
truncate_tables: Optional[List[TTableSchema]] = None,
) -> TSchemaTables:
"""Initializes destination storage including staging dataset if supported
Expand All @@ -81,6 +81,8 @@ def init_client(
expected_update (TSchemaTables): Schema update as in load package. Always present even if empty
truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated
load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into
drop_tables (Optional[List[TTableSchema]]): List of tables to drop before initializing storage
truncate_tables (Optional[List[TTableSchema]]): List of tables to truncate before initializing storage
Returns:
TSchemaTables: Actual migrations done at destination
Expand All @@ -96,19 +98,21 @@ def init_client(
tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data

# get tables to truncate by extending tables with jobs with all their child tables

if refresh == "drop_data":
truncate_filter = lambda t: t["name"] in tables_with_jobs - dlt_tables

truncate_tables = set(
_extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter)
initial_truncate_names = set(t["name"] for t in truncate_tables) if truncate_tables else set()
truncate_table_names = set(
_extend_tables_with_table_chain(
schema,
tables_with_jobs,
tables_with_jobs,
lambda t: truncate_filter(t) or t["name"] in initial_truncate_names,
)
)

applied_update = _init_dataset_and_update_schema(
job_client,
expected_update,
tables_with_jobs | dlt_tables,
truncate_tables,
truncate_table_names,
drop_tables=drop_tables,
)

Expand Down
22 changes: 16 additions & 6 deletions tests/pipeline/test_refresh_modes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from unittest import mock

import pytest
import dlt
from dlt.common.pipeline import resource_state
from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient

from tests.utils import clean_test_storage, preserve_environ
from tests.pipeline.utils import assert_load_info
Expand Down Expand Up @@ -198,9 +200,19 @@ def some_data_3():

# Second run of pipeline with only selected resources
first_run = False
info = pipeline.run(
my_source().with_resources("some_data_1", "some_data_2"), write_disposition="append"
)

# Mock wrap sql client to capture all queries executed
with mock.patch.object(
DuckDbSqlClient, "execute_query", side_effect=DuckDbSqlClient.execute_query, autospec=True
) as mock_execute_query:
info = pipeline.run(
my_source().with_resources("some_data_1", "some_data_2"), write_disposition="append"
)

all_queries = [k[0][1] for k in mock_execute_query.call_args_list]
assert all_queries
for q in all_queries:
assert "drop table" not in q.lower() # Tables are only truncated, never dropped

# Tables selected in second run are truncated and should only have data from second run
with pipeline.sql_client() as client:
Expand All @@ -211,8 +223,6 @@ def some_data_3():
result = client.execute_sql("SELECT id FROM some_data_1 ORDER BY id")
assert result == [(1,), (2,)]

# TODO: Test tables were truncated , not dropped

# Tables not selected in second run are not truncated, still have data from first run
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id FROM some_data_3 ORDER BY id")
Expand Down

0 comments on commit 52115c5

Please sign in to comment.