From 9d6670bcf8a87acd5709ee97011eb833ec670956 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 5 Jun 2024 05:09:41 +0530 Subject: [PATCH] fix: adding stream tests for views --- .../arrow/c/BufferImportTypeVisitor.java | 8 +- .../java/org/apache/arrow/c/StreamTest.java | 76 +++++++++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java index 5f91dc48f247f..e41c209979cb1 100644 --- a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java +++ b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java @@ -233,15 +233,17 @@ private List visitVariableWidthView(ArrowType type) { final int elementSize = BaseVariableWidthViewVector.ELEMENT_SIZE; final int lengthWidth = BaseVariableWidthViewVector.LENGTH_WIDTH; final int prefixWidth = BaseVariableWidthViewVector.PREFIX_WIDTH; + final int lengthPrefixWidth = lengthWidth + prefixWidth; // Map to store the data buffer index and the total length of data in that buffer Map dataBufferInfo = new HashMap<>(); for (int i = 0; i < fieldNode.getLength(); i++) { final int length = view.getInt((long) i * elementSize); if (length > BaseVariableWidthViewVector.INLINE_SIZE) { - assert maybeValidityBuffer != null; + checkState(maybeValidityBuffer != null, + "Validity buffer is required for data of type " + type); if (BitVectorHelper.get(maybeValidityBuffer, i) == 1) { final int bufferIndex = - view.getInt(((long) i * elementSize) + lengthWidth + prefixWidth); + view.getInt(((long) i * elementSize) + lengthPrefixWidth); if (dataBufferInfo.containsKey(bufferIndex)) { dataBufferInfo.compute(bufferIndex, (key, value) -> value != null ? value + (long) length : 0); } else { @@ -250,7 +252,7 @@ private List visitVariableWidthView(ArrowType type) { } } } - // fixed buffers for Utf8View or BinaryView are validity and view buffers + // fixed buffers for Utf8View or BinaryView are the validity buffer and the view buffer. final int fixedBufferCount = 2; // import data buffers for (Map.Entry entry : dataBufferInfo.entrySet()) { diff --git a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java index 68d4fc2a81e68..913fd0be7d6a0 100644 --- a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java @@ -40,6 +40,8 @@ import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ViewVarBinaryVector; +import org.apache.arrow.vector.ViewVarCharVector; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.dictionary.Dictionary; @@ -132,6 +134,80 @@ public void roundtripStrings() throws Exception { } } + @Test + public void roundtripStringViews() throws Exception { + final Schema schema = new Schema(Arrays.asList(Field.nullable("ints", new ArrowType.Int(32, true)), + Field.nullable("string_views", new ArrowType.Utf8View()))); + final List batches = new ArrayList<>(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final IntVector ints = (IntVector) root.getVector(0); + final ViewVarCharVector strs = (ViewVarCharVector) root.getVector(1); + VectorUnloader unloader = new VectorUnloader(root); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setSafe(1, 2); + ints.setSafe(2, 4); + ints.setSafe(3, 8); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(1, "a".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(3, "defg1234567890defg".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setNull(1); + ints.setSafe(2, 4); + ints.setNull(3); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setNull(1); + strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8)); + strs.setNull(3); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + roundtrip(schema, batches); + } + } + + @Test + public void roundtripBinaryViews() throws Exception { + final Schema schema = new Schema(Arrays.asList(Field.nullable("ints", new ArrowType.Int(32, true)), + Field.nullable("binary_views", new ArrowType.BinaryView()))); + final List batches = new ArrayList<>(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final IntVector ints = (IntVector) root.getVector(0); + final ViewVarBinaryVector strs = (ViewVarBinaryVector) root.getVector(1); + VectorUnloader unloader = new VectorUnloader(root); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setSafe(1, 2); + ints.setSafe(2, 4); + ints.setSafe(3, 8); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(1, "a".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(3, "defg1234567890defg".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setNull(1); + ints.setSafe(2, 4); + ints.setNull(3); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setNull(1); + strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8)); + strs.setNull(3); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + roundtrip(schema, batches); + } + } + @Test public void roundtripDictionary() throws Exception { final ArrowType.Int indexType = new ArrowType.Int(32, true);