diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 00000000..d090dbab --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +rust 1.73.0 \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e9ec3cc7..540b55dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,29 @@ Given that the parser produces a typed AST, any changes to the AST will technica ## [Unreleased] Check https://github.com/sqlparser-rs/sqlparser-rs/commits/main for undocumented changes. + +## [0.40.0] 2023-11-27 + +### Added +* Add `{pre,post}_visit_query` to `Visitor` (#1044) - Thanks @jmhain +* Support generated virtual columns with expression (#1051) - Thanks @takluyver +* Support PostgreSQL `END` (#1035) - Thanks @tobyhede +* Support `INSERT INTO ... DEFAULT VALUES ...` (#1036) - Thanks @CDThomas +* Support `RELEASE` and `ROLLBACK TO SAVEPOINT` (#1045) - Thanks @CDThomas +* Support `CONVERT` expressions (#1048) - Thanks @lovasoa +* Support `GLOBAL` and `SESSION` parts in `SHOW VARIABLES` for mysql and generic - Thanks @emin100 +* Support snowflake `PIVOT` on derived table factors (#1027) - Thanks @lustefaniak +* Support mssql json and xml extensions (#1043) - Thanks @lovasoa +* Support for `MAX` as a character length (#1038) - Thanks @lovasoa +* Support `IN ()` syntax of SQLite (#1028) - Thanks @alamb + +### Fixed +* Fix extra whitespace printed before `ON CONFLICT` (#1037) - Thanks @CDThomas + +### Changed +* Document round trip ability (#1052) - Thanks @alamb +* Add PRQL to list of users (#1031) - Thanks @vanillajonathan + ## [0.39.0] 2023-10-27 ### Added diff --git a/Cargo.toml b/Cargo.toml index 0af32759..f709a93e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlparser" description = "Extensible SQL Lexer and Parser with support for ANSI SQL:2011" -version = "0.39.0" +version = "0.40.0" authors = ["Andy Grove "] homepage = "https://github.com/sqlparser-rs/sqlparser-rs" documentation = "https://docs.rs/sqlparser/" @@ -34,7 +34,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } # of dev-dependencies because of # https://github.com/rust-lang/cargo/issues/1596 serde_json = { version = "1.0", optional = true } -sqlparser_derive = { version = "0.1.1", path = "derive", optional = true } +sqlparser_derive = { version = "0.2.0", path = "derive", optional = true } [dev-dependencies] simple_logger = "4.0" diff --git a/README.md b/README.md index 58f5b8d4..6a551d0f 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,28 @@ This crate avoids semantic analysis because it varies drastically between dialects and implementations. If you want to do semantic analysis, feel free to use this project as a base. +## Preserves Syntax Round Trip + +This crate allows users to recover the original SQL text (with normalized +whitespace and keyword capitalization), which is useful for tools that +analyze and manipulate SQL. + +This means that other than whitespace and the capitalization of keywords, the +following should hold true for all SQL: + +```rust +// Parse SQL +let ast = Parser::parse_sql(&GenericDialect, sql).unwrap(); + +// The original SQL text can be generated from the AST +assert_eq!(ast[0].to_string(), sql); +``` + +There are still some cases in this crate where different SQL with seemingly +similar semantics are represented with the same AST. We welcome PRs to fix such +issues and distinguish different syntaxes in the AST. + + ## SQL compliance SQL was first standardized in 1987, and revisions of the standard have been @@ -93,7 +115,7 @@ $ cargo run --features json_example --example cli FILENAME.sql [--dialectname] ## Users This parser is currently being used by the [DataFusion] query engine, -[LocustDB], [Ballista], [GlueSQL], [Opteryx], and [JumpWire]. +[LocustDB], [Ballista], [GlueSQL], [Opteryx], [PRQL], and [JumpWire]. If your project is using sqlparser-rs feel free to make a PR to add it to this list. @@ -188,6 +210,7 @@ licensed as above, without any additional terms or conditions. [Ballista]: https://github.com/apache/arrow-ballista [GlueSQL]: https://github.com/gluesql/gluesql [Opteryx]: https://github.com/mabel-dev/opteryx +[PRQL]: https://github.com/PRQL/prql [JumpWire]: https://github.com/extragoodlabs/jumpwire [Pratt Parser]: https://tdop.github.io/ [sql-2016-grammar]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 58e1fbf5..43f75a5e 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlparser_derive" description = "proc macro for sqlparser" -version = "0.1.1" +version = "0.2.1" authors = ["sqlparser-rs authors"] homepage = "https://github.com/sqlparser-rs/sqlparser-rs" documentation = "https://docs.rs/sqlparser_derive/" @@ -18,6 +18,6 @@ edition = "2021" proc-macro = true [dependencies] -syn = "1.0" +syn = { version = "2.0", default-features = false, features = ["printing", "parsing", "derive", "proc-macro"] } proc-macro2 = "1.0" quote = "1.0" diff --git a/derive/README.md b/derive/README.md index ec0fcb6f..ad4978a8 100644 --- a/derive/README.md +++ b/derive/README.md @@ -48,33 +48,102 @@ impl Visit for Bar { } ``` -Additionally certain types may wish to call a corresponding method on visitor before recursing +Some types may wish to call a corresponding method on the visitor: ```rust #[derive(Visit, VisitMut)] #[visit(with = "visit_expr")] enum Expr { - A(), - B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool), + IsNull(Box), + .. } ``` -Will generate +This will result in the following sequence of visitor calls when an `IsNull` +expression is visited + +``` +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +``` + +For some types it is only appropriate to call a particular visitor method in +some contexts. For example, not every `ObjectName` refers to a relation. + +In these cases, the `visit` attribute can be used on the field for which we'd +like to call the method: ```rust -impl Visit for Bar { +#[derive(Visit, VisitMut)] +#[visit(with = "visit_table_factor")] +pub enum TableFactor { + Table { + #[visit(with = "visit_relation")] + name: ObjectName, + alias: Option, + }, + .. +} +``` + +This will generate + +```rust +impl Visit for TableFactor { fn visit(&self, visitor: &mut V) -> ControlFlow { - visitor.visit_expr(self)?; + visitor.pre_visit_table_factor(self)?; match self { - Self::A() => {} - Self::B(_1, _2, _3) => { - _1.visit(visitor)?; - visitor.visit_relation(_3)?; - _2.visit(visitor)?; - _3.visit(visitor)?; + Self::Table { name, alias } => { + visitor.pre_visit_relation(name)?; + alias.visit(name)?; + visitor.post_visit_relation(name)?; + alias.visit(visitor)?; } } + visitor.post_visit_table_factor(self)?; ControlFlow::Continue(()) } } ``` + +Note that annotating both the type and the field is incorrect as it will result +in redundant calls to the method. For example + +```rust +#[derive(Visit, VisitMut)] +#[visit(with = "visit_expr")] +enum Expr { + IsNull(#[visit(with = "visit_expr")] Box), + .. +} +``` + +will result in these calls to the visitor + + +``` +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +``` + +## Releasing + +This crate's release is not automated. Instead it is released manually as needed + +Steps: +1. Update the version in `Cargo.toml` +2. Update the corresponding version in `../Cargo.toml` +3. Commit via PR +4. Publish to crates.io: + +```shell +# update to latest checked in main branch and publish via +cargo publish +``` + diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 43f40664..009e704d 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -2,8 +2,9 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; use syn::{ - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, - Ident, Index, Lit, Meta, MetaNameValue, NestedMeta, + parse::{Parse, ParseStream}, + parse_macro_input, parse_quote, Attribute, Data, DeriveInput, + Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token }; /// Implementation of `[#derive(Visit)]` @@ -84,38 +85,43 @@ struct Attributes { with: Option, } +struct WithIdent { + with: Option, +} +impl Parse for WithIdent { + fn parse(input: ParseStream) -> Result { + let mut result = WithIdent { with: None }; + let ident = input.parse::()?; + if ident != "with" { + return Err(syn::Error::new(ident.span(), "Expected identifier to be `with`")); + } + input.parse::()?; + let s = input.parse::()?; + result.with = Some(format_ident!("{}", s.value(), span = s.span())); + Ok(result) + } +} + impl Attributes { fn parse(attrs: &[Attribute]) -> Self { let mut out = Self::default(); - for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) { - let meta = attr.parse_meta().expect("visit attribute"); - match meta { - Meta::List(l) => { - for nested in &l.nested { - match nested { - NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v), - _ => panic!("Expected #[visit(key = \"value\")]"), + for attr in attrs { + if let Meta::List(ref metalist) = attr.meta { + if metalist.path.is_ident("visit") { + match syn::parse2::(metalist.tokens.clone()) { + Ok(with_ident) => { + out.with = with_ident.with; + } + Err(e) => { + panic!("{}", e); } } } - _ => panic!("Expected #[visit(...)]"), } } out } - /// Updates self with a name value attribute - fn parse_name_value(&mut self, v: &MetaNameValue) { - if v.path.is_ident("with") { - match &v.lit { - Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())), - _ => panic!("Expected a string value, got {}", v.lit.to_token_stream()), - } - return; - } - panic!("Unrecognised kv attribute {}", v.path.to_token_stream()) - } - /// Returns the pre and post visit token streams fn visit(&self, s: TokenStream) -> (Option, Option) { let pre_visit = self.with.as_ref().map(|m| { diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index c0dc8f88..54781efb 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -374,7 +374,6 @@ impl fmt::Display for DataType { } write!(f, ")") } - DataType::SnowflakeTimestamp => write!(f, "TIMESTAMP_NTZ"), DataType::Struct(fields) => { if !fields.is_empty() { write!(f, "STRUCT<{}>", display_comma_separated(fields)) @@ -382,6 +381,7 @@ impl fmt::Display for DataType { write!(f, "STRUCT") } } + DataType::SnowflakeTimestamp => write!(f, "TIMESTAMP_NTZ"), } } } @@ -521,18 +521,29 @@ impl fmt::Display for ExactNumberInfo { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct CharacterLength { - /// Default (if VARYING) or maximum (if not VARYING) length - pub length: u64, - /// Optional unit. If not informed, the ANSI handles it as CHARACTERS implicitly - pub unit: Option, +pub enum CharacterLength { + IntegerLength { + /// Default (if VARYING) or maximum (if not VARYING) length + length: u64, + /// Optional unit. If not informed, the ANSI handles it as CHARACTERS implicitly + unit: Option, + }, + /// VARCHAR(MAX) or NVARCHAR(MAX), used in T-SQL (Miscrosoft SQL Server) + Max, } impl fmt::Display for CharacterLength { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.length)?; - if let Some(unit) = &self.unit { - write!(f, " {unit}")?; + match self { + CharacterLength::IntegerLength { length, unit } => { + write!(f, "{}", length)?; + if let Some(unit) = unit { + write!(f, " {unit}")?; + } + } + CharacterLength::Max => { + write!(f, "MAX")?; + } } Ok(()) } diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index da2c8c9e..3192af8b 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -599,6 +599,7 @@ pub enum ColumnOption { generated_as: GeneratedAs, sequence_options: Option>, generation_expr: Option, + generation_expr_mode: Option, }, } @@ -639,25 +640,25 @@ impl fmt::Display for ColumnOption { generated_as, sequence_options, generation_expr, - } => match generated_as { - GeneratedAs::Always => { - write!(f, "GENERATED ALWAYS AS IDENTITY")?; - if sequence_options.is_some() { - let so = sequence_options.as_ref().unwrap(); - if !so.is_empty() { - write!(f, " (")?; - } - for sequence_option in so { - write!(f, "{sequence_option}")?; - } - if !so.is_empty() { - write!(f, " )")?; - } - } + generation_expr_mode, + } => { + if let Some(expr) = generation_expr { + let modifier = match generation_expr_mode { + None => "", + Some(GeneratedExpressionMode::Virtual) => " VIRTUAL", + Some(GeneratedExpressionMode::Stored) => " STORED", + }; + write!(f, "GENERATED ALWAYS AS ({expr}){modifier}")?; Ok(()) - } - GeneratedAs::ByDefault => { - write!(f, "GENERATED BY DEFAULT AS IDENTITY")?; + } else { + // Like Postgres - generated from sequence + let when = match generated_as { + GeneratedAs::Always => "ALWAYS", + GeneratedAs::ByDefault => "BY DEFAULT", + // ExpStored goes with an expression, handled above + GeneratedAs::ExpStored => unreachable!(), + }; + write!(f, "GENERATED {when} AS IDENTITY")?; if sequence_options.is_some() { let so = sequence_options.as_ref().unwrap(); if !so.is_empty() { @@ -672,17 +673,13 @@ impl fmt::Display for ColumnOption { } Ok(()) } - GeneratedAs::ExpStored => { - let expr = generation_expr.as_ref().unwrap(); - write!(f, "GENERATED ALWAYS AS ({expr}) STORED") - } - }, + } } } } /// `GeneratedAs`s are modifiers that follow a column option in a `generated`. -/// 'ExpStored' is PostgreSQL specific +/// 'ExpStored' is used for a column generated from an expression and stored. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -692,6 +689,16 @@ pub enum GeneratedAs { ExpStored, } +/// `GeneratedExpressionMode`s are modifiers that follow an expression in a `generated`. +/// No modifier is typically the same as Virtual. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum GeneratedExpressionMode { + Virtual, + Stored, +} + fn display_constraint_name(name: &'_ Option) -> impl fmt::Display + '_ { struct ConstraintName<'a>(&'a Option); impl<'a> fmt::Display for ConstraintName<'a> { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 3d73918c..b0b15371 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -31,18 +31,18 @@ pub use self::data_type::{ pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue}; pub use self::ddl::{ AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption, - ColumnOptionDef, GeneratedAs, IndexType, KeyOrIndexDisplay, Partition, ProcedureParam, - ReferentialAction, TableConstraint, UserDefinedTypeCompositeAttributeDef, + ColumnOptionDef, GeneratedAs, GeneratedExpressionMode, IndexType, KeyOrIndexDisplay, Partition, + ProcedureParam, ReferentialAction, TableConstraint, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, }; pub use self::operator::{BinaryOperator, UnaryOperator}; pub use self::query::{ - Cte, Distinct, ExceptSelectItem, ExcludeSelectItem, Fetch, GroupByExpr, IdentWithAlias, Join, - JoinConstraint, JoinOperator, LateralView, LockClause, LockType, NamedWindowDefinition, - NonBlock, Offset, OffsetRows, OrderByExpr, Query, RenameSelectItem, ReplaceSelectElement, - ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, Table, - TableAlias, TableFactor, TableVersion, TableWithJoins, Top, Values, WildcardAdditionalOptions, - With, + Cte, Distinct, ExceptSelectItem, ExcludeSelectItem, Fetch, ForClause, ForJson, ForXml, + GroupByExpr, IdentWithAlias, Join, JoinConstraint, JoinOperator, LateralView, LockClause, + LockType, NamedWindowDefinition, NonBlock, Offset, OffsetRows, OrderByExpr, Query, + RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, + SetExpr, SetOperator, SetQuantifier, Table, TableAlias, TableFactor, TableVersion, + TableWithJoins, Top, Values, WildcardAdditionalOptions, With, }; pub use self::value::{ escape_quoted_string, DateTimeField, DollarQuotedString, TrimWhereField, Value, @@ -473,6 +473,17 @@ pub enum Expr { }, /// Unary operation e.g. `NOT foo` UnaryOp { op: UnaryOperator, expr: Box }, + /// CONVERT a value to a different data type or character encoding `CONVERT(foo USING utf8mb4)` + Convert { + /// The expression to convert + expr: Box, + /// The target data type + data_type: Option, + /// The target character encoding + charset: Option, + /// whether the target comes before the expr (MSSQL syntax) + target_before_value: bool, + }, /// CAST an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))` Cast { expr: Box, @@ -844,6 +855,28 @@ impl fmt::Display for Expr { write!(f, "{op}{expr}") } } + Expr::Convert { + expr, + target_before_value, + data_type, + charset, + } => { + write!(f, "CONVERT(")?; + if let Some(data_type) = data_type { + if let Some(charset) = charset { + write!(f, "{expr}, {data_type} CHARACTER SET {charset}") + } else if *target_before_value { + write!(f, "{data_type}, {expr}") + } else { + write!(f, "{expr}, {data_type}") + } + } else if let Some(charset) = charset { + write!(f, "{expr} USING {charset}") + } else { + write!(f, "{expr}") // This should never happen + }?; + write!(f, ")") + } Expr::Cast { expr, data_type, @@ -1481,7 +1514,7 @@ pub enum Statement { /// Overwrite (Hive) overwrite: bool, /// A SQL query that specifies what to insert - source: Box, + source: Option>, /// partitioned insert (Hive) partitioned: Option>, /// Columns defined after PARTITION @@ -1846,6 +1879,8 @@ pub enum Statement { /// Note: this is a MySQL-specific statement. ShowVariables { filter: Option, + global: bool, + session: bool, }, /// SHOW CREATE TABLE /// @@ -1915,9 +1950,10 @@ pub enum Statement { Commit { chain: bool, }, - /// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]` + /// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]` Rollback { chain: bool, + savepoint: Option, }, /// CREATE SCHEMA CreateSchema { @@ -2063,6 +2099,10 @@ pub enum Statement { Savepoint { name: Ident, }, + /// RELEASE \[ SAVEPOINT \] savepoint_name + ReleaseSavepoint { + name: Ident, + }, // MERGE INTO statement, based on Snowflake. See Merge { // optional INTO keyword @@ -2447,7 +2487,14 @@ impl fmt::Display for Statement { if !after_columns.is_empty() { write!(f, "({}) ", display_comma_separated(after_columns))?; } - write!(f, "{source}")?; + + if let Some(source) = source { + write!(f, "{source}")?; + } + + if source.is_none() && columns.is_empty() { + write!(f, "DEFAULT VALUES")?; + } if let Some(on) = on { write!(f, "{on}")?; @@ -3183,8 +3230,19 @@ impl fmt::Display for Statement { } Ok(()) } - Statement::ShowVariables { filter } => { - write!(f, "SHOW VARIABLES")?; + Statement::ShowVariables { + filter, + global, + session, + } => { + write!(f, "SHOW")?; + if *global { + write!(f, " GLOBAL")?; + } + if *session { + write!(f, " SESSION")?; + } + write!(f, " VARIABLES")?; if filter.is_some() { write!(f, " {}", filter.as_ref().unwrap())?; } @@ -3285,8 +3343,18 @@ impl fmt::Display for Statement { Statement::Commit { chain } => { write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },) } - Statement::Rollback { chain } => { - write!(f, "ROLLBACK{}", if *chain { " AND CHAIN" } else { "" },) + Statement::Rollback { chain, savepoint } => { + write!(f, "ROLLBACK")?; + + if *chain { + write!(f, " AND CHAIN")?; + } + + if let Some(savepoint) = savepoint { + write!(f, " TO SAVEPOINT {savepoint}")?; + } + + Ok(()) } Statement::CreateSchema { schema_name, @@ -3383,6 +3451,9 @@ impl fmt::Display for Statement { write!(f, "SAVEPOINT ")?; write!(f, "{name}") } + Statement::ReleaseSavepoint { name } => { + write!(f, "RELEASE SAVEPOINT {name}") + } Statement::Merge { into, table, @@ -3843,7 +3914,7 @@ impl fmt::Display for OnInsert { " ON DUPLICATE KEY UPDATE {}", display_comma_separated(expr) ), - Self::OnConflict(o) => write!(f, " {o}"), + Self::OnConflict(o) => write!(f, "{o}"), } } } diff --git a/src/ast/query.rs b/src/ast/query.rs index bd92e35d..b0ce66fa 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -26,6 +26,7 @@ use crate::ast::*; #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] pub struct Query { /// WITH (common table expressions, or CTEs) pub with: Option, @@ -45,6 +46,10 @@ pub struct Query { pub fetch: Option, /// `FOR { UPDATE | SHARE } [ OF table_name ] [ SKIP LOCKED | NOWAIT ]` pub locks: Vec, + /// `FOR XML { RAW | AUTO | EXPLICIT | PATH } [ , ELEMENTS ]` + /// `FOR JSON { AUTO | PATH } [ , INCLUDE_NULL_VALUES ]` + /// (MSSQL-specific) + pub for_clause: Option, } impl fmt::Display for Query { @@ -71,6 +76,9 @@ impl fmt::Display for Query { if !self.locks.is_empty() { write!(f, " {}", display_separated(&self.locks, " "))?; } + if let Some(ref for_clause) = self.for_clause { + write!(f, " {}", for_clause)?; + } Ok(()) } } @@ -736,7 +744,6 @@ pub enum TableFactor { /// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))` /// See Pivot { - #[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))] table: Box, aggregate_function: Expr, // Function expression value_column: Vec, @@ -752,7 +759,6 @@ pub enum TableFactor { /// /// See . Unpivot { - #[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))] table: Box, value: Ident, name: Ident, @@ -1319,3 +1325,125 @@ impl fmt::Display for GroupByExpr { } } } + +/// FOR XML or FOR JSON clause, specific to MSSQL +/// (formats the output of a query as XML or JSON) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ForClause { + Browse, + Json { + for_json: ForJson, + root: Option, + include_null_values: bool, + without_array_wrapper: bool, + }, + Xml { + for_xml: ForXml, + elements: bool, + binary_base64: bool, + root: Option, + r#type: bool, + }, +} + +impl fmt::Display for ForClause { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ForClause::Browse => write!(f, "FOR BROWSE"), + ForClause::Json { + for_json, + root, + include_null_values, + without_array_wrapper, + } => { + write!(f, "FOR JSON ")?; + write!(f, "{}", for_json)?; + if let Some(root) = root { + write!(f, ", ROOT('{}')", root)?; + } + if *include_null_values { + write!(f, ", INCLUDE_NULL_VALUES")?; + } + if *without_array_wrapper { + write!(f, ", WITHOUT_ARRAY_WRAPPER")?; + } + Ok(()) + } + ForClause::Xml { + for_xml, + elements, + binary_base64, + root, + r#type, + } => { + write!(f, "FOR XML ")?; + write!(f, "{}", for_xml)?; + if *binary_base64 { + write!(f, ", BINARY BASE64")?; + } + if *r#type { + write!(f, ", TYPE")?; + } + if let Some(root) = root { + write!(f, ", ROOT('{}')", root)?; + } + if *elements { + write!(f, ", ELEMENTS")?; + } + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ForXml { + Raw(Option), + Auto, + Explicit, + Path(Option), +} + +impl fmt::Display for ForXml { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ForXml::Raw(root) => { + write!(f, "RAW")?; + if let Some(root) = root { + write!(f, "('{}')", root)?; + } + Ok(()) + } + ForXml::Auto => write!(f, "AUTO"), + ForXml::Explicit => write!(f, "EXPLICIT"), + ForXml::Path(root) => { + write!(f, "PATH")?; + if let Some(root) = root { + write!(f, "('{}')", root)?; + } + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ForJson { + Auto, + Path, +} + +impl fmt::Display for ForJson { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ForJson::Auto => write!(f, "AUTO"), + ForJson::Path => write!(f, "PATH"), + } + } +} diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 1fb64526..d26f4110 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -12,7 +12,7 @@ //! Recursive visitors for ast Nodes. See [`Visitor`] for more details. -use crate::ast::{Expr, FunctionArgExpr, ObjectName, SetExpr, Statement, TableFactor}; +use crate::ast::{Expr, FunctionArgExpr, ObjectName, Query, SetExpr, Statement, TableFactor}; use core::ops::ControlFlow; /// A type that can be visited by a [`Visitor`]. See [`Visitor`] for @@ -179,6 +179,16 @@ pub trait Visitor { /// Type returned when the recursion returns early. type Break; + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, _query: &Query) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -286,6 +296,16 @@ pub trait VisitorMut { /// Type returned when the recursion returns early. type Break; + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -736,6 +756,18 @@ mod tests { impl Visitor for TestVisitor { type Break = (); + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, query: &Query) -> ControlFlow { + self.visited.push(format!("PRE: QUERY: {query}")); + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, query: &Query) -> ControlFlow { + self.visited.push(format!("POST: QUERY: {query}")); + ControlFlow::Continue(()) + } + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { self.visited.push(format!("PRE: RELATION: {relation}")); ControlFlow::Continue(()) @@ -805,10 +837,12 @@ mod tests { "SELECT * from table_name as my_table", vec![ "PRE: STATEMENT: SELECT * FROM table_name AS my_table", + "PRE: QUERY: SELECT * FROM table_name AS my_table", "PRE: TABLE FACTOR: table_name AS my_table", "PRE: RELATION: table_name", "POST: RELATION: table_name", "POST: TABLE FACTOR: table_name AS my_table", + "POST: QUERY: SELECT * FROM table_name AS my_table", "POST: STATEMENT: SELECT * FROM table_name AS my_table", ], ), @@ -816,6 +850,7 @@ mod tests { "SELECT * from t1 join t2 on t1.id = t2.t1_id", vec![ "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", + "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", @@ -830,6 +865,7 @@ mod tests { "PRE: EXPR: t2.t1_id", "POST: EXPR: t2.t1_id", "POST: EXPR: t1.id = t2.t1_id", + "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", ], ), @@ -837,18 +873,22 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2)", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], ), @@ -856,18 +896,22 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2)", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], ), @@ -875,25 +919,54 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t3", "PRE: RELATION: t3", "POST: RELATION: t3", "POST: TABLE FACTOR: t3", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", ], ), + ( + concat!( + "SELECT * FROM monthly_sales ", + "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ", + "ORDER BY EMPID" + ), + vec![ + "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", + "PRE: TABLE FACTOR: monthly_sales", + "PRE: RELATION: monthly_sales", + "POST: RELATION: monthly_sales", + "POST: TABLE FACTOR: monthly_sales", + "PRE: EXPR: SUM(a.amount)", + "PRE: EXPR: a.amount", + "POST: EXPR: a.amount", + "POST: EXPR: SUM(a.amount)", + "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", + "PRE: EXPR: EMPID", + "POST: EXPR: EMPID", + "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + ] + ) ]; for (sql, expected) in tests { let actual = do_visit(sql); diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 856cfe1c..53bb891d 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -124,6 +124,15 @@ pub trait Dialect: Debug + Any { fn supports_substring_from_for_expr(&self) -> bool { true } + /// Returns true if the dialect supports `(NOT) IN ()` expressions + fn supports_in_empty_list(&self) -> bool { + false + } + /// Returns true if the dialect has a CONVERT function which accepts a type first + /// and an expression second, e.g. `CONVERT(varchar, 1)` + fn convert_type_before_value(&self) -> bool { + false + } /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { // return None to fall back to the default behavior diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 26ecd478..c7bf1186 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -35,6 +35,12 @@ impl Dialect for MsSqlDialect { || ch == '_' } + /// SQL Server has `CONVERT(type, value)` instead of `CONVERT(value, type)` + /// + fn convert_type_before_value(&self) -> bool { + true + } + fn supports_substring_from_for_expr(&self) -> bool { false } diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs index 73457ab3..8dc7d573 100644 --- a/src/dialect/redshift.rs +++ b/src/dialect/redshift.rs @@ -53,4 +53,10 @@ impl Dialect for RedshiftSqlDialect { // Extends Postgres dialect with sharp PostgreSqlDialect {}.is_identifier_part(ch) || ch == '#' } + + /// redshift has `CONVERT(type, value)` instead of `CONVERT(value, type)` + /// + fn convert_type_before_value(&self) -> bool { + true + } } diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 68515d24..c9e9ab18 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -52,4 +52,8 @@ impl Dialect for SQLiteDialect { None } } + + fn supports_in_empty_list(&self) -> bool { + true + } } diff --git a/src/keywords.rs b/src/keywords.rs index 9f5d11ce..e80d91a9 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -97,11 +97,13 @@ define_keywords!( ATOMIC, ATTACH, AUTHORIZATION, + AUTO, AUTOINCREMENT, AUTO_INCREMENT, AVG, AVRO, BACKWARD, + BASE64, BEGIN, BEGIN_FRAME, BEGIN_PARTITION, @@ -117,6 +119,7 @@ define_keywords!( BOOL, BOOLEAN, BOTH, + BROWSE, BTREE, BY, BYPASSRLS, @@ -233,6 +236,7 @@ define_keywords!( DYNAMIC, EACH, ELEMENT, + ELEMENTS, ELSE, ENCODING, ENCRYPTION, @@ -259,6 +263,7 @@ define_keywords!( EXP, EXPANSION, EXPLAIN, + EXPLICIT, EXTENDED, EXTERNAL, EXTRACT, @@ -322,6 +327,7 @@ define_keywords!( IMMUTABLE, IN, INCLUDE, + INCLUDE_NULL_VALUES, INCREMENT, INDEX, INDICATOR, @@ -470,6 +476,7 @@ define_keywords!( PARTITIONED, PARTITIONS, PASSWORD, + PATH, PATTERN, PAUSE, PEER, @@ -504,6 +511,7 @@ define_keywords!( QUOTE, RANGE, RANK, + RAW, RCFILE, READ, READS, @@ -547,6 +555,7 @@ define_keywords!( ROLE, ROLLBACK, ROLLUP, + ROOT, ROW, ROWID, ROWS, @@ -697,8 +706,10 @@ define_keywords!( WITH, WITHIN, WITHOUT, + WITHOUT_ARRAY_WRAPPER, WORK, WRITE, + XML, XOR, YEAR, ZONE, @@ -747,6 +758,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[ Keyword::QUALIFY, Keyword::WINDOW, Keyword::END, + Keyword::FOR, // for MYSQL PARTITION SELECTION Keyword::PARTITION, ]; diff --git a/src/lib.rs b/src/lib.rs index 5bcd3294..5afdfbc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ //! 2. [`ast`] for the AST structure //! 3. [`Dialect`] for supported SQL dialects //! -//! # Example +//! # Example parsing SQL text //! //! ``` //! use sqlparser::dialect::GenericDialect; @@ -39,6 +39,24 @@ //! println!("AST: {:?}", ast); //! ``` //! +//! # Creating SQL text from AST +//! +//! This crate allows users to recover the original SQL text (with normalized +//! whitespace and identifier capitalization), which is useful for tools that +//! analyze and manipulate SQL. +//! +//! ``` +//! # use sqlparser::dialect::GenericDialect; +//! # use sqlparser::parser::Parser; +//! let sql = "SELECT a FROM table_1"; +//! +//! // parse to a Vec +//! let ast = Parser::parse_sql(&GenericDialect, sql).unwrap(); +//! +//! // The original SQL text can be generated from the AST +//! assert_eq!(ast[0].to_string(), sql); +//! ``` +//! //! [sqlparser crates.io page]: https://crates.io/crates/sqlparser //! [`Parser::parse_sql`]: crate::parser::Parser::parse_sql //! [`Parser::new`]: crate::parser::Parser::new diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 0b9c9aa4..49760435 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -420,10 +420,10 @@ impl<'a> Parser<'a> { Token::EOF => break, // end of statement - Token::Word(word) - if word.keyword == Keyword::END && !dialect_of!(self is PostgreSqlDialect) => - { - break + Token::Word(word) => { + if expecting_statement_delimiter && word.keyword == Keyword::END { + break; + } } _ => {} } @@ -505,9 +505,13 @@ impl<'a> Parser<'a> { // standard `START TRANSACTION` statement. It is supported // by at least PostgreSQL and MySQL. Keyword::BEGIN => Ok(self.parse_begin()?), + // `END` is a nonstandard but common alias for the + // standard `COMMIT TRANSACTION` statement. It is supported + // by PostgreSQL. + Keyword::END => Ok(self.parse_end()?), Keyword::SAVEPOINT => Ok(self.parse_savepoint()?), + Keyword::RELEASE => Ok(self.parse_release()?), Keyword::COMMIT => Ok(self.parse_commit()?), - Keyword::END if dialect_of!(self is PostgreSqlDialect) => Ok(self.parse_commit()?), Keyword::ROLLBACK => Ok(self.parse_rollback()?), Keyword::ABORT if dialect_of!(self is PostgreSqlDialect) => { Ok(self.parse_rollback()?) @@ -758,6 +762,13 @@ impl<'a> Parser<'a> { Ok(Statement::Savepoint { name }) } + pub fn parse_release(&mut self) -> Result { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + let name = self.parse_identifier()?; + + Ok(Statement::ReleaseSavepoint { name }) + } + /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { // allow the dialect to override prefix parsing @@ -832,6 +843,7 @@ impl<'a> Parser<'a> { self.parse_time_functions(ObjectName(vec![w.to_ident()])) } Keyword::CASE => self.parse_case_expr(), + Keyword::CONVERT => self.parse_convert_expr(), Keyword::CAST => self.parse_cast_expr(), Keyword::TRY_CAST => self.parse_try_cast_expr(), Keyword::SAFE_CAST => self.parse_safe_cast_expr(), @@ -1238,6 +1250,57 @@ impl<'a> Parser<'a> { } } + /// mssql-like convert function + fn parse_mssql_convert(&mut self) -> Result { + self.expect_token(&Token::LParen)?; + let data_type = self.parse_data_type()?; + self.expect_token(&Token::Comma)?; + let expr = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Ok(Expr::Convert { + expr: Box::new(expr), + data_type: Some(data_type), + charset: None, + target_before_value: true, + }) + } + + /// Parse a SQL CONVERT function: + /// - `CONVERT('héhé' USING utf8mb4)` (MySQL) + /// - `CONVERT('héhé', CHAR CHARACTER SET utf8mb4)` (MySQL) + /// - `CONVERT(DECIMAL(10, 5), 42)` (MSSQL) - the type comes first + pub fn parse_convert_expr(&mut self) -> Result { + if self.dialect.convert_type_before_value() { + return self.parse_mssql_convert(); + } + self.expect_token(&Token::LParen)?; + let expr = self.parse_expr()?; + if self.parse_keyword(Keyword::USING) { + let charset = self.parse_object_name()?; + self.expect_token(&Token::RParen)?; + return Ok(Expr::Convert { + expr: Box::new(expr), + data_type: None, + charset: Some(charset), + target_before_value: false, + }); + } + self.expect_token(&Token::Comma)?; + let data_type = self.parse_data_type()?; + let charset = if self.parse_keywords(&[Keyword::CHARACTER, Keyword::SET]) { + Some(self.parse_object_name()?) + } else { + None + }; + self.expect_token(&Token::RParen)?; + Ok(Expr::Convert { + expr: Box::new(expr), + data_type: Some(data_type), + charset, + target_before_value: false, + }) + } + /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expr(&mut self) -> Result { self.expect_token(&Token::LParen)?; @@ -2355,7 +2418,11 @@ impl<'a> Parser<'a> { } else { Expr::InList { expr: Box::new(expr), - list: self.parse_comma_separated(Parser::parse_expr)?, + list: if self.dialect.supports_in_empty_list() { + self.parse_comma_separated0(Parser::parse_expr)? + } else { + self.parse_comma_separated(Parser::parse_expr)? + }, negated, } }; @@ -2721,6 +2788,27 @@ impl<'a> Parser<'a> { Ok(values) } + /// Parse a comma-separated list of 0+ items accepted by `F` + pub fn parse_comma_separated0(&mut self, f: F) -> Result, ParserError> + where + F: FnMut(&mut Parser<'a>) -> Result, + { + // () + if matches!(self.peek_token().token, Token::RParen) { + return Ok(vec![]); + } + // (,) + if self.options.trailing_commas + && matches!(self.peek_nth_token(0).token, Token::Comma) + && matches!(self.peek_nth_token(1).token, Token::RParen) + { + let _ = self.consume_token(&Token::Comma); + return Ok(vec![]); + } + + self.parse_comma_separated(f) + } + /// Run a parser method `f`, reverting back to the current position /// if unsuccessful. #[must_use] @@ -4249,6 +4337,7 @@ impl<'a> Parser<'a> { generated_as: GeneratedAs::Always, sequence_options: Some(sequence_options), generation_expr: None, + generation_expr_mode: None, })) } else if self.parse_keywords(&[ Keyword::BY, @@ -4265,16 +4354,31 @@ impl<'a> Parser<'a> { generated_as: GeneratedAs::ByDefault, sequence_options: Some(sequence_options), generation_expr: None, + generation_expr_mode: None, })) } else if self.parse_keywords(&[Keyword::ALWAYS, Keyword::AS]) { if self.expect_token(&Token::LParen).is_ok() { let expr = self.parse_expr()?; self.expect_token(&Token::RParen)?; - let _ = self.parse_keywords(&[Keyword::STORED]); + let (gen_as, expr_mode) = if self.parse_keywords(&[Keyword::STORED]) { + Ok(( + GeneratedAs::ExpStored, + Some(GeneratedExpressionMode::Stored), + )) + } else if dialect_of!(self is PostgreSqlDialect) { + // Postgres' AS IDENTITY branches are above, this one needs STORED + self.expected("STORED", self.peek_token()) + } else if self.parse_keywords(&[Keyword::VIRTUAL]) { + Ok((GeneratedAs::Always, Some(GeneratedExpressionMode::Virtual))) + } else { + Ok((GeneratedAs::Always, None)) + }?; + Ok(Some(ColumnOption::Generated { - generated_as: GeneratedAs::ExpStored, + generated_as: gen_as, sequence_options: None, generation_expr: Some(expr), + generation_expr_mode: expr_mode, })) } else { Ok(None) @@ -5642,6 +5746,9 @@ impl<'a> Parser<'a> { } pub fn parse_character_length(&mut self) -> Result { + if self.parse_keyword(Keyword::MAX) { + return Ok(CharacterLength::Max); + } let length = self.parse_literal_uint()?; let unit = if self.parse_keyword(Keyword::CHARACTERS) { Some(CharLengthUnits::Characters) @@ -5650,8 +5757,7 @@ impl<'a> Parser<'a> { } else { None }; - - Ok(CharacterLength { length, unit }) + Ok(CharacterLength::IntegerLength { length, unit }) } pub fn parse_optional_precision_scale( @@ -5848,6 +5954,7 @@ impl<'a> Parser<'a> { offset: None, fetch: None, locks: vec![], + for_clause: None, }) } else if self.parse_keyword(Keyword::UPDATE) { let update = self.parse_update()?; @@ -5860,6 +5967,7 @@ impl<'a> Parser<'a> { offset: None, fetch: None, locks: vec![], + for_clause: None, }) } else { let body = Box::new(self.parse_query_body(0)?); @@ -5911,9 +6019,15 @@ impl<'a> Parser<'a> { None }; + let mut for_clause = None; let mut locks = Vec::new(); while self.parse_keyword(Keyword::FOR) { - locks.push(self.parse_lock()?); + if let Some(parsed_for_clause) = self.parse_for_clause()? { + for_clause = Some(parsed_for_clause); + break; + } else { + locks.push(self.parse_lock()?); + } } Ok(Query { @@ -5925,10 +6039,113 @@ impl<'a> Parser<'a> { offset, fetch, locks, + for_clause, }) } } + /// Parse a mssql `FOR [XML | JSON | BROWSE]` clause + pub fn parse_for_clause(&mut self) -> Result, ParserError> { + if self.parse_keyword(Keyword::XML) { + Ok(Some(self.parse_for_xml()?)) + } else if self.parse_keyword(Keyword::JSON) { + Ok(Some(self.parse_for_json()?)) + } else if self.parse_keyword(Keyword::BROWSE) { + Ok(Some(ForClause::Browse)) + } else { + Ok(None) + } + } + + /// Parse a mssql `FOR XML` clause + pub fn parse_for_xml(&mut self) -> Result { + let for_xml = if self.parse_keyword(Keyword::RAW) { + let mut element_name = None; + if self.peek_token().token == Token::LParen { + self.expect_token(&Token::LParen)?; + element_name = Some(self.parse_literal_string()?); + self.expect_token(&Token::RParen)?; + } + ForXml::Raw(element_name) + } else if self.parse_keyword(Keyword::AUTO) { + ForXml::Auto + } else if self.parse_keyword(Keyword::EXPLICIT) { + ForXml::Explicit + } else if self.parse_keyword(Keyword::PATH) { + let mut element_name = None; + if self.peek_token().token == Token::LParen { + self.expect_token(&Token::LParen)?; + element_name = Some(self.parse_literal_string()?); + self.expect_token(&Token::RParen)?; + } + ForXml::Path(element_name) + } else { + return Err(ParserError::ParserError( + "Expected FOR XML [RAW | AUTO | EXPLICIT | PATH ]".to_string(), + )); + }; + let mut elements = false; + let mut binary_base64 = false; + let mut root = None; + let mut r#type = false; + while self.peek_token().token == Token::Comma { + self.next_token(); + if self.parse_keyword(Keyword::ELEMENTS) { + elements = true; + } else if self.parse_keyword(Keyword::BINARY) { + self.expect_keyword(Keyword::BASE64)?; + binary_base64 = true; + } else if self.parse_keyword(Keyword::ROOT) { + self.expect_token(&Token::LParen)?; + root = Some(self.parse_literal_string()?); + self.expect_token(&Token::RParen)?; + } else if self.parse_keyword(Keyword::TYPE) { + r#type = true; + } + } + Ok(ForClause::Xml { + for_xml, + elements, + binary_base64, + root, + r#type, + }) + } + + /// Parse a mssql `FOR JSON` clause + pub fn parse_for_json(&mut self) -> Result { + let for_json = if self.parse_keyword(Keyword::AUTO) { + ForJson::Auto + } else if self.parse_keyword(Keyword::PATH) { + ForJson::Path + } else { + return Err(ParserError::ParserError( + "Expected FOR JSON [AUTO | PATH ]".to_string(), + )); + }; + let mut root = None; + let mut include_null_values = false; + let mut without_array_wrapper = false; + while self.peek_token().token == Token::Comma { + self.next_token(); + if self.parse_keyword(Keyword::ROOT) { + self.expect_token(&Token::LParen)?; + root = Some(self.parse_literal_string()?); + self.expect_token(&Token::RParen)?; + } else if self.parse_keyword(Keyword::INCLUDE_NULL_VALUES) { + include_null_values = true; + } else if self.parse_keyword(Keyword::WITHOUT_ARRAY_WRAPPER) { + without_array_wrapper = true; + } + } + Ok(ForClause::Json { + for_json, + root, + include_null_values, + without_array_wrapper, + }) + } + /// Parse a CTE (`alias [( col1, col2, ... )] AS (subquery)`) pub fn parse_cte(&mut self) -> Result { let name = self.parse_identifier()?; @@ -6355,6 +6572,8 @@ impl<'a> Parser<'a> { pub fn parse_show(&mut self) -> Result { let extended = self.parse_keyword(Keyword::EXTENDED); let full = self.parse_keyword(Keyword::FULL); + let session = self.parse_keyword(Keyword::SESSION); + let global = self.parse_keyword(Keyword::GLOBAL); if self .parse_one_of_keywords(&[Keyword::COLUMNS, Keyword::FIELDS]) .is_some() @@ -6375,9 +6594,10 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::VARIABLES) && dialect_of!(self is MySqlDialect | GenericDialect) { - // TODO: Support GLOBAL|SESSION Ok(Statement::ShowVariables { filter: self.parse_show_statement_filter()?, + session, + global, }) } else { Ok(Statement::ShowVariable { @@ -6649,9 +6869,20 @@ impl<'a> Parser<'a> { // `parse_derived_table_factor` below will return success after parsing the // subquery, followed by the closing ')', and the alias of the derived table. // In the example above this is case (3). - return_ok_if_some!( + if let Some(mut table) = self.maybe_parse(|parser| parser.parse_derived_table_factor(NotLateral)) - ); + { + while let Some(kw) = self.parse_one_of_keywords(&[Keyword::PIVOT, Keyword::UNPIVOT]) + { + table = match kw { + Keyword::PIVOT => self.parse_pivot_table_factor(table)?, + Keyword::UNPIVOT => self.parse_unpivot_table_factor(table)?, + _ => unreachable!(), + } + } + return Ok(table); + } + // A parsing error from `parse_derived_table_factor` indicates that the '(' we've // recently consumed does not start a derived table (cases 1, 2, or 4). // `maybe_parse` will ignore such an error and rewind to be after the opening '('. @@ -7115,21 +7346,23 @@ impl<'a> Parser<'a> { let table = self.parse_keyword(Keyword::TABLE); let table_name = self.parse_object_name()?; let is_mysql = dialect_of!(self is MySqlDialect); - let columns = self.parse_parenthesized_column_list(Optional, is_mysql)?; - let partitioned = if self.parse_keyword(Keyword::PARTITION) { - self.expect_token(&Token::LParen)?; - let r = Some(self.parse_comma_separated(Parser::parse_expr)?); - self.expect_token(&Token::RParen)?; - r - } else { - None - }; + let (columns, partitioned, after_columns, source) = + if self.parse_keywords(&[Keyword::DEFAULT, Keyword::VALUES]) { + (vec![], None, vec![], None) + } else { + let columns = self.parse_parenthesized_column_list(Optional, is_mysql)?; - // Hive allows you to specify columns after partitions as well if you want. - let after_columns = self.parse_parenthesized_column_list(Optional, false)?; + let partitioned = self.parse_insert_partition()?; + + // Hive allows you to specify columns after partitions as well if you want. + let after_columns = self.parse_parenthesized_column_list(Optional, false)?; + + let source = Some(Box::new(self.parse_query()?)); + + (columns, partitioned, after_columns, source) + }; - let source = Box::new(self.parse_query()?); let on = if self.parse_keyword(Keyword::ON) { if self.parse_keyword(Keyword::CONFLICT) { let conflict_target = @@ -7200,6 +7433,17 @@ impl<'a> Parser<'a> { } } + pub fn parse_insert_partition(&mut self) -> Result>, ParserError> { + if self.parse_keyword(Keyword::PARTITION) { + self.expect_token(&Token::LParen)?; + let partition_cols = Some(self.parse_comma_separated(Parser::parse_expr)?); + self.expect_token(&Token::RParen)?; + Ok(partition_cols) + } else { + Ok(None) + } + } + pub fn parse_update(&mut self) -> Result { let table = self.parse_table_and_joins()?; self.expect_keyword(Keyword::SET)?; @@ -7646,6 +7890,12 @@ impl<'a> Parser<'a> { }) } + pub fn parse_end(&mut self) -> Result { + Ok(Statement::Commit { + chain: self.parse_commit_rollback_chain()?, + }) + } + pub fn parse_transaction_modes(&mut self) -> Result, ParserError> { let mut modes = vec![]; let mut required = false; @@ -7689,9 +7939,10 @@ impl<'a> Parser<'a> { } pub fn parse_rollback(&mut self) -> Result { - Ok(Statement::Rollback { - chain: self.parse_commit_rollback_chain()?, - }) + let chain = self.parse_commit_rollback_chain()?; + let savepoint = self.parse_rollback_savepoint()?; + + Ok(Statement::Rollback { chain, savepoint }) } pub fn parse_commit_rollback_chain(&mut self) -> Result { @@ -7705,6 +7956,17 @@ impl<'a> Parser<'a> { } } + pub fn parse_rollback_savepoint(&mut self) -> Result, ParserError> { + if self.parse_keyword(Keyword::TO) { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + let savepoint = self.parse_identifier()?; + + Ok(Some(savepoint)) + } else { + Ok(None) + } + } + pub fn parse_deallocate(&mut self) -> Result { let prepare = self.parse_keyword(Keyword::PREPARE); let name = self.parse_identifier()?; @@ -8385,7 +8647,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER(20)", - DataType::Character(Some(CharacterLength { + DataType::Character(Some(CharacterLength::IntegerLength { length: 20, unit: None })) @@ -8394,7 +8656,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER(20 CHARACTERS)", - DataType::Character(Some(CharacterLength { + DataType::Character(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Characters) })) @@ -8403,7 +8665,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER(20 OCTETS)", - DataType::Character(Some(CharacterLength { + DataType::Character(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Octets) })) @@ -8414,7 +8676,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR(20)", - DataType::Char(Some(CharacterLength { + DataType::Char(Some(CharacterLength::IntegerLength { length: 20, unit: None })) @@ -8423,7 +8685,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR(20 CHARACTERS)", - DataType::Char(Some(CharacterLength { + DataType::Char(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Characters) })) @@ -8432,7 +8694,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR(20 OCTETS)", - DataType::Char(Some(CharacterLength { + DataType::Char(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Octets) })) @@ -8441,7 +8703,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER VARYING(20)", - DataType::CharacterVarying(Some(CharacterLength { + DataType::CharacterVarying(Some(CharacterLength::IntegerLength { length: 20, unit: None })) @@ -8450,7 +8712,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER VARYING(20 CHARACTERS)", - DataType::CharacterVarying(Some(CharacterLength { + DataType::CharacterVarying(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Characters) })) @@ -8459,7 +8721,7 @@ mod tests { test_parse_data_type!( dialect, "CHARACTER VARYING(20 OCTETS)", - DataType::CharacterVarying(Some(CharacterLength { + DataType::CharacterVarying(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Octets) })) @@ -8468,7 +8730,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR VARYING(20)", - DataType::CharVarying(Some(CharacterLength { + DataType::CharVarying(Some(CharacterLength::IntegerLength { length: 20, unit: None })) @@ -8477,7 +8739,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR VARYING(20 CHARACTERS)", - DataType::CharVarying(Some(CharacterLength { + DataType::CharVarying(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Characters) })) @@ -8486,7 +8748,7 @@ mod tests { test_parse_data_type!( dialect, "CHAR VARYING(20 OCTETS)", - DataType::CharVarying(Some(CharacterLength { + DataType::CharVarying(Some(CharacterLength::IntegerLength { length: 20, unit: Some(CharLengthUnits::Octets) })) @@ -8495,7 +8757,7 @@ mod tests { test_parse_data_type!( dialect, "VARCHAR(20)", - DataType::Varchar(Some(CharacterLength { + DataType::Varchar(Some(CharacterLength::IntegerLength { length: 20, unit: None })) diff --git a/src/test_utils.rs b/src/test_utils.rs index 76a3e073..26cfec46 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -102,6 +102,11 @@ impl TestedDialects { /// Ensures that `sql` parses as a single [Statement] for all tested /// dialects. /// + /// In general, the canonical SQL should be the same (see crate + /// documentation for rationale) and you should prefer the `verified_` + /// variants in testing, such as [`verified_statement`] or + /// [`verified_query`]. + /// /// If `canonical` is non empty,this function additionally asserts /// that: /// diff --git a/src/tokenizer.rs b/src/tokenizer.rs index ee350a1c..0400b21c 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -534,12 +534,7 @@ impl<'a> Tokenizer<'a> { /// Tokenize the statement and produce a vector of tokens pub fn tokenize(&mut self) -> Result, TokenizerError> { let twl = self.tokenize_with_location()?; - - let mut tokens: Vec = Vec::with_capacity(twl.len()); - for token_with_location in twl { - tokens.push(token_with_location.token); - } - Ok(tokens) + Ok(twl.into_iter().map(|t| t.token).collect()) } /// Tokenize the statement and produce a vector of tokens with location information diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index befdf512..1d0923b4 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -85,7 +85,7 @@ fn parse_insert_values() { Statement::Insert { table_name, columns, - source, + source: Some(source), .. } => { assert_eq!(table_name.to_string(), expected_table_name); @@ -93,7 +93,7 @@ fn parse_insert_values() { for (index, column) in columns.iter().enumerate() { assert_eq!(column, &Ident::new(expected_columns[index].clone())); } - match &*source.body { + match *source.body { SetExpr::Values(Values { rows, .. }) => { assert_eq!(rows.as_slice(), expected_rows) } @@ -107,6 +107,111 @@ fn parse_insert_values() { verified_stmt("INSERT INTO customer WITH foo AS (SELECT 1) SELECT * FROM foo UNION VALUES (1)"); } +#[test] +fn parse_insert_default_values() { + let insert_with_default_values = verified_stmt("INSERT INTO test_table DEFAULT VALUES"); + + match insert_with_default_values { + Statement::Insert { + after_columns, + columns, + on, + partitioned, + returning, + source, + table_name, + .. + } => { + assert_eq!(columns, vec![]); + assert_eq!(after_columns, vec![]); + assert_eq!(on, None); + assert_eq!(partitioned, None); + assert_eq!(returning, None); + assert_eq!(source, None); + assert_eq!(table_name, ObjectName(vec!["test_table".into()])); + } + _ => unreachable!(), + } + + let insert_with_default_values_and_returning = + verified_stmt("INSERT INTO test_table DEFAULT VALUES RETURNING test_column"); + + match insert_with_default_values_and_returning { + Statement::Insert { + after_columns, + columns, + on, + partitioned, + returning, + source, + table_name, + .. + } => { + assert_eq!(after_columns, vec![]); + assert_eq!(columns, vec![]); + assert_eq!(on, None); + assert_eq!(partitioned, None); + assert!(returning.is_some()); + assert_eq!(source, None); + assert_eq!(table_name, ObjectName(vec!["test_table".into()])); + } + _ => unreachable!(), + } + + let insert_with_default_values_and_on_conflict = + verified_stmt("INSERT INTO test_table DEFAULT VALUES ON CONFLICT DO NOTHING"); + + match insert_with_default_values_and_on_conflict { + Statement::Insert { + after_columns, + columns, + on, + partitioned, + returning, + source, + table_name, + .. + } => { + assert_eq!(after_columns, vec![]); + assert_eq!(columns, vec![]); + assert!(on.is_some()); + assert_eq!(partitioned, None); + assert_eq!(returning, None); + assert_eq!(source, None); + assert_eq!(table_name, ObjectName(vec!["test_table".into()])); + } + _ => unreachable!(), + } + + let insert_with_columns_and_default_values = "INSERT INTO test_table (test_col) DEFAULT VALUES"; + assert_eq!( + ParserError::ParserError( + "Expected SELECT, VALUES, or a subquery in the query body, found: DEFAULT".to_string() + ), + parse_sql_statements(insert_with_columns_and_default_values).unwrap_err() + ); + + let insert_with_default_values_and_hive_after_columns = + "INSERT INTO test_table DEFAULT VALUES (some_column)"; + assert_eq!( + ParserError::ParserError("Expected end of statement, found: (".to_string()), + parse_sql_statements(insert_with_default_values_and_hive_after_columns).unwrap_err() + ); + + let insert_with_default_values_and_hive_partition = + "INSERT INTO test_table DEFAULT VALUES PARTITION (some_column)"; + assert_eq!( + ParserError::ParserError("Expected end of statement, found: PARTITION".to_string()), + parse_sql_statements(insert_with_default_values_and_hive_partition).unwrap_err() + ); + + let insert_with_default_values_and_values_list = "INSERT INTO test_table DEFAULT VALUES (1)"; + assert_eq!( + ParserError::ParserError("Expected end of statement, found: (".to_string()), + parse_sql_statements(insert_with_default_values_and_values_list).unwrap_err() + ); +} + #[test] fn parse_insert_sqlite() { let dialect = SQLiteDialect {}; @@ -270,6 +375,7 @@ fn parse_update_set_from() { offset: None, fetch: None, locks: vec![], + for_clause: None, }), alias: Some(TableAlias { name: Ident::new("t2"), @@ -2381,7 +2487,7 @@ fn parse_create_table() { vec![ ColumnDef { name: "name".into(), - data_type: DataType::Varchar(Some(CharacterLength { + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 100, unit: None, })), @@ -2756,6 +2862,7 @@ fn parse_create_table_as_table() { offset: None, fetch: None, locks: vec![], + for_clause: None, }); match verified_stmt(sql1) { @@ -2780,6 +2887,7 @@ fn parse_create_table_as_table() { offset: None, fetch: None, locks: vec![], + for_clause: None, }); match verified_stmt(sql2) { @@ -2929,7 +3037,7 @@ fn parse_create_external_table() { vec![ ColumnDef { name: "name".into(), - data_type: DataType::Varchar(Some(CharacterLength { + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 100, unit: None, })), @@ -3000,7 +3108,7 @@ fn parse_create_or_replace_external_table() { columns, vec![ColumnDef { name: "name".into(), - data_type: DataType::Varchar(Some(CharacterLength { + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 100, unit: None, })), @@ -4082,6 +4190,7 @@ fn parse_interval_and_or_xor() { offset: None, fetch: None, locks: vec![], + for_clause: None, }))]; assert_eq!(actual_ast, expected_ast); @@ -6227,15 +6336,52 @@ fn parse_commit() { one_statement_parses_to("COMMIT TRANSACTION", "COMMIT"); } +#[test] +fn parse_end() { + one_statement_parses_to("END AND NO CHAIN", "COMMIT"); + one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT"); + one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT"); + one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN"); + one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN"); + one_statement_parses_to("END WORK", "COMMIT"); + one_statement_parses_to("END TRANSACTION", "COMMIT"); +} + #[test] fn parse_rollback() { match verified_stmt("ROLLBACK") { - Statement::Rollback { chain: false } => (), + Statement::Rollback { + chain: false, + savepoint: None, + } => (), _ => unreachable!(), } match verified_stmt("ROLLBACK AND CHAIN") { - Statement::Rollback { chain: true } => (), + Statement::Rollback { + chain: true, + savepoint: None, + } => (), + _ => unreachable!(), + } + + match verified_stmt("ROLLBACK TO SAVEPOINT test1") { + Statement::Rollback { + chain: false, + savepoint, + } => { + assert_eq!(savepoint, Some(Ident::new("test1"))); + } + _ => unreachable!(), + } + + match verified_stmt("ROLLBACK AND CHAIN TO SAVEPOINT test1") { + Statement::Rollback { + chain: true, + savepoint, + } => { + assert_eq!(savepoint, Some(Ident::new("test1"))); + } _ => unreachable!(), } @@ -6246,6 +6392,11 @@ fn parse_rollback() { one_statement_parses_to("ROLLBACK TRANSACTION AND CHAIN", "ROLLBACK AND CHAIN"); one_statement_parses_to("ROLLBACK WORK", "ROLLBACK"); one_statement_parses_to("ROLLBACK TRANSACTION", "ROLLBACK"); + one_statement_parses_to("ROLLBACK TO test1", "ROLLBACK TO SAVEPOINT test1"); + one_statement_parses_to( + "ROLLBACK AND CHAIN TO test1", + "ROLLBACK AND CHAIN TO SAVEPOINT test1", + ); } #[test] @@ -6656,6 +6807,7 @@ fn parse_merge() { offset: None, fetch: None, locks: vec![], + for_clause: None, }), alias: Some(TableAlias { name: Ident { @@ -7859,3 +8011,25 @@ fn parse_binary_operators_without_whitespace() { "SELECT tbl1.field % tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id", ); } + +#[test] +fn test_savepoint() { + match verified_stmt("SAVEPOINT test1") { + Statement::Savepoint { name } => { + assert_eq!(Ident::new("test1"), name); + } + _ => unreachable!(), + } +} + +#[test] +fn test_release_savepoint() { + match verified_stmt("RELEASE SAVEPOINT test1") { + Statement::ReleaseSavepoint { name } => { + assert_eq!(Ident::new("test1"), name); + } + _ => unreachable!(), + } + + one_statement_parses_to("RELEASE test1", "RELEASE SAVEPOINT test1"); +} diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 4aa993fa..7d5beca9 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -96,6 +96,7 @@ fn parse_create_procedure() { offset: None, fetch: None, locks: vec![], + for_clause: None, order_by: vec![], body: Box::new(SetExpr::Select(Box::new(Select { distinct: None, @@ -127,7 +128,7 @@ fn parse_create_procedure() { value: "@bar".into(), quote_style: None }, - data_type: DataType::Varchar(Some(CharacterLength { + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 256, unit: None })) @@ -431,6 +432,56 @@ fn parse_like() { chk(true); } +#[test] +fn parse_for_clause() { + ms_and_generic().verified_stmt("SELECT a FROM t FOR JSON PATH"); + ms_and_generic().verified_stmt("SELECT b FROM t FOR JSON AUTO"); + ms_and_generic().verified_stmt("SELECT c FROM t FOR JSON AUTO, WITHOUT_ARRAY_WRAPPER"); + ms_and_generic().verified_stmt("SELECT 1 FROM t FOR JSON PATH, ROOT('x'), INCLUDE_NULL_VALUES"); + ms_and_generic().verified_stmt("SELECT 2 FROM t FOR XML AUTO"); + ms_and_generic().verified_stmt("SELECT 3 FROM t FOR XML AUTO, TYPE, ELEMENTS"); + ms_and_generic().verified_stmt("SELECT * FROM t WHERE x FOR XML AUTO, ELEMENTS"); + ms_and_generic().verified_stmt("SELECT x FROM t ORDER BY y FOR XML AUTO, ELEMENTS"); + ms_and_generic().verified_stmt("SELECT y FROM t FOR XML PATH('x'), ROOT('y'), ELEMENTS"); + ms_and_generic().verified_stmt("SELECT z FROM t FOR XML EXPLICIT, BINARY BASE64"); + ms_and_generic().verified_stmt("SELECT * FROM t FOR XML RAW('x')"); + ms_and_generic().verified_stmt("SELECT * FROM t FOR BROWSE"); +} + +#[test] +fn dont_parse_trailing_for() { + assert!(ms() + .run_parser_method("SELECT * FROM foo FOR", |p| p.parse_query()) + .is_err()); +} + +#[test] +fn parse_for_json_expect_ast() { + assert_eq!( + ms().verified_query("SELECT * FROM t FOR JSON PATH, ROOT('root')") + .for_clause + .unwrap(), + ForClause::Json { + for_json: ForJson::Path, + root: Some("root".into()), + without_array_wrapper: false, + include_null_values: false, + } + ); +} + +#[test] +fn parse_cast_varchar_max() { + ms_and_generic().verified_expr("CAST('foo' AS VARCHAR(MAX))"); +} + +#[test] +fn parse_convert() { + ms().verified_expr("CONVERT(VARCHAR(MAX), 'foo')"); + ms().verified_expr("CONVERT(VARCHAR(10), 'foo')"); + ms().verified_expr("CONVERT(DECIMAL(10,5), 12.55)"); +} + #[test] fn parse_similar_to() { fn chk(negated: bool) { @@ -540,6 +591,7 @@ fn parse_substring_in_select() { offset: None, fetch: None, locks: vec![], + for_clause: None, }), query ); diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index f1acf3c0..1c03c5c4 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -507,6 +507,18 @@ fn parse_create_table_comment_character_set() { } } +#[test] +fn parse_create_table_gencol() { + let sql_default = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2))"; + mysql_and_generic().verified_stmt(sql_default); + + let sql_virt = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2) VIRTUAL)"; + mysql_and_generic().verified_stmt(sql_virt); + + let sql_stored = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2) STORED)"; + mysql_and_generic().verified_stmt(sql_stored); +} + #[test] fn parse_quote_identifiers() { let sql = "CREATE TABLE `PRIMARY` (`BEGIN` INT PRIMARY KEY)"; @@ -566,6 +578,7 @@ fn parse_escaped_quote_identifiers_with_escape() { offset: None, fetch: None, locks: vec![], + for_clause: None, })) ); } @@ -609,6 +622,7 @@ fn parse_escaped_quote_identifiers_with_no_escape() { offset: None, fetch: None, locks: vec![], + for_clause: None, })) ); } @@ -649,6 +663,7 @@ fn parse_escaped_backticks_with_escape() { offset: None, fetch: None, locks: vec![], + for_clause: None, })) ); } @@ -689,6 +704,7 @@ fn parse_escaped_backticks_with_no_escape() { offset: None, fetch: None, locks: vec![], + for_clause: None, })) ); } @@ -937,7 +953,7 @@ fn parse_simple_insert() { assert_eq!(vec![Ident::new("title"), Ident::new("priority")], columns); assert!(on.is_none()); assert_eq!( - Box::new(Query { + Some(Box::new(Query { with: None, body: Box::new(SetExpr::Values(Values { explicit_row: false, @@ -964,7 +980,8 @@ fn parse_simple_insert() { offset: None, fetch: None, locks: vec![], - }), + for_clause: None, + })), source ); } @@ -990,7 +1007,7 @@ fn parse_ignore_insert() { assert!(on.is_none()); assert!(ignore); assert_eq!( - Box::new(Query { + Some(Box::new(Query { with: None, body: Box::new(SetExpr::Values(Values { explicit_row: false, @@ -1004,8 +1021,9 @@ fn parse_ignore_insert() { limit_by: vec![], offset: None, fetch: None, - locks: vec![] - }), + locks: vec![], + for_clause: None, + })), source ); } @@ -1029,7 +1047,7 @@ fn parse_empty_row_insert() { assert!(columns.is_empty()); assert!(on.is_none()); assert_eq!( - Box::new(Query { + Some(Box::new(Query { with: None, body: Box::new(SetExpr::Values(Values { explicit_row: false, @@ -1041,7 +1059,8 @@ fn parse_empty_row_insert() { offset: None, fetch: None, locks: vec![], - }), + for_clause: None, + })), source ); } @@ -1077,7 +1096,7 @@ fn parse_insert_with_on_duplicate_update() { columns ); assert_eq!( - Box::new(Query { + Some(Box::new(Query { with: None, body: Box::new(SetExpr::Values(Values { explicit_row: false, @@ -1100,7 +1119,8 @@ fn parse_insert_with_on_duplicate_update() { offset: None, fetch: None, locks: vec![], - }), + for_clause: None, + })), source ); assert_eq!( @@ -1490,6 +1510,7 @@ fn parse_substring_in_select() { offset: None, fetch: None, locks: vec![], + for_clause: None, }), query ); @@ -1503,6 +1524,12 @@ fn parse_show_variables() { mysql_and_generic().verified_stmt("SHOW VARIABLES"); mysql_and_generic().verified_stmt("SHOW VARIABLES LIKE 'admin%'"); mysql_and_generic().verified_stmt("SHOW VARIABLES WHERE value = '3306'"); + mysql_and_generic().verified_stmt("SHOW GLOBAL VARIABLES"); + mysql_and_generic().verified_stmt("SHOW GLOBAL VARIABLES LIKE 'admin%'"); + mysql_and_generic().verified_stmt("SHOW GLOBAL VARIABLES WHERE value = '3306'"); + mysql_and_generic().verified_stmt("SHOW SESSION VARIABLES"); + mysql_and_generic().verified_stmt("SHOW SESSION VARIABLES LIKE 'admin%'"); + mysql_and_generic().verified_stmt("SHOW GLOBAL VARIABLES WHERE value = '3306'"); } #[test] @@ -1785,6 +1812,7 @@ fn parse_hex_string_introducer() { offset: None, fetch: None, locks: vec![], + for_clause: None, })) ) } @@ -1827,3 +1855,18 @@ fn parse_drop_temporary_table() { _ => unreachable!(), } } + +#[test] +fn parse_convert_using() { + // https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_convert + + // CONVERT(expr USING transcoding_name) + mysql().verified_only_select("SELECT CONVERT('x' USING latin1)"); + mysql().verified_only_select("SELECT CONVERT(my_column USING utf8mb4) FROM my_table"); + + // CONVERT(expr, type) + mysql().verified_only_select("SELECT CONVERT('abc', CHAR(60))"); + mysql().verified_only_select("SELECT CONVERT(123.456, DECIMAL(5,2))"); + // with a type + a charset + mysql().verified_only_select("SELECT CONVERT('test', CHAR CHARACTER SET utf8mb4)"); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index d854db9b..c72696af 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -350,10 +350,12 @@ fn parse_create_table_with_defaults() { }, ColumnDef { name: "first_name".into(), - data_type: DataType::CharacterVarying(Some(CharacterLength { - length: 45, - unit: None - })), + data_type: DataType::CharacterVarying(Some( + CharacterLength::IntegerLength { + length: 45, + unit: None + } + )), collation: None, options: vec![ColumnOptionDef { name: None, @@ -362,10 +364,12 @@ fn parse_create_table_with_defaults() { }, ColumnDef { name: "last_name".into(), - data_type: DataType::CharacterVarying(Some(CharacterLength { - length: 45, - unit: None - })), + data_type: DataType::CharacterVarying(Some( + CharacterLength::IntegerLength { + length: 45, + unit: None + } + )), collation: Some(ObjectName(vec![Ident::with_quote('"', "es_ES")])), options: vec![ColumnOptionDef { name: None, @@ -374,10 +378,12 @@ fn parse_create_table_with_defaults() { }, ColumnDef { name: "email".into(), - data_type: DataType::CharacterVarying(Some(CharacterLength { - length: 50, - unit: None - })), + data_type: DataType::CharacterVarying(Some( + CharacterLength::IntegerLength { + length: 50, + unit: None + } + )), collation: None, options: vec![], }, @@ -1005,6 +1011,7 @@ fn parse_copy_to() { offset: None, fetch: None, locks: vec![], + for_clause: None, })), to: true, target: CopyTarget::File { @@ -1373,7 +1380,7 @@ fn parse_prepare() { Statement::Insert { table_name, columns, - source, + source: Some(source), .. } => { assert_eq!(table_name.to_string(), "customers"); @@ -1422,7 +1429,7 @@ fn parse_prepare() { fn parse_pg_on_conflict() { let stmt = pg_and_generic().verified_stmt( "INSERT INTO distributors (did, dname) \ - VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ + VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ ON CONFLICT(did) \ DO UPDATE SET dname = EXCLUDED.dname", ); @@ -1452,7 +1459,7 @@ fn parse_pg_on_conflict() { let stmt = pg_and_generic().verified_stmt( "INSERT INTO distributors (did, dname, area) \ - VALUES (5, 'Gizmo Transglobal', 'Mars'), (6, 'Associated Computing, Inc', 'Venus') \ + VALUES (5, 'Gizmo Transglobal', 'Mars'), (6, 'Associated Computing, Inc', 'Venus') \ ON CONFLICT(did, area) \ DO UPDATE SET dname = EXCLUDED.dname, area = EXCLUDED.area", ); @@ -1491,7 +1498,7 @@ fn parse_pg_on_conflict() { let stmt = pg_and_generic().verified_stmt( "INSERT INTO distributors (did, dname) \ - VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ + VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ ON CONFLICT DO NOTHING", ); match stmt { @@ -1510,7 +1517,7 @@ fn parse_pg_on_conflict() { let stmt = pg_and_generic().verified_stmt( "INSERT INTO distributors (did, dname, dsize) \ - VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ + VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ ON CONFLICT(did) \ DO UPDATE SET dname = $1 WHERE dsize > $2", ); @@ -1547,7 +1554,7 @@ fn parse_pg_on_conflict() { let stmt = pg_and_generic().verified_stmt( "INSERT INTO distributors (did, dname, dsize) \ - VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ + VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ ON CONFLICT ON CONSTRAINT distributors_did_pkey \ DO UPDATE SET dname = $1 WHERE dsize > $2", ); @@ -2055,6 +2062,7 @@ fn parse_array_subquery_expr() { offset: None, fetch: None, locks: vec![], + for_clause: None, })), expr_from_projection(only(&select.projection)), ); @@ -2086,16 +2094,6 @@ fn test_transaction_statement() { ); } -#[test] -fn test_savepoint() { - match pg().verified_stmt("SAVEPOINT test1") { - Statement::Savepoint { name } => { - assert_eq!(Ident::new("test1"), name); - } - _ => unreachable!(), - } -} - #[test] fn test_json() { let sql = "SELECT params ->> 'name' FROM events"; @@ -3945,7 +3943,7 @@ fn parse_mirror_table_mapping_v2_missing() { #[test] fn parse_abort() { match pg().verified_stmt("ROLLBACK") { - Statement::Rollback { chain: false } => (), + Statement::Rollback { chain: false, .. } => (), _ => unreachable!(), } pg().one_statement_parses_to("ABORT", "ROLLBACK"); diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 2b7e34cc..906327b1 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -1140,3 +1140,10 @@ fn parse_division_correctly() { "SELECT tbl1.field / tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id", ); } + +#[test] +fn parse_pivot_of_table_factor_derived() { + snowflake().verified_stmt( + "SELECT * FROM (SELECT place_id, weekday, open FROM times AS p) PIVOT(max(open) FOR weekday IN (0, 1, 2, 3, 4, 5, 6)) AS p (place_id, open_sun, open_mon, open_tue, open_wed, open_thu, open_fri, open_sat)" + ); +} diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 4935f1f5..cc0d53b1 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -22,6 +22,7 @@ use test_utils::*; use sqlparser::ast::SelectItem::UnnamedExpr; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, SQLiteDialect}; +use sqlparser::parser::ParserOptions; use sqlparser::tokenizer::Token; #[test] @@ -204,6 +205,18 @@ fn parse_create_sqlite_quote() { } } +#[test] +fn parse_create_table_gencol() { + let sql_default = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2))"; + sqlite_and_generic().verified_stmt(sql_default); + + let sql_virt = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2) VIRTUAL)"; + sqlite_and_generic().verified_stmt(sql_virt); + + let sql_stored = "CREATE TABLE t1 (a INT, b INT GENERATED ALWAYS AS (a * 2) STORED)"; + sqlite_and_generic().verified_stmt(sql_stored); +} + #[test] fn test_placeholder() { // In postgres, this would be the absolute value operator '@' applied to the column 'xxx' @@ -392,6 +405,32 @@ fn parse_attach_database() { } } +#[test] +fn parse_where_in_empty_list() { + let sql = "SELECT * FROM t1 WHERE a IN ()"; + let select = sqlite().verified_only_select(sql); + if let Expr::InList { list, .. } = select.selection.as_ref().unwrap() { + assert_eq!(list.len(), 0); + } else { + unreachable!() + } + + sqlite_with_options(ParserOptions::new().with_trailing_commas(true)).one_statement_parses_to( + "SELECT * FROM t1 WHERE a IN (,)", + "SELECT * FROM t1 WHERE a IN ()", + ); +} + +#[test] +fn invalid_empty_list() { + let sql = "SELECT * FROM t1 WHERE a IN (,,)"; + let sqlite = sqlite_with_options(ParserOptions::new().with_trailing_commas(true)); + assert_eq!( + "sql parser error: Expected an expression:, found: ,", + sqlite.parse_sql_statements(sql).unwrap_err().to_string() + ); +} + fn sqlite() -> TestedDialects { TestedDialects { dialects: vec![Box::new(SQLiteDialect {})], @@ -399,9 +438,15 @@ fn sqlite() -> TestedDialects { } } +fn sqlite_with_options(options: ParserOptions) -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(SQLiteDialect {})], + options: Some(options), + } +} + fn sqlite_and_generic() -> TestedDialects { TestedDialects { - // we don't have a separate SQLite dialect, so test only the generic dialect for now dialects: vec![Box::new(SQLiteDialect {}), Box::new(GenericDialect {})], options: None, }