diff --git a/java/pom.xml b/java/pom.xml index ddd0d06a74f..387ef1cb65b 100755 --- a/java/pom.xml +++ b/java/pom.xml @@ -132,6 +132,12 @@ 2.25.0 test + + org.apache.arrow + arrow-vector + ${arrow.version} + test + @@ -151,6 +157,7 @@ ALL ${project.build.directory}/cmake-build 1.7.30 + 0.15.1 diff --git a/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java b/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java new file mode 100644 index 00000000000..b3c97930d2a --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java @@ -0,0 +1,113 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.nio.ByteBuffer; +import java.util.ArrayList; + +/** + * Column builder from Arrow data. This builder takes in byte buffers referencing + * Arrow data and allows efficient building of CUDF ColumnVectors from that Arrow data. + * The caller can add multiple batches where each batch corresponds to Arrow data + * and those batches get concatenated together after being converted to CUDF + * ColumnVectors. + * This currently only supports primitive types and Strings, Decimals and nested types + * such as list and struct are not supported. + */ +public final class ArrowColumnBuilder implements AutoCloseable { + private DType type; + private final ArrayList data = new ArrayList<>(); + private final ArrayList validity = new ArrayList<>(); + private final ArrayList offsets = new ArrayList<>(); + private final ArrayList nullCount = new ArrayList<>(); + private final ArrayList rows = new ArrayList<>(); + + public ArrowColumnBuilder(HostColumnVector.DataType type) { + this.type = type.getType(); + } + + /** + * Add an Arrow buffer. This API allows you to add multiple if you want them + * combined into a single ColumnVector. + * Note, this takes all data, validity, and offsets buffers, but they may not all + * be needed based on the data type. The buffer should be null if its not used + * for that type. + * This API only supports primitive types and Strings, Decimals and nested types + * such as list and struct are not supported. + * @param rows - number of rows in this Arrow buffer + * @param nullCount - number of null values in this Arrow buffer + * @param data - ByteBuffer of the Arrow data buffer + * @param validity - ByteBuffer of the Arrow validity buffer + * @param offsets - ByteBuffer of the Arrow offsets buffer + */ + public void addBatch(long rows, long nullCount, ByteBuffer data, ByteBuffer validity, + ByteBuffer offsets) { + this.rows.add(rows); + this.nullCount.add(nullCount); + this.data.add(data); + this.validity.add(validity); + this.offsets.add(offsets); + } + + /** + * Create the immutable ColumnVector, copied to the device based on the Arrow data. + * @return - new ColumnVector + */ + public final ColumnVector buildAndPutOnDevice() { + int numBatches = rows.size(); + ArrayList allVecs = new ArrayList<>(numBatches); + ColumnVector vecRet; + try { + for (int i = 0; i < numBatches; i++) { + allVecs.add(ColumnVector.fromArrow(type, rows.get(i), nullCount.get(i), + data.get(i), validity.get(i), offsets.get(i))); + } + if (numBatches == 1) { + vecRet = allVecs.get(0); + } else if (numBatches > 1) { + vecRet = ColumnVector.concatenate(allVecs.toArray(new ColumnVector[0])); + } else { + throw new IllegalStateException("Can't build a ColumnVector when no Arrow batches specified"); + } + } finally { + // close the vectors that were concatenated + if (numBatches > 1) { + allVecs.forEach(cv -> cv.close()); + } + } + return vecRet; + } + + @Override + public void close() { + // memory buffers owned outside of this + } + + @Override + public String toString() { + return "ArrowColumnBuilder{" + + "type=" + type + + ", data=" + data + + ", validity=" + validity + + ", offsets=" + offsets + + ", nullCount=" + nullCount + + ", rows=" + rows + + '}'; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 88c024a437b..252f869a049 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -25,6 +25,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -310,6 +311,50 @@ public BaseDeviceMemoryBuffer getDeviceBufferFor(BufferType type) { return srcBuffer; } + /** + * Ensures the ByteBuffer passed in is a direct byte buffer. + * If it is not then it creates one and copies the data in + * the byte buffer passed in to the direct byte buffer + * it created and returns it. + */ + private static ByteBuffer bufferAsDirect(ByteBuffer buf) { + ByteBuffer bufferOut = buf; + if (bufferOut != null && !bufferOut.isDirect()) { + bufferOut = ByteBuffer.allocateDirect(buf.remaining()); + bufferOut.put(buf); + bufferOut.flip(); + } + return bufferOut; + } + + /** + * Create a ColumnVector from the Apache Arrow byte buffers passed in. + * Any of the buffers not used for that datatype should be set to null. + * The buffers are expected to be off heap buffers, but if they are not, + * it will handle copying them to direct byte buffers. + * This only supports primitive types. Strings, Decimals and nested types + * such as list and struct are not supported. + * @param type - type of the column + * @param numRows - Number of rows in the arrow column + * @param nullCount - Null count + * @param data - ByteBuffer of the Arrow data buffer + * @param validity - ByteBuffer of the Arrow validity buffer + * @param offsets - ByteBuffer of the Arrow offsets buffer + * @return - new ColumnVector + */ + public static ColumnVector fromArrow( + DType type, + long numRows, + long nullCount, + ByteBuffer data, + ByteBuffer validity, + ByteBuffer offsets) { + long columnHandle = fromArrow(type.typeId.getNativeId(), numRows, nullCount, + bufferAsDirect(data), bufferAsDirect(validity), bufferAsDirect(offsets)); + ColumnVector vec = new ColumnVector(columnHandle); + return vec; + } + /** * Create a new vector of length rows, where each row is filled with the Scalar's * value @@ -615,6 +660,10 @@ public ColumnVector castTo(DType type) { private static native long sequence(long initialValue, long step, int rows); + private static native long fromArrow(int type, long col_length, + long null_count, ByteBuffer data, ByteBuffer validity, + ByteBuffer offsets) throws CudfException; + private static native long fromScalar(long scalarHandle, int rowCount) throws CudfException; private static native long makeList(long[] handles, long typeHandle, int scale, long rows) diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 3bce4912fa4..a1e8517c646 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -14,12 +14,15 @@ * limitations under the License. */ +#include #include #include #include +#include #include #include #include +#include #include #include #include @@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass, + jint j_type, + jlong j_col_length, + jlong j_null_count, + jobject j_data_obj, + jobject j_validity_obj, + jobject j_offsets_obj) { + try { + cudf::jni::auto_set_device(env); + cudf::type_id n_type = static_cast(j_type); + // not all the buffers are used for all types + void const *data_address = 0; + int data_length = 0; + if (j_data_obj != 0) { + data_address = env->GetDirectBufferAddress(j_data_obj); + data_length = env->GetDirectBufferCapacity(j_data_obj); + } + void const *validity_address = 0; + int validity_length = 0; + if (j_validity_obj != 0) { + validity_address = env->GetDirectBufferAddress(j_validity_obj); + validity_length = env->GetDirectBufferCapacity(j_validity_obj); + } + void const *offsets_address = 0; + int offsets_length = 0; + if (j_offsets_obj != 0) { + offsets_address = env->GetDirectBufferAddress(j_offsets_obj); + offsets_length = env->GetDirectBufferCapacity(j_offsets_obj); + } + auto data_buffer = arrow::Buffer::Wrap(static_cast(data_address), static_cast(data_length)); + auto null_buffer = arrow::Buffer::Wrap(static_cast(validity_address), static_cast(validity_length)); + auto offsets_buffer = arrow::Buffer::Wrap(static_cast(offsets_address), static_cast(offsets_length)); + + cudf::jni::native_jlongArray outcol_handles(env, 1); + std::shared_ptr arrow_array; + switch (n_type) { + case cudf::type_id::DECIMAL32: + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0); + break; + case cudf::type_id::DECIMAL64: + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0); + break; + case cudf::type_id::STRUCT: + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0); + break; + case cudf::type_id::LIST: + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0); + break; + case cudf::type_id::DICTIONARY32: + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0); + break; + case cudf::type_id::STRING: + arrow_array = std::make_shared(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count); + break; + default: + // this handles the primitive types + arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count); + } + auto name_and_type = arrow::field("col", arrow_array->type()); + std::vector> fields = {name_and_type}; + std::shared_ptr schema = std::make_shared(fields); + auto arrow_table = arrow::Table::Make(schema, std::vector>{arrow_array}); + std::unique_ptr table_result = cudf::from_arrow(*(arrow_table)); + std::vector> retCols = table_result->release(); + if (retCols.size() != 1) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0); + } + return reinterpret_cast(retCols[0].release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object, jlongArray handles, jlong j_type, diff --git a/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java new file mode 100644 index 00000000000..d8ba4548b6d --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java @@ -0,0 +1,330 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.nio.ByteBuffer; +import java.util.ArrayList; + +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructType; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.util.Text; + +import org.junit.jupiter.api.Test; + +import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ArrowColumnVectorTest extends CudfTestBase { + + @Test + void testArrowIntMultiBatches() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT32)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + int numVecs = 4; + IntVector[] vectors = new IntVector[numVecs]; + try { + ArrayList expectedArr = new ArrayList(); + for (int j = 0; j < numVecs; j++) { + int pos = 0; + int count = 10000; + IntVector vector = new IntVector("intVec", allocator); + int start = count * j; + int end = count * (j + 1); + for (int i = start; i < end; i++) { + expectedArr.add(i); + ((IntVector) vector).setSafe(pos, i); + pos++; + } + vector.setValueCount(count); + vectors[j] = vector; + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + } + ColumnVector cv = builder.buildAndPutOnDevice(); + ColumnVector expected = ColumnVector.fromBoxedInts(expectedArr.toArray(new Integer[0])); + assertEquals(cv.getType(), DType.INT32); + assertColumnsAreEqual(expected, cv, "ints"); + } finally { + for (int i = 0; i < numVecs; i++) { + vectors[i].close(); + } + } + } + + @Test + void testArrowLong() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT64)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (BigIntVector vector = new BigIntVector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + expectedArr.add(new Long(i)); + ((BigIntVector) vector).setSafe(i, i); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.INT64); + ColumnVector expected = ColumnVector.fromBoxedLongs(expectedArr.toArray(new Long[0])); + assertColumnsAreEqual(expected, cv, "Longs"); + } + } + + @Test + void testArrowLongOnHeap() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.INT64)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (BigIntVector vector = new BigIntVector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + expectedArr.add(new Long(i)); + ((BigIntVector) vector).setSafe(i, i); + } + vector.setValueCount(count); + // test that we handle convert buffer to direct byte buffer if its on the heap + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer dataOnHeap = ByteBuffer.allocate(data.remaining()); + dataOnHeap.put(data); + dataOnHeap.flip(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + ByteBuffer validOnHeap = ByteBuffer.allocate(valid.remaining()); + validOnHeap.put(data); + validOnHeap.flip(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), dataOnHeap, validOnHeap, null); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.INT64); + ColumnVector expected = ColumnVector.fromBoxedLongs(expectedArr.toArray(new Long[0])); + assertColumnsAreEqual(expected, cv, "Longs"); + } + } + + @Test + void testArrowDouble() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.FLOAT64)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (Float8Vector vector = new Float8Vector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + expectedArr.add(new Double(i)); + ((Float8Vector) vector).setSafe(i, i); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.FLOAT64); + double[] array = expectedArr.stream().mapToDouble(i->i).toArray(); + ColumnVector expected = ColumnVector.fromDoubles(array); + assertColumnsAreEqual(expected, cv, "doubles"); + } + } + + @Test + void testArrowFloat() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.FLOAT32)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (Float4Vector vector = new Float4Vector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + expectedArr.add(new Float(i)); + ((Float4Vector) vector).setSafe(i, i); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.FLOAT32); + float[] floatArray = new float[expectedArr.size()]; + int i = 0; + for (Float f : expectedArr) { + floatArray[i++] = (f != null ? f : Float.NaN); // Or whatever default you want. + } + ColumnVector expected = ColumnVector.fromFloats(floatArray); + assertColumnsAreEqual(expected, cv, "floats"); + } + } + + @Test + void testArrowString() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.STRING)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (VarCharVector vector = new VarCharVector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + String toAdd = i + "testString"; + expectedArr.add(toAdd); + ((VarCharVector) vector).setSafe(i, new Text(toAdd)); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + ByteBuffer offsets = vector.getOffsetBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, offsets); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.STRING); + ColumnVector expected = ColumnVector.fromStrings(expectedArr.toArray(new String[0])); + assertColumnsAreEqual(expected, cv, "Strings"); + } + } + + @Test + void testArrowStringOnHeap() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.STRING)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (VarCharVector vector = new VarCharVector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + String toAdd = i + "testString"; + expectedArr.add(toAdd); + ((VarCharVector) vector).setSafe(i, new Text(toAdd)); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + ByteBuffer offsets = vector.getOffsetBuffer().nioBuffer(); + ByteBuffer dataOnHeap = ByteBuffer.allocate(data.remaining()); + dataOnHeap.put(data); + dataOnHeap.flip(); + ByteBuffer validOnHeap = ByteBuffer.allocate(valid.remaining()); + validOnHeap.put(data); + validOnHeap.flip(); + ByteBuffer offsetsOnHeap = ByteBuffer.allocate(offsets.remaining()); + offsetsOnHeap.put(offsets); + offsetsOnHeap.flip(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), dataOnHeap, validOnHeap, offsetsOnHeap); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.STRING); + ColumnVector expected = ColumnVector.fromStrings(expectedArr.toArray(new String[0])); + assertColumnsAreEqual(expected, cv, "Strings"); + } + } + + @Test + void testArrowDays() { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS)); + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (DateDayVector vector = new DateDayVector("vec", allocator)) { + ArrayList expectedArr = new ArrayList(); + int count = 10000; + for (int i = 0; i < count; i++) { + expectedArr.add(i); + ((DateDayVector) vector).setSafe(i, i); + } + vector.setValueCount(count); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + ColumnVector cv = builder.buildAndPutOnDevice(); + assertEquals(cv.getType(), DType.TIMESTAMP_DAYS); + int[] array = expectedArr.stream().mapToInt(i->i).toArray(); + ColumnVector expected = ColumnVector.daysFromInts(array); + assertColumnsAreEqual(expected, cv, "timestamp days"); + } + } + + @Test + void testArrowDecimalThrows() { + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (DecimalVector vector = new DecimalVector("vec", allocator, 7, 3)) { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 3))); + ((DecimalVector) vector).setSafe(0, -3); + ((DecimalVector) vector).setSafe(1, 1); + ((DecimalVector) vector).setSafe(2, 2); + ((DecimalVector) vector).setSafe(3, 3); + ((DecimalVector) vector).setSafe(4, 4); + ((DecimalVector) vector).setSafe(5, 5); + vector.setValueCount(6); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + assertThrows(IllegalArgumentException.class, () -> { + builder.buildAndPutOnDevice(); + }); + } + } + + @Test + void testArrowDecimal64Throws() { + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (DecimalVector vector = new DecimalVector("vec", allocator, 18, 0)) { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, -11))); + ((DecimalVector) vector).setSafe(0, -3); + ((DecimalVector) vector).setSafe(1, 1); + ((DecimalVector) vector).setSafe(2, 2); + vector.setValueCount(3); + ByteBuffer data = vector.getDataBuffer().nioBuffer(); + ByteBuffer valid = vector.getValidityBuffer().nioBuffer(); + builder.addBatch(vector.getValueCount(), vector.getNullCount(), data, valid, null); + assertThrows(IllegalArgumentException.class, () -> { + builder.buildAndPutOnDevice(); + }); + } + } + + @Test + void testArrowListThrows() { + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (ListVector vector = ListVector.empty("list", allocator)) { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new ListType(true, new HostColumnVector.BasicType(true, DType.STRING))); + // buffer don't matter as we expect it to throw anyway + builder.addBatch(vector.getValueCount(), vector.getNullCount(), null, null, null); + assertThrows(IllegalArgumentException.class, () -> { + builder.buildAndPutOnDevice(); + }); + } + } + + @Test + void testArrowStructThrows() { + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + try (StructVector vector = StructVector.empty("struct", allocator)) { + ArrowColumnBuilder builder = new ArrowColumnBuilder(new StructType(true, new HostColumnVector.BasicType(true, DType.STRING))); + // buffer don't matter as we expect it to throw anyway + builder.addBatch(vector.getValueCount(), vector.getNullCount(), null, null, null); + assertThrows(IllegalArgumentException.class, () -> { + builder.buildAndPutOnDevice(); + }); + } + } +}