From b97c5e6306576bc0377182ee4cc85e8922e11a68 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 8 Mar 2021 17:11:41 -0600 Subject: [PATCH 01/12] row_bit_count() support. --- cpp/CMakeLists.txt | 3 +- cpp/include/cudf/detail/transform.hpp | 12 +- cpp/include/cudf/transform.hpp | 34 +- cpp/src/jit/type.cpp | 1 + cpp/src/transform/row_bit_count.cu | 502 +++++++++++++++++++++ cpp/tests/CMakeLists.txt | 3 +- cpp/tests/transform/row_bit_count_test.cu | 508 ++++++++++++++++++++++ cpp/tests/utilities/column_utilities.cu | 34 +- 8 files changed, 1080 insertions(+), 17 deletions(-) create mode 100644 cpp/src/transform/row_bit_count.cu create mode 100644 cpp/tests/transform/row_bit_count_test.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d9d4c6bfd79..1e33757aaf6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -372,7 +372,8 @@ add_library(cudf src/transform/mask_to_bools.cu src/transform/nans_to_nulls.cu src/transform/transform.cpp - src/transpose/transpose.cu + src/transform/row_bit_count.cu + src/transpose/transpose.cu src/unary/cast_ops.cu src/unary/math_ops.cu src/unary/nan_ops.cu diff --git a/cpp/include/cudf/detail/transform.hpp b/cpp/include/cudf/detail/transform.hpp index bea480d85cd..b94223cdabe 100644 --- a/cpp/include/cudf/detail/transform.hpp +++ b/cpp/include/cudf/detail/transform.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,5 +77,15 @@ std::unique_ptr mask_to_bools( rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @copydoc cudf::row_bit_count + * + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr row_bit_count( + table_view const& t, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + } // namespace detail } // namespace cudf diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index 9b740d207e1..c3fc37f6758 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -142,5 +142,37 @@ std::unique_ptr mask_to_bools( size_type end_bit, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @brief Returns the cumulative size in bits of all columns in the `table_view` for + * each row. + * + * Each row in the returned column is the sum of the per-row size for each column in + * the table. + * + * In some cases, this is an inexact approximation. Specifically, with + * lists or strings, the cost of a row includes 32 bits for a single offset. However, two + * offsets is required to represent an entire row. But this presents a problem, because to + * represent 2 rows, you need 3 offsets. 3 rows 4 offsets, etc. Therefore it would not + * be accurate to say each row of a string column costs 2 offsets because summing multiple row + * sizes would give you a number too large. It is up to the caller to understand the schema + * of the input column to be able to calculate the small additional overhead of the + * terminating offset for any group of rows being considered. + * + * This function returns the per-row sizes as the columns are currently formed. This can + * end up being different than the number you would get by gathering the rows under + * certain circumstances. Specifically, the pushdown of validity masks by struct + * columns can nullify rows that actually contain underlying data for string or list + * columns. In these cases, the sized returned will be strictly: + * + * row_bit_count(column(x)) >= row_bit_count(gather(column(x))) + * + * @param t The table view to perform the computation on. + * @param mr Device memory resource used to allocate the returned columns's device memory + * @return A 32-bit integer column containing the per-row byte counts. + */ +std::unique_ptr row_bit_count( + table_view const& t, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + /** @} */ // end of group } // namespace cudf diff --git a/cpp/src/jit/type.cpp b/cpp/src/jit/type.cpp index d71e2eb4df8..6b1e8c57c3d 100644 --- a/cpp/src/jit/type.cpp +++ b/cpp/src/jit/type.cpp @@ -71,6 +71,7 @@ std::string get_type_name(data_type type) // TODO: Remove in JIT type utils PR switch (type.id()) { case type_id::LIST: return CUDF_STRINGIFY(List); + case type_id::STRUCT: return CUDF_STRINGIFY(Struct); case type_id::DECIMAL32: return CUDF_STRINGIFY(int32_t); case type_id::DECIMAL64: return CUDF_STRINGIFY(int64_t); diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu new file mode 100644 index 00000000000..fdf6600789c --- /dev/null +++ b/cpp/src/transform/row_bit_count.cu @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cudf { +namespace detail { + +namespace { + +using offset_type = int32_t; + +/** + * @brief Struct which contains per-column information necessary to + * traverse a column hierarchy on the gpu. + * + * When `row_bit_count` is called, the input column hierarchy is flattened into a + * vector of column_device_views. For each one of them, we store a column_info + * struct. The `depth` field represents the depth of the column in the original + * hierarchy. + * + * As we traverse the hierarchy for each input row, we maintain a span representing + * the start and end rows for the current nesting depth. At depth 0, this span is + * always just 1 row. As we cross list boundaries int the hierarchy, this span + * grows. So for each column we visit we always know how many rows of it are relevant + * and can compute it's contribution to the overall size. + * + * An example using a list> column, computing the size of row 1. + * + * { {{1, 2}, {3, 4}, {5, 6}}, {{7}, {8, 9, 10}, {11, 12, 13, 14}} } + * + * L0 = List>: + * Length : 2 + * Offsets : 0, 3, 6 + * L1 = List: + * Length : 6 + * Offsets : 0, 2, 4, 6, 7, 10, 14 + * I = 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + * + * + * span0 = [1, 2] row 1 is represented by the span [1, 2] + * span1 = [L0.offsets[span0[0]], L0.offsets[span0[1]]] expand by the offsets of L0 + * span1 = [3, 6] the range of offsets + * span2 = [L1.offsets[span1[0]], L1.offsets[span1[1]]] expand by the offsets of L1 + * span2 = [6, 14] + * + * The total size of our row is computed as: + * (span0[1] - span0[0]) * sizeof(int) the cost of the offsets for L0 + * + + * (span1[1] - span1[0]) * sizeof(int) the cost of the offsets for L1 + * + + * (span2[1] - span2[0]) * sizeof(int) the cost of the integers in I + * + * `depth` represents our depth in the source column hierarchy. + * + * "branches" within the spans can occur when we have lists inside of structs. + * consider a case where we are entering a struct with a span of [4, 8]. + * The internal list column will change that span to something else, say [5, 9]. + * But when we finish processing the list column, the final float column wants to + * go back and use the original span [4, 8]. + * + * [4, 8] [5, 9] [4, 8] + * struct< list<> float> + * + * To accomplish this we mantain a stack of spans. Pushing the current span + * whenever we enter a branch, and popping a span whenever we leave a branch. + * + * `branch_depth_start` represents the branch depth as we reach a new column. + * if `branch_depth_start` is < the last branch depth we saw, we are returning + * from a branch and should pop off the stack. + * + * `branch_depth_end` represents the new branch depth caused by this column. + * if branch_depth_end > branch_depth_start, we are branching and need to + * push the current span on the stack. + * + */ +struct column_info { + size_type depth; + size_type branch_depth_start, branch_depth_end; +}; + +/** + * @brief Struct which contains hierarchy information precomputed on the host. + * + * If the input data contains only fixed-width types, this preprocess step + * produces the value `simple_per_row_size` which is a constant for every + * row in the output. We can use thie value and skip the more complicated + * processing for lists, structs and strings entirely if `complex_type_count` + * is 0. + * + */ +struct hierarchy_info { + hierarchy_info() : simple_per_row_size(0), complex_type_count(0), max_branch_depth(0) {} + + // these two fields act as an optimization. if we find that the entire table + // is just fixed-width types, we do not need to do the more expensive kernel call that + // traverses the individual columns. so if complex_type_count is 0, we can just + // return a column where every row contains the value simple_per_row_size + size_type simple_per_row_size; // in bits + size_type complex_type_count; + + // max depth of span branches present in the hierarchy. + size_type max_branch_depth; +}; + +/** + * @brief Function which flattens the incoming column hierarchy into a vector + * of column_views and produces accompanying column_info and hierarchy_info + * metadata. + * + */ +template +void flatten_hierarchy(ColIter begin, + ColIter end, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth = 0, + size_type cur_branch_depth = 0, + int parent_index = -1); + +/** + * @brief Type-dispatched functor called by flatten_hierarchy. + * + */ +struct flatten_functor { + rmm::cuda_stream_view stream; + + // fixed width + template ()>* = nullptr> + void operator()(column_view const& col, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) + { + out.push_back(col); + info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); + h_info.simple_per_row_size += (sizeof(device_storage_type_t) * 8) + (col.nullable() ? 1 : 0); + } + + // strings + template ::value>* = nullptr> + void operator()(column_view const& col, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) + { + out.push_back(col); + info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); + h_info.complex_type_count++; + } + + // lists + template ::value>* = nullptr> + void operator()(column_view const& col, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) + { + // track branch depth as we reach this list and after we pass it + size_type const branch_depth_start = cur_branch_depth; + if (parent_index >= 0 && out[parent_index].type().id() == type_id::STRUCT) { + cur_branch_depth++; + if (cur_branch_depth > h_info.max_branch_depth) { + h_info.max_branch_depth = cur_branch_depth; + } + } + size_type const branch_depth_end = cur_branch_depth; + + out.push_back(col); + info.push_back({cur_depth, branch_depth_start, branch_depth_end}); + + lists_column_view lcv(col); + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [col = lcv.get_sliced_child(stream)](auto i) { return col; }); + h_info.complex_type_count++; + + flatten_hierarchy( + iter, iter + 1, out, info, h_info, stream, cur_depth + 1, cur_branch_depth, out.size() - 1); + } + + // structs + template ::value>* = nullptr> + void operator()(column_view const& col, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) + { + out.push_back(col); + info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); + + h_info.simple_per_row_size += col.nullable() ? 1 : 0; + + structs_column_view scv(col); + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [&scv](auto i) { return scv.get_sliced_child(i); }); + flatten_hierarchy(iter, + iter + scv.num_children(), + out, + info, + h_info, + stream, + cur_depth + 1, + cur_branch_depth, + out.size() - 1); + } + + // everything else + template () && !std::is_same::value && + !std::is_same::value && + !std::is_same::value>* = nullptr> + void operator()(column_view const& col, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) + { + CUDF_FAIL("Unsupported column type in row_bit_count"); + } +}; + +template +void flatten_hierarchy(ColIter begin, + ColIter end, + std::vector& out, + std::vector& info, + hierarchy_info& h_info, + rmm::cuda_stream_view stream, + size_type cur_depth, + size_type cur_branch_depth, + int parent_index) +{ + std::for_each(begin, end, [&](column_view const& col) { + cudf::type_dispatcher(col.type(), + flatten_functor{stream}, + col, + out, + info, + h_info, + stream, + cur_depth, + cur_branch_depth, + parent_index); + }); +} + +/** + * @brief Struct representing a span of rows. + * + */ +struct row_span { + size_type row_start, row_end; +}; + +/** + * @brief Functor for computing the size, in bits, of a `span` of rows for a given + * column_device_view + * + */ +struct row_size_functor { + template + __device__ size_type operator()(column_device_view const& col, row_span const& span) + { + auto const num_rows{span.row_end - span.row_start}; + return ((sizeof(device_storage_type_t) * 8) + (col.nullable() ? 1 : 0)) * num_rows; + } +}; + +template <> +__device__ size_type row_size_functor::operator()(column_device_view const& col, + row_span const& span) +{ + column_device_view const& offsets = col.child(strings_column_view::offsets_column_index); + auto const num_rows{span.row_end - span.row_start}; + auto const row_start{span.row_start + col.offset()}; + auto const row_end{span.row_end + col.offset()}; + + return (((sizeof(offset_type) * 8) + (col.nullable() ? 1 : 0)) * + num_rows) + // cost of offsets + validity + ((offsets.data()[row_end] - offsets.data()[row_start]) * + 8); // cost of chars +} + +template <> +__device__ size_type row_size_functor::operator()(column_device_view const& col, + row_span const& span) +{ + column_device_view const& offsets = col.child(lists_column_view::offsets_column_index); + auto const num_rows{span.row_end - span.row_start}; + return ((sizeof(offset_type) * 8) + (col.nullable() ? 1 : 0)) * + num_rows; // cost of offsets + validity +} + +template <> +__device__ size_type row_size_functor::operator()(column_device_view const& col, + row_span const& span) +{ + auto const num_rows{span.row_end - span.row_start}; + return (col.nullable() ? 1 : 0) * num_rows; // cost of validity +} + +/** + * @brief Kernel for computing per-row sizes in bits. + * + * @param cols An array of column_device_views represeting a column hierarcy + * @param info An array of column_info structs corresponding the elements in `cols` + * @param num_columns The number of columns + * @param num_rows The number of rows in the root column + * @param output Output buffer of size `num_rows` where per-row bit sizes are stored + * @param max_branch_depth Maximum depth of the span stack needed per-thread + * + */ +__global__ void compute_row_sizes(column_device_view* cols, + column_info* info, + size_type num_columns, + size_type num_rows, + size_type* output, + size_type max_branch_depth) +{ + extern __shared__ row_span branch_shared[]; + int const tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid >= num_rows) { return; } + + // branch stack. points to the last list prior to branching. + row_span* branch = branch_shared + (tid * max_branch_depth); + size_type branch_depth{0}; + + // current row span - always starts at 1 row. + row_span cur_span{tid, tid + 1}; + + // output size + size_type& size = *(output + tid); + size = 0; + + size_type last_branch_depth{0}; + for (size_type idx = 0; idx < num_columns; idx++) { + column_device_view const& col = cols[idx]; + + // if we've returned from a branch + if (info[idx].branch_depth_start < last_branch_depth) { cur_span = branch[--branch_depth]; } + // if we're entering a new branch. + // NOTE: this case can happen (a pop and a push by the same column) + // when we have a struct + if (info[idx].branch_depth_end > info[idx].branch_depth_start) { + branch[branch_depth++] = cur_span; + } + + // if we're back at depth 0, this is a new top-level column, so reset + // span info + if (info[idx].depth == 0) { + branch_depth = 0; + last_branch_depth = 0; + cur_span = row_span{tid, tid + 1}; + } + + // add the contributing size of this row + size += cudf::type_dispatcher(col.type(), row_size_functor{}, col, cur_span); + + // if this is a list column, update the working span from our offsets + if (col.type().id() == type_id::LIST) { + column_device_view const& offsets = col.child(lists_column_view::offsets_column_index); + auto const base_offset = offsets.data()[col.offset()]; + cur_span.row_start = + offsets.data()[cur_span.row_start + col.offset()] - base_offset; + cur_span.row_end = offsets.data()[cur_span.row_end + col.offset()] - base_offset; + } + + last_branch_depth = info[idx].branch_depth_end; + } +} + +} // anonymous namespace + +/** + * @copydoc cudf::detail::row_bit_count + * + */ +std::unique_ptr row_bit_count(table_view const& t, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // flatten the hierarchy and determine some information about it. + std::vector cols; + std::vector info; + hierarchy_info h_info; + flatten_hierarchy(t.begin(), t.end(), cols, info, h_info, stream); + + // create output buffer and view + auto output = cudf::make_fixed_width_column( + data_type{type_id::INT32}, t.num_rows(), mask_state::UNALLOCATED, stream, mr); + mutable_column_view mcv = output->mutable_view(); + + // simple case. if we have no complex types (lists, strings, etc), the per-row size is already + // trivially computed + if (h_info.complex_type_count <= 0) { + thrust::fill(rmm::exec_policy(stream), + mcv.begin(), + mcv.end(), + h_info.simple_per_row_size); + return output; + } + + // create a contiguous block of column_device_views + auto d_cols = contiguous_copy_column_device_views(cols, stream); + + // move stack info to the gpu + rmm::device_uvector d_info(info.size(), stream); + CUDA_TRY(cudaMemcpyAsync(d_info.data(), + info.data(), + sizeof(column_info) * info.size(), + cudaMemcpyHostToDevice, + stream.value())); + CUDF_EXPECTS(info.size() == cols.size(), "Size/info mismatch"); + + // each thread needs to maintain a stack of row spans of size max_branch_depth. we will use + // shared memory to do this rather than allocating a potentially gigantic temporary buffer + // of memory of size (# input rows * sizeof(row_span) * max_branch_depth). + auto const shmem_per_thread = sizeof(row_span) * h_info.max_branch_depth; + int device_id; + CUDA_TRY(cudaGetDevice(&device_id)); + int shmem_limit_per_block; + CUDA_TRY( + cudaDeviceGetAttribute(&shmem_limit_per_block, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); + constexpr int max_block_size = 256; + auto const block_size = + shmem_per_thread != 0 + ? std::min(max_block_size, shmem_limit_per_block / static_cast(shmem_per_thread)) + : max_block_size; + auto const shared_mem_size = shmem_per_thread * block_size; + // should we be aborting if we reach some extremely small block size, or just if we hit 0? + CUDF_EXPECTS(block_size > 0, "Encountered a column hierarchy too complex for row_bit_count"); + + cudf::detail::grid_1d grid{t.num_rows(), block_size, 1}; + compute_row_sizes<<>>( + std::get<1>(d_cols), + d_info.data(), + info.size(), + t.num_rows(), + mcv.data(), + h_info.max_branch_depth); + + return output; +} + +} // namespace detail + +/** + * @copydoc cudf::row_bit_count + * + */ +std::unique_ptr row_bit_count(table_view const& t, rmm::mr::device_memory_resource* mr) +{ + return detail::row_bit_count(t, rmm::cuda_stream_default, mr); +} + +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 492767c5d2f..55c9f6ef68c 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -189,7 +189,8 @@ ConfigureTest(TRANSFORM_TEST transform/integration/unary-transform-test.cpp transform/nans_to_null_test.cpp transform/mask_to_bools_test.cpp - transform/bools_to_mask_test.cpp) + transform/bools_to_mask_test.cpp + transform/row_bit_count_test.cu) ################################################################################################### # - interop tests ------------------------------------------------------------------------- diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu new file mode 100644 index 00000000000..f91e5658ab8 --- /dev/null +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -0,0 +1,508 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace cudf; + +using offset_type = int32_t; + +template +struct RowBitCountTyped : public cudf::test::BaseFixture { +}; + +TYPED_TEST_CASE(RowBitCountTyped, cudf::test::FixedWidthTypes); + +TYPED_TEST(RowBitCountTyped, SimpleTypes) +{ + using T = TypeParam; + + // no nulls + { + auto col = cudf::make_fixed_width_column(data_type{type_to_id()}, 16); + + table_view t({*col}); + auto result = cudf::row_bit_count(t); + + // expect size of the type per row + auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); + cudf::mutable_column_view mcv(*expected); + thrust::fill(rmm::exec_policy(0), + mcv.begin(), + mcv.end(), + sizeof(device_storage_type_t) * 8); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); + } + + // nulls + { + auto iter = thrust::make_counting_iterator(0); + auto valids = cudf::detail::make_counting_transform_iterator( + 0, [](int i) { return i % 2 == 0 ? true : false; }); + cudf::test::fixed_width_column_wrapper col(iter, iter + 16, valids); + + table_view t({col}); + auto result = cudf::row_bit_count(t); + + // expect size of the type + 1 bit per row + auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); + cudf::mutable_column_view mcv(*expected); + thrust::fill(rmm::exec_policy(0), + mcv.begin(), + mcv.end(), + (sizeof(device_storage_type_t) * 8) + 1); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); + } +} + +template +std::pair, std::unique_ptr> build_list_column() +{ + using LCW = cudf::test::lists_column_wrapper; + constexpr size_type type_size = sizeof(device_storage_type_t) * 8; + + // clang-format off + cudf::test::lists_column_wrapper col{ {{1, 2}, {3, 4, 5}}, + LCW{LCW{}}, + {LCW{10}}, + {{6, 7, 8}, {9}}, + {{-1, -2}, {-3, -4}}, + {{-5, -6, -7}, {-8, -9}} }; + + // expected size = (num rows at level 1 + num_rows at level 2) + # values in the leaf + cudf::test::fixed_width_column_wrapper expected{((4 + 8) * 8) + (type_size * 5), + ((4 + 0) * 8) + (type_size * 0), + ((4 + 4) * 8) + (type_size * 1), + ((4 + 8) * 8) + (type_size * 4), + ((4 + 8) * 8) + (type_size * 4), + ((4 + 8) * 8) + (type_size * 5)}; + + return {col.release(), expected.release()}; +} + +TYPED_TEST(RowBitCountTyped, Lists) +{ + using T = TypeParam; + + // no nulls + { + std::unique_ptr col; + std::unique_ptr expected_sizes; + std::tie(col, expected_sizes) = build_list_column(); + + // clang-format on + table_view t({*col}); + auto result = cudf::row_bit_count(t); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); + } + + // nulls + { + using LCW = cudf::test::lists_column_wrapper; + constexpr size_type type_size = sizeof(device_storage_type_t) * 8; + + std::vector valids{true, false, true}; + std::vector valids2{false, true, false}; + std::vector valids3{true, false}; + + // clang-format off + cudf::test::lists_column_wrapper col{ {{1, 2}, {{3, 4, 5}, valids.begin()}}, + LCW{LCW{}}, + {LCW{10}}, + {{{{6, 7, 8}, valids2.begin()}, {9}}, valids3.begin()} }; + // clang-format on + + table_view t({col}); + auto result = cudf::row_bit_count(t); + + // expected size = (num rows at level 1 + num_rows at level 2) + # values in the leaf + validity + // where applicable + cudf::test::fixed_width_column_wrapper expected{((4 + 8) * 8) + (type_size * 5) + 7, + ((4 + 0) * 8) + (type_size * 0), + ((4 + 4) * 8) + (type_size * 1) + 2, + ((4 + 8) * 8) + (type_size * 3) + 5}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } +} + +struct RowBitCount : public cudf::test::BaseFixture { +}; + +TEST_F(RowBitCount, Strings) +{ + // no nulls + { + std::vector strings{"abc", "def", "", "z", "bananas", "warp", "", "zing"}; + + cudf::test::strings_column_wrapper col(strings.begin(), strings.end()); + + table_view t({col}); + auto result = cudf::row_bit_count(t); + + // expect 1 offset (4 bytes) + length of string per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type)) * 8; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, + size_iter + strings.size()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } + + // nulls + { + // clang-format off + std::vector strings { "abc", "def", "", "z", "bananas", "warp", "", "zing" }; + std::vector valids { 1, 0, 0, 1, 0, 1, 1, 1 }; + // clang-format on + + cudf::test::strings_column_wrapper col(strings.begin(), strings.end(), valids.begin()); + + table_view t({col}); + auto result = cudf::row_bit_count(t); + + // expect 1 offset (4 bytes) + (length of string, or 0 if null) + 1 validity bit per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings, &valids](int i) { + return ((static_cast(valids[i] ? strings[i].size() : 0) + sizeof(offset_type)) * + 8) + + 1; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, + size_iter + strings.size()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } +} + +std::pair, std::unique_ptr> build_struct_column() +{ + std::vector struct_validity{0, 1, 1, 1, 1, 0}; + std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; + + cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; + cudf::test::fixed_width_column_wrapper col1{{8, 9, 10, 11, 12, 13}, {1, 0, 1, 1, 1, 1}}; + cudf::test::strings_column_wrapper col2(strings.begin(), strings.end()); + + // creating a struct column will cause all child columns to be promoted to have validity + cudf::test::structs_column_wrapper struct_col({col0, col1, col2}, struct_validity); + + // expect (1 offset (4 bytes) + (length of string if row is valid) + 1 validity bit) + + // (1 float + 1 validity bit) + + // (1 int16_t + 1 validity bit) + + // (1 validity bit) + auto size_iter = + cudf::detail::make_counting_transform_iterator(0, [&strings, &struct_validity](int i) { + return (sizeof(float) * 8) + 1 + (sizeof(int16_t) * 8) + 1 + + (static_cast(strings[i].size()) * 8) + (sizeof(offset_type) * 8) + 1 + 1; + }); + cudf::test::fixed_width_column_wrapper expected_sizes(size_iter, + size_iter + strings.size()); + + return {struct_col.release(), expected_sizes.release()}; +} + +TEST_F(RowBitCount, Structs) +{ + // no nulls + { + std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; + + cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; + cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; + cudf::test::strings_column_wrapper col2(strings.begin(), strings.end()); + + cudf::test::structs_column_wrapper struct_col({col0, col1, col2}); + + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); + + // expect 1 offset (4 bytes) + (length of string) + 1 float + 1 int16_t + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return ((sizeof(float) + sizeof(int16_t)) * 8) + + ((static_cast(strings[i].size()) + sizeof(offset_type)) * 8); + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } + + // nulls + { + // creating a struct column with validity will cause all child columns to be promoted to have + // validity + std::unique_ptr struct_col; + std::unique_ptr expected_sizes; + std::tie(struct_col, expected_sizes) = build_struct_column(); + table_view t({*struct_col}); + auto result = cudf::row_bit_count(t); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); + } + + // struct, int16> + { + cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; + cudf::test::structs_column_wrapper inner_struct({col0}); + + cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; + cudf::test::structs_column_wrapper struct_col({inner_struct, col1}); + + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); + + // expect num_rows * (4 + 2) bytes + auto size_iter = + cudf::detail::make_counting_transform_iterator(0, [&](int i) { return (4 + 2) * 8; }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } +} + +std::pair, std::unique_ptr> build_nested_column( + std::vector const& struct_validity) +{ + // tests the "branching" case -> list ...>>> + + // List, float, int16> + + // Inner list column + // clang-format off + cudf::test::lists_column_wrapper list{ + {1, 2, 3, 4, 5}, + {6, 7, 8}, + {33, 34, 35, 36, 37, 38, 39}, + {-1, -2}, + {-10, -11, -1, -20}, + {40, 41, 42}, + {100, 200, 300}, + {-100, -200, -300}}; + // clang-format on + + // floats + std::vector ages{5, 10, 15, 20, 4, 75, 16, -16}; + std::vector ages_validity = {1, 1, 1, 1, 0, 1, 0, 1}; + auto ages_column = + cudf::test::fixed_width_column_wrapper(ages.begin(), ages.end(), ages_validity.begin()); + + // int16 values + std::vector vals{-1, -2, -3, 1, 2, 3, 8, 9}; + auto i16_column = cudf::test::fixed_width_column_wrapper(vals.begin(), vals.end()); + + // Assemble struct column + auto struct_column = + cudf::test::structs_column_wrapper({list, ages_column, i16_column}, struct_validity); + + // wrap in a list + std::vector outer_offsets{0, 1, 1, 3, 6, 7, 8}; + cudf::test::fixed_width_column_wrapper outer_offsets_col(outer_offsets.begin(), + outer_offsets.end()); + auto const size = static_cast(outer_offsets_col).size() - 1; + + cudf::test::fixed_width_column_wrapper expected_sizes{276, 32, 520, 572, 212, 212}; + + return {cudf::make_lists_column(static_cast(size), + outer_offsets_col.release(), + struct_column.release(), + cudf::UNKNOWN_NULL_COUNT, + rmm::device_buffer{}), + expected_sizes.release()}; +} + +std::unique_ptr build_nested_column2(std::vector const& struct_validity) +{ + // List>, Struct>> + + // Inner list column + // clang-format off + cudf::test::lists_column_wrapper list{ + {{1, 2, 3, 4, 5}, {2, 3}}, + {{6, 7, 8}, {8, 9}}, + {{1, 2}, {3, 4, 5}, {33, 34, 35, 36, 37, 38, 39}}}; + // clang-format on + + // Inner struct + std::vector vals{-1, -2, -3}; + auto i16_column = cudf::test::fixed_width_column_wrapper(vals.begin(), vals.end()); + auto inner_struct = cudf::test::structs_column_wrapper({i16_column}); + + // outer struct + auto outer_struct = cudf::test::structs_column_wrapper({list, inner_struct}, struct_validity); + + // wrap in a list + std::vector outer_offsets{0, 1, 1, 3}; + cudf::test::fixed_width_column_wrapper outer_offsets_col(outer_offsets.begin(), + outer_offsets.end()); + auto const size = static_cast(outer_offsets_col).size() - 1; + return make_lists_column(static_cast(size), + outer_offsets_col.release(), + outer_struct.release(), + cudf::UNKNOWN_NULL_COUNT, + rmm::device_buffer{}); +} + +TEST_F(RowBitCount, NestedTypes) +{ + // List, float, List, int16> + { + std::unique_ptr col_no_nulls; + std::unique_ptr expected_sizes; + std::tie(col_no_nulls, expected_sizes) = build_nested_column({1, 1, 1, 1, 1, 1, 1, 1}); + table_view no_nulls_t({*col_no_nulls}); + auto no_nulls_result = cudf::row_bit_count(no_nulls_t); + + auto col_nulls = build_nested_column({0, 0, 1, 1, 1, 1, 1, 1}).first; + table_view nulls_t({*col_nulls}); + auto nulls_result = cudf::row_bit_count(nulls_t); + + // List, float, int16> + // + // this illustrates the difference between a row_bit_count + // returning a pre-gather result, or a post-gather result. + // + // in a post-gather situation, the nulls in the struct would result in the values + // nested in the list below to be dropped, resulting in smaller row sizes. + // + // however, for performance reasons, row_bit_count simply walks the data that is + // currently there. so list rows that are null, but have a real span of + // offsets (X, Y) instead of (X, X) will end up getting the child data for those + // rows included. + // + // if row_bit_count() is changed to return a post-gather result (which may be desirable), + // the nulls_result case below will start failing and will need to be changed. + // + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *no_nulls_result); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *nulls_result); + } + + // List>, Struct>> + { + auto col_no_nulls = build_nested_column2({1, 1, 1}); + table_view no_nulls_t({*col_no_nulls}); + auto no_nulls_result = cudf::row_bit_count(no_nulls_t); + + auto col_nulls = build_nested_column2({1, 0, 1}); + table_view nulls_t({*col_nulls}); + auto nulls_result = cudf::row_bit_count(nulls_t); + + cudf::test::fixed_width_column_wrapper expected_sizes{372, 32, 840}; + + // same explanation as above + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_sizes, *no_nulls_result); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_sizes, *nulls_result); + } + + // test pushing/popping multiple times within one struct, and branch depth > 1 + // + // Struct, float, List>, Struct, List, + // float>>, int8_t>> + { + cudf::test::lists_column_wrapper l0{{1, 2, 3}, {4, 5}, {6, 7, 8, 9}, {5}}; + cudf::test::lists_column_wrapper l1{ + {{-1, -2}, {3, 4}}, {{4, 5}, {6, 7, 8}}, {{-6, -7}, {2}}, {{-11, -11}, {-12, -12}, {3}}}; + cudf::test::lists_column_wrapper l2{{-1, -2}, {4, 5}, {-6, -7}, {1}}; + cudf::test::lists_column_wrapper l3{{-1, -2, 0}, {5}, {-1, -6, -7}, {1, 2}}; + + cudf::test::fixed_width_column_wrapper c0{1, 2, 3, 4}; + cudf::test::fixed_width_column_wrapper c1{1, 2, 3, 4}; + cudf::test::fixed_width_column_wrapper c2{1, 2, 3, 4}; + cudf::test::fixed_width_column_wrapper c3{11, 12, 13, 14}; + + // innermost List>> + auto innermost_struct = cudf::test::structs_column_wrapper({l3, c3}); + std::vector l4_offsets{0, 1, 2, 3, 4}; + cudf::test::fixed_width_column_wrapper l4_offsets_col(l4_offsets.begin(), + l4_offsets.end()); + auto const l4_size = l4_offsets.size() - 1; + auto l4 = cudf::make_lists_column(static_cast(l4_size), + l4_offsets_col.release(), + innermost_struct.release(), + cudf::UNKNOWN_NULL_COUNT, + rmm::device_buffer{}); + + // inner struct + std::vector> inner_struct_children; + inner_struct_children.push_back(l2.release()); + inner_struct_children.push_back(std::move(l4)); + auto inner_struct = cudf::test::structs_column_wrapper(std::move(inner_struct_children)); + + // outer struct + auto struct_col = cudf::test::structs_column_wrapper({c0, l0, c1, l1, inner_struct, c2}); + + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); + + cudf::test::fixed_width_column_wrapper expected_sizes{648, 568, 664, 568}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_sizes, *result); + } +} + +struct sum_functor { + size_type const* s0; + size_type const* s1; + size_type const* s2; + + size_type operator() __device__(int i) { return s0[i] + s1[i] + s2[i]; } +}; + +TEST_F(RowBitCount, Table) +{ + // complex nested column + std::unique_ptr col0; + std::unique_ptr col0_sizes; + std::tie(col0, col0_sizes) = build_nested_column({1, 1, 1, 1, 1, 1, 1, 1}); + + // struct column + std::unique_ptr col1; + std::unique_ptr col1_sizes; + std::tie(col1, col1_sizes) = build_struct_column(); + + // list column + std::unique_ptr col2; + std::unique_ptr col2_sizes; + std::tie(col2, col2_sizes) = build_list_column(); + + table_view t({*col0, *col1, *col2}); + auto result = cudf::row_bit_count(t); + + // sum all column sizes + column_view cv0 = static_cast(*col0_sizes); + column_view cv1 = static_cast(*col1_sizes); + column_view cv2 = static_cast(*col2_sizes); + auto expected = cudf::make_fixed_width_column(data_type{type_id::INT32}, t.num_rows()); + cudf::mutable_column_view mcv(*expected); + thrust::transform( + rmm::exec_policy(0), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(0) + t.num_rows(), + mcv.begin(), + sum_functor{cv0.data(), cv1.data(), cv2.data()}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); +} \ No newline at end of file diff --git a/cpp/tests/utilities/column_utilities.cu b/cpp/tests/utilities/column_utilities.cu index cea66eced11..78a67464654 100644 --- a/cpp/tests/utilities/column_utilities.cu +++ b/cpp/tests/utilities/column_utilities.cu @@ -694,12 +694,13 @@ struct column_view_printer { get_nested_type_str(col) + (is_sliced ? "(sliced)" : "") + ":\n" + indent + "Length : " + std::to_string(lcv.size()) + "\n" + indent + "Offsets : " + (lcv.size() > 0 ? nested_offsets_to_string(lcv) : "") + "\n" + - (lcv.has_nulls() ? indent + "Null count: " + std::to_string(lcv.null_count()) + "\n" + - detail::to_string(bitmask_to_host(col), col.size(), indent) + "\n" - : "") + - indent + "Children :\n" + - (child.type().id() != type_id::LIST && child.has_nulls() - ? indent + detail::to_string(bitmask_to_host(child), child.size(), indent) + "\n" + (lcv.parent().nullable() + ? indent + "Null count: " + std::to_string(lcv.null_count()) + "\n" + + detail::to_string(bitmask_to_host(col), col.size(), indent) + "\n" + : "") + + // non-nested types don't typically display their null masks, so do it here for convenience. + (!is_nested(child.type()) && child.nullable() + ? " " + detail::to_string(bitmask_to_host(child), child.size(), indent) + "\n" : "") + (detail::to_string(child, ", ", indent + " ")) + "\n"; @@ -718,18 +719,25 @@ struct column_view_printer { out_stream << get_nested_type_str(col) << ":\n" << indent << "Length : " << view.size() << ":\n"; - if (view.has_nulls()) { + if (view.nullable()) { out_stream << indent << "Null count: " << view.null_count() << "\n" << detail::to_string(bitmask_to_host(col), col.size(), indent) << "\n"; } auto iter = thrust::make_counting_iterator(0); - std::transform(iter, - iter + view.num_children(), - std::ostream_iterator(out_stream, "\n"), - [&](size_type index) { - return detail::to_string(view.get_sliced_child(index), ", ", indent + " "); - }); + std::transform( + iter, + iter + view.num_children(), + std::ostream_iterator(out_stream, "\n"), + [&](size_type index) { + auto child = view.get_sliced_child(index); + + // non-nested types don't typically display their null masks, so do it here for convenience. + return (!is_nested(child.type()) && child.nullable() + ? " " + detail::to_string(bitmask_to_host(child), child.size(), indent) + "\n" + : "") + + detail::to_string(child, ", ", indent + " "); + }); out.push_back(out_stream.str()); } From 415377817611ef847ac0843c02c61e35d43f50c9 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 8 Mar 2021 17:24:58 -0600 Subject: [PATCH 02/12] Newline at the end of row_bit_count_test.cu --- cpp/tests/transform/row_bit_count_test.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index f91e5658ab8..cef858b6368 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -505,4 +505,4 @@ TEST_F(RowBitCount, Table) mcv.begin(), sum_functor{cv0.data(), cv1.data(), cv2.data()}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); -} \ No newline at end of file +} From d1f51c68b22e083bbf7b3de7ac7a6737c581e476 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 8 Mar 2021 17:34:59 -0600 Subject: [PATCH 03/12] Comment fixes --- cpp/src/transform/row_bit_count.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index fdf6600789c..ff488d4da8a 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -64,9 +64,9 @@ using offset_type = int32_t; * * span0 = [1, 2] row 1 is represented by the span [1, 2] * span1 = [L0.offsets[span0[0]], L0.offsets[span0[1]]] expand by the offsets of L0 - * span1 = [3, 6] the range of offsets + * span1 = [3, 6] span applied to children of L0 * span2 = [L1.offsets[span1[0]], L1.offsets[span1[1]]] expand by the offsets of L1 - * span2 = [6, 14] + * span2 = [6, 14] span applied to children of L1 * * The total size of our row is computed as: * (span0[1] - span0[0]) * sizeof(int) the cost of the offsets for L0 @@ -108,7 +108,7 @@ struct column_info { * * If the input data contains only fixed-width types, this preprocess step * produces the value `simple_per_row_size` which is a constant for every - * row in the output. We can use thie value and skip the more complicated + * row in the output. We can use this value and skip the more complicated * processing for lists, structs and strings entirely if `complex_type_count` * is 0. * @@ -298,8 +298,8 @@ struct row_span { }; /** - * @brief Functor for computing the size, in bits, of a `span` of rows for a given - * column_device_view + * @brief Functor for computing the size, in bits, of a `row_span` of rows for a given + * `column_device_view` * */ struct row_size_functor { From 801ad1022df7830d9caf34bed711ef1fc422a952 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Tue, 9 Mar 2021 17:57:00 -0600 Subject: [PATCH 04/12] Small review comment cleanup. Add tests for sliced columns. --- cpp/src/transform/row_bit_count.cu | 22 +++--- cpp/tests/transform/row_bit_count_test.cu | 82 +++++++++++++++++++++++ 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index ff488d4da8a..116a60245d9 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -100,7 +100,8 @@ using offset_type = int32_t; */ struct column_info { size_type depth; - size_type branch_depth_start, branch_depth_end; + size_type branch_depth_start; + size_type branch_depth_end; }; /** @@ -307,7 +308,9 @@ struct row_size_functor { __device__ size_type operator()(column_device_view const& col, row_span const& span) { auto const num_rows{span.row_end - span.row_start}; - return ((sizeof(device_storage_type_t) * 8) + (col.nullable() ? 1 : 0)) * num_rows; + auto const element_size = sizeof(device_storage_type_t) * 8; + auto const validity_size = col.nullable() ? 1 : 0; + return (element_size + validity_size) * num_rows; } }; @@ -320,10 +323,11 @@ __device__ size_type row_size_functor::operator()(column_device_vie auto const row_start{span.row_start + col.offset()}; auto const row_end{span.row_end + col.offset()}; - return (((sizeof(offset_type) * 8) + (col.nullable() ? 1 : 0)) * - num_rows) + // cost of offsets + validity - ((offsets.data()[row_end] - offsets.data()[row_start]) * - 8); // cost of chars + auto const offsets_size = sizeof(offset_type) * 8; + auto const validity_size = col.nullable() ? 1 : 0; + auto const chars_size = + (offsets.data()[row_end] - offsets.data()[row_start]) * 8; + return ((offsets_size + validity_size) * num_rows) + chars_size; } template <> @@ -332,8 +336,10 @@ __device__ size_type row_size_functor::operator()(column_device_view { column_device_view const& offsets = col.child(lists_column_view::offsets_column_index); auto const num_rows{span.row_end - span.row_start}; - return ((sizeof(offset_type) * 8) + (col.nullable() ? 1 : 0)) * - num_rows; // cost of offsets + validity + + auto const offsets_size = sizeof(offset_type) * 8; + auto const validity_size = col.nullable() ? 1 : 0; + return (offsets_size + validity_size) * num_rows; } template <> diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index cef858b6368..7b871053ec4 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -506,3 +506,85 @@ TEST_F(RowBitCount, Table) sum_functor{cv0.data(), cv1.data(), cv2.data()}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); } + +TEST_F(RowBitCount, SlicedColumns) +{ + // fixed width + { + auto const slice_size = 7; + cudf::test::fixed_width_column_wrapper c0_unsliced{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + auto c0 = cudf::slice(c0_unsliced, {2, 2 + slice_size}); + + table_view t({c0}); + auto result = cudf::row_bit_count(t); + + cudf::test::fixed_width_column_wrapper expected{16, 16, 16, 16, 16, 16, 16}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } + + // strings + { + auto const slice_size = 7; + std::vector strings{ + "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + cudf::test::strings_column_wrapper c0_unsliced(strings.begin(), strings.end()); + auto c0 = cudf::slice(c0_unsliced, {3, 3 + slice_size}); + + table_view t({c0}); + auto result = cudf::row_bit_count(t); + + // expect 1 offset (4 bytes) + length of string per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type)) * 8; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter + 3, + size_iter + 3 + slice_size); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } + + // lists + { + auto const slice_size = 2; + cudf::test::lists_column_wrapper c0_unsliced{ + {{"banana", "v"}, {"cats"}}, + {{"dogs", "yay"}, {"xyz", ""}, {"ultra"}}, + {{"fast", "parrot"}, {"orange"}}, + {{"blue"}, {"red", "yellow"}, {"ultraviolet", "", "green"}}}; + auto c0 = cudf::slice(c0_unsliced, {1, 1 + slice_size}); + + table_view t({c0}); + auto result = cudf::row_bit_count(t); + + cudf::test::fixed_width_column_wrapper expected{408, 320}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } + + // structs + { + auto const slice_size = 7; + + cudf::test::fixed_width_column_wrapper c0{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + std::vector strings{ + "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + cudf::test::strings_column_wrapper c1(strings.begin(), strings.end()); + + auto struct_col_unsliced = cudf::test::structs_column_wrapper({c0, c1}); + auto struct_col = cudf::slice(struct_col_unsliced, {3, 3 + slice_size}); + + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); + + // expect 1 offset (4 bytes) + length of string per row + 1 int16_t per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type) + sizeof(int16_t)) * + 8; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter + 3, + size_iter + 3 + slice_size); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); + } +} \ No newline at end of file From 63e33410ea384407c11512754cb6026bcd078361 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Thu, 18 Mar 2021 17:36:36 -0500 Subject: [PATCH 05/12] PR review changes. Of note: added cudf::offset_type to types.hpp --- cpp/include/cudf/lists/lists_column_view.hpp | 1 - cpp/include/cudf/transform.hpp | 7 +- cpp/include/cudf/types.hpp | 1 + cpp/src/lists/drop_list_duplicates.cu | 2 +- cpp/src/transform/row_bit_count.cu | 117 ++++++---- cpp/tests/transform/row_bit_count_test.cu | 227 +++++++++---------- 6 files changed, 190 insertions(+), 165 deletions(-) diff --git a/cpp/include/cudf/lists/lists_column_view.hpp b/cpp/include/cudf/lists/lists_column_view.hpp index f8facb83975..768dde2c280 100644 --- a/cpp/include/cudf/lists/lists_column_view.hpp +++ b/cpp/include/cudf/lists/lists_column_view.hpp @@ -56,7 +56,6 @@ class lists_column_view : private column_view { using column_view::null_mask; using column_view::offset; using column_view::size; - using offset_type = int32_t; static_assert(std::is_same::value, "offset_type is expected to be the same as size_type."); using offset_iterator = offset_type const*; diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index c3fc37f6758..4c6804b3d36 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -143,9 +143,12 @@ std::unique_ptr mask_to_bools( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** - * @brief Returns the cumulative size in bits of all columns in the `table_view` for + * @brief Returns an approximate cumulative size in bits of all columns in the `table_view` for * each row. * + * This function counts bits instead of bytes to account for the null mask which only has one + * bit per row. + * * Each row in the returned column is the sum of the per-row size for each column in * the table. * @@ -168,7 +171,7 @@ std::unique_ptr mask_to_bools( * * @param t The table view to perform the computation on. * @param mr Device memory resource used to allocate the returned columns's device memory - * @return A 32-bit integer column containing the per-row byte counts. + * @return A 32-bit integer column containing the per-row bit counts. */ std::unique_ptr row_bit_count( table_view const& t, diff --git a/cpp/include/cudf/types.hpp b/cpp/include/cudf/types.hpp index 48e5d9543b8..a37a17f728e 100644 --- a/cpp/include/cudf/types.hpp +++ b/cpp/include/cudf/types.hpp @@ -91,6 +91,7 @@ class mutable_table_view; using size_type = int32_t; using bitmask_type = uint32_t; using valid_type = uint8_t; +using offset_type = int32_t; /** * @brief Similar to `std::distance` but returns `cudf::size_type` and performs `static_cast` diff --git a/cpp/src/lists/drop_list_duplicates.cu b/cpp/src/lists/drop_list_duplicates.cu index 1eb105d296d..b244d7c1d35 100644 --- a/cpp/src/lists/drop_list_duplicates.cu +++ b/cpp/src/lists/drop_list_duplicates.cu @@ -34,7 +34,7 @@ namespace cudf { namespace lists { namespace detail { namespace { -using offset_type = lists_column_view::offset_type; + /** * @brief Copy list entries and entry list offsets ignoring duplicates * diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index 116a60245d9..292fe46ae57 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -32,8 +32,6 @@ namespace detail { namespace { -using offset_type = int32_t; - /** * @brief Struct which contains per-column information necessary to * traverse a column hierarchy on the gpu. @@ -86,7 +84,7 @@ using offset_type = int32_t; * [4, 8] [5, 9] [4, 8] * struct< list<> float> * - * To accomplish this we mantain a stack of spans. Pushing the current span + * To accomplish this we maintain a stack of spans. Pushing the current span * whenever we enter a branch, and popping a span whenever we leave a branch. * * `branch_depth_start` represents the branch depth as we reach a new column. @@ -133,6 +131,14 @@ struct hierarchy_info { * of column_views and produces accompanying column_info and hierarchy_info * metadata. * + * @param begin: Beginning of a range of column views + * @param end: End of a range of column views + * @param out: (output) Flattened vector of output column_views + * @param info: (output) Additional per-output column_view metadata needed by the gpu + * @param h_info: (output) Information about the hierarchy + * @param cur_depth: Current absolute depth in the hierarchy + * @param cur_branch_depth: Current branch depth + * @param parent_index: Index into `out` representing our owning parent column */ template void flatten_hierarchy(ColIter begin, @@ -141,9 +147,9 @@ void flatten_hierarchy(ColIter begin, std::vector& info, hierarchy_info& h_info, rmm::cuda_stream_view stream, - size_type cur_depth = 0, - size_type cur_branch_depth = 0, - int parent_index = -1); + size_type cur_depth = 0, + size_type cur_branch_depth = 0, + thrust::optional parent_index = {}); /** * @brief Type-dispatched functor called by flatten_hierarchy. @@ -161,11 +167,12 @@ struct flatten_functor { rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { out.push_back(col); info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); - h_info.simple_per_row_size += (sizeof(device_storage_type_t) * 8) + (col.nullable() ? 1 : 0); + h_info.simple_per_row_size += + (sizeof(device_storage_type_t) * CHAR_BIT) + (col.nullable() ? 1 : 0); } // strings @@ -177,7 +184,7 @@ struct flatten_functor { rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { out.push_back(col); info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); @@ -193,15 +200,15 @@ struct flatten_functor { rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { // track branch depth as we reach this list and after we pass it size_type const branch_depth_start = cur_branch_depth; - if (parent_index >= 0 && out[parent_index].type().id() == type_id::STRUCT) { + auto const is_list_inside_struct = + parent_index && out[parent_index.value()].type().id() == type_id::STRUCT; + if (is_list_inside_struct) { cur_branch_depth++; - if (cur_branch_depth > h_info.max_branch_depth) { - h_info.max_branch_depth = cur_branch_depth; - } + h_info.max_branch_depth = max(h_info.max_branch_depth, cur_branch_depth); } size_type const branch_depth_end = cur_branch_depth; @@ -226,7 +233,7 @@ struct flatten_functor { rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { out.push_back(col); info.push_back({cur_depth, cur_branch_depth, cur_branch_depth}); @@ -259,7 +266,7 @@ struct flatten_functor { rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { CUDF_FAIL("Unsupported column type in row_bit_count"); } @@ -274,7 +281,7 @@ void flatten_hierarchy(ColIter begin, rmm::cuda_stream_view stream, size_type cur_depth, size_type cur_branch_depth, - int parent_index) + thrust::optional parent_index) { std::for_each(begin, end, [&](column_view const& col) { cudf::type_dispatcher(col.type(), @@ -304,16 +311,30 @@ struct row_span { * */ struct row_size_functor { + /** + * @brief Computes size in bits of a span of rows in a fixed-width column. + * + * Computed as : ((# of rows) * sizeof(data type) * 8) + * + + * 1 bit per row for validity if applicable. + */ template __device__ size_type operator()(column_device_view const& col, row_span const& span) { auto const num_rows{span.row_end - span.row_start}; - auto const element_size = sizeof(device_storage_type_t) * 8; + auto const element_size = sizeof(device_storage_type_t) * CHAR_BIT; auto const validity_size = col.nullable() ? 1 : 0; return (element_size + validity_size) * num_rows; } }; +/** + * @brief Computes size in bits of a span of rows in a strings column. + * + * Computed as : ((# of rows) * sizeof(offset) * 8) + (total # of characters * 8)) + * + + * 1 bit per row for validity if applicable. + */ template <> __device__ size_type row_size_functor::operator()(column_device_view const& col, row_span const& span) @@ -323,13 +344,20 @@ __device__ size_type row_size_functor::operator()(column_device_vie auto const row_start{span.row_start + col.offset()}; auto const row_end{span.row_end + col.offset()}; - auto const offsets_size = sizeof(offset_type) * 8; + auto const offsets_size = sizeof(offset_type) * CHAR_BIT; auto const validity_size = col.nullable() ? 1 : 0; auto const chars_size = - (offsets.data()[row_end] - offsets.data()[row_start]) * 8; + (offsets.data()[row_end] - offsets.data()[row_start]) * CHAR_BIT; return ((offsets_size + validity_size) * num_rows) + chars_size; } +/** + * @brief Computes size in bits of a span of rows in a list column. + * + * Computed as : ((# of rows) * sizeof(offset) * 8) + * + + * 1 bit per row for validity if applicable. + */ template <> __device__ size_type row_size_functor::operator()(column_device_view const& col, row_span const& span) @@ -337,11 +365,16 @@ __device__ size_type row_size_functor::operator()(column_device_view column_device_view const& offsets = col.child(lists_column_view::offsets_column_index); auto const num_rows{span.row_end - span.row_start}; - auto const offsets_size = sizeof(offset_type) * 8; + auto const offsets_size = sizeof(offset_type) * CHAR_BIT; auto const validity_size = col.nullable() ? 1 : 0; return (offsets_size + validity_size) * num_rows; } +/** + * @brief Computes size in bits of a span of rows in a struct column. + * + * Computed as : 1 bit per row for validity if applicable. + */ template <> __device__ size_type row_size_functor::operator()(column_device_view const& col, row_span const& span) @@ -353,48 +386,46 @@ __device__ size_type row_size_functor::operator()(column_device_vie /** * @brief Kernel for computing per-row sizes in bits. * - * @param cols An array of column_device_views represeting a column hierarcy - * @param info An array of column_info structs corresponding the elements in `cols` - * @param num_columns The number of columns - * @param num_rows The number of rows in the root column - * @param output Output buffer of size `num_rows` where per-row bit sizes are stored + * @param cols An span of column_device_views represeting a column hierarcy + * @param info An span of column_info structs corresponding the elements in `cols` + * @param output Output span of size (# rows) where per-row bit sizes are stored * @param max_branch_depth Maximum depth of the span stack needed per-thread - * */ -__global__ void compute_row_sizes(column_device_view* cols, - column_info* info, - size_type num_columns, - size_type num_rows, - size_type* output, +__global__ void compute_row_sizes(device_span cols, + device_span info, + device_span output, size_type max_branch_depth) { - extern __shared__ row_span branch_shared[]; + extern __shared__ row_span thread_branch_stacks[]; int const tid = threadIdx.x + blockIdx.x * blockDim.x; + auto const num_rows = output.size(); if (tid >= num_rows) { return; } // branch stack. points to the last list prior to branching. - row_span* branch = branch_shared + (tid * max_branch_depth); + row_span* my_branch_stack = thread_branch_stacks + (tid * max_branch_depth); size_type branch_depth{0}; // current row span - always starts at 1 row. row_span cur_span{tid, tid + 1}; // output size - size_type& size = *(output + tid); + size_type& size = output[tid]; size = 0; size_type last_branch_depth{0}; - for (size_type idx = 0; idx < num_columns; idx++) { + for (size_type idx = 0; idx < cols.size(); idx++) { column_device_view const& col = cols[idx]; // if we've returned from a branch - if (info[idx].branch_depth_start < last_branch_depth) { cur_span = branch[--branch_depth]; } + if (info[idx].branch_depth_start < last_branch_depth) { + cur_span = my_branch_stack[--branch_depth]; + } // if we're entering a new branch. // NOTE: this case can happen (a pop and a push by the same column) // when we have a struct if (info[idx].branch_depth_end > info[idx].branch_depth_start) { - branch[branch_depth++] = cur_span; + my_branch_stack[branch_depth++] = cur_span; } // if we're back at depth 0, this is a new top-level column, so reset @@ -436,6 +467,7 @@ std::unique_ptr row_bit_count(table_view const& t, std::vector info; hierarchy_info h_info; flatten_hierarchy(t.begin(), t.end(), cols, info, h_info, stream); + CUDF_EXPECTS(info.size() == cols.size(), "Size/info mismatch"); // create output buffer and view auto output = cudf::make_fixed_width_column( @@ -462,7 +494,6 @@ std::unique_ptr row_bit_count(table_view const& t, sizeof(column_info) * info.size(), cudaMemcpyHostToDevice, stream.value())); - CUDF_EXPECTS(info.size() == cols.size(), "Size/info mismatch"); // each thread needs to maintain a stack of row spans of size max_branch_depth. we will use // shared memory to do this rather than allocating a potentially gigantic temporary buffer @@ -484,11 +515,9 @@ std::unique_ptr row_bit_count(table_view const& t, cudf::detail::grid_1d grid{t.num_rows(), block_size, 1}; compute_row_sizes<<>>( - std::get<1>(d_cols), - d_info.data(), - info.size(), - t.num_rows(), - mcv.data(), + {std::get<1>(d_cols), cols.size()}, + {d_info.data(), info.size()}, + {mcv.data(), static_cast(t.num_rows())}, h_info.max_branch_depth); return output; diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index 7b871053ec4..a99ae0ddf9a 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -28,8 +28,6 @@ using namespace cudf; -using offset_type = int32_t; - template struct RowBitCountTyped : public cudf::test::BaseFixture { }; @@ -40,51 +38,50 @@ TYPED_TEST(RowBitCountTyped, SimpleTypes) { using T = TypeParam; - // no nulls - { - auto col = cudf::make_fixed_width_column(data_type{type_to_id()}, 16); + auto col = cudf::make_fixed_width_column(data_type{type_to_id()}, 16); - table_view t({*col}); - auto result = cudf::row_bit_count(t); + table_view t({*col}); + auto result = cudf::row_bit_count(t); - // expect size of the type per row - auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); - cudf::mutable_column_view mcv(*expected); - thrust::fill(rmm::exec_policy(0), - mcv.begin(), - mcv.end(), - sizeof(device_storage_type_t) * 8); + // expect size of the type per row + auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); + cudf::mutable_column_view mcv(*expected); + thrust::fill(rmm::exec_policy(0), + mcv.begin(), + mcv.end(), + sizeof(device_storage_type_t) * CHAR_BIT); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); +} - // nulls - { - auto iter = thrust::make_counting_iterator(0); - auto valids = cudf::detail::make_counting_transform_iterator( - 0, [](int i) { return i % 2 == 0 ? true : false; }); - cudf::test::fixed_width_column_wrapper col(iter, iter + 16, valids); +TYPED_TEST(RowBitCountTyped, SimpleTypesWithNulls) +{ + using T = TypeParam; - table_view t({col}); - auto result = cudf::row_bit_count(t); + auto iter = thrust::make_counting_iterator(0); + auto valids = cudf::detail::make_counting_transform_iterator( + 0, [](int i) { return i % 2 == 0 ? true : false; }); + cudf::test::fixed_width_column_wrapper col(iter, iter + 16, valids); - // expect size of the type + 1 bit per row - auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); - cudf::mutable_column_view mcv(*expected); - thrust::fill(rmm::exec_policy(0), - mcv.begin(), - mcv.end(), - (sizeof(device_storage_type_t) * 8) + 1); + table_view t({col}); + auto result = cudf::row_bit_count(t); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); - } + // expect size of the type + 1 bit per row + auto expected = make_fixed_width_column(data_type{type_id::INT32}, 16); + cudf::mutable_column_view mcv(*expected); + thrust::fill(rmm::exec_policy(0), + mcv.begin(), + mcv.end(), + (sizeof(device_storage_type_t) * CHAR_BIT) + 1); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); } template std::pair, std::unique_ptr> build_list_column() { using LCW = cudf::test::lists_column_wrapper; - constexpr size_type type_size = sizeof(device_storage_type_t) * 8; + constexpr size_type type_size = sizeof(device_storage_type_t) * CHAR_BIT; // clang-format off cudf::test::lists_column_wrapper col{ {{1, 2}, {3, 4, 5}}, @@ -92,64 +89,64 @@ std::pair, std::unique_ptr> build_list_column() {LCW{10}}, {{6, 7, 8}, {9}}, {{-1, -2}, {-3, -4}}, - {{-5, -6, -7}, {-8, -9}} }; + {{-5, -6, -7}, {-8, -9}} }; + // clang-format on // expected size = (num rows at level 1 + num_rows at level 2) + # values in the leaf - cudf::test::fixed_width_column_wrapper expected{((4 + 8) * 8) + (type_size * 5), - ((4 + 0) * 8) + (type_size * 0), - ((4 + 4) * 8) + (type_size * 1), - ((4 + 8) * 8) + (type_size * 4), - ((4 + 8) * 8) + (type_size * 4), - ((4 + 8) * 8) + (type_size * 5)}; + cudf::test::fixed_width_column_wrapper expected{ + ((4 + 8) * CHAR_BIT) + (type_size * 5), + ((4 + 0) * CHAR_BIT) + (type_size * 0), + ((4 + 4) * CHAR_BIT) + (type_size * 1), + ((4 + 8) * CHAR_BIT) + (type_size * 4), + ((4 + 8) * CHAR_BIT) + (type_size * 4), + ((4 + 8) * CHAR_BIT) + (type_size * 5)}; return {col.release(), expected.release()}; } TYPED_TEST(RowBitCountTyped, Lists) -{ - using T = TypeParam; +{ + using T = TypeParam; - // no nulls - { - std::unique_ptr col; - std::unique_ptr expected_sizes; - std::tie(col, expected_sizes) = build_list_column(); + std::unique_ptr col; + std::unique_ptr expected_sizes; + std::tie(col, expected_sizes) = build_list_column(); - // clang-format on - table_view t({*col}); - auto result = cudf::row_bit_count(t); + table_view t({*col}); + auto result = cudf::row_bit_count(t); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); +} - // nulls - { - using LCW = cudf::test::lists_column_wrapper; - constexpr size_type type_size = sizeof(device_storage_type_t) * 8; +TYPED_TEST(RowBitCountTyped, ListsWithNulls) +{ + using T = TypeParam; + using LCW = cudf::test::lists_column_wrapper; + constexpr size_type type_size = sizeof(device_storage_type_t) * CHAR_BIT; - std::vector valids{true, false, true}; - std::vector valids2{false, true, false}; - std::vector valids3{true, false}; + std::vector valids{true, false, true}; + std::vector valids2{false, true, false}; + std::vector valids3{true, false}; - // clang-format off - cudf::test::lists_column_wrapper col{ {{1, 2}, {{3, 4, 5}, valids.begin()}}, - LCW{LCW{}}, - {LCW{10}}, - {{{{6, 7, 8}, valids2.begin()}, {9}}, valids3.begin()} }; - // clang-format on + // clang-format off + cudf::test::lists_column_wrapper col{ {{1, 2}, {{3, 4, 5}, valids.begin()}}, + LCW{LCW{}}, + {LCW{10}}, + {{{{6, 7, 8}, valids2.begin()}, {9}}, valids3.begin()} }; + // clang-format on - table_view t({col}); - auto result = cudf::row_bit_count(t); + table_view t({col}); + auto result = cudf::row_bit_count(t); - // expected size = (num rows at level 1 + num_rows at level 2) + # values in the leaf + validity - // where applicable - cudf::test::fixed_width_column_wrapper expected{((4 + 8) * 8) + (type_size * 5) + 7, - ((4 + 0) * 8) + (type_size * 0), - ((4 + 4) * 8) + (type_size * 1) + 2, - ((4 + 8) * 8) + (type_size * 3) + 5}; + // expected size = (num rows at level 1 + num_rows at level 2) + # values in the leaf + validity + // where applicable + cudf::test::fixed_width_column_wrapper expected{ + ((4 + 8) * CHAR_BIT) + (type_size * 5) + 7, + ((4 + 0) * CHAR_BIT) + (type_size * 0), + ((4 + 4) * CHAR_BIT) + (type_size * 1) + 2, + ((4 + 8) * CHAR_BIT) + (type_size * 3) + 5}; - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } struct RowBitCount : public cudf::test::BaseFixture { @@ -157,48 +154,43 @@ struct RowBitCount : public cudf::test::BaseFixture { TEST_F(RowBitCount, Strings) { - // no nulls - { - std::vector strings{"abc", "def", "", "z", "bananas", "warp", "", "zing"}; + std::vector strings{"abc", "def", "", "z", "bananas", "warp", "", "zing"}; - cudf::test::strings_column_wrapper col(strings.begin(), strings.end()); + cudf::test::strings_column_wrapper col(strings.begin(), strings.end()); - table_view t({col}); - auto result = cudf::row_bit_count(t); + table_view t({col}); + auto result = cudf::row_bit_count(t); - // expect 1 offset (4 bytes) + length of string per row - auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return (static_cast(strings[i].size()) + sizeof(offset_type)) * 8; - }); - cudf::test::fixed_width_column_wrapper expected(size_iter, - size_iter + strings.size()); + // expect 1 offset (4 bytes) + length of string per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + strings.size()); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} - // nulls - { - // clang-format off - std::vector strings { "abc", "def", "", "z", "bananas", "warp", "", "zing" }; - std::vector valids { 1, 0, 0, 1, 0, 1, 1, 1 }; - // clang-format on +TEST_F(RowBitCount, StringsWithNulls) +{ + // clang-format off + std::vector strings { "abc", "def", "", "z", "bananas", "warp", "", "zing" }; + std::vector valids { 1, 0, 0, 1, 0, 1, 1, 1 }; + // clang-format on - cudf::test::strings_column_wrapper col(strings.begin(), strings.end(), valids.begin()); + cudf::test::strings_column_wrapper col(strings.begin(), strings.end(), valids.begin()); - table_view t({col}); - auto result = cudf::row_bit_count(t); + table_view t({col}); + auto result = cudf::row_bit_count(t); - // expect 1 offset (4 bytes) + (length of string, or 0 if null) + 1 validity bit per row - auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings, &valids](int i) { - return ((static_cast(valids[i] ? strings[i].size() : 0) + sizeof(offset_type)) * - 8) + - 1; - }); - cudf::test::fixed_width_column_wrapper expected(size_iter, - size_iter + strings.size()); + // expect 1 offset (4 bytes) + (length of string, or 0 if null) + 1 validity bit per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings, &valids](int i) { + return ((static_cast(valids[i] ? strings[i].size() : 0) + sizeof(offset_type)) * + CHAR_BIT) + + 1; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + strings.size()); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } std::pair, std::unique_ptr> build_struct_column() @@ -219,8 +211,9 @@ std::pair, std::unique_ptr> build_struct_column( // (1 validity bit) auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings, &struct_validity](int i) { - return (sizeof(float) * 8) + 1 + (sizeof(int16_t) * 8) + 1 + - (static_cast(strings[i].size()) * 8) + (sizeof(offset_type) * 8) + 1 + 1; + return (sizeof(float) * CHAR_BIT) + 1 + (sizeof(int16_t) * CHAR_BIT) + 1 + + (static_cast(strings[i].size()) * CHAR_BIT) + + (sizeof(offset_type) * CHAR_BIT) + 1 + 1; }); cudf::test::fixed_width_column_wrapper expected_sizes(size_iter, size_iter + strings.size()); @@ -245,8 +238,8 @@ TEST_F(RowBitCount, Structs) // expect 1 offset (4 bytes) + (length of string) + 1 float + 1 int16_t auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return ((sizeof(float) + sizeof(int16_t)) * 8) + - ((static_cast(strings[i].size()) + sizeof(offset_type)) * 8); + return ((sizeof(float) + sizeof(int16_t)) * CHAR_BIT) + + ((static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT); }); cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); @@ -279,7 +272,7 @@ TEST_F(RowBitCount, Structs) // expect num_rows * (4 + 2) bytes auto size_iter = - cudf::detail::make_counting_transform_iterator(0, [&](int i) { return (4 + 2) * 8; }); + cudf::detail::make_counting_transform_iterator(0, [&](int i) { return (4 + 2) * CHAR_BIT; }); cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } @@ -536,7 +529,7 @@ TEST_F(RowBitCount, SlicedColumns) // expect 1 offset (4 bytes) + length of string per row auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return (static_cast(strings[i].size()) + sizeof(offset_type)) * 8; + return (static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT; }); cudf::test::fixed_width_column_wrapper expected(size_iter + 3, size_iter + 3 + slice_size); @@ -580,11 +573,11 @@ TEST_F(RowBitCount, SlicedColumns) // expect 1 offset (4 bytes) + length of string per row + 1 int16_t per row auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { return (static_cast(strings[i].size()) + sizeof(offset_type) + sizeof(int16_t)) * - 8; + CHAR_BIT; }); cudf::test::fixed_width_column_wrapper expected(size_iter + 3, size_iter + 3 + slice_size); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } -} \ No newline at end of file +} From 3effc9f78cd85d39be4174d0f678dc3ec4d004ad Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Fri, 19 Mar 2021 10:06:34 -0500 Subject: [PATCH 06/12] Tweak Cmakelists.txt --- cpp/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 60664564453..74c028b9b6b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -383,9 +383,9 @@ add_library(cudf src/transform/jit/code/kernel.cpp src/transform/mask_to_bools.cu src/transform/nans_to_nulls.cu - src/transform/transform.cpp src/transform/row_bit_count.cu - src/transpose/transpose.cu + src/transform/transform.cpp + src/transpose/transpose.cu src/unary/cast_ops.cu src/unary/math_ops.cu src/unary/nan_ops.cu From 9709d605f5ffb38fdbdccd0c3c557ae41ed10472 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Fri, 19 Mar 2021 10:35:24 -0500 Subject: [PATCH 07/12] Split up tests a little more. Doc tweaks --- cpp/src/transform/row_bit_count.cu | 4 +- cpp/tests/transform/row_bit_count_test.cu | 203 +++++++++++----------- 2 files changed, 100 insertions(+), 107 deletions(-) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index 292fe46ae57..fc08417c9af 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -115,9 +115,9 @@ struct column_info { struct hierarchy_info { hierarchy_info() : simple_per_row_size(0), complex_type_count(0), max_branch_depth(0) {} - // these two fields act as an optimization. if we find that the entire table + // These two fields act as an optimization. If we find that the entire table // is just fixed-width types, we do not need to do the more expensive kernel call that - // traverses the individual columns. so if complex_type_count is 0, we can just + // traverses the individual columns. So if complex_type_count is 0, we can just // return a column where every row contains the value simple_per_row_size size_type simple_per_row_size; // in bits size_type complex_type_count; diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index a99ae0ddf9a..c82022de0b2 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -221,61 +221,57 @@ std::pair, std::unique_ptr> build_struct_column( return {struct_col.release(), expected_sizes.release()}; } -TEST_F(RowBitCount, Structs) +TEST_F(RowBitCount, StructsNoNulls) { - // no nulls - { - std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; + std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; - cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; - cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; - cudf::test::strings_column_wrapper col2(strings.begin(), strings.end()); + cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; + cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; + cudf::test::strings_column_wrapper col2(strings.begin(), strings.end()); - cudf::test::structs_column_wrapper struct_col({col0, col1, col2}); + cudf::test::structs_column_wrapper struct_col({col0, col1, col2}); - table_view t({struct_col}); - auto result = cudf::row_bit_count(t); + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); - // expect 1 offset (4 bytes) + (length of string) + 1 float + 1 int16_t - auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return ((sizeof(float) + sizeof(int16_t)) * CHAR_BIT) + - ((static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT); - }); - cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); + // expect 1 offset (4 bytes) + (length of string) + 1 float + 1 int16_t + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return ((sizeof(float) + sizeof(int16_t)) * CHAR_BIT) + + ((static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT); + }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} - // nulls - { - // creating a struct column with validity will cause all child columns to be promoted to have - // validity - std::unique_ptr struct_col; - std::unique_ptr expected_sizes; - std::tie(struct_col, expected_sizes) = build_struct_column(); - table_view t({*struct_col}); - auto result = cudf::row_bit_count(t); +TEST_F(RowBitCount, StructsNulls) +{ + std::unique_ptr struct_col; + std::unique_ptr expected_sizes; + std::tie(struct_col, expected_sizes) = build_struct_column(); + table_view t({*struct_col}); + auto result = cudf::row_bit_count(t); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_sizes, *result); +} +TEST_F(RowBitCount, StructsNested) +{ // struct, int16> - { - cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; - cudf::test::structs_column_wrapper inner_struct({col0}); + cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; + cudf::test::structs_column_wrapper inner_struct({col0}); - cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; - cudf::test::structs_column_wrapper struct_col({inner_struct, col1}); + cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; + cudf::test::structs_column_wrapper struct_col({inner_struct, col1}); - table_view t({struct_col}); - auto result = cudf::row_bit_count(t); + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); - // expect num_rows * (4 + 2) bytes - auto size_iter = - cudf::detail::make_counting_transform_iterator(0, [&](int i) { return (4 + 2) * CHAR_BIT; }); - cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + // expect num_rows * (4 + 2) bytes + auto size_iter = + cudf::detail::make_counting_transform_iterator(0, [&](int i) { return (4 + 2) * CHAR_BIT; }); + cudf::test::fixed_width_column_wrapper expected(size_iter, size_iter + t.num_rows()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } std::pair, std::unique_ptr> build_nested_column( @@ -500,84 +496,81 @@ TEST_F(RowBitCount, Table) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected, *result); } -TEST_F(RowBitCount, SlicedColumns) +TEST_F(RowBitCount, SlicedColumnsFixedWidth) { - // fixed width - { - auto const slice_size = 7; - cudf::test::fixed_width_column_wrapper c0_unsliced{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - auto c0 = cudf::slice(c0_unsliced, {2, 2 + slice_size}); + auto const slice_size = 7; + cudf::test::fixed_width_column_wrapper c0_unsliced{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + auto c0 = cudf::slice(c0_unsliced, {2, 2 + slice_size}); - table_view t({c0}); - auto result = cudf::row_bit_count(t); + table_view t({c0}); + auto result = cudf::row_bit_count(t); - cudf::test::fixed_width_column_wrapper expected{16, 16, 16, 16, 16, 16, 16}; + cudf::test::fixed_width_column_wrapper expected{16, 16, 16, 16, 16, 16, 16}; - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} - // strings - { - auto const slice_size = 7; - std::vector strings{ - "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; - cudf::test::strings_column_wrapper c0_unsliced(strings.begin(), strings.end()); - auto c0 = cudf::slice(c0_unsliced, {3, 3 + slice_size}); +TEST_F(RowBitCount, SlicedColumnsStrings) +{ + auto const slice_size = 7; + std::vector strings{ + "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + cudf::test::strings_column_wrapper c0_unsliced(strings.begin(), strings.end()); + auto c0 = cudf::slice(c0_unsliced, {3, 3 + slice_size}); - table_view t({c0}); - auto result = cudf::row_bit_count(t); + table_view t({c0}); + auto result = cudf::row_bit_count(t); - // expect 1 offset (4 bytes) + length of string per row - auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return (static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT; - }); - cudf::test::fixed_width_column_wrapper expected(size_iter + 3, - size_iter + 3 + slice_size); + // expect 1 offset (4 bytes) + length of string per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type)) * CHAR_BIT; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter + 3, + size_iter + 3 + slice_size); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} - // lists - { - auto const slice_size = 2; - cudf::test::lists_column_wrapper c0_unsliced{ - {{"banana", "v"}, {"cats"}}, - {{"dogs", "yay"}, {"xyz", ""}, {"ultra"}}, - {{"fast", "parrot"}, {"orange"}}, - {{"blue"}, {"red", "yellow"}, {"ultraviolet", "", "green"}}}; - auto c0 = cudf::slice(c0_unsliced, {1, 1 + slice_size}); - - table_view t({c0}); - auto result = cudf::row_bit_count(t); +TEST_F(RowBitCount, SlicedColumnsLists) +{ + auto const slice_size = 2; + cudf::test::lists_column_wrapper c0_unsliced{ + {{"banana", "v"}, {"cats"}}, + {{"dogs", "yay"}, {"xyz", ""}, {"ultra"}}, + {{"fast", "parrot"}, {"orange"}}, + {{"blue"}, {"red", "yellow"}, {"ultraviolet", "", "green"}}}; + auto c0 = cudf::slice(c0_unsliced, {1, 1 + slice_size}); + + table_view t({c0}); + auto result = cudf::row_bit_count(t); - cudf::test::fixed_width_column_wrapper expected{408, 320}; + cudf::test::fixed_width_column_wrapper expected{408, 320}; - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); +} - // structs - { - auto const slice_size = 7; +TEST_F(RowBitCount, SlicedColumnsStructs) +{ + auto const slice_size = 7; - cudf::test::fixed_width_column_wrapper c0{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - std::vector strings{ - "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; - cudf::test::strings_column_wrapper c1(strings.begin(), strings.end()); + cudf::test::fixed_width_column_wrapper c0{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + std::vector strings{ + "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + cudf::test::strings_column_wrapper c1(strings.begin(), strings.end()); - auto struct_col_unsliced = cudf::test::structs_column_wrapper({c0, c1}); - auto struct_col = cudf::slice(struct_col_unsliced, {3, 3 + slice_size}); + auto struct_col_unsliced = cudf::test::structs_column_wrapper({c0, c1}); + auto struct_col = cudf::slice(struct_col_unsliced, {3, 3 + slice_size}); - table_view t({struct_col}); - auto result = cudf::row_bit_count(t); + table_view t({struct_col}); + auto result = cudf::row_bit_count(t); - // expect 1 offset (4 bytes) + length of string per row + 1 int16_t per row - auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { - return (static_cast(strings[i].size()) + sizeof(offset_type) + sizeof(int16_t)) * - CHAR_BIT; - }); - cudf::test::fixed_width_column_wrapper expected(size_iter + 3, - size_iter + 3 + slice_size); + // expect 1 offset (4 bytes) + length of string per row + 1 int16_t per row + auto size_iter = cudf::detail::make_counting_transform_iterator(0, [&strings](int i) { + return (static_cast(strings[i].size()) + sizeof(offset_type) + sizeof(int16_t)) * + CHAR_BIT; + }); + cudf::test::fixed_width_column_wrapper expected(size_iter + 3, + size_iter + 3 + slice_size); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); - } + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } From 2985b4de7863d55205331e107a08e2b0bdf4264b Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Tue, 23 Mar 2021 14:17:16 -0500 Subject: [PATCH 08/12] Cleaned up some test function names. Sprinkled utf8 characters around various tests. --- cpp/tests/transform/row_bit_count_test.cu | 31 ++++++++++++----------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index c82022de0b2..c2a56e980cf 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -154,7 +154,7 @@ struct RowBitCount : public cudf::test::BaseFixture { TEST_F(RowBitCount, Strings) { - std::vector strings{"abc", "def", "", "z", "bananas", "warp", "", "zing"}; + std::vector strings{"abc", "ï", "", "z", "bananas", "warp", "", "zing"}; cudf::test::strings_column_wrapper col(strings.begin(), strings.end()); @@ -173,8 +173,8 @@ TEST_F(RowBitCount, Strings) TEST_F(RowBitCount, StringsWithNulls) { // clang-format off - std::vector strings { "abc", "def", "", "z", "bananas", "warp", "", "zing" }; - std::vector valids { 1, 0, 0, 1, 0, 1, 1, 1 }; + std::vector strings { "daïs", "def", "", "z", "bananas", "warp", "", "zing" }; + std::vector valids { 1, 0, 0, 1, 0, 1, 1, 1 }; // clang-format on cudf::test::strings_column_wrapper col(strings.begin(), strings.end(), valids.begin()); @@ -196,7 +196,7 @@ TEST_F(RowBitCount, StringsWithNulls) std::pair, std::unique_ptr> build_struct_column() { std::vector struct_validity{0, 1, 1, 1, 1, 0}; - std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; + std::vector strings{"abc", "def", "", "z", "bananas", "daïs"}; cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; cudf::test::fixed_width_column_wrapper col1{{8, 9, 10, 11, 12, 13}, {1, 0, 1, 1, 1, 1}}; @@ -223,7 +223,7 @@ std::pair, std::unique_ptr> build_struct_column( TEST_F(RowBitCount, StructsNoNulls) { - std::vector strings{"abc", "def", "", "z", "bananas", "warp"}; + std::vector strings{"abc", "daïs", "", "z", "bananas", "warp"}; cudf::test::fixed_width_column_wrapper col0{0, 1, 2, 3, 4, 5}; cudf::test::fixed_width_column_wrapper col1{8, 9, 10, 11, 12, 13}; @@ -274,7 +274,7 @@ TEST_F(RowBitCount, StructsNested) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } -std::pair, std::unique_ptr> build_nested_column( +std::pair, std::unique_ptr> build_nested_and_expected_column( std::vector const& struct_validity) { // tests the "branching" case -> list ...>>> @@ -324,7 +324,7 @@ std::pair, std::unique_ptr> build_nested_column( expected_sizes.release()}; } -std::unique_ptr build_nested_column2(std::vector const& struct_validity) +std::unique_ptr build_nested_column(std::vector const& struct_validity) { // List>, Struct>> @@ -362,11 +362,12 @@ TEST_F(RowBitCount, NestedTypes) { std::unique_ptr col_no_nulls; std::unique_ptr expected_sizes; - std::tie(col_no_nulls, expected_sizes) = build_nested_column({1, 1, 1, 1, 1, 1, 1, 1}); + std::tie(col_no_nulls, expected_sizes) = + build_nested_and_expected_column({1, 1, 1, 1, 1, 1, 1, 1}); table_view no_nulls_t({*col_no_nulls}); auto no_nulls_result = cudf::row_bit_count(no_nulls_t); - auto col_nulls = build_nested_column({0, 0, 1, 1, 1, 1, 1, 1}).first; + auto col_nulls = build_nested_and_expected_column({0, 0, 1, 1, 1, 1, 1, 1}).first; table_view nulls_t({*col_nulls}); auto nulls_result = cudf::row_bit_count(nulls_t); @@ -392,11 +393,11 @@ TEST_F(RowBitCount, NestedTypes) // List>, Struct>> { - auto col_no_nulls = build_nested_column2({1, 1, 1}); + auto col_no_nulls = build_nested_column({1, 1, 1}); table_view no_nulls_t({*col_no_nulls}); auto no_nulls_result = cudf::row_bit_count(no_nulls_t); - auto col_nulls = build_nested_column2({1, 0, 1}); + auto col_nulls = build_nested_column({1, 0, 1}); table_view nulls_t({*col_nulls}); auto nulls_result = cudf::row_bit_count(nulls_t); @@ -466,7 +467,7 @@ TEST_F(RowBitCount, Table) // complex nested column std::unique_ptr col0; std::unique_ptr col0_sizes; - std::tie(col0, col0_sizes) = build_nested_column({1, 1, 1, 1, 1, 1, 1, 1}); + std::tie(col0, col0_sizes) = build_nested_and_expected_column({1, 1, 1, 1, 1, 1, 1, 1}); // struct column std::unique_ptr col1; @@ -514,7 +515,7 @@ TEST_F(RowBitCount, SlicedColumnsStrings) { auto const slice_size = 7; std::vector strings{ - "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + "banana", "metric", "imperial", "abc", "daïs", "", "fire", "def", "cudf", "xyzw"}; cudf::test::strings_column_wrapper c0_unsliced(strings.begin(), strings.end()); auto c0 = cudf::slice(c0_unsliced, {3, 3 + slice_size}); @@ -536,7 +537,7 @@ TEST_F(RowBitCount, SlicedColumnsLists) auto const slice_size = 2; cudf::test::lists_column_wrapper c0_unsliced{ {{"banana", "v"}, {"cats"}}, - {{"dogs", "yay"}, {"xyz", ""}, {"ultra"}}, + {{"dogs", "yay"}, {"xyz", ""}, {"daïs"}}, {{"fast", "parrot"}, {"orange"}}, {{"blue"}, {"red", "yellow"}, {"ultraviolet", "", "green"}}}; auto c0 = cudf::slice(c0_unsliced, {1, 1 + slice_size}); @@ -555,7 +556,7 @@ TEST_F(RowBitCount, SlicedColumnsStructs) cudf::test::fixed_width_column_wrapper c0{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; std::vector strings{ - "banana", "metric", "imperial", "abc", "pears", "", "fire", "def", "cudf", "xyzw"}; + "banana", "metric", "imperial", "abc", "daïs", "", "fire", "def", "cudf", "xyzw"}; cudf::test::strings_column_wrapper c1(strings.begin(), strings.end()); auto struct_col_unsliced = cudf::test::structs_column_wrapper({c0, c1}); From 7c7d6746458c546d5952ef94d91a0ed7a2575558 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Wed, 24 Mar 2021 10:51:46 -0500 Subject: [PATCH 09/12] Add missing header for thrust::optional --- cpp/src/transform/row_bit_count.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index fc08417c9af..5d8f1a17504 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -23,6 +23,8 @@ #include #include +#include + #include #include #include From 764d3eb0e43bc6a3009ddd2a56ea60d17c6982ea Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Thu, 25 Mar 2021 10:11:09 -0500 Subject: [PATCH 10/12] Doc changes/improvements from PR review. --- cpp/include/cudf/transform.hpp | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index 4c6804b3d36..e99e0db21c5 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -152,20 +152,14 @@ std::unique_ptr mask_to_bools( * Each row in the returned column is the sum of the per-row size for each column in * the table. * - * In some cases, this is an inexact approximation. Specifically, with - * lists or strings, the cost of a row includes 32 bits for a single offset. However, two - * offsets is required to represent an entire row. But this presents a problem, because to - * represent 2 rows, you need 3 offsets. 3 rows 4 offsets, etc. Therefore it would not - * be accurate to say each row of a string column costs 2 offsets because summing multiple row - * sizes would give you a number too large. It is up to the caller to understand the schema - * of the input column to be able to calculate the small additional overhead of the - * terminating offset for any group of rows being considered. - * - * This function returns the per-row sizes as the columns are currently formed. This can - * end up being different than the number you would get by gathering the rows under - * certain circumstances. Specifically, the pushdown of validity masks by struct - * columns can nullify rows that actually contain underlying data for string or list - * columns. In these cases, the sized returned will be strictly: + * In some cases, this is an inexact approximation. Specifically, columns of lists and strings + * require N+1 offsets to represent N rows. It is up to the caller to calculate the small + * additional overhead of the terminating offset for any group of rows being considered. + * + * This function returns the per-row sizes as the columns are currently formed. This can + * end up being larger than the number you would get by gathering the rows. Specifically, + * the push-down of struct column validity masks can nullify rows that contain data for + * string or list columns. In these cases, the size returned is conservative: * * row_bit_count(column(x)) >= row_bit_count(gather(column(x))) * From a04f67af2edee48385860e1a9abe0a81b9a0896b Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 29 Mar 2021 10:48:40 -0500 Subject: [PATCH 11/12] Handle empty table edge case. --- cpp/src/transform/row_bit_count.cu | 6 ++++++ cpp/tests/transform/row_bit_count_test.cu | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index 5d8f1a17504..88a249f6eff 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -464,6 +464,12 @@ std::unique_ptr row_bit_count(table_view const& t, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { + // no rows + if (t.num_rows() <= 0) { + return cudf::make_fixed_width_column( + data_type{type_id::INT32}, 0, mask_state::UNALLOCATED, stream, mr); + } + // flatten the hierarchy and determine some information about it. std::vector cols; std::vector info; diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index c2a56e980cf..c0288d9b73e 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -575,3 +576,21 @@ TEST_F(RowBitCount, SlicedColumnsStructs) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *result); } + +TEST_F(RowBitCount, EmptyTable) +{ + { + cudf::table_view empty; + auto result = cudf::row_bit_count(empty); + CUDF_EXPECTS(result != nullptr && result->size() == 0, "Expected an empty column"); + } + + { + auto strings = cudf::strings::detail::make_empty_strings_column(0); + auto ints = cudf::make_fixed_width_column(data_type{type_id::INT32}, 0); + cudf::table_view empty({*strings, *ints}); + + auto result = cudf::row_bit_count(empty); + CUDF_EXPECTS(result != nullptr && result->size() == 0, "Expected an empty column"); + } +} \ No newline at end of file From b4ce7d1078711e15664fffd09ef47e169486a5a0 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 29 Mar 2021 11:23:32 -0500 Subject: [PATCH 12/12] Use make_empty_column() instead of make_fixed_width_column(). --- cpp/src/transform/row_bit_count.cu | 5 +---- cpp/tests/transform/row_bit_count_test.cu | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/cpp/src/transform/row_bit_count.cu b/cpp/src/transform/row_bit_count.cu index 88a249f6eff..e36fa36596f 100644 --- a/cpp/src/transform/row_bit_count.cu +++ b/cpp/src/transform/row_bit_count.cu @@ -465,10 +465,7 @@ std::unique_ptr row_bit_count(table_view const& t, rmm::mr::device_memory_resource* mr) { // no rows - if (t.num_rows() <= 0) { - return cudf::make_fixed_width_column( - data_type{type_id::INT32}, 0, mask_state::UNALLOCATED, stream, mr); - } + if (t.num_rows() <= 0) { return cudf::make_empty_column(data_type{type_id::INT32}); } // flatten the hierarchy and determine some information about it. std::vector cols; diff --git a/cpp/tests/transform/row_bit_count_test.cu b/cpp/tests/transform/row_bit_count_test.cu index c0288d9b73e..21e5c818197 100644 --- a/cpp/tests/transform/row_bit_count_test.cu +++ b/cpp/tests/transform/row_bit_count_test.cu @@ -587,7 +587,7 @@ TEST_F(RowBitCount, EmptyTable) { auto strings = cudf::strings::detail::make_empty_strings_column(0); - auto ints = cudf::make_fixed_width_column(data_type{type_id::INT32}, 0); + auto ints = cudf::make_empty_column(data_type{type_id::INT32}); cudf::table_view empty({*strings, *ints}); auto result = cudf::row_bit_count(empty);