diff --git a/java/pom.xml b/java/pom.xml
index ddd0d06a74f..387ef1cb65b 100755
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -132,6 +132,12 @@
2.25.0
test
+
+ org.apache.arrow
+ arrow-vector
+ ${arrow.version}
+ test
+
@@ -151,6 +157,7 @@
ALL
${project.build.directory}/cmake-build
1.7.30
+ 0.15.1
diff --git a/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java b/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java
new file mode 100644
index 00000000000..b3c97930d2a
--- /dev/null
+++ b/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java
@@ -0,0 +1,113 @@
+/*
+ *
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package ai.rapids.cudf;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+
+/**
+ * Column builder from Arrow data. This builder takes in byte buffers referencing
+ * Arrow data and allows efficient building of CUDF ColumnVectors from that Arrow data.
+ * The caller can add multiple batches where each batch corresponds to Arrow data
+ * and those batches get concatenated together after being converted to CUDF
+ * ColumnVectors.
+ * This currently only supports primitive types and Strings, Decimals and nested types
+ * such as list and struct are not supported.
+ */
+public final class ArrowColumnBuilder implements AutoCloseable {
+ private DType type;
+ private final ArrayList data = new ArrayList<>();
+ private final ArrayList validity = new ArrayList<>();
+ private final ArrayList offsets = new ArrayList<>();
+ private final ArrayList nullCount = new ArrayList<>();
+ private final ArrayList rows = new ArrayList<>();
+
+ public ArrowColumnBuilder(HostColumnVector.DataType type) {
+ this.type = type.getType();
+ }
+
+ /**
+ * Add an Arrow buffer. This API allows you to add multiple if you want them
+ * combined into a single ColumnVector.
+ * Note, this takes all data, validity, and offsets buffers, but they may not all
+ * be needed based on the data type. The buffer should be null if its not used
+ * for that type.
+ * This API only supports primitive types and Strings, Decimals and nested types
+ * such as list and struct are not supported.
+ * @param rows - number of rows in this Arrow buffer
+ * @param nullCount - number of null values in this Arrow buffer
+ * @param data - ByteBuffer of the Arrow data buffer
+ * @param validity - ByteBuffer of the Arrow validity buffer
+ * @param offsets - ByteBuffer of the Arrow offsets buffer
+ */
+ public void addBatch(long rows, long nullCount, ByteBuffer data, ByteBuffer validity,
+ ByteBuffer offsets) {
+ this.rows.add(rows);
+ this.nullCount.add(nullCount);
+ this.data.add(data);
+ this.validity.add(validity);
+ this.offsets.add(offsets);
+ }
+
+ /**
+ * Create the immutable ColumnVector, copied to the device based on the Arrow data.
+ * @return - new ColumnVector
+ */
+ public final ColumnVector buildAndPutOnDevice() {
+ int numBatches = rows.size();
+ ArrayList allVecs = new ArrayList<>(numBatches);
+ ColumnVector vecRet;
+ try {
+ for (int i = 0; i < numBatches; i++) {
+ allVecs.add(ColumnVector.fromArrow(type, rows.get(i), nullCount.get(i),
+ data.get(i), validity.get(i), offsets.get(i)));
+ }
+ if (numBatches == 1) {
+ vecRet = allVecs.get(0);
+ } else if (numBatches > 1) {
+ vecRet = ColumnVector.concatenate(allVecs.toArray(new ColumnVector[0]));
+ } else {
+ throw new IllegalStateException("Can't build a ColumnVector when no Arrow batches specified");
+ }
+ } finally {
+ // close the vectors that were concatenated
+ if (numBatches > 1) {
+ allVecs.forEach(cv -> cv.close());
+ }
+ }
+ return vecRet;
+ }
+
+ @Override
+ public void close() {
+ // memory buffers owned outside of this
+ }
+
+ @Override
+ public String toString() {
+ return "ArrowColumnBuilder{" +
+ "type=" + type +
+ ", data=" + data +
+ ", validity=" + validity +
+ ", offsets=" + offsets +
+ ", nullCount=" + nullCount +
+ ", rows=" + rows +
+ '}';
+ }
+}
diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java
index 88c024a437b..252f869a049 100644
--- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java
+++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java
@@ -25,6 +25,7 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@@ -310,6 +311,50 @@ public BaseDeviceMemoryBuffer getDeviceBufferFor(BufferType type) {
return srcBuffer;
}
+ /**
+ * Ensures the ByteBuffer passed in is a direct byte buffer.
+ * If it is not then it creates one and copies the data in
+ * the byte buffer passed in to the direct byte buffer
+ * it created and returns it.
+ */
+ private static ByteBuffer bufferAsDirect(ByteBuffer buf) {
+ ByteBuffer bufferOut = buf;
+ if (bufferOut != null && !bufferOut.isDirect()) {
+ bufferOut = ByteBuffer.allocateDirect(buf.remaining());
+ bufferOut.put(buf);
+ bufferOut.flip();
+ }
+ return bufferOut;
+ }
+
+ /**
+ * Create a ColumnVector from the Apache Arrow byte buffers passed in.
+ * Any of the buffers not used for that datatype should be set to null.
+ * The buffers are expected to be off heap buffers, but if they are not,
+ * it will handle copying them to direct byte buffers.
+ * This only supports primitive types. Strings, Decimals and nested types
+ * such as list and struct are not supported.
+ * @param type - type of the column
+ * @param numRows - Number of rows in the arrow column
+ * @param nullCount - Null count
+ * @param data - ByteBuffer of the Arrow data buffer
+ * @param validity - ByteBuffer of the Arrow validity buffer
+ * @param offsets - ByteBuffer of the Arrow offsets buffer
+ * @return - new ColumnVector
+ */
+ public static ColumnVector fromArrow(
+ DType type,
+ long numRows,
+ long nullCount,
+ ByteBuffer data,
+ ByteBuffer validity,
+ ByteBuffer offsets) {
+ long columnHandle = fromArrow(type.typeId.getNativeId(), numRows, nullCount,
+ bufferAsDirect(data), bufferAsDirect(validity), bufferAsDirect(offsets));
+ ColumnVector vec = new ColumnVector(columnHandle);
+ return vec;
+ }
+
/**
* Create a new vector of length rows, where each row is filled with the Scalar's
* value
@@ -615,6 +660,10 @@ public ColumnVector castTo(DType type) {
private static native long sequence(long initialValue, long step, int rows);
+ private static native long fromArrow(int type, long col_length,
+ long null_count, ByteBuffer data, ByteBuffer validity,
+ ByteBuffer offsets) throws CudfException;
+
private static native long fromScalar(long scalarHandle, int rowCount) throws CudfException;
private static native long makeList(long[] handles, long typeHandle, int scale, long rows)
diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp
index 3bce4912fa4..a1e8517c646 100644
--- a/java/src/main/native/src/ColumnVectorJni.cpp
+++ b/java/src/main/native/src/ColumnVectorJni.cpp
@@ -14,12 +14,15 @@
* limitations under the License.
*/
+#include
#include
#include
#include
+#include
#include
#include
#include
+#include
#include
#include
#include
@@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
CATCH_STD(env, 0);
}
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass,
+ jint j_type,
+ jlong j_col_length,
+ jlong j_null_count,
+ jobject j_data_obj,
+ jobject j_validity_obj,
+ jobject j_offsets_obj) {
+ try {
+ cudf::jni::auto_set_device(env);
+ cudf::type_id n_type = static_cast(j_type);
+ // not all the buffers are used for all types
+ void const *data_address = 0;
+ int data_length = 0;
+ if (j_data_obj != 0) {
+ data_address = env->GetDirectBufferAddress(j_data_obj);
+ data_length = env->GetDirectBufferCapacity(j_data_obj);
+ }
+ void const *validity_address = 0;
+ int validity_length = 0;
+ if (j_validity_obj != 0) {
+ validity_address = env->GetDirectBufferAddress(j_validity_obj);
+ validity_length = env->GetDirectBufferCapacity(j_validity_obj);
+ }
+ void const *offsets_address = 0;
+ int offsets_length = 0;
+ if (j_offsets_obj != 0) {
+ offsets_address = env->GetDirectBufferAddress(j_offsets_obj);
+ offsets_length = env->GetDirectBufferCapacity(j_offsets_obj);
+ }
+ auto data_buffer = arrow::Buffer::Wrap(static_cast(data_address), static_cast(data_length));
+ auto null_buffer = arrow::Buffer::Wrap(static_cast(validity_address), static_cast(validity_length));
+ auto offsets_buffer = arrow::Buffer::Wrap(static_cast(offsets_address), static_cast(offsets_length));
+
+ cudf::jni::native_jlongArray outcol_handles(env, 1);
+ std::shared_ptr arrow_array;
+ switch (n_type) {
+ case cudf::type_id::DECIMAL32:
+ JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0);
+ break;
+ case cudf::type_id::DECIMAL64:
+ JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0);
+ break;
+ case cudf::type_id::STRUCT:
+ JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0);
+ break;
+ case cudf::type_id::LIST:
+ JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0);
+ break;
+ case cudf::type_id::DICTIONARY32:
+ JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0);
+ break;
+ case cudf::type_id::STRING:
+ arrow_array = std::make_shared(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count);
+ break;
+ default:
+ // this handles the primitive types
+ arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count);
+ }
+ auto name_and_type = arrow::field("col", arrow_array->type());
+ std::vector> fields = {name_and_type};
+ std::shared_ptr schema = std::make_shared(fields);
+ auto arrow_table = arrow::Table::Make(schema, std::vector>{arrow_array});
+ std::unique_ptr table_result = cudf::from_arrow(*(arrow_table));
+ std::vector> retCols = table_result->release();
+ if (retCols.size() != 1) {
+ JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0);
+ }
+ return reinterpret_cast(retCols[0].release());
+ }
+ CATCH_STD(env, 0);
+}
+
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
jlongArray handles,
jlong j_type,
diff --git a/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java
new file mode 100644
index 00000000000..d8ba4548b6d
--- /dev/null
+++ b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java
@@ -0,0 +1,330 @@
+/*
+ *
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package ai.rapids.cudf;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+
+import ai.rapids.cudf.HostColumnVector.BasicType;
+import ai.rapids.cudf.HostColumnVector.ListType;
+import ai.rapids.cudf.HostColumnVector.StructType;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.util.Text;
+
+import org.junit.jupiter.api.Test;
+
+import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class ArrowColumnVectorTest extends CudfTestBase {
+
+ @Test
+ void testArrowIntMultiBatches() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT32));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ int numVecs = 4;
+ IntVector[] vectors = new IntVector[numVecs];
+ try {
+ ArrayList expectedArr = new ArrayList();
+ for (int j = 0; j < numVecs; j++) {
+ int pos = 0;
+ int count = 10000;
+ IntVector vector = new IntVector("intVec", allocator);
+ int start = count * j;
+ int end = count * (j + 1);
+ for (int i = start; i < end; i++) {
+ expectedArr.add(i);
+ ((IntVector) vector).setSafe(pos, i);
+ pos++;
+ }
+ vector.setValueCount(count);
+ vectors[j] = vector;
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ }
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ ColumnVector expected = ColumnVector.fromBoxedInts(expectedArr.toArray(new Integer[0]));
+ assertEquals(cv.getType(), DType.INT32);
+ assertColumnsAreEqual(expected, cv, "ints");
+ } finally {
+ for (int i = 0; i < numVecs; i++) {
+ vectors[i].close();
+ }
+ }
+ }
+
+ @Test
+ void testArrowLong() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT64));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (BigIntVector vector = new BigIntVector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ expectedArr.add(new Long(i));
+ ((BigIntVector) vector).setSafe(i, i);
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.INT64);
+ ColumnVector expected = ColumnVector.fromBoxedLongs(expectedArr.toArray(new Long[0]));
+ assertColumnsAreEqual(expected, cv, "Longs");
+ }
+ }
+
+ @Test
+ void testArrowLongOnHeap() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT64));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (BigIntVector vector = new BigIntVector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ expectedArr.add(new Long(i));
+ ((BigIntVector) vector).setSafe(i, i);
+ }
+ vector.setValueCount(count);
+ // test that we handle convert buffer to direct byte buffer if its on the heap
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer dataOnHeap = ByteBuffer.allocate(data.remaining());
+ dataOnHeap.put(data);
+ dataOnHeap.flip();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ ByteBuffer validOnHeap = ByteBuffer.allocate(valid.remaining());
+ validOnHeap.put(data);
+ validOnHeap.flip();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), dataOnHeap, validOnHeap, null);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.INT64);
+ ColumnVector expected = ColumnVector.fromBoxedLongs(expectedArr.toArray(new Long[0]));
+ assertColumnsAreEqual(expected, cv, "Longs");
+ }
+ }
+
+ @Test
+ void testArrowDouble() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.FLOAT64));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (Float8Vector vector = new Float8Vector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ expectedArr.add(new Double(i));
+ ((Float8Vector) vector).setSafe(i, i);
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.FLOAT64);
+ double[] array = expectedArr.stream().mapToDouble(i->i).toArray();
+ ColumnVector expected = ColumnVector.fromDoubles(array);
+ assertColumnsAreEqual(expected, cv, "doubles");
+ }
+ }
+
+ @Test
+ void testArrowFloat() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.FLOAT32));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (Float4Vector vector = new Float4Vector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ expectedArr.add(new Float(i));
+ ((Float4Vector) vector).setSafe(i, i);
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.FLOAT32);
+ float[] floatArray = new float[expectedArr.size()];
+ int i = 0;
+ for (Float f : expectedArr) {
+ floatArray[i++] = (f != null ? f : Float.NaN); // Or whatever default you want.
+ }
+ ColumnVector expected = ColumnVector.fromFloats(floatArray);
+ assertColumnsAreEqual(expected, cv, "floats");
+ }
+ }
+
+ @Test
+ void testArrowString() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.STRING));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (VarCharVector vector = new VarCharVector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ String toAdd = i + "testString";
+ expectedArr.add(toAdd);
+ ((VarCharVector) vector).setSafe(i, new Text(toAdd));
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ ByteBuffer offsets = vector.getOffsetBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, offsets);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.STRING);
+ ColumnVector expected = ColumnVector.fromStrings(expectedArr.toArray(new String[0]));
+ assertColumnsAreEqual(expected, cv, "Strings");
+ }
+ }
+
+ @Test
+ void testArrowStringOnHeap() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.STRING));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (VarCharVector vector = new VarCharVector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ String toAdd = i + "testString";
+ expectedArr.add(toAdd);
+ ((VarCharVector) vector).setSafe(i, new Text(toAdd));
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ ByteBuffer offsets = vector.getOffsetBuffer().nioBuffer();
+ ByteBuffer dataOnHeap = ByteBuffer.allocate(data.remaining());
+ dataOnHeap.put(data);
+ dataOnHeap.flip();
+ ByteBuffer validOnHeap = ByteBuffer.allocate(valid.remaining());
+ validOnHeap.put(data);
+ validOnHeap.flip();
+ ByteBuffer offsetsOnHeap = ByteBuffer.allocate(offsets.remaining());
+ offsetsOnHeap.put(offsets);
+ offsetsOnHeap.flip();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), dataOnHeap, validOnHeap, offsetsOnHeap);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.STRING);
+ ColumnVector expected = ColumnVector.fromStrings(expectedArr.toArray(new String[0]));
+ assertColumnsAreEqual(expected, cv, "Strings");
+ }
+ }
+
+ @Test
+ void testArrowDays() {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS));
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (DateDayVector vector = new DateDayVector("vec", allocator)) {
+ ArrayList expectedArr = new ArrayList();
+ int count = 10000;
+ for (int i = 0; i < count; i++) {
+ expectedArr.add(i);
+ ((DateDayVector) vector).setSafe(i, i);
+ }
+ vector.setValueCount(count);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ ColumnVector cv = builder.buildAndPutOnDevice();
+ assertEquals(cv.getType(), DType.TIMESTAMP_DAYS);
+ int[] array = expectedArr.stream().mapToInt(i->i).toArray();
+ ColumnVector expected = ColumnVector.daysFromInts(array);
+ assertColumnsAreEqual(expected, cv, "timestamp days");
+ }
+ }
+
+ @Test
+ void testArrowDecimalThrows() {
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (DecimalVector vector = new DecimalVector("vec", allocator, 7, 3)) {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 3)));
+ ((DecimalVector) vector).setSafe(0, -3);
+ ((DecimalVector) vector).setSafe(1, 1);
+ ((DecimalVector) vector).setSafe(2, 2);
+ ((DecimalVector) vector).setSafe(3, 3);
+ ((DecimalVector) vector).setSafe(4, 4);
+ ((DecimalVector) vector).setSafe(5, 5);
+ vector.setValueCount(6);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ assertThrows(IllegalArgumentException.class, () -> {
+ builder.buildAndPutOnDevice();
+ });
+ }
+ }
+
+ @Test
+ void testArrowDecimal64Throws() {
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (DecimalVector vector = new DecimalVector("vec", allocator, 18, 0)) {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, -11)));
+ ((DecimalVector) vector).setSafe(0, -3);
+ ((DecimalVector) vector).setSafe(1, 1);
+ ((DecimalVector) vector).setSafe(2, 2);
+ vector.setValueCount(3);
+ ByteBuffer data = vector.getDataBuffer().nioBuffer();
+ ByteBuffer valid = vector.getValidityBuffer().nioBuffer();
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null);
+ assertThrows(IllegalArgumentException.class, () -> {
+ builder.buildAndPutOnDevice();
+ });
+ }
+ }
+
+ @Test
+ void testArrowListThrows() {
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (ListVector vector = ListVector.empty("list", allocator)) {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new ListType(true, new HostColumnVector.BasicType(true, DType.STRING)));
+ // buffer don't matter as we expect it to throw anyway
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), null, null, null);
+ assertThrows(IllegalArgumentException.class, () -> {
+ builder.buildAndPutOnDevice();
+ });
+ }
+ }
+
+ @Test
+ void testArrowStructThrows() {
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ try (StructVector vector = StructVector.empty("struct", allocator)) {
+ ArrowColumnBuilder builder = new ArrowColumnBuilder(new StructType(true, new HostColumnVector.BasicType(true, DType.STRING)));
+ // buffer don't matter as we expect it to throw anyway
+ builder.addBatch(vector.getValueCount(), vector.getNullCount(), null, null, null);
+ assertThrows(IllegalArgumentException.class, () -> {
+ builder.buildAndPutOnDevice();
+ });
+ }
+ }
+}