From 12210fd7891b9735531a03bf99ce7627e182aa6e Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Mon, 26 Jun 2023 10:52:56 -0400 Subject: [PATCH] Fix tz_localize for dask_cudf Series (#13610) Closes https://github.com/rapidsai/cudf/issues/13602 Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/13610 --- python/dask_cudf/dask_cudf/backends.py | 6 ++++ .../dask_cudf/tests/test_accessor.py | 34 +++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index 378f46de22c..2470b4d50f1 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -8,6 +8,7 @@ import pandas as pd import pyarrow as pa from pandas.api.types import is_scalar +from pandas.core.tools.datetimes import is_datetime64tz_dtype import dask.dataframe as dd from dask import config @@ -122,6 +123,11 @@ def _get_non_empty_data(s): data = cudf.core.column.as_column(data, dtype=s.dtype) elif is_string_dtype(s.dtype): data = pa.array(["cat", "dog"]) + elif is_datetime64tz_dtype(s.dtype): + from cudf.utils.dtypes import get_time_unit + + data = cudf.date_range("2001-01-01", periods=2, freq=get_time_unit(s)) + data = data.tz_localize(str(s.dtype.tz))._column else: if pd.api.types.is_numeric_dtype(s.dtype): data = cudf.core.column.as_column( diff --git a/python/dask_cudf/dask_cudf/tests/test_accessor.py b/python/dask_cudf/dask_cudf/tests/test_accessor.py index 6b1627c91e8..bea0cbb48ae 100644 --- a/python/dask_cudf/dask_cudf/tests/test_accessor.py +++ b/python/dask_cudf/dask_cudf/tests/test_accessor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. import numpy as np import pandas as pd @@ -7,7 +7,7 @@ from dask import dataframe as dd -from cudf import DataFrame, Series +from cudf import DataFrame, Series, date_range from cudf.testing._utils import assert_eq, does_not_raise import dask_cudf as dgd @@ -527,3 +527,33 @@ def test_struct_explode(data): got = dgd.from_cudf(Series(data), 2).struct.explode() # Output index will not agree for >1 partitions assert_eq(expect, got.compute().reset_index(drop=True)) + + +def test_tz_localize(): + data = Series(date_range("2000-04-01", "2000-04-03", freq="H")) + expect = data.dt.tz_localize( + "US/Eastern", ambiguous="NaT", nonexistent="NaT" + ) + got = dgd.from_cudf(data, 2).dt.tz_localize( + "US/Eastern", ambiguous="NaT", nonexistent="NaT" + ) + dd.assert_eq(expect, got) + + expect = expect.dt.tz_localize(None) + got = got.dt.tz_localize(None) + dd.assert_eq(expect, got) + + +@pytest.mark.parametrize( + "data", + [ + date_range("2000-04-01", "2000-04-03", freq="H").tz_localize("UTC"), + date_range("2000-04-01", "2000-04-03", freq="H").tz_localize( + "US/Eastern" + ), + ], +) +def test_tz_convert(data): + expect = Series(data).dt.tz_convert("US/Pacific") + got = dgd.from_cudf(Series(data), 2).dt.tz_convert("US/Pacific") + dd.assert_eq(expect, got)