Skip to content

Commit

Permalink
Annoying
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jul 9, 2024
1 parent 49442c5 commit f85a899
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
45 changes: 32 additions & 13 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,29 @@ def copy(self) -> Self:
)

def mask_nans(self) -> Self:
"""Return a copy of self with nans masked out."""
if self.nan_count > 0:
raise NotImplementedError("Need to port transform.hpp to pylibcudf")
"""Return a shallow copy of self with nans masked out."""
if plc.traits.is_floating_point(self.obj.type()):
old_count = self.obj.null_count()
mask, new_count = plc.transform.nans_to_nulls(self.obj)
result = type(self)(self.obj.with_mask(mask, new_count))
if old_count == new_count:
return result.sorted_like(self)
return result
return self.copy()

@functools.cached_property
def nan_count(self) -> int:
"""Return the number of NaN values in the column."""
if self.obj.type().id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
return 0
return plc.interop.to_arrow(
plc.reduce.reduce(
plc.unary.is_nan(self.obj),
plc.aggregation.sum(),
# TODO: pylibcudf needs to have a SizeType DataType singleton
plc.DataType(plc.TypeId.INT32),
)
).as_py()
if plc.traits.is_floating_point(self.obj.type()):
return plc.interop.to_arrow(
plc.reduce.reduce(
plc.unary.is_nan(self.obj),
plc.aggregation.sum(),
# TODO: pylibcudf needs to have a SizeType DataType singleton
plc.DataType(plc.TypeId.INT32),
)
).as_py()
return 0


class NamedColumn(Column):
Expand Down Expand Up @@ -187,3 +192,17 @@ def copy(self, *, new_name: str | None = None) -> Self:
order=self.order,
null_order=self.null_order,
)

def mask_nans(self) -> Self:
"""Return a shallow copy of self with nans masked out."""
# Annoying, the inheritance is not right (can't call the
# super-type mask_nans), but will sort that by refactoring
# later.
if plc.traits.is_floating_point(self.obj.type()):
old_count = self.obj.null_count()
mask, new_count = plc.transform.nans_to_nulls(self.obj)
result = type(self)(self.obj.with_mask(mask, new_count), self.name)
if old_count == new_count:
return result.sorted_like(self)
return result
return self.copy()
9 changes: 6 additions & 3 deletions python/cudf_polars/tests/containers/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from __future__ import annotations

from functools import partial

import pyarrow
import pytest

import cudf._lib.pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.containers import Column, NamedColumn


def test_non_scalar_access_raises():
Expand Down Expand Up @@ -54,10 +56,11 @@ def test_shallow_copy():


@pytest.mark.parametrize("typeid", [plc.TypeId.INT8, plc.TypeId.FLOAT32])
def test_mask_nans(typeid):
@pytest.mark.parametrize("constructor", [Column, partial(NamedColumn, name="name")])
def test_mask_nans(typeid, constructor):
dtype = plc.DataType(typeid)
values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype))
column = Column(plc.interop.from_arrow(values))
column = constructor(plc.interop.from_arrow(values))
masked = column.mask_nans()
assert column.obj is masked.obj

Expand Down

0 comments on commit f85a899

Please sign in to comment.