Skip to content

Commit

Permalink
Update JNI for contiguous_split packed results (#7127)
Browse files Browse the repository at this point in the history
This PR requires the libcudf changes in #7096, fixing the Java bindings to `contiguous_split` that are broken by that change.

This also adds the ability to create a `ContiguousTable` instance without manifesting a `Table` instance and all `ColumnVector` instances underneath it which should prove useful during Spark's shuffle.

Authors:
  - Jason Lowe (@jlowe)

Approvers:
  - Robert (Bobby) Evans (@revans2)
  - Alessandro Bellina (@abellina)

URL: #7127
  • Loading branch information
jlowe authored Feb 4, 2021
1 parent 4f87a59 commit 110ef3e
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 110 deletions.
110 changes: 76 additions & 34 deletions java/src/main/java/ai/rapids/cudf/ContiguousTable.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,63 +18,96 @@

package ai.rapids.cudf;

import java.util.Arrays;
import java.nio.ByteBuffer;

/**
* A table that is backed by a single contiguous device buffer. This makes transfers of the data
* much simpler.
*/
public final class ContiguousTable implements AutoCloseable {
private Table table;
private long metadataHandle = 0;
private Table table = null;
private DeviceMemoryBuffer buffer;
private ByteBuffer metadataBuffer = null;
private final long rowCount;

//Will be called from JNI
static ContiguousTable fromContiguousColumnViews(long[] columnViewAddresses,
long address, long lengthInBytes, long rmmBufferAddress) {
Table table = null;
ColumnVector[] vectors = new ColumnVector[columnViewAddresses.length];
DeviceMemoryBuffer buffer = DeviceMemoryBuffer.fromRmm(address, lengthInBytes, rmmBufferAddress);
try {
for (int i = 0; i < vectors.length; i++) {
vectors[i] = ColumnVector.fromViewWithContiguousAllocation(columnViewAddresses[i], buffer);
}
table = new Table(vectors);
ContiguousTable ret = new ContiguousTable(table, buffer);
buffer = null;
table = null;
return ret;
} finally {
if (buffer != null) {
buffer.close();
}

for (int i = 0; i < vectors.length; i++) {
if (vectors[i] != null) {
vectors[i].close();
}
}

if (table != null) {
table.close();
}
}
// This method is invoked by JNI
static ContiguousTable fromPackedTable(long metadataHandle,
long dataAddress,
long dataLength,
long rmmBufferAddress,
long rowCount) {
DeviceMemoryBuffer buffer = DeviceMemoryBuffer.fromRmm(dataAddress, dataLength, rmmBufferAddress);
return new ContiguousTable(metadataHandle, buffer, rowCount);
}

/** Construct a contiguous table instance given a table and the device buffer backing it. */
ContiguousTable(Table table, DeviceMemoryBuffer buffer) {
this.metadataHandle = createPackedMetadata(table.getNativeView(),
buffer.getAddress(), buffer.getLength());
this.table = table;
this.buffer = buffer;
this.rowCount = table.getRowCount();
}

public Table getTable() {
/**
* Construct a contiguous table
* @param metadataHandle address of the cudf packed_table host-based metadata instance
* @param buffer buffer containing the packed table data
* @param rowCount number of rows in the table
*/
ContiguousTable(long metadataHandle, DeviceMemoryBuffer buffer, long rowCount) {
this.metadataHandle = metadataHandle;
this.buffer = buffer;
this.rowCount = rowCount;
}

/**
* Returns the number of rows in the table. This accessor avoids manifesting
* the Table instance if only the row count is needed.
*/
public long getRowCount() {
return rowCount;
}

/** Get the table instance, reconstructing it from the metadata if necessary. */
public synchronized Table getTable() {
if (table == null) {
table = Table.fromPackedTable(getMetadataDirectBuffer(), buffer);
}
return table;
}

/** Get the device buffer backing the contiguous table data. */
public DeviceMemoryBuffer getBuffer() {
return buffer;
}

/**
* Get the byte buffer containing the host metadata describing the schema and layout of the
* contiguous table.
* <p>
* NOTE: This is a direct byte buffer that is backed by the underlying native metadata instance
* and therefore is only valid to be used while this contiguous table instance is valid.
* Attempts to cache and access the resulting buffer after this instance has been destroyed
* will result in undefined behavior including the possibility of segmentation faults
* or data corruption.
*/
public ByteBuffer getMetadataDirectBuffer() {
if (metadataBuffer == null) {
metadataBuffer = createMetadataDirectBuffer(metadataHandle);
}
return metadataBuffer.asReadOnlyBuffer();
}

/** Close the contiguous table instance and its underlying resources. */
@Override
public void close() {
if (metadataHandle != 0) {
closeMetadata(metadataHandle);
metadataHandle = 0;
}

if (table != null) {
table.close();
table = null;
Expand All @@ -85,4 +118,13 @@ public void close() {
buffer = null;
}
}

// create packed metadata for a table backed by a single data buffer
private static native long createPackedMetadata(long tableView, long dataAddress, long dataSize);

// create a DirectByteBuffer for the packed table metadata
private static native ByteBuffer createMetadataDirectBuffer(long metadataHandle);

// release the native metadata resources for a packed table
private static native void closeMetadata(long metadataHandle);
}
54 changes: 53 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.File;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -117,6 +118,11 @@ ColumnVector[] getColumns() {
return columns;
}

/** Return the native table view handle for this table */
long getNativeView() {
return nativeHandle;
}

/**
* Return the {@link ColumnVector} at the specified index. If you want to keep a reference to
* the column around past the life time of the table, you will need to increment the reference
Expand Down Expand Up @@ -503,7 +509,9 @@ private static native long[] repeatColumnCount(long tableHandle,

private static native long[] explode(long tableHandle, int index);

private native long createCudfTableView(long[] nativeColumnViewHandles);
private static native long createCudfTableView(long[] nativeColumnViewHandles);

private static native long[] columnViewsFromPacked(ByteBuffer metadata, long dataAddress);

/////////////////////////////////////////////////////////////////////////////
// TABLE CREATION APIs
Expand Down Expand Up @@ -1796,6 +1804,50 @@ public static Table convertFromRows(ColumnVector vec, DType ... schema) {
return new Table(convertFromRows(vec.getNativeView(), types, scale));
}

/**
* Construct a table from a packed representation.
* @param metadata host-based metadata for the table
* @param data GPU data buffer for the table
* @return table which is zero-copy reconstructed from the packed-form
*/
public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data) {
// Ensure the metadata buffer is direct so it can be passed to JNI
ByteBuffer directBuffer = metadata;
if (!directBuffer.isDirect()) {
directBuffer = ByteBuffer.allocateDirect(metadata.remaining());
directBuffer.put(metadata);
directBuffer.flip();
}

long[] columnViewAddresses = columnViewsFromPacked(directBuffer, data.getAddress());
ColumnVector[] columns = new ColumnVector[columnViewAddresses.length];
Table result = null;
try {
for (int i = 0; i < columns.length; i++) {
columns[i] = ColumnVector.fromViewWithContiguousAllocation(columnViewAddresses[i], data);
columnViewAddresses[i] = 0;
}
result = new Table(columns);
} catch (Throwable t) {
for (int i = 0; i < columns.length; i++) {
if (columns[i] != null) {
columns[i].close();
}
if (columnViewAddresses[i] != 0) {
ColumnView.deleteColumnView(columnViewAddresses[i]);
}
}
throw t;
}

// close columns to leave the resulting table responsible for freeing underlying columns
for (ColumnVector column : columns) {
column.close();
}

return result;
}

/////////////////////////////////////////////////////////////////////////////
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////
Expand Down
1 change: 1 addition & 0 deletions java/src/main/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ set(SOURCE_FILES
"src/CudaJni.cpp"
"src/ColumnVectorJni.cpp"
"src/ColumnViewJni.cpp"
"src/ContiguousTableJni.cpp"
"src/HostMemoryBufferNativeUtilsJni.cpp"
"src/NvcompJni.cpp"
"src/NvtxRangeJni.cpp"
Expand Down
122 changes: 122 additions & 0 deletions java/src/main/native/src/ContiguousTableJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cudf_jni_apis.hpp"

namespace {

#define CONTIGUOUS_TABLE_CLASS "ai/rapids/cudf/ContiguousTable"
#define CONTIGUOUS_TABLE_FACTORY_SIG(param_sig) "(" param_sig ")L" CONTIGUOUS_TABLE_CLASS ";"

jclass Contiguous_table_jclass;
jmethodID From_packed_table_method;

} // anonymous namespace

namespace cudf {
namespace jni {

bool cache_contiguous_table_jni(JNIEnv *env) {
jclass cls = env->FindClass(CONTIGUOUS_TABLE_CLASS);
if (cls == nullptr) {
return false;
}

From_packed_table_method =
env->GetStaticMethodID(cls, "fromPackedTable", CONTIGUOUS_TABLE_FACTORY_SIG("JJJJJ"));
if (From_packed_table_method == nullptr) {
return false;
}

// Convert local reference to global so it cannot be garbage collected.
Contiguous_table_jclass = static_cast<jclass>(env->NewGlobalRef(cls));
if (Contiguous_table_jclass == nullptr) {
return false;
}
return true;
}

void release_contiguous_table_jni(JNIEnv *env) {
if (Contiguous_table_jclass != nullptr) {
env->DeleteGlobalRef(Contiguous_table_jclass);
Contiguous_table_jclass = nullptr;
}
}

jobject contiguous_table_from(JNIEnv *env, cudf::packed_columns &split, long row_count) {
jlong metadata_address = reinterpret_cast<jlong>(split.metadata_.get());
jlong data_address = reinterpret_cast<jlong>(split.gpu_data->data());
jlong data_size = static_cast<jlong>(split.gpu_data->size());
jlong rmm_buffer_address = reinterpret_cast<jlong>(split.gpu_data.get());

jobject contig_table_obj = env->CallStaticObjectMethod(
Contiguous_table_jclass, From_packed_table_method, metadata_address, data_address, data_size,
rmm_buffer_address, row_count);

if (contig_table_obj != nullptr) {
split.metadata_.release();
split.gpu_data.release();
}

return contig_table_obj;
}

native_jobjectArray<jobject> contiguous_table_array(JNIEnv *env, jsize length) {
return native_jobjectArray<jobject>(
env, env->NewObjectArray(length, Contiguous_table_jclass, nullptr));
}

} // namespace jni
} // namespace cudf

extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ContiguousTable_createPackedMetadata(
JNIEnv *env, jclass, jlong j_table, jlong j_buffer_addr, jlong j_buffer_length) {
JNI_NULL_CHECK(env, j_table, "input table is null", 0);
try {
cudf::jni::auto_set_device(env);
auto table = reinterpret_cast<cudf::table_view const *>(j_table);
auto data_addr = reinterpret_cast<uint8_t const *>(j_buffer_addr);
auto data_size = static_cast<size_t>(j_buffer_length);
auto metadata_ptr =
new cudf::packed_columns::metadata(cudf::pack_metadata(*table, data_addr, data_size));
return reinterpret_cast<jlong>(metadata_ptr);
}
CATCH_STD(env, 0);
}

JNIEXPORT jobject JNICALL Java_ai_rapids_cudf_ContiguousTable_createMetadataDirectBuffer(
JNIEnv *env, jclass, jlong j_metadata_ptr) {
JNI_NULL_CHECK(env, j_metadata_ptr, "metadata is null", nullptr);
try {
auto metadata = reinterpret_cast<cudf::packed_columns::metadata *>(j_metadata_ptr);
return env->NewDirectByteBuffer(const_cast<uint8_t *>(metadata->data()), metadata->size());
}
CATCH_STD(env, nullptr);
}

JNIEXPORT void JNICALL Java_ai_rapids_cudf_ContiguousTable_closeMetadata(JNIEnv *env, jclass,
jlong j_metadata_ptr) {
JNI_NULL_CHECK(env, j_metadata_ptr, "metadata is null", );
try {
auto metadata = reinterpret_cast<cudf::packed_columns::metadata *>(j_metadata_ptr);
delete metadata;
}
CATCH_STD(env, );
}

} // extern "C"
Loading

0 comments on commit 110ef3e

Please sign in to comment.