diff --git a/dlt/load/load.py b/dlt/load/load.py index 0db257b751..23010cb39c 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -502,7 +502,9 @@ def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientB job_client (JobClientBase): Job client to use for the staging dataset. """ - if not isinstance(job_client, WithStagingDataset): + if not ( + isinstance(job_client, WithStagingDataset) and self.config.truncate_staging_dataset + ): return data_tables = schema.data_table_names() @@ -511,11 +513,9 @@ def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientB ) try: - if self.config.truncate_staging_dataset: - with self.get_destination_client(schema) as client: - if isinstance(client, WithStagingDataset): - with client.with_staging_dataset(): - client.initialize_storage(truncate_tables=tables) + with self.get_destination_client(schema) as client: + with client.with_staging_dataset(): + client.initialize_storage(truncate_tables=tables) except Exception as exc: logger.warn( diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index ee1118a64a..72289642bb 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -1001,11 +1001,12 @@ def table_3(make_data=False): if isinstance(job_client, WithStagingDataset): with pipeline.sql_client() as client: for i in range(1, 4): - if job_client.should_load_data_to_staging_dataset( - pipeline.default_schema.tables[f"table_{i}"] + table_name = f"table_{i}" + with client.with_staging_dataset( + job_client.should_load_data_to_staging_dataset(job_client.schema.tables[table_name]) # type: ignore[attr-defined] ): with client.execute_query( - f"SELECT * FROM {pipeline.dataset_name}_staging.table_{i}" + f"SELECT * FROM {client.make_qualified_table_name(table_name)}" ) as cur: assert len(cur.fetchall()) == 0