diff --git a/java/src/main/java/ai/rapids/cudf/Schema.java b/java/src/main/java/ai/rapids/cudf/Schema.java index acf78564824..63446ca2305 100644 --- a/java/src/main/java/ai/rapids/cudf/Schema.java +++ b/java/src/main/java/ai/rapids/cudf/Schema.java @@ -28,27 +28,42 @@ public class Schema { public static final Schema INFERRED = new Schema(); + private static final int UNKNOWN_PRECISION = -1; + private final DType topLevelType; + private final int precision; // storing precision for decimal types private final List childNames; private final List childSchemas; private boolean flattened = false; private String[] flattenedNames; private DType[] flattenedTypes; private int[] flattenedCounts; + private int[] flattenedPrecisions; + + private Schema(DType topLevelType, + int precision, List childNames, List childSchemas) { this.topLevelType = topLevelType; + this.precision = precision; this.childNames = childNames; this.childSchemas = childSchemas; } + private Schema(DType topLevelType, + List childNames, + List childSchemas) { + this(topLevelType, UNKNOWN_PRECISION, childNames, childSchemas); + } + /** * Inferred schema. */ private Schema() { topLevelType = null; + precision = UNKNOWN_PRECISION; childNames = null; childSchemas = null; } @@ -105,14 +120,17 @@ private void flattenIfNeeded() { flattenedNames = null; flattenedTypes = null; flattenedCounts = null; + flattenedPrecisions = null; } else { String[] names = new String[flatLen]; DType[] types = new DType[flatLen]; int[] counts = new int[flatLen]; - collectFlattened(names, types, counts, 0); + int[] precisions = new int[flatLen]; + collectFlattened(names, types, counts, precisions, 0); flattenedNames = names; flattenedTypes = types; flattenedCounts = counts; + flattenedPrecisions = precisions; } flattened = true; } @@ -128,19 +146,20 @@ private int flattenedLength(int startingLength) { return startingLength; } - private int collectFlattened(String[] names, DType[] types, int[] counts, int offset) { + private int collectFlattened(String[] names, DType[] types, int[] counts, int[] precisions, int offset) { if (childSchemas != null) { for (int i = 0; i < childSchemas.size(); i++) { Schema child = childSchemas.get(i); names[offset] = childNames.get(i); types[offset] = child.topLevelType; + precisions[offset] = child.precision; if (child.childNames != null) { counts[offset] = child.childNames.size(); } else { counts[offset] = 0; } offset++; - offset = this.childSchemas.get(i).collectFlattened(names, types, counts, offset); + offset = this.childSchemas.get(i).collectFlattened(names, types, counts, precisions, offset); } } return offset; @@ -233,14 +252,7 @@ public int[] getFlattenedTypeScales() { */ public int[] getFlattenedDecimalPrecisions() { flattenIfNeeded(); - if (flattenedTypes == null) { - return null; - } - int[] ret = new int[flattenedTypes.length]; - for (int i = 0; i < flattenedTypes.length; i++) { - ret[i] = flattenedTypes[i].getDecimalPrecision(); - } - return ret; + return flattenedPrecisions; } /** @@ -324,11 +336,13 @@ public HostColumnVector.DataType asHostDataType() { public static class Builder { private final DType topLevelType; + private final int topLevelPrecision; private final List names; private final List types; - private Builder(DType topLevelType) { + private Builder(DType topLevelType, int precision) { this.topLevelType = topLevelType; + this.topLevelPrecision = precision; if (topLevelType == DType.STRUCT || topLevelType == DType.LIST) { // There can be children names = new ArrayList<>(); @@ -339,14 +353,19 @@ private Builder(DType topLevelType) { } } + private Builder(DType topLevelType) { + this(topLevelType, UNKNOWN_PRECISION); + } + /** * Add a new column * @param type the type of column to add * @param name the name of the column to add (Ignored for list types) + * @param precision the decimal precision, only applicable for decimal types * @return the builder for the new column. This should really only be used when the type * passed in is a LIST or a STRUCT. */ - public Builder addColumn(DType type, String name) { + public Builder addColumn(DType type, String name, int precision) { if (names == null) { throw new IllegalStateException("A column of type " + topLevelType + " cannot have children"); @@ -357,21 +376,31 @@ public Builder addColumn(DType type, String name) { if (names.contains(name)) { throw new IllegalStateException("Cannot add duplicate names to a schema"); } - Builder ret = new Builder(type); + Builder ret = new Builder(type, precision); types.add(ret); names.add(name); return ret; } + public Builder addColumn(DType type, String name) { + return addColumn(type, name, UNKNOWN_PRECISION); + } + /** * Adds a single column to the current schema. addColumn is preferred as it can be used * to support nested types. * @param type the type of the column. * @param name the name of the column. + * @param precision the decimal precision, only applicable for decimal types. * @return this for chaining. */ + public Builder column(DType type, String name, int precision) { + addColumn(type, name, precision); + return this; + } + public Builder column(DType type, String name) { - addColumn(type, name); + addColumn(type, name, UNKNOWN_PRECISION); return this; }