diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 730e585d6..fe890b11f 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1765,11 +1765,13 @@ def load_sql_based_model( unrendered_merge_filter = None unrendered_signals = None + unrendered_audits = None for prop in meta.expressions: if prop.name.lower() == "signals": unrendered_signals = prop.args.get("value") - + if prop.name.lower() == "audits": + unrendered_audits = prop.args.get("value") if ( prop.name.lower() == "kind" and (value := prop.args.get("value")) @@ -1816,10 +1818,13 @@ def load_sql_based_model( **kwargs, } - # Signals and merge_filter must remain unrendered, so that they can be rendered later at evaluation runtime. + # signals, audits and merge_filter must remain unrendered, so that they can be rendered later at evaluation runtime if unrendered_signals: meta_fields["signals"] = unrendered_signals + if unrendered_audits: + meta_fields["audits"] = unrendered_audits + if unrendered_merge_filter: for idx, kind_prop in enumerate(meta_fields["kind"].expressions): if kind_prop.name.lower() == "merge_filter": diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 7f55bd82e..8de59f7f2 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1172,7 +1172,8 @@ def test_audits(): name db.seed, audits ( audit_a, - audit_b(key='value') + audit_b(key='value'), + audit_c(key=@start_ds) ), tags (foo) ); @@ -1184,6 +1185,7 @@ def test_audits(): assert model.audits == [ ("audit_a", {}), ("audit_b", {"key": exp.Literal.string("value")}), + ("audit_c", {"key": d.MacroVar(this="start_ds")}), ] assert model.tags == ["foo"] diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 86b457e8e..45027a248 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -28,6 +28,7 @@ IncrementalByTimeRangeKind, IncrementalUnmanagedKind, IncrementalByPartitionKind, + IncrementalByUniqueKeyKind, PythonModel, SqlModel, TimeColumn, @@ -2708,6 +2709,39 @@ def test_audit_wap(adapter_mock, make_snapshot): adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id) +def test_audit_with_datetime_macros(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=IncrementalByUniqueKeyKind(unique_key="a"), + query=parse_one("SELECT a, start_ds FROM tbl"), + audits=[ + ( + "unique_combination_of_columns", + { + "columns": exp.Array(expressions=[exp.to_column("a")]), + "condition": d.MacroVar(this="start_ds").neq("2020-01-01"), + }, + ), + ], + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.fetchone.return_value = (0,) + evaluator.audit(snapshot, snapshots={}, start="2020-01-01") + + call_args = adapter_mock.fetchone.call_args_list + assert len(call_args) == 1 + + unique_combination_of_columns_query = call_args[0][0][0] + assert ( + unique_combination_of_columns_query.sql(dialect="duckdb") + == """SELECT COUNT(*) FROM (SELECT "a" AS "a" FROM (SELECT * FROM "test_schema"."test_table" AS "test_table") AS "_q_0" WHERE '2020-01-01' <> '2020-01-01' GROUP BY "a" HAVING COUNT(*) > 1) AS audit""" + ) + + def test_audit_set_blocking_at_use_site(adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock)