Skip to content

Commit

Permalink
support year/month/day functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangli20 committed Sep 11, 2024
1 parent 23d102b commit 4f9211f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
4 changes: 4 additions & 0 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use datafusion_ext_commons::df_unimplemented_err;

mod brickhouse;
mod spark_check_overflow;
mod spark_dates;
pub mod spark_get_json_object;
mod spark_make_array;
mod spark_make_decimal;
Expand Down Expand Up @@ -51,6 +52,9 @@ pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementat
"StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"StringLower" => Arc::new(spark_strings::string_lower),
"StringUpper" => Arc::new(spark_strings::string_upper),
"Year" => Arc::new(spark_dates::spark_year),
"Month" => Arc::new(spark_dates::spark_month),
"Day" => Arc::new(spark_dates::spark_day),
"BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union),
_ => df_unimplemented_err!("spark ext function not implemented: {name}")?,
})
Expand Down
98 changes: 98 additions & 0 deletions native-engine/datafusion-ext-functions/src/spark_dates.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow::compute::{day_dyn, month_dyn, year_dyn};
use datafusion::{common::Result, physical_plan::ColumnarValue};

pub fn spark_year(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let input = args[0].clone().into_array(1)?;
Ok(ColumnarValue::Array(year_dyn(&input)?))
}

pub fn spark_month(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let input = args[0].clone().into_array(1)?;
Ok(ColumnarValue::Array(month_dyn(&input)?))
}

pub fn spark_day(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let input = args[0].clone().into_array(1)?;
Ok(ColumnarValue::Array(day_dyn(&input)?))
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::{ArrayRef, Date32Array, Int32Array};

use super::*;

#[test]
fn test_spark_year() {
let input = Arc::new(Date32Array::from(vec![
Some(0),
Some(1000),
Some(2000),
None,
]));
let args = vec![ColumnarValue::Array(input)];
let expected_ret: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1970),
Some(1972),
Some(1975),
None,
]));
assert_eq!(
&spark_year(&args).unwrap().into_array(1).unwrap(),
&expected_ret
);
}

#[test]
fn test_spark_month() {
let input = Arc::new(Date32Array::from(vec![Some(0), Some(35), Some(65), None]));
let args = vec![ColumnarValue::Array(input)];
let expected_ret: ArrayRef =
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), None]));
assert_eq!(
&spark_month(&args).unwrap().into_array(1).unwrap(),
&expected_ret
);
}

#[test]
fn test_spark_day() {
let input = Arc::new(Date32Array::from(vec![
Some(0),
Some(10),
Some(20),
Some(30),
Some(40),
None,
]));
let args = vec![ColumnarValue::Array(input)];
let expected_ret: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(11),
Some(21),
Some(31),
Some(10),
None,
]));
assert_eq!(
&spark_day(&args).unwrap().into_array(1).unwrap(),
&expected_ret
);
}
}
2 changes: 1 addition & 1 deletion native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::{
any::Any,
fmt::{Debug, Formatter},
fs::File,
io::{BufReader, Read, Seek, SeekFrom},
io::{BufReader, Cursor, Read, Seek, SeekFrom},
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Days
import org.apache.spark.sql.catalyst.expressions.GetJsonObject
import org.apache.spark.sql.catalyst.expressions.LeafExpression
import org.apache.spark.sql.catalyst.expressions.Month
import org.apache.spark.sql.catalyst.expressions.XxHash64
import org.apache.spark.sql.catalyst.expressions.Year
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.execution.blaze.plan.Util
import org.apache.spark.sql.execution.ScalarSubquery
Expand Down Expand Up @@ -870,6 +873,10 @@ object NativeConverters extends Logging {
case XxHash64(children, 42L) =>
buildExtScalarFunction("XxHash64", children, LongType)

case Year(child) => buildExtScalarFunction("Year", child :: Nil, DateType)
case Month(child) => buildExtScalarFunction("Month", child :: Nil, DateType)
case Days(child) => buildExtScalarFunction("Day", child :: Nil, DateType)

// startswith is converted to scalar function in pruning-expr mode
case StartsWith(expr, Literal(prefix, StringType)) if isPruningExpr =>
buildExprNode(
Expand Down

0 comments on commit 4f9211f

Please sign in to comment.