Skip to content

Commit

Permalink
Refactor Java host-side buffer concatenation to expose separate steps (
Browse files Browse the repository at this point in the history
…#7610)

This refactors `JCudfSerialization.concatToContiguousTable` to expose the separate steps of concatenating to a single host-side buffer and constructing a device-side contiguous table from that host buffer.  This allows application code to perform other operations in-between those two steps.

Authors:
  - Jason Lowe (@jlowe)

Approvers:
  - Robert (Bobby) Evans (@revans2)

URL: #7610
  • Loading branch information
jlowe authored Mar 17, 2021
1 parent 39ad863 commit 9c6e1ba
Showing 1 changed file with 71 additions and 23 deletions.
94 changes: 71 additions & 23 deletions java/src/main/java/ai/rapids/cudf/JCudfSerialization.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -353,6 +353,50 @@ static SerializedColumnHeader readFrom(DataInputStream din, long rowCount) throw
}
}

/** Class to hold the header and buffer pair result from host-side concatenation */
public static final class HostConcatResult implements AutoCloseable {
private final SerializedTableHeader tableHeader;
private final HostMemoryBuffer hostBuffer;

public HostConcatResult(SerializedTableHeader tableHeader, HostMemoryBuffer tableBuffer) {
this.tableHeader = tableHeader;
this.hostBuffer = tableBuffer;
}

public SerializedTableHeader getTableHeader() {
return tableHeader;
}

public HostMemoryBuffer getHostBuffer() {
return hostBuffer;
}

/** Build a contiguous table in device memory from this host-concatenated result */
public ContiguousTable toContiguousTable() {
DeviceMemoryBuffer devBuffer = DeviceMemoryBuffer.allocate(hostBuffer.length);
try {
if (hostBuffer.length > 0) {
devBuffer.copyFromHostBuffer(hostBuffer);
}
Table table = sliceUpColumnVectors(tableHeader, devBuffer, hostBuffer);
try {
return new ContiguousTable(table, devBuffer);
} catch (Exception e) {
table.close();
throw e;
}
} catch (Exception e) {
devBuffer.close();
throw e;
}
}

@Override
public void close() {
hostBuffer.close();
}
}

/**
* Visible for testing
*/
Expand Down Expand Up @@ -1681,43 +1725,47 @@ public static Table readAndConcat(SerializedTableHeader[] headers,
return ct.getTable();
}

/**
* Concatenate multiple tables in host memory into a contiguous table in device memory.
* @param headers table headers corresponding to the host table buffers
* @param dataBuffers host table buffer for each input table to be concatenated
* @return contiguous table in device memory
*/
public static ContiguousTable concatToContiguousTable(SerializedTableHeader[] headers,
HostMemoryBuffer[] dataBuffers) throws IOException {
try (HostConcatResult concatResult = concatToHostBuffer(headers, dataBuffers)) {
return concatResult.toContiguousTable();
}
}

/**
* Concatenate multiple tables in host memory into a single host table buffer.
* @param headers table headers corresponding to the host table buffers
* @param dataBuffers host table buffer for each input table to be concatenated
* @return host table header and buffer
*/
public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] headers,
HostMemoryBuffer[] dataBuffers) throws IOException {
ColumnBufferProvider[][] providersPerColumn = providersFrom(headers, dataBuffers);
DeviceMemoryBuffer devBuffer = null;
Table table = null;
try {
SerializedTableHeader combined = calcConcatHeader(providersPerColumn);

try (HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(combined.dataLen)) {
HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(combined.dataLen);
try {
try (NvtxRange range = new NvtxRange("Concat Host Side", NvtxColor.GREEN)) {
DataWriter writer = writerFrom(hostBuffer);
int numColumns = combined.getNumColumns();
for (int columnIdx = 0; columnIdx < numColumns; columnIdx++) {
writeConcat(writer, combined.getColumnHeader(columnIdx), providersPerColumn[columnIdx]);
}
}

devBuffer = DeviceMemoryBuffer.allocate(hostBuffer.length);
if (hostBuffer.length > 0) {
try (NvtxRange range = new NvtxRange("Copy Data To Device", NvtxColor.WHITE)) {
devBuffer.copyFromHostBuffer(hostBuffer);
}
}
table = sliceUpColumnVectors(combined, devBuffer, hostBuffer);
ContiguousTable result = new ContiguousTable(table, devBuffer);
table = null;
devBuffer = null;
return result;
} catch (Exception e) {
hostBuffer.close();
throw e;
}

return new HostConcatResult(combined, hostBuffer);
} finally {
closeAll(providersPerColumn);
if (table != null) {
table.close();
}
if (devBuffer != null) {
devBuffer.close();
}
}
}

Expand Down

0 comments on commit 9c6e1ba

Please sign in to comment.