Skip to content

Commit

Permalink
Parquet reader optimization to address V100 regression. (#12577)
Browse files Browse the repository at this point in the history
Addresses #12316

Some recent changes caused a performance regression in the parquet reader benchmarks for lists.  The culprit ended up being slightly different code generation happening for arch 70.  In several memory hotspots, the code was reading values from global, modifying them and then storing them.   Previously it had done a better job of loading and keeping them in registers and the L2 cache was helping keep things fast.  But the extra store was causing twice as many L2 access in these places and causing many long scoreboard stalls.

Ultimately the issue is that these values shouldn't be kept in global memory.  The initial implementation did it this way because the data was variable in size (based on depth of column nesting).  But in practice, we never see more than 2 or 3 levels of nesting.  So the solution is:

- Keep these values (in a struct called `PageNestingDecodeInfo`) that is kept in shared memory for up to N nesting levels.  N is currently 10.
- If the nesting information for the incoming column fits in the cache, use it.  Otherwise fall back to the arrays in global memory.  In practice, it is exceedingly rare to see columns nested >= 10 deep.

This addresses the performance regression and actually gives some performance increases.   Some comparisons for LIST benchmarks.
```
cudf 22.10 (prior to regression)
| data_type | cardinality | run_length | bytes_per_second | 
|-----------|-------------|------------|------------------|
|      LIST |           0 |          1 |     892901208    | 
|      LIST |        1000 |          1 |     952863876    |  
|      LIST |           0 |         32 |    1246033395    |  
|      LIST |        1000 |         32 |    1232884866    |  
```

```
cudf 22.12 (where the regression occurred)
| data_type | cardinality | run_length | bytes_per_second | 
|-----------|-------------|------------|------------------|
|      LIST |           0 |          1 |     747758436    | 
|      LIST |        1000 |          1 |     827763260    |  
|      LIST |           0 |         32 |    1026048576    |  
|      LIST |        1000 |         32 |    1022928119    |  
```

```
This PR
| data_type | cardinality | run_length | bytes_per_second | 
|-----------|-------------|------------|------------------|
|      LIST |           0 |          1 |     927347737    | 
|      LIST |        1000 |          1 |     1024566150   |  
|      LIST |           0 |         32 |    1315972881    |  
|      LIST |        1000 |         32 |    1303995168    |  
```

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Bradley Dice (https://github.com/bdice)

URL: #12577
  • Loading branch information
nvdbaranec authored Jan 25, 2023
1 parent ed6daad commit 2784f58
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 115 deletions.
204 changes: 129 additions & 75 deletions cpp/src/io/parquet/page_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ struct page_state_s {
const uint8_t* lvl_start[NUM_LEVEL_TYPES]; // [def,rep]
int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded
int32_t row_index_lower_bound; // lower bound of row indices we should process

// a shared-memory cache of frequently used data when decoding. The source of this data is
// normally stored in global memory which can yield poor performance. So, when possible
// we copy that info here prior to decoding
PageNestingDecodeInfo nesting_decode_cache[max_cacheable_nesting_decode_info];
// points to either nesting_decode_cache above when possible, or to the global source otherwise
PageNestingDecodeInfo* nesting_info;
};

/**
Expand Down Expand Up @@ -927,23 +934,49 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s,
int chunk_idx;

// Fetch page info
if (t == 0) s->page = *p;
if (!t) s->page = *p;
__syncthreads();

if (s->page.flags & PAGEINFO_FLAGS_DICTIONARY) { return false; }
// Fetch column chunk info
chunk_idx = s->page.chunk_idx;
if (t == 0) { s->col = chunks[chunk_idx]; }

// zero nested value and valid counts
int d = 0;
while (d < s->page.num_output_nesting_levels) {
if (d + t < s->page.num_output_nesting_levels) {
s->page.nesting[d + t].valid_count = 0;
s->page.nesting[d + t].value_count = 0;
s->page.nesting[d + t].null_count = 0;
if (!t) { s->col = chunks[chunk_idx]; }

// if we can use the decode cache, set it up now
auto const can_use_decode_cache = s->page.nesting_info_size <= max_cacheable_nesting_decode_info;
if (can_use_decode_cache) {
int depth = 0;
while (depth < s->page.nesting_info_size) {
int const thread_depth = depth + t;
if (thread_depth < s->page.nesting_info_size) {
// these values need to be copied over from global
s->nesting_decode_cache[thread_depth].max_def_level =
s->page.nesting_decode[thread_depth].max_def_level;
s->nesting_decode_cache[thread_depth].page_start_value =
s->page.nesting_decode[thread_depth].page_start_value;
s->nesting_decode_cache[thread_depth].start_depth =
s->page.nesting_decode[thread_depth].start_depth;
s->nesting_decode_cache[thread_depth].end_depth =
s->page.nesting_decode[thread_depth].end_depth;
}
depth += blockDim.x;
}
}
if (!t) {
s->nesting_info = can_use_decode_cache ? s->nesting_decode_cache : s->page.nesting_decode;
}
__syncthreads();

// zero counts
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
int const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
s->nesting_info[thread_depth].valid_count = 0;
s->nesting_info[thread_depth].value_count = 0;
s->nesting_info[thread_depth].null_count = 0;
}
d += blockDim.x;
depth += blockDim.x;
}
__syncthreads();

Expand Down Expand Up @@ -1076,7 +1109,7 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s,
if (is_decode_step) {
int max_depth = s->col.max_nesting_depth;
for (int idx = 0; idx < max_depth; idx++) {
PageNestingInfo* pni = &s->page.nesting[idx];
PageNestingDecodeInfo* nesting_info = &s->nesting_info[idx];

size_t output_offset;
// schemas without lists
Expand All @@ -1085,21 +1118,21 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s,
}
// for schemas with lists, we've already got the exact value precomputed
else {
output_offset = pni->page_start_value;
output_offset = nesting_info->page_start_value;
}

pni->data_out = static_cast<uint8_t*>(s->col.column_data_base[idx]);
nesting_info->data_out = static_cast<uint8_t*>(s->col.column_data_base[idx]);

if (pni->data_out != nullptr) {
if (nesting_info->data_out != nullptr) {
// anything below max depth with a valid data pointer must be a list, so the
// element size is the size of the offset type.
uint32_t len = idx < max_depth - 1 ? sizeof(cudf::size_type) : s->dtype_len;
pni->data_out += (output_offset * len);
nesting_info->data_out += (output_offset * len);
}
pni->valid_map = s->col.valid_map_base[idx];
if (pni->valid_map != nullptr) {
pni->valid_map += output_offset >> 5;
pni->valid_map_offset = (int32_t)(output_offset & 0x1f);
nesting_info->valid_map = s->col.valid_map_base[idx];
if (nesting_info->valid_map != nullptr) {
nesting_info->valid_map += output_offset >> 5;
nesting_info->valid_map_offset = (int32_t)(output_offset & 0x1f);
}
}
}
Expand Down Expand Up @@ -1217,26 +1250,26 @@ static __device__ bool setupLocalPageInfo(page_state_s* const s,
* @brief Store a validity mask containing value_count bits into the output validity buffer of the
* page.
*
* @param[in,out] pni The page/nesting information to store the mask in. The validity map offset is
* also updated
* @param[in,out] nesting_info The page/nesting information to store the mask in. The validity map
* offset is also updated
* @param[in] valid_mask The validity mask to be stored
* @param[in] value_count # of bits in the validity mask
*/
static __device__ void store_validity(PageNestingInfo* pni,
static __device__ void store_validity(PageNestingDecodeInfo* nesting_info,
uint32_t valid_mask,
int32_t value_count)
{
int word_offset = pni->valid_map_offset / 32;
int bit_offset = pni->valid_map_offset % 32;
int word_offset = nesting_info->valid_map_offset / 32;
int bit_offset = nesting_info->valid_map_offset % 32;
// if we fit entirely in the output word
if (bit_offset + value_count <= 32) {
auto relevant_mask = static_cast<uint32_t>((static_cast<uint64_t>(1) << value_count) - 1);

if (relevant_mask == ~0) {
pni->valid_map[word_offset] = valid_mask;
nesting_info->valid_map[word_offset] = valid_mask;
} else {
atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset));
atomicOr(pni->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset);
atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset));
atomicOr(nesting_info->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset);
}
}
// we're going to spill over into the next word.
Expand All @@ -1250,17 +1283,17 @@ static __device__ void store_validity(PageNestingInfo* pni,
// first word. strip bits_left bits off the beginning and store that
uint32_t relevant_mask = ((1 << bits_left) - 1);
uint32_t mask_word0 = valid_mask & relevant_mask;
atomicAnd(pni->valid_map + word_offset, ~(relevant_mask << bit_offset));
atomicOr(pni->valid_map + word_offset, mask_word0 << bit_offset);
atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset));
atomicOr(nesting_info->valid_map + word_offset, mask_word0 << bit_offset);

// second word. strip the remainder of the bits off the end and store that
relevant_mask = ((1 << (value_count - bits_left)) - 1);
uint32_t mask_word1 = valid_mask & (relevant_mask << bits_left);
atomicAnd(pni->valid_map + word_offset + 1, ~(relevant_mask));
atomicOr(pni->valid_map + word_offset + 1, mask_word1 >> bits_left);
atomicAnd(nesting_info->valid_map + word_offset + 1, ~(relevant_mask));
atomicOr(nesting_info->valid_map + word_offset + 1, mask_word1 >> bits_left);
}

pni->valid_map_offset += value_count;
nesting_info->valid_map_offset += value_count;
}

/**
Expand Down Expand Up @@ -1294,8 +1327,8 @@ inline __device__ void get_nesting_bounds(int& start_depth,
// bound what nesting levels we apply values to
if (s->col.max_level[level_type::REPETITION] > 0) {
int r = s->rep[index];
start_depth = s->page.nesting[r].start_depth;
end_depth = s->page.nesting[d].end_depth;
start_depth = s->nesting_info[r].start_depth;
end_depth = s->nesting_info[d].end_depth;
}
// for columns without repetition (even ones involving structs) we always
// traverse the entire hierarchy.
Expand Down Expand Up @@ -1326,6 +1359,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu
// how many rows we've processed in the page so far
int input_row_count = s->input_row_count;

PageNestingDecodeInfo* nesting_info_base = s->nesting_info;

// process until we've reached the target
while (input_value_count < target_input_value_count) {
// determine the nesting bounds for this thread (the range of nesting depths we
Expand Down Expand Up @@ -1367,14 +1402,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu
// walk from 0 to max_depth
uint32_t next_thread_value_count, next_warp_value_count;
for (int s_idx = 0; s_idx < max_depth; s_idx++) {
PageNestingInfo* pni = &s->page.nesting[s_idx];
PageNestingDecodeInfo* nesting_info = &nesting_info_base[s_idx];

// if we are within the range of nesting levels we should be adding value indices for
int const in_nesting_bounds =
((s_idx >= start_depth && s_idx <= end_depth) && in_row_bounds) ? 1 : 0;

// everything up to the max_def_level is a non-null value
uint32_t const is_valid = d >= pni->max_def_level && in_nesting_bounds ? 1 : 0;
uint32_t const is_valid = d >= nesting_info->max_def_level && in_nesting_bounds ? 1 : 0;

// compute warp and thread valid counts
uint32_t const warp_valid_mask =
Expand All @@ -1395,8 +1430,8 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu

// if this is the value column emit an index for value decoding
if (is_valid && s_idx == max_depth - 1) {
int const src_pos = pni->valid_count + thread_valid_count;
int const dst_pos = pni->value_count + thread_value_count;
int const src_pos = nesting_info->valid_count + thread_valid_count;
int const dst_pos = nesting_info->value_count + thread_value_count;
// nz_idx is a mapping of src buffer indices to destination buffer indices
s->nz_idx[rolling_index(src_pos)] = dst_pos;
}
Expand All @@ -1414,12 +1449,12 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu
// if we're -not- at a leaf column and we're within nesting/row bounds
// and we have a valid data_out pointer, it implies this is a list column, so
// emit an offset.
if (in_nesting_bounds && pni->data_out != nullptr) {
int const idx = pni->value_count + thread_value_count;
cudf::size_type const ofs = s->page.nesting[s_idx + 1].value_count +
if (in_nesting_bounds && nesting_info->data_out != nullptr) {
int const idx = nesting_info->value_count + thread_value_count;
cudf::size_type const ofs = nesting_info_base[s_idx + 1].value_count +
next_thread_value_count +
s->page.nesting[s_idx + 1].page_start_value;
(reinterpret_cast<cudf::size_type*>(pni->data_out))[idx] = ofs;
nesting_info_base[s_idx + 1].page_start_value;
(reinterpret_cast<cudf::size_type*>(nesting_info->data_out))[idx] = ofs;
}
}

Expand All @@ -1441,14 +1476,14 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu

// increment count of valid values, count of total values, and update validity mask
if (!t) {
if (pni->valid_map != nullptr && warp_valid_mask_bit_count > 0) {
if (nesting_info->valid_map != nullptr && warp_valid_mask_bit_count > 0) {
uint32_t const warp_output_valid_mask = warp_valid_mask >> first_thread_in_write_range;
store_validity(pni, warp_output_valid_mask, warp_valid_mask_bit_count);
store_validity(nesting_info, warp_output_valid_mask, warp_valid_mask_bit_count);

pni->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask);
nesting_info->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask);
}
pni->valid_count += warp_valid_count;
pni->value_count += warp_value_count;
nesting_info->valid_count += warp_valid_count;
nesting_info->value_count += warp_value_count;
}

// propagate value counts for the next level
Expand All @@ -1463,7 +1498,7 @@ static __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_inpu
// update
if (!t) {
// update valid value count for decoding and total # of values we've processed
s->nz_count = s->page.nesting[max_depth - 1].valid_count;
s->nz_count = nesting_info_base[max_depth - 1].valid_count;
s->input_value_count = input_value_count;
s->input_row_count = input_row_count;
}
Expand Down Expand Up @@ -1545,7 +1580,7 @@ static __device__ void gpuUpdatePageSizes(page_state_s* s,
// count rows and leaf values
int const is_new_row = start_depth == 0 ? 1 : 0;
uint32_t const warp_row_count_mask = ballot(is_new_row);
int const is_new_leaf = (d >= s->page.nesting[max_depth - 1].max_def_level) ? 1 : 0;
int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level) ? 1 : 0;
uint32_t const warp_leaf_count_mask = ballot(is_new_leaf);
// is this thread within row bounds? on the first pass we don't know the bounds, so we will be
// computing the full size of the column. on the second pass, we will know our actual row
Expand Down Expand Up @@ -1673,40 +1708,44 @@ __global__ void __launch_bounds__(block_size)
// to do the expensive work of traversing the level data to determine sizes. we can just compute
// it directly.
if (!has_repetition && !compute_string_sizes) {
int d = 0;
while (d < s->page.num_output_nesting_levels) {
auto const i = d + t;
if (i < s->page.num_output_nesting_levels) {
if (is_base_pass) { pp->nesting[i].size = pp->num_input_values; }
pp->nesting[i].batch_size = pp->num_input_values;
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
auto const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
if (is_base_pass) { pp->nesting[thread_depth].size = pp->num_input_values; }
pp->nesting[thread_depth].batch_size = pp->num_input_values;
}
d += blockDim.x;
depth += blockDim.x;
}
return;
}

// in the trim pass, for anything with lists, we only need to fully process bounding pages (those
// at the beginning or the end of the row bounds)
if (!is_base_pass && !is_bounds_page(s, min_row, num_rows)) {
int d = 0;
while (d < s->page.num_output_nesting_levels) {
auto const i = d + t;
if (i < s->page.num_output_nesting_levels) {
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
auto const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
// if we are not a bounding page (as checked above) then we are either
// returning 0 rows from the page (completely outside the bounds) or all
// rows in the page (completely within the bounds)
pp->nesting[i].batch_size = s->num_rows == 0 ? 0 : pp->nesting[i].size;
pp->nesting[thread_depth].batch_size =
s->num_rows == 0 ? 0 : pp->nesting[thread_depth].size;
}
d += blockDim.x;
depth += blockDim.x;
}
return;
}

// zero sizes
int d = 0;
while (d < s->page.num_output_nesting_levels) {
if (d + t < s->page.num_output_nesting_levels) { s->page.nesting[d + t].batch_size = 0; }
d += blockDim.x;
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
auto const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
s->page.nesting[thread_depth].batch_size = 0;
}
depth += blockDim.x;
}

__syncthreads();
Expand Down Expand Up @@ -1754,13 +1793,13 @@ __global__ void __launch_bounds__(block_size)
if (!t) { pp->num_rows = s->page.nesting[0].batch_size; }

// store off this batch size as the "full" size
int d = 0;
while (d < s->page.num_output_nesting_levels) {
auto const i = d + t;
if (i < s->page.num_output_nesting_levels) {
pp->nesting[i].size = pp->nesting[i].batch_size;
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
auto const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
pp->nesting[thread_depth].size = pp->nesting[thread_depth].batch_size;
}
d += blockDim.x;
depth += blockDim.x;
}
}

Expand Down Expand Up @@ -1808,6 +1847,8 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData(
((s->col.data_type & 7) == BOOLEAN || (s->col.data_type & 7) == BYTE_ARRAY) ? 64 : 32;
}

PageNestingDecodeInfo* nesting_info_base = s->nesting_info;

// skipped_leaf_values will always be 0 for flat hierarchies.
uint32_t skipped_leaf_values = s->page.skipped_leaf_values;
while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) {
Expand Down Expand Up @@ -1876,7 +1917,7 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData(

uint32_t dtype_len = s->dtype_len;
void* dst =
s->page.nesting[leaf_level_index].data_out + static_cast<size_t>(dst_pos) * dtype_len;
nesting_info_base[leaf_level_index].data_out + static_cast<size_t>(dst_pos) * dtype_len;
if (dtype == BYTE_ARRAY) {
if (s->col.converted_type == DECIMAL) {
auto const [ptr, len] = gpuGetStringData(s, val_src_pos);
Expand Down Expand Up @@ -1931,6 +1972,19 @@ __global__ void __launch_bounds__(block_size) gpuDecodePageData(
}
__syncthreads();
}

// if we are using the nesting decode cache, copy null count back
if (s->nesting_info == s->nesting_decode_cache) {
int depth = 0;
while (depth < s->page.num_output_nesting_levels) {
int const thread_depth = depth + t;
if (thread_depth < s->page.num_output_nesting_levels) {
s->page.nesting_decode[thread_depth].null_count =
s->nesting_decode_cache[thread_depth].null_count;
}
depth += blockDim.x;
}
}
}

} // anonymous namespace
Expand Down
Loading

0 comments on commit 2784f58

Please sign in to comment.