From 821f4dea107db6a51fcbffff997fa6844ab5565f Mon Sep 17 00:00:00 2001
From: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com>
Date: Thu, 25 Jan 2024 17:44:58 -0600
Subject: [PATCH] Fixed an issue with output chunking computation stemming from
 input chunking. (#14889)

Fixes  https://github.com/rapidsai/cudf/issues/14883

The core issue was that the output chunking code was expecting all columns to have terminating pages that end in the same row count.  Previously this was the case because we always processed entire row groups.  But now with the subrowgroup reader, we can split on page boundaries that cause a jagged max row index for different columns.  Example:

```
             0       100             200
Col A     [-----------][--------------]      300
Col B     [-----------][----------------------]
```

The input chunking would have computed a max row index of 200 for the subpass.  But when computing the _output_ chunks, there was code that would have tried finding where row 300 was in column A, resulting in an out-of-bounds read.

The fix is simply to cap the max row seen for column B to be the max expected row for the subpass.

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Vukasin Milovanovic (https://github.com/vuule)
  - Mike Wilson (https://github.com/hyperbolic2346)

URL: https://github.com/rapidsai/cudf/pull/14889
---
 cpp/src/io/parquet/reader_impl_chunking.cu | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu
index 1bfe5745b9e..e0cb2fbb4f4 100644
--- a/cpp/src/io/parquet/reader_impl_chunking.cu
+++ b/cpp/src/io/parquet/reader_impl_chunking.cu
@@ -253,13 +253,15 @@ struct set_row_index {
   device_span<ColumnChunkDesc const> chunks;
   device_span<PageInfo const> pages;
   device_span<cumulative_page_info> c_info;
+  size_t max_row;
 
   __device__ void operator()(size_t i)
   {
-    auto const& page            = pages[i];
-    auto const& chunk           = chunks[page.chunk_idx];
-    size_t const page_start_row = chunk.start_row + page.chunk_row + page.num_rows;
-    c_info[i].row_index         = page_start_row;
+    auto const& page          = pages[i];
+    auto const& chunk         = chunks[page.chunk_idx];
+    size_t const page_end_row = chunk.start_row + page.chunk_row + page.num_rows;
+    // if we have been passed in a cap, apply it
+    c_info[i].row_index = max_row > 0 ? min(max_row, page_end_row) : page_end_row;
   }
 };
 
@@ -1288,7 +1290,7 @@ void reader::impl::setup_next_subpass(bool uses_custom_row_bounds)
     thrust::for_each(rmm::exec_policy_nosync(_stream),
                      iter,
                      iter + pass.pages.size(),
-                     set_row_index{pass.chunks, pass.pages, c_info});
+                     set_row_index{pass.chunks, pass.pages, c_info, 0});
     // print_cumulative_page_info(pass.pages, pass.chunks, c_info, _stream);
 
     // get the next batch of pages
@@ -1533,10 +1535,15 @@ void reader::impl::compute_output_chunks_for_subpass()
                                 thrust::equal_to{},
                                 cumulative_page_sum{});
   auto iter = thrust::make_counting_iterator(0);
+  // cap the max row in all pages by the max row we expect in the subpass. input chunking
+  // can cause "dangling" row counts where for example, only 1 column has a page whose
+  // maximum row is beyond our expected subpass max row, which will cause an out of
+  // bounds index in compute_page_splits_by_row.
+  auto const subpass_max_row = subpass.skip_rows + subpass.num_rows;
   thrust::for_each(rmm::exec_policy_nosync(_stream),
                    iter,
                    iter + subpass.pages.size(),
-                   set_row_index{pass.chunks, subpass.pages, c_info});
+                   set_row_index{pass.chunks, subpass.pages, c_info, subpass_max_row});
   // print_cumulative_page_info(subpass.pages, c_info, _stream);
 
   // compute the splits