From 52115c5a92bde05b77205d41dc2bc75166e3dc04 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Apr 2024 13:41:25 -0400 Subject: [PATCH] Save truncated tables in load package state --- dlt/common/storages/load_package.py | 2 ++ dlt/extract/extract.py | 10 ++++++---- dlt/load/load.py | 5 +++-- dlt/load/utils.py | 20 ++++++++++++-------- tests/pipeline/test_refresh_modes.py | 22 ++++++++++++++++------ 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 2e24909a8a..59fe4b1823 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -64,6 +64,8 @@ class TLoadPackageState(TVersionedState, total=False): dropped_tables: NotRequired[List[TTableSchema]] """List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)""" + truncated_tables: NotRequired[List[TTableSchema]] + """List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)""" class TLoadPackage(TypedDict, total=False): diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 20c8662f3d..344cf34d77 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -401,7 +401,6 @@ def extract( _state, resources=_resources_to_drop, drop_all=self.refresh == "drop_dataset", - state_only=self.refresh == "drop_data", state_paths="*" if self.refresh == "drop_dataset" else [], ) _state.update(new_state) @@ -411,9 +410,12 @@ def extract( for table in source.schema.tables.values() if table["name"] in drop_info["tables"] ] - load_package.state["dropped_tables"] = drop_tables - source.schema.tables.clear() - source.schema.tables.update(new_schema.tables) + if self.refresh == "drop_data": + load_package.state["truncated_tables"] = drop_tables + else: + source.schema.tables.clear() + source.schema.tables.update(new_schema.tables) + load_package.state["dropped_tables"] = drop_tables # reset resource states, the `extracted` list contains all the explicit resources and all their parents for resource in source.resources.extracted.values(): diff --git a/dlt/load/load.py b/dlt/load/load.py index ca4b1f0245..a0e1b3f6f1 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -358,6 +358,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) dropped_tables = current_load_package()["state"].get("dropped_tables", []) + truncated_tables = current_load_package()["state"].get("truncated_tables", []) # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: @@ -374,8 +375,8 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: if isinstance(job_client, WithStagingDataset) else None ), - refresh=self.refresh, drop_tables=dropped_tables, + truncate_tables=truncated_tables, ) # init staging client @@ -394,8 +395,8 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: job_client.should_truncate_table_before_load_on_staging_destination, # should_truncate_staging, job_client.should_load_data_to_staging_dataset_on_staging_destination, - refresh=self.refresh, drop_tables=dropped_tables, + truncate_tables=truncated_tables, ) self.load_storage.commit_schema_update(load_id, applied_update) diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 6b1f8ceb92..ab2238b214 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -67,8 +67,8 @@ def init_client( expected_update: TSchemaTables, truncate_filter: Callable[[TTableSchema], bool], load_staging_filter: Callable[[TTableSchema], bool], - refresh: Optional[TRefreshMode] = None, drop_tables: Optional[List[TTableSchema]] = None, + truncate_tables: Optional[List[TTableSchema]] = None, ) -> TSchemaTables: """Initializes destination storage including staging dataset if supported @@ -81,6 +81,8 @@ def init_client( expected_update (TSchemaTables): Schema update as in load package. Always present even if empty truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into + drop_tables (Optional[List[TTableSchema]]): List of tables to drop before initializing storage + truncate_tables (Optional[List[TTableSchema]]): List of tables to truncate before initializing storage Returns: TSchemaTables: Actual migrations done at destination @@ -96,19 +98,21 @@ def init_client( tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data # get tables to truncate by extending tables with jobs with all their child tables - - if refresh == "drop_data": - truncate_filter = lambda t: t["name"] in tables_with_jobs - dlt_tables - - truncate_tables = set( - _extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter) + initial_truncate_names = set(t["name"] for t in truncate_tables) if truncate_tables else set() + truncate_table_names = set( + _extend_tables_with_table_chain( + schema, + tables_with_jobs, + tables_with_jobs, + lambda t: truncate_filter(t) or t["name"] in initial_truncate_names, + ) ) applied_update = _init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, - truncate_tables, + truncate_table_names, drop_tables=drop_tables, ) diff --git a/tests/pipeline/test_refresh_modes.py b/tests/pipeline/test_refresh_modes.py index 39569d435e..e18ed70e1e 100644 --- a/tests/pipeline/test_refresh_modes.py +++ b/tests/pipeline/test_refresh_modes.py @@ -1,8 +1,10 @@ -import pytest +from unittest import mock +import pytest import dlt from dlt.common.pipeline import resource_state from dlt.destinations.exceptions import DatabaseUndefinedRelation +from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from tests.utils import clean_test_storage, preserve_environ from tests.pipeline.utils import assert_load_info @@ -198,9 +200,19 @@ def some_data_3(): # Second run of pipeline with only selected resources first_run = False - info = pipeline.run( - my_source().with_resources("some_data_1", "some_data_2"), write_disposition="append" - ) + + # Mock wrap sql client to capture all queries executed + with mock.patch.object( + DuckDbSqlClient, "execute_query", side_effect=DuckDbSqlClient.execute_query, autospec=True + ) as mock_execute_query: + info = pipeline.run( + my_source().with_resources("some_data_1", "some_data_2"), write_disposition="append" + ) + + all_queries = [k[0][1] for k in mock_execute_query.call_args_list] + assert all_queries + for q in all_queries: + assert "drop table" not in q.lower() # Tables are only truncated, never dropped # Tables selected in second run are truncated and should only have data from second run with pipeline.sql_client() as client: @@ -211,8 +223,6 @@ def some_data_3(): result = client.execute_sql("SELECT id FROM some_data_1 ORDER BY id") assert result == [(1,), (2,)] - # TODO: Test tables were truncated , not dropped - # Tables not selected in second run are not truncated, still have data from first run with pipeline.sql_client() as client: result = client.execute_sql("SELECT id FROM some_data_3 ORDER BY id")