Skip to content

Commit

Permalink
try to change inclusive_exclusive
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Feb 24, 2025
1 parent e9f37da commit b37ff4d
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 45 deletions.
6 changes: 3 additions & 3 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,19 +344,19 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
if not forward_only_preview_needed:
if self._is_dev and not snapshot.is_paused:
self._console.log_warning(
f"Cannot restate model '{model_fqn}' because the current version is used in production. "
f"Cannot restate model '{snapshot.name}' because the current version is used in production. "
"Run the restatement against the production environment instead to restate this model."
)
continue
elif (not self._is_dev or not snapshot.is_paused) and snapshot.disable_restatement:
self._console.log_warning(
f"Cannot restate model '{model_fqn}'. "
f"Cannot restate model '{snapshot.name}'. "
"Restatement is disabled for this model to prevent possible data loss."
"If you want to restate this model, change the model's `disable_restatement` setting to `false`."
)
continue
elif snapshot.is_symbolic or snapshot.is_seed:
logger.info("Skipping restatement for model '%s'", model_fqn)
logger.info("Skipping restatement for model '%s'", snapshot.name)
continue

removal_interval = snapshot.get_removal_interval(
Expand Down
43 changes: 24 additions & 19 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def add_interval(self, start: TimeLike, end: TimeLike, is_dev: bool = False) ->
f"Attempted to add an Invalid interval ({start}, {end}) to snapshot {self.snapshot_id}"
)

start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False, expand=False)

if start_ts >= end_ts:
# Skipping partial interval.
return
Expand Down Expand Up @@ -752,6 +753,7 @@ def inclusive_exclusive(
end: TimeLike,
strict: bool = True,
allow_partial: t.Optional[bool] = None,
expand: bool = True,
) -> Interval:
"""Transform the inclusive start and end into a [start, end) pair.
Expand All @@ -760,6 +762,7 @@ def inclusive_exclusive(
end: The end date/time of the interval (inclusive)
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.
Returns:
A [start, end) pair.
Expand All @@ -770,9 +773,9 @@ def inclusive_exclusive(
start,
end,
self.node.interval_unit,
model_allow_partials=self.is_model and self.model.allow_partials,
strict=strict,
allow_partial=allow_partial,
expand=expand,
)

def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
Expand Down Expand Up @@ -867,12 +870,7 @@ def missing_intervals(
allow_partials = self.is_model and self.model.allow_partials
start_ts, end_ts = (
to_timestamp(ts)
for ts in self.inclusive_exclusive(
start,
end,
strict=False,
allow_partial=allow_partials,
)
for ts in self.inclusive_exclusive(start, end)
)

interval_unit = self.node.interval_unit
Expand Down Expand Up @@ -1867,7 +1865,6 @@ def inclusive_exclusive(
start: TimeLike,
end: TimeLike,
interval_unit: IntervalUnit,
model_allow_partials: bool,
strict: bool = True,
allow_partial: bool = False,
) -> Interval:
Expand All @@ -1877,26 +1874,34 @@ def inclusive_exclusive(
start: The start date/time of the interval (inclusive)
end: The end date/time of the interval (inclusive)
interval_unit: The interval unit.
model_allow_partials: Whether or not the model allows partials.
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.
Returns:
A [start, end) pair.
"""
start_ts = to_timestamp(interval_unit.cron_floor(start))
if start_ts < to_timestamp(start) and not model_allow_partials:
start_ts = to_timestamp(interval_unit.cron_next(start_ts))
start_dt = interval_unit.cron_floor(start)

if not expand and not allow_partial and start_dt < to_datetime(start):
start_dt = interval_unit.cron_next(start_dt)

start_ts = to_timestamp(start_dt)

if is_date(end):
end = to_datetime(end) + timedelta(days=1)
end_ts = to_timestamp(interval_unit.cron_floor(end) if not allow_partial else end)
if end_ts < start_ts and to_timestamp(end) > to_timestamp(start) and not strict:
# This can happen when the interval unit is coarser than the size of the input interval.
# For example, if the interval unit is monthly, but the input interval is only 1 hour long.
return (start_ts, end_ts)

if (strict and start_ts >= end_ts) or (start_ts > end_ts):
if allow_partial:
end_dt = end
else:
end_dt = interval_unit.cron_floor(end)

if expand and end_dt != to_datetime(end):
end_dt = interval_unit.cron_next(end_dt)

end_ts = to_timestamp(end_dt)

if strict and start_ts >= end_ts:
raise ValueError(
f"`end` ({to_datetime(end_ts)}) must be greater than `start` ({to_datetime(start_ts)})"
)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def add_interval(
end: The end of the interval to add.
is_dev: Indicates whether the given interval is being added while in development mode
"""
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
if not snapshot.version:
raise SQLMeshError("Snapshot version must be set to add an interval.")
intervals = [(start_ts, end_ts)]
Expand Down
8 changes: 3 additions & 5 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable):
assert not plan.new_snapshots
assert plan.requires_backfill
assert plan.restatements == {
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-08"))
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
}
assert plan.missing_intervals == [
SnapshotIntervals(
Expand Down Expand Up @@ -4562,16 +4562,14 @@ def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable)
no_prompts=True,
)

restatement_end = to_timestamp("2023-01-08")

sushi_customer_interval = restatement_plan.restatements[
context.get_snapshot("sushi.customers").snapshot_id
]
assert sushi_customer_interval == (to_timestamp("2023-01-01"), restatement_end)
assert sushi_customer_interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
waiter_by_day_interval = restatement_plan.restatements[
context.get_snapshot("sushi.waiter_as_customer_by_day").snapshot_id
]
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), restatement_end)
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))


def initial_add(context: Context, environment: str):
Expand Down
34 changes: 19 additions & 15 deletions tests/core/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,30 +738,34 @@ def test_restate_models(sushi_context_pre_scheduling: Context):
plan = sushi_context_pre_scheduling.plan(
restate_models=["sushi.waiter_revenue_by_day", "tag:expensive"], no_prompts=True
)

start = to_timestamp(plan.start)
tomorrow = to_timestamp(to_date("tomorrow"))

assert plan.restatements == {
sushi_context_pre_scheduling.get_snapshot(
"sushi.waiter_revenue_by_day", raise_if_missing=True
).snapshot_id: (
to_timestamp(plan.start),
to_timestamp(to_date("today")),
start,
tomorrow,
),
sushi_context_pre_scheduling.get_snapshot(
"sushi.top_waiters", raise_if_missing=True
).snapshot_id: (
to_timestamp(plan.start),
to_timestamp(to_date("today")),
start,
tomorrow,
),
sushi_context_pre_scheduling.get_snapshot(
"sushi.customer_revenue_by_day", raise_if_missing=True
).snapshot_id: (
to_timestamp(plan.start),
to_timestamp(to_date("today")),
start,
tomorrow,
),
sushi_context_pre_scheduling.get_snapshot(
"sushi.customer_revenue_lifetime", raise_if_missing=True
).snapshot_id: (
to_timestamp(plan.start),
to_timestamp(to_date("today")),
start,
tomorrow,
),
}
assert plan.requires_backfill
Expand Down Expand Up @@ -828,7 +832,7 @@ def test_restate_models_with_existing_missing_intervals(init_and_plan_context: t
),
top_waiters_snapshot_id: (
plan_start_ts,
today_ts,
to_timestamp(to_date("tomorrow")),
),
}
assert plan.missing_intervals == [
Expand Down Expand Up @@ -1850,7 +1854,7 @@ def test_disable_restatement(make_snapshot, mocker: MockerFixture):
# Restatements should still be supported when in dev.
plan = PlanBuilder(context_diff, schema_differ, is_dev=True, restate_models=['"a"']).build()
assert plan.restatements == {
snapshot.snapshot_id: (to_timestamp(plan.start), to_timestamp(to_date("today")))
snapshot.snapshot_id: (to_timestamp(plan.start), to_timestamp(to_date("tomorrow")))
}

# We don't want to restate a disable_restatement model if it is unpaused since that would be mean we are violating
Expand Down Expand Up @@ -2846,10 +2850,10 @@ def test_restate_daily_to_monthly(make_snapshot, mocker: MockerFixture):
end="2025-02-20",
).build()

# b is invalid because it's monthly and the date range is partial
# c is invalid because it only depends on b which is not restated
assert plan.restatements == {
snapshot_a.snapshot_id: (1739577600000, 1740096000000),
snapshot_d.snapshot_id: (1739577600000, 1740096000000),
snapshot_e.snapshot_id: (1739577600000, 1740096000000),
snapshot_a.snapshot_id: (1739577600000, 1740441600000),
snapshot_b.snapshot_id: (1738368000000, 1740787200000),
snapshot_c.snapshot_id: (1739577600000, 1740441600000),
snapshot_d.snapshot_id: (1739577600000, 1740441600000),
snapshot_e.snapshot_id: (1739577600000, 1740441600000),
}
5 changes: 3 additions & 2 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def test_missing_intervals(snapshot: Snapshot):
assert snapshot.missing_intervals("2020-01-03 00:00:01", "2020-01-05 00:00:02") == []
assert snapshot.missing_intervals("2020-01-03 00:00:01", "2020-01-07 00:00:02") == [
(to_timestamp("2020-01-06"), to_timestamp("2020-01-07")),
(to_timestamp("2020-01-07"), to_timestamp("2020-01-08")),
]


Expand Down Expand Up @@ -1517,12 +1518,12 @@ def test_inclusive_exclusive_monthly(make_snapshot):

assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-01") == (
to_timestamp("2023-01-01"),
to_timestamp("2023-07-01"),
to_timestamp("2023-08-01"),
)

assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-06") == (
to_timestamp("2023-01-01"),
to_timestamp("2023-07-01"),
to_timestamp("2023-08-01"),
)

assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-31") == (
Expand Down

0 comments on commit b37ff4d

Please sign in to comment.