From 93d25fcfd8f73e251255f5e0001a62ef8ddf67a1 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 16 Mar 2021 11:13:34 -0500 Subject: [PATCH] Refactor Java host-side buffer concatenation to expose separate steps --- .../ai/rapids/cudf/JCudfSerialization.java | 94 ++++++++++++++----- 1 file changed, 71 insertions(+), 23 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java b/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java index bf49fb59d52..6c52b8fe798 100644 --- a/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java +++ b/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -353,6 +353,50 @@ static SerializedColumnHeader readFrom(DataInputStream din, long rowCount) throw } } + /** Class to hold the header and buffer pair result from host-side concatenation */ + public static final class HostConcatResult implements AutoCloseable { + private final SerializedTableHeader tableHeader; + private final HostMemoryBuffer hostBuffer; + + public HostConcatResult(SerializedTableHeader tableHeader, HostMemoryBuffer tableBuffer) { + this.tableHeader = tableHeader; + this.hostBuffer = tableBuffer; + } + + public SerializedTableHeader getTableHeader() { + return tableHeader; + } + + public HostMemoryBuffer getHostBuffer() { + return hostBuffer; + } + + /** Build a contiguous table in device memory from this host-concatenated result */ + public ContiguousTable toContiguousTable() { + DeviceMemoryBuffer devBuffer = DeviceMemoryBuffer.allocate(hostBuffer.length); + try { + if (hostBuffer.length > 0) { + devBuffer.copyFromHostBuffer(hostBuffer); + } + Table table = sliceUpColumnVectors(tableHeader, devBuffer, hostBuffer); + try { + return new ContiguousTable(table, devBuffer); + } catch (Exception e) { + table.close(); + throw e; + } + } catch (Exception e) { + devBuffer.close(); + throw e; + } + } + + @Override + public void close() { + hostBuffer.close(); + } + } + /** * Visible for testing */ @@ -1681,15 +1725,32 @@ public static Table readAndConcat(SerializedTableHeader[] headers, return ct.getTable(); } + /** + * Concatenate multiple tables in host memory into a contiguous table in device memory. + * @param headers table headers corresponding to the host table buffers + * @param dataBuffers host table buffer for each input table to be concatenated + * @return contiguous table in device memory + */ public static ContiguousTable concatToContiguousTable(SerializedTableHeader[] headers, HostMemoryBuffer[] dataBuffers) throws IOException { + try (HostConcatResult concatResult = concatToHostBuffer(headers, dataBuffers)) { + return concatResult.toContiguousTable(); + } + } + + /** + * Concatenate multiple tables in host memory into a single host table buffer. + * @param headers table headers corresponding to the host table buffers + * @param dataBuffers host table buffer for each input table to be concatenated + * @return host table header and buffer + */ + public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] headers, + HostMemoryBuffer[] dataBuffers) throws IOException { ColumnBufferProvider[][] providersPerColumn = providersFrom(headers, dataBuffers); - DeviceMemoryBuffer devBuffer = null; - Table table = null; try { SerializedTableHeader combined = calcConcatHeader(providersPerColumn); - - try (HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(combined.dataLen)) { + HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(combined.dataLen); + try { try (NvtxRange range = new NvtxRange("Concat Host Side", NvtxColor.GREEN)) { DataWriter writer = writerFrom(hostBuffer); int numColumns = combined.getNumColumns(); @@ -1697,27 +1758,14 @@ public static ContiguousTable concatToContiguousTable(SerializedTableHeader[] he writeConcat(writer, combined.getColumnHeader(columnIdx), providersPerColumn[columnIdx]); } } - - devBuffer = DeviceMemoryBuffer.allocate(hostBuffer.length); - if (hostBuffer.length > 0) { - try (NvtxRange range = new NvtxRange("Copy Data To Device", NvtxColor.WHITE)) { - devBuffer.copyFromHostBuffer(hostBuffer); - } - } - table = sliceUpColumnVectors(combined, devBuffer, hostBuffer); - ContiguousTable result = new ContiguousTable(table, devBuffer); - table = null; - devBuffer = null; - return result; + } catch (Exception e) { + hostBuffer.close(); + throw e; } + + return new HostConcatResult(combined, hostBuffer); } finally { closeAll(providersPerColumn); - if (table != null) { - table.close(); - } - if (devBuffer != null) { - devBuffer.close(); - } } }