Skip to content

Commit

Permalink
Merge pull request #4 from AikidoSec/dollar
Browse files Browse the repository at this point in the history
SQLite: Allow dollar signs in placeholder names
  • Loading branch information
willem-delbare authored Dec 30, 2024
2 parents 4fdeb5c + f125b1f commit 3f5fdeb
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 20 deletions.
13 changes: 8 additions & 5 deletions src/ast/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1123,15 +1123,18 @@ pub enum ColumnOption {
/// `DEFAULT <restricted-expr>`
Default(Expr),

/// ClickHouse supports `MATERIALIZE`, `EPHEMERAL` and `ALIAS` expr to generate default values.
/// `MATERIALIZE <expr>`
/// Syntax: `b INT MATERIALIZE (a + 1)`
///
/// [ClickHouse](https://clickhouse.com/docs/en/sql-reference/statements/create/table#default_values)
/// `MATERIALIZE <expr>`
Materialized(Expr),
/// `EPHEMERAL [<expr>]`
///
/// [ClickHouse](https://clickhouse.com/docs/en/sql-reference/statements/create/table#default_values)
Ephemeral(Option<Expr>),
/// `ALIAS <expr>`
///
/// [ClickHouse](https://clickhouse.com/docs/en/sql-reference/statements/create/table#default_values)
Alias(Expr),

/// `{ PRIMARY KEY | UNIQUE } [<constraint_characteristics>]`
Expand Down Expand Up @@ -1330,7 +1333,7 @@ pub enum GeneratedExpressionMode {
#[must_use]
fn display_constraint_name(name: &'_ Option<Ident>) -> impl fmt::Display + '_ {
struct ConstraintName<'a>(&'a Option<Ident>);
impl<'a> fmt::Display for ConstraintName<'a> {
impl fmt::Display for ConstraintName<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(name) = self.0 {
write!(f, "CONSTRAINT {name} ")?;
Expand All @@ -1351,7 +1354,7 @@ fn display_option<'a, T: fmt::Display>(
option: &'a Option<T>,
) -> impl fmt::Display + 'a {
struct OptionDisplay<'a, T>(&'a str, &'a str, &'a Option<T>);
impl<'a, T: fmt::Display> fmt::Display for OptionDisplay<'a, T> {
impl<T: fmt::Display> fmt::Display for OptionDisplay<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(inner) = self.2 {
let (prefix, postfix) = (self.0, self.1);
Expand Down
2 changes: 1 addition & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ where
sep: &'static str,
}

impl<'a, T> fmt::Display for DisplaySeparated<'a, T>
impl<T> fmt::Display for DisplaySeparated<'_, T>
where
T: fmt::Display,
{
Expand Down
2 changes: 1 addition & 1 deletion src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ impl fmt::Display for Join {
}
fn suffix(constraint: &'_ JoinConstraint) -> impl fmt::Display + '_ {
struct Suffix<'a>(&'a JoinConstraint);
impl<'a> fmt::Display for Suffix<'a> {
impl fmt::Display for Suffix<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.0 {
JoinConstraint::On(expr) => write!(f, " ON {expr}"),
Expand Down
6 changes: 3 additions & 3 deletions src/ast/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ pub struct EscapeQuotedString<'a> {
quote: char,
}

impl<'a> fmt::Display for EscapeQuotedString<'a> {
impl fmt::Display for EscapeQuotedString<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// EscapeQuotedString doesn't know which mode of escape was
// chosen by the user. So this code must to correctly display
Expand Down Expand Up @@ -325,7 +325,7 @@ pub fn escape_double_quote_string(s: &str) -> EscapeQuotedString<'_> {

pub struct EscapeEscapedStringLiteral<'a>(&'a str);

impl<'a> fmt::Display for EscapeEscapedStringLiteral<'a> {
impl fmt::Display for EscapeEscapedStringLiteral<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for c in self.0.chars() {
match c {
Expand Down Expand Up @@ -359,7 +359,7 @@ pub fn escape_escaped_string(s: &str) -> EscapeEscapedStringLiteral<'_> {

pub struct EscapeUnicodeStringLiteral<'a>(&'a str);

impl<'a> fmt::Display for EscapeUnicodeStringLiteral<'a> {
impl fmt::Display for EscapeUnicodeStringLiteral<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for c in self.0.chars() {
match c {
Expand Down
6 changes: 6 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if this dialect allows dollar placeholders
/// e.g. `SELECT $var` (SQLite)
fn supports_dollar_placeholder(&self) -> bool {
false
}

/// Does the dialect support with clause in create index statement?
/// e.g. `CREATE INDEX idx ON t WITH (key = value, key2)`
fn supports_create_index_with_clause(&self) -> bool {
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,8 @@ impl Dialect for SQLiteDialect {
fn supports_asc_desc_in_column_definition(&self) -> bool {
true
}

fn supports_dollar_placeholder(&self) -> bool {
true
}
}
2 changes: 1 addition & 1 deletion src/parser/alter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
tokenizer::Token,
};

impl<'a> Parser<'a> {
impl Parser<'_> {
pub fn parse_alter_role(&mut self) -> Result<Statement, ParserError> {
if dialect_of!(self is PostgreSqlDialect) {
return self.parse_pg_alter_role();
Expand Down
6 changes: 2 additions & 4 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10458,13 +10458,12 @@ impl<'a> Parser<'a> {
Ok(ExprWithAlias { expr, alias })
}
/// Parses an expression with an optional alias
///
/// Examples:
///
/// ```sql
/// SUM(price) AS total_price
/// ```
/// ```sql
/// SUM(price)
/// ```
Expand All @@ -10480,7 +10479,6 @@ impl<'a> Parser<'a> {
/// assert_eq!(Some("b".to_string()), expr_with_alias.alias.map(|x|x.value));
/// # Ok(())
/// # }
pub fn parse_expr_with_alias(&mut self) -> Result<ExprWithAlias, ParserError> {
let expr = self.parse_expr()?;
let alias = if self.parse_keyword(Keyword::AS) {
Expand Down
39 changes: 34 additions & 5 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ struct State<'a> {
pub col: u64,
}

impl<'a> State<'a> {
impl State<'_> {
/// return the next character and advance the stream
pub fn next(&mut self) -> Option<char> {
match self.peekable.next() {
Expand Down Expand Up @@ -1278,7 +1278,8 @@ impl<'a> Tokenizer<'a> {

chars.next();

if let Some('$') = chars.peek() {
// If the dialect does not support dollar-quoted strings, then `$$` is rather a placeholder.
if matches!(chars.peek(), Some('$')) && !self.dialect.supports_dollar_placeholder() {
chars.next();

let mut is_terminated = false;
Expand Down Expand Up @@ -1312,10 +1313,14 @@ impl<'a> Tokenizer<'a> {
};
} else {
value.push_str(&peeking_take_while(chars, |ch| {
ch.is_alphanumeric() || ch == '_'
ch.is_alphanumeric()
|| ch == '_'
// Allow $ as a placeholder character if the dialect supports it
|| matches!(ch, '$' if self.dialect.supports_dollar_placeholder())
}));

if let Some('$') = chars.peek() {
// If the dialect does not support dollar-quoted strings, don't look for the end delimiter.
if matches!(chars.peek(), Some('$')) && !self.dialect.supports_dollar_placeholder() {
chars.next();

'searching_for_end: loop {
Expand Down Expand Up @@ -1885,7 +1890,7 @@ fn take_char_from_hex_digits(
mod tests {
use super::*;
use crate::dialect::{
BigQueryDialect, ClickHouseDialect, HiveDialect, MsSqlDialect, MySqlDialect,
BigQueryDialect, ClickHouseDialect, HiveDialect, MsSqlDialect, MySqlDialect, SQLiteDialect,
};
use core::fmt::Debug;

Expand Down Expand Up @@ -2321,6 +2326,30 @@ mod tests {
);
}

#[test]
fn tokenize_dollar_placeholder() {
let sql = String::from("SELECT $$, $$ABC$$, $ABC$, $ABC");
let dialect = SQLiteDialect {};
let tokens = Tokenizer::new(&dialect, &sql).tokenize().unwrap();
assert_eq!(
tokens,
vec![
Token::make_keyword("SELECT"),
Token::Whitespace(Whitespace::Space),
Token::Placeholder("$$".into()),
Token::Comma,
Token::Whitespace(Whitespace::Space),
Token::Placeholder("$$ABC$$".into()),
Token::Comma,
Token::Whitespace(Whitespace::Space),
Token::Placeholder("$ABC$".into()),
Token::Comma,
Token::Whitespace(Whitespace::Space),
Token::Placeholder("$ABC".into()),
]
);
}

#[test]
fn tokenize_dollar_quoted_string_untagged() {
let sql =
Expand Down
10 changes: 10 additions & 0 deletions tests/sqlparser_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,16 @@ fn test_dollar_identifier_as_placeholder() {
}
_ => unreachable!(),
}

// $$ is a valid placeholder in SQLite
match sqlite().verified_expr("id = $$") {
Expr::BinaryOp { op, left, right } => {
assert_eq!(op, BinaryOperator::Eq);
assert_eq!(left, Box::new(Expr::Identifier(Ident::new("id"))));
assert_eq!(right, Box::new(Expr::Value(Placeholder("$$".to_string()))));
}
_ => unreachable!(),
}
}

fn sqlite() -> TestedDialects {
Expand Down

0 comments on commit 3f5fdeb

Please sign in to comment.