Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 3, 2024
1 parent 47bb759 commit 965a87c
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 50 deletions.
7 changes: 6 additions & 1 deletion dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import re

from dlt.pipeline.current import load_package as current_load_package
from dlt.common.storages.exceptions import CurrentLoadPackageStateNotAvailable
from dlt.common import logger
from dlt.common.json import json
from dlt.common.pendulum import pendulum
Expand Down Expand Up @@ -287,7 +288,11 @@ def restore_file_load(self, file_path: str) -> LoadJob:
return None

def _store_pipeline_state(self) -> None:
pipeline_state_doc = current_load_package()["state"].get("pipeline_state")
try:
pipeline_state_doc = current_load_package()["state"].get("pipeline_state")
except CurrentLoadPackageStateNotAvailable:
# Not in load package context, nothing to do
return
if not pipeline_state_doc:
# We're probably dealing with an old load package pre load_package_state
return
Expand Down
4 changes: 2 additions & 2 deletions tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def _updated_event(node_id):
_get_shuffled_events(True) | github_resource,
loader_file_format=destination_config.file_format,
)
# no packages were loaded
assert len(info.loads_ids) == 0
# empty load package loaded
assert_load_info(info, expected_total_jobs=0)

# load one more event with a new id
info = p.run(
Expand Down
16 changes: 10 additions & 6 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ def data_fun() -> Iterator[Any]:
# set no dataset name -> if destination does not support it we revert to default
p._set_dataset_name(None)
assert p.dataset_name in possible_dataset_names

# Extract the state to new load package
with p.managed_state(extract_state=True):
pass
# the last package contains just the state (we added a new schema)
last_load_id = p.list_extracted_load_packages()[-1]
state_package = p.get_load_package_info(last_load_id)
assert len(state_package.jobs["new_jobs"]) == 1
assert len(state_package.jobs["new_jobs"]) == 0 # State package has no jobs, only state
assert state_package.schema_name == p.default_schema_name
p.normalize()
info = p.load(dataset_name="d" + uniq_id())
Expand Down Expand Up @@ -944,14 +948,14 @@ def table_3(make_data=False):
with pytest.raises(DatabaseUndefinedRelation):
load_table_counts(pipeline, "table_3")
assert "x-normalizer" not in pipeline.default_schema.tables["table_3"]
assert (
pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item]
is True
)
# assert (
# pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item]
# is True
# )

# load with one empty job, table 3 not created
load_info = pipeline.run(source.table_3, loader_file_format=destination_config.file_format)
assert_load_info(load_info, expected_load_packages=0)
assert_load_info(load_info, expected_load_packages=1, expected_total_jobs=0)
with pytest.raises(DatabaseUndefinedRelation):
load_table_counts(pipeline, "table_3")
# print(pipeline.default_schema.to_pretty_yaml())
Expand Down
24 changes: 6 additions & 18 deletions tests/load/pipeline/test_replace_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,6 @@ def test_replace_disposition(
# make duckdb to reuse database in working folder
os.environ["DESTINATION__DUCKDB__CREDENTIALS"] = "duckdb:///test_replace_disposition.duckdb"

increase_state_loads = lambda info: len(
[
job
for job in info.load_packages[0].jobs["completed_jobs"]
if job.job_file_info.table_name == "_dlt_pipeline_state"
and job.job_file_info.file_format not in ["sql", "reference"]
]
)

# filesystem does not have child tables, prepend defaults
def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, int]:
return {**{t: 0 for t in child_tables}, **counts}
Expand Down Expand Up @@ -99,7 +90,7 @@ def append_items():
)
assert_load_info(info)
# count state records that got extracted
state_records = increase_state_loads(info)
state_records: int = 1
dlt_loads: int = 1
dlt_versions: int = 1

Expand All @@ -109,7 +100,6 @@ def append_items():
[load_items, append_items], loader_file_format=destination_config.file_format
)
assert_load_info(info)
state_records += increase_state_loads(info)
dlt_loads += 1

# we should have all items loaded
Expand All @@ -122,6 +112,7 @@ def append_items():
"_dlt_pipeline_state": state_records,
"_dlt_loads": dlt_loads,
"_dlt_version": dlt_versions,
"_dlt_pipeline_state": state_records,
}

# check trace
Expand Down Expand Up @@ -157,11 +148,11 @@ def load_items_none():
[load_items_none, append_items], loader_file_format=destination_config.file_format
)
assert_load_info(info)
state_records += increase_state_loads(info)
dlt_loads += 1

# table and child tables should be cleared
table_counts = load_table_counts(pipeline, *pipeline.default_schema.tables.keys())

assert norm_table_counts(
table_counts, "items__sub_items", "items__sub_items__sub_sub_items"
) == {
Expand All @@ -173,6 +164,7 @@ def load_items_none():
"_dlt_loads": dlt_loads,
"_dlt_version": dlt_versions,
}

# check trace
assert pipeline.last_trace.last_normalize_info.row_counts == {
"append_items": 12,
Expand All @@ -190,22 +182,19 @@ def load_items_none():
load_items, table_name="items_copy", loader_file_format=destination_config.file_format
)
assert_load_info(info)
new_state_records = increase_state_loads(info)
assert new_state_records == 1
new_state_records = 1
dlt_loads += 1
dlt_versions += 1
# check trace
assert pipeline_2.last_trace.last_normalize_info.row_counts == {
"items_copy": 120,
"items_copy__sub_items": 240,
"items_copy__sub_items__sub_sub_items": 120,
"_dlt_pipeline_state": 1,
}

info = pipeline_2.run(append_items, loader_file_format=destination_config.file_format)
assert_load_info(info)
new_state_records = increase_state_loads(info)
assert new_state_records == 0
new_state_records += 1
dlt_loads += 1

# new pipeline
Expand Down Expand Up @@ -341,7 +330,6 @@ def yield_empty_list():
"other_items__sub_items": 2,
"static_items": 1,
"static_items__sub_items": 2,
"_dlt_pipeline_state": 1,
}

# see if child table gets cleared
Expand Down
12 changes: 6 additions & 6 deletions tests/load/test_dummy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_spool_job_failed_exception_init() -> None:
package_info = load.load_storage.get_load_package_info(load_id)
assert package_info.state == "aborted"
# both failed - we wait till the current loop is completed and then raise
assert len(package_info.jobs["failed_jobs"]) == 1
assert len(package_info.jobs["failed_jobs"]) == 2
assert len(package_info.jobs["started_jobs"]) == 0
# load id was never committed
complete_load.assert_not_called()
Expand All @@ -232,7 +232,7 @@ def test_spool_job_failed_exception_complete() -> None:
package_info = load.load_storage.get_load_package_info(load_id)
assert package_info.state == "aborted"
# both failed - we wait till the current loop is completed and then raise
assert len(package_info.jobs["failed_jobs"]) == 1
assert len(package_info.jobs["failed_jobs"]) == 2
assert len(package_info.jobs["started_jobs"]) == 0


Expand All @@ -254,8 +254,8 @@ def test_spool_job_retry_spool_new() -> None:
with ThreadPoolExecutor() as pool:
load.pool = pool
jobs_count, jobs = load.spool_new_jobs(load_id, schema)
assert jobs_count == 1
assert len(jobs) == 1
assert jobs_count == 2
assert len(jobs) == 2


def test_spool_job_retry_started() -> None:
Expand Down Expand Up @@ -315,11 +315,11 @@ def test_try_retrieve_job() -> None:
load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES)
load.pool = ThreadPoolExecutor()
jobs_count, jobs = load.spool_new_jobs(load_id, schema)
assert jobs_count == 1
assert jobs_count == 2
# now jobs are known
with load.destination.client(schema, load.initial_client_config) as c:
job_count, jobs = load.retrieve_jobs(c, load_id)
assert job_count == 1
assert job_count == 2
for j in jobs:
assert j.state() == "running"

Expand Down
16 changes: 2 additions & 14 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
cm_yield_client_with_storage,
write_dataset,
prepare_table,
prepare_load_package,
)
from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration

Expand Down Expand Up @@ -170,6 +171,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None:
def test_complete_load(client: SqlJobClientBase) -> None:
client.update_stored_schema()
load_id = "182879721.182912"

client.complete_load(load_id)
load_table = client.sql_client.make_qualified_table_name(LOADS_TABLE_NAME)
load_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {load_table}"))
Expand Down Expand Up @@ -349,20 +351,6 @@ def test_drop_tables(client: SqlJobClientBase) -> None:
exists, _ = client.get_storage_table(tbl)
assert not exists

# Verify _dlt_version schema is updated and old versions deleted
table_name = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME)
rows = client.sql_client.execute_sql(
f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema.name
)
assert len(rows) == 1
assert rows[0][0] == schema.version_hash

# Other schema is not replaced
rows = client.sql_client.execute_sql(
f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema_2.name
)
assert len(rows) == 2


@pytest.mark.parametrize(
"client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name
Expand Down
17 changes: 14 additions & 3 deletions tests/pipeline/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import posixpath
from typing import Any, Dict, List, Tuple, Callable, Sequence
from typing import Any, Dict, List, Tuple, Callable, Sequence, Optional
import pytest
import random
from os import environ
Expand Down Expand Up @@ -79,14 +79,25 @@ def many_delayed(many, iters):
#


def assert_load_info(info: LoadInfo, expected_load_packages: int = 1) -> None:
"""Asserts that expected number of packages was loaded and there are no failed jobs"""
def assert_load_info(
info: LoadInfo, expected_load_packages: int = 1, expected_total_jobs: Optional[int] = None
) -> None:
"""Asserts that expected number of packages was loaded and there are no failed jobs
Expected total jobs is optional, if provided, it will assert that the total number of jobs (sum of failed/new/completed/started)
"""
assert len(info.loads_ids) == expected_load_packages
# all packages loaded
assert all(p.completed_at is not None for p in info.load_packages) is True
# no failed jobs in any of the packages
assert all(len(p.jobs["failed_jobs"]) == 0 for p in info.load_packages) is True

if expected_total_jobs is not None:
s = 0
for package in info.load_packages:
s += sum(len(val) for val in package.jobs.values())
assert s == expected_total_jobs


#
# Load utils
Expand Down

0 comments on commit 965a87c

Please sign in to comment.