Skip to content

Commit

Permalink
Add test that diagonal concat with mismatching schemas raises (#16006)
Browse files Browse the repository at this point in the history
Arguably this should be determined during query optimization by polars, but for now it is raised late during compute, so we must validate on our side.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)

URL: #16006
  • Loading branch information
wence- authored Jun 12, 2024
1 parent 97518ac commit b35991c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,10 @@ class Union(IR):
"""Optional slice to apply after concatenation."""

def __post_init__(self) -> None:
"""Validated preconditions."""
"""Validate preconditions."""
schema = self.dfs[0].schema
if not all(s.schema == schema for s in self.dfs[1:]):
raise ValueError("Schema mismatch")
raise NotImplementedError("Schema mismatch")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand Down
16 changes: 16 additions & 0 deletions python/cudf_polars/tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


Expand All @@ -19,6 +22,19 @@ def test_union():
assert_gpu_result_equal(query)


def test_union_schema_mismatch_raises():
ldf = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [1, 1, 1, 1, 1, 1, 1],
}
).lazy()
ldf2 = ldf.select(pl.col("a").cast(pl.Float32))
query = pl.concat([ldf, ldf2], how="diagonal")
with pytest.raises(NotImplementedError):
_ = translate_ir(query._ldf.visit())


def test_concat_vertical():
ldf = pl.LazyFrame(
{
Expand Down

0 comments on commit b35991c

Please sign in to comment.