Skip to content

Commit

Permalink
Add JNI support for converting Arrow buffers to CUDF ColumnVectors (#…
Browse files Browse the repository at this point in the history
…7222)

This adds in the JNI layer to be able to take build up Arrow column vectors which are just references to off heap arrow buffers and then convert those into CUDF ColumnVectors by directly copying the arrow data to the GPU.

The way this works is you create a ArrowColumnBuilder for each column you need. You call addBatch for each separate arrow buffer you want to add into that column and then you call buildAndPutOnDevice() on the Builder. That will cause the arrow pointer to be passed into CUDF, an Arrow Table with 1 column is created, that Arrow table gets passed into the cudf::from_arrow which returns a CUDF Table and we grab the 1 column from that and return it.

Note this only supports primitive types and Strings for now. List, Struct, Dictionary, and Decimal are not supported yet.

Signed-off-by: Thomas Graves <[email protected]>

Authors:
  - Thomas Graves (@tgravescs)

Approvers:
  - Robert (Bobby) Evans (@revans2)
  - Jason Lowe (@jlowe)

URL: #7222
  • Loading branch information
tgravescs authored Jan 28, 2021
1 parent 9631660 commit cbc0394
Show file tree
Hide file tree
Showing 5 changed files with 574 additions and 0 deletions.
7 changes: 7 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@
<version>2.25.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<properties>
Expand All @@ -151,6 +157,7 @@
<GPU_ARCHS>ALL</GPU_ARCHS>
<native.build.path>${project.build.directory}/cmake-build</native.build.path>
<slf4j.version>1.7.30</slf4j.version>
<arrow.version>0.15.1</arrow.version>
</properties>

<profiles>
Expand Down
113 changes: 113 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import java.nio.ByteBuffer;
import java.util.ArrayList;

/**
* Column builder from Arrow data. This builder takes in byte buffers referencing
* Arrow data and allows efficient building of CUDF ColumnVectors from that Arrow data.
* The caller can add multiple batches where each batch corresponds to Arrow data
* and those batches get concatenated together after being converted to CUDF
* ColumnVectors.
* This currently only supports primitive types and Strings, Decimals and nested types
* such as list and struct are not supported.
*/
public final class ArrowColumnBuilder implements AutoCloseable {
private DType type;
private final ArrayList<ByteBuffer> data = new ArrayList<>();
private final ArrayList<ByteBuffer> validity = new ArrayList<>();
private final ArrayList<ByteBuffer> offsets = new ArrayList<>();
private final ArrayList<Long> nullCount = new ArrayList<>();
private final ArrayList<Long> rows = new ArrayList<>();

public ArrowColumnBuilder(HostColumnVector.DataType type) {
this.type = type.getType();
}

/**
* Add an Arrow buffer. This API allows you to add multiple if you want them
* combined into a single ColumnVector.
* Note, this takes all data, validity, and offsets buffers, but they may not all
* be needed based on the data type. The buffer should be null if its not used
* for that type.
* This API only supports primitive types and Strings, Decimals and nested types
* such as list and struct are not supported.
* @param rows - number of rows in this Arrow buffer
* @param nullCount - number of null values in this Arrow buffer
* @param data - ByteBuffer of the Arrow data buffer
* @param validity - ByteBuffer of the Arrow validity buffer
* @param offsets - ByteBuffer of the Arrow offsets buffer
*/
public void addBatch(long rows, long nullCount, ByteBuffer data, ByteBuffer validity,
ByteBuffer offsets) {
this.rows.add(rows);
this.nullCount.add(nullCount);
this.data.add(data);
this.validity.add(validity);
this.offsets.add(offsets);
}

/**
* Create the immutable ColumnVector, copied to the device based on the Arrow data.
* @return - new ColumnVector
*/
public final ColumnVector buildAndPutOnDevice() {
int numBatches = rows.size();
ArrayList<ColumnVector> allVecs = new ArrayList<>(numBatches);
ColumnVector vecRet;
try {
for (int i = 0; i < numBatches; i++) {
allVecs.add(ColumnVector.fromArrow(type, rows.get(i), nullCount.get(i),
data.get(i), validity.get(i), offsets.get(i)));
}
if (numBatches == 1) {
vecRet = allVecs.get(0);
} else if (numBatches > 1) {
vecRet = ColumnVector.concatenate(allVecs.toArray(new ColumnVector[0]));
} else {
throw new IllegalStateException("Can't build a ColumnVector when no Arrow batches specified");
}
} finally {
// close the vectors that were concatenated
if (numBatches > 1) {
allVecs.forEach(cv -> cv.close());
}
}
return vecRet;
}

@Override
public void close() {
// memory buffers owned outside of this
}

@Override
public String toString() {
return "ArrowColumnBuilder{" +
"type=" + type +
", data=" + data +
", validity=" + validity +
", offsets=" + offsets +
", nullCount=" + nullCount +
", rows=" + rows +
'}';
}
}
49 changes: 49 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -310,6 +311,50 @@ public BaseDeviceMemoryBuffer getDeviceBufferFor(BufferType type) {
return srcBuffer;
}

/**
* Ensures the ByteBuffer passed in is a direct byte buffer.
* If it is not then it creates one and copies the data in
* the byte buffer passed in to the direct byte buffer
* it created and returns it.
*/
private static ByteBuffer bufferAsDirect(ByteBuffer buf) {
ByteBuffer bufferOut = buf;
if (bufferOut != null && !bufferOut.isDirect()) {
bufferOut = ByteBuffer.allocateDirect(buf.remaining());
bufferOut.put(buf);
bufferOut.flip();
}
return bufferOut;
}

/**
* Create a ColumnVector from the Apache Arrow byte buffers passed in.
* Any of the buffers not used for that datatype should be set to null.
* The buffers are expected to be off heap buffers, but if they are not,
* it will handle copying them to direct byte buffers.
* This only supports primitive types. Strings, Decimals and nested types
* such as list and struct are not supported.
* @param type - type of the column
* @param numRows - Number of rows in the arrow column
* @param nullCount - Null count
* @param data - ByteBuffer of the Arrow data buffer
* @param validity - ByteBuffer of the Arrow validity buffer
* @param offsets - ByteBuffer of the Arrow offsets buffer
* @return - new ColumnVector
*/
public static ColumnVector fromArrow(
DType type,
long numRows,
long nullCount,
ByteBuffer data,
ByteBuffer validity,
ByteBuffer offsets) {
long columnHandle = fromArrow(type.typeId.getNativeId(), numRows, nullCount,
bufferAsDirect(data), bufferAsDirect(validity), bufferAsDirect(offsets));
ColumnVector vec = new ColumnVector(columnHandle);
return vec;
}

/**
* Create a new vector of length rows, where each row is filled with the Scalar's
* value
Expand Down Expand Up @@ -615,6 +660,10 @@ public ColumnVector castTo(DType type) {

private static native long sequence(long initialValue, long step, int rows);

private static native long fromArrow(int type, long col_length,
long null_count, ByteBuffer data, ByteBuffer validity,
ByteBuffer offsets) throws CudfException;

private static native long fromScalar(long scalarHandle, int rowCount) throws CudfException;

private static native long makeList(long[] handles, long typeHandle, int scale, long rows)
Expand Down
75 changes: 75 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
* limitations under the License.
*/

#include <arrow/api.h>
#include <cudf/column/column_factories.hpp>
#include <cudf/concatenate.hpp>
#include <cudf/filling.hpp>
#include <cudf/interop.hpp>
#include <cudf/hashing.hpp>
#include <cudf/reshape.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/detail/interop.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar_factories.hpp>
Expand Down Expand Up @@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass,
jint j_type,
jlong j_col_length,
jlong j_null_count,
jobject j_data_obj,
jobject j_validity_obj,
jobject j_offsets_obj) {
try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
// not all the buffers are used for all types
void const *data_address = 0;
int data_length = 0;
if (j_data_obj != 0) {
data_address = env->GetDirectBufferAddress(j_data_obj);
data_length = env->GetDirectBufferCapacity(j_data_obj);
}
void const *validity_address = 0;
int validity_length = 0;
if (j_validity_obj != 0) {
validity_address = env->GetDirectBufferAddress(j_validity_obj);
validity_length = env->GetDirectBufferCapacity(j_validity_obj);
}
void const *offsets_address = 0;
int offsets_length = 0;
if (j_offsets_obj != 0) {
offsets_address = env->GetDirectBufferAddress(j_offsets_obj);
offsets_length = env->GetDirectBufferCapacity(j_offsets_obj);
}
auto data_buffer = arrow::Buffer::Wrap(static_cast<const char *>(data_address), static_cast<int>(data_length));
auto null_buffer = arrow::Buffer::Wrap(static_cast<const char *>(validity_address), static_cast<int>(validity_length));
auto offsets_buffer = arrow::Buffer::Wrap(static_cast<const char *>(offsets_address), static_cast<int>(offsets_length));

cudf::jni::native_jlongArray outcol_handles(env, 1);
std::shared_ptr<arrow::Array> arrow_array;
switch (n_type) {
case cudf::type_id::DECIMAL32:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0);
break;
case cudf::type_id::DECIMAL64:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0);
break;
case cudf::type_id::STRUCT:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0);
break;
case cudf::type_id::LIST:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0);
break;
case cudf::type_id::DICTIONARY32:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0);
break;
case cudf::type_id::STRING:
arrow_array = std::make_shared<arrow::StringArray>(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count);
break;
default:
// this handles the primitive types
arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count);
}
auto name_and_type = arrow::field("col", arrow_array->type());
std::vector<std::shared_ptr<arrow::Field>> fields = {name_and_type};
std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(fields);
auto arrow_table = arrow::Table::Make(schema, std::vector<std::shared_ptr<arrow::Array>>{arrow_array});
std::unique_ptr<cudf::table> table_result = cudf::from_arrow(*(arrow_table));
std::vector<std::unique_ptr<cudf::column>> retCols = table_result->release();
if (retCols.size() != 1) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0);
}
return reinterpret_cast<jlong>(retCols[0].release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
jlongArray handles,
jlong j_type,
Expand Down
Loading

0 comments on commit cbc0394

Please sign in to comment.