Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Java Parquet write after writer API changes [skip ci] #7655

Merged
merged 1 commit into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,12 @@ public Builder withTimestampInt96(boolean int96) {
}

/**
* Overwrite flattened precision values for all decimal columns that are expected to be in
* this Table. The list of precisions should be an in-order traversal of all Decimal columns,
* including nested columns. Please look at the example below.
*
* NOTE: The number of `precisionValues` should be equal to the numbers of Decimal columns
* otherwise a CudfException will be thrown. Also note that the values will be overwritten
* every time this method is called
*
* Example:
* Table0 : c0[type: INT32]
* c1[type: Decimal32(3, 1)]
* c2[type: Struct[col0[type: Decimal(2, 1)],
* col1[type: INT64],
* col2[type: Decimal(8, 6)]]
* c3[type: Decimal64(12, 5)]
*
* Flattened list of precision from the above example will be {3, 2, 8, 12}
* This is a temporary hack to make things work. This API will go away once we can update the
* parquet APIs properly.
* @param precisionValues a value for each column, non-decimal columns are ignored.
* @return this for chaining.
*/
public Builder withPrecisionValues(int... precisionValues) {
public Builder withDecimalPrecisions(int ... precisionValues) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: int ... => int...
extra space after int

this.precisionValues = precisionValues;
return this;
}
Expand All @@ -86,8 +73,6 @@ public ParquetWriterOptions build() {
}
}

public static final ParquetWriterOptions DEFAULT = new ParquetWriterOptions(new Builder());

public static Builder builder() {
return new Builder();
}
Expand All @@ -107,7 +92,7 @@ public StatisticsFrequency getStatisticsFrequency() {

/**
* Return the flattened list of precisions if set otherwise empty array will be returned.
* For a definition of what `flattened` means please look at {@link Builder#withPrecisionValues}
* For a definition of what `flattened` means please look at {@link Builder#withDecimalPrecisions}
*/
public int[] getPrecisions() {
return precisions;
Expand Down
17 changes: 6 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ private static class ParquetTableWriter implements TableWriter {
HostBufferConsumer consumer;

private ParquetTableWriter(ParquetWriterOptions options, File outputFile) {
int numColumns = options.getColumnNames().length;
assert (numColumns == options.getColumnNullability().length);
int[] precisions = options.getPrecisions();
if (precisions != null) {
assert (numColumns >= options.getPrecisions().length);
}
this.consumer = null;
this.handle = writeParquetFileBegin(options.getColumnNames(),
options.getColumnNullability(),
Expand Down Expand Up @@ -871,17 +877,6 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options,
return new ParquetTableWriter(options, consumer);
}

/**
* Writes this table to a Parquet file on the host
*
* @param outputFile file to write the table to
* @deprecated please use writeParquetChunked instead
*/
@Deprecated
public void writeParquet(File outputFile) {
writeParquet(ParquetWriterOptions.DEFAULT, outputFile);
}

/**
* Writes this table to a Parquet file on the host
*
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/WriterOptions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,7 +46,7 @@ protected static class WriterBuilder<T extends WriterBuilder> {
final List<Boolean> columnNullability = new ArrayList<>();

/**
* Add column name
* Add column name(s). For Parquet column names are not optional.
* @param columnNames
*/
public T withColumnNames(String... columnNames) {
Expand Down
60 changes: 38 additions & 22 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,29 +915,38 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin(
cudf::jni::native_jbooleanArray col_nullability(env, j_col_nullability);
cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys);
cudf::jni::native_jstringArray meta_values(env, j_metadata_values);
cudf::jni::native_jintArray precisions(env, j_precisions);

auto cpp_names = col_names.as_cpp_vector();
table_input_metadata metadata;
metadata.column_metadata.resize(col_nullability.size());
for (int i = 0; i < col_nullability.size(); i++) {
metadata.column_metadata[i]
.set_name(cpp_names[i])
.set_nullability(col_nullability[i])
.set_int96_timestamps(j_isInt96);
}

// Precisions is not always set
for (int i = 0; i < precisions.size(); i++) {
metadata.column_metadata[i]
.set_decimal_precision(precisions[i]);
}

auto d = col_nullability.data();
std::vector<bool> nullability(d, d + col_nullability.size());
table_metadata_with_nullability metadata;
metadata.column_nullable = nullability;
metadata.column_names = col_names.as_cpp_vector();
for (auto i = 0; i < meta_keys.size(); ++i) {
metadata.user_data[meta_keys[i].get()] = meta_values[i].get();
}

std::unique_ptr<cudf::jni::jni_writer_data_sink> data_sink(
new cudf::jni::jni_writer_data_sink(env, consumer));
sink_info sink{data_sink.get()};
cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> const v_precisions(
precisions.data(), precisions.data() + precisions.size());
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

auto writer_ptr = std::make_unique<cudf::io::parquet_chunked_writer>(opts);
Expand Down Expand Up @@ -965,27 +974,34 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetFileBegin(
cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys);
cudf::jni::native_jstringArray meta_values(env, j_metadata_values);
cudf::jni::native_jstring output_path(env, j_output_path);
cudf::jni::native_jintArray precisions(env, j_precisions);

auto d = col_nullability.data();
std::vector<bool> nullability(d, d + col_nullability.size());
table_metadata_with_nullability metadata;
metadata.column_nullable = nullability;
metadata.column_names = col_names.as_cpp_vector();
for (int i = 0; i < meta_keys.size(); ++i) {
auto cpp_names = col_names.as_cpp_vector();
table_input_metadata metadata;
metadata.column_metadata.resize(col_nullability.size());
for (int i = 0; i < col_nullability.size(); i++) {
metadata.column_metadata[i]
.set_name(cpp_names[i])
.set_nullability(col_nullability[i])
.set_int96_timestamps(j_isInt96);
}

// Precisions is not always set
for (int i = 0; i < precisions.size(); i++) {
metadata.column_metadata[i]
.set_decimal_precision(precisions[i]);
}

for (auto i = 0; i < meta_keys.size(); ++i) {
metadata.user_data[meta_keys[i].get()] = meta_values[i].get();
}
cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> v_precisions(
precisions.data(), precisions.data() + precisions.size());


sink_info sink{output_path.get()};
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

auto writer_ptr = std::make_unique<cudf::io::parquet_chunked_writer>(opts);
Expand Down
29 changes: 11 additions & 18 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4148,19 +4148,6 @@ private Table getExpectedFileTableWithDecimals() {
.build();
}

@Test
void testParquetWriteToFileNoNames() throws IOException {
File tempFile = File.createTempFile("test-nonames", ".parquet");
try (Table table0 = getExpectedFileTable()) {
table0.writeParquet(tempFile.getAbsoluteFile());
try (Table table1 = Table.readParquet(tempFile.getAbsoluteFile())) {
assertTablesAreEqual(table0, table1);
}
} finally {
tempFile.delete();
}
}

private final class MyBufferConsumer implements HostBufferConsumer, AutoCloseable {
public final HostMemoryBuffer buffer;
long offset = 0;
Expand Down Expand Up @@ -4210,8 +4197,9 @@ void testParquetWriteToBufferChunkedInt96() {
try (Table table0 = getExpectedFileTableWithDecimals();
MyBufferConsumer consumer = new MyBufferConsumer()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9")
.withTimestampInt96(true)
.withPrecisionValues(5, 5)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 5)
.build();

try (TableWriter writer = Table.writeParquetChunked(options, consumer)) {
Expand All @@ -4228,9 +4216,13 @@ void testParquetWriteToBufferChunkedInt96() {

@Test
void testParquetWriteToBufferChunked() {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7")
.withTimestampInt96(true)
.build();
try (Table table0 = getExpectedFileTable();
MyBufferConsumer consumer = new MyBufferConsumer()) {
try (TableWriter writer = Table.writeParquetChunked(ParquetWriterOptions.DEFAULT, consumer)) {
try (TableWriter writer = Table.writeParquetChunked(options, consumer)) {
writer.write(table0);
writer.write(table0);
writer.write(table0);
Expand All @@ -4251,7 +4243,7 @@ void testParquetWriteToFileWithNames() throws IOException {
"eighth", "nineth")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(5, 6)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4274,7 +4266,7 @@ void testParquetWriteToFileWithNamesAndMetadata() throws IOException {
.withMetadata("somekey", "somevalue")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(6, 8)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 6, 8)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4292,9 +4284,10 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException {
File tempFile = File.createTempFile("test-uncompressed", ".parquet");
try (Table table0 = getExpectedFileTableWithDecimals()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(4, 6)
.withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 4, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand Down