Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Generate per-dialect options on init project #3733

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions sqlmesh/cli/example_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from sqlmesh.integrations.dlt import generate_dlt_models_and_settings
from sqlmesh.utils.date import yesterday_ds

from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE


PRIMITIVES = (str, int, bool, float)


class ProjectTemplate(Enum):
AIRFLOW = "airflow"
Expand All @@ -23,30 +28,66 @@ def _gen_config(
start: t.Optional[str],
template: ProjectTemplate,
) -> str:
connection_settings = (
settings
or """ type: duckdb
if not settings:
connection_settings = """ type: duckdb
database: db.db"""
)

doc_link = "# Visit https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link} for more information on configuring the connection to your execution engine."
engine_link = ""

engine = "mssql" if dialect == "tsql" else dialect

if engine in CONNECTION_CONFIG_TO_TYPE:
required_fields = []
non_required_fields = []

for name, field in CONNECTION_CONFIG_TO_TYPE[engine].model_fields.items():
field_name = field.alias or name
default_value = field.get_default()

if isinstance(default_value, Enum):
default_value = default_value.value
elif not isinstance(default_value, PRIMITIVES):
default_value = None

required = field.is_required() or field_name == "type"
option_str = (
f" {'# ' if not required else ''}{field_name}: {default_value or ''}\n"
)

if required:
required_fields.append(option_str)
else:
non_required_fields.append(option_str)

connection_settings = "".join(required_fields + non_required_fields)

engine_link = f"/{engine}/#connection-options"

connection_settings = (
f" {doc_link.format(engine_link=engine_link)}\n{connection_settings}"
)
else:
connection_settings = settings

default_configs = {
ProjectTemplate.DEFAULT: f"""gateways:
local:
dev:
connection:
{connection_settings}

default_gateway: local
default_gateway: dev

model_defaults:
dialect: {dialect}
start: {start or yesterday_ds()}
""",
ProjectTemplate.AIRFLOW: f"""gateways:
local:
dev:
connection:
{connection_settings}
{connection_settings}

default_gateway: local
default_gateway: dev

default_scheduler:
type: airflow
Expand Down
27 changes: 25 additions & 2 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,12 +793,12 @@ def test_plan_dlt(runner, tmp_path):
init_example_project(tmp_path, "duckdb", ProjectTemplate.DLT, "sushi")

expected_config = f"""gateways:
local:
dev:
connection:
type: duckdb
database: {dataset_path}

default_gateway: local
default_gateway: dev

model_defaults:
dialect: duckdb
Expand Down Expand Up @@ -948,3 +948,26 @@ def test_plan_dlt(runner, tmp_path):
assert dlt_sushi_twice_nested_model_path.exists()
finally:
remove(dataset_path)


def test_init_project_dialects(runner, tmp_path):
dialect_to_config = {
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: ",
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: \n # force_databricks_connect: \n # disable_databricks_connect: \n # disable_spark_session: ",
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: ",
}

for dialect, expected_config in dialect_to_config.items():
init_example_project(tmp_path, dialect=dialect)

config_start = f"gateways:\n dev:\n connection:\n # Visit https://sqlmesh.readthedocs.io/en/stable/integrations/engines/{dialect}/#connection-options for more information on configuring the connection to your execution engine.\n type: {dialect}\n "
config_end = f"\n\n\ndefault_gateway: dev\n\nmodel_defaults:\n dialect: {dialect}\n start: 2025-01-29\n"

with open(tmp_path / "config.yaml") as file:
config = file.read()

assert config == f"{config_start}{expected_config}{config_end}"

remove(tmp_path / "config.yaml")