-
Notifications
You must be signed in to change notification settings - Fork 915
/
page_data.cu
2129 lines (1961 loc) · 81.7 KB
/
page_data.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parquet_gpu.hpp"
#include <io/utilities/block_utils.cuh>
#include <io/utilities/column_buffer.hpp>
#include <cuda/std/tuple>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/detail/utilities/hash_functions.cuh>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/strings/string_view.hpp>
#include <cudf/utilities/bit.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/functional.h>
#include <thrust/iterator/iterator_categories.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include <thrust/tuple.h>
namespace cudf {
namespace io {
namespace parquet {
namespace gpu {
namespace {
constexpr int block_size = 128;
constexpr int non_zero_buffer_size = block_size * 2;
constexpr int rolling_index(int index) { return index & (non_zero_buffer_size - 1); }
struct page_state_s {
const uint8_t* data_start;
const uint8_t* data_end;
const uint8_t* lvl_end;
const uint8_t* dict_base; // ptr to dictionary page data
int32_t dict_size; // size of dictionary data
int32_t first_row; // First row in page to output
int32_t num_rows; // Rows in page to decode (including rows to be skipped)
int32_t first_output_value; // First value in page to output
int32_t num_input_values; // total # of input/level values in the page
int32_t dtype_len; // Output data type length
int32_t dtype_len_in; // Can be larger than dtype_len if truncating 32-bit into 8-bit
int32_t dict_bits; // # of bits to store dictionary indices
uint32_t dict_run;
int32_t dict_val;
uint32_t initial_rle_run[NUM_LEVEL_TYPES]; // [def,rep]
int32_t initial_rle_value[NUM_LEVEL_TYPES]; // [def,rep]
int32_t error;
PageInfo page;
ColumnChunkDesc col;
// (leaf) value decoding
int32_t nz_count; // number of valid entries in nz_idx (write position in circular buffer)
int32_t dict_pos; // write position of dictionary indices
int32_t src_pos; // input read position of final output value
int32_t ts_scale; // timestamp scale: <0: divide by -ts_scale, >0: multiply by ts_scale
// repetition/definition level decoding
int32_t input_value_count; // how many values of the input we've processed
int32_t input_row_count; // how many rows of the input we've processed
int32_t input_leaf_count; // how many leaf values of the input we've processed
uint32_t rep[non_zero_buffer_size]; // circular buffer of repetition level values
uint32_t def[non_zero_buffer_size]; // circular buffer of definition level values
const uint8_t* lvl_start[NUM_LEVEL_TYPES]; // [def,rep]
int32_t lvl_count[NUM_LEVEL_TYPES]; // how many of each of the streams we've decoded
int32_t row_index_lower_bound; // lower bound of row indices we should process
// a shared-memory cache of frequently used data when decoding. The source of this data is
// normally stored in global memory which can yield poor performance. So, when possible
// we copy that info here prior to decoding
PageNestingDecodeInfo nesting_decode_cache[max_cacheable_nesting_decode_info];
// points to either nesting_decode_cache above when possible, or to the global source otherwise
PageNestingDecodeInfo* nesting_info;
};
// buffers only used in the decode kernel. separated from page_state_s to keep
// shared memory usage in other kernels (eg, gpuComputePageSizes) down.
struct page_state_buffers_s {
uint32_t nz_idx[non_zero_buffer_size]; // circular buffer of non-null value positions
uint32_t dict_idx[non_zero_buffer_size]; // Dictionary index, boolean, or string offset values
uint32_t str_len[non_zero_buffer_size]; // String length for plain encoding of strings
};
/**
* @brief Returns whether or not a page spans either the beginning or the end of the
* specified row bounds
*
* @param s The page to be checked
* @param start_row The starting row index
* @param num_rows The number of rows
*
* @return True if the page spans the beginning or the end of the row bounds
*/
inline __device__ bool is_bounds_page(page_state_s* const s, size_t start_row, size_t num_rows)
{
size_t const page_begin = s->col.start_row + s->page.chunk_row;
size_t const page_end = page_begin + s->page.num_rows;
size_t const begin = start_row;
size_t const end = start_row + num_rows;
return ((page_begin <= begin && page_end >= begin) || (page_begin <= end && page_end >= end));
}
/**
* @brief Returns whether or not a page is completely contained within the specified
* row bounds
*
* @param s The page to be checked
* @param start_row The starting row index
* @param num_rows The number of rows
*
* @return True if the page is completely contained within the row bounds
*/
inline __device__ bool is_page_contained(page_state_s* const s, size_t start_row, size_t num_rows)
{
size_t const page_begin = s->col.start_row + s->page.chunk_row;
size_t const page_end = page_begin + s->page.num_rows;
size_t const begin = start_row;
size_t const end = start_row + num_rows;
return page_begin >= begin && page_end <= end;
}
/**
* @brief Read a 32-bit varint integer
*
* @param[in,out] cur The current data position, updated after the read
* @param[in] end The end data position
*
* @return The 32-bit value read
*/
inline __device__ uint32_t get_vlq32(const uint8_t*& cur, const uint8_t* end)
{
uint32_t v = *cur++;
if (v >= 0x80 && cur < end) {
v = (v & 0x7f) | ((*cur++) << 7);
if (v >= (0x80 << 7) && cur < end) {
v = (v & ((0x7f << 7) | 0x7f)) | ((*cur++) << 14);
if (v >= (0x80 << 14) && cur < end) {
v = (v & ((0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 21);
if (v >= (0x80 << 21) && cur < end) {
v = (v & ((0x7f << 21) | (0x7f << 14) | (0x7f << 7) | 0x7f)) | ((*cur++) << 28);
}
}
}
}
return v;
}
/**
* @brief Parse the beginning of the level section (definition or repetition),
* initializes the initial RLE run & value, and returns the section length
*
* @param[in,out] s The page state
* @param[in] cur The current data position
* @param[in] end The end of the data
* @param[in] level_bits The bits required
*
* @return The length of the section
*/
__device__ uint32_t InitLevelSection(page_state_s* s,
const uint8_t* cur,
const uint8_t* end,
level_type lvl)
{
int32_t len;
int level_bits = s->col.level_bits[lvl];
Encoding encoding = lvl == level_type::DEFINITION ? s->page.definition_level_encoding
: s->page.repetition_level_encoding;
if (level_bits == 0) {
len = 0;
s->initial_rle_run[lvl] = s->page.num_input_values * 2; // repeated value
s->initial_rle_value[lvl] = 0;
s->lvl_start[lvl] = cur;
} else if (encoding == Encoding::RLE) {
// V2 only uses RLE encoding, so only perform check here
if (s->page.def_lvl_bytes || s->page.rep_lvl_bytes) {
len = lvl == level_type::DEFINITION ? s->page.def_lvl_bytes : s->page.rep_lvl_bytes;
} else if (cur + 4 < end) {
len = 4 + (cur[0]) + (cur[1] << 8) + (cur[2] << 16) + (cur[3] << 24);
cur += 4;
} else {
len = 0;
s->error = 2;
}
if (!s->error) {
uint32_t run = get_vlq32(cur, end);
s->initial_rle_run[lvl] = run;
if (!(run & 1)) {
int v = (cur < end) ? cur[0] : 0;
cur++;
if (level_bits > 8) {
v |= ((cur < end) ? cur[0] : 0) << 8;
cur++;
}
s->initial_rle_value[lvl] = v;
}
s->lvl_start[lvl] = cur;
if (cur > end) { s->error = 2; }
}
} else if (encoding == Encoding::BIT_PACKED) {
len = (s->page.num_input_values * level_bits + 7) >> 3;
s->initial_rle_run[lvl] = ((s->page.num_input_values + 7) >> 3) * 2 + 1; // literal run
s->initial_rle_value[lvl] = 0;
s->lvl_start[lvl] = cur;
} else {
s->error = 3;
len = 0;
}
return static_cast<uint32_t>(len);
}
/**
* @brief Decode values out of a definition or repetition stream
*
* @param[in,out] s Page state input/output
* @param[in] t target_count Target count of stream values on output
* @param[in] t Warp0 thread ID (0..31)
* @param[in] lvl The level type we are decoding - DEFINITION or REPETITION
*/
__device__ void gpuDecodeStream(
uint32_t* output, page_state_s* s, int32_t target_count, int t, level_type lvl)
{
const uint8_t* cur_def = s->lvl_start[lvl];
const uint8_t* end = s->lvl_end;
uint32_t level_run = s->initial_rle_run[lvl];
int32_t level_val = s->initial_rle_value[lvl];
int level_bits = s->col.level_bits[lvl];
int32_t num_input_values = s->num_input_values;
int32_t value_count = s->lvl_count[lvl];
int32_t batch_coded_count = 0;
while (value_count < target_count && value_count < num_input_values) {
int batch_len;
if (level_run <= 1) {
// Get a new run symbol from the byte stream
int sym_len = 0;
if (!t) {
const uint8_t* cur = cur_def;
if (cur < end) { level_run = get_vlq32(cur, end); }
if (!(level_run & 1)) {
if (cur < end) level_val = cur[0];
cur++;
if (level_bits > 8) {
if (cur < end) level_val |= cur[0] << 8;
cur++;
}
}
if (cur > end || level_run <= 1) { s->error = 0x10; }
sym_len = (int32_t)(cur - cur_def);
__threadfence_block();
}
sym_len = shuffle(sym_len);
level_val = shuffle(level_val);
level_run = shuffle(level_run);
cur_def += sym_len;
}
if (s->error) { break; }
batch_len = min(num_input_values - value_count, 32);
if (level_run & 1) {
// Literal run
int batch_len8;
batch_len = min(batch_len, (level_run >> 1) * 8);
batch_len8 = (batch_len + 7) >> 3;
if (t < batch_len) {
int bitpos = t * level_bits;
const uint8_t* cur = cur_def + (bitpos >> 3);
bitpos &= 7;
if (cur < end) level_val = cur[0];
cur++;
if (level_bits > 8 - bitpos && cur < end) {
level_val |= cur[0] << 8;
cur++;
if (level_bits > 16 - bitpos && cur < end) level_val |= cur[0] << 16;
}
level_val = (level_val >> bitpos) & ((1 << level_bits) - 1);
}
level_run -= batch_len8 * 2;
cur_def += batch_len8 * level_bits;
} else {
// Repeated value
batch_len = min(batch_len, level_run >> 1);
level_run -= batch_len * 2;
}
if (t < batch_len) {
int idx = value_count + t;
output[rolling_index(idx)] = level_val;
}
batch_coded_count += batch_len;
value_count += batch_len;
}
// update the stream info
if (!t) {
s->lvl_start[lvl] = cur_def;
s->initial_rle_run[lvl] = level_run;
s->initial_rle_value[lvl] = level_val;
s->lvl_count[lvl] = value_count;
}
}
/**
* @brief Performs RLE decoding of dictionary indexes
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] target_pos Target index position in dict_idx buffer (may exceed this value by up to
* 31)
* @param[in] t Warp1 thread ID (0..31)
*
* @return A pair containing the new output position, and the total length of strings decoded (this
* will only be valid on thread 0 and if sizes_only is true). In the event that this function
* decodes strings beyond target_pos, the total length of strings returned will include these
* additional values.
*/
template <bool sizes_only>
__device__ cuda::std::pair<int, int> gpuDecodeDictionaryIndices(
volatile page_state_s* s,
[[maybe_unused]] volatile page_state_buffers_s* sb,
int target_pos,
int t)
{
const uint8_t* end = s->data_end;
int dict_bits = s->dict_bits;
int pos = s->dict_pos;
int str_len = 0;
while (pos < target_pos) {
int is_literal, batch_len;
if (!t) {
uint32_t run = s->dict_run;
const uint8_t* cur = s->data_start;
if (run <= 1) {
run = (cur < end) ? get_vlq32(cur, end) : 0;
if (!(run & 1)) {
// Repeated value
int bytecnt = (dict_bits + 7) >> 3;
if (cur + bytecnt <= end) {
int32_t run_val = cur[0];
if (bytecnt > 1) {
run_val |= cur[1] << 8;
if (bytecnt > 2) {
run_val |= cur[2] << 16;
if (bytecnt > 3) { run_val |= cur[3] << 24; }
}
}
s->dict_val = run_val & ((1 << dict_bits) - 1);
}
cur += bytecnt;
}
}
if (run & 1) {
// Literal batch: must output a multiple of 8, except for the last batch
int batch_len_div8;
batch_len = max(min(32, (int)(run >> 1) * 8), 1);
batch_len_div8 = (batch_len + 7) >> 3;
run -= batch_len_div8 * 2;
cur += batch_len_div8 * dict_bits;
} else {
batch_len = max(min(32, (int)(run >> 1)), 1);
run -= batch_len * 2;
}
s->dict_run = run;
s->data_start = cur;
is_literal = run & 1;
__threadfence_block();
}
__syncwarp();
is_literal = shuffle(is_literal);
batch_len = shuffle(batch_len);
// compute dictionary index.
int dict_idx = 0;
if (t < batch_len) {
dict_idx = s->dict_val;
if (is_literal) {
int32_t ofs = (t - ((batch_len + 7) & ~7)) * dict_bits;
const uint8_t* p = s->data_start + (ofs >> 3);
ofs &= 7;
if (p < end) {
uint32_t c = 8 - ofs;
dict_idx = (*p++) >> ofs;
if (c < dict_bits && p < end) {
dict_idx |= (*p++) << c;
c += 8;
if (c < dict_bits && p < end) {
dict_idx |= (*p++) << c;
c += 8;
if (c < dict_bits && p < end) { dict_idx |= (*p++) << c; }
}
}
dict_idx &= (1 << dict_bits) - 1;
}
}
// if we're not computing sizes, store off the dictionary index
if constexpr (!sizes_only) { sb->dict_idx[rolling_index(pos + t)] = dict_idx; }
}
// if we're computing sizes, add the length(s)
if constexpr (sizes_only) {
int const len = [&]() {
if (t >= batch_len || (pos + t >= target_pos)) { return 0; }
uint32_t const dict_pos = (s->dict_bits > 0) ? dict_idx * sizeof(string_index_pair) : 0;
if (dict_pos < (uint32_t)s->dict_size) {
const auto* src = reinterpret_cast<const string_index_pair*>(s->dict_base + dict_pos);
return src->second;
}
return 0;
}();
using WarpReduce = cub::WarpReduce<size_type>;
__shared__ typename WarpReduce::TempStorage temp_storage;
// note: str_len will only be valid on thread 0.
str_len += WarpReduce(temp_storage).Sum(len);
}
pos += batch_len;
}
return {pos, str_len};
}
/**
* @brief Performs RLE decoding of dictionary indexes, for when dict_size=1
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] target_pos Target write position
* @param[in] t Thread ID
*
* @return The new output position
*/
__device__ int gpuDecodeRleBooleans(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int target_pos,
int t)
{
const uint8_t* end = s->data_end;
int pos = s->dict_pos;
while (pos < target_pos) {
int is_literal, batch_len;
if (!t) {
uint32_t run = s->dict_run;
const uint8_t* cur = s->data_start;
if (run <= 1) {
run = (cur < end) ? get_vlq32(cur, end) : 0;
if (!(run & 1)) {
// Repeated value
s->dict_val = (cur < end) ? cur[0] & 1 : 0;
cur++;
}
}
if (run & 1) {
// Literal batch: must output a multiple of 8, except for the last batch
int batch_len_div8;
batch_len = max(min(32, (int)(run >> 1) * 8), 1);
if (batch_len >= 8) { batch_len &= ~7; }
batch_len_div8 = (batch_len + 7) >> 3;
run -= batch_len_div8 * 2;
cur += batch_len_div8;
} else {
batch_len = max(min(32, (int)(run >> 1)), 1);
run -= batch_len * 2;
}
s->dict_run = run;
s->data_start = cur;
is_literal = run & 1;
__threadfence_block();
}
__syncwarp();
is_literal = shuffle(is_literal);
batch_len = shuffle(batch_len);
if (t < batch_len) {
int dict_idx;
if (is_literal) {
int32_t ofs = t - ((batch_len + 7) & ~7);
const uint8_t* p = s->data_start + (ofs >> 3);
dict_idx = (p < end) ? (p[0] >> (ofs & 7u)) & 1 : 0;
} else {
dict_idx = s->dict_val;
}
sb->dict_idx[rolling_index(pos + t)] = dict_idx;
}
pos += batch_len;
}
return pos;
}
/**
* @brief Parses the length and position of strings and returns total length of all strings
* processed
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] target_pos Target output position
* @param[in] t Thread ID
*
* @return Total length of strings processed
*/
template <bool sizes_only>
__device__ size_type gpuInitStringDescriptors(volatile page_state_s* s,
[[maybe_unused]] volatile page_state_buffers_s* sb,
int target_pos,
int t)
{
int pos = s->dict_pos;
int total_len = 0;
// This step is purely serial
if (!t) {
const uint8_t* cur = s->data_start;
int dict_size = s->dict_size;
int k = s->dict_val;
while (pos < target_pos) {
int len;
if (k + 4 <= dict_size) {
len = (cur[k]) | (cur[k + 1] << 8) | (cur[k + 2] << 16) | (cur[k + 3] << 24);
k += 4;
if (k + len > dict_size) { len = 0; }
} else {
len = 0;
}
if constexpr (!sizes_only) {
sb->dict_idx[rolling_index(pos)] = k;
sb->str_len[rolling_index(pos)] = len;
}
k += len;
total_len += len;
pos++;
}
s->dict_val = k;
__threadfence_block();
}
return total_len;
}
/**
* @brief Retrieves string information for a string at the specified source position
*
* @param[in] s Page state input
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
*
* @return A pair containing a pointer to the string and its length
*/
inline __device__ cuda::std::pair<const char*, size_t> gpuGetStringData(
volatile page_state_s* s, volatile page_state_buffers_s* sb, int src_pos)
{
const char* ptr = nullptr;
size_t len = 0;
if (s->dict_base) {
// String dictionary
uint32_t dict_pos =
(s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] * sizeof(string_index_pair) : 0;
if (dict_pos < (uint32_t)s->dict_size) {
const auto* src = reinterpret_cast<const string_index_pair*>(s->dict_base + dict_pos);
ptr = src->first;
len = src->second;
}
} else {
// Plain encoding
uint32_t dict_pos = sb->dict_idx[rolling_index(src_pos)];
if (dict_pos <= (uint32_t)s->dict_size) {
ptr = reinterpret_cast<const char*>(s->data_start + dict_pos);
len = sb->str_len[rolling_index(src_pos)];
}
}
return {ptr, len};
}
/**
* @brief Output a string descriptor
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dstv Pointer to row output data (string descriptor or 32-bit hash)
*/
inline __device__ void gpuOutputString(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int src_pos,
void* dstv)
{
auto [ptr, len] = gpuGetStringData(s, sb, src_pos);
if (s->dtype_len == 4) {
// Output hash. This hash value is used if the option to convert strings to
// categoricals is enabled. The seed value is chosen arbitrarily.
uint32_t constexpr hash_seed = 33;
cudf::string_view const sv{ptr, static_cast<size_type>(len)};
*static_cast<uint32_t*>(dstv) = cudf::detail::MurmurHash3_32<cudf::string_view>{hash_seed}(sv);
} else {
// Output string descriptor
auto* dst = static_cast<string_index_pair*>(dstv);
dst->first = ptr;
dst->second = len;
}
}
/**
* @brief Output a boolean
*
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dst Pointer to row output data
*/
inline __device__ void gpuOutputBoolean(volatile page_state_buffers_s* sb,
int src_pos,
uint8_t* dst)
{
*dst = sb->dict_idx[rolling_index(src_pos)];
}
/**
* @brief Store a 32-bit data element
*
* @param[out] dst ptr to output
* @param[in] src8 raw input bytes
* @param[in] dict_pos byte position in dictionary
* @param[in] dict_size size of dictionary
*/
inline __device__ void gpuStoreOutput(uint32_t* dst,
const uint8_t* src8,
uint32_t dict_pos,
uint32_t dict_size)
{
uint32_t bytebuf;
unsigned int ofs = 3 & reinterpret_cast<size_t>(src8);
src8 -= ofs; // align to 32-bit boundary
ofs <<= 3; // bytes -> bits
if (dict_pos < dict_size) {
bytebuf = *reinterpret_cast<const uint32_t*>(src8 + dict_pos);
if (ofs) {
uint32_t bytebufnext = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 4);
bytebuf = __funnelshift_r(bytebuf, bytebufnext, ofs);
}
} else {
bytebuf = 0;
}
*dst = bytebuf;
}
/**
* @brief Store a 64-bit data element
*
* @param[out] dst ptr to output
* @param[in] src8 raw input bytes
* @param[in] dict_pos byte position in dictionary
* @param[in] dict_size size of dictionary
*/
inline __device__ void gpuStoreOutput(uint2* dst,
const uint8_t* src8,
uint32_t dict_pos,
uint32_t dict_size)
{
uint2 v;
unsigned int ofs = 3 & reinterpret_cast<size_t>(src8);
src8 -= ofs; // align to 32-bit boundary
ofs <<= 3; // bytes -> bits
if (dict_pos < dict_size) {
v.x = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 0);
v.y = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 4);
if (ofs) {
uint32_t next = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 8);
v.x = __funnelshift_r(v.x, v.y, ofs);
v.y = __funnelshift_r(v.y, next, ofs);
}
} else {
v.x = v.y = 0;
}
*dst = v;
}
/**
* @brief Convert an INT96 Spark timestamp to 64-bit timestamp
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[out] dst Pointer to row output data
*/
inline __device__ void gpuOutputInt96Timestamp(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int src_pos,
int64_t* dst)
{
using cuda::std::chrono::duration_cast;
const uint8_t* src8;
uint32_t dict_pos, dict_size = s->dict_size, ofs;
if (s->dict_base) {
// Dictionary
dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0;
src8 = s->dict_base;
} else {
// Plain
dict_pos = src_pos;
src8 = s->data_start;
}
dict_pos *= (uint32_t)s->dtype_len_in;
ofs = 3 & reinterpret_cast<size_t>(src8);
src8 -= ofs; // align to 32-bit boundary
ofs <<= 3; // bytes -> bits
if (dict_pos + 4 >= dict_size) {
*dst = 0;
return;
}
uint3 v;
int64_t nanos, days;
v.x = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 0);
v.y = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 4);
v.z = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 8);
if (ofs) {
uint32_t next = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 12);
v.x = __funnelshift_r(v.x, v.y, ofs);
v.y = __funnelshift_r(v.y, v.z, ofs);
v.z = __funnelshift_r(v.z, next, ofs);
}
nanos = v.y;
nanos <<= 32;
nanos |= v.x;
// Convert from Julian day at noon to UTC seconds
days = static_cast<int32_t>(v.z);
cudf::duration_D d_d{
days - 2440588}; // TBD: Should be noon instead of midnight, but this matches pyarrow
*dst = [&]() {
switch (s->col.ts_clock_rate) {
case 1: // seconds
return duration_cast<duration_s>(d_d).count() +
duration_cast<duration_s>(duration_ns{nanos}).count();
case 1'000: // milliseconds
return duration_cast<duration_ms>(d_d).count() +
duration_cast<duration_ms>(duration_ns{nanos}).count();
case 1'000'000: // microseconds
return duration_cast<duration_us>(d_d).count() +
duration_cast<duration_us>(duration_ns{nanos}).count();
case 1'000'000'000: // nanoseconds
default: return duration_cast<cudf::duration_ns>(d_d).count() + nanos;
}
}();
}
/**
* @brief Output a 64-bit timestamp
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dst Pointer to row output data
*/
inline __device__ void gpuOutputInt64Timestamp(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int src_pos,
int64_t* dst)
{
const uint8_t* src8;
uint32_t dict_pos, dict_size = s->dict_size, ofs;
int64_t ts;
if (s->dict_base) {
// Dictionary
dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0;
src8 = s->dict_base;
} else {
// Plain
dict_pos = src_pos;
src8 = s->data_start;
}
dict_pos *= (uint32_t)s->dtype_len_in;
ofs = 3 & reinterpret_cast<size_t>(src8);
src8 -= ofs; // align to 32-bit boundary
ofs <<= 3; // bytes -> bits
if (dict_pos + 4 < dict_size) {
uint2 v;
int64_t val;
int32_t ts_scale;
v.x = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 0);
v.y = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 4);
if (ofs) {
uint32_t next = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 8);
v.x = __funnelshift_r(v.x, v.y, ofs);
v.y = __funnelshift_r(v.y, next, ofs);
}
val = v.y;
val <<= 32;
val |= v.x;
// Output to desired clock rate
ts_scale = s->ts_scale;
if (ts_scale < 0) {
// round towards negative infinity
int sign = (val < 0);
ts = ((val + sign) / -ts_scale) + sign;
} else {
ts = val * ts_scale;
}
} else {
ts = 0;
}
*dst = ts;
}
/**
* @brief Output a byte array as int.
*
* @param[in] ptr Pointer to the byte array
* @param[in] len Byte array length
* @param[out] dst Pointer to row output data
*/
template <typename T>
__device__ void gpuOutputByteArrayAsInt(char const* ptr, int32_t len, T* dst)
{
T unscaled = 0;
for (auto i = 0; i < len; i++) {
uint8_t v = ptr[i];
unscaled = (unscaled << 8) | v;
}
// Shift the unscaled value up and back down when it isn't all 8 bytes,
// which sign extend the value for correctly representing negative numbers.
unscaled <<= (sizeof(T) - len) * 8;
unscaled >>= (sizeof(T) - len) * 8;
*dst = unscaled;
}
/**
* @brief Output a fixed-length byte array as int.
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dst Pointer to row output data
*/
template <typename T>
__device__ void gpuOutputFixedLenByteArrayAsInt(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int src_pos,
T* dst)
{
uint32_t const dtype_len_in = s->dtype_len_in;
uint8_t const* data = s->dict_base ? s->dict_base : s->data_start;
uint32_t const pos =
(s->dict_base ? ((s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0) : src_pos) *
dtype_len_in;
uint32_t const dict_size = s->dict_size;
T unscaled = 0;
for (unsigned int i = 0; i < dtype_len_in; i++) {
uint32_t v = (pos + i < dict_size) ? data[pos + i] : 0;
unscaled = (unscaled << 8) | v;
}
// Shift the unscaled value up and back down when it isn't all 8 bytes,
// which sign extend the value for correctly representing negative numbers.
if (dtype_len_in < sizeof(T)) {
unscaled <<= (sizeof(T) - dtype_len_in) * 8;
unscaled >>= (sizeof(T) - dtype_len_in) * 8;
}
*dst = unscaled;
}
/**
* @brief Output a small fixed-length value
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dst Pointer to row output data
*/
template <typename T>
inline __device__ void gpuOutputFast(volatile page_state_s* s,
volatile page_state_buffers_s* sb,
int src_pos,
T* dst)
{
const uint8_t* dict;
uint32_t dict_pos, dict_size = s->dict_size;
if (s->dict_base) {
// Dictionary
dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0;
dict = s->dict_base;
} else {
// Plain
dict_pos = src_pos;
dict = s->data_start;
}
dict_pos *= (uint32_t)s->dtype_len_in;
gpuStoreOutput(dst, dict, dict_pos, dict_size);
}
/**
* @brief Output a N-byte value
*
* @param[in,out] s Page state input/output
* @param[out] sb Page state buffer output
* @param[in] src_pos Source position
* @param[in] dst8 Pointer to row output data
* @param[in] len Length of element
*/
static __device__ void gpuOutputGeneric(
volatile page_state_s* s, volatile page_state_buffers_s* sb, int src_pos, uint8_t* dst8, int len)
{
const uint8_t* dict;
uint32_t dict_pos, dict_size = s->dict_size;
if (s->dict_base) {
// Dictionary
dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0;
dict = s->dict_base;
} else {
// Plain
dict_pos = src_pos;
dict = s->data_start;
}
dict_pos *= (uint32_t)s->dtype_len_in;
if (len & 3) {
// Generic slow path
for (unsigned int i = 0; i < len; i++) {
dst8[i] = (dict_pos + i < dict_size) ? dict[dict_pos + i] : 0;
}
} else {
// Copy 4 bytes at a time
const uint8_t* src8 = dict;
unsigned int ofs = 3 & reinterpret_cast<size_t>(src8);
src8 -= ofs; // align to 32-bit boundary
ofs <<= 3; // bytes -> bits
for (unsigned int i = 0; i < len; i += 4) {
uint32_t bytebuf;
if (dict_pos < dict_size) {
bytebuf = *reinterpret_cast<const uint32_t*>(src8 + dict_pos);
if (ofs) {
uint32_t bytebufnext = *reinterpret_cast<const uint32_t*>(src8 + dict_pos + 4);
bytebuf = __funnelshift_r(bytebuf, bytebufnext, ofs);
}
} else {
bytebuf = 0;
}
dict_pos += 4;
*reinterpret_cast<uint32_t*>(dst8 + i) = bytebuf;
}
}
}
/**
* @brief Sets up block-local page state information from the global pages.
*
* @param[in, out] s The local page state to be filled in
* @param[in] p The global page to be copied from
* @param[in] chunks The global list of chunks
* @param[in] min_row Crop all rows below min_row
* @param[in] num_rows Maximum number of rows to read
* @param[in] is_decode_step If we are setting up for the decode step (instead of the preprocess
* step)
*/
static __device__ bool setupLocalPageInfo(page_state_s* const s,
PageInfo const* p,
device_span<ColumnChunkDesc const> chunks,
size_t min_row,
size_t num_rows,
bool is_decode_step)
{
int t = threadIdx.x;
int chunk_idx;
// Fetch page info
if (!t) s->page = *p;
__syncthreads();
if (s->page.flags & PAGEINFO_FLAGS_DICTIONARY) { return false; }