Skip to content

Commit

Permalink
Write pipeline state with load_package_state
Browse files Browse the repository at this point in the history
Pipeline state is written by destination client in complete_load, without "state_resource"

schema migrate for state table, update some tests

Update more tests

Backwards compat state table

Extract drop cmd pipeline state in empty load package
  • Loading branch information
steinitzu committed Apr 24, 2024
1 parent 564c282 commit 9402090
Show file tree
Hide file tree
Showing 23 changed files with 645 additions and 160 deletions.
7 changes: 5 additions & 2 deletions dlt/common/schema/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dlt.common.schema.typing import (
LOADS_TABLE_NAME,
VERSION_TABLE_NAME,
STATE_TABLE_NAME,
TSimpleRegex,
TStoredSchema,
TTableSchemaColumns,
Expand All @@ -14,7 +15,7 @@
from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException

from dlt.common.normalizers.utils import import_normalizers
from dlt.common.schema.utils import new_table, version_table, load_table
from dlt.common.schema.utils import new_table, version_table, load_table, state_table


def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema:
Expand Down Expand Up @@ -118,7 +119,9 @@ def migrate_filters(group: str, filters: List[str]) -> None:
x_normalizer = table.setdefault("x-normalizer", {})
x_normalizer["seen-data"] = True
from_engine = 9

if from_engine == 9 and to_engine > 9:
schema_dict["tables"].setdefault(STATE_TABLE_NAME, {}).update(state_table())
from_engine = 10
schema_dict["engine_version"] = from_engine
if from_engine != to_engine:
raise SchemaEngineNoUpgradePathException(
Expand Down
3 changes: 3 additions & 0 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,9 @@ def _add_standard_tables(self) -> None:
self._schema_tables[self.loads_table_name] = self.normalize_table_identifiers(
utils.load_table()
)
self._schema_tables[self.state_table_name] = self.normalize_table_identifiers(
utils.state_table()
)

def _add_standard_hints(self) -> None:
default_hints = utils.standard_hints()
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


# current version of schema engine
SCHEMA_ENGINE_VERSION = 9
SCHEMA_ENGINE_VERSION = 10

# dlt tables
VERSION_TABLE_NAME = "_dlt_version"
Expand Down
20 changes: 20 additions & 0 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LOADS_TABLE_NAME,
SIMPLE_REGEX_PREFIX,
VERSION_TABLE_NAME,
STATE_TABLE_NAME,
TColumnName,
TPartialTableSchema,
TSchemaTables,
Expand Down Expand Up @@ -681,6 +682,25 @@ def load_table() -> TTableSchema:
return table


def state_table() -> TTableSchema:
table = new_table(
STATE_TABLE_NAME,
columns=[
{"name": "version", "data_type": "bigint", "nullable": False},
{"name": "engine_version", "data_type": "bigint", "nullable": False},
{"name": "pipeline_name", "data_type": "text", "nullable": False},
{"name": "state", "data_type": "text", "nullable": False},
{"name": "created_at", "data_type": "timestamp", "nullable": False},
{"name": "version_hash", "data_type": "text", "nullable": True},
{"name": "_dlt_load_id", "data_type": "text", "nullable": True},
{"name": "_dlt_id", "data_type": "text", "nullable": True},
],
)
table["write_disposition"] = "append" # TOOD: Legacy support for existing state load packages
table["description"] = "Created by DLT. Tracks pipeline state"
return table


def new_table(
table_name: str,
parent_table_name: str = None,
Expand Down
60 changes: 48 additions & 12 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import zlib
import re

from dlt.pipeline.current import load_package as current_load_package
from dlt.common import logger
from dlt.common.json import json
from dlt.common.pendulum import pendulum
Expand All @@ -36,6 +37,7 @@
TWriteDisposition,
TTableFormat,
)
from dlt.common.normalizers.utils import generate_dlt_id
from dlt.common.storages import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME
Expand Down Expand Up @@ -146,9 +148,13 @@ class SqlJobClientBase(JobClientBase, WithStateSync):
"engine_version",
"pipeline_name",
"state",
"version_hash",
"created_at",
"_dlt_load_id",
)
_STATE_TABLE_SELECT_COLUMNS: ClassVar[Tuple[str, ...]] = tuple(
col for col in _STATE_TABLE_COLUMNS if col != "version_hash"
)

def __init__(
self,
Expand All @@ -162,6 +168,9 @@ def __init__(
self.state_table_columns = ", ".join(
sql_client.escape_column_name(col) for col in self._STATE_TABLE_COLUMNS
)
self.state_table_select_columns = ", ".join(
sql_client.escape_column_name(col) for col in self._STATE_TABLE_SELECT_COLUMNS
)

super().__init__(schema, config)
self.sql_client = sql_client
Expand Down Expand Up @@ -280,18 +289,45 @@ def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJobWithoutFollowup.from_file_path(file_path, "completed")
return None

def _store_pipeline_state(self) -> None:
pipeline_state_doc = current_load_package()["state"].get("pipeline_state")
if not pipeline_state_doc:
# We're probably dealing with an old load package pre load_package_state
return
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
stmt = """
INSERT INTO {state_table}({columns}, _dlt_id)
VALUES(%s, %s, %s, %s, %s, %s, %s, %s)
""".format(
state_table=state_table,
columns=self.state_table_columns,
)
self.sql_client.execute_sql(
stmt,
pipeline_state_doc["version"],
pipeline_state_doc["engine_version"],
pipeline_state_doc["pipeline_name"],
pipeline_state_doc["state"],
pipeline_state_doc["version_hash"],
pipeline_state_doc["created_at"],
pipeline_state_doc["dlt_load_id"],
generate_dlt_id(), # TODO: legacy, old datasets have a non-nullable _dlt_id column
)

def complete_load(self, load_id: str) -> None:
name = self.sql_client.make_qualified_table_name(self.schema.loads_table_name)
now_ts = pendulum.now()
self.sql_client.execute_sql(
f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash)"
" VALUES(%s, %s, %s, %s, %s);",
load_id,
self.schema.name,
0,
now_ts,
self.schema.version_hash,
)
with self.sql_client.begin_transaction():
self._store_pipeline_state()
self.sql_client.execute_sql(
f"INSERT INTO {name}(load_id, schema_name, status, inserted_at,"
" schema_version_hash) VALUES(%s, %s, %s, %s, %s);",
load_id,
self.schema.name,
0,
now_ts,
self.schema.version_hash,
)

def __enter__(self) -> "SqlJobClientBase":
self.sql_client.open_connection()
Expand Down Expand Up @@ -369,9 +405,9 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
loads_table = self.sql_client.make_qualified_table_name(self.schema.loads_table_name)
query = (
f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN {loads_table} AS l ON"
" l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY"
" l.load_id DESC"
f"SELECT {self.state_table_select_columns} FROM {state_table} AS s JOIN"
f" {loads_table} AS l ON l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND"
" l.status = 0 ORDER BY l.load_id DESC"
)
with self.sql_client.execute_query(query, pipeline_name) as cur:
row = cur.fetchone()
Expand Down
45 changes: 30 additions & 15 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ def gather_metrics(self, load_id: str, source: DltSource) -> None:
# NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state
self.extract_storage.remove_closed_files(load_id)

@contextlib.contextmanager
def new_load_package(self, schema: Schema) -> Iterator[LoadPackageStateInjectableContext]:
load_id = self.extract_storage.create_load_package(schema)
with Container().injectable_context(
LoadPackageStateInjectableContext(
load_id=load_id, storage=self.extract_storage.new_packages
)
) as load_package:
yield load_package
commit_load_package_state()

def extract(
self,
source: DltSource,
Expand All @@ -373,15 +384,10 @@ def extract(
) -> str:
# generate load package to be able to commit all the sources together later
load_id = self.extract_storage.create_load_package(source.discover_schema())
with Container().injectable_context(

with self.new_load_package(source.schema) as load_package, Container().injectable_context(
SourceSchemaInjectableContext(source.schema)
), Container().injectable_context(
SourceInjectableContext(source)
), Container().injectable_context(
LoadPackageStateInjectableContext(
load_id=load_id, storage=self.extract_storage.new_packages
)
) as load_package:
), Container().injectable_context(SourceInjectableContext(source)):
# inject the config section with the current source name
with inject_section(
ConfigSectionContext(
Expand All @@ -407,17 +413,26 @@ def extract(
commit_load_package_state()
return load_id

def commit_package(
self,
load_id: str,
schema_name: str,
pipeline_state_doc: Optional[TPipelineStateDoc] = None,
cleanup: bool = False,
) -> None:
if pipeline_state_doc:
package_state = self.extract_storage.new_packages.get_load_package_state(load_id)
package_state["pipeline_state"] = {**pipeline_state_doc, "dlt_load_id": load_id}
self.extract_storage.new_packages.save_load_package_state(load_id, package_state)
self.extract_storage.commit_new_load_package(load_id, self.schema_storage[schema_name])
if cleanup:
self.extract_storage.delete_empty_extract_folder()

def commit_packages(self, pipline_state_doc: TPipelineStateDoc = None) -> None:
"""Commits all extracted packages to normalize storage, and adds the pipeline state to the load package"""
# commit load packages
for load_id, metrics in self._load_id_metrics.items():
if pipline_state_doc:
package_state = self.extract_storage.new_packages.get_load_package_state(load_id)
package_state["pipeline_state"] = {**pipline_state_doc, "dlt_load_id": load_id}
self.extract_storage.new_packages.save_load_package_state(load_id, package_state)
self.extract_storage.commit_new_load_package(
load_id, self.schema_storage[metrics[0]["schema_name"]]
)
self.commit_package(load_id, metrics[0]["schema_name"], pipline_state_doc)
# all load ids got processed, cleanup empty folder
self.extract_storage.delete_empty_extract_folder()

Expand Down
10 changes: 5 additions & 5 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,11 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics:
logger.info(
f"Found {len(schema_files)} files in schema {schema.name} load_id {load_id}"
)
if len(schema_files) == 0:
# delete empty package
self.normalize_storage.extracted_packages.delete_package(load_id)
logger.info(f"Empty package {load_id} processed")
continue
# if len(schema_files) == 0:
# # delete empty package
# self.normalize_storage.extracted_packages.delete_package(load_id)
# logger.info(f"Empty package {load_id} processed")
# continue
with self.collector(f"Normalize {schema.name} in {load_id}"):
self.collector.update("Files", 0, len(schema_files))
self.collector.update("Items", 0)
Expand Down
Loading

0 comments on commit 9402090

Please sign in to comment.