Skip to content

Commit

Permalink
Fix tz_localize for dask_cudf Series (#13610)
Browse files Browse the repository at this point in the history
Closes #13602

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #13610
  • Loading branch information
shwina authored Jun 26, 2023
1 parent 9a3f3a9 commit 12210fd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
6 changes: 6 additions & 0 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import pyarrow as pa
from pandas.api.types import is_scalar
from pandas.core.tools.datetimes import is_datetime64tz_dtype

import dask.dataframe as dd
from dask import config
Expand Down Expand Up @@ -122,6 +123,11 @@ def _get_non_empty_data(s):
data = cudf.core.column.as_column(data, dtype=s.dtype)
elif is_string_dtype(s.dtype):
data = pa.array(["cat", "dog"])
elif is_datetime64tz_dtype(s.dtype):
from cudf.utils.dtypes import get_time_unit

data = cudf.date_range("2001-01-01", periods=2, freq=get_time_unit(s))
data = data.tz_localize(str(s.dtype.tz))._column
else:
if pd.api.types.is_numeric_dtype(s.dtype):
data = cudf.core.column.as_column(
Expand Down
34 changes: 32 additions & 2 deletions python/dask_cudf/dask_cudf/tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.

import numpy as np
import pandas as pd
Expand All @@ -7,7 +7,7 @@

from dask import dataframe as dd

from cudf import DataFrame, Series
from cudf import DataFrame, Series, date_range
from cudf.testing._utils import assert_eq, does_not_raise

import dask_cudf as dgd
Expand Down Expand Up @@ -527,3 +527,33 @@ def test_struct_explode(data):
got = dgd.from_cudf(Series(data), 2).struct.explode()
# Output index will not agree for >1 partitions
assert_eq(expect, got.compute().reset_index(drop=True))


def test_tz_localize():
data = Series(date_range("2000-04-01", "2000-04-03", freq="H"))
expect = data.dt.tz_localize(
"US/Eastern", ambiguous="NaT", nonexistent="NaT"
)
got = dgd.from_cudf(data, 2).dt.tz_localize(
"US/Eastern", ambiguous="NaT", nonexistent="NaT"
)
dd.assert_eq(expect, got)

expect = expect.dt.tz_localize(None)
got = got.dt.tz_localize(None)
dd.assert_eq(expect, got)


@pytest.mark.parametrize(
"data",
[
date_range("2000-04-01", "2000-04-03", freq="H").tz_localize("UTC"),
date_range("2000-04-01", "2000-04-03", freq="H").tz_localize(
"US/Eastern"
),
],
)
def test_tz_convert(data):
expect = Series(data).dt.tz_convert("US/Pacific")
got = dgd.from_cudf(Series(data), 2).dt.tz_convert("US/Pacific")
dd.assert_eq(expect, got)

0 comments on commit 12210fd

Please sign in to comment.