Skip to content

Commit

Permalink
fix: prevent get_removal_interval from returning invalid interval
Browse files Browse the repository at this point in the history
if your partial restate a daily model -> monthly, (2024-02-15,
2024-02-20), get_removal_interval returned an interval where start >
end. this ensures get_removal_interval only returns valid intervals.
  • Loading branch information
tobymao committed Feb 21, 2025
1 parent bb43e5f commit e9f37da
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 41 deletions.
64 changes: 39 additions & 25 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import re
import sys
import typing as t
from collections import defaultdict
from functools import cached_property
Expand All @@ -27,7 +26,14 @@
from sqlmesh.core.snapshot.definition import Interval, SnapshotId
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds, to_timestamp
from sqlmesh.utils.date import (
TimeLike,
now,
to_datetime,
yesterday_ds,
to_timestamp,
time_like_to_str,
)
from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -322,12 +328,19 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
if not restate_models:
return {}

start = self._start or earliest_interval_start
end = self._end or now()

# Add restate snapshots and their downstream snapshots
dummy_interval = (sys.maxsize, -sys.maxsize)
for model_fqn in restate_models:
snapshot = self._model_fqn_to_snapshot.get(model_fqn)
if not snapshot:
if model_fqn not in self._model_fqn_to_snapshot:
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
snapshot = self._context_diff.snapshots[s_id]

if not forward_only_preview_needed:
if self._is_dev and not snapshot.is_paused:
self._console.log_warning(
Expand All @@ -346,32 +359,33 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
logger.info("Skipping restatement for model '%s'", model_fqn)
continue

restatements[snapshot.snapshot_id] = dummy_interval
for downstream_s_id in dag.downstream(snapshot.snapshot_id):
if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]):
restatements[downstream_s_id] = dummy_interval

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
if s_id not in restatements:
continue
snapshot = self._context_diff.snapshots[s_id]
interval = snapshot.get_removal_interval(
self._start or earliest_interval_start,
self._end or now(),
removal_interval = snapshot.get_removal_interval(
start,
end,
self._execution_time,
strict=False,
is_preview=is_preview,
)

# Since we are traversing the graph in topological order and the largest interval range is pushed down
# the graph we just have to check our immediate parents in the graph and not the whole upstream graph.
snapshot_dependencies = snapshot.parents
possible_intervals = [
restatements.get(s, dummy_interval)
for s in snapshot_dependencies
if self._context_diff.snapshots[s].is_incremental
] + [interval]
restating_parents = [
self._context_diff.snapshots[s] for s in snapshot.parents if s in restatements
]

if not restating_parents and snapshot.name not in restate_models:
continue
if not removal_interval:
self._console.log_error(
f"Skipping restatement of {snapshot.name} because provided range"
f" [{time_like_to_str(start)} - {time_like_to_str(end)}]"
f" is not a complete {snapshot.node.interval_unit}."
)
continue

possible_intervals = {
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
} | {removal_interval}
snapshot_start = min(i[0] for i in possible_intervals)
snapshot_end = max(i[1] for i in possible_intervals)

Expand Down
15 changes: 9 additions & 6 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def get_removal_interval(
*,
strict: bool = True,
is_preview: bool = False,
) -> Interval:
) -> t.Optional[Interval]:
"""Get the interval that should be removed from the snapshot.
Args:
Expand Down Expand Up @@ -742,7 +742,9 @@ def get_removal_interval(

removal_interval = expanded_removal_interval

return removal_interval
if removal_interval[0] < removal_interval[1]:
return removal_interval
return None

def inclusive_exclusive(
self,
Expand Down Expand Up @@ -2039,10 +2041,11 @@ def apply_auto_restatements(
interval_to_remove_start, interval_to_remove_end, execution_time=execution_time
)

auto_restated_intervals_per_snapshot[s_id] = removal_interval
snapshot.pending_restatement_intervals = merge_intervals(
[*snapshot.pending_restatement_intervals, removal_interval]
)
if removal_interval:
auto_restated_intervals_per_snapshot[s_id] = removal_interval
snapshot.pending_restatement_intervals = merge_intervals(
[*snapshot.pending_restatement_intervals, removal_interval]
)

snapshot.apply_pending_restatement_intervals()
snapshot.update_next_auto_restatement_ts(execution_time)
Expand Down
13 changes: 6 additions & 7 deletions sqlmesh/core/state_sync/engine_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,14 @@ def unpause_snapshots(
target_snapshot.snapshot_id,
)
full_snapshot = snapshot.full_snapshot
self.remove_intervals(
[
(
full_snapshot,
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
)
]

removal_interval = full_snapshot.get_removal_interval(
effective_from_ts, current_ts
)

if removal_interval:
self.remove_intervals([(full_snapshot, removal_interval)])

if snapshot.unpaused_ts:
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
snapshot.set_unpaused_ts(None)
Expand Down
100 changes: 100 additions & 0 deletions tests/core/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2753,3 +2753,103 @@ def test_restate_production_model_in_dev(make_snapshot, mocker: MockerFixture):
"Cannot restate model '\"test_model_b\"' because the current version is used in production. "
"Run the restatement against the production environment instead to restate this model."
)


def test_restate_daily_to_monthly(make_snapshot, mocker: MockerFixture):
snapshot_a = make_snapshot(
SqlModel(
name="a",
query=parse_one("select 1 as one"),
cron="@daily",
start="2025-01-01",
),
)

snapshot_b = make_snapshot(
SqlModel(
name="b",
query=parse_one("select one from a"),
cron="@monthly",
start="2025-01-01",
),
nodes={'"a"': snapshot_a.model},
)

snapshot_c = make_snapshot(
SqlModel(
name="c",
query=parse_one("select one from b"),
cron="@daily",
start="2025-01-01",
),
nodes={
'"a"': snapshot_a.model,
'"b"': snapshot_b.model,
},
)

snapshot_d = make_snapshot(
SqlModel(
name="d",
query=parse_one("select one from b union all select one from a"),
cron="@daily",
start="2025-01-01",
),
nodes={
'"a"': snapshot_a.model,
'"b"': snapshot_b.model,
},
)
snapshot_e = make_snapshot(
SqlModel(
name="e",
query=parse_one("select one from b"),
cron="@daily",
start="2025-01-01",
),
nodes={
'"a"': snapshot_a.model,
'"b"': snapshot_b.model,
},
)

context_diff = ContextDiff(
environment="prod",
is_new_environment=False,
is_unfinalized_environment=True,
normalize_environment_name=True,
create_from="prod",
create_from_env_exists=True,
added=set(),
removed_snapshots={},
modified_snapshots={},
snapshots={
snapshot_a.snapshot_id: snapshot_a,
snapshot_b.snapshot_id: snapshot_b,
snapshot_c.snapshot_id: snapshot_c,
snapshot_d.snapshot_id: snapshot_d,
snapshot_e.snapshot_id: snapshot_e,
},
new_snapshots={},
previous_plan_id=None,
previously_promoted_snapshot_ids=set(),
previous_finalized_snapshots=None,
)

schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER

plan = PlanBuilder(
context_diff,
schema_differ,
restate_models=[snapshot_a.name, snapshot_e.name],
start="2025-02-15",
end="2025-02-20",
).build()

# b is invalid because it's monthly and the date range is partial
# c is invalid because it only depends on b which is not restated
assert plan.restatements == {
snapshot_a.snapshot_id: (1739577600000, 1740096000000),
snapshot_d.snapshot_id: (1739577600000, 1740096000000),
snapshot_e.snapshot_id: (1739577600000, 1740096000000),
}
2 changes: 1 addition & 1 deletion tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def test_missing_interval_smaller_than_interval_unit(make_snapshot):
]


def test_remove_intervals(snapshot: Snapshot):
def test_remove_intervals(snapshot):
snapshot.add_interval("2020-01-01", "2020-01-01")
snapshot.remove_interval(snapshot.get_removal_interval("2020-01-01", "2020-01-01"))
assert snapshot.intervals == []
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import pytest
import pandas as pd
from pydantic import field_validator, ValidationInfo, ValidationError
from pydantic import field_validator, ValidationError
from pathlib import Path
from pytest_mock.plugin import MockerFixture
from sqlglot import expressions as exp
Expand Down Expand Up @@ -3181,7 +3181,7 @@ class TestCustomKind(CustomKind):

@field_validator("primary_key", mode="before")
@classmethod
def _validate_primary_key(cls, value: t.Any, info: ValidationInfo) -> t.Any:
def _validate_primary_key(cls, value, info):
return list_of_fields_validator(value, info.data)

class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]):
Expand Down

0 comments on commit e9f37da

Please sign in to comment.