Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pick smallest decimal type with required precision in ORC reader #9775

Merged
merged 19 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cpp/include/cudf/io/orc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class orc_reader_options {

// Columns that should be read as Decimal128
std::vector<std::string> _decimal128_columns;
bool _enable_decimal128 = true;

friend orc_reader_options_builder;

Expand Down Expand Up @@ -151,6 +152,11 @@ class orc_reader_options {
*/
std::vector<std::string> const& get_decimal128_columns() const { return _decimal128_columns; }

/**
* @brief Whether to use row index to speed-up reading.
*/
bool is_enabled_decimal128() const { return _enable_decimal128; }

// Setters

/**
Expand Down Expand Up @@ -225,6 +231,13 @@ class orc_reader_options {
_decimal_cols_as_float = std::move(val);
}

/**
* @brief Enable/Disable the use of decimal128 type
*
* @param use Boolean value to enable/disable.
*/
void enable_decimal128(bool enable) { _enable_decimal128 = enable; }

/**
* @brief Set columns that should be read as 128-bit Decimal
*
Expand Down Expand Up @@ -362,6 +375,17 @@ class orc_reader_options_builder {
return *this;
}

/**
* @brief Enable/Disable use of decimal128 type
*
* @param use Boolean value to enable/disable.
*/
orc_reader_options_builder& decimal128(bool enable)
{
options.enable_decimal128(enable);
return *this;
}

/**
* @brief move orc_reader_options member once it's built.
*/
Expand Down
37 changes: 23 additions & 14 deletions cpp/src/io/orc/reader_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,24 +230,35 @@ size_t gather_stream_info(const size_t stripe_index,
/**
* @brief Determines cuDF type of an ORC Decimal column.
*/
auto decimal_column_type(const std::vector<std::string>& float64_columns,
const std::vector<std::string>& decimal128_columns,
cudf::io::orc::metadata& metadata,
auto decimal_column_type(std::vector<std::string> const& float64_columns,
std::vector<std::string> const& decimal128_columns,
bool is_decimal128_enabled,
cudf::io::orc::detail::aggregate_orc_metadata const& metadata,
int column_index)
{
auto const& column_path = metadata.column_path(column_index);
if (metadata.get_col_type(column_index).kind != DECIMAL) return type_id::EMPTY;

auto const& column_path = metadata.column_path(0, column_index);
auto is_column_in = [&](const std::vector<std::string>& cols) {
return std::find(cols.cbegin(), cols.cend(), column_path) != cols.end();
};

auto const user_selected_float64 = is_column_in(float64_columns);
auto const user_selected_decimal128 = is_column_in(decimal128_columns);
auto const user_selected_decimal128 = is_decimal128_enabled and is_column_in(decimal128_columns);
CUDF_EXPECTS(not user_selected_float64 or not user_selected_decimal128,
"Both decimal128 and float64 types selected for column " + column_path);

if (user_selected_float64) return type_id::FLOAT64;
if (user_selected_decimal128) return type_id::DECIMAL128;
return type_id::DECIMAL64;

auto const precision = metadata.get_col_type(column_index)
.precision.value_or(cuda::std::numeric_limits<int64_t>::digits10);
if (precision <= cuda::std::numeric_limits<int32_t>::digits10) return type_id::DECIMAL32;
if (precision <= cuda::std::numeric_limits<int64_t>::digits10) return type_id::DECIMAL64;
CUDF_EXPECTS(is_decimal128_enabled,
"Decimal precision too high for decimal64, use `decimal_cols_as_float` or enable "
"decimal128 use");
return type_id::DECIMAL128;
}

} // namespace
Expand Down Expand Up @@ -744,7 +755,7 @@ std::unique_ptr<column> reader::impl::create_empty_column(const size_type orc_co
_use_np_dtypes,
_timestamp_type.id(),
decimal_column_type(
_decimal_cols_as_float, decimal128_columns, _metadata.per_file_metadata[0], orc_col_id));
_decimal_cols_as_float, decimal128_columns, is_decimal128_enabled, _metadata, orc_col_id));
int32_t scale = 0;
std::vector<std::unique_ptr<column>> child_columns;
std::unique_ptr<column> out_col = nullptr;
Expand Down Expand Up @@ -795,7 +806,7 @@ std::unique_ptr<column> reader::impl::create_empty_column(const size_type orc_co
break;

case orc::DECIMAL:
if (type == type_id::DECIMAL64 or type == type_id::DECIMAL128) {
if (type == type_id::DECIMAL32 or type == type_id::DECIMAL64 or type == type_id::DECIMAL128) {
scale = -static_cast<int32_t>(_metadata.get_types()[orc_col_id].scale.value_or(0));
}
out_col = make_empty_column(data_type(type, scale));
Expand Down Expand Up @@ -889,6 +900,7 @@ reader::impl::impl(std::vector<std::unique_ptr<datasource>>&& sources,
// Control decimals conversion
_decimal_cols_as_float = options.get_decimal_cols_as_float();
decimal128_columns = options.get_decimal128_columns();
is_decimal128_enabled = options.is_enabled_decimal128();
}

timezone_table reader::impl::compute_timezone_table(
Expand Down Expand Up @@ -953,13 +965,10 @@ table_with_metadata reader::impl::read(size_type skip_rows,
_use_np_dtypes,
_timestamp_type.id(),
decimal_column_type(
_decimal_cols_as_float, decimal128_columns, _metadata.per_file_metadata[0], col.id));
_decimal_cols_as_float, decimal128_columns, is_decimal128_enabled, _metadata, col.id));
CUDF_EXPECTS(col_type != type_id::EMPTY, "Unknown type");
CUDF_EXPECTS(
(col_type != type_id::DECIMAL64) or (_metadata.get_col_type(col.id).precision <= 18),
"Precision of column " + std::string{_metadata.column_name(0, col.id)} +
" is over 18, use 128-bit Decimal.");
if (col_type == type_id::DECIMAL64 or col_type == type_id::DECIMAL128) {
if (col_type == type_id::DECIMAL32 or col_type == type_id::DECIMAL64 or
col_type == type_id::DECIMAL128) {
// sign of the scale is changed since cuDF follows c++ libraries like CNL
// which uses negative scaling, but liborc and other libraries
// follow positive scaling.
Expand Down
1 change: 1 addition & 0 deletions cpp/src/io/orc/reader_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class reader::impl {
bool _use_np_dtypes = true;
std::vector<std::string> _decimal_cols_as_float;
std::vector<std::string> decimal128_columns;
bool is_decimal128_enabled;
data_type _timestamp_type{type_id::EMPTY};
reader_column_meta _col_meta;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistency in what is initialized and how things are initialized here. Two bool members above are initialized to true using =. This bool is not initialized. data_type is initialized using {}. Suggest adding initialization for everything and doing them all the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, only left the vectors without "visible" initialization.

};
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/io/orc/stripe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1066,12 +1066,12 @@ static __device__ int Decode_Decimals(orc_bytestream_s* bs,
return (v / kPow5i[scale]) >> scale;
}
}();
if (dtype_id == type_id::DECIMAL64) {
if (dtype_id == type_id::DECIMAL32) {
vals.i32[t] = scaled_value;
} else if (dtype_id == type_id::DECIMAL64) {
vals.i64[t] = scaled_value;
} else {
{
vals.i128[t] = scaled_value;
}
vals.i128[t] = scaled_value;
}
}
}
Expand Down Expand Up @@ -1708,8 +1708,10 @@ __global__ void __launch_bounds__(block_size)
case DOUBLE:
case LONG: static_cast<uint64_t*>(data_out)[row] = s->vals.u64[t + vals_skipped]; break;
case DECIMAL:
if (s->chunk.dtype_id == type_id::FLOAT64 or
s->chunk.dtype_id == type_id::DECIMAL64) {
if (s->chunk.dtype_id == type_id::DECIMAL32) {
static_cast<uint32_t*>(data_out)[row] = s->vals.u32[t + vals_skipped];
} else if (s->chunk.dtype_id == type_id::FLOAT64 or
s->chunk.dtype_id == type_id::DECIMAL64) {
static_cast<uint64_t*>(data_out)[row] = s->vals.u64[t + vals_skipped];
} else {
// decimal128
Expand Down
22 changes: 8 additions & 14 deletions cpp/tests/io/orc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,7 @@ TEST_F(OrcWriterTest, MultiColumn)
cudf_io::write_orc(out_opts);

cudf_io::orc_reader_options in_opts =
cudf_io::orc_reader_options::builder(cudf_io::source_info{filepath})
.use_index(false)
.decimal128_columns({"decimal_pos_scale", "decimal_neg_scale"});
cudf_io::orc_reader_options::builder(cudf_io::source_info{filepath}).use_index(false);
auto result = cudf_io::read_orc(in_opts);

CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view());
Expand Down Expand Up @@ -1178,9 +1176,9 @@ TEST_F(OrcWriterTest, Decimal32)
auto data = cudf::detail::make_counting_transform_iterator(0, [&vals](auto i) {
return numeric::decimal32{vals[i], numeric::scale_type{2}};
});
auto mask = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 13 == 0; });
auto mask = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 13; });
column_wrapper<numeric::decimal32> col{data, data + num_rows, mask};
cudf::table_view expected({static_cast<cudf::column_view>(col)});
cudf::table_view expected({col});

auto filepath = temp_env->get_temp_filepath("Decimal32.orc");
cudf_io::orc_writer_options out_opts =
Expand All @@ -1192,12 +1190,7 @@ TEST_F(OrcWriterTest, Decimal32)
cudf_io::orc_reader_options::builder(cudf_io::source_info{filepath});
auto result = cudf_io::read_orc(in_opts);

auto data64 = cudf::detail::make_counting_transform_iterator(0, [&vals](auto i) {
return numeric::decimal64{vals[i], numeric::scale_type{2}};
});
column_wrapper<numeric::decimal64> col64{data64, data64 + num_rows, mask};

CUDF_TEST_EXPECT_COLUMNS_EQUAL(col64, result.tbl->view().column(0));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(col, result.tbl->view().column(0));
}

TEST_F(OrcStatisticsTest, Overflow)
Expand Down Expand Up @@ -1438,7 +1431,7 @@ TEST_F(OrcReaderTest, DecimalOptions)
cudf_io::orc_reader_options::builder(cudf_io::source_info{filepath})
.decimal128_columns({"dec", "fake_name"})
.decimal_cols_as_float({"decc", "fake_name"});
// Should not throw
// Should not throw, even with "fake name" in both options
EXPECT_NO_THROW(cudf_io::read_orc(valid_opts));

cudf_io::orc_reader_options invalid_opts =
Expand Down Expand Up @@ -1493,10 +1486,11 @@ TEST_F(OrcWriterTest, DecimalOptionsNested)
cudf_io::orc_reader_options in_opts =
cudf_io::orc_reader_options::builder(cudf_io::source_info{filepath})
.use_index(false)
.decimal128_columns({"lists.1.dec128"});
.decimal128_columns({"lists.1.dec64"});
auto result = cudf_io::read_orc(in_opts);

CUDF_TEST_EXPECT_TABLES_EQUAL(expected, result.tbl->view());
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result.tbl->view().column(0).child(1).child(0),
result.tbl->view().column(0).child(1).child(1));
}

CUDF_TEST_PROGRAM_MAIN()
89 changes: 4 additions & 85 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -1091,20 +1091,6 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options,
}
}

/**
* Writes this table to a Parquet file on the host
*
* @param options parameters for the writer
* @param outputFile file to write the table to
* @deprecated please use writeParquetChunked instead
*/
@Deprecated
public void writeParquet(ParquetWriterOptions options, File outputFile) {
try (TableWriter writer = writeParquetChunked(options, outputFile)) {
writer.write(this);
}
}

private static class ORCTableWriter implements TableWriter {
private long handle;
HostBufferConsumer consumer;
Expand Down Expand Up @@ -1179,33 +1165,6 @@ public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferCo
return new ORCTableWriter(options, consumer);
}

/**
* Writes this table to a file on the host.
* @param outputFile - File to write the table to
* @deprecated please use writeORCChunked instead
*/
@Deprecated
public void writeORC(File outputFile) {
// Need to specify the number of columns but leave all column names undefined
String[] names = new String[getNumberOfColumns()];
Arrays.fill(names, "");
ORCWriterOptions opts = ORCWriterOptions.builder().withColumns(true, names).build();
writeORC(opts, outputFile);
}

/**
* Writes this table to a file on the host.
* @param outputFile - File to write the table to
* @deprecated please use writeORCChunked instead
*/
@Deprecated
public void writeORC(ORCWriterOptions options, File outputFile) {
assert options.getTopLevelChildren() == getNumberOfColumns() : "must specify names for all columns";
try (TableWriter writer = Table.writeORCChunked(options, outputFile)) {
writer.write(this);
}
}

private static class ArrowIPCTableWriter implements TableWriter {
private final ArrowIPCWriterOptions.DoneOnGpu callback;
private long handle;
Expand Down Expand Up @@ -2082,26 +2041,6 @@ public Table gather(ColumnView gatherMap) {
return gather(gatherMap, OutOfBoundsPolicy.NULLIFY);
}

/**
* Gathers the rows of this table according to `gatherMap` such that row "i"
* in the resulting table's columns will contain row "gatherMap[i]" from this table.
* The number of rows in the result table will be equal to the number of elements in
* `gatherMap`.
*
* A negative value `i` in the `gatherMap` is interpreted as `i+n`, where
* `n` is the number of rows in this table.
*
* @deprecated Use {@link #gather(ColumnView, OutOfBoundsPolicy)}
* @param gatherMap the map of indexes. Must be non-nullable and integral type.
* @param checkBounds if true bounds checking is performed on the value. Be very careful
* when setting this to false.
* @return the resulting Table.
*/
@Deprecated
public Table gather(ColumnView gatherMap, boolean checkBounds) {
return new Table(gather(nativeHandle, gatherMap.getNativeView(), checkBounds));
}

/**
* Gathers the rows of this table according to `gatherMap` such that row "i"
* in the resulting table's columns will contain row "gatherMap[i]" from this table.
Expand Down Expand Up @@ -2256,7 +2195,7 @@ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable,
* the left and right tables, respectively, to produce the result of the left join.
* It is the responsibility of the caller to close the resulting gather map instances.
* This interface allows passing an output row count that was previously computed from
* {@link #conditionalLeftJoinRowCount(Table, CompiledExpression, boolean)}.
* {@link #conditionalLeftJoinRowCount(Table, CompiledExpression)}.
* WARNING: Passing a row count that is smaller than the actual row count will result
* in undefined behavior.
* @param rightTable the right side table of the join in the join
Expand Down Expand Up @@ -2396,7 +2335,7 @@ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable,
* the left and right tables, respectively, to produce the result of the inner join.
* It is the responsibility of the caller to close the resulting gather map instances.
* This interface allows passing an output row count that was previously computed from
* {@link #conditionalInnerJoinRowCount(Table, CompiledExpression, boolean)}.
* {@link #conditionalInnerJoinRowCount(Table, CompiledExpression)}.
* WARNING: Passing a row count that is smaller than the actual row count will result
* in undefined behavior.
* @param rightTable the right side table of the join in the join
Expand Down Expand Up @@ -2588,7 +2527,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable,
* to produce the result of the left semi join.
* It is the responsibility of the caller to close the resulting gather map instance.
* This interface allows passing an output row count that was previously computed from
* {@link #conditionalLeftSemiJoinRowCount(Table, CompiledExpression, boolean)}.
* {@link #conditionalLeftSemiJoinRowCount(Table, CompiledExpression)}.
* WARNING: Passing a row count that is smaller than the actual row count will result
* in undefined behavior.
* @param rightTable the right side table of the join
Expand Down Expand Up @@ -2667,7 +2606,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable,
* to produce the result of the left anti join.
* It is the responsibility of the caller to close the resulting gather map instance.
* This interface allows passing an output row count that was previously computed from
* {@link #conditionalLeftAntiJoinRowCount(Table, CompiledExpression, boolean)}.
* {@link #conditionalLeftAntiJoinRowCount(Table, CompiledExpression)}.
* WARNING: Passing a row count that is smaller than the actual row count will result
* in undefined behavior.
* @param rightTable the right side table of the join
Expand Down Expand Up @@ -3449,14 +3388,6 @@ public ContiguousTable[] contiguousSplitGroups() {
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest());
}

/**
* @deprecated use aggregateWindowsOverRanges
*/
@Deprecated
public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggregates) {
return aggregateWindowsOverRanges(windowAggregates);
}
}

public static final class TableOperation {
Expand Down Expand Up @@ -3651,18 +3582,6 @@ public PartitionedTable hashPartition(HashType type, int numberOfPartitions) {
partitionOffsets.length,
partitionOffsets)), partitionOffsets);
}

/**
* Hash partition a table into the specified number of partitions.
* @deprecated Use {@link #hashPartition(int)}
* @param numberOfPartitions - number of partitions to use
* @return - {@link PartitionedTable} - Table that exposes a limited functionality of the
* {@link Table} class
*/
@Deprecated
public PartitionedTable partition(int numberOfPartitions) {
return hashPartition(numberOfPartitions);
}
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
Loading