Skip to content

Commit

Permalink
Fix duckdb & sqlite character_length scalar unparsing (apache#13428)
Browse files Browse the repository at this point in the history
* Fix duckdb & sqlite character_length scalar unparsing (#59)

* Fix duckdb & sqlite character_length scalar unparsing

* Add comments

* Update CharacterLengthStyle::SQLStandard to CharacterLengthExtractStyle::CharacterLength

* Fix clippy error
  • Loading branch information
Sevenannn authored Nov 17, 2024
1 parent a892101 commit cd013c7
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 12 deletions.
93 changes: 84 additions & 9 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use sqlparser::{

use datafusion_common::Result;

use super::{utils::date_part_to_sql, Unparser};
use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser};

/// `Dialect` to use for Unparsing
///
Expand Down Expand Up @@ -80,6 +80,11 @@ pub trait Dialect: Send + Sync {
DateFieldExtractStyle::DatePart
}

/// The character length extraction style to use: `CharacterLengthStyle`
fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::CharacterLength
}

/// The SQL type to use for Arrow Int64 unparsing
/// Most dialects use BigInt, but some, like MySQL, require SIGNED
fn int64_cast_dtype(&self) -> ast::DataType {
Expand Down Expand Up @@ -176,6 +181,17 @@ pub enum DateFieldExtractStyle {
Strftime,
}

/// `CharacterLengthStyle` to use for unparsing
///
/// Different DBMSs uses different names for function calculating the number of characters in the string
/// `Length` style uses length(x)
/// `SQLStandard` style uses character_length(x)
#[derive(Clone, Copy, PartialEq)]
pub enum CharacterLengthStyle {
Length,
CharacterLength,
}

pub struct DefaultDialect {}

impl Dialect for DefaultDialect {
Expand Down Expand Up @@ -271,6 +287,35 @@ impl PostgreSqlDialect {
}
}

pub struct DuckDBDialect {}

impl Dialect for DuckDBDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
Some('"')
}

fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::Length
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "character_length" {
return character_length_to_sql(
unparser,
self.character_length_style(),
args,
);
}

Ok(None)
}
}

pub struct MySqlDialect {}

impl Dialect for MySqlDialect {
Expand Down Expand Up @@ -347,6 +392,10 @@ impl Dialect for SqliteDialect {
ast::DataType::Text
}

fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::Length
}

fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
Expand All @@ -357,11 +406,15 @@ impl Dialect for SqliteDialect {
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
match func_name {
"date_part" => {
date_part_to_sql(unparser, self.date_field_extract_style(), args)
}
"character_length" => {
character_length_to_sql(unparser, self.character_length_style(), args)
}
_ => Ok(None),
}

Ok(None)
}
}

Expand All @@ -374,6 +427,7 @@ pub struct CustomDialect {
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
character_length_style: CharacterLengthStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
Expand All @@ -395,6 +449,7 @@ impl Default for CustomDialect {
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
character_length_style: CharacterLengthStyle::CharacterLength,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
Expand Down Expand Up @@ -454,6 +509,10 @@ impl Dialect for CustomDialect {
self.date_field_extract_style
}

fn character_length_style(&self) -> CharacterLengthStyle {
self.character_length_style
}

fn int64_cast_dtype(&self) -> ast::DataType {
self.int64_cast_dtype.clone()
}
Expand Down Expand Up @@ -488,11 +547,15 @@ impl Dialect for CustomDialect {
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
match func_name {
"date_part" => {
date_part_to_sql(unparser, self.date_field_extract_style(), args)
}
"character_length" => {
character_length_to_sql(unparser, self.character_length_style(), args)
}
_ => Ok(None),
}

Ok(None)
}

fn requires_derived_table_alias(&self) -> bool {
Expand Down Expand Up @@ -527,6 +590,7 @@ pub struct CustomDialectBuilder {
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
character_length_style: CharacterLengthStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
Expand Down Expand Up @@ -554,6 +618,7 @@ impl CustomDialectBuilder {
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
character_length_style: CharacterLengthStyle::CharacterLength,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
Expand All @@ -578,6 +643,7 @@ impl CustomDialectBuilder {
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
date_field_extract_style: self.date_field_extract_style,
character_length_style: self.character_length_style,
int64_cast_dtype: self.int64_cast_dtype,
int32_cast_dtype: self.int32_cast_dtype,
timestamp_cast_dtype: self.timestamp_cast_dtype,
Expand Down Expand Up @@ -620,6 +686,15 @@ impl CustomDialectBuilder {
self
}

/// Customize the dialect with a specific character_length_style listed in `CharacterLengthStyle`
pub fn with_character_length_style(
mut self,
character_length_style: CharacterLengthStyle,
) -> Self {
self.character_length_style = character_length_style;
self
}

/// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc.
pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self {
self.float64_ast_dtype = float64_ast_dtype;
Expand Down
31 changes: 29 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1525,8 +1525,8 @@ mod tests {
use datafusion_functions_window::row_number::row_number_udwf;

use crate::unparser::dialect::{
CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect,
PostgreSqlDialect,
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
Dialect, PostgreSqlDialect,
};

use super::*;
Expand Down Expand Up @@ -2050,6 +2050,33 @@ mod tests {
Ok(())
}

#[test]
fn test_character_length_scalar_to_expr() {
let tests = [
(CharacterLengthStyle::Length, "length(x)"),
(CharacterLengthStyle::CharacterLength, "character_length(x)"),
];

for (style, expected) in tests {
let dialect = CustomDialectBuilder::new()
.with_character_length_style(style)
.build();
let unparser = Unparser::new(&dialect);

let expr = ScalarUDF::new_from_impl(
datafusion_functions::unicode::character_length::CharacterLengthFunc::new(
),
)
.call(vec![col("x")]);

let ast = unparser.expr_to_sql(&expr).expect("to be unparsed");

let actual = format!("{ast}");

assert_eq!(actual, expected);
}
}

#[test]
fn test_interval_scalar_to_expr() {
let tests = [
Expand Down
21 changes: 20 additions & 1 deletion datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use datafusion_expr::{
};
use sqlparser::ast;

use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser};
use super::{
dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle,
rewrite::TableAliasRewriter, Unparser,
};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -445,3 +448,19 @@ pub(crate) fn date_part_to_sql(

Ok(None)
}

pub(crate) fn character_length_to_sql(
unparser: &Unparser,
style: CharacterLengthStyle,
character_length_args: &[Expr],
) -> Result<Option<ast::Expr>> {
let func_name = match style {
CharacterLengthStyle::CharacterLength => "character_length",
CharacterLengthStyle::Length => "length",
};

Ok(Some(unparser.scalar_function_to_sql(
func_name,
character_length_args,
)?))
}

0 comments on commit cd013c7

Please sign in to comment.