Skip to content

Commit

Permalink
Fix(mysql): improve parsing of INSERT .. SELECT statement (#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 30, 2023
1 parent 59da40e commit 8a19d7a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
13 changes: 12 additions & 1 deletion sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ class Tokenizer(tokens.Tokenizer):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}

class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
TokenType.SCHEMA,
TokenType.VALUES,
}

TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)
Expand All @@ -207,6 +213,10 @@ class Parser(parser.Parser):
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
# https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
"VALUES": lambda self: self.expression(
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
),
}

STATEMENT_PARSERS = {
Expand Down Expand Up @@ -399,6 +409,7 @@ class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
JOIN_HINTS = False
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ class Insert(Expression):
"partition": False,
"alternative": False,
"where": False,
"ignore": False,
}

def with_(
Expand Down
10 changes: 8 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ class Generator:
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True

# Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
DUPLICATE_KEY_UPDATE_WITH_SET = True

# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")

Expand Down Expand Up @@ -1105,6 +1108,8 @@ def insert_sql(self, expression: exp.Insert) -> str:

alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
ignore = " IGNORE" if expression.args.get("ignore") else ""

this = f"{this} {self.sql(expression, 'this')}"

exists = " IF EXISTS" if expression.args.get("exists") else ""
Expand All @@ -1116,7 +1121,7 @@ def insert_sql(self, expression: exp.Insert) -> str:
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)

def intersect_sql(self, expression: exp.Intersect) -> str:
Expand All @@ -1143,8 +1148,9 @@ def onconflict_sql(self, expression: exp.OnConflict) -> str:
do = "" if expression.args.get("duplicate") else " DO "
nothing = "NOTHING" if expression.args.get("nothing") else ""
expressions = self.expressions(expression, flat=True)
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
if expressions:
expressions = f"UPDATE SET {expressions}"
expressions = f"UPDATE {set_keyword}{expressions}"
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"

def returning_sql(self, expression: exp.Returning) -> str:
Expand Down
18 changes: 11 additions & 7 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,7 @@ def _parse_describe(self) -> exp.Describe:

def _parse_insert(self) -> exp.Insert:
overwrite = self._match(TokenType.OVERWRITE)
ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
alternative = None

Expand Down Expand Up @@ -1709,6 +1710,7 @@ def _parse_insert(self) -> exp.Insert:
returning=self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
ignore=ignore,
)

def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
Expand All @@ -1734,7 +1736,8 @@ def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
nothing = True
else:
self._match(TokenType.UPDATE)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
self._match(TokenType.SET)
expressions = self._parse_csv(self._parse_equality)

return self.expression(
exp.OnConflict,
Expand Down Expand Up @@ -1917,7 +1920,7 @@ def _parse_select(
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")

limit = self._parse_limit(top=True)
expressions = self._parse_csv(self._parse_expression)
expressions = self._parse_expressions()

this = self.expression(
exp.Select,
Expand Down Expand Up @@ -2091,9 +2094,7 @@ def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]:

partition = self._parse_partition_by()
order = self._parse_order()
measures = (
self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
)
measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None

if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
rows = exp.var("ONE ROW PER MATCH")
Expand Down Expand Up @@ -3174,7 +3175,7 @@ def _parse_primary(self) -> t.Optional[exp.Expression]:
if query:
expressions = [query]
else:
expressions = self._parse_csv(self._parse_expression)
expressions = self._parse_expressions()

this = self._parse_query_modifiers(seq_get(expressions, 0))

Expand Down Expand Up @@ -4226,7 +4227,7 @@ def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_expression)
return self._parse_csv(self._parse_expression)
return self._parse_expressions()

def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
Expand Down Expand Up @@ -4276,6 +4277,9 @@ def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.
self._match_r_paren()
return parse_result

def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
return self._parse_csv(self._parse_expression)

def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_set_operations(
self._parse_expression() if alias else self._parse_conjunction()
Expand Down
14 changes: 13 additions & 1 deletion tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ def test_ddl(self):
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10")
self.validate_identity(
"INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1"
"INSERT IGNORE INTO subscribers (email) VALUES ('[email protected]'), ('[email protected]')"
)
self.validate_identity(
"INSERT INTO t1 (a, b, c) VALUES (1, 2, 3), (4, 5, 6) ON DUPLICATE KEY UPDATE c = VALUES(a) + VALUES(b)"
)
self.validate_identity(
"INSERT INTO t1 (a, b) SELECT c, d FROM t2 UNION SELECT e, f FROM t3 ON DUPLICATE KEY UPDATE b = b + c"
)
self.validate_identity(
"INSERT INTO t1 (a, b, c) VALUES (1, 2, 3) ON DUPLICATE KEY UPDATE c = c + 1"
)
self.validate_identity(
"INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE x.id = 1"
)

self.validate_all(
Expand Down

0 comments on commit 8a19d7a

Please sign in to comment.