Skip to content

Commit

Permalink
feat(tsql): Support for stored procedure options (#4260)
Browse files Browse the repository at this point in the history
* added support for tsql procedure options

* fixed return type

* ran style

* adressed comment refactorings

* modified _parse_column_constraint method in Parser class

* removed match and retreat in _parse_column_constraint

* added check if next
  • Loading branch information
rsanchez-xtillion authored Oct 18, 2024
1 parent 48be3d8 commit 04dccf3
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
10 changes: 9 additions & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,10 @@ class Parser(parser.Parser):

JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}

PROCEDURE_OPTIONS = dict.fromkeys(
("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple()
)

RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
*parser.Parser.TYPE_TOKENS,
Expand Down Expand Up @@ -719,7 +723,11 @@ def _parse_user_defined_function(
):
return this

expressions = self._parse_csv(self._parse_function_parameter)
if not self._match(TokenType.WITH, advance=False):
expressions = self._parse_csv(self._parse_function_parameter)
else:
expressions = None

return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)

def _parse_id_var(
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,10 @@ class WithSystemVersioningProperty(Property):
}


class WithProcedureOptions(Property):
arg_types = {"expressions": True}


class Properties(Expression):
arg_types = {"expressions": True}

Expand Down
4 changes: 3 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class Generator(metaclass=_Generator):
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}",
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
Expand Down Expand Up @@ -565,6 +566,7 @@ class Generator(metaclass=_Generator):
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA,
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
Expand Down Expand Up @@ -3622,7 +3624,7 @@ def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
return f"{this}{expressions}"
return f"{this}{expressions}" if expressions.strip() != "" else this

def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
Expand Down
32 changes: 27 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,10 @@ class Parser(metaclass=_Parser):
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
}

PROCEDURE_OPTIONS: OPTIONS_TYPE = {}

EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple())

KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
"NOT": ("ENFORCED",),
"MATCH": (
Expand Down Expand Up @@ -2203,11 +2207,26 @@ def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expres
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
)

if self._match_texts(self.PROCEDURE_OPTIONS, advance=False):
return self.expression(
exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option)
)

if not self._next:
return None

return self._parse_withisolatedloading()

def _parse_procedure_option(self) -> exp.Expression | None:
if self._match_text_seq("EXECUTE", "AS"):
return self.expression(
exp.ExecuteAsProperty,
this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False)
or self._parse_string(),
)

return self._parse_var_from_options(self.PROCEDURE_OPTIONS)

# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
self._match(TokenType.EQ)
Expand Down Expand Up @@ -5550,12 +5569,15 @@ def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
return None

def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
else:
this = None
this = self._match(TokenType.CONSTRAINT) and self._parse_id_var()

procedure_option_follows = (
self._match(TokenType.WITH, advance=False)
and self._next
and self._next.text.upper() in self.PROCEDURE_OPTIONS
)

if self._match_texts(self.CONSTRAINT_PARSERS):
if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression(
exp.ColumnConstraint,
this=this,
Expand Down
14 changes: 13 additions & 1 deletion tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,17 @@ def test_udf(self):
)
self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz")
self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz")

self.validate_identity("CREATE PROCEDURE foo WITH ENCRYPTION AS SELECT 1")
self.validate_identity("CREATE PROCEDURE foo WITH RECOMPILE AS SELECT 1")
self.validate_identity("CREATE PROCEDURE foo WITH SCHEMABINDING AS SELECT 1")
self.validate_identity("CREATE PROCEDURE foo WITH NATIVE_COMPILATION AS SELECT 1")
self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS OWNER AS SELECT 1")
self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS 'username' AS SELECT 1")
self.validate_identity(
"CREATE PROCEDURE foo WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS SELECT 1"
)

self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1")
self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER")

Expand Down Expand Up @@ -1059,6 +1070,7 @@ def test_fullproc(self):
CREATE procedure [TRANSF].[SP_Merge_Sales_Real]
@Loadid INTEGER
,@NumberOfRows INTEGER
WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION
AS
BEGIN
SET XACT_ABORT ON;
Expand All @@ -1074,7 +1086,7 @@ def test_fullproc(self):
"""

expected_sqls = [
"CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON",
"CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS BEGIN SET XACT_ABORT ON",
"DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
"DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
"DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())",
Expand Down

0 comments on commit 04dccf3

Please sign in to comment.