Skip to content

Commit

Permalink
Fix decoding of dictionary encoded FIXED_LEN_BYTE_ARRAY data in Parqu…
Browse files Browse the repository at this point in the history
…et reader (#15601)

Reading Parquet files with dictionary encoded FIXED_LEN_BYTE_ARRAY data fails because the dictionary page is never parsed, leading to out-of-bounds memory accesses.

Authors:
  - Ed Seidl (https://github.com/etseidl)
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Karthikeyan (https://github.com/karthikeyann)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #15601
  • Loading branch information
etseidl authored May 7, 2024
1 parent 2e81857 commit 0cfdbc1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
10 changes: 7 additions & 3 deletions cpp/src/io/parquet/page_decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,13 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s,
// be made to is_supported_encoding() in reader_impl_preprocess.cu
switch (s->page.encoding) {
case Encoding::PLAIN_DICTIONARY:
case Encoding::RLE_DICTIONARY:
case Encoding::RLE_DICTIONARY: {
// RLE-packed dictionary indices, first byte indicates index length in bits
if (s->col.physical_type == BYTE_ARRAY && s->col.str_dict_index != nullptr) {
auto const is_decimal =
s->col.logical_type.has_value() and s->col.logical_type->type == LogicalType::DECIMAL;
if ((s->col.physical_type == BYTE_ARRAY or
s->col.physical_type == FIXED_LEN_BYTE_ARRAY) and
not is_decimal and s->col.str_dict_index != nullptr) {
// String dictionary: use index
s->dict_base = reinterpret_cast<uint8_t const*>(s->col.str_dict_index);
s->dict_size = s->col.dict_page->num_input_values * sizeof(string_index_pair);
Expand All @@ -1314,7 +1318,7 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s,
if (s->dict_bits > 32 || (!s->dict_base && s->col.dict_page->num_input_values > 0)) {
s->set_error_code(decode_error::INVALID_DICT_WIDTH);
}
break;
} break;
case Encoding::PLAIN:
case Encoding::BYTE_STREAM_SPLIT:
s->dict_size = static_cast<int32_t>(end - cur);
Expand Down
21 changes: 16 additions & 5 deletions cpp/src/io/parquet/page_hdr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -538,17 +538,28 @@ CUDF_KERNEL void __launch_bounds__(128)
int pos = 0, cur = 0;
for (int i = 0; i < num_entries; i++) {
int len = 0;
if (cur + 4 <= dict_size) {
len = dict[cur + 0] | (dict[cur + 1] << 8) | (dict[cur + 2] << 16) | (dict[cur + 3] << 24);
if (len >= 0 && cur + 4 + len <= dict_size) {
if (ck->physical_type == FIXED_LEN_BYTE_ARRAY) {
if (cur + ck->type_length <= dict_size) {
len = ck->type_length;
pos = cur;
cur = cur + 4 + len;
cur += len;
} else {
cur = dict_size;
}
} else {
if (cur + 4 <= dict_size) {
len =
dict[cur + 0] | (dict[cur + 1] << 8) | (dict[cur + 2] << 16) | (dict[cur + 3] << 24);
if (len >= 0 && cur + 4 + len <= dict_size) {
pos = cur + 4;
cur = pos + len;
} else {
cur = dict_size;
}
}
}
// TODO: Could store 8 entries in shared mem, then do a single warp-wide store
dict_index[i].first = reinterpret_cast<char const*>(dict + pos + 4);
dict_index[i].first = reinterpret_cast<char const*>(dict + pos);
dict_index[i].second = len;
}
}
Expand Down
15 changes: 12 additions & 3 deletions cpp/src/io/parquet/reader_impl_preprocess.cu
Original file line number Diff line number Diff line change
Expand Up @@ -636,15 +636,24 @@ void decode_page_headers(pass_intermediate_data& pass,
stream.synchronize();
}

constexpr bool is_string_chunk(ColumnChunkDesc const& chunk)
{
auto const is_decimal =
chunk.logical_type.has_value() and chunk.logical_type->type == LogicalType::DECIMAL;
auto const is_binary =
chunk.physical_type == BYTE_ARRAY or chunk.physical_type == FIXED_LEN_BYTE_ARRAY;
return is_binary and not is_decimal;
}

struct set_str_dict_index_count {
device_span<size_t> str_dict_index_count;
device_span<const ColumnChunkDesc> chunks;

__device__ void operator()(PageInfo const& page)
{
auto const& chunk = chunks[page.chunk_idx];
if ((page.flags & PAGEINFO_FLAGS_DICTIONARY) && chunk.physical_type == BYTE_ARRAY &&
(chunk.num_dict_pages > 0)) {
if ((page.flags & PAGEINFO_FLAGS_DICTIONARY) != 0 and chunk.num_dict_pages > 0 and
is_string_chunk(chunk)) {
// there is only ever one dictionary page per chunk, so this is safe to do in parallel.
str_dict_index_count[page.chunk_idx] = page.num_input_values;
}
Expand All @@ -659,7 +668,7 @@ struct set_str_dict_index_ptr {
__device__ void operator()(size_t i)
{
auto& chunk = chunks[i];
if (chunk.physical_type == BYTE_ARRAY && (chunk.num_dict_pages > 0)) {
if (chunk.num_dict_pages > 0 and is_string_chunk(chunk)) {
chunk.str_dict_index = base + str_dict_index_offsets[i];
}
}
Expand Down
19 changes: 19 additions & 0 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import glob
import hashlib
import math
import os
import pathlib
Expand Down Expand Up @@ -2807,6 +2808,24 @@ def test_parquet_reader_fixed_bin(datadir):
assert_eq(expect, got)


def test_parquet_reader_fixed_len_with_dict(tmpdir):
def flba(i):
hasher = hashlib.sha256()
hasher.update(i.to_bytes(4, "little"))
return hasher.digest()

# use pyarrow to write table of fixed_len_byte_array
num_rows = 200
data = pa.array([flba(i) for i in range(num_rows)], type=pa.binary(32))
padf = pa.Table.from_arrays([data], names=["flba"])
padf_fname = tmpdir.join("padf.parquet")
pq.write_table(padf, padf_fname, use_dictionary=True)

expect = pd.read_parquet(padf_fname)
got = cudf.read_parquet(padf_fname)
assert_eq(expect, got)


def test_parquet_reader_rle_boolean(datadir):
fname = datadir / "rle_boolean_encoding.parquet"

Expand Down

0 comments on commit 0cfdbc1

Please sign in to comment.