Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pipeline): add an ability to auto truncate #1292

Merged
merged 26 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dlt/load/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class LoaderConfiguration(PoolRunnerConfiguration):
raise_on_max_retries: int = 5
"""When gt 0 will raise when job reaches raise_on_max_retries"""
_load_storage_config: LoadStorageConfiguration = None
# if set to `True`, the staging dataset will be
# truncated after loading the data
truncate_staging_dataset: bool = False

def on_resolved(self) -> None:
self.pool_type = "none" if self.workers == 1 else "thread"
35 changes: 34 additions & 1 deletion dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
LoadClientUnsupportedWriteDisposition,
LoadClientUnsupportedFileFormats,
)
from dlt.load.utils import get_completed_table_chain, init_client
from dlt.load.utils import _extend_tables_with_table_chain, get_completed_table_chain, init_client


class Load(Runnable[Executor], WithStepInfo[LoadMetrics, LoadInfo]):
Expand Down Expand Up @@ -348,6 +348,8 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False)
)
):
job_client.complete_load(load_id)
self._maybe_trancate_staging_dataset(schema, job_client)

self.load_storage.complete_load_package(load_id, aborted)
# collect package info
self._loaded_packages.append(self.load_storage.get_load_package_info(load_id))
Expand Down Expand Up @@ -490,6 +492,37 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics:

return TRunMetrics(False, len(self.load_storage.list_normalized_packages()))

def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None:
"""
Truncate the staging dataset if one used,
and configuration requests truncation.

Args:
schema (Schema): Schema to use for the staging dataset.
job_client (JobClientBase):
Job client to use for the staging dataset.
"""
if not (
isinstance(job_client, WithStagingDataset) and self.config.truncate_staging_dataset
):
return

data_tables = schema.data_table_names()
tables = _extend_tables_with_table_chain(
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
schema, data_tables, data_tables, job_client.should_load_data_to_staging_dataset
)

try:
with self.get_destination_client(schema) as client:
with client.with_staging_dataset(): # type: ignore
client.initialize_storage(truncate_tables=tables)

except Exception as exc:
logger.warn(
f"Staging dataset truncate failed due to the following error: {exc}"
" However, it didn't affect the data integrity."
)

def get_step_info(
self,
pipeline: SupportsPipeline,
Expand Down
1 change: 1 addition & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def load(
with signals.delayed_signals():
runner.run_pool(load_step.config, load_step)
info: LoadInfo = self._get_step_info(load_step)

self.first_run = False
return info
except Exception as l_ex:
Expand Down
6 changes: 6 additions & 0 deletions docs/website/docs/running-in-production/running.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ behind. In `config.toml`:
load.delete_completed_jobs=true
```

Also, by default, `dlt` leaves data in staging dataset, used during merge and replace load for deduplication. In order to clear it, put the following line in `config.toml`:

```toml
load.truncate_staging_dataset=true
```

## Using slack to send messages

`dlt` provides basic support for sending slack messages. You can configure Slack incoming hook via
Expand Down
12 changes: 11 additions & 1 deletion tests/helpers/airflow_tests/test_airflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,17 @@ def dag_parallel():
with mock.patch("dlt.helpers.airflow_helper.logger.warn") as warn_mock:
dag_def = dag_parallel()
dag_def.test()
warn_mock.assert_called_once()
warn_mock.assert_has_calls(
[
mock.call(
"The resource resource2 in task"
" mock_data_incremental_source_resource1-resource2 is using incremental loading"
" and may modify the state. Resources that modify the state should not run in"
" parallel within the single pipeline as the state will not be correctly"
" merged. Please use 'serialize' or 'parallel-isolated' modes instead."
)
]
)


def test_parallel_isolated_run():
Expand Down
17 changes: 17 additions & 0 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dlt.common.pipeline import SupportsPipeline
from dlt.common.destination import Destination
from dlt.common.destination.exceptions import DestinationHasFailedJobs
from dlt.common.destination.reference import WithStagingDataset
from dlt.common.schema.exceptions import CannotCoerceColumnException
from dlt.common.schema.schema import Schema
from dlt.common.schema.typing import VERSION_TABLE_NAME
Expand Down Expand Up @@ -896,6 +897,7 @@ def test_pipeline_upfront_tables_two_loads(

# use staging tables for replace
os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy
os.environ["TRUNCATE_STAGING_DATASET"] = "True"

pipeline = destination_config.setup_pipeline(
"test_pipeline_upfront_tables_two_loads",
Expand Down Expand Up @@ -1001,6 +1003,21 @@ def table_3(make_data=False):
is True
)

job_client, _ = pipeline._get_destination_clients(schema)

if destination_config.staging and isinstance(job_client, WithStagingDataset):
for i in range(1, 4):
with pipeline.sql_client() as client:
table_name = f"table_{i}"

if job_client.should_load_data_to_staging_dataset(
job_client.schema.tables[table_name]
):
with client.with_staging_dataset(staging=True):
tab_name = client.make_qualified_table_name(table_name)
with client.execute_query(f"SELECT * FROM {tab_name}") as cur:
assert len(cur.fetchall()) == 0


# @pytest.mark.skip(reason="Finalize the test: compare some_data values to values from database")
# @pytest.mark.parametrize(
Expand Down
32 changes: 31 additions & 1 deletion tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging
import os
import random
import threading
from time import sleep
from typing import Any, Tuple, cast
import threading
from tenacity import retry_if_exception, Retrying, stop_after_attempt

import pytest
Expand Down Expand Up @@ -2230,3 +2230,33 @@ def stateful_resource():
assert len(fs_client.list_table_files("_dlt_loads")) == 2
assert len(fs_client.list_table_files("_dlt_version")) == 1
assert len(fs_client.list_table_files("_dlt_pipeline_state")) == 1


@pytest.mark.parametrize("truncate", (True, False))
def test_staging_dataset_truncate(truncate) -> None:
dlt.config["truncate_staging_dataset"] = truncate

@dlt.resource(write_disposition="merge", merge_key="id")
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
def test_data():
yield [{"field": 1, "id": 1}, {"field": 2, "id": 2}, {"field": 3, "id": 3}]

pipeline = dlt.pipeline(
pipeline_name="test_staging_cleared",
destination="duckdb",
full_refresh=True,
)

info = pipeline.run(test_data, table_name="staging_cleared")
assert_load_info(info)

with pipeline.sql_client() as client:
with client.execute_query(
f"SELECT * FROM {pipeline.dataset_name}_staging.staging_cleared"
) as cur:
if truncate:
assert len(cur.fetchall()) == 0
else:
assert len(cur.fetchall()) == 3

with client.execute_query(f"SELECT * FROM {pipeline.dataset_name}.staging_cleared") as cur:
assert len(cur.fetchall()) == 3
Loading