diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 045e0fbab5..122e937423 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -574,6 +574,10 @@ class Parser(parser.Parser): JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} + PROCEDURE_OPTIONS = dict.fromkeys( + ("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple() + ) + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, *parser.Parser.TYPE_TOKENS, @@ -719,7 +723,11 @@ def _parse_user_defined_function( ): return this - expressions = self._parse_csv(self._parse_function_parameter) + if not self._match(TokenType.WITH, advance=False): + expressions = self._parse_csv(self._parse_function_parameter) + else: + expressions = None + return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) def _parse_id_var( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 3d3dc42bb3..56ac65abf2 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2996,6 +2996,10 @@ class WithSystemVersioningProperty(Property): } +class WithProcedureOptions(Property): + arg_types = {"expressions": True} + + class Properties(Expression): arg_types = {"expressions": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d50f992b30..4da80a13ff 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -201,6 +201,7 @@ class Generator(metaclass=_Generator): exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}", exp.VolatileProperty: lambda *_: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", + exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}", exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}", exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}", } @@ -565,6 +566,7 @@ class Generator(metaclass=_Generator): exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, + exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA, exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA, exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA, } @@ -3622,7 +3624,7 @@ def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: expressions = ( self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" ) - return f"{this}{expressions}" + return f"{this}{expressions}" if expressions.strip() != "" else this def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bbd6d1cce8..e93fe2c374 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1225,6 +1225,10 @@ class Parser(metaclass=_Parser): **dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()), } + PROCEDURE_OPTIONS: OPTIONS_TYPE = {} + + EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple()) + KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = { "NOT": ("ENFORCED",), "MATCH": ( @@ -2203,11 +2207,26 @@ def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expres this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS), ) + if self._match_texts(self.PROCEDURE_OPTIONS, advance=False): + return self.expression( + exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option) + ) + if not self._next: return None return self._parse_withisolatedloading() + def _parse_procedure_option(self) -> exp.Expression | None: + if self._match_text_seq("EXECUTE", "AS"): + return self.expression( + exp.ExecuteAsProperty, + this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False) + or self._parse_string(), + ) + + return self._parse_var_from_options(self.PROCEDURE_OPTIONS) + # https://dev.mysql.com/doc/refman/8.0/en/create-view.html def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: self._match(TokenType.EQ) @@ -5550,12 +5569,15 @@ def _parse_not_constraint(self) -> t.Optional[exp.Expression]: return None def _parse_column_constraint(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.CONSTRAINT): - this = self._parse_id_var() - else: - this = None + this = self._match(TokenType.CONSTRAINT) and self._parse_id_var() + + procedure_option_follows = ( + self._match(TokenType.WITH, advance=False) + and self._next + and self._next.text.upper() in self.PROCEDURE_OPTIONS + ) - if self._match_texts(self.CONSTRAINT_PARSERS): + if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS): return self.expression( exp.ColumnConstraint, this=this, diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index e5b5b1867e..042891a4d4 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1001,6 +1001,17 @@ def test_udf(self): ) self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz") self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz") + + self.validate_identity("CREATE PROCEDURE foo WITH ENCRYPTION AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH RECOMPILE AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH SCHEMABINDING AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH NATIVE_COMPILATION AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS OWNER AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS 'username' AS SELECT 1") + self.validate_identity( + "CREATE PROCEDURE foo WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS SELECT 1" + ) + self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1") self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER") @@ -1059,6 +1070,7 @@ def test_fullproc(self): CREATE procedure [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER ,@NumberOfRows INTEGER + WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS BEGIN SET XACT_ABORT ON; @@ -1074,7 +1086,7 @@ def test_fullproc(self): """ expected_sqls = [ - "CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON", + "CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS BEGIN SET XACT_ABORT ON", "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", "DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", "DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())",