Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI for ORC/Parquet writer compression statistics #13376

Merged
merged 12 commits into from
May 18, 2023
150 changes: 60 additions & 90 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -894,20 +894,19 @@ private static native long startWriteCSVToBuffer(String[] columnNames,

private static native void endWriteCSVToBuffer(long writerHandle);

private static class CSVTableWriter implements TableWriter {
private long writerHandle;
private static class CSVTableWriter extends TableWriter {
private HostBufferConsumer consumer;

private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) {
this.writerHandle = startWriteCSVToBuffer(options.getColumnNames(),
options.getIncludeHeader(),
options.getRowDelimiter(),
options.getFieldDelimiter(),
options.getNullValue(),
options.getTrueValue(),
options.getFalseValue(),
options.getQuoteStyle().nativeId,
consumer);
super(startWriteCSVToBuffer(options.getColumnNames(),
options.getIncludeHeader(),
options.getRowDelimiter(),
options.getFieldDelimiter(),
options.getNullValue(),
options.getTrueValue(),
options.getFalseValue(),
options.getQuoteStyle().nativeId,
consumer));
this.consumer = consumer;
}

Expand Down Expand Up @@ -1287,82 +1286,61 @@ public static Table readORC(ORCOptions opts, HostMemoryBuffer buffer,
opts.getDecimal128Columns()));
}

private static class ParquetTableWriter implements TableWriter {
private long handle;
private static class ParquetTableWriter extends TableWriter {
HostBufferConsumer consumer;

private ParquetTableWriter(ParquetWriterOptions options, File outputFile) {
String[] columnNames = options.getFlatColumnNames();
boolean[] columnNullabilities = options.getFlatIsNullable();
boolean[] timeInt96Values = options.getFlatIsTimeTypeInt96();
boolean[] isMapValues = options.getFlatIsMap();
boolean[] isBinaryValues = options.getFlatIsBinary();
int[] precisions = options.getFlatPrecision();
boolean[] hasParquetFieldIds = options.getFlatHasParquetFieldId();
int[] parquetFieldIds = options.getFlatParquetFieldId();
int[] flatNumChildren = options.getFlatNumChildren();

this.consumer = null;
this.handle = writeParquetFileBegin(columnNames,
super(writeParquetFileBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
flatNumChildren,
columnNullabilities,
options.getFlatNumChildren(),
options.getFlatIsNullable(),
options.getMetadataKeys(),
options.getMetadataValues(),
options.getCompressionType().nativeId,
options.getStatisticsFrequency().nativeId,
timeInt96Values,
precisions,
isMapValues,
isBinaryValues,
hasParquetFieldIds,
parquetFieldIds,
outputFile.getAbsolutePath());
options.getFlatIsTimeTypeInt96(),
options.getFlatPrecision(),
options.getFlatIsMap(),
options.getFlatIsBinary(),
options.getFlatHasParquetFieldId(),
options.getFlatParquetFieldId(),
outputFile.getAbsolutePath()));
this.consumer = null;
}

private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer consumer) {
String[] columnNames = options.getFlatColumnNames();
boolean[] columnNullabilities = options.getFlatIsNullable();
boolean[] timeInt96Values = options.getFlatIsTimeTypeInt96();
boolean[] isMapValues = options.getFlatIsMap();
boolean[] isBinaryValues = options.getFlatIsBinary();
int[] precisions = options.getFlatPrecision();
boolean[] hasParquetFieldIds = options.getFlatHasParquetFieldId();
int[] parquetFieldIds = options.getFlatParquetFieldId();
int[] flatNumChildren = options.getFlatNumChildren();

this.consumer = consumer;
this.handle = writeParquetBufferBegin(columnNames,
super(writeParquetBufferBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
flatNumChildren,
columnNullabilities,
options.getFlatNumChildren(),
options.getFlatIsNullable(),
options.getMetadataKeys(),
options.getMetadataValues(),
options.getCompressionType().nativeId,
options.getStatisticsFrequency().nativeId,
timeInt96Values,
precisions,
isMapValues,
isBinaryValues,
hasParquetFieldIds,
parquetFieldIds,
consumer);
options.getFlatIsTimeTypeInt96(),
options.getFlatPrecision(),
options.getFlatIsMap(),
options.getFlatIsBinary(),
options.getFlatHasParquetFieldId(),
options.getFlatParquetFieldId(),
consumer));
this.consumer = consumer;
}

@Override
public void write(Table table) {
if (handle == 0) {
if (writerHandle == 0) {
throw new IllegalStateException("Writer was already closed");
}
writeParquetChunk(handle, table.nativeHandle, table.getDeviceMemorySize());
writeParquetChunk(writerHandle, table.nativeHandle, table.getDeviceMemorySize());
}

@Override
public void close() throws CudfException {
if (handle != 0) {
writeParquetEnd(handle);
if (writerHandle != 0) {
writeParquetEnd(writerHandle);
}
handle = 0;
writerHandle = 0;
if (consumer != null) {
consumer.done();
consumer = null;
Expand Down Expand Up @@ -1425,19 +1403,18 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options,
for (ColumnView cv : columnViews) {
total += cv.getDeviceMemorySize();
}
writeParquetChunk(writer.handle, nativeHandle, total);
writeParquetChunk(writer.writerHandle, nativeHandle, total);
}
} finally {
deleteCudfTable(nativeHandle);
}
}

private static class ORCTableWriter implements TableWriter {
private long handle;
private static class ORCTableWriter extends TableWriter {
HostBufferConsumer consumer;

private ORCTableWriter(ORCWriterOptions options, File outputFile) {
this.handle = writeORCFileBegin(options.getFlatColumnNames(),
super(writeORCFileBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
options.getFlatNumChildren(),
options.getFlatIsNullable(),
Expand All @@ -1446,12 +1423,12 @@ private ORCTableWriter(ORCWriterOptions options, File outputFile) {
options.getCompressionType().nativeId,
options.getFlatPrecision(),
options.getFlatIsMap(),
outputFile.getAbsolutePath());
outputFile.getAbsolutePath()));
this.consumer = null;
}

private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) {
this.handle = writeORCBufferBegin(options.getFlatColumnNames(),
super(writeORCBufferBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
options.getFlatNumChildren(),
options.getFlatIsNullable(),
Expand All @@ -1460,24 +1437,24 @@ private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) {
options.getCompressionType().nativeId,
options.getFlatPrecision(),
options.getFlatIsMap(),
consumer);
consumer));
this.consumer = consumer;
}

@Override
public void write(Table table) {
if (handle == 0) {
if (writerHandle == 0) {
throw new IllegalStateException("Writer was already closed");
}
writeORCChunk(handle, table.nativeHandle, table.getDeviceMemorySize());
writeORCChunk(writerHandle, table.nativeHandle, table.getDeviceMemorySize());
}

@Override
public void close() throws CudfException {
if (handle != 0) {
writeORCEnd(handle);
if (writerHandle != 0) {
writeORCEnd(writerHandle);
}
handle = 0;
writerHandle = 0;
if (consumer != null) {
consumer.done();
consumer = null;
Expand Down Expand Up @@ -1506,52 +1483,45 @@ public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferCo
return new ORCTableWriter(options, consumer);
}

private static class ArrowIPCTableWriter implements TableWriter {
private static class ArrowIPCTableWriter extends TableWriter {
private final ArrowIPCWriterOptions.DoneOnGpu callback;
private long handle;
private HostBufferConsumer consumer;
private long maxChunkSize;

private ArrowIPCTableWriter(ArrowIPCWriterOptions options,
File outputFile) {
private ArrowIPCTableWriter(ArrowIPCWriterOptions options, File outputFile) {
super(writeArrowIPCFileBegin(options.getColumnNames(), outputFile.getAbsolutePath()));
this.callback = options.getCallback();
this.consumer = null;
this.maxChunkSize = options.getMaxChunkSize();
this.handle = writeArrowIPCFileBegin(
options.getColumnNames(),
outputFile.getAbsolutePath());
}

private ArrowIPCTableWriter(ArrowIPCWriterOptions options,
HostBufferConsumer consumer) {
private ArrowIPCTableWriter(ArrowIPCWriterOptions options, HostBufferConsumer consumer) {
super(writeArrowIPCBufferBegin(options.getColumnNames(), consumer));
this.callback = options.getCallback();
this.consumer = consumer;
this.maxChunkSize = options.getMaxChunkSize();
this.handle = writeArrowIPCBufferBegin(
options.getColumnNames(),
consumer);
}

@Override
public void write(Table table) {
if (handle == 0) {
if (writerHandle == 0) {
throw new IllegalStateException("Writer was already closed");
}
long arrowHandle = convertCudfToArrowTable(handle, table.nativeHandle);
long arrowHandle = convertCudfToArrowTable(writerHandle, table.nativeHandle);
try {
callback.doneWithTheGpu(table);
writeArrowIPCArrowChunk(handle, arrowHandle, maxChunkSize);
writeArrowIPCArrowChunk(writerHandle, arrowHandle, maxChunkSize);
} finally {
closeArrowTable(arrowHandle);
}
}

@Override
public void close() throws CudfException {
if (handle != 0) {
writeArrowIPCEnd(handle);
if (writerHandle != 0) {
writeArrowIPCEnd(writerHandle);
}
handle = 0;
writerHandle = 0;
if (consumer != null) {
consumer.done();
consumer = null;
Expand Down
76 changes: 57 additions & 19 deletions java/src/main/java/ai/rapids/cudf/TableWriter.java
Original file line number Diff line number Diff line change
@@ -1,37 +1,75 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.rapids.cudf;

/**
* Provides an interface for writing out Table information in multiple steps.
* A TableWriter will be returned from one of various factory functions in Table that
* let you set the format of the data and its destination. After that write can be called one or
* more times. When you are done writing call close to finish.
* A TableWriter will be returned from one of various factory functions in the Table class that
* let you set the format of the data and its destination. After that write can be called one or
* more times. When you are done writing call close to finish.
*/
public interface TableWriter extends AutoCloseable {
public abstract class TableWriter implements AutoCloseable {
protected long writerHandle;

TableWriter(long writerHandle) { this.writerHandle = writerHandle; }

/**
* Write out a table. Note that all columns must be in the same order each time this is called
* Write out a table. Note that all columns must be in the same order each time this is called
* and the format of each table cannot change.
* @param table what to write out.
*/
void write(Table table) throws CudfException;
abstract public void write(Table table) throws CudfException;

@Override
void close() throws CudfException;
abstract public void close() throws CudfException;

public static class WriteStatistics {
public final long numCompressedBytes; // The number of bytes that were successfully compressed
public final long numFailedBytes; // The number of bytes that failed to compress
public final long numSkippedBytes; // The number of bytes that were skipped during compression
public final double compressionRatio; // The compression ratio for the successfully compressed data

public WriteStatistics(long numCompressedBytes, long numFailedBytes, long numSkippedBytes,
double compressionRatio) {
this.numCompressedBytes = numCompressedBytes;
this.numFailedBytes = numFailedBytes;
this.numSkippedBytes = numSkippedBytes;
this.compressionRatio = compressionRatio;
}
}

/**
* Get the write statistics for the writer up to the last write call.
* Currently, only ORC and Parquet writers support write statistics.
* Calling this method on other writers will return null.
* @return The write statistics.
*/
public WriteStatistics getWriteStatistics() {
double[] statsData = getWriteStatistics(writerHandle);
assert statsData.length == 4 : "Unexpected write statistics data length";
return new WriteStatistics((long) statsData[0], (long) statsData[1], (long) statsData[2],
statsData[3]);
}

/**
* Get the write statistics for the writer up to the last write call.
* The data returned from native method is encoded as an array of doubles.
* @param writerHandle The handle to the writer.
* @return The write statistics.
*/
private static native double[] getWriteStatistics(long writerHandle);
}
Loading