Skip to content

Commit

Permalink
Support DATE Extraction (#984)
Browse files Browse the repository at this point in the history
* support date extraction

* style fix

* remove extractdateoperation

* use strftime

* convert str to datetime

* cudf behavior

* apply Charles suggestions

* revert rust change

* Update dialect.rs

* add test

* style

* add check_index

* handle scalar input

* add RuntimeError

* style fix

* use xfail

---------

Co-authored-by: Ayush Dattagupta <[email protected]>
Co-authored-by: Charles Blackmon-Luca <[email protected]>
  • Loading branch information
3 people authored Jun 23, 2023
1 parent f8bf06c commit 5421bbf
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 15 deletions.
29 changes: 29 additions & 0 deletions dask_planner/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,35 @@ impl Dialect for DaskDialect {
special: false,
})))
}
Token::Word(w) if w.value.to_lowercase() == "extract" => {
// EXTRACT(DATE FROM d)
parser.next_token(); // skip extract
parser.expect_token(&Token::LParen)?;
if !parser.parse_keywords(&[Keyword::DATE, Keyword::FROM]) {
// Parse EXTRACT(x FROM d) as normal
parser.prev_token();
parser.prev_token();
return Ok(None);
}
let expr = parser.parse_expr()?;
parser.expect_token(&Token::RParen)?;

// convert to function args
let args = vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
Value::SingleQuotedString("DATE".to_string()),
))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
];

Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("extract_date")]),
args,
over: None,
distinct: false,
special: false,
})))
}
_ => Ok(None),
}
}
Expand Down
14 changes: 14 additions & 0 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ impl ContextProvider for DaskSQLContext {
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"extract_date" => {
let sig = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, None),
]),
],
Volatility::Immutable,
);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
_ => (),
}

Expand Down
2 changes: 2 additions & 0 deletions dask_sql/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
# Parquet predicate-support version checks
PQ_NOT_IN_SUPPORT = parseVersion(dask.__version__) > parseVersion("2023.5.1")
PQ_IS_SUPPORT = parseVersion(dask.__version__) >= parseVersion("2023.3.1")

DASK_CUDF_TODATETIME_SUPPORT = _dask_version >= parseVersion("2023.5.1")
12 changes: 12 additions & 0 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.utils import random_state_data

from dask_planner.rust import SqlTypeName
from dask_sql._compat import DASK_CUDF_TODATETIME_SUPPORT
from dask_sql.datacontainer import DataContainer
from dask_sql.mappings import (
cast_column_to_type,
Expand Down Expand Up @@ -929,6 +930,16 @@ def date_part(self, what, df: SeriesOrScalar):
return df.week
elif what in {"YEAR", "YEARS"}:
return df.year
elif what == "DATE":
if isinstance(df, pd.Timestamp):
return df.date()
else:
if is_cudf_type(df) and not DASK_CUDF_TODATETIME_SUPPORT:
raise RuntimeError(
"Dask-cuDF to_datetime support requires Dask version >= 2023.5.1"
)
else:
return dd.to_datetime(df.strftime("%Y-%m-%d"))
else:
raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.")

Expand Down Expand Up @@ -1070,6 +1081,7 @@ class RexCallPlugin(BaseRexPlugin):
"coalesce": CoalesceOperation(),
"replace": ReplaceOperation(),
# date/time operations
"extract_date": ExtractOperation(),
"localtime": Operation(lambda *args: pd.Timestamp.now()),
"localtimestamp": Operation(lambda *args: pd.Timestamp.now()),
"current_time": Operation(lambda *args: pd.Timestamp.now()),
Expand Down
71 changes: 69 additions & 2 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest

from dask_sql._compat import DASK_CUDF_TODATETIME_SUPPORT
from tests.utils import assert_eq


Expand Down Expand Up @@ -759,10 +760,11 @@ def test_date_functions(c):
EXTRACT(SECOND FROM d) AS "second",
EXTRACT(WEEK FROM d) AS "week",
EXTRACT(YEAR FROM d) AS "year",
EXTRACT(DATE FROM d) AS "date",
LAST_DAY(d) as "last_day",
TIMESTAMPADD(YEAR, 2, d) as "plus_1_year",
TIMESTAMPADD(YEAR, 1, d) as "plus_1_year",
TIMESTAMPADD(MONTH, 1, d) as "plus_1_month",
TIMESTAMPADD(WEEK, 1, d) as "plus_1_week",
TIMESTAMPADD(DAY, 1, d) as "plus_1_day",
Expand Down Expand Up @@ -806,8 +808,9 @@ def test_date_functions(c):
"second": [42],
"week": [39],
"year": [2021],
"date": [datetime(2021, 10, 3)],
"last_day": [datetime(2021, 10, 31, 15, 53, 42, 47)],
"plus_1_year": [datetime(2023, 10, 3, 15, 53, 42, 47)],
"plus_1_year": [datetime(2022, 10, 3, 15, 53, 42, 47)],
"plus_1_month": [datetime(2021, 11, 3, 15, 53, 42, 47)],
"plus_1_week": [datetime(2021, 10, 10, 15, 53, 42, 47)],
"plus_1_day": [datetime(2021, 10, 4, 15, 53, 42, 47)],
Expand Down Expand Up @@ -1054,3 +1057,67 @@ def test_totimestamp(c, gpu):
}
)
assert_eq(df, expected_df, check_dtype=False)


@pytest.mark.parametrize(
"gpu",
[
False,
pytest.param(
True,
marks=(
pytest.mark.gpu,
pytest.mark.xfail(
not DASK_CUDF_TODATETIME_SUPPORT,
reason="Requires https://github.com/dask/dask/pull/9881",
raises=RuntimeError,
),
),
),
],
)
def test_extract_date(c, gpu):
df = pd.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
df["t"] = [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)]
c.create_table("df", df, gpu=gpu)

result = c.sql("SELECT EXTRACT(DATE FROM t) AS e FROM df")
expected_df = pd.DataFrame(
{"e": [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)]}
)
assert_eq(result, expected_df)

result = c.sql("SELECT * FROM df WHERE EXTRACT(DATE FROM t) > '2021-02-01'")
expected_df = pd.DataFrame(
{
"a": [2, 3],
"b": [5, 6],
"t": [datetime(2022, 2, 2), datetime(2023, 3, 3)],
}
)
assert_eq(result, expected_df, check_index=False)

result = c.sql(
"SELECT * FROM df WHERE EXTRACT(DATE FROM t) BETWEEN '2020-10-01' AND '2022-10-10'"
)
expected_df = pd.DataFrame(
{"a": [1, 2], "b": [4, 5], "t": [datetime(2021, 1, 1), datetime(2022, 2, 2)]}
)
assert_eq(result, expected_df)

result = c.sql("SELECT TIMESTAMPADD(YEAR, 1, EXTRACT(DATE FROM t)) AS ta FROM df")
expected_df = pd.DataFrame(
{"ta": [datetime(2022, 1, 1), datetime(2023, 2, 2), datetime(2024, 3, 3)]}
)
assert_eq(result, expected_df)

result = c.sql("SELECT EXTRACT(DATE FROM t) + INTERVAL '2 days' AS i FROM df")
expected_df = pd.DataFrame(
{"i": [datetime(2021, 1, 3), datetime(2022, 2, 4), datetime(2023, 3, 5)]}
)
assert_eq(result, expected_df)
29 changes: 16 additions & 13 deletions tests/unit/test_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
import operator
from datetime import datetime
from unittest.mock import MagicMock

import dask.dataframe as dd
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_string_operations():
def test_dates():
op = call.ExtractOperation()

date = datetime(2021, 10, 3, 15, 53, 42, 47)
date = datetime.datetime(2021, 10, 3, 15, 53, 42, 47)
assert int(op("CENTURY", date)) == 20
assert op("DAY", date) == 3
assert int(op("DECADE", date)) == 202
Expand All @@ -206,18 +206,21 @@ def test_dates():
assert op("SECOND", date) == 42
assert op("WEEK", date) == 39
assert op("YEAR", date) == 2021
assert op("DATE", date) == datetime.date(2021, 10, 3)

ceil_op = call.CeilFloorOperation("ceil")
floor_op = call.CeilFloorOperation("floor")

assert ceil_op(date, "DAY") == datetime(2021, 10, 4)
assert ceil_op(date, "HOUR") == datetime(2021, 10, 3, 16)
assert ceil_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 54)
assert ceil_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 43)
assert ceil_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42, 1000)

assert floor_op(date, "DAY") == datetime(2021, 10, 3)
assert floor_op(date, "HOUR") == datetime(2021, 10, 3, 15)
assert floor_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 53)
assert floor_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 42)
assert floor_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42)
assert ceil_op(date, "DAY") == datetime.datetime(2021, 10, 4)
assert ceil_op(date, "HOUR") == datetime.datetime(2021, 10, 3, 16)
assert ceil_op(date, "MINUTE") == datetime.datetime(2021, 10, 3, 15, 54)
assert ceil_op(date, "SECOND") == datetime.datetime(2021, 10, 3, 15, 53, 43)
assert ceil_op(date, "MILLISECOND") == datetime.datetime(
2021, 10, 3, 15, 53, 42, 1000
)

assert floor_op(date, "DAY") == datetime.datetime(2021, 10, 3)
assert floor_op(date, "HOUR") == datetime.datetime(2021, 10, 3, 15)
assert floor_op(date, "MINUTE") == datetime.datetime(2021, 10, 3, 15, 53)
assert floor_op(date, "SECOND") == datetime.datetime(2021, 10, 3, 15, 53, 42)
assert floor_op(date, "MILLISECOND") == datetime.datetime(2021, 10, 3, 15, 53, 42)

0 comments on commit 5421bbf

Please sign in to comment.