From b35991c366cf81b650fb79fc27604fd79468f132 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 12 Jun 2024 22:50:52 +0100 Subject: [PATCH] Add test that diagonal concat with mismatching schemas raises (#16006) Arguably this should be determined during query optimization by polars, but for now it is raised late during compute, so we must validate on our side. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Thomas Li (https://github.com/lithomas1) URL: https://github.com/rapidsai/cudf/pull/16006 --- python/cudf_polars/cudf_polars/dsl/ir.py | 4 ++-- python/cudf_polars/tests/test_union.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 0a6deb5698c..46241ab8e71 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -933,10 +933,10 @@ class Union(IR): """Optional slice to apply after concatenation.""" def __post_init__(self) -> None: - """Validated preconditions.""" + """Validate preconditions.""" schema = self.dfs[0].schema if not all(s.schema == schema for s in self.dfs[1:]): - raise ValueError("Schema mismatch") + raise NotImplementedError("Schema mismatch") def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" diff --git a/python/cudf_polars/tests/test_union.py b/python/cudf_polars/tests/test_union.py index 18cf4748692..6c9122bc260 100644 --- a/python/cudf_polars/tests/test_union.py +++ b/python/cudf_polars/tests/test_union.py @@ -2,8 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import pytest + import polars as pl +from cudf_polars import translate_ir from cudf_polars.testing.asserts import assert_gpu_result_equal @@ -19,6 +22,19 @@ def test_union(): assert_gpu_result_equal(query) +def test_union_schema_mismatch_raises(): + ldf = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7], + "b": [1, 1, 1, 1, 1, 1, 1], + } + ).lazy() + ldf2 = ldf.select(pl.col("a").cast(pl.Float32)) + query = pl.concat([ldf, ldf2], how="diagonal") + with pytest.raises(NotImplementedError): + _ = translate_ir(query._ldf.visit()) + + def test_concat_vertical(): ldf = pl.LazyFrame( {