-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization to decoding of parquet level streams #13203
Optimization to decoding of parquet level streams #13203
Conversation
…, it was only 1 warp wide. Now it is block-wide. Only integrated into the gpuComputePageSizes() kernel. gpuDecodePages() will be a followup PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bunch of minor comments, still need to fully understand the core algorithm.
…al with a performance issue introduced in gpuDecodePageData by previously changing them to be pointers instead of hardcoded arrays.
… buffer size from 4096 to 2048. Global scratch memory cost per page now 8k instead of 32k. This will likely need to be tuned further as this optimization gets appled to the decode kernel.
…incorrectly sized (benign) run_buffer_size constexpr.
… cases, we only need 1 byte to store level information since size of the values is proportional to nesting depth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::review << few_small_comments << std::flush
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good stuff! I've been using it in my own branch and it's really made a big difference. Looking forward to the sequel 😄.
A few const nits, but can't see any other issues (and have not run into any in testing). My only question is the choice of 512 for num_rle_stream_decode_threads
. Is there an existing corpus for tuning this value? In my own testing it has been the best overall, but there have been cases where a smaller value is optimal. I'm also curious what this choice means for Spark where they'll have multiple decodes running concurrently IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first pass, still digesting the changes. 🔥
I've done benchmarking against some internal queries we use. Specifically a case where we have 4 cpu threads running parquet jobs at the same time on the gpu. This has traditionally been sensitive to occupancy issues so I kept an eye on that. I do suspect there will be tuning in the future: in particular because I think a useful post-optimization will be to balance out uneven run sizes across the decode warps - that'll certainly affect the # of useful warps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a weird edge case, but otherwise good to go as far as I can see.
…ce condition when computing the number of skipped values during the preprocess step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! 🔥
Thanks @nvdbaranec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just a few questions/suggestions.
359e281
to
8bbbab1
Compare
/merge |
The current Parquet reader decodes string data into a list of {ptr, length} tuples, which are then used in a gather step by `make_strings_column`. This gather step can be time consuming, especially when there are a large number of string columns. This PR addresses this by changing the decode step to write char and offset data directly to the `column_buffer`, which can then be used directly, bypassing the gather step. The image below compares the new approach to the old. The green arc at the top (82ms) is `gpuDecodePageData`, and the red arc (252ms) is the time spent in `make_strings_column`. The green arc below (25ms) is `gpuDecodePageData`, the amber arc (22ms) is a new kernel that computes string sizes for each page, and the magenta arc (106ms) is the kernel that decodes string columns. ![flat_edited](https://user-images.githubusercontent.com/25541553/236529570-f2d0d8d4-b2b5-4078-93ae-5123fa489c3c.png) NVbench shows a good speed up for strings as well. There is a jump in time for the INTEGRAL benchmark, but little to no change for other data types. The INTEGRAL time seems to be affected by extra time spent in `malloc` allocating host memory for a `hostdevice_vector`. This `malloc` always occurs, but for some reason in this branch it takes much longer to return. This is comparing to @nvdbaranec's branch for #13203. ``` | data_type | io | cardinality | run_length | Ref Time | Cmp Time | Diff | %Diff | |-------------|---------------|---------------|--------------|------------|------------|-------------|---------| | INTEGRAL | DEVICE_BUFFER | 0 | 1 | 14.288 ms | 14.729 ms | 440.423 us | 3.08% | | INTEGRAL | DEVICE_BUFFER | 1000 | 1 | 13.397 ms | 13.997 ms | 600.596 us | 4.48% | | INTEGRAL | DEVICE_BUFFER | 0 | 32 | 11.831 ms | 12.354 ms | 522.485 us | 4.42% | | INTEGRAL | DEVICE_BUFFER | 1000 | 32 | 11.335 ms | 11.854 ms | 518.791 us | 4.58% | | FLOAT | DEVICE_BUFFER | 0 | 1 | 8.681 ms | 8.715 ms | 34.846 us | 0.40% | | FLOAT | DEVICE_BUFFER | 1000 | 1 | 8.473 ms | 8.472 ms | -0.680 us | -0.01% | | FLOAT | DEVICE_BUFFER | 0 | 32 | 7.217 ms | 7.192 ms | -25.311 us | -0.35% | | FLOAT | DEVICE_BUFFER | 1000 | 32 | 7.425 ms | 7.422 ms | -3.162 us | -0.04% | | STRING | DEVICE_BUFFER | 0 | 1 | 50.079 ms | 42.566 ms |-7513.004 us | -15.00% | | STRING | DEVICE_BUFFER | 1000 | 1 | 16.813 ms | 14.989 ms |-1823.660 us | -10.85% | | STRING | DEVICE_BUFFER | 0 | 32 | 49.875 ms | 42.443 ms |-7432.718 us | -14.90% | | STRING | DEVICE_BUFFER | 1000 | 32 | 15.312 ms | 13.953 ms |-1358.910 us | -8.87% | | LIST | DEVICE_BUFFER | 0 | 1 | 80.303 ms | 80.688 ms | 385.916 us | 0.48% | | LIST | DEVICE_BUFFER | 1000 | 1 | 71.921 ms | 72.356 ms | 435.153 us | 0.61% | | LIST | DEVICE_BUFFER | 0 | 32 | 61.658 ms | 62.129 ms | 471.022 us | 0.76% | | LIST | DEVICE_BUFFER | 1000 | 32 | 63.086 ms | 63.371 ms | 285.608 us | 0.45% | | STRUCT | DEVICE_BUFFER | 0 | 1 | 66.272 ms | 61.142 ms |-5130.639 us | -7.74% | | STRUCT | DEVICE_BUFFER | 1000 | 1 | 40.217 ms | 39.328 ms | -888.781 us | -2.21% | | STRUCT | DEVICE_BUFFER | 0 | 32 | 63.660 ms | 58.837 ms |-4822.647 us | -7.58% | | STRUCT | DEVICE_BUFFER | 1000 | 32 | 38.080 ms | 37.104 ms | -976.133 us | -2.56% | ``` May address #13024 ~Depends on #13203~ Authors: - Ed Seidl (https://github.com/etseidl) - https://github.com/nvdbaranec - Vukasin Milovanovic (https://github.com/vuule) - Nghia Truong (https://github.com/ttnghia) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Mike Wilson (https://github.com/hyperbolic2346) - https://github.com/nvdbaranec - Vyas Ramasubramani (https://github.com/vyasr) URL: #13302
An optimization to the decoding of the definition and repetition level streams in Parquet files. Previously, we were decoding these streams using 1 warp. With this optimization we do it arbitrarily wide (currently set for 512 threads). This gives a dramatic improvement.
The core of the work is in the new file
rle_stream.cuh
which encapsulates the decoding into anrle_stream
object.This PR only applies the opimization to the
gpuComputePageSizes
kernel, used for preprocessing list columns and for the chunked read case involving strings or lists. In addition, theUpdatePageSizes
function has been improved to also work at the block level instead of just using a single warp. Testing with the cudf parquet reader list benchmarks result in as much as a 75% reduction in time in thegpuComputePageSizes
kernel.Future PRs will apply this to the gpuDecodePageData kernel.
Leaving as a draft for the moment - more detailed benchmarks and numbers forthcoming, along with some possible parameter tuning.
Benchmark info. A before/after sample from the
parquet_reader_io_compression
suite on an A5000. The kernel goes from 427 milliseconds to 93 milliseconds. This seems to be a pretty typical situation, although it will definitely be affected by the encoded data (run lengths, etc).The reader benchmarks that involve this kernel yield some great improvements.
Checklist