-
Notifications
You must be signed in to change notification settings - Fork 933
/
Copy pathjson_normalization.cu
456 lines (410 loc) · 20.4 KB
/
json_normalization.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "io/fst/lookup_tables.cuh"
#include <cudf/detail/device_scalar.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/detail/json.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
#include <cub/device/device_copy.cuh>
#include <cuda/atomic>
#include <thrust/binary_search.h>
#include <thrust/distance.h>
#include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h>
#include <thrust/remove.h>
#include <cstdlib>
#include <string>
#include <vector>
namespace cudf::io::json {
// Type used to represent the atomic symbol type used within the finite-state machine
using SymbolT = char;
using StateT = char;
// Type sufficiently large to index symbols within the input and output (may be unsigned)
using SymbolOffsetT = uint32_t;
namespace normalize_quotes {
enum class dfa_states : StateT { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES };
enum class dfa_symbol_group_id : uint32_t {
DOUBLE_QUOTE_CHAR, ///< Quote character SG: "
SINGLE_QUOTE_CHAR, ///< Quote character SG: '
ESCAPE_CHAR, ///< Escape character SG: '\'
NEWLINE_CHAR, ///< Newline character SG: '\n'
OTHER_SYMBOLS, ///< SG implicitly matching all other characters
NUM_SYMBOL_GROUPS ///< Total number of symbol groups
};
// Aliases for readability of the transition table
constexpr auto TT_OOS = dfa_states::TT_OOS;
constexpr auto TT_DQS = dfa_states::TT_DQS;
constexpr auto TT_SQS = dfa_states::TT_SQS;
constexpr auto TT_DEC = dfa_states::TT_DEC;
constexpr auto TT_SEC = dfa_states::TT_SEC;
constexpr auto TT_NUM_STATES = static_cast<StateT>(dfa_states::TT_NUM_STATES);
constexpr auto NUM_SYMBOL_GROUPS = static_cast<uint32_t>(dfa_symbol_group_id::NUM_SYMBOL_GROUPS);
// The i-th string representing all the characters of a symbol group
std::array<std::vector<SymbolT>, NUM_SYMBOL_GROUPS - 1> const qna_sgs{
{{'\"'}, {'\''}, {'\\'}, {'\n'}}};
// Transition table
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const qna_state_tt{{
/* IN_STATE " ' \ \n OTHER */
/* TT_OOS */ {{TT_DQS, TT_SQS, TT_OOS, TT_OOS, TT_OOS}},
/* TT_DQS */ {{TT_OOS, TT_DQS, TT_DEC, TT_OOS, TT_DQS}},
/* TT_SQS */ {{TT_SQS, TT_OOS, TT_SEC, TT_OOS, TT_SQS}},
/* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS}},
/* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SQS}},
}};
// The DFA's starting state
constexpr auto start_state = static_cast<StateT>(TT_OOS);
struct TransduceToNormalizedQuotes {
/**
* @brief Returns the <relative_offset>-th output symbol on the transition (state_id, match_id).
*/
template <typename StateT, typename SymbolGroupT, typename RelativeOffsetT, typename SymbolT>
constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id,
SymbolGroupT const match_id,
RelativeOffsetT const relative_offset,
SymbolT const read_symbol) const
{
// -------- TRANSLATION TABLE ------------
// Let the alphabet set be Sigma
// ---------------------------------------
// ---------- NON-SPECIAL CASES: ----------
// Output symbol same as input symbol <s>
// state | read_symbol <s> -> output_symbol <s>
// DQS | Sigma -> Sigma
// DEC | Sigma -> Sigma
// OOS | Sigma\{'} -> Sigma\{'}
// SQS | Sigma\{', "} -> Sigma\{', "}
// ---------- SPECIAL CASES: --------------
// Input symbol translates to output symbol
// OOS | {'} -> {"}
// SQS | {'} -> {"}
// SQS | {"} -> {\"}
// SQS | {\} -> <nop>
// DQS | {\} -> <nop>
// SEC | {'} -> {'}
// SEC | Sigma\{'} -> {\*}
// DEC | {'} -> {'}
// DEC | Sigma\{'} -> {\*}
// Whether this transition translates to the escape sequence: \"
bool const outputs_escape_sequence =
(state_id == static_cast<StateT>(dfa_states::TT_SQS)) &&
(match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR));
// Case when a double quote needs to be replaced by the escape sequence: \"
if (outputs_escape_sequence) { return (relative_offset == 0) ? '\\' : '"'; }
// Case when a single quote needs to be replaced by a double quote
if ((match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) &&
((state_id == static_cast<StateT>(dfa_states::TT_SQS)) ||
(state_id == static_cast<StateT>(dfa_states::TT_OOS)))) {
return '"';
}
// Case when the read symbol is an escape character - the actual translation for \<s> for some
// symbol <s> is handled by transitions from SEC. The same logic applies for the transition from
// DEC. For now, there is no output for this transition
if (match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::ESCAPE_CHAR) &&
(state_id == static_cast<StateT>(dfa_states::TT_SQS) ||
state_id == static_cast<StateT>(dfa_states::TT_DQS))) {
return 0;
}
// Case when an escaped single quote in an input single-quoted or double-quoted string needs
// to be replaced by an unescaped single quote
if (match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::SINGLE_QUOTE_CHAR) &&
(state_id == static_cast<StateT>(dfa_states::TT_SEC) ||
state_id == static_cast<StateT>(dfa_states::TT_DEC))) {
return '\'';
}
// Case when an escaped symbol <s> that is not a single-quote needs to be replaced with \<s>
if (state_id == static_cast<StateT>(dfa_states::TT_SEC) ||
state_id == static_cast<StateT>(dfa_states::TT_DEC)) {
return (relative_offset == 0) ? '\\' : read_symbol;
}
// In all other cases we simply output the input symbol
return read_symbol;
}
/**
* @brief Returns the number of output characters for a given transition. During quote
* normalization, we always emit one output character (i.e., either the input character or the
* single quote-input replaced by a double quote), except when we need to escape a double quote
* that was previously inside a single-quoted string.
*/
template <typename StateT, typename SymbolGroupT, typename SymbolT>
constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id,
SymbolGroupT const match_id,
SymbolT const read_symbol) const
{
// Whether this transition translates to the escape sequence: \"
bool const sqs_outputs_escape_sequence =
(state_id == static_cast<StateT>(dfa_states::TT_SQS)) &&
(match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR));
// Number of characters to output on this transition
if (sqs_outputs_escape_sequence) { return 2; }
// Whether this transition translates to the escape sequence \<s> or unescaped '
bool const sec_dec_outputs_escape_sequence =
(state_id == static_cast<StateT>(dfa_states::TT_SEC) ||
state_id == static_cast<StateT>(dfa_states::TT_DEC)) &&
(match_id != static_cast<SymbolGroupT>(dfa_symbol_group_id::SINGLE_QUOTE_CHAR));
// Number of characters to output on this transition
if (sec_dec_outputs_escape_sequence) { return 2; }
// Whether this transition translates to no output <nop>
bool const sqs_dqs_outputs_nop =
(state_id == static_cast<StateT>(dfa_states::TT_SQS) ||
state_id == static_cast<StateT>(dfa_states::TT_DQS)) &&
(match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::ESCAPE_CHAR));
// Number of characters to output on this transition
if (sqs_dqs_outputs_nop) { return 0; }
return 1;
}
};
} // namespace normalize_quotes
namespace normalize_whitespace {
enum class dfa_symbol_group_id : uint32_t {
DOUBLE_QUOTE_CHAR, ///< Quote character SG: "
ESCAPE_CHAR, ///< Escape character SG: '\\'
NEWLINE_CHAR, ///< Newline character SG: '\n'
WHITESPACE_SYMBOLS, ///< Whitespace characters SG: '\t' or ' '
OTHER_SYMBOLS, ///< SG implicitly matching all other characters
NUM_SYMBOL_GROUPS ///< Total number of symbol groups
};
// Alias for readability of symbol group ids
constexpr auto NUM_SYMBOL_GROUPS = static_cast<uint32_t>(dfa_symbol_group_id::NUM_SYMBOL_GROUPS);
// The i-th string representing all the characters of a symbol group
std::array<std::vector<SymbolT>, NUM_SYMBOL_GROUPS - 1> const wna_sgs{
{{'"'}, {'\\'}, {'\n'}, {' ', '\t'}}};
/**
* -------- FST states ---------
* -----------------------------
* TT_OOS | Out-of-string state handling whitespace and non-whitespace chars outside double
* | quotes as well as any other character not enclosed by a string. Also handles
* | newline character present within a string
* TT_DQS | Double-quoted string state handling all characters within double quotes except
* | newline character
* TT_DEC | State handling escaped characters inside double-quoted string. Note that this
* | state is necessary to process escaped double-quote characters. Without this
* | state, whitespaces following escaped double quotes inside strings may be removed.
*
*/
enum class dfa_states : StateT { TT_OOS = 0U, TT_DQS, TT_DEC, TT_NUM_STATES };
// Aliases for readability of the transition table
constexpr auto TT_OOS = dfa_states::TT_OOS;
constexpr auto TT_DQS = dfa_states::TT_DQS;
constexpr auto TT_DEC = dfa_states::TT_DEC;
constexpr auto TT_NUM_STATES = static_cast<StateT>(dfa_states::TT_NUM_STATES);
// Transition table
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const wna_state_tt{
{/* IN_STATE " \ \n <SPC> OTHER */
/* TT_OOS */ {{TT_DQS, TT_OOS, TT_OOS, TT_OOS, TT_OOS}},
/* TT_DQS */ {{TT_OOS, TT_DEC, TT_OOS, TT_DQS, TT_DQS}},
/* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS}}}};
// The DFA's starting state
constexpr StateT start_state = static_cast<StateT>(TT_OOS);
struct TransduceToNormalizedWS {
/**
* @brief Returns the <relative_offset>-th output symbol on the transition (state_id, match_id).
*/
template <typename StateT, typename SymbolGroupT, typename RelativeOffsetT, typename SymbolT>
constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id,
SymbolGroupT const match_id,
RelativeOffsetT const relative_offset,
SymbolT const read_symbol) const
{
// -------- TRANSLATION TABLE ------------
// Let the alphabet set be Sigma
// ---------------------------------------
// ---------- NON-SPECIAL CASES: ----------
// Input symbol translates to output symbol
// state | read_symbol <s> -> output_symbol <s>
// DQS | Sigma -> <nop>
// OOS | Sigma\{<SPC>,\t} -> <nop>
// DEC | Sigma -> <nop>
// ---------- SPECIAL CASES: --------------
// Output symbol same as input symbol <s>
// OOS | {<SPC>} -> {<SPC>}
// OOS | {\t} -> {\t}
// Case when read symbol is not an unquoted space or tab
// This will be the same condition as in `operator()(state_id, match_id, read_symbol)` function
// However, since there is no output in this case i.e. the count returned by
// operator()(state_id, match_id, read_symbol) is zero, this function is never called.
// So skipping the check for this case.
// In all other cases, we have an output symbol for the input symbol.
// We simply output the input symbol
return read_symbol;
}
/**
* @brief Returns the number of output characters for a given transition.
* During whitespace normalization, we always emit one output character i.e., the input
* character, except when we need to remove the space/tab character
*/
template <typename StateT, typename SymbolGroupT, typename SymbolT>
constexpr CUDF_HOST_DEVICE uint32_t operator()(StateT const state_id,
SymbolGroupT const match_id,
SymbolT const read_symbol) const
{
// Case when read symbol is a space or tab but is unquoted
if (!(match_id == static_cast<SymbolGroupT>(dfa_symbol_group_id::WHITESPACE_SYMBOLS) &&
state_id == static_cast<StateT>(dfa_states::TT_OOS))) {
return 0;
}
return 1;
}
};
} // namespace normalize_whitespace
namespace detail {
void normalize_single_quotes(datasource::owning_buffer<rmm::device_buffer>& indata,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
static constexpr std::int32_t min_out = 0;
static constexpr std::int32_t max_out = 2;
auto parser =
fst::detail::make_fst(fst::detail::make_symbol_group_lut(normalize_quotes::qna_sgs),
fst::detail::make_transition_table(normalize_quotes::qna_state_tt),
fst::detail::make_translation_functor<SymbolT, min_out, max_out>(
normalize_quotes::TransduceToNormalizedQuotes{}),
stream);
rmm::device_buffer outbuf(indata.size() * 2, stream, mr);
cudf::detail::device_scalar<SymbolOffsetT> outbuf_size(stream, mr);
parser.Transduce(reinterpret_cast<SymbolT const*>(indata.data()),
static_cast<SymbolOffsetT>(indata.size()),
static_cast<SymbolT*>(outbuf.data()),
thrust::make_discard_iterator(),
outbuf_size.data(),
normalize_quotes::start_state,
stream);
outbuf.resize(outbuf_size.value(stream), stream);
datasource::owning_buffer<rmm::device_buffer> outdata(std::move(outbuf));
std::swap(indata, outdata);
}
std::
tuple<rmm::device_uvector<char>, rmm::device_uvector<size_type>, rmm::device_uvector<size_type>>
normalize_whitespace(device_span<char const> d_input,
device_span<size_type const> col_offsets,
device_span<size_type const> col_lengths,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
/*
* Algorithm:
1. Create a single buffer by concatenating the rows of the string column. Create segment offsets
and lengths array for concatenated buffer
2. Run a whitespace normalization FST that performs NOP for non-whitespace and quoted
whitespace characters, and outputs indices of unquoted whitespace characters
3. Update segment lengths based on the number of output indices between segment offsets
4. Remove characters at output indices from concatenated buffer.
5. Return updated buffer, segment lengths and updated segment offsets
*/
auto inbuf_lengths = cudf::detail::make_device_uvector_async(
col_lengths, stream, cudf::get_current_device_resource_ref());
size_t inbuf_lengths_size = inbuf_lengths.size();
size_type inbuf_size =
thrust::reduce(rmm::exec_policy_nosync(stream), inbuf_lengths.begin(), inbuf_lengths.end());
rmm::device_uvector<char> inbuf(inbuf_size, stream);
rmm::device_uvector<size_type> inbuf_offsets(inbuf_lengths_size, stream);
thrust::exclusive_scan(rmm::exec_policy_nosync(stream),
inbuf_lengths.begin(),
inbuf_lengths.end(),
inbuf_offsets.begin(),
0);
auto input_it = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cuda::proclaim_return_type<char const*>(
[d_input = d_input.begin(), col_offsets = col_offsets.begin()] __device__(
size_t i) -> char const* { return &d_input[col_offsets[i]]; }));
auto output_it = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cuda::proclaim_return_type<char*>(
[inbuf = inbuf.begin(), inbuf_offsets = inbuf_offsets.cbegin()] __device__(
size_t i) -> char* { return &inbuf[inbuf_offsets[i]]; }));
{
// cub device batched copy
size_t temp_storage_bytes = 0;
cub::DeviceCopy::Batched(nullptr,
temp_storage_bytes,
input_it,
output_it,
inbuf_lengths.begin(),
inbuf_lengths_size,
stream.value());
rmm::device_buffer temp_storage(temp_storage_bytes, stream);
cub::DeviceCopy::Batched(temp_storage.data(),
temp_storage_bytes,
input_it,
output_it,
inbuf_lengths.begin(),
inbuf_lengths_size,
stream.value());
}
// whitespace normalization : get the indices of the unquoted whitespace characters
auto parser =
fst::detail::make_fst(fst::detail::make_symbol_group_lut(normalize_whitespace::wna_sgs),
fst::detail::make_transition_table(normalize_whitespace::wna_state_tt),
fst::detail::make_translation_functor<SymbolT, 0, 2>(
normalize_whitespace::TransduceToNormalizedWS{}),
stream);
rmm::device_uvector<size_type> outbuf_indices(inbuf.size(), stream, mr);
cudf::detail::device_scalar<SymbolOffsetT> outbuf_indices_size(stream, mr);
parser.Transduce(inbuf.data(),
static_cast<SymbolOffsetT>(inbuf.size()),
thrust::make_discard_iterator(),
outbuf_indices.data(),
outbuf_indices_size.data(),
normalize_whitespace::start_state,
stream);
auto const num_deletions = outbuf_indices_size.value(stream);
outbuf_indices.resize(num_deletions, stream);
// now these indices need to be removed
// TODO: is there a better way to do this?
thrust::for_each(
rmm::exec_policy_nosync(stream),
outbuf_indices.begin(),
outbuf_indices.end(),
[inbuf_offsets_begin = inbuf_offsets.begin(),
inbuf_offsets_end = inbuf_offsets.end(),
inbuf_lengths = inbuf_lengths.begin()] __device__(size_type idx) {
auto it = thrust::upper_bound(thrust::seq, inbuf_offsets_begin, inbuf_offsets_end, idx);
auto pos = thrust::distance(inbuf_offsets_begin, it) - 1;
cuda::atomic_ref<size_type, cuda::thread_scope_device> ref{*(inbuf_lengths + pos)};
ref.fetch_add(-1, cuda::std::memory_order_relaxed);
});
auto stencil = cudf::detail::make_zeroed_device_uvector_async<bool>(
static_cast<std::size_t>(inbuf_size), stream, cudf::get_current_device_resource_ref());
thrust::scatter(rmm::exec_policy_nosync(stream),
thrust::make_constant_iterator(true),
thrust::make_constant_iterator(true) + num_deletions,
outbuf_indices.begin(),
stencil.begin());
thrust::remove_if(rmm::exec_policy_nosync(stream),
inbuf.begin(),
inbuf.end(),
stencil.begin(),
thrust::identity<int>());
inbuf.resize(inbuf_size - num_deletions, stream);
thrust::exclusive_scan(rmm::exec_policy_nosync(stream),
inbuf_lengths.begin(),
inbuf_lengths.end(),
inbuf_offsets.begin(),
0);
stream.synchronize();
return std::tuple{std::move(inbuf), std::move(inbuf_offsets), std::move(inbuf_lengths)};
}
} // namespace detail
} // namespace cudf::io::json