Skip to content

Commit

Permalink
Support destination name in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Aug 31, 2024
1 parent e8d0ea8 commit 0d9c75a
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 128 deletions.
16 changes: 8 additions & 8 deletions tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@ def test_load_arrow_item(
# os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True"
os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True"
os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True"
include_time = destination_config.destination not in (
include_time = destination_config.destination_type not in (
"athena",
"redshift",
"databricks",
"synapse",
"clickhouse",
) # athena/redshift can't load TIME columns
include_binary = not (
destination_config.destination in ("redshift", "databricks")
destination_config.destination_type in ("redshift", "databricks")
and destination_config.file_format == "jsonl"
)

include_decimal = not (
destination_config.destination == "databricks" and destination_config.file_format == "jsonl"
destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl"
)
include_date = not (
destination_config.destination == "databricks" and destination_config.file_format == "jsonl"
destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl"
)

item, records, _ = arrow_table_all_data_types(
Expand All @@ -77,7 +77,7 @@ def some_data():

# use csv for postgres to get native arrow processing
file_format = (
destination_config.file_format if destination_config.destination != "postgres" else "csv"
destination_config.file_format if destination_config.destination_type != "postgres" else "csv"
)

load_info = pipeline.run(some_data(), loader_file_format=file_format)
Expand Down Expand Up @@ -107,13 +107,13 @@ def some_data():
if isinstance(row[i], memoryview):
row[i] = row[i].tobytes()

if destination_config.destination == "redshift":
if destination_config.destination_type == "redshift":
# Redshift needs hex string
for record in records:
if "binary" in record:
record["binary"] = record["binary"].hex()

if destination_config.destination == "clickhouse":
if destination_config.destination_type == "clickhouse":
for record in records:
# Clickhouse needs base64 string for jsonl
if "binary" in record and destination_config.file_format == "jsonl":
Expand All @@ -128,7 +128,7 @@ def some_data():
row[i] = pendulum.instance(row[i])
# clickhouse produces rounding errors on double with jsonl, so we round the result coming from there
if (
destination_config.destination == "clickhouse"
destination_config.destination_type == "clickhouse"
and destination_config.file_format == "jsonl"
and isinstance(row[i], float)
):
Expand Down
6 changes: 3 additions & 3 deletions tests/load/pipeline/test_dbt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def dbt_venv() -> Iterator[Venv]:
def test_run_jaffle_package(
destination_config: DestinationTestConfiguration, dbt_venv: Venv
) -> None:
if destination_config.destination == "athena":
if destination_config.destination_type == "athena":
pytest.skip(
"dbt-athena requires database to be created and we don't do it in case of Jaffle"
)
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_run_jaffle_package(
ids=lambda x: x.name,
)
def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None:
if destination_config.destination == "mssql":
if destination_config.destination_type == "mssql":
pytest.skip(
"mssql requires non standard SQL syntax and we do not have specialized dbt package"
" for it"
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven
def test_run_chess_dbt_to_other_dataset(
destination_config: DestinationTestConfiguration, dbt_venv: Venv
) -> None:
if destination_config.destination == "mssql":
if destination_config.destination_type == "mssql":
pytest.skip(
"mssql requires non standard SQL syntax and we do not have specialized dbt package"
" for it"
Expand Down
4 changes: 2 additions & 2 deletions tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def data_resource(data):
assert sorted(observed, key=lambda d: d["id"]) == expected

# additional tests with two records, run only on duckdb to limit test load
if destination_config.destination == "duckdb":
if destination_config.destination_type == "duckdb":
# two records with same primary key
# record with highest value in sort column is a delete
# existing record is deleted and no record will be inserted
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def r():
ids=lambda x: x.name,
)
def test_upsert_merge_strategy_config(destination_config: DestinationTestConfiguration) -> None:
if destination_config.destination == "filesystem":
if destination_config.destination_type == "filesystem":
# TODO: implement validation and remove this test exception
pytest.skip(
"`upsert` merge strategy configuration validation has not yet been"
Expand Down
33 changes: 16 additions & 17 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def data_fun() -> Iterator[Any]:
# mock the correct destinations (never do that in normal code)
with p.managed_state():
p._set_destinations(
destination=Destination.from_reference(destination_config.destination),
destination=Destination.from_reference(destination_config.destination_type),
staging=(
Destination.from_reference(destination_config.staging)
if destination_config.staging
Expand Down Expand Up @@ -162,13 +162,12 @@ def test_default_schema_name(
for idx, alpha in [(0, "A"), (0, "B"), (0, "C")]
]

p = dlt.pipeline(
p = destination_config.setup_pipeline(
"test_default_schema_name",
TEST_STORAGE_ROOT,
destination=destination_config.destination,
staging=destination_config.staging,
dataset_name=dataset_name,
pipelines_dir=TEST_STORAGE_ROOT,
)

p.config.use_single_dataset = use_single_dataset
p.extract(data, table_name="test", schema=Schema("default"))
p.normalize()
Expand Down Expand Up @@ -207,7 +206,7 @@ def _data():
destination_config.setup()
info = dlt.run(
_data(),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name="specific" + uniq_id(),
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -283,7 +282,7 @@ def _data():
p = dlt.pipeline(dev_mode=True)
info = p.run(
_data(),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name="iteration" + uniq_id(),
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -373,7 +372,7 @@ def extended_rows():
assert "new_column" not in schema.get_table("simple_rows")["columns"]

# lets violate unique constraint on postgres, redshift and BQ ignore unique indexes
if destination_config.destination == "postgres":
if destination_config.destination_type == "postgres":
assert p.dataset_name == dataset_name
err_info = p.run(
source(1).with_resources("simple_rows"),
Expand Down Expand Up @@ -458,7 +457,7 @@ def complex_data():

info = dlt.run(
complex_data(),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name="ds_" + uniq_id(),
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -849,11 +848,11 @@ def other_data():
column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA)

# parquet on bigquery and clickhouse does not support JSON but we still want to run the test
if destination_config.destination in ["bigquery"]:
if destination_config.destination_type in ["bigquery"]:
column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text"

# duckdb 0.9.1 does not support TIME other than 6
if destination_config.destination in ["duckdb", "motherduck"]:
if destination_config.destination_type in ["duckdb", "motherduck"]:
column_schemas["col11_precision"]["precision"] = 0
# also we do not want to test col4_precision (datetime) because
# those timestamps are not TZ aware in duckdb and we'd need to
Expand All @@ -862,7 +861,7 @@ def other_data():
column_schemas["col4_precision"]["precision"] = 6

# drop TIME from databases not supporting it via parquet
if destination_config.destination in [
if destination_config.destination_type in [
"redshift",
"athena",
"synapse",
Expand All @@ -876,7 +875,7 @@ def other_data():
column_schemas.pop("col11_null")
column_schemas.pop("col11_precision")

if destination_config.destination in ("redshift", "dremio"):
if destination_config.destination_type in ("redshift", "dremio"):
data_types.pop("col7_precision")
column_schemas.pop("col7_precision")

Expand Down Expand Up @@ -923,10 +922,10 @@ def some_source():
assert_all_data_types_row(
db_row,
schema=column_schemas,
parse_complex_strings=destination_config.destination
parse_complex_strings=destination_config.destination_type
in ["snowflake", "bigquery", "redshift"],
allow_string_binary=destination_config.destination == "clickhouse",
timestamp_precision=3 if destination_config.destination in ("athena", "dremio") else 6,
allow_string_binary=destination_config.destination_type == "clickhouse",
timestamp_precision=3 if destination_config.destination_type in ("athena", "dremio") else 6,
)


Expand Down Expand Up @@ -1144,7 +1143,7 @@ def _data():
p = dlt.pipeline(
pipeline_name=f"pipeline_{dataset_name}",
dev_mode=dev_mode,
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/load/pipeline/test_refresh_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration
data = load_tables_to_dicts(pipeline, "some_data_1", "some_data_2", "some_data_3")
# name column still remains when table was truncated instead of dropped
# (except on filesystem where truncate and drop are the same)
if destination_config.destination == "filesystem":
if destination_config.destination_type == "filesystem":
result = sorted([row["id"] for row in data["some_data_1"]])
assert result == [3, 4]

Expand Down
22 changes: 12 additions & 10 deletions tests/load/pipeline/test_restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ def test_get_schemas_from_destination(
use_single_dataset: bool,
naming_convention: str,
) -> None:
set_naming_env(destination_config.destination, naming_convention)
set_naming_env(destination_config.destination_type, naming_convention)

pipeline_name = "pipe_" + uniq_id()
dataset_name = "state_test_" + uniq_id()

p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name)
assert_naming_to_caps(destination_config.destination, p.destination.capabilities())
assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities())
p.config.use_single_dataset = use_single_dataset

def _make_dn_name(schema_name: str) -> str:
Expand Down Expand Up @@ -318,13 +318,13 @@ def _make_dn_name(schema_name: str) -> str:
def test_restore_state_pipeline(
destination_config: DestinationTestConfiguration, naming_convention: str
) -> None:
set_naming_env(destination_config.destination, naming_convention)
set_naming_env(destination_config.destination_type, naming_convention)
# enable restoring from destination
os.environ["RESTORE_FROM_DESTINATION"] = "True"
pipeline_name = "pipe_" + uniq_id()
dataset_name = "state_test_" + uniq_id()
p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name)
assert_naming_to_caps(destination_config.destination, p.destination.capabilities())
assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities())

def some_data_gen(param: str) -> Any:
dlt.current.source_state()[param] = param
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_restore_schemas_while_import_schemas_exist(
)
# use run to get changes
p.run(
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -605,7 +605,7 @@ def some_data(param: str) -> Any:
p.run(
[data1, some_data("state2")],
schema=Schema("default"),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
loader_file_format=destination_config.file_format,
Expand All @@ -615,7 +615,7 @@ def some_data(param: str) -> Any:
# create a production pipeline in separate pipelines_dir
production_p = dlt.pipeline(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT)
production_p.run(
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -695,7 +695,9 @@ def some_data(param: str) -> Any:
[5, 4, 4, 3, 2],
)
except SqlClientNotAvailable:
pytest.skip(f"destination {destination_config.destination} does not support sql client")
pytest.skip(
f"destination {destination_config.destination_type} does not support sql client"
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -723,7 +725,7 @@ def some_data(param: str) -> Any:
p.run(
data4,
schema=Schema("sch1"),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
loader_file_format=destination_config.file_format,
Expand Down Expand Up @@ -753,7 +755,7 @@ def some_data(param: str) -> Any:
p.run(
data4,
schema=Schema("sch1"),
destination=destination_config.destination,
destination=destination_config.destination_factory(),
staging=destination_config.staging,
dataset_name=dataset_name,
loader_file_format=destination_config.file_format,
Expand Down
Loading

0 comments on commit 0d9c75a

Please sign in to comment.