Skip to content

Commit

Permalink
fix: intial update
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Oct 20, 2023
1 parent ccf4387 commit 8868127
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
package org.apache.arrow.compression;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -62,7 +66,7 @@ public void init() {
allocator = new RootAllocator(Long.MAX_VALUE);

dictionaryVector1 = (VarCharVector)
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("D1", allocator, null);
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1", allocator, null);

setVector(dictionaryVector1,
"foo".getBytes(StandardCharsets.UTF_8),
Expand All @@ -78,9 +82,7 @@ public void terminate() throws Exception {
dictionaryVector1.close();
allocator.close();
}




@Test
public void testArrowFileZstdRoundTrip() throws Exception {
// Prepare sample data
Expand Down Expand Up @@ -117,7 +119,6 @@ public void testArrowFileZstdRoundTrip() throws Exception {
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());

Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch());
String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD";
Assert.assertEquals(expectedMessage, exception.getMessage());
Expand Down Expand Up @@ -170,9 +171,8 @@ public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch());
String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD";
Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch());
Assert.assertEquals(expectedMessage, exception.getMessage());
}
}
Expand All @@ -196,34 +196,38 @@ public void testArrowStreamZstdRoundTrip() throws Exception {
fields.add(encodedVector1.getField());

VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields), allocator);
final int rowCount = 10;
final int rowCount = 3;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);

// Write an in-memory compressed arrow file
ByteArrayOutputStream out = new ByteArrayOutputStream();
File tempFile = File.createTempFile("dictionary_compression", ".arrow");
FileOutputStream fileOut = new FileOutputStream(tempFile);
try (final ArrowStreamWriter writer =
new ArrowStreamWriter(root, provider, Channels.newChannel(out), IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, Optional.of(7))) {
new ArrowStreamWriter(root, provider, Channels.newChannel(fileOut), IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD,
Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}

// Read the in-memory compressed arrow file with CommonsCompressionFactory provided
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, CommonsCompressionFactory.INSTANCE)) {
// Read the on-disk compressed arrow file with CommonsCompressionFactory provided
try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath());
ArrowStreamReader reader =
new ArrowStreamReader(channel, allocator, CommonsCompressionFactory.INSTANCE)) {
org.apache.arrow.vector.types.pojo.Schema schema = reader.getVectorSchemaRoot().getSchema();
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}

// Read the in-memory compressed arrow file without CompressionFactory provided
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, NoCompressionCodec.Factory.INSTANCE)) {
}

// Read the on-disk compressed arrow file without CompressionFactory provided
try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath());
ArrowStreamReader reader =
new ArrowStreamReader(channel, allocator, NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch());
String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD";
Assert.assertEquals(expectedMessage, exception.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ private void load(ArrowDictionaryBatch dictionaryBatch, FieldVector vector) {
VectorSchemaRoot root = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector), 0);
VectorLoader loader = new VectorLoader(root);
VectorLoader loader = new VectorLoader(root, this.compressionFactory);
try {
loader.load(dictionaryBatch.getDictionary());
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ public abstract class ArrowWriter implements AutoCloseable {

protected IpcOption option;

private CompressionCodec.Factory compressionFactory;

private CompressionUtil.CodecType codecType;

private Optional<Integer> compressionLevel;

protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
this(root, provider, out, IpcOption.DEFAULT);
}
Expand Down Expand Up @@ -99,6 +105,10 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
this.option = option;
this.dictionaryProvider = provider;

this.compressionFactory = compressionFactory;
this.codecType = codecType;
this.compressionLevel = compressionLevel;

List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());

MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
Expand Down Expand Up @@ -133,7 +143,11 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot);
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true,
this.compressionLevel.isPresent() ?
this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) :
this.compressionFactory.createCodec(this.codecType),
/*alignBuffers*/ true);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
try {
Expand Down

0 comments on commit 8868127

Please sign in to comment.