-
Notifications
You must be signed in to change notification settings - Fork 242
/
Copy pathCudfUnsafeRow.java
400 lines (358 loc) · 13.5 KB
/
CudfUnsafeRow.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import java.util.Arrays;
/**
* This is an InternalRow implementation based off of UnsafeRow, but follows a format for use with
* the row format supported by cudf. In this format each column is padded to match the alignment
* needed by it, and validity is placed at the end one byte at a time.
*
* It also supports remapping the columns so that if the columns were re-ordered to reduce packing
* in the format, then they can be mapped back to their original positions.
*
* This class is likely to go away once we move to code generation when going directly to an
* UnsafeRow through code generation. This is rather difficult because of some details in how
* UnsafeRow works.
*/
public final class CudfUnsafeRow extends InternalRow {
public static int alignOffset(int offset, int alignment) {
return (offset + alignment - 1) & -alignment;
}
public static int calculateBitSetWidthInBytes(int numFields) {
return (numFields + 7)/ 8;
}
public static int getRowSizeEstimate(Attribute[] attributes) {
// This needs to match what is in cudf and what is in the constructor.
int offset = 0;
for (Attribute attr : attributes) {
int length = GpuColumnVector.getNonNestedRapidsType(attr.dataType()).getSizeInBytes();
offset = alignOffset(offset, length);
offset += length;
}
int bitSetWidthInBytes = calculateBitSetWidthInBytes(attributes.length);
// Each row is 64-bit aligned
return alignOffset(offset + bitSetWidthInBytes, 8);
}
//////////////////////////////////////////////////////////////////////////////
// Private fields and methods
//////////////////////////////////////////////////////////////////////////////
/**
* Address of where the row is stored in off heap memory.
*/
private long address;
/**
* For each column the starting location to read from. The index to the is the position in
* the row bytes, not the user faceing ordinal.
*/
private int[] startOffsets;
/**
* At what point validity data starts.
*/
private int fixedWidthSizeInBytes;
/**
* The size of this row's backing data, in bytes.
*/
private int sizeInBytes;
/**
* A mapping from the user facing ordinal to the index in the underlying row.
*/
private int[] remapping;
/**
* Get the address where a field is stored.
* @param ordinal the user facing ordinal.
* @return the address of the field.
*/
private long getFieldAddressFromOrdinal(int ordinal) {
assertIndexIsValid(ordinal);
int i = remapping[ordinal];
return address + startOffsets[i];
}
/**
* Verify that index is valid for this row.
* @param index in this case the index can be either the user facing ordinal or the index into the
* row.
*/
private void assertIndexIsValid(int index) {
assert index >= 0 : "index (" + index + ") should >= 0";
assert index < startOffsets.length : "index (" + index + ") should < " + startOffsets.length;
}
//////////////////////////////////////////////////////////////////////////////
// Public methods
//////////////////////////////////////////////////////////////////////////////
/**
* Construct a new Row. The resulting row won't be usable until `pointTo()` has been called,
* since the value returned by this constructor is equivalent to a null pointer.
*
* @param attributes the schema of what this will hold. This is the schema of the underlying
* row, so if columns were re-ordered it is the attributes of the reordered
* data.
* @param remapping a mapping from the user requested column to the underlying column in the
* backing row.
*/
public CudfUnsafeRow(Attribute[] attributes, int[] remapping) {
int offset = 0;
startOffsets = new int[attributes.length];
for (int i = 0; i < attributes.length; i++) {
Attribute attr = attributes[i];
int length = GpuColumnVector.getNonNestedRapidsType(attr.dataType()).getSizeInBytes();
assert length > 0 : "Only fixed width types are currently supported.";
offset = alignOffset(offset, length);
startOffsets[i] = offset;
offset += length;
}
fixedWidthSizeInBytes = offset;
this.remapping = remapping;
assert startOffsets.length == remapping.length;
}
// for serializer
public CudfUnsafeRow() {}
@Override
public int numFields() { return startOffsets.length; }
/**
* Update this CudfUnsafeRow to point to different backing data.
*
* @param address the address in host memory for this. We should change this to be a
* MemoryBuffer class or something like that.
* @param sizeInBytes the size of this row's backing data, in bytes
*/
public void pointTo(long address, int sizeInBytes) {
assert startOffsets != null && startOffsets.length > 0 : "startOffsets not properly initialized";
assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8";
this.address = address;
this.sizeInBytes = sizeInBytes;
}
@Override
public void update(int ordinal, Object value) {
throw new UnsupportedOperationException();
}
@Override
public Object get(int ordinal, DataType dataType) {
// Don't remap the ordinal because it will be remapped in each of the other backing APIs
return SpecializedGettersReader.read(this, ordinal, dataType, true, true);
}
@Override
public boolean isNullAt(int ordinal) {
int i = remapping[ordinal];
assertIndexIsValid(i);
int validByteIndex = i / 8;
int validBitIndex = i % 8;
byte b = Platform.getByte(null, address + fixedWidthSizeInBytes + validByteIndex);
return ((1 << validBitIndex) & b) == 0;
}
@Override
public void setNullAt(int ordinal) {
int i = remapping[ordinal];
assertIndexIsValid(i);
int validByteIndex = i / 8;
int validBitIndex = i % 8;
byte b = Platform.getByte(null, address + fixedWidthSizeInBytes + validByteIndex);
b = (byte)((b & ~(1 << validBitIndex)) & 0xFF);
Platform.putByte(null, address + fixedWidthSizeInBytes + validByteIndex, b);
}
@Override
public boolean getBoolean(int ordinal) {
return Platform.getBoolean(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public byte getByte(int ordinal) {
return Platform.getByte(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public short getShort(int ordinal) {
return Platform.getShort(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public int getInt(int ordinal) {
return Platform.getInt(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public long getLong(int ordinal) {
return Platform.getLong(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public float getFloat(int ordinal) {
return Platform.getFloat(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public double getDouble(int ordinal) {
return Platform.getDouble(null, getFieldAddressFromOrdinal(ordinal));
}
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
if (isNullAt(ordinal)) {
return null;
}
if (precision <= Decimal.MAX_INT_DIGITS()) {
return Decimal.createUnsafe(getInt(ordinal), precision, scale);
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.createUnsafe(getLong(ordinal), precision, scale);
} else {
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
// byte[] bytes = getBinary(ordinal);
// BigInteger bigInteger = new BigInteger(bytes);
// BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
// return Decimal.apply(javaDecimal, precision, scale);
}
}
@Override
public UTF8String getUTF8String(int ordinal) {
// if (isNullAt(ordinal)) return null;
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// return UTF8String.fromAddress(null, address + offset, size);
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public byte[] getBinary(int ordinal) {
// if (isNullAt(ordinal)) {
// return null;
// } else {
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// final byte[] bytes = new byte[size];
// Platform.copyMemory(
// null,
// address + offset,
// bytes,
// Platform.BYTE_ARRAY_OFFSET,
// size
// );
// return bytes;
// }
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public CalendarInterval getInterval(int ordinal) {
// if (isNullAt(ordinal)) {
// return null;
// } else {
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int months = Platform.getInt(baseObject, address + offset);
// final int days = Platform.getInt(baseObject, address + offset + 4);
// final long microseconds = Platform.getLong(baseObject, address + offset + 8);
// return new CalendarInterval(months, days, microseconds);
// }
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public CudfUnsafeRow getStruct(int ordinal, int numFields) {
// if (isNullAt(ordinal)) {
// return null;
// } else {
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// final UnsafeRow row = new UnsafeRow(numFields);
// row.pointTo(baseObject, address + offset, size);
// return row;
// }
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public ArrayData getArray(int ordinal) {
// if (isNullAt(ordinal)) {
// return null;
// } else {
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// final UnsafeArrayData array = new UnsafeArrayData();
// array.pointTo(baseObject, address + offset, size);
// return array;
// }
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public MapData getMap(int ordinal) {
// if (isNullAt(ordinal)) {
// return null;
// } else {
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// final UnsafeMapData map = new UnsafeMapData();
// map.pointTo(baseObject, address + offset, size);
// return map;
// }
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
/**
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
* byte array rather than referencing data stored in a data page.
*/
@Override
public CudfUnsafeRow copy() {
// UnsafeRow rowCopy = new UnsafeRow(numFields);
// final byte[] rowDataCopy = new byte[sizeInBytes];
// Platform.copyMemory(
// baseObject,
// address,
// rowDataCopy,
// Platform.BYTE_ARRAY_OFFSET,
// sizeInBytes
// );
// rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
// return rowCopy;
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
}
@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(null, address, sizeInBytes, 42);
}
@Override
public boolean equals(Object other) {
if (other instanceof CudfUnsafeRow) {
CudfUnsafeRow o = (CudfUnsafeRow) other;
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(null, address, null, o.address, sizeInBytes) &&
Arrays.equals(remapping, o.remapping);
}
return false;
}
// This is for debugging
@Override
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
if (i != 0) build.append(',');
build.append(java.lang.Long.toHexString(Platform.getLong(null, address + i)));
}
build.append(']');
build.append(" remapped with ");
build.append(Arrays.toString(remapping));
return build.toString();
}
@Override
public boolean anyNull() {
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
// return BitSetMethods.anySet(baseObject, address, bitSetWidthInBytes / 8);
}
}