From 256a3790651e4c99d70928aaeee0c9f1ec73fc5e Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 31 May 2023 17:40:43 +0300 Subject: [PATCH 1/2] Feat(mysql): add support for the UNIQUE KEY constraint --- sqlglot/expressions.py | 2 +- sqlglot/generator.py | 4 +++- sqlglot/parser.py | 4 +++- tests/dialects/test_mysql.py | 14 ++++++++++---- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index db1fa5005a..b74e369a23 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1451,7 +1451,7 @@ class PrimaryKey(Expression): class Unique(Expression): - arg_types = {"expressions": True} + arg_types = {"this": False, "expressions": True} # https://www.postgresql.org/docs/9.1/sql-selectinto.html diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1d5f5f3eb3..1515e5d1bc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1772,8 +1772,10 @@ def primarykey_sql(self, expression: exp.ForeignKey) -> str: return f"PRIMARY KEY ({expressions}){options}" def unique_sql(self, expression: exp.Unique) -> str: + this = self.sql(expression, "this") + this = f"{this} " if this else "" columns = self.expressions(expression, key="expressions") - return f"UNIQUE ({columns})" + return f"UNIQUE {this}({columns})" def if_sql(self, expression: exp.If) -> str: return self.case_sql( diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 6afc11c01b..49c02b41cc 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -3371,9 +3371,11 @@ def _parse_unnamed_constraint( return self.CONSTRAINT_PARSERS[constraint](self) def _parse_unique(self) -> exp.Expression: + this = self._match_text_seq("KEY") and self._parse_id_var() if not self._match(TokenType.L_PAREN, advance=False): return self.expression(exp.UniqueColumnConstraint) - return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) + + return self.expression(exp.Unique, this=this, expressions=self._parse_wrapped_id_vars()) def _parse_key_constraint_options(self) -> t.List[str]: options = [] diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index be7887e8b3..a80153b552 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -6,6 +6,10 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" + ) + self.validate_all( "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", write={ @@ -21,10 +25,6 @@ def test_ddl(self): "mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", }, ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" - ) - self.validate_all( "CREATE TABLE x (id int not null auto_increment, primary key (id))", write={ @@ -37,6 +37,12 @@ def test_ddl(self): "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", }, ) + self.validate_all( + "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", + write={ + "mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", + }, + ) def test_identity(self): self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") From 86fa81ccca5a37a888aa78406039abaa8cd03a25 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 31 May 2023 18:57:43 +0300 Subject: [PATCH 2/2] Cleanup --- sqlglot/expressions.py | 6 +----- sqlglot/generator.py | 12 ++++-------- sqlglot/parser.py | 9 ++++----- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b74e369a23..441cef5c89 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1285,7 +1285,7 @@ class TitleColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind): - arg_types: t.Dict[str, t.Any] = {} + arg_types = {"this": False} class UppercaseColumnConstraint(ColumnConstraintKind): @@ -1450,10 +1450,6 @@ class PrimaryKey(Expression): arg_types = {"expressions": True, "options": False} -class Unique(Expression): - arg_types = {"this": False, "expressions": True} - - # https://www.postgresql.org/docs/9.1/sql-selectinto.html # https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples class Into(Expression): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1515e5d1bc..4661706494 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -629,8 +629,10 @@ def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstra return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _) -> str: - return "UNIQUE" + def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"UNIQUE{this}" def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() @@ -1771,12 +1773,6 @@ def primarykey_sql(self, expression: exp.ForeignKey) -> str: options = f" {options}" if options else "" return f"PRIMARY KEY ({expressions}){options}" - def unique_sql(self, expression: exp.Unique) -> str: - this = self.sql(expression, "this") - this = f"{this} " if this else "" - columns = self.expressions(expression, key="expressions") - return f"UNIQUE {this}({columns})" - def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 49c02b41cc..3c71b577ed 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -3371,11 +3371,10 @@ def _parse_unnamed_constraint( return self.CONSTRAINT_PARSERS[constraint](self) def _parse_unique(self) -> exp.Expression: - this = self._match_text_seq("KEY") and self._parse_id_var() - if not self._match(TokenType.L_PAREN, advance=False): - return self.expression(exp.UniqueColumnConstraint) - - return self.expression(exp.Unique, this=this, expressions=self._parse_wrapped_id_vars()) + self._match_text_seq("KEY") + return self.expression( + exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) + ) def _parse_key_constraint_options(self) -> t.List[str]: options = []