diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 14082858b5..6bf5a9ef4f 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -2,7 +2,6 @@ import logging import re -import sys import typing as t from collections import defaultdict from functools import cached_property @@ -27,7 +26,14 @@ from sqlmesh.core.snapshot.definition import Interval, SnapshotId from sqlmesh.utils import columns_to_types_all_known, random_id from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds, to_timestamp +from sqlmesh.utils.date import ( + TimeLike, + now, + to_datetime, + yesterday_ds, + to_timestamp, + time_like_to_str, +) from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError logger = logging.getLogger(__name__) @@ -322,12 +328,19 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool: if not restate_models: return {} + start = self._start or earliest_interval_start + end = self._end or now() + # Add restate snapshots and their downstream snapshots - dummy_interval = (sys.maxsize, -sys.maxsize) for model_fqn in restate_models: - snapshot = self._model_fqn_to_snapshot.get(model_fqn) - if not snapshot: + if model_fqn not in self._model_fqn_to_snapshot: raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.") + + # Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's + # restatement range that it's downstream dependencies all expand their restatement ranges as well. + for s_id in dag: + snapshot = self._context_diff.snapshots[s_id] + if not forward_only_preview_needed: if self._is_dev and not snapshot.is_paused: self._console.log_warning( @@ -346,32 +359,33 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool: logger.info("Skipping restatement for model '%s'", model_fqn) continue - restatements[snapshot.snapshot_id] = dummy_interval - for downstream_s_id in dag.downstream(snapshot.snapshot_id): - if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]): - restatements[downstream_s_id] = dummy_interval - - # Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's - # restatement range that it's downstream dependencies all expand their restatement ranges as well. - for s_id in dag: - if s_id not in restatements: - continue - snapshot = self._context_diff.snapshots[s_id] - interval = snapshot.get_removal_interval( - self._start or earliest_interval_start, - self._end or now(), + removal_interval = snapshot.get_removal_interval( + start, + end, self._execution_time, strict=False, is_preview=is_preview, ) + # Since we are traversing the graph in topological order and the largest interval range is pushed down # the graph we just have to check our immediate parents in the graph and not the whole upstream graph. - snapshot_dependencies = snapshot.parents - possible_intervals = [ - restatements.get(s, dummy_interval) - for s in snapshot_dependencies - if self._context_diff.snapshots[s].is_incremental - ] + [interval] + restating_parents = [ + self._context_diff.snapshots[s] for s in snapshot.parents if s in restatements + ] + + if not restating_parents and snapshot.name not in restate_models: + continue + if not removal_interval: + self._console.log_error( + f"Skipping restatement of {snapshot.name} because provided range" + f" [{time_like_to_str(start)} - {time_like_to_str(end)}]" + f" is not a complete {snapshot.node.interval_unit}." + ) + continue + + possible_intervals = { + restatements[p.snapshot_id] for p in restating_parents if p.is_incremental + } | {removal_interval} snapshot_start = min(i[0] for i in possible_intervals) snapshot_end = max(i[1] for i in possible_intervals) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 8f2384375a..4171369452 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -709,7 +709,7 @@ def get_removal_interval( *, strict: bool = True, is_preview: bool = False, - ) -> Interval: + ) -> t.Optional[Interval]: """Get the interval that should be removed from the snapshot. Args: @@ -742,7 +742,9 @@ def get_removal_interval( removal_interval = expanded_removal_interval - return removal_interval + if removal_interval[0] < removal_interval[1]: + return removal_interval + return None def inclusive_exclusive( self, @@ -2039,10 +2041,11 @@ def apply_auto_restatements( interval_to_remove_start, interval_to_remove_end, execution_time=execution_time ) - auto_restated_intervals_per_snapshot[s_id] = removal_interval - snapshot.pending_restatement_intervals = merge_intervals( - [*snapshot.pending_restatement_intervals, removal_interval] - ) + if removal_interval: + auto_restated_intervals_per_snapshot[s_id] = removal_interval + snapshot.pending_restatement_intervals = merge_intervals( + [*snapshot.pending_restatement_intervals, removal_interval] + ) snapshot.apply_pending_restatement_intervals() snapshot.update_next_auto_restatement_ts(execution_time) diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py index a93c2f4eae..85a28960bc 100644 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ b/sqlmesh/core/state_sync/engine_adapter.py @@ -473,15 +473,14 @@ def unpause_snapshots( target_snapshot.snapshot_id, ) full_snapshot = snapshot.full_snapshot - self.remove_intervals( - [ - ( - full_snapshot, - full_snapshot.get_removal_interval(effective_from_ts, current_ts), - ) - ] + + removal_interval = full_snapshot.get_removal_interval( + effective_from_ts, current_ts ) + if removal_interval: + self.remove_intervals([(full_snapshot, removal_interval)]) + if snapshot.unpaused_ts: logger.info("Pausing snapshot %s", snapshot.snapshot_id) snapshot.set_unpaused_ts(None) diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 423650600e..f0fcc888d5 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -2753,3 +2753,103 @@ def test_restate_production_model_in_dev(make_snapshot, mocker: MockerFixture): "Cannot restate model '\"test_model_b\"' because the current version is used in production. " "Run the restatement against the production environment instead to restate this model." ) + + +def test_restate_daily_to_monthly(make_snapshot, mocker: MockerFixture): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as one"), + cron="@daily", + start="2025-01-01", + ), + ) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select one from a"), + cron="@monthly", + start="2025-01-01", + ), + nodes={'"a"': snapshot_a.model}, + ) + + snapshot_c = make_snapshot( + SqlModel( + name="c", + query=parse_one("select one from b"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + + snapshot_d = make_snapshot( + SqlModel( + name="d", + query=parse_one("select one from b union all select one from a"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + snapshot_e = make_snapshot( + SqlModel( + name="e", + query=parse_one("select one from b"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=True, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={ + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + }, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + ) + + schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER + + plan = PlanBuilder( + context_diff, + schema_differ, + restate_models=[snapshot_a.name, snapshot_e.name], + start="2025-02-15", + end="2025-02-20", + ).build() + + # b is invalid because it's monthly and the date range is partial + # c is invalid because it only depends on b which is not restated + assert plan.restatements == { + snapshot_a.snapshot_id: (1739577600000, 1740096000000), + snapshot_d.snapshot_id: (1739577600000, 1740096000000), + snapshot_e.snapshot_id: (1739577600000, 1740096000000), + } diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 58a7eed061..5d8af18914 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -677,7 +677,7 @@ def test_missing_interval_smaller_than_interval_unit(make_snapshot): ] -def test_remove_intervals(snapshot: Snapshot): +def test_remove_intervals(snapshot): snapshot.add_interval("2020-01-01", "2020-01-01") snapshot.remove_interval(snapshot.get_removal_interval("2020-01-01", "2020-01-01")) assert snapshot.intervals == [] diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index e0949ad6fd..668d72bda8 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -5,7 +5,7 @@ import logging import pytest import pandas as pd -from pydantic import field_validator, ValidationInfo, ValidationError +from pydantic import field_validator, ValidationError from pathlib import Path from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp @@ -3181,7 +3181,7 @@ class TestCustomKind(CustomKind): @field_validator("primary_key", mode="before") @classmethod - def _validate_primary_key(cls, value: t.Any, info: ValidationInfo) -> t.Any: + def _validate_primary_key(cls, value, info): return list_of_fields_validator(value, info.data) class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]):