Skip to content

Commit

Permalink
Fix Java Parquet write after writer API changes (#7655)
Browse files Browse the repository at this point in the history
This fixes the java build, but it required breaking changes to do it. I'll put up a corresponding change in the rapids plugin shortly.

#7654

was found as a part of this.

This is also not the final API that we will want. We need to redo how we configure the builders so that they can take advantage of the new APIs properly.

Authors:
  - Robert (Bobby) Evans (@revans2)

Approvers:
  - Jason Lowe (@jlowe)
  - Raza Jafri (@razajafri)

URL: #7655
  • Loading branch information
revans2 authored Mar 19, 2021
1 parent 8773a40 commit 8687182
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 74 deletions.
27 changes: 6 additions & 21 deletions java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,12 @@ public Builder withTimestampInt96(boolean int96) {
}

/**
* Overwrite flattened precision values for all decimal columns that are expected to be in
* this Table. The list of precisions should be an in-order traversal of all Decimal columns,
* including nested columns. Please look at the example below.
*
* NOTE: The number of `precisionValues` should be equal to the numbers of Decimal columns
* otherwise a CudfException will be thrown. Also note that the values will be overwritten
* every time this method is called
*
* Example:
* Table0 : c0[type: INT32]
* c1[type: Decimal32(3, 1)]
* c2[type: Struct[col0[type: Decimal(2, 1)],
* col1[type: INT64],
* col2[type: Decimal(8, 6)]]
* c3[type: Decimal64(12, 5)]
*
* Flattened list of precision from the above example will be {3, 2, 8, 12}
* This is a temporary hack to make things work. This API will go away once we can update the
* parquet APIs properly.
* @param precisionValues a value for each column, non-decimal columns are ignored.
* @return this for chaining.
*/
public Builder withPrecisionValues(int... precisionValues) {
public Builder withDecimalPrecisions(int ... precisionValues) {
this.precisionValues = precisionValues;
return this;
}
Expand All @@ -86,8 +73,6 @@ public ParquetWriterOptions build() {
}
}

public static final ParquetWriterOptions DEFAULT = new ParquetWriterOptions(new Builder());

public static Builder builder() {
return new Builder();
}
Expand All @@ -107,7 +92,7 @@ public StatisticsFrequency getStatisticsFrequency() {

/**
* Return the flattened list of precisions if set otherwise empty array will be returned.
* For a definition of what `flattened` means please look at {@link Builder#withPrecisionValues}
* For a definition of what `flattened` means please look at {@link Builder#withDecimalPrecisions}
*/
public int[] getPrecisions() {
return precisions;
Expand Down
17 changes: 6 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ private static class ParquetTableWriter implements TableWriter {
HostBufferConsumer consumer;

private ParquetTableWriter(ParquetWriterOptions options, File outputFile) {
int numColumns = options.getColumnNames().length;
assert (numColumns == options.getColumnNullability().length);
int[] precisions = options.getPrecisions();
if (precisions != null) {
assert (numColumns >= options.getPrecisions().length);
}
this.consumer = null;
this.handle = writeParquetFileBegin(options.getColumnNames(),
options.getColumnNullability(),
Expand Down Expand Up @@ -871,17 +877,6 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options,
return new ParquetTableWriter(options, consumer);
}

/**
* Writes this table to a Parquet file on the host
*
* @param outputFile file to write the table to
* @deprecated please use writeParquetChunked instead
*/
@Deprecated
public void writeParquet(File outputFile) {
writeParquet(ParquetWriterOptions.DEFAULT, outputFile);
}

/**
* Writes this table to a Parquet file on the host
*
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/WriterOptions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,7 +46,7 @@ protected static class WriterBuilder<T extends WriterBuilder> {
final List<Boolean> columnNullability = new ArrayList<>();

/**
* Add column name
* Add column name(s). For Parquet column names are not optional.
* @param columnNames
*/
public T withColumnNames(String... columnNames) {
Expand Down
60 changes: 38 additions & 22 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,29 +915,38 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin(
cudf::jni::native_jbooleanArray col_nullability(env, j_col_nullability);
cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys);
cudf::jni::native_jstringArray meta_values(env, j_metadata_values);
cudf::jni::native_jintArray precisions(env, j_precisions);

auto cpp_names = col_names.as_cpp_vector();
table_input_metadata metadata;
metadata.column_metadata.resize(col_nullability.size());
for (int i = 0; i < col_nullability.size(); i++) {
metadata.column_metadata[i]
.set_name(cpp_names[i])
.set_nullability(col_nullability[i])
.set_int96_timestamps(j_isInt96);
}

// Precisions is not always set
for (int i = 0; i < precisions.size(); i++) {
metadata.column_metadata[i]
.set_decimal_precision(precisions[i]);
}

auto d = col_nullability.data();
std::vector<bool> nullability(d, d + col_nullability.size());
table_metadata_with_nullability metadata;
metadata.column_nullable = nullability;
metadata.column_names = col_names.as_cpp_vector();
for (auto i = 0; i < meta_keys.size(); ++i) {
metadata.user_data[meta_keys[i].get()] = meta_values[i].get();
}

std::unique_ptr<cudf::jni::jni_writer_data_sink> data_sink(
new cudf::jni::jni_writer_data_sink(env, consumer));
sink_info sink{data_sink.get()};
cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> const v_precisions(
precisions.data(), precisions.data() + precisions.size());
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

auto writer_ptr = std::make_unique<cudf::io::parquet_chunked_writer>(opts);
Expand Down Expand Up @@ -965,27 +974,34 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetFileBegin(
cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys);
cudf::jni::native_jstringArray meta_values(env, j_metadata_values);
cudf::jni::native_jstring output_path(env, j_output_path);
cudf::jni::native_jintArray precisions(env, j_precisions);

auto d = col_nullability.data();
std::vector<bool> nullability(d, d + col_nullability.size());
table_metadata_with_nullability metadata;
metadata.column_nullable = nullability;
metadata.column_names = col_names.as_cpp_vector();
for (int i = 0; i < meta_keys.size(); ++i) {
auto cpp_names = col_names.as_cpp_vector();
table_input_metadata metadata;
metadata.column_metadata.resize(col_nullability.size());
for (int i = 0; i < col_nullability.size(); i++) {
metadata.column_metadata[i]
.set_name(cpp_names[i])
.set_nullability(col_nullability[i])
.set_int96_timestamps(j_isInt96);
}

// Precisions is not always set
for (int i = 0; i < precisions.size(); i++) {
metadata.column_metadata[i]
.set_decimal_precision(precisions[i]);
}

for (auto i = 0; i < meta_keys.size(); ++i) {
metadata.user_data[meta_keys[i].get()] = meta_values[i].get();
}
cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> v_precisions(
precisions.data(), precisions.data() + precisions.size());


sink_info sink{output_path.get()};
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

auto writer_ptr = std::make_unique<cudf::io::parquet_chunked_writer>(opts);
Expand Down
29 changes: 11 additions & 18 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4148,19 +4148,6 @@ private Table getExpectedFileTableWithDecimals() {
.build();
}

@Test
void testParquetWriteToFileNoNames() throws IOException {
File tempFile = File.createTempFile("test-nonames", ".parquet");
try (Table table0 = getExpectedFileTable()) {
table0.writeParquet(tempFile.getAbsoluteFile());
try (Table table1 = Table.readParquet(tempFile.getAbsoluteFile())) {
assertTablesAreEqual(table0, table1);
}
} finally {
tempFile.delete();
}
}

private final class MyBufferConsumer implements HostBufferConsumer, AutoCloseable {
public final HostMemoryBuffer buffer;
long offset = 0;
Expand Down Expand Up @@ -4210,8 +4197,9 @@ void testParquetWriteToBufferChunkedInt96() {
try (Table table0 = getExpectedFileTableWithDecimals();
MyBufferConsumer consumer = new MyBufferConsumer()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9")
.withTimestampInt96(true)
.withPrecisionValues(5, 5)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 5)
.build();

try (TableWriter writer = Table.writeParquetChunked(options, consumer)) {
Expand All @@ -4228,9 +4216,13 @@ void testParquetWriteToBufferChunkedInt96() {

@Test
void testParquetWriteToBufferChunked() {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7")
.withTimestampInt96(true)
.build();
try (Table table0 = getExpectedFileTable();
MyBufferConsumer consumer = new MyBufferConsumer()) {
try (TableWriter writer = Table.writeParquetChunked(ParquetWriterOptions.DEFAULT, consumer)) {
try (TableWriter writer = Table.writeParquetChunked(options, consumer)) {
writer.write(table0);
writer.write(table0);
writer.write(table0);
Expand All @@ -4251,7 +4243,7 @@ void testParquetWriteToFileWithNames() throws IOException {
"eighth", "nineth")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(5, 6)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4274,7 +4266,7 @@ void testParquetWriteToFileWithNamesAndMetadata() throws IOException {
.withMetadata("somekey", "somevalue")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(6, 8)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 6, 8)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4292,9 +4284,10 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException {
File tempFile = File.createTempFile("test-uncompressed", ".parquet");
try (Table table0 = getExpectedFileTableWithDecimals()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(4, 6)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 4, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand Down

0 comments on commit 8687182

Please sign in to comment.