Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support storing precision of decimal types in Schema class #17176

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions java/src/main/java/ai/rapids/cudf/Schema.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,52 @@ public class Schema {
public static final Schema INFERRED = new Schema();

private final DType topLevelType;

/**
* Default value for precision value, when it is not specified or the column type is not decimal.
*/
private static final int UNKNOWN_PRECISION = -1;

/**
* Store precision for the top level column, only applicable if the column is a decimal type.
* <p/>
* This variable is not designed to be used by any libcudf's APIs since libcudf does not support
* precisions for fixed point numbers.
* Instead, it is used only to pass down the precision values from Spark's DecimalType to the
* JNI level, where some JNI functions require these values to perform their operations.
*/
private final int topLevelPrecision;

private final List<String> childNames;
private final List<Schema> childSchemas;
private boolean flattened = false;
private String[] flattenedNames;
private DType[] flattenedTypes;
private int[] flattenedPrecisions;
private int[] flattenedCounts;

private Schema(DType topLevelType,
int topLevelPrecision,
List<String> childNames,
List<Schema> childSchemas) {
this.topLevelType = topLevelType;
this.topLevelPrecision = topLevelPrecision;
this.childNames = childNames;
this.childSchemas = childSchemas;
}

private Schema(DType topLevelType,
List<String> childNames,
List<Schema> childSchemas) {
this(topLevelType, UNKNOWN_PRECISION, childNames, childSchemas);
}

/**
* Inferred schema.
*/
private Schema() {
topLevelType = null;
topLevelPrecision = UNKNOWN_PRECISION;
childNames = null;
childSchemas = null;
}
Expand Down Expand Up @@ -104,14 +130,17 @@ private void flattenIfNeeded() {
if (flatLen == 0) {
flattenedNames = null;
flattenedTypes = null;
flattenedPrecisions = null;
flattenedCounts = null;
} else {
String[] names = new String[flatLen];
DType[] types = new DType[flatLen];
int[] precisions = new int[flatLen];
int[] counts = new int[flatLen];
collectFlattened(names, types, counts, 0);
collectFlattened(names, types, precisions, counts, 0);
flattenedNames = names;
flattenedTypes = types;
flattenedPrecisions = precisions;
flattenedCounts = counts;
}
flattened = true;
Expand All @@ -128,19 +157,20 @@ private int flattenedLength(int startingLength) {
return startingLength;
}

private int collectFlattened(String[] names, DType[] types, int[] counts, int offset) {
private int collectFlattened(String[] names, DType[] types, int[] precisions, int[] counts, int offset) {
if (childSchemas != null) {
for (int i = 0; i < childSchemas.size(); i++) {
Schema child = childSchemas.get(i);
names[offset] = childNames.get(i);
types[offset] = child.topLevelType;
precisions[offset] = child.topLevelPrecision;
if (child.childNames != null) {
counts[offset] = child.childNames.size();
} else {
counts[offset] = 0;
}
offset++;
offset = this.childSchemas.get(i).collectFlattened(names, types, counts, offset);
offset = this.childSchemas.get(i).collectFlattened(names, types, precisions, counts, offset);
}
}
return offset;
Expand Down Expand Up @@ -226,6 +256,22 @@ public int[] getFlattenedTypeScales() {
return ret;
}

/**
* Get decimal precisions of the columns' types flattened from all levels in schema by
* depth-first traversal.
* <p/>
* This is used to pass down the decimal precisions from Spark to only the JNI layer, where
* some JNI functions require precision values to perform their operations.
* Decimal precisions should not be consumed by any libcudf's APIs since libcudf does not
* support precisions for fixed point numbers.
*
* @return An array containing decimal precision of all columns in schema.
*/
public int[] getFlattenedDecimalPrecisions() {
flattenIfNeeded();
return flattenedPrecisions;
}

/**
* Get the types of the columns in schema flattened from all levels by depth-first traversal.
* @return An array containing types of all columns in schema.
Expand Down Expand Up @@ -307,11 +353,13 @@ public HostColumnVector.DataType asHostDataType() {

public static class Builder {
private final DType topLevelType;
private final int topLevelPrecision;
private final List<String> names;
private final List<Builder> types;

private Builder(DType topLevelType) {
private Builder(DType topLevelType, int topLevelPrecision) {
this.topLevelType = topLevelType;
this.topLevelPrecision = topLevelPrecision;
if (topLevelType == DType.STRUCT || topLevelType == DType.LIST) {
// There can be children
names = new ArrayList<>();
Expand All @@ -322,14 +370,19 @@ private Builder(DType topLevelType) {
}
}

private Builder(DType topLevelType) {
this(topLevelType, UNKNOWN_PRECISION);
}

/**
* Add a new column
* @param type the type of column to add
* @param name the name of the column to add (Ignored for list types)
* @param precision the decimal precision, only applicable for decimal types
* @return the builder for the new column. This should really only be used when the type
* passed in is a LIST or a STRUCT.
*/
public Builder addColumn(DType type, String name) {
public Builder addColumn(DType type, String name, int precision) {
if (names == null) {
throw new IllegalStateException("A column of type " + topLevelType +
" cannot have children");
Expand All @@ -340,21 +393,31 @@ public Builder addColumn(DType type, String name) {
if (names.contains(name)) {
throw new IllegalStateException("Cannot add duplicate names to a schema");
}
Builder ret = new Builder(type);
Builder ret = new Builder(type, precision);
types.add(ret);
names.add(name);
return ret;
}

public Builder addColumn(DType type, String name) {
return addColumn(type, name, UNKNOWN_PRECISION);
}

/**
* Adds a single column to the current schema. addColumn is preferred as it can be used
* to support nested types.
* @param type the type of the column.
* @param name the name of the column.
* @param precision the decimal precision, only applicable for decimal types.
* @return this for chaining.
*/
public Builder column(DType type, String name, int precision) {
addColumn(type, name, precision);
return this;
}

public Builder column(DType type, String name) {
addColumn(type, name);
addColumn(type, name, UNKNOWN_PRECISION);
return this;
}

Expand Down
Loading