Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Java API to deserialize a table to host columns #9402

Merged
merged 1 commit into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions java/src/main/java/ai/rapids/cudf/JCudfSerialization.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/**
* Serialize and deserialize CUDF tables and columns using a custom format. The goal of this is
Expand Down Expand Up @@ -1660,6 +1662,51 @@ public static void writeConcatedStream(SerializedTableHeader[] headers,
// COLUMN AND TABLE READ
/////////////////////////////////////////////

private static HostColumnVector buildHostColumn(SerializedColumnHeader column,
ArrayDeque<ColumnOffsets> columnOffsets,
HostMemoryBuffer buffer) {
ColumnOffsets offsetsInfo = columnOffsets.remove();
SerializedColumnHeader[] children = column.getChildren();
int numChildren = children != null ? children.length : 0;
List<HostColumnVectorCore> childColumns = new ArrayList<>(numChildren);
try {
if (children != null) {
for (SerializedColumnHeader child : children) {
childColumns.add(buildHostColumn(child, columnOffsets, buffer));
}
}
DType dtype = column.getType();
long rowCount = column.getRowCount();
long nullCount = column.getNullCount();
HostMemoryBuffer dataBuffer = null;
HostMemoryBuffer validityBuffer = null;
HostMemoryBuffer offsetsBuffer = null;
if (!dtype.isNestedType()) {
dataBuffer = buffer.slice(offsetsInfo.data, offsetsInfo.dataLen);
}
if (nullCount > 0) {
long validitySize = BitVectorHelper.getValidityLengthInBytes(rowCount);
validityBuffer = buffer.slice(offsetsInfo.validity, validitySize);
}
if (dtype.hasOffsets()) {
// one 32-bit integer offset per row plus one additional offset at the end
long offsetsSize = rowCount > 0 ? (rowCount + 1) * Integer.BYTES : 0;
offsetsBuffer = buffer.slice(offsetsInfo.offsets, offsetsSize);
}
HostColumnVector result = new HostColumnVector(dtype, column.getRowCount(),
Optional.of(column.getNullCount()), dataBuffer, validityBuffer, offsetsBuffer,
childColumns);
childColumns = null;
return result;
} finally {
if (childColumns != null) {
for (HostColumnVectorCore c : childColumns) {
c.close();
}
}
}
}

private static long buildColumnView(SerializedColumnHeader column,
ArrayDeque<ColumnOffsets> columnOffsets,
DeviceMemoryBuffer combinedBuffer) {
Expand Down Expand Up @@ -1769,6 +1816,38 @@ public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] header
}
}

/**
* Deserialize a serialized contiguous table into an array of host columns.
*
* @param header serialized table header
* @param hostBuffer buffer containing the data for all columns in the serialized table
* @return array of host columns representing the data from the serialized table
*/
public static HostColumnVector[] unpackHostColumnVectors(SerializedTableHeader header,
HostMemoryBuffer hostBuffer) {
ArrayDeque<ColumnOffsets> columnOffsets = buildIndex(header, hostBuffer);
int numColumns = header.getNumColumns();
HostColumnVector[] columns = new HostColumnVector[numColumns];
boolean succeeded = false;
try {
for (int i = 0; i < numColumns; i++) {
SerializedColumnHeader column = header.getColumnHeader(i);
columns[i] = buildHostColumn(column, columnOffsets, hostBuffer);
}
assert columnOffsets.isEmpty();
succeeded = true;
} finally {
if (!succeeded) {
for (HostColumnVector c : columns) {
if (c != null) {
c.close();
}
}
}
}
return columns;
firestarman marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* After reading a header for a table read the data portion into a host side buffer.
* @param in the stream to read the data from.
Expand Down
61 changes: 61 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3147,6 +3147,28 @@ void testSerializationRoundTripConcatOnHostEmpty() throws IOException {
}
}

@Test
void testSerializationRoundTripToHostEmpty() throws IOException {
DataType listStringsType = new ListType(true, new BasicType(true, DType.STRING));
DataType mapType = new ListType(true,
new StructType(true,
new BasicType(false, DType.STRING),
new BasicType(false, DType.STRING)));
DataType structType = new StructType(true,
new BasicType(true, DType.INT8),
new BasicType(false, DType.FLOAT32));
try (ColumnVector emptyInt = ColumnVector.fromInts();
ColumnVector emptyDouble = ColumnVector.fromDoubles();
ColumnVector emptyString = ColumnVector.fromStrings();
ColumnVector emptyListString = ColumnVector.fromLists(listStringsType);
ColumnVector emptyMap = ColumnVector.fromLists(mapType);
ColumnVector emptyStruct = ColumnVector.fromStructs(structType);
Table t = new Table(emptyInt, emptyInt, emptyDouble, emptyString,
emptyListString, emptyMap, emptyStruct)) {
testSerializationRoundTripToHost(t);
}
}

@Test
void testRoundRobinPartition() {
try (Table t = new Table.TestBuilder()
Expand Down Expand Up @@ -3285,6 +3307,45 @@ void testSerializationRoundTripConcatHostSide() throws IOException {
}
}

@Test
void testSerializationRoundTripToHost() throws IOException {
try (Table t = buildTestTable()) {
testSerializationRoundTripToHost(t);
}
}

private void testSerializationRoundTripToHost(Table t) throws IOException {
long rowCount = t.getRowCount();
ByteArrayOutputStream bout = new ByteArrayOutputStream();
JCudfSerialization.writeToStream(t, bout, 0, rowCount);
ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray());
DataInputStream din = new DataInputStream(bin);

JCudfSerialization.SerializedTableHeader header =
new JCudfSerialization.SerializedTableHeader(din);
assertTrue(header.wasInitialized());
try (HostMemoryBuffer buffer = HostMemoryBuffer.allocate(header.getDataLen())) {
JCudfSerialization.readTableIntoBuffer(din, header, buffer);
assertTrue(header.wasDataRead());
HostColumnVector[] hostColumns =
JCudfSerialization.unpackHostColumnVectors(header, buffer);
try {
assertEquals(t.getNumberOfColumns(), hostColumns.length);
for (int i = 0; i < hostColumns.length; i++) {
HostColumnVector actual = hostColumns[i];
assertEquals(rowCount, actual.getRowCount());
try (HostColumnVector expected = t.getColumn(i).copyToHost()) {
assertPartialColumnsAreEqual(expected, 0, rowCount, actual, "COLUMN " + i, true, false);
}
}
} finally {
for (HostColumnVector c: hostColumns) {
c.close();
}
}
}
}

@Test
void testConcatHost() throws IOException {
try (Table t1 = new Table.TestBuilder()
Expand Down