diff --git a/dask_planner/src/dialect.rs b/dask_planner/src/dialect.rs index b27c81ec3..262b5223c 100644 --- a/dask_planner/src/dialect.rs +++ b/dask_planner/src/dialect.rs @@ -75,6 +75,33 @@ impl Dialect for DaskDialect { special: false, }))) } + Token::Word(w) if w.value.to_lowercase() == "timestampdiff" => { + parser.next_token(); // skip timestampdiff + parser.expect_token(&Token::LParen)?; + let time_unit = parser.next_token(); + parser.expect_token(&Token::Comma)?; + let expr1 = parser.parse_expr()?; + parser.expect_token(&Token::Comma)?; + let expr2 = parser.parse_expr()?; + parser.expect_token(&Token::RParen)?; + + // convert to function args + let args = vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString(time_unit.to_string()), + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr1)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr2)), + ]; + + Ok(Some(Expr::Function(Function { + name: ObjectName(vec![Ident::new("timestampdiff")]), + args, + over: None, + distinct: false, + special: false, + }))) + } Token::Word(w) if w.value.to_lowercase() == "to_timestamp" => { // TO_TIMESTAMP(d, "%d/%m/%Y") parser.next_token(); // skip to_timestamp diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index bf6ce16ab..aa1e3d091 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -11,7 +11,7 @@ pub mod types; use std::{collections::HashMap, sync::Arc}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_common::{DFSchema, DataFusionError}; use datafusion_expr::{ logical_plan::Extension, @@ -152,6 +152,18 @@ impl ContextProvider for DaskSQLContext { let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } + "timestampdiff" => { + let sig = Signature::exact( + vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ], + Volatility::Immutable, + ); + let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); + return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); + } "dsql_totimestamp" => { let sig = Signature::one_of( vec![ diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 6a5b01c17..b0eb5332a 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -683,6 +683,51 @@ def timestampadd(self, unit, interval, df: SeriesOrScalar): raise NotImplementedError(f"Extraction of {unit} is not (yet) implemented.") +class DatetimeSubOperation(Operation): + """ + Datetime subtraction is a special case of the `minus` operation + which also specifies a sql interval return type for the operation. + """ + + def __init__(self): + super().__init__(self.datetime_sub) + + def datetime_sub(self, unit, df1, df2): + subtraction_op = ReduceOperation( + operation=operator.sub, unary_operation=lambda x: -x + ) + result = subtraction_op(df2, df1) + + if unit in {"NANOSECOND", "NANOSECONDS"}: + return result + elif unit in {"MICROSECOND", "MICROSECONDS"}: + return result // 1_000 + elif unit in {"SECOND", "SECONDS"}: + return result // 1_000_000_000 + elif unit in {"MINUTE", "MINUTES"}: + return (result / 1_000_000_000) // 60 + elif unit in {"HOUR", "HOURS"}: + return (result / 1_000_000_000) // 3600 + elif unit in {"DAY", "DAYS"}: + return ((result / 1_000_000_000) / 3600) // 24 + elif unit in {"WEEK", "WEEKS"}: + return (((result / 1_000_000_000) / 3600) / 24) // 7 + elif unit in {"MONTH", "MONTHS"}: + day_result = ((result / 1_000_000_000) / 3600) // 24 + avg_days_in_month = ((30 * 4) + 28 + (31 * 7)) / 12 + return day_result / avg_days_in_month + elif unit in {"QUARTER", "QUARTERS"}: + day_result = ((result / 1_000_000_000) / 3600) // 24 + avg_days_in_quarter = 3 * ((30 * 4) + 28 + (31 * 7)) / 12 + return day_result / avg_days_in_quarter + elif unit in {"YEAR", "YEARS"}: + return (((result / 1_000_000_000) / 3600) / 24) // 365 + else: + raise NotImplementedError( + f"Timestamp difference with {unit} is not supported." + ) + + class CeilFloorOperation(PredicateBasedOperation): """ Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal) @@ -1031,6 +1076,7 @@ class RexCallPlugin(BaseRexPlugin): "datepart": DatePartOperation(), "year": YearOperation(), "timestampadd": TimeStampAddOperation(), + "timestampdiff": DatetimeSubOperation(), } def convert( diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 510bf953b..045b97cda 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -679,6 +679,100 @@ def test_date_functions(c): ) +def test_timestampdiff(c): + ts_literal1 = datetime(2002, 3, 7, 9, 10, 5, 123) + ts_literal2 = datetime(2001, 6, 5, 10, 11, 6, 234) + df = dd.from_pandas( + pd.DataFrame({"ts_literal1": [ts_literal1], "ts_literal2": [ts_literal2]}), + npartitions=1, + ) + c.register_dask_table(df, "df") + + query = """ + SELECT timestampdiff(NANOSECOND, ts_literal1, ts_literal2) as res0, + timestampdiff(MICROSECOND, ts_literal1, ts_literal2) as res1, + timestampdiff(SECOND, ts_literal1, ts_literal2) as res2, + timestampdiff(MINUTE, ts_literal1, ts_literal2) as res3, + timestampdiff(HOUR, ts_literal1, ts_literal2) as res4, + timestampdiff(DAY, ts_literal1, ts_literal2) as res5, + timestampdiff(WEEK, ts_literal1, ts_literal2) as res6, + timestampdiff(MONTH, ts_literal1, ts_literal2) as res7, + timestampdiff(QUARTER, ts_literal1, ts_literal2) as res8, + timestampdiff(YEAR, ts_literal1, ts_literal2) as res9 + FROM df + """ + df = c.sql(query) + + expected_df = pd.DataFrame( + { + "res0": [-23756338999889000], + "res1": [-23756338999889], + "res2": [-23756338], + "res3": [-395938], + "res4": [-6598], + "res5": [-274], + "res6": [-39], + "res7": [-9], + "res8": [-3], + "res9": [0], + } + ) + + assert_eq(df, expected_df, check_dtype=False) + + test = pd.DataFrame( + { + "a": [ + datetime(2002, 6, 5, 2, 1, 5, 200), + datetime(2002, 9, 1), + datetime(1970, 12, 3), + ], + "b": [ + datetime(2002, 6, 7, 1, 0, 2, 100), + datetime(2003, 6, 5), + datetime(2038, 6, 5), + ], + } + ) + c.create_table("test", test) + + query = ( + "SELECT timestampdiff(NANOSECOND, a, b) as nanoseconds," + "timestampdiff(MICROSECOND, a, b) as microseconds," + "timestampdiff(SECOND, a, b) as seconds," + "timestampdiff(MINUTE, a, b) as minutes," + "timestampdiff(HOUR, a, b) as hours," + "timestampdiff(DAY, a, b) as days," + "timestampdiff(WEEK, a, b) as weeks," + "timestampdiff(MONTH, a, b) as months," + "timestampdiff(QUARTER, a, b) as quarters," + "timestampdiff(YEAR, a, b) as years" + " FROM test" + ) + ddf = c.sql(query) + + expected_df = pd.DataFrame( + { + "nanoseconds": [ + 169136999900000, + 23932800000000000, + 2130278400000000000, + ], + "microseconds": [169136999900, 23932800000000, 2130278400000000], + "seconds": [169136, 23932800, 2130278400], + "minutes": [2818, 398880, 35504640], + "hours": [46, 6648, 591744], + "days": [1, 277, 24656], + "weeks": [0, 39, 3522], + "months": [0, 9, 810], + "quarters": [0, 3, 270], + "years": [0, 0, 67], + } + ) + + assert_eq(ddf, expected_df, check_dtype=False) + + @pytest.mark.parametrize( "gpu", [