diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 500d7c79d..e59025918 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -197,6 +197,7 @@ def sql_to_python_type(sql_type: str) -> type: return pd.StringDtype() elif sql_type.startswith("INTERVAL"): return np.dtype(" SeriesOrScalar: - if not is_frame(operand): # pragma: no cover - return operand - output_type = str(rex.getType()) python_type = sql_to_python_type(output_type.upper()) @@ -715,6 +712,43 @@ def search(self, series: dd.Series, sarg: SargPythonImplementation): return conditions[0] +class DatetimeSubOperation(Operation): + """ + Datetime subtraction is a special case of the `minus` operation in calcite + which also specifies a sql interval return type for the operation. + """ + + needs_rex = True + + def __init__(self): + super().__init__(self.datetime_sub) + + def datetime_sub(self, *operands, rex=None): + output_type = str(rex.getType()) + assert output_type.startswith("INTERVAL") + interval_unit = output_type.split()[1].lower() + + subtraction_op = ReduceOperation( + operation=operator.sub, unary_operation=lambda x: -x + ) + intermediate_res = subtraction_op(*operands) + + # Special case output_type for datetime operations + if interval_unit in {"year", "quarter", "month"}: + # if interval_unit is INTERVAL YEAR, Calcite will covert to months + if not is_frame(intermediate_res): + # Numpy doesn't allow divsion by month time unit + result = intermediate_res.astype("timedelta64[M]") + # numpy -ve timedelta's are off by one vs sql when casted to month + result = result + 1 if result < 0 else result + else: + result = intermediate_res / np.timedelta64(1, "M") + else: + result = intermediate_res.astype("timedelta64[ms]") + + return result + + class RexCallPlugin(BaseRexPlugin): """ RexCall is used for expressions, which calculate something. @@ -752,6 +786,7 @@ class RexCallPlugin(BaseRexPlugin): "/int": IntDivisionOperator(), # special operations "cast": CastOperation(), + "reinterpret": CastOperation(), "case": CaseOperation(), "like": LikeOperation(), "similar to": SimilarOperation(), @@ -812,6 +847,7 @@ class RexCallPlugin(BaseRexPlugin): lambda x: x + pd.tseries.offsets.MonthEnd(1), lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1), ), + "datetime_subtraction": DatetimeSubOperation(), } def convert( @@ -827,6 +863,8 @@ def convert( # Now use the operator name in the mapping schema_name, operator_name = context.fqn(rex.getOperator().getNameAsId()) + if special_op := check_special_operator(rex.getOperator()): + operator_name = special_op operator_name = operator_name.lower() try: @@ -850,3 +888,16 @@ def convert( return operation(*operands, **kwargs) # TODO: We have information on the typing here - we should use it + + +def check_special_operator(operator: "org.apache.calcite.sql.fun"): + """ + Check for special operator classes that have an overloaded name with other + operator type/kinds. + + eg: sqlDatetimeSubtractionOperator has the sqltype and kind of the `-` or `minus` operation. + """ + special_op_to_name = { + "org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator": "datetime_subtraction" + } + return special_op_to_name.get(get_java_class(operator), None) diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 8e38c94ac..508d2694d 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -587,3 +587,92 @@ def test_date_functions(c): FROM df """ ) + + +@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) +def test_timestampdiff(c, gpu): + # single value test + ts_literal1 = "2002-03-07 09:10:05.123" + ts_literal2 = "2001-06-05 10:11:06.234" + query = ( + f"SELECT timestampdiff(NANOSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res0," + f"timestampdiff(MICROSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res1," + f"timestampdiff(SECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res2," + f"timestampdiff(MINUTE, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res3," + f"timestampdiff(HOUR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res4," + f"timestampdiff(DAY, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res5," + f"timestampdiff(WEEK, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res6," + f"timestampdiff(MONTH, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res7," + f"timestampdiff(QUARTER, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res8," + f"timestampdiff(YEAR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res9" + ) + df = c.sql(query) + expected_df = pd.DataFrame( + { + "res0": [-23756339_000_000_000], + "res1": [-23756339_000_000], + "res2": [-23756339], + "res3": [-395938], + "res4": [-6598], + "res5": [-274], + "res6": [-39], + "res7": [-9], + "res8": [-3], + "res9": [0], + } + ) + assert_eq(df, expected_df) + # dataframe test + + test = pd.DataFrame( + { + "a": [ + "2002-06-05 02:01:05.200", + "2002-09-01 00:00:00", + "1970-12-03 00:00:00", + ], + "b": [ + "2002-06-07 01:00:02.100", + "2003-06-05 00:00:00", + "2038-06-05 00:00:00", + ], + } + ) + + c.create_table("test", test, gpu=gpu) + query = ( + "SELECT timestampdiff(NANOSECOND, CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP)) as nanoseconds," + "timestampdiff(MICROSECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as microseconds," + "timestampdiff(SECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as seconds," + "timestampdiff(MINUTE, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as minutes," + "timestampdiff(HOUR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as hours," + "timestampdiff(DAY, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as days," + "timestampdiff(WEEK, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as weeks," + "timestampdiff(MONTH, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as months," + "timestampdiff(QUARTER, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as quarters," + "timestampdiff(YEAR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as years" + " FROM test" + ) + + ddf = c.sql(query) + + expected_df = pd.DataFrame( + { + "nanoseconds": [ + 169136_000_000_000, + 23932_800_000_000_000, + 2_130_278_400_000_000_000, + ], + "microseconds": [169136_000_000, 23932_800_000_000, 2_130_278_400_000_000], + "seconds": [169136, 23932_800, 2_130_278_400], + "minutes": [2818, 398880, 35504640], + "hours": [46, 6648, 591744], + "days": [1, 277, 24656], + "weeks": [0, 39, 3522], + "months": [0, 9, 810], + "quarters": [0, 3, 270], + "years": [0, 0, 67], + } + ) + + assert_eq(ddf, expected_df, check_dtype=False)