diff --git a/cpp/include/cudf/detail/indexalator.cuh b/cpp/include/cudf/detail/indexalator.cuh index d546162fc7a..d0fa4e02440 100644 --- a/cpp/include/cudf/detail/indexalator.cuh +++ b/cpp/include/cudf/detail/indexalator.cuh @@ -502,17 +502,32 @@ struct indexalator_factory { iter = make_input_iterator(col); } + __device__ thrust::pair operator()(size_type i) const + { + return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)}; + } + }; + + /** + * @brief An index accessor that returns a validity flag along with the index value. + * + * This is suitable as a `pair_iterator`. + */ + struct scalar_nullable_index_accessor { + input_indexalator iter; + bool const is_null; + /** * @brief Create an accessor from a scalar. */ - nullable_index_accessor(scalar const& input) : has_nulls{!input.is_valid()} + scalar_nullable_index_accessor(scalar const& input) : is_null{!input.is_valid()} { iter = indexalator_factory::make_input_iterator(input); } - __device__ thrust::pair operator()(size_type i) const + __device__ thrust::pair operator()(size_type) const { - return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)}; + return {*iter, is_null}; } }; @@ -530,7 +545,75 @@ struct indexalator_factory { static auto make_input_pair_iterator(scalar const& input) { return thrust::make_transform_iterator(thrust::make_constant_iterator(0), - nullable_index_accessor{input}); + scalar_nullable_index_accessor{input}); + } + + /** + * @brief An index accessor that returns an index value if corresponding validity flag is true. + * + * This is suitable as an `optional_iterator`. + */ + struct optional_index_accessor { + input_indexalator iter; + bitmask_type const* null_mask{}; + size_type const offset{}; + bool const has_nulls{}; + + /** + * @brief Create an accessor from a column_view. + */ + optional_index_accessor(column_view const& col, bool has_nulls = false) + : null_mask{col.null_mask()}, offset{col.offset()}, has_nulls{has_nulls} + { + if (has_nulls) { CUDF_EXPECTS(col.nullable(), "Unexpected non-nullable column."); } + iter = make_input_iterator(col); + } + + __device__ thrust::optional operator()(size_type i) const + { + return has_nulls && !bit_is_set(null_mask, i + offset) ? thrust::nullopt + : thrust::make_optional(iter[i]); + } + }; + + /** + * @brief An index accessor that returns an index value if corresponding validity flag is true. + * + * This is suitable as an `optional_iterator`. + */ + struct scalar_optional_index_accessor { + input_indexalator iter; + bool const is_null; + + /** + * @brief Create an accessor from a scalar. + */ + scalar_optional_index_accessor(scalar const& input) : is_null{!input.is_valid()} + { + iter = indexalator_factory::make_input_iterator(input); + } + + __device__ thrust::optional operator()(size_type) const + { + return is_null ? thrust::nullopt : thrust::make_optional(*iter); + } + }; + + /** + * @brief Create an index iterator with a nullable index accessor. + */ + static auto make_input_optional_iterator(column_view const& col) + { + return make_counting_transform_iterator(0, optional_index_accessor{col, col.has_nulls()}); + } + + /** + * @brief Create an index iterator with a nullable index accessor for a scalar. + */ + static auto make_input_optional_iterator(scalar const& input) + { + return thrust::make_transform_iterator(thrust::make_constant_iterator(0), + scalar_optional_index_accessor{input}); } }; diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 088b0b747fb..d3475cbbed2 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -266,6 +266,7 @@ ConfigureTest(ITERATOR_TEST iterator/scalar_iterator_test.cu iterator/optional_iterator_test_chrono.cu iterator/optional_iterator_test_numeric.cu + iterator/indexalator_test.cu ) ################################################################################################### diff --git a/cpp/tests/iterator/indexalator_test.cu b/cpp/tests/iterator/indexalator_test.cu new file mode 100644 index 00000000000..d5379b6dd30 --- /dev/null +++ b/cpp/tests/iterator/indexalator_test.cu @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +#include + +#include +#include + +#include + +using TestingTypes = cudf::test::IntegralTypesNotBool; + +template +struct IndexalatorTest : public IteratorTest { +}; + +TYPED_TEST_CASE(IndexalatorTest, TestingTypes); + +TYPED_TEST(IndexalatorTest, input_iterator) +{ + using T = TypeParam; + + auto host_values = cudf::test::make_type_param_vector({0, 6, 0, -14, 13, 64, -13, -20, 45}); + + auto d_col = cudf::test::fixed_width_column_wrapper(host_values.begin(), host_values.end()); + + auto expected_values = thrust::host_vector(host_values.size()); + std::transform(host_values.begin(), host_values.end(), expected_values.begin(), [](auto v) { + return static_cast(v); + }); + + auto it_dev = cudf::detail::indexalator_factory::make_input_iterator(d_col); + this->iterator_test_thrust(expected_values, it_dev, host_values.size()); +} + +TYPED_TEST(IndexalatorTest, pair_iterator) +{ + using T = TypeParam; + + auto host_values = cudf::test::make_type_param_vector({0, 6, 0, -14, 13, 64, -13, -120, 115}); + auto validity = std::vector({0, 1, 1, 1, 1, 1, 0, 1, 1}); + + auto d_col = cudf::test::fixed_width_column_wrapper( + host_values.begin(), host_values.end(), validity.begin()); + + auto expected_values = + thrust::host_vector>(host_values.size()); + std::transform(host_values.begin(), + host_values.end(), + validity.begin(), + expected_values.begin(), + [](T v, bool b) { return thrust::make_pair(static_cast(v), b); }); + + auto it_dev = cudf::detail::indexalator_factory::make_input_pair_iterator(d_col); + this->iterator_test_thrust(expected_values, it_dev, host_values.size()); +} + +TYPED_TEST(IndexalatorTest, optional_iterator) +{ + using T = TypeParam; + + auto host_values = cudf::test::make_type_param_vector({0, 6, 0, -104, 103, 64, -13, -20, 45}); + auto validity = std::vector({0, 1, 1, 1, 1, 1, 0, 1, 1}); + + auto d_col = cudf::test::fixed_width_column_wrapper( + host_values.begin(), host_values.end(), validity.begin()); + + auto expected_values = thrust::host_vector>(host_values.size()); + + std::transform(host_values.begin(), + host_values.end(), + validity.begin(), + expected_values.begin(), + [](T v, bool b) { + return (b) ? thrust::make_optional(static_cast(v)) + : thrust::nullopt; + }); + + auto it_dev = cudf::detail::indexalator_factory::make_input_optional_iterator(d_col); + this->iterator_test_thrust(expected_values, it_dev, host_values.size()); +}