Skip to content

Commit

Permalink
Fix typo bug in gather implementation (#16000)
Browse files Browse the repository at this point in the history
Pylibcudf calls the datatype accessor type(). Add tests to cover this case, and raising on out of bounds accesses.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)
  - Bradley Dice (https://github.com/bdice)

URL: #16000
  • Loading branch information
wence- authored Jun 12, 2024
1 parent 0891c5d commit 97518ac
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def do_evaluate(
obj = plc.replace.replace_nulls(
indices.obj,
plc.interop.from_arrow(
pa.scalar(n, type=plc.interop.to_arrow(indices.obj.data_type()))
pa.scalar(n, type=plc.interop.to_arrow(indices.obj.type()))
),
)
else:
Expand Down
31 changes: 31 additions & 0 deletions python/cudf_polars/tests/expressions/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import execute_with_cudf
from cudf_polars.testing.asserts import assert_gpu_result_equal


Expand All @@ -17,3 +20,31 @@ def test_gather():

query = ldf.select(pl.col("a").gather(pl.col("b")))
assert_gpu_result_equal(query)


def test_gather_with_nulls():
ldf = pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [0, None, 1, None, 6, 1, 0],
}
)

query = ldf.select(pl.col("a").gather(pl.col("b")))

assert_gpu_result_equal(query)


@pytest.mark.parametrize("negative", [False, True])
def test_gather_out_of_bounds(negative):
ldf = pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [0, -10 if negative else 10, 1, 2, 6, 1, 0],
}
)

query = ldf.select(pl.col("a").gather(pl.col("b")))

with pytest.raises(pl.exceptions.ComputeError):
query.collect(post_opt_callback=execute_with_cudf)

0 comments on commit 97518ac

Please sign in to comment.