Skip to content

Commit

Permalink
Timestampdiff support (#495)
Browse files Browse the repository at this point in the history
* added timestampdiff

* initial work for timestampdiff

* Added test cases for timestampdiff

* Update interval month dtype mapping

* Add datetimesubOperator

* Uncomment timestampdiff literal tests

* Update logic for handling interval_months for pandas/cudf series and scalars

* Add negative diff testcases, and gpu tests

* Update reinterpret and timedelta to explicitly cast to int64 instead of int

* Simplify cast_column_to_type mapping logic

* Add scalar handling to castOperation and reuse it for reinterpret

Co-authored-by: rajagurnath <[email protected]>
  • Loading branch information
ayushdg and rajagurunath authored May 17, 2022
1 parent b58989f commit cb3d903
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 16 deletions.
21 changes: 12 additions & 9 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def sql_to_python_type(sql_type: str) -> type:
return pd.StringDtype()
elif sql_type.startswith("INTERVAL"):
return np.dtype("<m8[ns]")

elif sql_type.startswith("TIMESTAMP(") or sql_type.startswith("TIME("):
return np.dtype("<M8[ns]")
elif sql_type.startswith("TIMESTAMP_WITH_LOCAL_TIME_ZONE("):
Expand Down Expand Up @@ -287,15 +288,17 @@ def cast_column_to_type(col: dd.Series, expected_type: str):
logger.debug("...not converting.")
return None

current_float = pd.api.types.is_float_dtype(current_type)
expected_integer = pd.api.types.is_integer_dtype(expected_type)
if current_float and expected_integer:
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
# because NA is a different type. It works with np.NaN though.
# For our use case, that does not matter, as the conversion to integer later
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))
if pd.api.types.is_integer_dtype(expected_type):
if pd.api.types.is_float_dtype(current_type):
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
# because NA is a different type. It works with np.NaN though.
# For our use case, that does not matter, as the conversion to integer later
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))
elif pd.api.types.is_timedelta64_dtype(current_type):
logger.debug(f"Explicitly casting from {current_type} to np.int64")
return col.astype(np.int64)

logger.debug(f"Need to cast from {current_type} to {expected_type}")
return col.astype(expected_type)
65 changes: 58 additions & 7 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.utils import random_state_data

from dask_sql.datacontainer import DataContainer
from dask_sql.java import get_java_class
from dask_sql.mappings import cast_column_to_type, sql_to_python_type
from dask_sql.physical.rex import RexConverter
from dask_sql.physical.rex.base import BaseRexPlugin
Expand Down Expand Up @@ -168,11 +169,10 @@ def div(self, lhs, rhs):
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
return result
else: # pragma: no cover
result = da.trunc(result)
if isinstance(result, (datetime.timedelta, np.timedelta64)):
return result
else:
return da.trunc(result).astype(np.int64)


class CaseOperation(Operation):
Expand Down Expand Up @@ -220,9 +220,6 @@ def __init__(self):
super().__init__(self.cast)

def cast(self, operand, rex=None) -> SeriesOrScalar:
if not is_frame(operand): # pragma: no cover
return operand

output_type = str(rex.getType())
python_type = sql_to_python_type(output_type.upper())

Expand Down Expand Up @@ -715,6 +712,43 @@ def search(self, series: dd.Series, sarg: SargPythonImplementation):
return conditions[0]


class DatetimeSubOperation(Operation):
"""
Datetime subtraction is a special case of the `minus` operation in calcite
which also specifies a sql interval return type for the operation.
"""

needs_rex = True

def __init__(self):
super().__init__(self.datetime_sub)

def datetime_sub(self, *operands, rex=None):
output_type = str(rex.getType())
assert output_type.startswith("INTERVAL")
interval_unit = output_type.split()[1].lower()

subtraction_op = ReduceOperation(
operation=operator.sub, unary_operation=lambda x: -x
)
intermediate_res = subtraction_op(*operands)

# Special case output_type for datetime operations
if interval_unit in {"year", "quarter", "month"}:
# if interval_unit is INTERVAL YEAR, Calcite will covert to months
if not is_frame(intermediate_res):
# Numpy doesn't allow divsion by month time unit
result = intermediate_res.astype("timedelta64[M]")
# numpy -ve timedelta's are off by one vs sql when casted to month
result = result + 1 if result < 0 else result
else:
result = intermediate_res / np.timedelta64(1, "M")
else:
result = intermediate_res.astype("timedelta64[ms]")

return result


class RexCallPlugin(BaseRexPlugin):
"""
RexCall is used for expressions, which calculate something.
Expand Down Expand Up @@ -752,6 +786,7 @@ class RexCallPlugin(BaseRexPlugin):
"/int": IntDivisionOperator(),
# special operations
"cast": CastOperation(),
"reinterpret": CastOperation(),
"case": CaseOperation(),
"like": LikeOperation(),
"similar to": SimilarOperation(),
Expand Down Expand Up @@ -812,6 +847,7 @@ class RexCallPlugin(BaseRexPlugin):
lambda x: x + pd.tseries.offsets.MonthEnd(1),
lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),
),
"datetime_subtraction": DatetimeSubOperation(),
}

def convert(
Expand All @@ -827,6 +863,8 @@ def convert(

# Now use the operator name in the mapping
schema_name, operator_name = context.fqn(rex.getOperator().getNameAsId())
if special_op := check_special_operator(rex.getOperator()):
operator_name = special_op
operator_name = operator_name.lower()

try:
Expand All @@ -850,3 +888,16 @@ def convert(

return operation(*operands, **kwargs)
# TODO: We have information on the typing here - we should use it


def check_special_operator(operator: "org.apache.calcite.sql.fun"):
"""
Check for special operator classes that have an overloaded name with other
operator type/kinds.
eg: sqlDatetimeSubtractionOperator has the sqltype and kind of the `-` or `minus` operation.
"""
special_op_to_name = {
"org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator": "datetime_subtraction"
}
return special_op_to_name.get(get_java_class(operator), None)
89 changes: 89 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,92 @@ def test_date_functions(c):
FROM df
"""
)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_timestampdiff(c, gpu):
# single value test
ts_literal1 = "2002-03-07 09:10:05.123"
ts_literal2 = "2001-06-05 10:11:06.234"
query = (
f"SELECT timestampdiff(NANOSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res0,"
f"timestampdiff(MICROSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res1,"
f"timestampdiff(SECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res2,"
f"timestampdiff(MINUTE, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res3,"
f"timestampdiff(HOUR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res4,"
f"timestampdiff(DAY, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res5,"
f"timestampdiff(WEEK, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res6,"
f"timestampdiff(MONTH, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res7,"
f"timestampdiff(QUARTER, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res8,"
f"timestampdiff(YEAR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res9"
)
df = c.sql(query)
expected_df = pd.DataFrame(
{
"res0": [-23756339_000_000_000],
"res1": [-23756339_000_000],
"res2": [-23756339],
"res3": [-395938],
"res4": [-6598],
"res5": [-274],
"res6": [-39],
"res7": [-9],
"res8": [-3],
"res9": [0],
}
)
assert_eq(df, expected_df)
# dataframe test

test = pd.DataFrame(
{
"a": [
"2002-06-05 02:01:05.200",
"2002-09-01 00:00:00",
"1970-12-03 00:00:00",
],
"b": [
"2002-06-07 01:00:02.100",
"2003-06-05 00:00:00",
"2038-06-05 00:00:00",
],
}
)

c.create_table("test", test, gpu=gpu)
query = (
"SELECT timestampdiff(NANOSECOND, CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP)) as nanoseconds,"
"timestampdiff(MICROSECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as microseconds,"
"timestampdiff(SECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as seconds,"
"timestampdiff(MINUTE, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as minutes,"
"timestampdiff(HOUR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as hours,"
"timestampdiff(DAY, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as days,"
"timestampdiff(WEEK, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as weeks,"
"timestampdiff(MONTH, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as months,"
"timestampdiff(QUARTER, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as quarters,"
"timestampdiff(YEAR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as years"
" FROM test"
)

ddf = c.sql(query)

expected_df = pd.DataFrame(
{
"nanoseconds": [
169136_000_000_000,
23932_800_000_000_000,
2_130_278_400_000_000_000,
],
"microseconds": [169136_000_000, 23932_800_000_000, 2_130_278_400_000_000],
"seconds": [169136, 23932_800, 2_130_278_400],
"minutes": [2818, 398880, 35504640],
"hours": [46, 6648, 591744],
"days": [1, 277, 24656],
"weeks": [0, 39, 3522],
"months": [0, 9, 810],
"quarters": [0, 3, 270],
"years": [0, 0, 67],
}
)

assert_eq(ddf, expected_df, check_dtype=False)

0 comments on commit cb3d903

Please sign in to comment.