Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Showing 11 changed files with 1,044 additions and 114 deletions.
58 changes: 43 additions & 15 deletions cpp/src/io/parquet/delta_binary.cuh
Original file line number Diff line number Diff line change
@@ -39,15 +39,15 @@ namespace cudf::io::parquet::detail {
// per mini-block. While encoding, the lowest delta value is subtracted from all the deltas in the
// block to ensure that all encoded values are positive. The deltas for each mini-block are bit
// packed using the same encoding as the RLE/Bit-Packing Hybrid encoder.
//
// DELTA_BYTE_ARRAY encoding (incremental encoding or front compression), is used for BYTE_ARRAY
// columns. For each element in a sequence of strings, a prefix length from the preceding string
// and a suffix is stored. The prefix lengths are DELTA_BINARY_PACKED encoded. The suffixes are
// encoded with DELTA_LENGTH_BYTE_ARRAY encoding, which is a DELTA_BINARY_PACKED list of suffix
// lengths, followed by the concatenated suffix data.

// we decode one mini-block at a time. max mini-block size seen is 64.
constexpr int delta_rolling_buf_size = 128;
// The largest mini-block size we can currently support.
constexpr int max_delta_mini_block_size = 64;

// The first pass decodes `values_per_mb` values, and then the second pass does another
// batch of size `values_per_mb`. The largest value for values_per_miniblock among the
// major writers seems to be 64, so 2 * 64 should be good. We save the first value separately
// since it is not encoded in the first mini-block.
constexpr int delta_rolling_buf_size = 2 * max_delta_mini_block_size;

/**
* @brief Read a ULEB128 varint integer
@@ -90,7 +90,8 @@ struct delta_binary_decoder {
uleb128_t mini_block_count; // usually 4, chosen such that block_size/mini_block_count is a
// multiple of 32
uleb128_t value_count; // total values encoded in the block
zigzag128_t last_value; // last value decoded, initialized to first_value from header
zigzag128_t first_value; // initial value, stored in the header
zigzag128_t last_value; // last value decoded

uint32_t values_per_mb; // block_size / mini_block_count, must be multiple of 32
uint32_t current_value_idx; // current value index, initialized to 0 at start of block
@@ -102,6 +103,13 @@ struct delta_binary_decoder {

uleb128_t value[delta_rolling_buf_size]; // circular buffer of delta values

// returns the value stored in the `value` array at index
// `rolling_index<delta_rolling_buf_size>(idx)`. If `idx` is `0`, then return `first_value`.
constexpr zigzag128_t value_at(size_type idx)
{
return idx == 0 ? first_value : value[rolling_index<delta_rolling_buf_size>(idx)];
}

// returns the number of values encoded in the block data. when all_values is true,
// account for the first value in the header. otherwise just count the values encoded
// in the mini-block data.
@@ -145,7 +153,8 @@ struct delta_binary_decoder {
block_size = get_uleb128(d_start, d_end);
mini_block_count = get_uleb128(d_start, d_end);
value_count = get_uleb128(d_start, d_end);
last_value = get_zz128(d_start, d_end);
first_value = get_zz128(d_start, d_end);
last_value = first_value;

current_value_idx = 0;
values_per_mb = block_size / mini_block_count;
@@ -179,19 +188,38 @@ struct delta_binary_decoder {
}
}

// given start/end pointers in the data, find the end of the binary encoded block. when done,
// `this` will be initialized with the correct start and end positions. returns the end, which is
// start of data/next block. should only be called from thread 0.
inline __device__ uint8_t const* find_end_of_block(uint8_t const* start, uint8_t const* end)
{
// read block header
init_binary_block(start, end);

// test for no encoded values. a single value will be in the block header.
if (value_count <= 1) { return block_start; }

// read mini-block headers and skip over data
while (current_value_idx < num_encoded_values(false)) {
setup_next_mini_block(false);
}
// calculate the correct end of the block
auto const* const new_end = cur_mb == 0 ? block_start : cur_mb_start;
// re-init block with correct end
init_binary_block(start, new_end);
return new_end;
}

// decode the current mini-batch of deltas, and convert to values.
// called by all threads in a warp, currently only one warp supported.
inline __device__ void calc_mini_block_values(int lane_id)
{
using cudf::detail::warp_size;
if (current_value_idx >= value_count) { return; }

// need to save first value from header on first pass
// need to account for the first value from header on first pass
if (current_value_idx == 0) {
if (lane_id == 0) {
current_value_idx++;
value[0] = last_value;
}
if (lane_id == 0) { current_value_idx++; }
__syncwarp();
if (current_value_idx >= value_count) { return; }
}
12 changes: 9 additions & 3 deletions cpp/src/io/parquet/page_data.cu
Original file line number Diff line number Diff line change
@@ -449,8 +449,13 @@ __global__ void __launch_bounds__(decode_block_size)
int out_thread0;
[[maybe_unused]] null_count_back_copier _{s, t};

if (!setupLocalPageInfo(
s, &pages[page_idx], chunks, min_row, num_rows, mask_filter{KERNEL_MASK_GENERAL}, true)) {
if (!setupLocalPageInfo(s,
&pages[page_idx],
chunks,
min_row,
num_rows,
mask_filter{decode_kernel_mask::GENERAL},
true)) {
return;
}

@@ -486,6 +491,7 @@ __global__ void __launch_bounds__(decode_block_size)
target_pos = min(s->nz_count, src_pos + decode_block_size - out_thread0);
if (out_thread0 > 32) { target_pos = min(target_pos, s->dict_pos); }
}
// TODO(ets): see if this sync can be removed
__syncthreads();
if (t < 32) {
// decode repetition and definition levels.
@@ -603,7 +609,7 @@ __global__ void __launch_bounds__(decode_block_size)
}

struct mask_tform {
__device__ uint32_t operator()(PageInfo const& p) { return p.kernel_mask; }
__device__ uint32_t operator()(PageInfo const& p) { return static_cast<uint32_t>(p.kernel_mask); }
};

} // anonymous namespace
12 changes: 10 additions & 2 deletions cpp/src/io/parquet/page_decode.cuh
Original file line number Diff line number Diff line change
@@ -991,8 +991,15 @@ struct all_types_filter {
* @brief Functor for setupLocalPageInfo that takes a mask of allowed types.
*/
struct mask_filter {
int mask;
__device__ inline bool operator()(PageInfo const& page) { return (page.kernel_mask & mask) != 0; }
uint32_t mask;

__device__ mask_filter(uint32_t m) : mask(m) {}
__device__ mask_filter(decode_kernel_mask m) : mask(static_cast<uint32_t>(m)) {}

__device__ inline bool operator()(PageInfo const& page)
{
return BitAnd(mask, page.kernel_mask) != 0;
}
};

/**
@@ -1306,6 +1313,7 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s,
s->dict_run = 0;
} break;
case Encoding::DELTA_BINARY_PACKED:
case Encoding::DELTA_BYTE_ARRAY:
// nothing to do, just don't error
break;
default: {
490 changes: 467 additions & 23 deletions cpp/src/io/parquet/page_delta_decode.cu

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions cpp/src/io/parquet/page_hdr.cu
Original file line number Diff line number Diff line change
@@ -146,18 +146,21 @@ __device__ void skip_struct_field(byte_stream_s* bs, int field_type)
* @param chunk Column chunk the page belongs to
* @return `kernel_mask_bits` value for the given page
*/
__device__ uint32_t kernel_mask_for_page(PageInfo const& page, ColumnChunkDesc const& chunk)
__device__ decode_kernel_mask kernel_mask_for_page(PageInfo const& page,
ColumnChunkDesc const& chunk)
{
if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return 0; }
if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return decode_kernel_mask::NONE; }

if (page.encoding == Encoding::DELTA_BINARY_PACKED) {
return KERNEL_MASK_DELTA_BINARY;
return decode_kernel_mask::DELTA_BINARY;
} else if (page.encoding == Encoding::DELTA_BYTE_ARRAY) {
return decode_kernel_mask::DELTA_BYTE_ARRAY;
} else if (is_string_col(chunk)) {
return KERNEL_MASK_STRING;
return decode_kernel_mask::STRING;
}

// non-string, non-delta
return KERNEL_MASK_GENERAL;
return decode_kernel_mask::GENERAL;
}

/**
@@ -380,7 +383,9 @@ __global__ void __launch_bounds__(128)
bs->page.skipped_values = -1;
bs->page.skipped_leaf_values = 0;
bs->page.str_bytes = 0;
bs->page.kernel_mask = 0;
bs->page.temp_string_size = 0;
bs->page.temp_string_buf = nullptr;
bs->page.kernel_mask = decode_kernel_mask::NONE;
}
num_values = bs->ck.num_values;
page_info = bs->ck.page_info;
344 changes: 311 additions & 33 deletions cpp/src/io/parquet/page_string_decode.cu

Large diffs are not rendered by default.

79 changes: 64 additions & 15 deletions cpp/src/io/parquet/parquet_gpu.hpp
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@

#include <cuda_runtime.h>

#include <type_traits>
#include <vector>

namespace cudf::io::parquet::detail {
@@ -64,7 +65,8 @@ constexpr bool is_supported_encoding(Encoding enc)
case Encoding::PLAIN_DICTIONARY:
case Encoding::RLE:
case Encoding::RLE_DICTIONARY:
case Encoding::DELTA_BINARY_PACKED: return true;
case Encoding::DELTA_BINARY_PACKED:
case Encoding::DELTA_BYTE_ARRAY: return true;
default: return false;
}
}
@@ -86,13 +88,15 @@ constexpr void set_error(int32_t error, int32_t* error_code)
* These values are used as bitmasks, so they must be powers of 2.
*/
enum class decode_error : int32_t {
DATA_STREAM_OVERRUN = 0x1,
LEVEL_STREAM_OVERRUN = 0x2,
UNSUPPORTED_ENCODING = 0x4,
INVALID_LEVEL_RUN = 0x8,
INVALID_DATA_TYPE = 0x10,
EMPTY_PAGE = 0x20,
INVALID_DICT_WIDTH = 0x40,
DATA_STREAM_OVERRUN = 0x1,
LEVEL_STREAM_OVERRUN = 0x2,
UNSUPPORTED_ENCODING = 0x4,
INVALID_LEVEL_RUN = 0x8,
INVALID_DATA_TYPE = 0x10,
EMPTY_PAGE = 0x20,
INVALID_DICT_WIDTH = 0x40,
DELTA_PARAM_MISMATCH = 0x80,
DELTA_PARAMS_UNSUPPORTED = 0x100,
};

/**
@@ -145,6 +149,17 @@ constexpr uint32_t BitAnd(T1 a, T2 b)
return static_cast<uint32_t>(a) & static_cast<uint32_t>(b);
}

template <class T1,
class T2,
typename std::enable_if_t<(is_scoped_enum<T1>::value and std::is_same_v<T1, T2>) or
(is_scoped_enum<T1>::value and std::is_same_v<uint32_t, T2>) or
(is_scoped_enum<T2>::value and std::is_same_v<uint32_t, T1>)>* =
nullptr>
constexpr uint32_t BitOr(T1 a, T2 b)
{
return static_cast<uint32_t>(a) | static_cast<uint32_t>(b);
}

/**
* @brief Enums for the flags in the page header
*/
@@ -168,10 +183,12 @@ enum level_type {
*
* Used to control which decode kernels to run.
*/
enum kernel_mask_bits {
KERNEL_MASK_GENERAL = (1 << 0), // Run catch-all decode kernel
KERNEL_MASK_STRING = (1 << 1), // Run decode kernel for string data
KERNEL_MASK_DELTA_BINARY = (1 << 2) // Run decode kernel for DELTA_BINARY_PACKED data
enum class decode_kernel_mask {
NONE = 0,
GENERAL = (1 << 0), // Run catch-all decode kernel
STRING = (1 << 1), // Run decode kernel for string data
DELTA_BINARY = (1 << 2), // Run decode kernel for DELTA_BINARY_PACKED data
DELTA_BYTE_ARRAY = (1 << 3) // Run decode kernel for DELTA_BYTE_ARRAY encoded data
};

/**
@@ -252,9 +269,11 @@ struct PageInfo {
int32_t num_input_values;
int32_t chunk_row; // starting row of this page relative to the start of the chunk
int32_t num_rows; // number of rows in this page
// the next two are calculated in gpuComputePageStringSizes
// the next four are calculated in gpuComputePageStringSizes
int32_t num_nulls; // number of null values (V2 header), but recalculated for string cols
int32_t num_valids; // number of non-null values, taking into account skip_rows/num_rows
int32_t start_val; // index of first value of the string data stream to use
int32_t end_val; // index of last value in string data stream
int32_t chunk_idx; // column chunk this page belongs to
int32_t src_col_schema; // schema index of this column
uint8_t flags; // PAGEINFO_FLAGS_XXX
@@ -291,7 +310,11 @@ struct PageInfo {
// level decode buffers
uint8_t* lvl_decode_buf[level_type::NUM_LEVEL_TYPES];

uint32_t kernel_mask;
// temporary space for decoding DELTA_BYTE_ARRAY encoded strings
int64_t temp_string_size;
uint8_t* temp_string_buf;

decode_kernel_mask kernel_mask;
};

/**
@@ -597,16 +620,20 @@ void ComputePageSizes(cudf::detail::hostdevice_vector<PageInfo>& pages,
*
* @param[in,out] pages All pages to be decoded
* @param[in] chunks All chunks to be decoded
* @param[out] temp_string_buf Temporary space needed for decoding DELTA_BYTE_ARRAY strings
* @param[in] min_rows crop all rows below min_row
* @param[in] num_rows Maximum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[in] kernel_mask Mask of kernels to run
* @param[in] stream CUDA stream to use
*/
void ComputePageStringSizes(cudf::detail::hostdevice_vector<PageInfo>& pages,
cudf::detail::hostdevice_vector<ColumnChunkDesc> const& chunks,
rmm::device_uvector<uint8_t>& temp_string_buf,
size_t min_row,
size_t num_rows,
int level_type_size,
uint32_t kernel_mask,
rmm::cuda_stream_view stream);

/**
@@ -665,7 +692,7 @@ void DecodeStringPageData(cudf::detail::hostdevice_vector<PageInfo>& pages,
* @param[in] min_row Minimum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[out] error_code Error code for kernel failures
* @param[in] stream CUDA stream to use, default 0
* @param[in] stream CUDA stream to use
*/
void DecodeDeltaBinary(cudf::detail::hostdevice_vector<PageInfo>& pages,
cudf::detail::hostdevice_vector<ColumnChunkDesc> const& chunks,
@@ -675,6 +702,28 @@ void DecodeDeltaBinary(cudf::detail::hostdevice_vector<PageInfo>& pages,
int32_t* error_code,
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel for reading the DELTA_BYTE_ARRAY column data stored in the pages
*
* The page data will be written to the output pointed to in the page's
* associated column chunk.
*
* @param[in,out] pages All pages to be decoded
* @param[in] chunks All chunks to be decoded
* @param[in] num_rows Total number of rows to read
* @param[in] min_row Minimum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[out] error_code Error code for kernel failures
* @param[in] stream CUDA stream to use
*/
void DecodeDeltaByteArray(cudf::detail::hostdevice_vector<PageInfo>& pages,
cudf::detail::hostdevice_vector<ColumnChunkDesc> const& chunks,
size_t num_rows,
size_t min_row,
int level_type_size,
int32_t* error_code,
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel for initializing encoder row group fragments
*
38 changes: 23 additions & 15 deletions cpp/src/io/parquet/reader_impl.cpp
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@
#include <cudf/detail/transform.hpp>
#include <cudf/detail/utilities/stream_pool.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <rmm/cuda_stream_pool.hpp>

#include <bitset>
#include <numeric>
@@ -30,10 +29,15 @@ namespace cudf::io::parquet::detail {

void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows)
{
auto& chunks = _pass_itm_data->chunks;
auto& pages = _pass_itm_data->pages_info;
auto& page_nesting = _pass_itm_data->page_nesting_info;
auto& page_nesting_decode = _pass_itm_data->page_nesting_decode_info;
auto& chunks = _pass_itm_data->chunks;
auto& pages = _pass_itm_data->pages_info;
auto& page_nesting = _pass_itm_data->page_nesting_info;
auto& page_nesting_decode = _pass_itm_data->page_nesting_decode_info;
auto const level_type_size = _pass_itm_data->level_type_size;

// temporary space for DELTA_BYTE_ARRAY decoding. this only needs to live until
// gpu::DecodeDeltaByteArray returns.
rmm::device_uvector<uint8_t> delta_temp_buf(0, _stream);

// Should not reach here if there is no page data.
CUDF_EXPECTS(pages.size() > 0, "There is no page to decode");
@@ -52,11 +56,12 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows)
// doing a gather operation later on.
// TODO: This step is somewhat redundant if size info has already been calculated (nested schema,
// chunked reader).
auto const has_strings = (kernel_mask & KERNEL_MASK_STRING) != 0;
auto const has_strings =
(kernel_mask & BitOr(decode_kernel_mask::STRING, decode_kernel_mask::DELTA_BYTE_ARRAY)) != 0;
std::vector<size_t> col_sizes(_input_columns.size(), 0L);
if (has_strings) {
ComputePageStringSizes(
pages, chunks, skip_rows, num_rows, _pass_itm_data->level_type_size, _stream);
pages, chunks, delta_temp_buf, skip_rows, num_rows, level_type_size, kernel_mask, _stream);

col_sizes = calculate_page_string_offsets();

@@ -163,6 +168,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows)
chunks.host_to_device_async(_stream);
chunk_nested_valids.host_to_device_async(_stream);
chunk_nested_data.host_to_device_async(_stream);
if (has_strings) { chunk_nested_str_data.host_to_device_async(_stream); }

// create this before we fork streams
kernel_error error_code(_stream);
@@ -171,25 +177,27 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows)
int const nkernels = std::bitset<32>(kernel_mask).count();
auto streams = cudf::detail::fork_streams(_stream, nkernels);

auto const level_type_size = _pass_itm_data->level_type_size;

// launch string decoder
int s_idx = 0;
if (has_strings) {
auto& stream = streams[s_idx++];
chunk_nested_str_data.host_to_device_async(stream);
if (BitAnd(kernel_mask, decode_kernel_mask::STRING) != 0) {
DecodeStringPageData(
pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), stream);
pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]);
}

// launch delta byte array decoder
if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BYTE_ARRAY) != 0) {
DecodeDeltaByteArray(
pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]);
}

// launch delta binary decoder
if ((kernel_mask & KERNEL_MASK_DELTA_BINARY) != 0) {
if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BINARY) != 0) {
DecodeDeltaBinary(
pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]);
}

// launch the catch-all page decoder
if ((kernel_mask & KERNEL_MASK_GENERAL) != 0) {
if (BitAnd(kernel_mask, decode_kernel_mask::GENERAL) != 0) {
DecodePageData(
pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]);
}
4 changes: 2 additions & 2 deletions cpp/src/io/parquet/reader_impl_preprocess.cu
Original file line number Diff line number Diff line change
@@ -1416,15 +1416,15 @@ std::vector<size_t> reader::impl::calculate_page_string_offsets()
page_index.begin(), page_to_string_size{pages.device_ptr(), chunks.device_ptr()});

// do scan by key to calculate string offsets for each page
thrust::exclusive_scan_by_key(rmm::exec_policy(_stream),
thrust::exclusive_scan_by_key(rmm::exec_policy_nosync(_stream),
page_keys.begin(),
page_keys.end(),
val_iter,
page_offset_output_iter{pages.device_ptr(), page_index.data()});

// now sum up page sizes
rmm::device_uvector<int> reduce_keys(col_sizes.size(), _stream);
thrust::reduce_by_key(rmm::exec_policy(_stream),
thrust::reduce_by_key(rmm::exec_policy_nosync(_stream),
page_keys.begin(),
page_keys.end(),
val_iter,
Binary file not shown.
104 changes: 104 additions & 0 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
@@ -1284,6 +1284,15 @@ def test_parquet_reader_v2(tmpdir, simple_pdf):
assert_eq(cudf.read_parquet(pdf_fname), simple_pdf)


def test_parquet_delta_byte_array(datadir):
fname = datadir / "delta_byte_arr.parquet"
assert_eq(cudf.read_parquet(fname), pd.read_parquet(fname))


def delta_num_rows():
return [1, 2, 23, 32, 33, 34, 64, 65, 66, 128, 129, 130, 20000, 50000]


@pytest.mark.parametrize("nrows", [1, 100000])
@pytest.mark.parametrize("add_nulls", [True, False])
@pytest.mark.parametrize(
@@ -1320,6 +1329,7 @@ def test_delta_binary(nrows, add_nulls, dtype, tmpdir):
version="2.6",
column_encoding="DELTA_BINARY_PACKED",
data_page_version="2.0",
data_page_size=64 * 1024,
engine="pyarrow",
use_dictionary=False,
)
@@ -1350,6 +1360,100 @@ def test_delta_binary(nrows, add_nulls, dtype, tmpdir):
assert_eq(cdf2, cdf)


@pytest.mark.parametrize("nrows", delta_num_rows())
@pytest.mark.parametrize("add_nulls", [True, False])
@pytest.mark.parametrize("str_encoding", ["DELTA_BYTE_ARRAY"])
def test_delta_byte_array_roundtrip(nrows, add_nulls, str_encoding, tmpdir):
null_frequency = 0.25 if add_nulls else 0

# Create a pandas dataframe with random data of mixed lengths
test_pdf = dg.rand_dataframe(
dtypes_meta=[
{
"dtype": "str",
"null_frequency": null_frequency,
"cardinality": nrows,
"max_string_length": 10,
},
{
"dtype": "str",
"null_frequency": null_frequency,
"cardinality": nrows,
"max_string_length": 100,
},
],
rows=nrows,
seed=0,
use_threads=False,
).to_pandas()

pdf_fname = tmpdir.join("pdfdeltaba.parquet")
test_pdf.to_parquet(
pdf_fname,
version="2.6",
column_encoding=str_encoding,
data_page_version="2.0",
data_page_size=64 * 1024,
engine="pyarrow",
use_dictionary=False,
)
cdf = cudf.read_parquet(pdf_fname)
pcdf = cudf.from_pandas(test_pdf)
assert_eq(cdf, pcdf)


@pytest.mark.parametrize("nrows", delta_num_rows())
@pytest.mark.parametrize("add_nulls", [True, False])
@pytest.mark.parametrize("str_encoding", ["DELTA_BYTE_ARRAY"])
def test_delta_struct_list(tmpdir, nrows, add_nulls, str_encoding):
# Struct<List<List>>
lists_per_row = 3
list_size = 4
num_rows = nrows
include_validity = add_nulls

def list_gen_wrapped(x, y):
return list_row_gen(
int_gen, x * list_size * lists_per_row, list_size, lists_per_row
)

def string_list_gen_wrapped(x, y):
return list_row_gen(
string_gen,
x * list_size * lists_per_row,
list_size,
lists_per_row,
include_validity,
)

data = struct_gen(
[int_gen, string_gen, list_gen_wrapped, string_list_gen_wrapped],
0,
num_rows,
include_validity,
)
test_pdf = pa.Table.from_pydict({"sol": data}).to_pandas()
pdf_fname = tmpdir.join("pdfdeltaba.parquet")
test_pdf.to_parquet(
pdf_fname,
version="2.6",
column_encoding={
"sol.col0": "DELTA_BINARY_PACKED",
"sol.col1": str_encoding,
"sol.col2.list.element.list.element": "DELTA_BINARY_PACKED",
"sol.col3.list.element.list.element": str_encoding,
},
data_page_version="2.0",
data_page_size=64 * 1024,
engine="pyarrow",
use_dictionary=False,
)
# sanity check to verify file is written properly
assert_eq(test_pdf, pd.read_parquet(pdf_fname))
cdf = cudf.read_parquet(pdf_fname)
assert_eq(cdf, cudf.from_pandas(test_pdf))


@pytest.mark.parametrize(
"data",
[

0 comments on commit a51ab18

Please sign in to comment.