-
Notifications
You must be signed in to change notification settings - Fork 915
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
276 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from polars.testing.asserts import assert_frame_equal | ||
|
||
|
||
def assert_gpu_result_equal( | ||
lazydf, | ||
*, | ||
check_row_order: bool = True, | ||
check_column_order: bool = True, | ||
check_dtype: bool = True, | ||
check_exact: bool = True, | ||
rtol: float = 1e-05, | ||
atol: float = 1e-08, | ||
categorical_as_str: bool = False, | ||
): | ||
""" | ||
Assert that collection of a lazyframe on GPU produces correct results. | ||
Parameters | ||
---------- | ||
lazydf | ||
frame to collect. | ||
check_row_order | ||
Expect rows to be in same order | ||
check_column_order | ||
Expect columns to be in same order | ||
check_dtype | ||
Expect dtypes to match | ||
check_exact | ||
Require exact equality for floats, if `False` compare using | ||
rtol and atol. | ||
rtol | ||
Relative tolerance for float comparisons | ||
atol | ||
Absolute tolerance for float comparisons | ||
categorical_as_str | ||
Decat categoricals to strings before comparing | ||
Raises | ||
------ | ||
AssertionError | ||
If the GPU and CPU collection do not match. | ||
NotImplementedError | ||
If GPU collection failed in some way. | ||
""" | ||
expect = lazydf.collect(use_gpu=False) | ||
got = lazydf.collect(use_gpu=True, cpu_fallback=False) | ||
assert_frame_equal( | ||
expect, | ||
got, | ||
check_row_order=check_row_order, | ||
check_column_order=check_column_order, | ||
check_dtype=check_dtype, | ||
check_exact=check_exact, | ||
rtol=rtol, | ||
atol=atol, | ||
categorical_as_str=categorical_as_str, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
def pytest_sessionstart(session): | ||
from cudf_polars.patch import _WAS_PATCHED | ||
|
||
if not _WAS_PATCHED: | ||
# We could also just patch in the test, but this approach | ||
# provides a canary for failures with patching that we might | ||
# observe in trying this with other tests. | ||
raise RuntimeError("Patch was not applied") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
def test_extcontext(): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
ldf2 = ldf.select((pl.col("b") + pl.col("a")).alias("c")) | ||
query = ldf.with_context(ldf2).select(pl.col("b"), pl.col("c")) | ||
assert_gpu_result_equal(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
def test_filter(): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
|
||
# group-by is just to avoid the filter being pushed into the scan. | ||
query = ( | ||
ldf.group_by(pl.col("a")) | ||
.agg(pl.col("b").sum()) | ||
.filter(pl.col("b") < 1) | ||
) | ||
assert_gpu_result_equal(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
def test_hconcat(): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
ldf2 = ldf.select((pl.col("a") + pl.col("b")).alias("c")) | ||
query = pl.concat([ldf, ldf2], how="horizontal") | ||
assert_gpu_result_equal(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
def test_hstack(): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
|
||
query = ldf.with_columns(pl.col("a") + pl.col("b")) | ||
assert_gpu_result_equal(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
import pytest | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"how", | ||
[ | ||
"inner", | ||
"left", | ||
pytest.param( | ||
"outer", | ||
marks=pytest.mark.xfail( | ||
reason="non-coalescing join not implemented" | ||
), | ||
), | ||
"semi", | ||
"anti", | ||
pytest.param( | ||
"cross", | ||
marks=pytest.mark.xfail(reason="cross join not implemented"), | ||
), | ||
"outer_coalesce", | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"join_nulls", [False, True], ids=["nulls_not_equal", "nulls_equal"] | ||
) | ||
@pytest.mark.parametrize( | ||
"join_expr", | ||
[ | ||
pl.col("a"), | ||
pytest.param( | ||
pl.col("a") * 2, | ||
marks=pytest.mark.xfail( | ||
reason="Taking key columns from wrong table" | ||
), | ||
), | ||
pytest.param( | ||
[pl.col("a"), pl.col("a") + 1], | ||
marks=pytest.mark.xfail( | ||
reason="Taking key columns from wrong table" | ||
), | ||
), | ||
["c", "a"], | ||
], | ||
) | ||
def test_join(how, join_nulls, join_expr): | ||
left = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 1, None], | ||
"b": [1, 2, 3, 4, 5], | ||
"c": [2, 3, 4, 5, 6], | ||
} | ||
).lazy() | ||
right = pl.DataFrame( | ||
{ | ||
"a": [1, 4, 3, 7, None, None], | ||
"c": [2, 3, 4, 5, 6, 7], | ||
} | ||
).lazy() | ||
|
||
query = left.join(right, on=join_expr, how=how, join_nulls=join_nulls) | ||
assert_gpu_result_equal(query, check_row_order=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
import pytest | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"offset", | ||
[0, 1, 2], | ||
) | ||
@pytest.mark.parametrize( | ||
"len", | ||
[0, 2, 12], | ||
) | ||
def test_slice(offset, len): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
|
||
query = ( | ||
ldf.group_by(pl.col("a")) | ||
.agg(pl.col("b").sum()) | ||
.sort(by=pl.col("a")) | ||
.slice(offset, len) | ||
) | ||
assert_gpu_result_equal(query, check_row_order=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import polars as pl | ||
import pytest | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
@pytest.mark.xfail(reason="Need handling of null scalars that are cast") | ||
def test_union(): | ||
ldf = pl.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6, 7], | ||
"b": [1, 1, 1, 1, 1, 1, 1], | ||
} | ||
).lazy() | ||
ldf2 = ldf.select((pl.col("a") + pl.col("b")).alias("c"), pl.col("a")) | ||
query = pl.concat([ldf, ldf2], how="diagonal") | ||
# Plan for this produces a `None`.astype(Int64) which we don't | ||
# handle correctly right now | ||
assert_gpu_result_equal(query) |