From de99e43129d22e76f5af3a51e08797db61a1918a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 18 Feb 2025 02:48:58 +0200 Subject: [PATCH] Refactor --- sqlmesh/core/dialect.py | 28 +------ sqlmesh/core/loader.py | 135 +++++++++---------------------- sqlmesh/core/model/__init__.py | 1 + sqlmesh/core/model/decorator.py | 14 ++++ sqlmesh/core/model/definition.py | 81 ++++++++++++++++++- tests/core/test_model.py | 11 ++- 6 files changed, 143 insertions(+), 127 deletions(-) diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 491dcb90f7..c5999755e0 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -548,8 +548,6 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: from sqlmesh.core.model.kind import ModelKindName expressions: t.List[exp.Expression] = [] - gateway: t.Optional[exp.Expression] = None - blueprints: t.List[exp.Expression] = [] while True: prev_property = seq_get(expressions, -1) @@ -613,36 +611,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: else: value = self._parse_bracket(self._parse_field(any_token=True)) - if key == "gateway": - gateway = value - elif key == "blueprints": - if isinstance(value, exp.Paren): - blueprints = [value.this] - elif isinstance(value, (exp.Tuple, exp.Array)): - blueprints = value.expressions - else: - raise ConfigError( - "The 'blueprints' values need to be enclosed in " - f"parentheses or brackets, got {value} instead." - ) - - # We don't want to include blueprints in the property list - continue - if isinstance(value, exp.Expression): value.meta["sql"] = self._find_sql(start, self._prev) expressions.append(self.expression(exp.Property, this=key, value=value)) - expression = self.expression(expression_type, expressions=expressions) - - # We store these properties in the meta to provide quick access at load time - if blueprints: - expression.meta["blueprints"] = blueprints - if gateway: - expression.meta["gateway"] = gateway - - return expression + return self.expression(expression_type, expressions=expressions) return parse diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index baca71276e..9e7efbb56a 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -10,14 +10,12 @@ from dataclasses import dataclass from pathlib import Path -from sqlglot import exp from sqlglot.errors import SqlglotError -from sqlglot.helper import seq_get from sqlmesh.core import constants as c from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits -from sqlmesh.core.dialect import parse, Model as ModelMeta -from sqlmesh.core.macros import MacroRegistry, MacroVar, macro +from sqlmesh.core.dialect import parse +from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl from sqlmesh.core.model import ( Model, @@ -25,7 +23,7 @@ ModelCache, SeedModel, create_external_model, - load_sql_based_model, + load_sql_based_models, ) from sqlmesh.core.model import model as model_registry from sqlmesh.core.signal import signal @@ -294,7 +292,9 @@ def _track_file(self, path: Path) -> None: """Project file to track for modifications""" self._path_mtimes[path] = path.stat().st_mtime - def _get_variables(self, gateway_name: str) -> t.Dict[str, t.Any]: + def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]: + gateway_name = gateway_name or self.context.selected_gateway + if gateway_name not in self._variables_by_gateway: try: gateway = self.config.get_gateway(gateway_name) @@ -411,56 +411,24 @@ def _load() -> t.List[Model]: except SqlglotError as ex: raise ConfigError(f"Failed to parse a model definition at '{path}': {ex}.") - model_meta = seq_get(expressions, 0) - blueprints = isinstance(model_meta, ModelMeta) and model_meta.meta.get("blueprints") - - loaded_models = [] - for blueprint in blueprints or [None]: - if not blueprint: - blueprint_variables = {} - elif isinstance(blueprint, (exp.Tuple, exp.Array)): - blueprint_variables = { - # Assumes the syntax: (k := v [, ...]) or [k := v [, ...]] - e.left.name: e.right.sql(dialect=self.config.model_defaults.dialect) - for e in blueprint.expressions - } - elif isinstance(blueprint, (exp.Column, exp.Literal)): - blueprint_variables = { - "blueprint": blueprint.sql(dialect=self.config.model_defaults.dialect) - } - else: - # TODO: update docs link in this error message - raise ConfigError( - f"Invalid value for 'blueprints' {blueprints} at '{path}', please refer to ." - ) - - gateway = t.cast(exp.Expression, model_meta).meta.get("gateway") - if not isinstance(gateway, MacroVar) or not ( - gateway_name := blueprint_variables.get(gateway.name) - ): - gateway_name = self.context.selected_gateway - - loaded_model = load_sql_based_model( - expressions, - defaults=self.config.model_defaults.dict(), - macros=macros, - jinja_macros=jinja_macros, - audit_definitions=audits, - default_audits=self.config.model_defaults.audits, - path=Path(path).absolute(), - module_path=self.config_path, - dialect=self.config.model_defaults.dialect, - time_column_format=self.config.time_column_format, - physical_schema_mapping=self.config.physical_schema_mapping, - project=self.config.project, - default_catalog=self.context.default_catalog, - variables={**self._get_variables(gateway_name), **blueprint_variables}, - infer_names=self.config.model_naming.infer_names, - signal_definitions=signals, - ) - loaded_models.append(loaded_model) - - return loaded_models + return load_sql_based_models( + expressions, + self._get_variables, + defaults=self.config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, + audit_definitions=audits, + default_audits=self.config.model_defaults.audits, + path=Path(path).absolute(), + module_path=self.config_path, + dialect=self.config.model_defaults.dialect, + time_column_format=self.config.time_column_format, + physical_schema_mapping=self.config.physical_schema_mapping, + project=self.config.project, + default_catalog=self.context.default_catalog, + infer_names=self.config.model_naming.infer_names, + signal_definitions=signals, + ) for model in cache.get_or_load_models(path, _load): if model.enabled: @@ -500,44 +468,21 @@ def _load_python_models( new = registry.keys() - registered registered |= new for name in new: - registered_entrypoint = registry[name] - - gateway = registered_entrypoint.kwargs.get("gateway") or "" - blueprints = registered_entrypoint.kwargs.pop("blueprints", None) - - for blueprint in blueprints or [None]: - if not blueprint: - blueprint_variables = {} - elif isinstance(blueprint, dict): - blueprint_variables = blueprint - elif isinstance(blueprint, str): - blueprint_variables = {"blueprint": blueprint} - else: - # TODO: update docs link in this error message - raise ConfigError( - f"Invalid value for 'blueprints' {blueprints} at '{path}', please refer to ." - ) - - if not gateway.startswith("@") or not ( - gateway_name := blueprint_variables.get(gateway[1:]) - ): - gateway_name = self.context.selected_gateway - - model = registry[name].model( - path=path, - module_path=self.config_path, - defaults=self.config.model_defaults.dict(), - macros=macros, - jinja_macros=jinja_macros, - dialect=self.config.model_defaults.dialect, - time_column_format=self.config.time_column_format, - physical_schema_mapping=self.config.physical_schema_mapping, - project=self.config.project, - default_catalog=self.context.default_catalog, - variables={**self._get_variables(gateway_name), **blueprint_variables}, - infer_names=self.config.model_naming.infer_names, - audit_definitions=audits, - ) + for model in registry[name].models( + self._get_variables, + path=path, + module_path=self.config_path, + defaults=self.config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, + dialect=self.config.model_defaults.dialect, + time_column_format=self.config.time_column_format, + physical_schema_mapping=self.config.physical_schema_mapping, + project=self.config.project, + default_catalog=self.context.default_catalog, + infer_names=self.config.model_naming.infer_names, + audit_definitions=audits, + ): if model.enabled: models[model.fqn] = model finally: @@ -588,7 +533,7 @@ def _load_audits( """Loads all the model audits.""" audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits") audits_max_mtime: t.Optional[float] = None - variables = self._get_variables(self.context.selected_gateway) + variables = self._get_variables() for path in self._glob_paths( self.config_path / c.AUDITS, diff --git a/sqlmesh/core/model/__init__.py b/sqlmesh/core/model/__init__.py index 301524425d..c2ab47d9e7 100644 --- a/sqlmesh/core/model/__init__.py +++ b/sqlmesh/core/model/__init__.py @@ -15,6 +15,7 @@ create_seed_model as create_seed_model, create_sql_model as create_sql_model, load_sql_based_model as load_sql_based_model, + load_sql_based_models as load_sql_based_models, ) from sqlmesh.core.model.kind import ( CustomKind as CustomKind, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 267e493533..bfbbe723e8 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -16,6 +16,7 @@ Model, create_python_model, create_sql_model, + create_model_blueprints, get_model_name, render_meta_fields, ) @@ -84,6 +85,19 @@ def __call__( self.name = get_model_name(Path(inspect.getfile(func))) return super().__call__(func) + def models( + self, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + **loader_kwargs: t.Any, + ) -> t.List[Model]: + return create_model_blueprints( + gateway=self.kwargs.get("gateway"), + blueprints=self.kwargs.pop("blueprints", None), + get_variables=get_variables, + loader=self.model, + **loader_kwargs, + ) + def model( self, *, diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 56bb45ece2..958b2a3701 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -5,7 +5,7 @@ import types import re import typing as t -from functools import cached_property +from functools import cached_property, partial from pathlib import Path import pandas as pd @@ -13,6 +13,7 @@ from pydantic import Field from sqlglot import diff, exp from sqlglot.diff import Insert +from sqlglot.helper import seq_get from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import gen from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -1784,6 +1785,84 @@ class AuditResult(PydanticModel): blocking: bool = True +def create_model_blueprints( + gateway: t.Optional[str | exp.Expression], + blueprints: t.Any, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + loader: t.Callable[..., Model], + **loader_kwargs: t.Any, +) -> t.List[Model]: + path = loader_kwargs.get("path") + dialect = loader_kwargs.get("dialect") + + if not blueprints: + blueprints = [None] + elif isinstance(blueprints, exp.Paren): + blueprints = [blueprints.this] + elif isinstance(blueprints, (exp.Tuple, exp.Array)): + blueprints = blueprints.expressions + elif not isinstance(blueprints, list): + # TODO: put docs link here + raise ConfigError( + f"Invalid 'blueprints' property '{blueprints}' at '{path}', please refer to ." + ) + + model_blueprints: t.List[Model] = [] + for blueprint in blueprints: + if not blueprint: + variables = {} + elif isinstance(blueprint, exp.Paren): + blueprint = blueprint.unnest() + variables = {blueprint.left.name: blueprint.right.sql(dialect=dialect)} + elif isinstance(blueprint, (exp.Tuple, exp.Array)): + variables = {e.left.name: e.right.sql(dialect=dialect) for e in blueprint.expressions} + elif isinstance(blueprint, dict): + variables = blueprint + else: + # TODO: put docs link here + raise ConfigError( + f"Invalid blueprint value: {blueprint} at '{path}', please refer to ." + ) + + if isinstance(gateway, d.MacroVar): + gateway_name = variables.get(gateway.name) + elif isinstance(gateway, str) and gateway.startswith("@"): + gateway_name = variables.get(gateway[1:]) + else: + gateway_name = None + + model_blueprints.append( + loader(variables={**get_variables(gateway_name), **variables}, **loader_kwargs) + ) + + return model_blueprints + + +def load_sql_based_models( + expressions: t.List[exp.Expression], + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + **loader_kwargs: t.Any, +) -> t.List[Model]: + gateway: t.Optional[exp.Expression] = None + blueprints: t.Optional[t.List[t.Optional[exp.Expression]]] = None + + model_meta = seq_get(expressions, 0) + for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []: + if prop.name == "gateway": + gateway = prop.args["value"] + elif prop.name == "blueprints": + # We pop the blueprints property here because it shouldn't be part of the model + blueprints = prop.pop().args["value"] + + return create_model_blueprints( + gateway=gateway, + blueprints=blueprints, + get_variables=get_variables, + loader=partial(load_sql_based_model, expressions), + **loader_kwargs, + ) + + def load_sql_based_model( expressions: t.List[exp.Expression], *, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 752bcb3801..82078f4a1c 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -7600,7 +7600,7 @@ def test_model_blueprinting(tmp_path: Path) -> None: MODEL ( name @{blueprint}.test_model_sql, gateway @blueprint, - blueprints (gw1, gw2), + blueprints ((blueprint := gw1), (blueprint := gw2)), kind FULL ); @@ -7619,7 +7619,7 @@ def test_model_blueprinting(tmp_path: Path) -> None: @model( "@{blueprint}.test_model_pydf", gateway="@blueprint", - blueprints=["gw1", "gw2"], + blueprints=[{"blueprint": "gw1"}, {"blueprint": "gw2"}], kind="FULL", columns={"x": "INT"}, ) @@ -7637,7 +7637,7 @@ def entrypoint(context, *args, **kwargs): @model( "@{blueprint}.test_model_pysql", gateway="@blueprint", - blueprints=["gw1", "gw2"], + blueprints=[{"blueprint": "gw1"}, {"blueprint": "gw2"}], kind="FULL", is_sql=True, ) @@ -7669,7 +7669,10 @@ def entrypoint(evaluator): """ MODEL ( name @{customer}.my_table, - blueprints [(customer=customer1, foo='bar'), (customer=customer2, foo=qux)], + blueprints ( + (customer := customer1, foo := 'bar'), + (customer := customer2, foo := qux), + ), kind FULL );