diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 97cf23fdfc..b3fc2fbcd4 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -15,6 +15,9 @@ class LoaderConfiguration(PoolRunnerConfiguration): raise_on_max_retries: int = 5 """When gt 0 will raise when job reaches raise_on_max_retries""" _load_storage_config: LoadStorageConfiguration = None + # if set to `True`, the staging dataset will be + # truncated after loading the data + truncate_staging_dataset: bool = False def on_resolved(self) -> None: self.pool_type = "none" if self.workers == 1 else "thread" diff --git a/dlt/load/load.py b/dlt/load/load.py index 66ddb1c308..9d898bc54d 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -53,7 +53,7 @@ LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, ) -from dlt.load.utils import get_completed_table_chain, init_client +from dlt.load.utils import _extend_tables_with_table_chain, get_completed_table_chain, init_client class Load(Runnable[Executor], WithStepInfo[LoadMetrics, LoadInfo]): @@ -348,6 +348,8 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) ) ): job_client.complete_load(load_id) + self._maybe_trancate_staging_dataset(schema, job_client) + self.load_storage.complete_load_package(load_id, aborted) # collect package info self._loaded_packages.append(self.load_storage.get_load_package_info(load_id)) @@ -490,6 +492,37 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) + def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None: + """ + Truncate the staging dataset if one used, + and configuration requests truncation. + + Args: + schema (Schema): Schema to use for the staging dataset. + job_client (JobClientBase): + Job client to use for the staging dataset. + """ + if not ( + isinstance(job_client, WithStagingDataset) and self.config.truncate_staging_dataset + ): + return + + data_tables = schema.data_table_names() + tables = _extend_tables_with_table_chain( + schema, data_tables, data_tables, job_client.should_load_data_to_staging_dataset + ) + + try: + with self.get_destination_client(schema) as client: + with client.with_staging_dataset(): # type: ignore + client.initialize_storage(truncate_tables=tables) + + except Exception as exc: + logger.warn( + f"Staging dataset truncate failed due to the following error: {exc}" + " However, it didn't affect the data integrity." + ) + def get_step_info( self, pipeline: SupportsPipeline, diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index a2ea1936a9..53770f332d 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -554,6 +554,7 @@ def load( with signals.delayed_signals(): runner.run_pool(load_step.config, load_step) info: LoadInfo = self._get_step_info(load_step) + self.first_run = False return info except Exception as l_ex: diff --git a/docs/website/docs/running-in-production/running.md b/docs/website/docs/running-in-production/running.md index dc49cf7659..25a964afc4 100644 --- a/docs/website/docs/running-in-production/running.md +++ b/docs/website/docs/running-in-production/running.md @@ -108,6 +108,12 @@ behind. In `config.toml`: load.delete_completed_jobs=true ``` +Also, by default, `dlt` leaves data in staging dataset, used during merge and replace load for deduplication. In order to clear it, put the following line in `config.toml`: + +```toml +load.truncate_staging_dataset=true +``` + ## Using slack to send messages `dlt` provides basic support for sending slack messages. You can configure Slack incoming hook via diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index 845800e47f..533d16c998 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -384,7 +384,17 @@ def dag_parallel(): with mock.patch("dlt.helpers.airflow_helper.logger.warn") as warn_mock: dag_def = dag_parallel() dag_def.test() - warn_mock.assert_called_once() + warn_mock.assert_has_calls( + [ + mock.call( + "The resource resource2 in task" + " mock_data_incremental_source_resource1-resource2 is using incremental loading" + " and may modify the state. Resources that modify the state should not run in" + " parallel within the single pipeline as the state will not be correctly" + " merged. Please use 'serialize' or 'parallel-isolated' modes instead." + ) + ] + ) def test_parallel_isolated_run(): diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index a498b570a0..d98f335d16 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -10,6 +10,7 @@ from dlt.common.pipeline import SupportsPipeline from dlt.common.destination import Destination from dlt.common.destination.exceptions import DestinationHasFailedJobs +from dlt.common.destination.reference import WithStagingDataset from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME @@ -896,6 +897,7 @@ def test_pipeline_upfront_tables_two_loads( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + os.environ["TRUNCATE_STAGING_DATASET"] = "True" pipeline = destination_config.setup_pipeline( "test_pipeline_upfront_tables_two_loads", @@ -1001,6 +1003,21 @@ def table_3(make_data=False): is True ) + job_client, _ = pipeline._get_destination_clients(schema) + + if destination_config.staging and isinstance(job_client, WithStagingDataset): + for i in range(1, 4): + with pipeline.sql_client() as client: + table_name = f"table_{i}" + + if job_client.should_load_data_to_staging_dataset( + job_client.schema.tables[table_name] + ): + with client.with_staging_dataset(staging=True): + tab_name = client.make_qualified_table_name(table_name) + with client.execute_query(f"SELECT * FROM {tab_name}") as cur: + assert len(cur.fetchall()) == 0 + # @pytest.mark.skip(reason="Finalize the test: compare some_data values to values from database") # @pytest.mark.parametrize( diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index a828de40fd..1c4383405b 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -5,9 +5,9 @@ import logging import os import random +import threading from time import sleep from typing import Any, Tuple, cast -import threading from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest @@ -2230,3 +2230,33 @@ def stateful_resource(): assert len(fs_client.list_table_files("_dlt_loads")) == 2 assert len(fs_client.list_table_files("_dlt_version")) == 1 assert len(fs_client.list_table_files("_dlt_pipeline_state")) == 1 + + +@pytest.mark.parametrize("truncate", (True, False)) +def test_staging_dataset_truncate(truncate) -> None: + dlt.config["truncate_staging_dataset"] = truncate + + @dlt.resource(write_disposition="merge", merge_key="id") + def test_data(): + yield [{"field": 1, "id": 1}, {"field": 2, "id": 2}, {"field": 3, "id": 3}] + + pipeline = dlt.pipeline( + pipeline_name="test_staging_cleared", + destination="duckdb", + full_refresh=True, + ) + + info = pipeline.run(test_data, table_name="staging_cleared") + assert_load_info(info) + + with pipeline.sql_client() as client: + with client.execute_query( + f"SELECT * FROM {pipeline.dataset_name}_staging.staging_cleared" + ) as cur: + if truncate: + assert len(cur.fetchall()) == 0 + else: + assert len(cur.fetchall()) == 3 + + with client.execute_query(f"SELECT * FROM {pipeline.dataset_name}.staging_cleared") as cur: + assert len(cur.fetchall()) == 3