Skip to content

Commit

Permalink
fix dialect selector
Browse files Browse the repository at this point in the history
  • Loading branch information
nolanbconaway committed Sep 30, 2024
1 parent 2dfa0f1 commit 27a1379
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inject_vars():
"""Inject arbitrary data into all templates."""
return dict(
all_rules=config.VALID_RULES,
all_dialects=config.VALID_DIALECTS,
all_dialects=list(config.VALID_DIALECTS.values()),
sqlfluff_version=config.SQLFLUFF_VERSION,
)

Expand Down
2 changes: 1 addition & 1 deletion src/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SQLFLUFF_VERSION = sqlfluff.__version__

VALID_DIALECTS = tuple(d.name for d in sqlfluff.list_dialects())
VALID_DIALECTS = {d.label: d.name for d in sqlfluff.list_dialects()}

# dict mapping string rule names to descriptions
VALID_RULES = {r.code: r.description for r in sqlfluff.list_rules()}
9 changes: 9 additions & 0 deletions src/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flask import Blueprint, redirect, render_template, request, url_for
from sqlfluff.api import fix, lint
from .config import VALID_DIALECTS

bp = Blueprint("routes", __name__)

Expand Down Expand Up @@ -43,7 +44,15 @@ def fluff_results():
sql = sql_decode(request.args["sql"]).strip()
sql = "\n".join(sql.splitlines()) + "\n"

# dialect must be a dialect label for `load_raw_dialect`. VALID_DIALECTS is a
# dictionary of dialect labels to dialect names. If we have a name, we need to
# get the label.
dialect = request.args["dialect"]
if dialect in VALID_DIALECTS.values():
dialect = next(
label for label, name in VALID_DIALECTS.items() if name == dialect
)

try:
linted = lint(sql, dialect=dialect)
fixed_sql = fix(sql, dialect=dialect)
Expand Down
11 changes: 8 additions & 3 deletions test/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ def test_post_redirect(client):
assert rv.status_code == 302 and "/fluffed?sql" in rv.headers["location"]


def test_results_no_errors(client):
"""Test that the results is good to go when there is no error."""
@pytest.mark.parametrize("dialect", ["sparksql", "Apache Spark SQL"])
def test_results_no_errors(client, dialect):
"""Test that the results is good to go when there is no error.
Parameterized dialect asserts that either the formatted name or label can be used
as the dialect parameter.
"""
sql_encoded = sql_encode("select * from table")
rv = client.get("/fluffed", query_string=f"""dialect=ansi&sql={sql_encoded}""")
rv = client.get("/fluffed", query_string=f"""dialect={dialect}&sql={sql_encoded}""")
html = rv.data.decode().lower()
assert "sqlfluff online" in html
assert "fixed sql" in html
Expand Down

0 comments on commit 27a1379

Please sign in to comment.