diff --git a/src/app/__init__.py b/src/app/__init__.py index 24809fb..68cc911 100644 --- a/src/app/__init__.py +++ b/src/app/__init__.py @@ -38,7 +38,7 @@ def inject_vars(): """Inject arbitrary data into all templates.""" return dict( all_rules=config.VALID_RULES, - all_dialects=config.VALID_DIALECTS, + all_dialects=list(config.VALID_DIALECTS.values()), sqlfluff_version=config.SQLFLUFF_VERSION, ) diff --git a/src/app/config.py b/src/app/config.py index 65524da..2408f0e 100644 --- a/src/app/config.py +++ b/src/app/config.py @@ -4,7 +4,7 @@ SQLFLUFF_VERSION = sqlfluff.__version__ -VALID_DIALECTS = tuple(d.name for d in sqlfluff.list_dialects()) +VALID_DIALECTS = {d.label: d.name for d in sqlfluff.list_dialects()} # dict mapping string rule names to descriptions VALID_RULES = {r.code: r.description for r in sqlfluff.list_rules()} diff --git a/src/app/routes.py b/src/app/routes.py index 455c6c1..b65bf9f 100644 --- a/src/app/routes.py +++ b/src/app/routes.py @@ -3,6 +3,7 @@ from flask import Blueprint, redirect, render_template, request, url_for from sqlfluff.api import fix, lint +from .config import VALID_DIALECTS bp = Blueprint("routes", __name__) @@ -43,7 +44,15 @@ def fluff_results(): sql = sql_decode(request.args["sql"]).strip() sql = "\n".join(sql.splitlines()) + "\n" + # dialect must be a dialect label for `load_raw_dialect`. VALID_DIALECTS is a + # dictionary of dialect labels to dialect names. If we have a name, we need to + # get the label. dialect = request.args["dialect"] + if dialect in VALID_DIALECTS.values(): + dialect = next( + label for label, name in VALID_DIALECTS.items() if name == dialect + ) + try: linted = lint(sql, dialect=dialect) fixed_sql = fix(sql, dialect=dialect) diff --git a/test/test_app.py b/test/test_app.py index 2ff7340..3a7e72e 100644 --- a/test/test_app.py +++ b/test/test_app.py @@ -36,10 +36,15 @@ def test_post_redirect(client): assert rv.status_code == 302 and "/fluffed?sql" in rv.headers["location"] -def test_results_no_errors(client): - """Test that the results is good to go when there is no error.""" +@pytest.mark.parametrize("dialect", ["sparksql", "Apache Spark SQL"]) +def test_results_no_errors(client, dialect): + """Test that the results is good to go when there is no error. + + Parameterized dialect asserts that either the formatted name or label can be used + as the dialect parameter. + """ sql_encoded = sql_encode("select * from table") - rv = client.get("/fluffed", query_string=f"""dialect=ansi&sql={sql_encoded}""") + rv = client.get("/fluffed", query_string=f"""dialect={dialect}&sql={sql_encoded}""") html = rv.data.decode().lower() assert "sqlfluff online" in html assert "fixed sql" in html