From 6c3b941c10d422f700d5616c00fb4bd6b13ad888 Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Tue, 16 Jul 2024 17:34:40 -0700 Subject: [PATCH] Configurable data type instead of flag for Utf8 unparsing --- datafusion/sql/src/unparser/dialect.rs | 59 ++++++++++++++++++-------- datafusion/sql/src/unparser/expr.rs | 33 ++++++-------- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 1e44ee9a929a..7dadcb8b0afd 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -16,7 +16,7 @@ // under the License. use regex::Regex; -use sqlparser::keywords::ALL_KEYWORDS; +use sqlparser::{ast, keywords::ALL_KEYWORDS}; /// `Dialect` to use for Unparsing /// @@ -46,11 +46,15 @@ pub trait Dialect { IntervalStyle::PostgresVerbose } - // Does the dialect use CHAR to cast Utf8 rather than TEXT? - // E.g. MySQL requires CHAR instead of TEXT and automatically produces a string with - // the VARCHAR, TEXT or LONGTEXT data type based on the length of the string - fn use_char_for_utf8_cast(&self) -> bool { - false + // The SQL type to use for for Arrow Utf8 unparsing + // Most dialects use VARCHAR, but some, like MySQL, require CHAR + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Varchar(None) + } + // The SQL type to use for Arrow LargeUtf8 unparsing + // Most dialects use TEXT, but some, like MySQL, require CHAR + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text } } @@ -111,8 +115,12 @@ impl Dialect for MySqlDialect { IntervalStyle::MySQL } - fn use_char_for_utf8_cast(&self) -> bool { - true + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) } } @@ -129,7 +137,8 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - use_char_for_utf8_cast: bool, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialect { @@ -139,7 +148,8 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, - use_char_for_utf8_cast: false, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } } @@ -172,8 +182,12 @@ impl Dialect for CustomDialect { self.interval_style } - fn use_char_for_utf8_cast(&self) -> bool { - self.use_char_for_utf8_cast + fn utf8_cast_dtype(&self) -> ast::DataType { + self.utf8_cast_dtype.clone() + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + self.large_utf8_cast_dtype.clone() } } @@ -196,7 +210,8 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - use_char_for_utf8_cast: bool, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialectBuilder { @@ -212,7 +227,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, - use_char_for_utf8_cast: false, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } @@ -222,7 +238,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, - use_char_for_utf8_cast: self.use_char_for_utf8_cast, + utf8_cast_dtype: self.utf8_cast_dtype, + large_utf8_cast_dtype: self.large_utf8_cast_dtype, } } @@ -256,8 +273,16 @@ impl CustomDialectBuilder { self } - pub fn with_use_char_for_utf8_cast(mut self, use_char_for_utf8_cast: bool) -> Self { - self.use_char_for_utf8_cast = use_char_for_utf8_cast; + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { + self.utf8_cast_dtype = utf8_cast_dtype; + self + } + + pub fn with_large_utf8_cast_dtype( + mut self, + large_utf8_cast_dtype: ast::DataType, + ) -> Self { + self.large_utf8_cast_dtype = large_utf8_cast_dtype; self } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f5d0412df711..2dd828aaebd2 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1272,16 +1272,8 @@ impl Unparser<'_> { DataType::BinaryView => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(if self.dialect.use_char_for_utf8_cast() { - ast::DataType::Char(None) - } else { - ast::DataType::Varchar(None) - }), - DataType::LargeUtf8 => Ok(if self.dialect.use_char_for_utf8_cast() { - ast::DataType::Char(None) - } else { - ast::DataType::Text - }), + DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), + DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1944,16 +1936,19 @@ mod tests { #[test] fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { - for (use_char_for_utf8_cast, data_type, identifier) in [ - (false, DataType::Utf8, "VARCHAR"), - (true, DataType::Utf8, "CHAR"), - (false, DataType::LargeUtf8, "TEXT"), - (true, DataType::LargeUtf8, "CHAR"), + let default_dialect = CustomDialectBuilder::default().build(); + let mysql_custom_dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .with_large_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Utf8, "VARCHAR"), + (&default_dialect, DataType::LargeUtf8, "TEXT"), + (&mysql_custom_dialect, DataType::Utf8, "CHAR"), + (&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"), ] { - let dialect = CustomDialectBuilder::new() - .with_use_char_for_utf8_cast(use_char_for_utf8_cast) - .build(); - let unparser = Unparser::new(&dialect); + let unparser = Unparser::new(dialect); let expr = Expr::Cast(Cast { expr: Box::new(col("a")),