From d9f74cfe0dc827295fccfed6f512fedc302c4a19 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sat, 9 Oct 2021 09:50:07 +0800 Subject: [PATCH] add digest(utf8, method) function --- ballista/rust/core/proto/ballista.proto | 1 + .../core/src/serde/logical_plan/from_proto.rs | 7 +- .../core/src/serde/logical_plan/to_proto.rs | 1 + .../src/serde/physical_plan/from_proto.rs | 1 + datafusion/src/logical_plan/expr.rs | 4 +- datafusion/src/logical_plan/mod.rs | 17 +- .../src/physical_plan/crypto_expressions.rs | 304 ++++++++++-------- datafusion/src/physical_plan/functions.rs | 8 +- datafusion/src/prelude.rs | 10 +- 9 files changed, 198 insertions(+), 155 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 47cc80100a11f..8175156e30515 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -153,6 +153,7 @@ enum ScalarFunction { SHA512 = 33; LN = 34; TOTIMESTAMPMILLIS = 35; + DIGEST = 36; } message ScalarFunctionNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index c9ef97ee5c885..353be9a596426 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -27,8 +27,8 @@ use datafusion::logical_plan::window_frames::{ WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion::logical_plan::{ - abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType, + abs, acos, asin, atan, ceil, cos, digest, exp, floor, ln, log10, log2, round, signum, + sin, sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; @@ -1152,6 +1152,9 @@ impl TryInto for &protobuf::LogicalExprNode { protobuf::ScalarFunction::Sha512 => { Ok(sha512((&args[0]).try_into()?)) } + protobuf::ScalarFunction::Digest => { + Ok(digest((&args[0]).try_into()?, (&args[1]).try_into()?)) + } _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index bd7fc4d5bfc11..c3ffb1a2022e7 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1485,6 +1485,7 @@ impl TryInto for &BuiltinScalarFunction { BuiltinScalarFunction::SHA256 => Ok(protobuf::ScalarFunction::Sha256), BuiltinScalarFunction::SHA384 => Ok(protobuf::ScalarFunction::Sha384), BuiltinScalarFunction::SHA512 => Ok(protobuf::ScalarFunction::Sha512), + BuiltinScalarFunction::Digest => Ok(protobuf::ScalarFunction::Digest), BuiltinScalarFunction::ToTimestampMillis => { Ok(protobuf::ScalarFunction::Totimestampmillis) } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 0d233725fc9f9..5241e8b2bd5e0 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -559,6 +559,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sha256 => BuiltinScalarFunction::SHA256, ScalarFunction::Sha384 => BuiltinScalarFunction::SHA384, ScalarFunction::Sha512 => BuiltinScalarFunction::SHA512, + ScalarFunction::Digest => BuiltinScalarFunction::Digest, ScalarFunction::Ln => BuiltinScalarFunction::Ln, ScalarFunction::Totimestampmillis => BuiltinScalarFunction::ToTimestampMillis, } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 0fc00f3db14bf..56bf260d66818 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1508,7 +1508,7 @@ macro_rules! unary_scalar_expr { }; } -/// Create an convenience function representing a /binaryunary scalar function +/// Create an convenience function representing a binary scalar function macro_rules! binary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { #[doc = "this scalar function is not documented yet"] @@ -1581,6 +1581,7 @@ unary_scalar_expr!(Upper, upper); // date functions binary_scalar_expr!(DatePart, date_part); binary_scalar_expr!(DateTrunc, date_trunc); +binary_scalar_expr!(Digest, digest); /// returns an array of fixed size with each argument on it. pub fn array(args: Vec) -> Expr { @@ -2171,6 +2172,7 @@ mod tests { test_unary_scalar_expr!(SHA256, sha256); test_unary_scalar_expr!(SHA384, sha384); test_unary_scalar_expr!(SHA512, sha512); + test_unary_scalar_expr!(Digest, digest); test_unary_scalar_expr!(SplitPart, split_part); test_unary_scalar_expr!(StartsWith, starts_with); test_unary_scalar_expr!(Strpos, strpos); diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 84ee9e5b9fb89..3f0c7d253c938 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -38,14 +38,15 @@ pub use display::display_schema; pub use expr::{ abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, - cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, exp, - exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, - ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, - now, octet_length, or, random, regexp_match, regexp_replace, repeat, replace, - replace_col, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, - signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - translate, trim, trunc, unnormalize_col, unnormalize_cols, upper, when, Column, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, + cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, + exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, + lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, + normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, + regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, unnormalize_cols, + upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, + RewriteRecursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/physical_plan/crypto_expressions.rs b/datafusion/src/physical_plan/crypto_expressions.rs index 8ad876b24d0ce..5dcdeb0718a85 100644 --- a/datafusion/src/physical_plan/crypto_expressions.rs +++ b/datafusion/src/physical_plan/crypto_expressions.rs @@ -16,14 +16,7 @@ // under the License. //! Crypto expressions -use std::sync::Arc; - -use md5::Md5; -use sha2::{ - digest::Output as SHA2DigestOutput, Digest as SHA2Digest, Sha224, Sha256, Sha384, - Sha512, -}; - +use super::ColumnarValue; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -32,167 +25,204 @@ use arrow::{ array::{Array, BinaryArray, GenericStringArray, StringOffsetSizeTrait}, datatypes::DataType, }; - -use super::{string_expressions::unary_string_function, ColumnarValue}; - -/// Computes the md5 of a string. -fn md5_process(input: &str) -> String { - let mut digest = Md5::default(); - digest.update(&input); - - let mut result = String::new(); - - for byte in &digest.finalize() { - result.push_str(&format!("{:02x}", byte)); - } - - result +use md5::Md5; +use sha2::{Digest as SHA2Digest, Sha224, Sha256, Sha384, Sha512}; +use std::any::type_name; +use std::sync::Arc; +use std::{fmt, str::FromStr}; + +/// digest algorithm +#[derive(Debug, Copy, Clone)] +enum DigestAlgorithm { + Md5, + Sha224, + Sha256, + Sha384, + Sha512, } -// It's not possible to return &[u8], because trait in trait without short lifetime -fn sha_process(input: &str) -> SHA2DigestOutput { - let mut digest = D::default(); - digest.update(&input); - - digest.finalize() +macro_rules! downcast_string_arg { + ($ARG:expr, $NAME:expr, $T:ident) => {{ + $ARG.as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::>() + )) + })? + }}; } -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -fn unary_binary_function( - args: &[&dyn Array], - op: F, - name: &str, +fn digest_string_array( + value: &dyn Array, + digest_algorithm: DigestAlgorithm, ) -> Result where - R: AsRef<[u8]>, T: StringOffsetSizeTrait, - F: Fn(&str) -> R, { - if args.len() != 1 { - return Err(DataFusionError::Internal(format!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name, - ))); - } - - let array = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("failed to downcast to string".to_string()) - })?; - + let array = downcast_string_arg!(value, "value", T); // first map is the iterator, second is for the `Option<_>` - Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) + Ok(array + .iter() + .map(|x| x.map(|x| digest_algorithm.digest_str(x))) + .collect()) } -fn handle(args: &[ColumnarValue], op: F, name: &str) -> Result -where - R: AsRef<[u8]>, - F: Fn(&str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) +fn digest_process( + value: &ColumnarValue, + digest_algorithm: DigestAlgorithm, +) -> Result { + match value { + ColumnarValue::Array(a) => { + match a.data_type() { + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + digest_string_array::(a.as_ref(), digest_algorithm)?, + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( + digest_string_array::(a.as_ref(), digest_algorithm)?, + ))), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + other, + digest_algorithm.to_string(), + ))), } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function {}", - other, name, - ))), - }, + } ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec()); - Ok(ColumnarValue::Scalar(ScalarValue::Binary(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec()); + // both cases resolve to binary + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| digest_algorithm.digest_str(x).to_vec()); Ok(ColumnarValue::Scalar(ScalarValue::Binary(result))) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", - other, name, + other, + digest_algorithm.to_string(), ))), }, } } -fn md5_array( - args: &[&dyn Array], -) -> Result> { - unary_string_function::(args, md5_process, "md5") +macro_rules! digest_str_process { + ($METHOD: ident, $INPUT:expr) => {{ + let mut digest = $METHOD::default(); + digest.update($INPUT); + + digest.finalize().to_vec() + }}; } -/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn md5(args: &[ColumnarValue]) -> Result { - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ - a.as_ref() - ])?))), - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ - a.as_ref() - ])?))) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, - ))), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| md5_process(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| md5_process(x)); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) +impl DigestAlgorithm { + fn digest_str(self, input: &str) -> Vec { + match self { + Self::Md5 => { + let mut digest = Md5::default(); + digest.update(&input); + + let mut result = String::new(); + + for byte in &digest.finalize() { + result.push_str(&format!("{:02x}", byte)); + } + + result.into_bytes() } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, - ))), - }, + Self::Sha224 => digest_str_process!(Sha224, input), + Self::Sha256 => digest_str_process!(Sha256, input), + Self::Sha384 => digest_str_process!(Sha384, input), + Self::Sha512 => digest_str_process!(Sha512, input), + } } } -/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha224(args: &[ColumnarValue]) -> Result { - handle(args, sha_process::, "ssh224") +impl fmt::Display for DigestAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", format!("{:?}", self).to_lowercase()) + } } -/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha256(args: &[ColumnarValue]) -> Result { - handle(args, sha_process::, "sha256") +impl FromStr for DigestAlgorithm { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "md5" => Self::Md5, + "sha224" => Self::Sha224, + "sha256" => Self::Sha256, + "sha384" => Self::Sha384, + "sha512" => Self::Sha512, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in digest algorithm named {}", + name + ))) + } + }) + } } -/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha384(args: &[ColumnarValue]) -> Result { - handle(args, sha_process::, "sha384") +macro_rules! define_digest_function { + ($NAME: ident, $METHOD: ident, $DOC: expr) => { + #[doc = $DOC] + pub fn $NAME(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + DigestAlgorithm::$METHOD.to_string(), + ))); + } + digest_process(&args[0], DigestAlgorithm::$METHOD) + } + }; } -/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha512(args: &[ColumnarValue]) -> Result { - handle(args, sha_process::, "sha512") +define_digest_function!(md5, Md5, "computes md5 hash digest of the given input"); +define_digest_function!( + sha224, + Sha224, + "computes sha224 hash digest of the given input" +); +define_digest_function!( + sha256, + Sha256, + "computes sha256 hash digest of the given input" +); +define_digest_function!( + sha384, + Sha384, + "computes sha384 hash digest of the given input" +); +define_digest_function!( + sha512, + Sha512, + "computes sha512 hash digest of the given input" +); + +/// Digest computes a binary hash of the given data, accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Second argument is the algorithm to use. +/// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512. +pub fn digest(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but digest takes exactly two arguments", + args.len(), + ))); + } + let digest_algorithm = match &args[1] { + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + method.parse::() + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function digest", + other, + ))), + }, + ColumnarValue::Array(_) => Err(DataFusionError::Internal( + "Digest using dynamically decided method is not yet supported".into(), + )), + }?; + digest_process(&args[0], digest_algorithm) } diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index a1d7d451ce2a2..534ec94b6f9df 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -185,6 +185,8 @@ pub enum BuiltinScalarFunction { Ceil, /// cos Cos, + /// Digest + Digest, /// exp Exp, /// floor @@ -310,7 +312,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Random | BuiltinScalarFunction::Now ) } - /// Returns the [Volatility] of the builtin function. + /// Returns the [Volatility] of the builtin function. pub fn volatility(&self) -> Volatility { match self { //Immutable scalar builtins @@ -350,7 +352,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::MD5 => Volatility::Immutable, BuiltinScalarFunction::NullIf => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, - BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, BuiltinScalarFunction::Replace => Volatility::Immutable, @@ -362,6 +363,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SHA256 => Volatility::Immutable, BuiltinScalarFunction::SHA384 => Volatility::Immutable, BuiltinScalarFunction::SHA512 => Volatility::Immutable, + BuiltinScalarFunction::Digest => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, BuiltinScalarFunction::StartsWith => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, @@ -449,6 +451,7 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, + "digest" => BuiltinScalarFunction::Digest, "split_part" => BuiltinScalarFunction::SplitPart, "starts_with" => BuiltinScalarFunction::StartsWith, "strpos" => BuiltinScalarFunction::Strpos, @@ -554,6 +557,7 @@ pub fn return_type( BuiltinScalarFunction::SHA256 => utf8_to_binary_type(&arg_types[0], "sha256"), BuiltinScalarFunction::SHA384 => utf8_to_binary_type(&arg_types[0], "sha384"), BuiltinScalarFunction::SHA512 => utf8_to_binary_type(&arg_types[0], "sha512"), + BuiltinScalarFunction::Digest => utf8_to_binary_type(&arg_types[0], "digest"), BuiltinScalarFunction::SplitPart => utf8_to_str_type(&arg_types[0], "split_part"), BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean), BuiltinScalarFunction::Strpos => utf8_to_int_type(&arg_types[0], "strpos"), diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index 168e1d5df41ac..02b9d4f3419eb 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -29,10 +29,10 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, in_list, initcap, left, length, lit, lower, - lpad, ltrim, max, md5, min, now, octet_length, random, regexp_replace, repeat, - replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, split_part, - starts_with, strpos, substr, sum, to_hex, translate, trim, upper, Column, JoinType, - Partitioning, + count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, + lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_replace, + repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, + split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, Column, + JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions;