Skip to content

Commit

Permalink
feat: support create function, close #282
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Oct 31, 2022
1 parent b1520ed commit 9318697
Show file tree
Hide file tree
Showing 19 changed files with 646 additions and 65 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ All notable changes to this project will be documented in this file.

- sqlfmt now supports `delete` statements and the associated keywords `using` and `returning` ([#281](https://github.com/tconbeer/sqlfmt/issues/281))
- sqlfmt now supports `grant` and `revoke` statements and all associated keywords ([#283](https://github.com/tconbeer/sqlfmt/issues/283))
- sqlfmt now supports `create function` statements and all associated keywords ([#282](https://github.com/tconbeer/sqlfmt/issues/282))
- sqlfmt now supports the `explain` keyword ([#280](https://github.com/tconbeer/sqlfmt/issues/280))
- sqlfmt now supports BigQuery typed table and struct definitions and literals, like `table<a int64, b bytes(5), c string>`

### Features

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ sqlfmt is not configurable, except for line length. It enforces a single style.

sqlfmt is not a linter. It does not parse your code into an AST; it just lexes it and tracks a small subset of tokens that impact formatting. This lets us "do one thing and do it well:" sqlfmt is very fast, and easier to maintain and extend than linters that need a full SQL grammar.

For now, sqlfmt only works on `select` statements (which is all you need if you use sqlfmt with a dbt project). In the future, it will be extended to DDL statements, as well.
For now, sqlfmt only works on `select`, `delete`, `grant`, `revoke`, and `create function` statements (which is all you need if you use sqlfmt with a dbt project). It is being extended to additional DDL and DML. Visit [this tracking issue](https://github.com/tconbeer/sqlfmt/issues/262) for more information.

## Documentation

Expand Down
35 changes: 33 additions & 2 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,38 @@ def handle_semicolon(
)


def handle_ddl_as(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
) -> None:
"""
When we hit "as" in a create function or table statement,
the following syntax should be parsed using the main (select) rules,
unless the next token is a quoted name.
"""
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.UNTERM_KEYWORD,
)

quoted_name_rule = analyzer.get_rule("quoted_name")
comment_rule = analyzer.get_rule("comment")

quoted_name_pattern = rf"({comment_rule.pattern}|\s)*" + quoted_name_rule.pattern
quoted_name_match = re.match(
quoted_name_pattern, source_string[analyzer.pos :], re.IGNORECASE | re.DOTALL
)

if not quoted_name_match:
assert (
analyzer.rule_stack
), "Internal Error! Open an issue. Could not parse DDL 'AS'"
analyzer.pop_rules()


def handle_set_operator(
analyzer: "Analyzer", source_string: str, match: re.Match
) -> None:
Expand Down Expand Up @@ -268,8 +300,7 @@ def lex_ruleset(
"""
Makes a nested call to analyzer.lex, with the new ruleset activated.
"""
rules = sorted(new_ruleset, key=lambda rule: rule.priority)
analyzer.push_rules(rules)
analyzer.push_rules(new_ruleset)
try:
analyzer.lex(source_string)
except stop_exception:
Expand Down
33 changes: 21 additions & 12 deletions src/sqlfmt/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parse_query(self, source_string: str) -> Query:

def push_rules(self, new_rules: List[Rule]) -> None:
self.rule_stack.append(self.rules.copy())
self.rules = new_rules
self.rules = sorted(new_rules, key=lambda rule: rule.priority)

def pop_rules(self) -> List[Rule]:
old_rules = self.rules
Expand All @@ -117,6 +117,25 @@ def get_rule(self, rule_name: str) -> Rule:
except StopIteration:
raise ValueError(f"No rule '{rule_name}'")

def lex_one(self, source_string: str) -> None:
"""
Repeatedly match Rules to the source_string (at self.pos)
and apply the matched action.
Mutates the analyzer's buffers and pos
"""
for rule in self.rules:
match = rule.program.match(source_string, self.pos)
if match:
rule.action(self, source_string, match)
return
# nothing matched. Either whitespace or an error
else:
raise SqlfmtParsingError(
f"Could not parse SQL at position {self.pos}:"
f" '{source_string[self.pos:self.pos+50].strip()}'"
)

def lex(self, source_string: str, eof_pos: int = -1) -> None:
"""
Repeatedly match Rules to the source_string (until the source_string is
Expand All @@ -133,17 +152,7 @@ def lex(self, source_string: str, eof_pos: int = -1) -> None:
last_loop_pos = -1
while self.pos < eof_pos and self.pos > last_loop_pos:
last_loop_pos = self.pos
for rule in self.rules:
match = rule.program.match(source_string, self.pos)
if match:
rule.action(self, source_string, match)
break
# nothing matched. Either whitespace or an error
else:
raise SqlfmtParsingError(
f"Could not parse SQL at position {self.pos}:"
f" '{source_string[self.pos:self.pos+50].strip()}'"
)
self.lex_one(source_string)

def search_for_terminating_token(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/sqlfmt/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def starts_with_jinja_statement(self) -> bool:
return False

@property
def starts_with_square_bracket_operator(self) -> bool:
def starts_with_bracket_operator(self) -> bool:
try:
return self.nodes[0].is_square_bracket_operator
return self.nodes[0].is_bracket_operator
except IndexError:
return False

Expand Down Expand Up @@ -237,7 +237,7 @@ def is_standalone_jinja_statement(self) -> bool:
@property
def is_standalone_operator(self) -> bool:
return self._is_standalone_if(
self.starts_with_operator and not self.starts_with_square_bracket_operator
self.starts_with_operator and not self.starts_with_bracket_operator
)

@property
Expand Down
21 changes: 16 additions & 5 deletions src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,34 @@ def is_opening_bracket(self) -> bool:
)

@property
def is_square_bracket_operator(self) -> bool:
def is_bracket_operator(self) -> bool:
"""
Node is an opening square bracket ("[")
that follows a token that could be a name
that follows a token that could be a name.
Alternatively, node is an open paren ("(")
that follow an closing angle bracket.
"""
if self.token.type != TokenType.BRACKET_OPEN or self.value != "[":
if self.token.type != TokenType.BRACKET_OPEN:
return False

prev_token, _ = get_previous_token(self.previous_node)
if not prev_token:
return False
else:
elif self.value == "[":
return prev_token.type in (
TokenType.NAME,
TokenType.QUOTED_NAME,
TokenType.BRACKET_CLOSE,
)
# BQ struct literals have parens that follow closing angle
# brackets
elif self.value == "(":
return (
prev_token.type == TokenType.BRACKET_CLOSE and ">" in prev_token.token
)
else:
return False

@property
def is_closing_bracket(self) -> bool:
Expand Down Expand Up @@ -205,7 +216,7 @@ def is_operator(self) -> bool:
TokenType.SEMICOLON,
)
or self.is_multiplication_star
or self.is_square_bracket_operator
or self.is_bracket_operator
)

@property
Expand Down
8 changes: 8 additions & 0 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def raise_on_mismatched_bracket(self, token: Token, last_bracket: Node) -> None:
"(": ")",
"[": "]",
"case": "end",
"array<": ">",
"table<": ">",
"struct<": ">",
}
if (
last_bracket.token.type
Expand Down Expand Up @@ -155,6 +158,10 @@ def whitespace(
and previous_token.type == TokenType.COLON
):
return NO_SPACE
# open brackets that contain `<` are bq type definitions
# like `array<` in `array<int64>` and require a space
elif token.type == TokenType.BRACKET_OPEN and "<" in token.token:
return SPACE
# open brackets that follow names are function calls or array indexes.
# open brackets that follow closing brackets are array indexes.
# open brackets that follow open brackets are just nested brackets.
Expand Down Expand Up @@ -220,6 +227,7 @@ def standardize_value(self, token: Token) -> str:
"""
if token.type in (
TokenType.UNTERM_KEYWORD,
TokenType.BRACKET_OPEN,
TokenType.STATEMENT_START,
TokenType.STATEMENT_END,
TokenType.WORD_OPERATOR,
Expand Down
106 changes: 98 additions & 8 deletions src/sqlfmt/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __post_init__(self) -> None:
r'(rb?|b|br|u&|@)?"([^"\\]*(\\.[^"\\]*|""[^"\\]*)*)"',
# possibly escaped single quotes
r"(rb?|b|br|u&|x)?'([^'\\]*(\\.[^'\\]*|''[^'\\]*)*)'",
r"\$\w*\$[^$]*?\$\w*\$", # pg dollar-delimited strings
r"\$\w*\$.*?\$\w*\$", # pg dollar-delimited strings
# possibly escaped backtick
r"`([^`\\]*(\\.[^`\\]*)*)`",
)
Expand Down Expand Up @@ -325,6 +325,8 @@ def __post_init__(self) -> None:
r"\[",
r"\(",
r"\{",
# bq usese angle brackets for type definitions for compound types
r"(array|table|struct)\s*<",
),
action=partial(actions.add_node_to_buffer, token_type=TokenType.BRACKET_OPEN),
),
Expand Down Expand Up @@ -379,18 +381,22 @@ def __post_init__(self) -> None:
r"[*+?]?\?", # regex greedy/non-greedy, also ?
r"!!", # negate text match
r"%%", # psycopg escaped mod operator
r"[+\-*/%&|^=<>:#!]=?", # singles
r">=", # gte
r"[+\-*/%&|^=<:#!]=?", # singles
),
action=partial(actions.add_node_to_buffer, token_type=TokenType.OPERATOR),
),
Rule(
name="bq_typed_array",
priority=900,
name="angle_bracket_close",
priority=810,
pattern=group(
r"array<\w+>",
)
+ group(r"\[", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.NAME),
r">",
),
action=partial(
actions.safe_add_node_to_buffer,
token_type=TokenType.BRACKET_CLOSE,
fallback_token_type=TokenType.OPERATOR,
),
),
Rule(
name="name",
Expand Down Expand Up @@ -427,6 +433,71 @@ def __post_init__(self) -> None:
),
]

CREATE_FUNCTION = [
*CORE,
Rule(
name="function_as",
priority=1100,
pattern=group(
r"as",
)
+ group(r"\W", r"$"),
action=actions.handle_ddl_as,
),
Rule(
name="word_operator",
priority=1100,
pattern=group(
r"to",
r"from",
# snowflake
r"runtime_version",
)
+ group(r"\W", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR),
),
Rule(
name="unterm_keyword",
priority=1300,
pattern=group(
(
r"create(\s+or\s+replace)?(\s+temp(orary)?)?(\s+secure)?(\s+table)?"
r"\s+function(\s+if\s+not\s+exists)?"
),
r"language",
r"transform",
r"immutable",
r"stable",
r"volatile",
r"(not\s+)?leakproof",
r"volatile",
r"called\s+on\s+null\s+input",
r"returns\s+null\s+on\s+null\s+input",
r"return(s)?(?!\s+null)",
r"strict",
r"(external\s+)?security\s+(invoker|definer)",
r"parallel\s+(unsafe|restricted|safe)",
r"cost",
r"rows",
r"support",
r"set",
r"as",
# snowflake
r"comment",
r"imports",
r"packages",
r"handler",
r"target_path",
r"(not\s+)?null",
# bq
r"options",
r"remote\s+with\s+connection",
)
+ group(r"\W", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD),
),
]

MAIN = [
*CORE,
Rule(
Expand Down Expand Up @@ -608,6 +679,25 @@ def __post_init__(self) -> None:
),
),
),
Rule(
name="create_function",
priority=2020,
pattern=group(
(
r"create(\s+or\s+replace)?(\s+temp(orary)?)?(\s+secure)?(\s+table)?"
r"\s+function(\s+if\s+not\s+exists)?"
),
)
+ group(r"\W", r"$"),
action=partial(
actions.handle_nonreserved_keyword,
action=partial(
actions.lex_ruleset,
new_ruleset=CREATE_FUNCTION,
stop_exception=StopRulesetLexing,
),
),
),
Rule(
name="unsupported_ddl",
priority=2999,
Expand Down
26 changes: 26 additions & 0 deletions tests/data/unformatted/124_bq_compound_types.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
SELECT ARRAY<FLOAT64>[1, 2, 3] as floats,
STRUCT("Nathan" as name, ARRAY<FLOAT64>[] as laps),
STRUCT<INT32, INT64>(1, 2),
array<INT64>[3, 4, 5, 6, 1000000, 20000000, 30000000, 409000000, 5000000, 60000000, 700000] as ints,
array<string>['1', '2', '3'] as strings;
create function foo(bar struct<string, array<bytes(5)>, int64>)
returns struct<
string,
array<
bytes(
5
)>, int64>
as bar
)))))__SQLFMT_OUTPUT__(((((
select
array<float64>[1, 2, 3] as floats,
struct("Nathan" as name, array<float64>[] as laps),
struct<int32, int64>(1, 2),
array<int64>[
3, 4, 5, 6, 1000000, 20000000, 30000000, 409000000, 5000000, 60000000, 700000
] as ints,
array<string>['1', '2', '3'] as strings
;
create function foo(bar struct<string, array<bytes(5)>, int64>)
returns struct<string, array<bytes(5)>, int64>
as bar
Loading

0 comments on commit 9318697

Please sign in to comment.