Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timestampdiff support #495

Merged
merged 13 commits into from
May 17, 2022
Merged
9 changes: 8 additions & 1 deletion dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def sql_to_python_type(sql_type: str) -> type:
return pd.StringDtype()
elif sql_type.startswith("INTERVAL"):
return np.dtype("<m8[ns]")

elif sql_type.startswith("TIMESTAMP(") or sql_type.startswith("TIME("):
return np.dtype("<M8[ns]")
elif sql_type.startswith("TIMESTAMP_WITH_LOCAL_TIME_ZONE("):
Expand Down Expand Up @@ -289,6 +290,7 @@ def cast_column_to_type(col: dd.Series, expected_type: str):

current_float = pd.api.types.is_float_dtype(current_type)
expected_integer = pd.api.types.is_integer_dtype(expected_type)
current_timedelta_type = pd.api.types.is_timedelta64_dtype(current_type)
if current_float and expected_integer:
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
Expand All @@ -297,5 +299,10 @@ def cast_column_to_type(col: dd.Series, expected_type: str):
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))

if current_timedelta_type and expected_integer:
res = col.astype(np.int64)
else:
res = col.astype(expected_type)
ayushdg marked this conversation as resolved.
Show resolved Hide resolved

logger.debug(f"Need to cast from {current_type} to {expected_type}")
return col.astype(expected_type)
return res
77 changes: 76 additions & 1 deletion dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.utils import random_state_data

from dask_sql.datacontainer import DataContainer
from dask_sql.java import get_java_class
from dask_sql.mappings import cast_column_to_type, sql_to_python_type
from dask_sql.physical.rex import RexConverter
from dask_sql.physical.rex.base import BaseRexPlugin
Expand Down Expand Up @@ -168,10 +169,11 @@ def div(self, lhs, rhs):
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
if isinstance(result, (datetime.timedelta, np.timedelta64)):
return result
else: # pragma: no cover
result = da.trunc(result)
result = result.astype(np.int64)
return result
ayushdg marked this conversation as resolved.
Show resolved Hide resolved


Expand Down Expand Up @@ -240,6 +242,25 @@ def cast(self, operand, rex=None) -> SeriesOrScalar:
return return_column


class ReinterpretOperation(Operation):
"""The cast operator"""

needs_rex = True

def __init__(self):
super().__init__(self.cast)

def cast(self, operand, rex=None) -> SeriesOrScalar:
output_type = str(rex.getType())
python_type = sql_to_python_type(output_type.upper())

return_column = cast_column_to_type(operand, python_type)
if return_column is None:
return operand
else:
return return_column
ayushdg marked this conversation as resolved.
Show resolved Hide resolved


class IsFalseOperation(Operation):
"""The is false operator"""

Expand Down Expand Up @@ -715,6 +736,43 @@ def search(self, series: dd.Series, sarg: SargPythonImplementation):
return conditions[0]


class DatetimeSubOperation(Operation):
"""
Datetime subtraction is a special case of the `minus` operation in calcite
which also specifies a sql interval return type for the operation.
"""

needs_rex = True

def __init__(self):
super().__init__(self.datetime_sub)

def datetime_sub(self, *operands, rex=None):
output_type = str(rex.getType())
assert output_type.startswith("INTERVAL")
interval_unit = output_type.split()[1].lower()

subtraction_op = ReduceOperation(
operation=operator.sub, unary_operation=lambda x: -x
)
intermediate_res = subtraction_op(*operands)

# Special case output_type for datetime operations
if interval_unit in {"year", "quarter", "month"}:
# if interval_unit is INTERVAL YEAR, Calcite will covert to months
if not is_frame(intermediate_res):
# Numpy doesn't allow divsion by month time unit
result = intermediate_res.astype("timedelta64[M]")
# numpy -ve timedelta's are off by one vs sql when casted to month
result = result + 1 if result < 0 else result
else:
result = intermediate_res / np.timedelta64(1, "M")
else:
result = intermediate_res.astype("timedelta64[ms]")

return result


class RexCallPlugin(BaseRexPlugin):
"""
RexCall is used for expressions, which calculate something.
Expand Down Expand Up @@ -752,6 +810,7 @@ class RexCallPlugin(BaseRexPlugin):
"/int": IntDivisionOperator(),
# special operations
"cast": CastOperation(),
"reinterpret": ReinterpretOperation(),
"case": CaseOperation(),
"like": LikeOperation(),
"similar to": SimilarOperation(),
Expand Down Expand Up @@ -812,6 +871,7 @@ class RexCallPlugin(BaseRexPlugin):
lambda x: x + pd.tseries.offsets.MonthEnd(1),
lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),
),
"datetime_subtraction": DatetimeSubOperation(),
}

def convert(
Expand All @@ -827,6 +887,8 @@ def convert(

# Now use the operator name in the mapping
schema_name, operator_name = context.fqn(rex.getOperator().getNameAsId())
if special_op := check_special_operator(rex.getOperator()):
operator_name = special_op
operator_name = operator_name.lower()

try:
Expand All @@ -850,3 +912,16 @@ def convert(

return operation(*operands, **kwargs)
# TODO: We have information on the typing here - we should use it


def check_special_operator(operator: "org.apache.calcite.sql.fun"):
"""
Check for special operator classes that have an overloaded name with other
operator type/kinds.

eg: sqlDatetimeSubtractionOperator has the sqltype and kind of the `-` or `minus` operation.
"""
special_op_to_name = {
"org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator": "datetime_subtraction"
}
return special_op_to_name.get(get_java_class(operator), None)
89 changes: 89 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,92 @@ def test_date_functions(c):
FROM df
"""
)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_timestampdiff(c, gpu):
# single value test
ts_literal1 = "2002-03-07 09:10:05.123"
ts_literal2 = "2001-06-05 10:11:06.234"
query = (
f"SELECT timestampdiff(NANOSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res0,"
f"timestampdiff(MICROSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res1,"
f"timestampdiff(SECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res2,"
f"timestampdiff(MINUTE, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res3,"
f"timestampdiff(HOUR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res4,"
f"timestampdiff(DAY, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res5,"
f"timestampdiff(WEEK, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res6,"
f"timestampdiff(MONTH, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res7,"
f"timestampdiff(QUARTER, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res8,"
f"timestampdiff(YEAR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res9"
)
df = c.sql(query)
expected_df = pd.DataFrame(
{
"res0": [-23756339_000_000_000],
"res1": [-23756339_000_000],
"res2": [-23756339],
"res3": [-395938],
"res4": [-6598],
"res5": [-274],
"res6": [-39],
"res7": [-9],
"res8": [-3],
"res9": [0],
}
)
assert_eq(df, expected_df)
# dataframe test

test = pd.DataFrame(
{
"a": [
"2002-06-05 02:01:05.200",
"2002-09-01 00:00:00",
"1970-12-03 00:00:00",
],
"b": [
"2002-06-07 01:00:02.100",
"2003-06-05 00:00:00",
"2038-06-05 00:00:00",
],
}
)

c.create_table("test", test, gpu=gpu)
query = (
"SELECT timestampdiff(NANOSECOND, CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP)) as nanoseconds,"
"timestampdiff(MICROSECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as microseconds,"
"timestampdiff(SECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as seconds,"
"timestampdiff(MINUTE, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as minutes,"
"timestampdiff(HOUR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as hours,"
"timestampdiff(DAY, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as days,"
"timestampdiff(WEEK, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as weeks,"
"timestampdiff(MONTH, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as months,"
"timestampdiff(QUARTER, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as quarters,"
"timestampdiff(YEAR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as years"
" FROM test"
)

ddf = c.sql(query)

expected_df = pd.DataFrame(
{
"nanoseconds": [
169136_000_000_000,
23932_800_000_000_000,
2_130_278_400_000_000_000,
],
"microseconds": [169136_000_000, 23932_800_000_000, 2_130_278_400_000_000],
"seconds": [169136, 23932_800, 2_130_278_400],
"minutes": [2818, 398880, 35504640],
"hours": [46, 6648, 591744],
"days": [1, 277, 24656],
"weeks": [0, 39, 3522],
"months": [0, 9, 810],
"quarters": [0, 3, 270],
"years": [0, 0, 67],
}
)

assert_eq(ddf, expected_df, check_dtype=False)