Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Feb 18, 2025
1 parent 82b068f commit de99e43
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 127 deletions.
28 changes: 1 addition & 27 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
from sqlmesh.core.model.kind import ModelKindName

expressions: t.List[exp.Expression] = []
gateway: t.Optional[exp.Expression] = None
blueprints: t.List[exp.Expression] = []

while True:
prev_property = seq_get(expressions, -1)
Expand Down Expand Up @@ -613,36 +611,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
else:
value = self._parse_bracket(self._parse_field(any_token=True))

if key == "gateway":
gateway = value
elif key == "blueprints":
if isinstance(value, exp.Paren):
blueprints = [value.this]
elif isinstance(value, (exp.Tuple, exp.Array)):
blueprints = value.expressions
else:
raise ConfigError(
"The 'blueprints' values need to be enclosed in "
f"parentheses or brackets, got {value} instead."
)

# We don't want to include blueprints in the property list
continue

if isinstance(value, exp.Expression):
value.meta["sql"] = self._find_sql(start, self._prev)

expressions.append(self.expression(exp.Property, this=key, value=value))

expression = self.expression(expression_type, expressions=expressions)

# We store these properties in the meta to provide quick access at load time
if blueprints:
expression.meta["blueprints"] = blueprints
if gateway:
expression.meta["gateway"] = gateway

return expression
return self.expression(expression_type, expressions=expressions)

return parse

Expand Down
135 changes: 40 additions & 95 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,20 @@
from dataclasses import dataclass
from pathlib import Path

from sqlglot import exp
from sqlglot.errors import SqlglotError
from sqlglot.helper import seq_get

from sqlmesh.core import constants as c
from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits
from sqlmesh.core.dialect import parse, Model as ModelMeta
from sqlmesh.core.macros import MacroRegistry, MacroVar, macro
from sqlmesh.core.dialect import parse
from sqlmesh.core.macros import MacroRegistry, macro
from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
from sqlmesh.core.model import (
Model,
ExternalModel,
ModelCache,
SeedModel,
create_external_model,
load_sql_based_model,
load_sql_based_models,
)
from sqlmesh.core.model import model as model_registry
from sqlmesh.core.signal import signal
Expand Down Expand Up @@ -294,7 +292,9 @@ def _track_file(self, path: Path) -> None:
"""Project file to track for modifications"""
self._path_mtimes[path] = path.stat().st_mtime

def _get_variables(self, gateway_name: str) -> t.Dict[str, t.Any]:
def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
gateway_name = gateway_name or self.context.selected_gateway

if gateway_name not in self._variables_by_gateway:
try:
gateway = self.config.get_gateway(gateway_name)
Expand Down Expand Up @@ -411,56 +411,24 @@ def _load() -> t.List[Model]:
except SqlglotError as ex:
raise ConfigError(f"Failed to parse a model definition at '{path}': {ex}.")

model_meta = seq_get(expressions, 0)
blueprints = isinstance(model_meta, ModelMeta) and model_meta.meta.get("blueprints")

loaded_models = []
for blueprint in blueprints or [None]:
if not blueprint:
blueprint_variables = {}
elif isinstance(blueprint, (exp.Tuple, exp.Array)):
blueprint_variables = {
# Assumes the syntax: (k := v [, ...]) or [k := v [, ...]]
e.left.name: e.right.sql(dialect=self.config.model_defaults.dialect)
for e in blueprint.expressions
}
elif isinstance(blueprint, (exp.Column, exp.Literal)):
blueprint_variables = {
"blueprint": blueprint.sql(dialect=self.config.model_defaults.dialect)
}
else:
# TODO: update docs link in this error message
raise ConfigError(
f"Invalid value for 'blueprints' {blueprints} at '{path}', please refer to <link>."
)

gateway = t.cast(exp.Expression, model_meta).meta.get("gateway")
if not isinstance(gateway, MacroVar) or not (
gateway_name := blueprint_variables.get(gateway.name)
):
gateway_name = self.context.selected_gateway

loaded_model = load_sql_based_model(
expressions,
defaults=self.config.model_defaults.dict(),
macros=macros,
jinja_macros=jinja_macros,
audit_definitions=audits,
default_audits=self.config.model_defaults.audits,
path=Path(path).absolute(),
module_path=self.config_path,
dialect=self.config.model_defaults.dialect,
time_column_format=self.config.time_column_format,
physical_schema_mapping=self.config.physical_schema_mapping,
project=self.config.project,
default_catalog=self.context.default_catalog,
variables={**self._get_variables(gateway_name), **blueprint_variables},
infer_names=self.config.model_naming.infer_names,
signal_definitions=signals,
)
loaded_models.append(loaded_model)

return loaded_models
return load_sql_based_models(
expressions,
self._get_variables,
defaults=self.config.model_defaults.dict(),
macros=macros,
jinja_macros=jinja_macros,
audit_definitions=audits,
default_audits=self.config.model_defaults.audits,
path=Path(path).absolute(),
module_path=self.config_path,
dialect=self.config.model_defaults.dialect,
time_column_format=self.config.time_column_format,
physical_schema_mapping=self.config.physical_schema_mapping,
project=self.config.project,
default_catalog=self.context.default_catalog,
infer_names=self.config.model_naming.infer_names,
signal_definitions=signals,
)

for model in cache.get_or_load_models(path, _load):
if model.enabled:
Expand Down Expand Up @@ -500,44 +468,21 @@ def _load_python_models(
new = registry.keys() - registered
registered |= new
for name in new:
registered_entrypoint = registry[name]

gateway = registered_entrypoint.kwargs.get("gateway") or ""
blueprints = registered_entrypoint.kwargs.pop("blueprints", None)

for blueprint in blueprints or [None]:
if not blueprint:
blueprint_variables = {}
elif isinstance(blueprint, dict):
blueprint_variables = blueprint
elif isinstance(blueprint, str):
blueprint_variables = {"blueprint": blueprint}
else:
# TODO: update docs link in this error message
raise ConfigError(
f"Invalid value for 'blueprints' {blueprints} at '{path}', please refer to <link>."
)

if not gateway.startswith("@") or not (
gateway_name := blueprint_variables.get(gateway[1:])
):
gateway_name = self.context.selected_gateway

model = registry[name].model(
path=path,
module_path=self.config_path,
defaults=self.config.model_defaults.dict(),
macros=macros,
jinja_macros=jinja_macros,
dialect=self.config.model_defaults.dialect,
time_column_format=self.config.time_column_format,
physical_schema_mapping=self.config.physical_schema_mapping,
project=self.config.project,
default_catalog=self.context.default_catalog,
variables={**self._get_variables(gateway_name), **blueprint_variables},
infer_names=self.config.model_naming.infer_names,
audit_definitions=audits,
)
for model in registry[name].models(
self._get_variables,
path=path,
module_path=self.config_path,
defaults=self.config.model_defaults.dict(),
macros=macros,
jinja_macros=jinja_macros,
dialect=self.config.model_defaults.dialect,
time_column_format=self.config.time_column_format,
physical_schema_mapping=self.config.physical_schema_mapping,
project=self.config.project,
default_catalog=self.context.default_catalog,
infer_names=self.config.model_naming.infer_names,
audit_definitions=audits,
):
if model.enabled:
models[model.fqn] = model
finally:
Expand Down Expand Up @@ -588,7 +533,7 @@ def _load_audits(
"""Loads all the model audits."""
audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
audits_max_mtime: t.Optional[float] = None
variables = self._get_variables(self.context.selected_gateway)
variables = self._get_variables()

for path in self._glob_paths(
self.config_path / c.AUDITS,
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
create_seed_model as create_seed_model,
create_sql_model as create_sql_model,
load_sql_based_model as load_sql_based_model,
load_sql_based_models as load_sql_based_models,
)
from sqlmesh.core.model.kind import (
CustomKind as CustomKind,
Expand Down
14 changes: 14 additions & 0 deletions sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Model,
create_python_model,
create_sql_model,
create_model_blueprints,
get_model_name,
render_meta_fields,
)
Expand Down Expand Up @@ -84,6 +85,19 @@ def __call__(
self.name = get_model_name(Path(inspect.getfile(func)))
return super().__call__(func)

def models(
self,
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
**loader_kwargs: t.Any,
) -> t.List[Model]:
return create_model_blueprints(
gateway=self.kwargs.get("gateway"),
blueprints=self.kwargs.pop("blueprints", None),
get_variables=get_variables,
loader=self.model,
**loader_kwargs,
)

def model(
self,
*,
Expand Down
81 changes: 80 additions & 1 deletion sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import types
import re
import typing as t
from functools import cached_property
from functools import cached_property, partial
from pathlib import Path

import pandas as pd
import numpy as np
from pydantic import Field
from sqlglot import diff, exp
from sqlglot.diff import Insert
from sqlglot.helper import seq_get
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import gen
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
Expand Down Expand Up @@ -1784,6 +1785,84 @@ class AuditResult(PydanticModel):
blocking: bool = True


def create_model_blueprints(
gateway: t.Optional[str | exp.Expression],
blueprints: t.Any,
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
loader: t.Callable[..., Model],
**loader_kwargs: t.Any,
) -> t.List[Model]:
path = loader_kwargs.get("path")
dialect = loader_kwargs.get("dialect")

if not blueprints:
blueprints = [None]
elif isinstance(blueprints, exp.Paren):
blueprints = [blueprints.this]
elif isinstance(blueprints, (exp.Tuple, exp.Array)):
blueprints = blueprints.expressions
elif not isinstance(blueprints, list):
# TODO: put docs link here
raise ConfigError(
f"Invalid 'blueprints' property '{blueprints}' at '{path}', please refer to <link>."
)

model_blueprints: t.List[Model] = []
for blueprint in blueprints:
if not blueprint:
variables = {}
elif isinstance(blueprint, exp.Paren):
blueprint = blueprint.unnest()
variables = {blueprint.left.name: blueprint.right.sql(dialect=dialect)}
elif isinstance(blueprint, (exp.Tuple, exp.Array)):
variables = {e.left.name: e.right.sql(dialect=dialect) for e in blueprint.expressions}
elif isinstance(blueprint, dict):
variables = blueprint
else:
# TODO: put docs link here
raise ConfigError(
f"Invalid blueprint value: {blueprint} at '{path}', please refer to <link>."
)

if isinstance(gateway, d.MacroVar):
gateway_name = variables.get(gateway.name)
elif isinstance(gateway, str) and gateway.startswith("@"):
gateway_name = variables.get(gateway[1:])
else:
gateway_name = None

model_blueprints.append(
loader(variables={**get_variables(gateway_name), **variables}, **loader_kwargs)
)

return model_blueprints


def load_sql_based_models(
expressions: t.List[exp.Expression],
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
**loader_kwargs: t.Any,
) -> t.List[Model]:
gateway: t.Optional[exp.Expression] = None
blueprints: t.Optional[t.List[t.Optional[exp.Expression]]] = None

model_meta = seq_get(expressions, 0)
for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []:
if prop.name == "gateway":
gateway = prop.args["value"]
elif prop.name == "blueprints":
# We pop the blueprints property here because it shouldn't be part of the model
blueprints = prop.pop().args["value"]

return create_model_blueprints(
gateway=gateway,
blueprints=blueprints,
get_variables=get_variables,
loader=partial(load_sql_based_model, expressions),
**loader_kwargs,
)


def load_sql_based_model(
expressions: t.List[exp.Expression],
*,
Expand Down
Loading

0 comments on commit de99e43

Please sign in to comment.