Skip to content

Commit

Permalink
converted _make_copy_replacing_NaT_with_null() into a free function
Browse files Browse the repository at this point in the history
  • Loading branch information
skirui-source committed Dec 16, 2021
1 parent 4de3b95 commit a04647e
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,25 +306,6 @@ def _get_mask_as_column(self) -> ColumnBase:
self.base_mask, self.offset, self.offset + len(self)
)

def _make_copy_replacing_NaT_with_null(self):
"""Return a copy with NaT values replaced with nulls."""
if np.issubdtype(self.dtype, np.timedelta64):
na_value = np.timedelta64("NaT", self.time_unit)

elif np.issubdtype(self.dtype, np.datetime64):
na_value = np.datetime64("NaT", self.time_unit)

null = column_empty_like(self, masked=True, newsize=1)
out_col = cudf._lib.replace.replace(
self,
build_column(
Buffer(np.array([na_value], dtype=self.dtype).view("|u1")),
dtype=self.dtype,
),
null,
)
return out_col

def memory_usage(self) -> int:
n = 0
if self.data is not None:
Expand Down Expand Up @@ -1671,6 +1652,27 @@ def build_struct_column(
return cast("cudf.core.column.StructColumn", result)


def _make_copy_replacing_NaT_with_null(column):
"""Return a copy with NaT values replaced with nulls."""
if np.issubdtype(column.dtype, np.timedelta64):
na_value = np.timedelta64("NaT", column.time_unit)
elif np.issubdtype(column.dtype, np.datetime64):
na_value = np.datetime64("NaT", column.time_unit)
else:
raise ValueError("This type does not support replacing NaT with null.")

null = column_empty_like(column, masked=True, newsize=1)
out_col = cudf._lib.replace.replace(
column,
build_column(
Buffer(np.array([na_value], dtype=column.dtype).view("|u1")),
dtype=column.dtype,
),
null,
)
return out_col


def as_column(
arbitrary: Any,
nan_as_null: bool = None,
Expand Down Expand Up @@ -1772,9 +1774,7 @@ def as_column(
col = col.set_mask(mask)
elif np.issubdtype(col.dtype, np.datetime64):
if nan_as_null or (mask is None and nan_as_null is None):
# Ignore typing error since this method is only defined for
# DatetimeColumn, not the ColumnBase class.
col = col._make_copy_replacing_NaT_with_null() # type: ignore
col = _make_copy_replacing_NaT_with_null(col)
return col

elif isinstance(arbitrary, (pa.Array, pa.ChunkedArray)):
Expand Down Expand Up @@ -1905,7 +1905,7 @@ def as_column(
mask = None
if nan_as_null is None or nan_as_null is True:
data = build_column(buffer, dtype=arbitrary.dtype)
data = data._make_copy_replacing_NaT_with_null()
data = _make_copy_replacing_NaT_with_null(data)
mask = data.mask

data = cudf.core.column.datetime.DatetimeColumn(
Expand All @@ -1923,7 +1923,7 @@ def as_column(
mask = None
if nan_as_null is None or nan_as_null is True:
data = build_column(buffer, dtype=arbitrary.dtype)
data = data._make_copy_replacing_NaT_with_null()
data = _make_copy_replacing_NaT_with_null(data)
mask = data.mask

data = cudf.core.column.timedelta.TimeDeltaColumn(
Expand Down

0 comments on commit a04647e

Please sign in to comment.