diff --git a/java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java b/java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java index 1203fc25931..2e793494b7b 100644 --- a/java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java +++ b/java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java @@ -58,25 +58,12 @@ public Builder withTimestampInt96(boolean int96) { } /** - * Overwrite flattened precision values for all decimal columns that are expected to be in - * this Table. The list of precisions should be an in-order traversal of all Decimal columns, - * including nested columns. Please look at the example below. - * - * NOTE: The number of `precisionValues` should be equal to the numbers of Decimal columns - * otherwise a CudfException will be thrown. Also note that the values will be overwritten - * every time this method is called - * - * Example: - * Table0 : c0[type: INT32] - * c1[type: Decimal32(3, 1)] - * c2[type: Struct[col0[type: Decimal(2, 1)], - * col1[type: INT64], - * col2[type: Decimal(8, 6)]] - * c3[type: Decimal64(12, 5)] - * - * Flattened list of precision from the above example will be {3, 2, 8, 12} + * This is a temporary hack to make things work. This API will go away once we can update the + * parquet APIs properly. + * @param precisionValues a value for each column, non-decimal columns are ignored. + * @return this for chaining. */ - public Builder withPrecisionValues(int... precisionValues) { + public Builder withDecimalPrecisions(int ... precisionValues) { this.precisionValues = precisionValues; return this; } @@ -86,8 +73,6 @@ public ParquetWriterOptions build() { } } - public static final ParquetWriterOptions DEFAULT = new ParquetWriterOptions(new Builder()); - public static Builder builder() { return new Builder(); } @@ -107,7 +92,7 @@ public StatisticsFrequency getStatisticsFrequency() { /** * Return the flattened list of precisions if set otherwise empty array will be returned. - * For a definition of what `flattened` means please look at {@link Builder#withPrecisionValues} + * For a definition of what `flattened` means please look at {@link Builder#withDecimalPrecisions} */ public int[] getPrecisions() { return precisions; diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 0dc529d423f..4da99d811f2 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -803,6 +803,12 @@ private static class ParquetTableWriter implements TableWriter { HostBufferConsumer consumer; private ParquetTableWriter(ParquetWriterOptions options, File outputFile) { + int numColumns = options.getColumnNames().length; + assert (numColumns == options.getColumnNullability().length); + int[] precisions = options.getPrecisions(); + if (precisions != null) { + assert (numColumns >= options.getPrecisions().length); + } this.consumer = null; this.handle = writeParquetFileBegin(options.getColumnNames(), options.getColumnNullability(), @@ -871,17 +877,6 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options, return new ParquetTableWriter(options, consumer); } - /** - * Writes this table to a Parquet file on the host - * - * @param outputFile file to write the table to - * @deprecated please use writeParquetChunked instead - */ - @Deprecated - public void writeParquet(File outputFile) { - writeParquet(ParquetWriterOptions.DEFAULT, outputFile); - } - /** * Writes this table to a Parquet file on the host * diff --git a/java/src/main/java/ai/rapids/cudf/WriterOptions.java b/java/src/main/java/ai/rapids/cudf/WriterOptions.java index 60f7fb03459..5d5af3006a3 100644 --- a/java/src/main/java/ai/rapids/cudf/WriterOptions.java +++ b/java/src/main/java/ai/rapids/cudf/WriterOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ protected static class WriterBuilder { final List columnNullability = new ArrayList<>(); /** - * Add column name + * Add column name(s). For Parquet column names are not optional. * @param columnNames */ public T withColumnNames(String... columnNames) { diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 81b9882104f..249a3d9b55b 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -915,12 +915,24 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin( cudf::jni::native_jbooleanArray col_nullability(env, j_col_nullability); cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys); cudf::jni::native_jstringArray meta_values(env, j_metadata_values); + cudf::jni::native_jintArray precisions(env, j_precisions); + + auto cpp_names = col_names.as_cpp_vector(); + table_input_metadata metadata; + metadata.column_metadata.resize(col_nullability.size()); + for (int i = 0; i < col_nullability.size(); i++) { + metadata.column_metadata[i] + .set_name(cpp_names[i]) + .set_nullability(col_nullability[i]) + .set_int96_timestamps(j_isInt96); + } + + // Precisions is not always set + for (int i = 0; i < precisions.size(); i++) { + metadata.column_metadata[i] + .set_decimal_precision(precisions[i]); + } - auto d = col_nullability.data(); - std::vector nullability(d, d + col_nullability.size()); - table_metadata_with_nullability metadata; - metadata.column_nullable = nullability; - metadata.column_names = col_names.as_cpp_vector(); for (auto i = 0; i < meta_keys.size(); ++i) { metadata.user_data[meta_keys[i].get()] = meta_values[i].get(); } @@ -928,16 +940,13 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin( std::unique_ptr data_sink( new cudf::jni::jni_writer_data_sink(env, consumer)); sink_info sink{data_sink.get()}; - cudf::jni::native_jintArray precisions(env, j_precisions); std::vector const v_precisions( precisions.data(), precisions.data() + precisions.size()); chunked_parquet_writer_options opts = chunked_parquet_writer_options::builder(sink) - .nullable_metadata(&metadata) + .metadata(&metadata) .compression(static_cast(j_compression)) .stats_level(static_cast(j_stats_freq)) - .int96_timestamps(static_cast(j_isInt96)) - .decimal_precision(v_precisions) .build(); auto writer_ptr = std::make_unique(opts); @@ -965,27 +974,34 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetFileBegin( cudf::jni::native_jstringArray meta_keys(env, j_metadata_keys); cudf::jni::native_jstringArray meta_values(env, j_metadata_values); cudf::jni::native_jstring output_path(env, j_output_path); + cudf::jni::native_jintArray precisions(env, j_precisions); - auto d = col_nullability.data(); - std::vector nullability(d, d + col_nullability.size()); - table_metadata_with_nullability metadata; - metadata.column_nullable = nullability; - metadata.column_names = col_names.as_cpp_vector(); - for (int i = 0; i < meta_keys.size(); ++i) { + auto cpp_names = col_names.as_cpp_vector(); + table_input_metadata metadata; + metadata.column_metadata.resize(col_nullability.size()); + for (int i = 0; i < col_nullability.size(); i++) { + metadata.column_metadata[i] + .set_name(cpp_names[i]) + .set_nullability(col_nullability[i]) + .set_int96_timestamps(j_isInt96); + } + + // Precisions is not always set + for (int i = 0; i < precisions.size(); i++) { + metadata.column_metadata[i] + .set_decimal_precision(precisions[i]); + } + + for (auto i = 0; i < meta_keys.size(); ++i) { metadata.user_data[meta_keys[i].get()] = meta_values[i].get(); } - cudf::jni::native_jintArray precisions(env, j_precisions); - std::vector v_precisions( - precisions.data(), precisions.data() + precisions.size()); - + sink_info sink{output_path.get()}; chunked_parquet_writer_options opts = chunked_parquet_writer_options::builder(sink) - .nullable_metadata(&metadata) + .metadata(&metadata) .compression(static_cast(j_compression)) .stats_level(static_cast(j_stats_freq)) - .int96_timestamps(static_cast(j_isInt96)) - .decimal_precision(v_precisions) .build(); auto writer_ptr = std::make_unique(opts); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index c075f074068..02b64f0b713 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4148,19 +4148,6 @@ private Table getExpectedFileTableWithDecimals() { .build(); } - @Test - void testParquetWriteToFileNoNames() throws IOException { - File tempFile = File.createTempFile("test-nonames", ".parquet"); - try (Table table0 = getExpectedFileTable()) { - table0.writeParquet(tempFile.getAbsoluteFile()); - try (Table table1 = Table.readParquet(tempFile.getAbsoluteFile())) { - assertTablesAreEqual(table0, table1); - } - } finally { - tempFile.delete(); - } - } - private final class MyBufferConsumer implements HostBufferConsumer, AutoCloseable { public final HostMemoryBuffer buffer; long offset = 0; @@ -4210,8 +4197,9 @@ void testParquetWriteToBufferChunkedInt96() { try (Table table0 = getExpectedFileTableWithDecimals(); MyBufferConsumer consumer = new MyBufferConsumer()) { ParquetWriterOptions options = ParquetWriterOptions.builder() + .withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9") .withTimestampInt96(true) - .withPrecisionValues(5, 5) + .withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 5) .build(); try (TableWriter writer = Table.writeParquetChunked(options, consumer)) { @@ -4228,9 +4216,13 @@ void testParquetWriteToBufferChunkedInt96() { @Test void testParquetWriteToBufferChunked() { + ParquetWriterOptions options = ParquetWriterOptions.builder() + .withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7") + .withTimestampInt96(true) + .build(); try (Table table0 = getExpectedFileTable(); MyBufferConsumer consumer = new MyBufferConsumer()) { - try (TableWriter writer = Table.writeParquetChunked(ParquetWriterOptions.DEFAULT, consumer)) { + try (TableWriter writer = Table.writeParquetChunked(options, consumer)) { writer.write(table0); writer.write(table0); writer.write(table0); @@ -4251,7 +4243,7 @@ void testParquetWriteToFileWithNames() throws IOException { "eighth", "nineth") .withCompressionType(CompressionType.NONE) .withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE) - .withPrecisionValues(5, 6) + .withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 5, 6) .build(); try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) { writer.write(table0); @@ -4274,7 +4266,7 @@ void testParquetWriteToFileWithNamesAndMetadata() throws IOException { .withMetadata("somekey", "somevalue") .withCompressionType(CompressionType.NONE) .withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE) - .withPrecisionValues(6, 8) + .withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 6, 8) .build(); try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) { writer.write(table0); @@ -4292,9 +4284,10 @@ void testParquetWriteToFileUncompressedNoStats() throws IOException { File tempFile = File.createTempFile("test-uncompressed", ".parquet"); try (Table table0 = getExpectedFileTableWithDecimals()) { ParquetWriterOptions options = ParquetWriterOptions.builder() + .withColumnNames("_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7", "_c8", "_c9") .withCompressionType(CompressionType.NONE) .withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE) - .withPrecisionValues(4, 6) + .withDecimalPrecisions(0, 0, 0, 0, 0, 0, 0, 4, 6) .build(); try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) { writer.write(table0);