Skip to content

Commit

Permalink
Add TIMESTAMPDIFF support (#876)
Browse files Browse the repository at this point in the history
* initial commit

* style fix

* WIP month, quarter, year

* lint and check_dtype

* update month/quarter/year logic

* lint

* simplify month/quarter/year logic

Co-authored-by: Ayush Dattagupta <[email protected]>
  • Loading branch information
sarahyurick and ayushdg authored Nov 30, 2022
1 parent 9cce9be commit d2896fa
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 1 deletion.
27 changes: 27 additions & 0 deletions dask_planner/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ impl Dialect for DaskDialect {
special: false,
})))
}
Token::Word(w) if w.value.to_lowercase() == "timestampdiff" => {
parser.next_token(); // skip timestampdiff
parser.expect_token(&Token::LParen)?;
let time_unit = parser.next_token();
parser.expect_token(&Token::Comma)?;
let expr1 = parser.parse_expr()?;
parser.expect_token(&Token::Comma)?;
let expr2 = parser.parse_expr()?;
parser.expect_token(&Token::RParen)?;

// convert to function args
let args = vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
Value::SingleQuotedString(time_unit.to_string()),
))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr1)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr2)),
];

Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("timestampdiff")]),
args,
over: None,
distinct: false,
special: false,
})))
}
Token::Word(w) if w.value.to_lowercase() == "to_timestamp" => {
// TO_TIMESTAMP(d, "%d/%m/%Y")
parser.next_token(); // skip to_timestamp
Expand Down
14 changes: 13 additions & 1 deletion dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod types;

use std::{collections::HashMap, sync::Arc};

use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion_common::{DFSchema, DataFusionError};
use datafusion_expr::{
logical_plan::Extension,
Expand Down Expand Up @@ -152,6 +152,18 @@ impl ContextProvider for DaskSQLContext {
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"timestampdiff" => {
let sig = Signature::exact(
vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
],
Volatility::Immutable,
);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"dsql_totimestamp" => {
let sig = Signature::one_of(
vec![
Expand Down
46 changes: 46 additions & 0 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,51 @@ def timestampadd(self, unit, interval, df: SeriesOrScalar):
raise NotImplementedError(f"Extraction of {unit} is not (yet) implemented.")


class DatetimeSubOperation(Operation):
"""
Datetime subtraction is a special case of the `minus` operation
which also specifies a sql interval return type for the operation.
"""

def __init__(self):
super().__init__(self.datetime_sub)

def datetime_sub(self, unit, df1, df2):
subtraction_op = ReduceOperation(
operation=operator.sub, unary_operation=lambda x: -x
)
result = subtraction_op(df2, df1)

if unit in {"NANOSECOND", "NANOSECONDS"}:
return result
elif unit in {"MICROSECOND", "MICROSECONDS"}:
return result // 1_000
elif unit in {"SECOND", "SECONDS"}:
return result // 1_000_000_000
elif unit in {"MINUTE", "MINUTES"}:
return (result / 1_000_000_000) // 60
elif unit in {"HOUR", "HOURS"}:
return (result / 1_000_000_000) // 3600
elif unit in {"DAY", "DAYS"}:
return ((result / 1_000_000_000) / 3600) // 24
elif unit in {"WEEK", "WEEKS"}:
return (((result / 1_000_000_000) / 3600) / 24) // 7
elif unit in {"MONTH", "MONTHS"}:
day_result = ((result / 1_000_000_000) / 3600) // 24
avg_days_in_month = ((30 * 4) + 28 + (31 * 7)) / 12
return day_result / avg_days_in_month
elif unit in {"QUARTER", "QUARTERS"}:
day_result = ((result / 1_000_000_000) / 3600) // 24
avg_days_in_quarter = 3 * ((30 * 4) + 28 + (31 * 7)) / 12
return day_result / avg_days_in_quarter
elif unit in {"YEAR", "YEARS"}:
return (((result / 1_000_000_000) / 3600) / 24) // 365
else:
raise NotImplementedError(
f"Timestamp difference with {unit} is not supported."
)


class CeilFloorOperation(PredicateBasedOperation):
"""
Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal)
Expand Down Expand Up @@ -1031,6 +1076,7 @@ class RexCallPlugin(BaseRexPlugin):
"datepart": DatePartOperation(),
"year": YearOperation(),
"timestampadd": TimeStampAddOperation(),
"timestampdiff": DatetimeSubOperation(),
}

def convert(
Expand Down
94 changes: 94 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,100 @@ def test_date_functions(c):
)


def test_timestampdiff(c):
ts_literal1 = datetime(2002, 3, 7, 9, 10, 5, 123)
ts_literal2 = datetime(2001, 6, 5, 10, 11, 6, 234)
df = dd.from_pandas(
pd.DataFrame({"ts_literal1": [ts_literal1], "ts_literal2": [ts_literal2]}),
npartitions=1,
)
c.register_dask_table(df, "df")

query = """
SELECT timestampdiff(NANOSECOND, ts_literal1, ts_literal2) as res0,
timestampdiff(MICROSECOND, ts_literal1, ts_literal2) as res1,
timestampdiff(SECOND, ts_literal1, ts_literal2) as res2,
timestampdiff(MINUTE, ts_literal1, ts_literal2) as res3,
timestampdiff(HOUR, ts_literal1, ts_literal2) as res4,
timestampdiff(DAY, ts_literal1, ts_literal2) as res5,
timestampdiff(WEEK, ts_literal1, ts_literal2) as res6,
timestampdiff(MONTH, ts_literal1, ts_literal2) as res7,
timestampdiff(QUARTER, ts_literal1, ts_literal2) as res8,
timestampdiff(YEAR, ts_literal1, ts_literal2) as res9
FROM df
"""
df = c.sql(query)

expected_df = pd.DataFrame(
{
"res0": [-23756338999889000],
"res1": [-23756338999889],
"res2": [-23756338],
"res3": [-395938],
"res4": [-6598],
"res5": [-274],
"res6": [-39],
"res7": [-9],
"res8": [-3],
"res9": [0],
}
)

assert_eq(df, expected_df, check_dtype=False)

test = pd.DataFrame(
{
"a": [
datetime(2002, 6, 5, 2, 1, 5, 200),
datetime(2002, 9, 1),
datetime(1970, 12, 3),
],
"b": [
datetime(2002, 6, 7, 1, 0, 2, 100),
datetime(2003, 6, 5),
datetime(2038, 6, 5),
],
}
)
c.create_table("test", test)

query = (
"SELECT timestampdiff(NANOSECOND, a, b) as nanoseconds,"
"timestampdiff(MICROSECOND, a, b) as microseconds,"
"timestampdiff(SECOND, a, b) as seconds,"
"timestampdiff(MINUTE, a, b) as minutes,"
"timestampdiff(HOUR, a, b) as hours,"
"timestampdiff(DAY, a, b) as days,"
"timestampdiff(WEEK, a, b) as weeks,"
"timestampdiff(MONTH, a, b) as months,"
"timestampdiff(QUARTER, a, b) as quarters,"
"timestampdiff(YEAR, a, b) as years"
" FROM test"
)
ddf = c.sql(query)

expected_df = pd.DataFrame(
{
"nanoseconds": [
169136999900000,
23932800000000000,
2130278400000000000,
],
"microseconds": [169136999900, 23932800000000, 2130278400000000],
"seconds": [169136, 23932800, 2130278400],
"minutes": [2818, 398880, 35504640],
"hours": [46, 6648, 591744],
"days": [1, 277, 24656],
"weeks": [0, 39, 3522],
"months": [0, 9, 810],
"quarters": [0, 3, 270],
"years": [0, 0, 67],
}
)

assert_eq(ddf, expected_df, check_dtype=False)


@pytest.mark.parametrize(
"gpu",
[
Expand Down

0 comments on commit d2896fa

Please sign in to comment.