Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HostMemoryAllocator interface #13924

Merged
merged 8 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -5003,7 +5003,7 @@ private static NestedColumnVector createNestedColumnVector(DType type, long rows
/////////////////////////////////////////////////////////////////////////////

private static HostColumnVectorCore copyToHostNestedHelper(
ColumnView deviceCvPointer) {
ColumnView deviceCvPointer, HostMemoryAllocator hostMemoryAllocator) {
if (deviceCvPointer == null) {
return null;
}
Expand All @@ -5023,21 +5023,21 @@ private static HostColumnVectorCore copyToHostNestedHelper(
currOffsets = deviceCvPointer.getOffsets();
currValidity = deviceCvPointer.getValid();
if (currData != null) {
hostData = HostMemoryBuffer.allocate(currData.length);
hostData = hostMemoryAllocator.allocate(currData.length);
hostData.copyFromDeviceBuffer(currData);
}
if (currValidity != null) {
hostValid = HostMemoryBuffer.allocate(currValidity.length);
hostValid = hostMemoryAllocator.allocate(currValidity.length);
hostValid.copyFromDeviceBuffer(currValidity);
}
if (currOffsets != null) {
hostOffsets = HostMemoryBuffer.allocate(currOffsets.length);
hostOffsets = hostMemoryAllocator.allocate(currOffsets.length);
hostOffsets.copyFromDeviceBuffer(currOffsets);
}
int numChildren = deviceCvPointer.getNumChildren();
for (int i = 0; i < numChildren; i++) {
try(ColumnView childDevPtr = deviceCvPointer.getChildColumnView(i)) {
children.add(copyToHostNestedHelper(childDevPtr));
children.add(copyToHostNestedHelper(childDevPtr, hostMemoryAllocator));
}
}
currNullCount = deviceCvPointer.getNullCount();
Expand Down Expand Up @@ -5074,7 +5074,7 @@ private static HostColumnVectorCore copyToHostNestedHelper(
/**
* Copy the data to the host.
*/
public HostColumnVector copyToHost() {
public HostColumnVector copyToHost(HostMemoryAllocator hostMemoryAllocator) {
try (NvtxRange toHost = new NvtxRange("ensureOnHost", NvtxColor.BLUE)) {
HostMemoryBuffer hostDataBuffer = null;
HostMemoryBuffer hostValidityBuffer = null;
Expand All @@ -5094,16 +5094,16 @@ public HostColumnVector copyToHost() {
getNullCount();
if (!type.isNestedType()) {
if (valid != null) {
hostValidityBuffer = HostMemoryBuffer.allocate(valid.getLength());
hostValidityBuffer = hostMemoryAllocator.allocate(valid.getLength());
hostValidityBuffer.copyFromDeviceBuffer(valid);
}
if (offsets != null) {
hostOffsetsBuffer = HostMemoryBuffer.allocate(offsets.length);
hostOffsetsBuffer = hostMemoryAllocator.allocate(offsets.length);
hostOffsetsBuffer.copyFromDeviceBuffer(offsets);
}
// If a strings column is all null values there is no data buffer allocated
if (data != null) {
hostDataBuffer = HostMemoryBuffer.allocate(data.length);
hostDataBuffer = hostMemoryAllocator.allocate(data.length);
hostDataBuffer.copyFromDeviceBuffer(data);
}
HostColumnVector ret = new HostColumnVector(type, rows, Optional.of(nullCount),
Expand All @@ -5112,22 +5112,22 @@ public HostColumnVector copyToHost() {
return ret;
} else {
if (data != null) {
hostDataBuffer = HostMemoryBuffer.allocate(data.length);
hostDataBuffer = hostMemoryAllocator.allocate(data.length);
hostDataBuffer.copyFromDeviceBuffer(data);
}

if (valid != null) {
hostValidityBuffer = HostMemoryBuffer.allocate(valid.getLength());
hostValidityBuffer = hostMemoryAllocator.allocate(valid.getLength());
hostValidityBuffer.copyFromDeviceBuffer(valid);
}
if (offsets != null) {
hostOffsetsBuffer = HostMemoryBuffer.allocate(offsets.getLength());
hostOffsetsBuffer = hostMemoryAllocator.allocate(offsets.getLength());
hostOffsetsBuffer.copyFromDeviceBuffer(offsets);
}
List<HostColumnVectorCore> children = new ArrayList<>();
for (int i = 0; i < getNumChildren(); i++) {
try (ColumnView childDevPtr = getChildColumnView(i)) {
children.add(copyToHostNestedHelper(childDevPtr));
children.add(copyToHostNestedHelper(childDevPtr, hostMemoryAllocator));
}
}
HostColumnVector ret = new HostColumnVector(type, rows, Optional.of(nullCount),
Expand Down Expand Up @@ -5160,6 +5160,10 @@ public HostColumnVector copyToHost() {
}
}

public HostColumnVector copyToHost() {
return copyToHost(DefaultHostMemoryAllocator.get());
}

/**
* Calculate the total space required to copy the data to the host. This should be padded to
* the alignment that the CPU requires.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

public class DefaultHostMemoryAllocator implements HostMemoryAllocator {
private static final HostMemoryAllocator INSTANCE = new DefaultHostMemoryAllocator();
public static HostMemoryAllocator get() {
return INSTANCE;
}

@Override
public HostMemoryBuffer allocate(long bytes, boolean preferPinned) {
return HostMemoryBuffer.allocate(bytes, preferPinned);
}

@Override
public HostMemoryBuffer allocate(long bytes) {
return HostMemoryBuffer.allocate(bytes);
}
}
39 changes: 39 additions & 0 deletions java/src/main/java/ai/rapids/cudf/HostMemoryAllocator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.rapids.cudf;

public interface HostMemoryAllocator {

/**
* Allocate memory, but be sure to close the returned buffer to avoid memory leaks.
* @param bytes size in bytes to allocate
* @param preferPinned If set to true, the pinned memory pool will be used if possible with a
* fallback to off-heap memory. If set to false, the allocation will always
* be from off-heap memory.
* @return the newly created buffer
*/
HostMemoryBuffer allocate(long bytes, boolean preferPinned);

/**
* Allocate memory, but be sure to close the returned buffer to avoid memory leaks. Pinned memory
* for allocations preference is up to the implementor
*
* @param bytes size in bytes to allocate
* @return the newly created buffer
*/
HostMemoryBuffer allocate(long bytes);
}
23 changes: 19 additions & 4 deletions java/src/main/java/ai/rapids/cudf/JCudfSerialization.java
Original file line number Diff line number Diff line change
Expand Up @@ -1810,14 +1810,17 @@ public static ContiguousTable concatToContiguousTable(SerializedTableHeader[] he
* Concatenate multiple tables in host memory into a single host table buffer.
* @param headers table headers corresponding to the host table buffers
* @param dataBuffers host table buffer for each input table to be concatenated
* @param hostMemoryAllocator allocator for host memory buffers
* @return host table header and buffer
*/
public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] headers,
HostMemoryBuffer[] dataBuffers) throws IOException {
HostMemoryBuffer[] dataBuffers,
HostMemoryAllocator hostMemoryAllocator
) throws IOException {
ColumnBufferProvider[][] providersPerColumn = providersFrom(headers, dataBuffers);
try {
SerializedTableHeader combined = calcConcatHeader(providersPerColumn);
HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(combined.dataLen);
HostMemoryBuffer hostBuffer = hostMemoryAllocator.allocate(combined.dataLen);
try {
try (NvtxRange range = new NvtxRange("Concat Host Side", NvtxColor.GREEN)) {
DataWriter writer = writerFrom(hostBuffer);
Expand All @@ -1837,6 +1840,12 @@ public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] header
}
}

public static HostConcatResult concatToHostBuffer(SerializedTableHeader[] headers,
HostMemoryBuffer[] dataBuffers
) throws IOException {
return concatToHostBuffer(headers, dataBuffers, DefaultHostMemoryAllocator.get());
}

/**
* Deserialize a serialized contiguous table into an array of host columns.
*
Expand Down Expand Up @@ -1916,12 +1925,14 @@ public static TableAndRowCountPair readTableFrom(SerializedTableHeader header,
/**
* Read a serialize table from the given InputStream.
* @param in the stream to read the table data from.
* @param hostMemoryAllocator a host memory allocator for an intermediate host memory buffer
* @return the deserialized table in device memory, or null if the stream has no table to read
* from, an end of the stream at the very beginning.
* @throws IOException on any error.
* @throws EOFException if the data stream ended unexpectedly in the middle of processing.
*/
public static TableAndRowCountPair readTableFrom(InputStream in) throws IOException {
public static TableAndRowCountPair readTableFrom(InputStream in,
HostMemoryAllocator hostMemoryAllocator) throws IOException {
DataInputStream din;
if (in instanceof DataInputStream) {
din = (DataInputStream) in;
Expand All @@ -1934,14 +1945,18 @@ public static TableAndRowCountPair readTableFrom(InputStream in) throws IOExcept
return new TableAndRowCountPair(0, null);
}

try (HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(header.dataLen)) {
try (HostMemoryBuffer hostBuffer = hostMemoryAllocator.allocate(header.dataLen)) {
if (header.dataLen > 0) {
readTableIntoBuffer(din, header, hostBuffer);
}
return readTableFrom(header, hostBuffer);
}
}

public static TableAndRowCountPair readTableFrom(InputStream in) throws IOException {
return readTableFrom(in, DefaultHostMemoryAllocator.get());
}

/** Holds the result of deserializing a table. */
public static final class TableAndRowCountPair implements Closeable {
private final int numRows;
Expand Down
8 changes: 6 additions & 2 deletions java/src/main/java/ai/rapids/cudf/PinnedMemoryPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,18 @@ public static HostMemoryBuffer tryAllocate(long bytes) {
* @param bytes size in bytes to allocate
* @return newly created buffer
*/
public static HostMemoryBuffer allocate(long bytes) {
public static HostMemoryBuffer allocate(long bytes, HostMemoryAllocator hostMemoryAllocator) {
HostMemoryBuffer result = tryAllocate(bytes);
if (result == null) {
result = HostMemoryBuffer.allocate(bytes, false);
result = hostMemoryAllocator.allocate(bytes, false);
}
return result;
}

public static HostMemoryBuffer allocate(long bytes) {
return allocate(bytes, DefaultHostMemoryAllocator.get());
}

/**
* Get the number of bytes free in the pinned memory pool.
*
Expand Down
Loading