diff --git a/cpp/include/cudf/dictionary/dictionary_column_view.hpp b/cpp/include/cudf/dictionary/dictionary_column_view.hpp index 1da52e67e06..42f8310040e 100644 --- a/cpp/include/cudf/dictionary/dictionary_column_view.hpp +++ b/cpp/include/cudf/dictionary/dictionary_column_view.hpp @@ -77,6 +77,11 @@ class dictionary_column_view : private column_view { */ column_view keys() const noexcept; + /** + * @brief Returns the `data_type` of the keys child column. + */ + data_type keys_type() const noexcept; + /** * @brief Returns the number of rows in the keys column. */ diff --git a/cpp/src/copying/copy.cu b/cpp/src/copying/copy.cu index 10af2ffb614..91fc5f02989 100644 --- a/cpp/src/copying/copy.cu +++ b/cpp/src/copying/copy.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -267,6 +268,22 @@ struct copy_if_else_functor_impl { } }; +template <> +struct copy_if_else_functor_impl { + template + std::unique_ptr operator()(Left const& lhs, + Right const& rhs, + size_type size, + bool, + bool, + Filter filter, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + return scatter_gather_based_if_else(lhs, rhs, size, filter, stream, mr); + } +}; + /** * @brief Functor called by the `type_dispatcher` to invoke copy_if_else on combinations * of column_view and scalar @@ -297,7 +314,6 @@ std::unique_ptr copy_if_else(Left const& lhs, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - CUDF_EXPECTS(lhs.type() == rhs.type(), "Both inputs must be of the same type"); CUDF_EXPECTS(boolean_mask.type() == data_type(type_id::BOOL8), "Boolean mask column must be of type type_id::BOOL8"); @@ -311,7 +327,11 @@ std::unique_ptr copy_if_else(Left const& lhs, return (!has_nulls || bool_mask_device.is_valid_nocheck(i)) and bool_mask_device.element(i); }; - return cudf::type_dispatcher(lhs.type(), + + // always dispatch on dictionary-type if either input is a dictionary + auto dispatch_type = cudf::is_dictionary(rhs.type()) ? rhs.type() : lhs.type(); + + return cudf::type_dispatcher(dispatch_type, copy_if_else_functor{}, lhs, rhs, @@ -334,6 +354,8 @@ std::unique_ptr copy_if_else(column_view const& lhs, CUDF_EXPECTS(boolean_mask.size() == lhs.size(), "Boolean mask column must be the same size as lhs and rhs columns"); CUDF_EXPECTS(lhs.size() == rhs.size(), "Both columns must be of the size"); + CUDF_EXPECTS(lhs.type() == rhs.type(), "Both inputs must be of the same type"); + return copy_if_else(lhs, rhs, lhs.has_nulls(), rhs.has_nulls(), boolean_mask, stream, mr); } @@ -345,6 +367,11 @@ std::unique_ptr copy_if_else(scalar const& lhs, { CUDF_EXPECTS(boolean_mask.size() == rhs.size(), "Boolean mask column must be the same size as rhs column"); + + auto rhs_type = + cudf::is_dictionary(rhs.type()) ? cudf::dictionary_column_view(rhs).keys_type() : rhs.type(); + CUDF_EXPECTS(lhs.type() == rhs_type, "Both inputs must be of the same type"); + return copy_if_else(lhs, rhs, !lhs.is_valid(stream), rhs.has_nulls(), boolean_mask, stream, mr); } @@ -356,6 +383,11 @@ std::unique_ptr copy_if_else(column_view const& lhs, { CUDF_EXPECTS(boolean_mask.size() == lhs.size(), "Boolean mask column must be the same size as lhs column"); + + auto lhs_type = + cudf::is_dictionary(lhs.type()) ? cudf::dictionary_column_view(lhs).keys_type() : lhs.type(); + CUDF_EXPECTS(lhs_type == rhs.type(), "Both inputs must be of the same type"); + return copy_if_else(lhs, rhs, lhs.has_nulls(), !rhs.is_valid(stream), boolean_mask, stream, mr); } @@ -365,6 +397,7 @@ std::unique_ptr copy_if_else(scalar const& lhs, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { + CUDF_EXPECTS(lhs.type() == rhs.type(), "Both inputs must be of the same type"); return copy_if_else( lhs, rhs, !lhs.is_valid(stream), !rhs.is_valid(stream), boolean_mask, stream, mr); } diff --git a/cpp/src/dictionary/dictionary_column_view.cpp b/cpp/src/dictionary/dictionary_column_view.cpp index d33fd6c548f..4906e5b4f9c 100644 --- a/cpp/src/dictionary/dictionary_column_view.cpp +++ b/cpp/src/dictionary/dictionary_column_view.cpp @@ -44,8 +44,12 @@ column_view dictionary_column_view::keys() const noexcept { return child(1); } size_type dictionary_column_view::keys_size() const noexcept { - if (size() == 0) return 0; - return keys().size(); + return (size() == 0) ? 0 : keys().size(); +} + +data_type dictionary_column_view::keys_type() const noexcept +{ + return (size() == 0) ? data_type{type_id::EMPTY} : keys().type(); } } // namespace cudf diff --git a/cpp/tests/copying/copy_tests.cpp b/cpp/tests/copying/copy_tests.cpp index 651a977050c..4468bc69640 100644 --- a/cpp/tests/copying/copy_tests.cpp +++ b/cpp/tests/copying/copy_tests.cpp @@ -18,11 +18,13 @@ #include #include #include +#include #include #include #include #include +#include #include template @@ -633,3 +635,85 @@ TYPED_TEST(FixedPointTypes, FixedPointScaleMismatch) EXPECT_THROW(cudf::copy_if_else(a, b, mask), cudf::logic_error); } + +struct DictionaryCopyIfElseTest : public cudf::test::BaseFixture { +}; + +TEST_F(DictionaryCopyIfElseTest, ColumnColumn) +{ + auto valids = cudf::test::iterators::null_at(2); + std::vector h_strings1{"eee", "bb", "", "aa", "bb", "ééé"}; + cudf::test::dictionary_column_wrapper input1( + h_strings1.begin(), h_strings1.end(), valids); + std::vector h_strings2{"zz", "bb", "", "aa", "ééé", "ooo"}; + cudf::test::dictionary_column_wrapper input2( + h_strings2.begin(), h_strings2.end(), valids); + + bool mask[] = {1, 1, 0, 1, 0, 1}; + bool mask_v[] = {1, 1, 1, 1, 1, 0}; + cudf::test::fixed_width_column_wrapper mask_w(mask, mask + 6, mask_v); + + auto results = cudf::copy_if_else(input1, input2, mask_w); + auto decoded = cudf::dictionary::decode(cudf::dictionary_column_view(results->view())); + + std::vector h_expected; + for (cudf::size_type idx = 0; idx < static_cast(h_strings1.size()); ++idx) { + if (mask[idx] and mask_v[idx]) + h_expected.push_back(h_strings1[idx]); + else + h_expected.push_back(h_strings2[idx]); + } + cudf::test::strings_column_wrapper expected(h_expected.begin(), h_expected.end(), valids); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(decoded->view(), expected); +} + +TEST_F(DictionaryCopyIfElseTest, ColumnScalar) +{ + std::string h_string{"eee"}; + cudf::string_scalar input1{h_string}; + std::vector h_strings{"zz", "", "yyy", "w", "ééé", "ooo"}; + auto valids = cudf::test::iterators::null_at(1); + cudf::test::dictionary_column_wrapper input2( + h_strings.begin(), h_strings.end(), valids); + + bool mask[] = {0, 1, 1, 1, 0, 1}; + cudf::test::fixed_width_column_wrapper mask_w(mask, mask + 6); + + auto results = cudf::copy_if_else(input2, input1, mask_w); + auto decoded = cudf::dictionary::decode(cudf::dictionary_column_view(results->view())); + + std::vector h_expected1; + std::vector h_expected2; + for (cudf::size_type idx = 0; idx < static_cast(h_strings.size()); ++idx) { + if (mask[idx]) { + h_expected1.push_back(h_strings[idx]); + h_expected2.push_back(h_string.c_str()); + } else { + h_expected1.push_back(h_string.c_str()); + h_expected2.push_back(h_strings[idx]); + } + } + + cudf::test::strings_column_wrapper expected1(h_expected1.begin(), h_expected1.end(), valids); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(decoded->view(), expected1); + + results = cudf::copy_if_else(input1, input2, mask_w); + decoded = cudf::dictionary::decode(cudf::dictionary_column_view(results->view())); + + cudf::test::strings_column_wrapper expected2(h_expected2.begin(), h_expected2.end()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(decoded->view(), expected2); +} + +TEST_F(DictionaryCopyIfElseTest, TypeMismatch) +{ + cudf::test::dictionary_column_wrapper input1({1, 1, 1, 1}); + cudf::test::dictionary_column_wrapper input2({1.0, 1.0, 1.0, 1.0}); + cudf::test::fixed_width_column_wrapper mask({1, 0, 0, 1}); + + EXPECT_THROW(cudf::copy_if_else(input1, input2, mask), cudf::logic_error); + + cudf::string_scalar input3{"1"}; + EXPECT_THROW(cudf::copy_if_else(input1, input3, mask), cudf::logic_error); + EXPECT_THROW(cudf::copy_if_else(input3, input2, mask), cudf::logic_error); + EXPECT_THROW(cudf::copy_if_else(input2, input3, mask), cudf::logic_error); +}