From 4ff56da2ca9c6593707318b0ad33b72cfba05599 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 26 May 2023 12:02:59 -0400 Subject: [PATCH] Fix cudf::repeat logic when count is zero --- cpp/src/filling/repeat.cu | 4 ++-- cpp/tests/filling/repeat_tests.cpp | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/cpp/src/filling/repeat.cu b/cpp/src/filling/repeat.cu index 9c14ccca1f9..2be15a06c0d 100644 --- a/cpp/src/filling/repeat.cu +++ b/cpp/src/filling/repeat.cu @@ -137,13 +137,13 @@ std::unique_ptr repeat(table_view const& input_table, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { + if ((input_table.num_rows() == 0) || (count == 0)) { return cudf::empty_like(input_table); } + CUDF_EXPECTS(count >= 0, "count value should be non-negative"); CUDF_EXPECTS(input_table.num_rows() <= std::numeric_limits::max() / count, "The resulting table exceeds the column size limit", std::overflow_error); - if ((input_table.num_rows() == 0) || (count == 0)) { return cudf::empty_like(input_table); } - auto output_size = input_table.num_rows() * count; auto map_begin = cudf::detail::make_counting_transform_iterator( 0, [count] __device__(auto i) { return i / count; }); diff --git a/cpp/tests/filling/repeat_tests.cpp b/cpp/tests/filling/repeat_tests.cpp index 8fb28fb3390..4f74523ec7c 100644 --- a/cpp/tests/filling/repeat_tests.cpp +++ b/cpp/tests/filling/repeat_tests.cpp @@ -174,6 +174,21 @@ TYPED_TEST(RepeatTypedTestFixture, ZeroSizeInput) CUDF_TEST_EXPECT_COLUMNS_EQUAL(p_ret->view().column(0), expected); } +TYPED_TEST(RepeatTypedTestFixture, ZeroCount) +{ + using T = TypeParam; + cudf::test::fixed_width_column_wrapper input(thrust::make_counting_iterator(0), + thrust::make_counting_iterator(10)); + + auto expected = cudf::make_empty_column(cudf::type_to_id()); + + cudf::table_view input_table{{input}}; + auto p_ret = cudf::repeat(input_table, 0); + + EXPECT_EQ(p_ret->num_columns(), 1); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(p_ret->view().column(0), expected->view()); +} + class RepeatStringTestFixture : public cudf::test::BaseFixture, cudf::test::UniformRandomGenerator { public: