Skip to content

Commit

Permalink
Address potential race conditions in Parquet reader (#14602)
Browse files Browse the repository at this point in the history
Related to #14597. Fixes reported errors by racecheck.

Authors:
  - Ed Seidl (https://github.com/etseidl)
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Nghia Truong (https://github.com/ttnghia)
  - https://github.com/nvdbaranec

URL: #14602
  • Loading branch information
etseidl authored Dec 15, 2023
1 parent 39386c2 commit 2cb8f3d
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cpp/src/io/parquet/delta_binary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ struct delta_binary_decoder {

// need to account for the first value from header on first pass
if (current_value_idx == 0) {
// make sure all threads access current_value_idx above before incrementing
__syncwarp();
if (lane_id == 0) { current_value_idx++; }
__syncwarp();
if (current_value_idx >= value_count) { return; }
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/io/parquet/page_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ __global__ void __launch_bounds__(decode_block_size)
target_pos = min(s->nz_count, src_pos + decode_block_size - out_thread0);
if (out_thread0 > 32) { target_pos = min(target_pos, s->dict_pos); }
}
// TODO(ets): see if this sync can be removed
// this needs to be here to prevent warp 3 modifying src_pos before all threads have read it
__syncthreads();
if (t < 32) {
// decode repetition and definition levels.
Expand All @@ -495,6 +495,10 @@ __global__ void __launch_bounds__(decode_block_size)
uint32_t src_target_pos = target_pos + skipped_leaf_values;

// WARP1: Decode dictionary indices, booleans or string positions
// NOTE: racecheck complains of a RAW error involving the s->dict_pos assignment below.
// This is likely a false positive in practice, but could be solved by wrapping the next
// 9 lines in `if (s->dict_pos < src_target_pos) {}`. If that change is made here, it will
// be needed in the other DecodeXXX kernels.
if (s->dict_base) {
src_target_pos = gpuDecodeDictionaryIndices<false>(s, sb, src_target_pos, t & 0x1f).first;
} else if ((s->col.data_type & 7) == BOOLEAN) {
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/io/parquet/page_decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ __device__ cuda::std::pair<int, int> gpuDecodeDictionaryIndices(page_state_s* s,
int pos = s->dict_pos;
int str_len = 0;

// NOTE: racecheck warns about a RAW involving s->dict_pos, which is likely a false positive
// because the only path that does not include a sync will lead to s->dict_pos being overwritten
// with the same value

while (pos < target_pos) {
int is_literal, batch_len;
if (!t) {
Expand Down Expand Up @@ -357,6 +361,10 @@ inline __device__ int gpuDecodeRleBooleans(page_state_s* s, state_buf* sb, int t
uint8_t const* end = s->data_end;
int64_t pos = s->dict_pos;

// NOTE: racecheck warns about a RAW involving s->dict_pos, which is likely a false positive
// because the only path that does not include a sync will lead to s->dict_pos being overwritten
// with the same value

while (pos < target_pos) {
int is_literal, batch_len;
if (!t) {
Expand Down Expand Up @@ -549,6 +557,9 @@ __device__ void gpuDecodeStream(
batch_coded_count += batch_len;
value_count += batch_len;
}
// issue #14597
// racecheck reported race between reads at the start of this function and the writes below
__syncwarp();

// update the stream info
if (!t) {
Expand Down Expand Up @@ -681,6 +692,9 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value
level_t const* const def,
int t)
{
// exit early if there's no work to do
if (s->input_value_count >= target_input_value_count) { return; }

// max nesting depth of the column
int const max_depth = s->col.max_nesting_depth;
bool const has_repetition = s->col.max_level[level_type::REPETITION] > 0;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/io/parquet/page_delta_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ __global__ void __launch_bounds__(96)
} else { // warp2
target_pos = min(s->nz_count, src_pos + batch_size);
}
// TODO(ets): see if this sync can be removed
// this needs to be here to prevent warp 2 modifying src_pos before all threads have read it
__syncthreads();

// warp0 will decode the rep/def levels, warp1 will unpack a mini-batch of deltas.
Expand Down Expand Up @@ -507,7 +507,7 @@ __global__ void __launch_bounds__(decode_block_size)
} else { // warp 3
target_pos = min(s->nz_count, src_pos + batch_size);
}
// TODO(ets): see if this sync can be removed
// this needs to be here to prevent warp 3 modifying src_pos before all threads have read it
__syncthreads();

// warp0 will decode the rep/def levels, warp1 will unpack a mini-batch of prefixes, warp 2 will
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/io/parquet/page_string_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ __device__ thrust::pair<int, int> page_bounds(page_state_s* const s,

// can skip all this if we know there are no nulls
if (max_def == 0 && !is_bounds_pg) {
s->page.num_valids = s->num_input_values;
s->page.num_nulls = 0;
if (t == 0) {
s->page.num_valids = s->num_input_values;
s->page.num_nulls = 0;
}
return {0, s->num_input_values};
}

Expand Down Expand Up @@ -294,7 +296,6 @@ __device__ thrust::pair<int, int> page_bounds(page_state_s* const s,
pp->num_nulls = null_count;
pp->num_valids = pp->num_input_values - null_count;
}
__syncthreads();

end_value -= pp->num_nulls;
}
Expand Down Expand Up @@ -848,7 +849,7 @@ __global__ void __launch_bounds__(decode_block_size)
target_pos = min(s->nz_count, src_pos + decode_block_size - out_thread0);
if (out_thread0 > 32) { target_pos = min(target_pos, s->dict_pos); }
}
// TODO(ets): see if this sync can be removed
// this needs to be here to prevent warp 1/2 modifying src_pos before all threads have read it
__syncthreads();
if (t < 32) {
// decode repetition and definition levels.
Expand Down

0 comments on commit 2cb8f3d

Please sign in to comment.