-
Notifications
You must be signed in to change notification settings - Fork 242
/
Copy pathSchemaUtils.scala
278 lines (262 loc) · 11.1 KB
/
SchemaUtils.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import java.util.Optional
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import ai.rapids.cudf._
import ai.rapids.cudf.ColumnWriterOptions._
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
import org.apache.orc.TypeDescription
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
object SchemaUtils extends Arm {
/**
* Convert a TypeDescription to a Catalyst StructType.
*/
implicit def toCatalystSchema(schema: TypeDescription): StructType = {
// Here just follows the implementation of Spark3.0.x, so it does not replace the
// CharType/VarcharType with StringType. It is OK because GPU does not support
// these two char types yet.
CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]
}
private def getPrecisionsList(dt: DataType): Seq[Int] = dt match {
case ArrayType(et, _) => getPrecisionsList(et)
case MapType(kt, vt, _) => getPrecisionsList(kt) ++ getPrecisionsList(vt)
case StructType(fields) => fields.flatMap(f => getPrecisionsList(f.dataType))
case d: DecimalType => Seq(d.precision)
case _ => Seq.empty[Int]
}
private def buildTypeIdMapFromSchema(schema: StructType,
isCaseSensitive: Boolean): Map[String, (DataType, Int)] = {
val typeIdSeq = schema.map(_.dataType).zipWithIndex
val name2TypeIdSensitiveMap = schema.map(_.name).zip(typeIdSeq).toMap
if (isCaseSensitive) {
name2TypeIdSensitiveMap
} else {
CaseInsensitiveMap[(DataType, Int)](name2TypeIdSensitiveMap)
}
}
/**
* Now the schema evolution covers only two things. (Full type casting is not supported yet).
* 1) Cast decimal columns with precision that can be stored in an int to DECIMAL32.
* The reason to do this is the plugin requires decimals being stored as DECIMAL32 if the
* precision is small enough to fit in an int. And getting this wrong may lead to a number
* of problems later on. For example, the cuDF ORC reader always read decimals as DECIMAL64.
* 2) Add columns for names are in the "readSchema" but not in the "tableSchema".
* It will create a new column with nulls for each missing name instead of throwing an
* exception.
* Column pruning will be done implicitly by iterating the readSchema.
*
* (This is mainly used by the GPU Parquet/ORC readers to partially support the schema
* evolution.)
*
* @param table The input table, will be closed after returning
* @param tableSchema The schema of the table
* @param readSchema The read schema from Spark
* @param isCaseSensitive Whether the name check should be case sensitive or not
* @return a new table mapping to the "readSchema". Users should close it if no longer needed.
*/
private[rapids] def evolveSchemaIfNeededAndClose(
table: Table,
tableSchema: StructType,
readSchema: StructType,
isCaseSensitive: Boolean): Table = {
assert(table.getNumberOfColumns == tableSchema.length)
// Check if schema evolution is needed. It is true when
// there are columns with precision can be stored in an int, or
// "readSchema" is not equal to "tableSchema".
val isDecCastNeeded = getPrecisionsList(tableSchema).exists(p => p <= Decimal.MAX_INT_DIGITS)
val isAddOrPruneColNeeded = readSchema != tableSchema
if (isDecCastNeeded || isAddOrPruneColNeeded) {
val name2TypeIdMap = buildTypeIdMapFromSchema(tableSchema, isCaseSensitive)
withResource(table) { t =>
val newColumns = readSchema.safeMap { rf =>
if (name2TypeIdMap.contains(rf.name)) {
// Found the column in the table, so start the column evolution.
val typeAndId = name2TypeIdMap(rf.name)
val cv = t.getColumn(typeAndId._2)
withResource(new ArrayBuffer[ColumnView]) { toClose =>
val newCol =
evolveColumnRecursively(cv, typeAndId._1, rf.dataType, isCaseSensitive, toClose)
if (newCol == cv) {
cv.incRefCount()
} else {
toClose += newCol
newCol.copyToColumnVector()
}
}
} else {
// Return a null column if the name is not found in the table.
GpuColumnVector.columnVectorFromNull(t.getRowCount.toInt, rf.dataType)
}
}
withResource(newColumns) { newCols =>
new Table(newCols: _*)
}
}
} else {
table
}
}
private def evolveColumnRecursively(col: ColumnView, colType: DataType, targetType: DataType,
isCaseSensitive: Boolean, toClose: ArrayBuffer[ColumnView]): ColumnView = {
// An util function to add a view to the buffer "toClose".
val addToClose = (v: ColumnView) => {
toClose += v
v
}
// Type casting is not supported yet.
assert(colType.getClass == targetType.getClass)
colType match {
case st: StructType =>
// This is for the case of nested columns.
val typeIdMap = buildTypeIdMapFromSchema(st, isCaseSensitive)
var changed = false
val newViews = targetType.asInstanceOf[StructType].safeMap { f =>
if (typeIdMap.contains(f.name)) {
val typeAndId = typeIdMap(f.name)
val cv = addToClose(col.getChildColumnView(typeAndId._2))
val newChild =
evolveColumnRecursively(cv, typeAndId._1, f.dataType, isCaseSensitive, toClose)
if (newChild != cv) {
addToClose(newChild)
changed = true
}
newChild
} else {
changed = true
// Return a null column if the name is not found in the table.
addToClose(GpuColumnVector.columnVectorFromNull(col.getRowCount.toInt, f.dataType))
}
}
if (changed) {
// Create a new struct column view with only different children.
// It would be better to add a dedicate API in cuDF for this.
val opNullCount = Optional.of(col.getNullCount.asInstanceOf[java.lang.Long])
new ColumnView(col.getType, col.getRowCount, opNullCount, col.getValid,
col.getOffsets, newViews.toArray)
} else {
col
}
case at: ArrayType =>
val targetElemType = targetType.asInstanceOf[ArrayType].elementType
val child = addToClose(col.getChildColumnView(0))
val newChild =
evolveColumnRecursively(child, at.elementType, targetElemType, isCaseSensitive, toClose)
if (child == newChild) {
col
} else {
col.replaceListChild(addToClose(newChild))
}
case mt: MapType =>
val targetMapType = targetType.asInstanceOf[MapType]
val listChild = addToClose(col.getChildColumnView(0))
// listChild is struct with two fields: key and value.
val newStructChildren = new ArrayBuffer[ColumnView](2)
val newStructIndices = new ArrayBuffer[Int](2)
// An until function to handle key and value view
val processView = (id: Int, srcType: DataType, distType: DataType) => {
val view = addToClose(listChild.getChildColumnView(id))
val newView = evolveColumnRecursively(view, srcType, distType, isCaseSensitive, toClose)
if (newView != view) {
newStructChildren += addToClose(newView)
newStructIndices += id
}
}
// key and value
processView(0, mt.keyType, targetMapType.keyType)
processView(1, mt.valueType, targetMapType.valueType)
if (newStructChildren.nonEmpty) {
// Have new key or value, or both
col.replaceListChild(
addToClose(listChild.replaceChildrenWithViews(newStructIndices.toArray,
newStructChildren.toArray))
)
} else {
col
}
case dt: DecimalType if !GpuColumnVector.getNonNestedRapidsType(dt).equals(col.getType) =>
col.castTo(DecimalUtil.createCudfDecimal(dt))
case _ => col
}
}
private def writerOptionsFromField[T <: NestedBuilder[_, _], V <: ColumnWriterOptions](
builder: NestedBuilder[T, V],
dataType: DataType,
name: String,
nullable: Boolean,
writeInt96: Boolean): T = {
dataType match {
case dt: DecimalType =>
builder.withDecimalColumn(name, dt.precision, nullable)
case TimestampType =>
builder.withTimestampColumn(name, writeInt96, nullable)
case s: StructType =>
builder.withStructColumn(
writerOptionsFromSchema(
structBuilder(name, nullable),
s,
writeInt96).build())
case a: ArrayType =>
builder.withListColumn(
writerOptionsFromField(
listBuilder(name, nullable),
a.elementType,
name,
a.containsNull,
writeInt96).build())
case m: MapType =>
// It is ok to use `StructBuilder` here for key and value, since either
// `OrcWriterOptions.Builder` or `ParquetWriterOptions.Builder` is actually an
// `AbstractStructBuilder`, and here only handles the common column metadata things.
builder.withMapColumn(
mapColumn(name,
writerOptionsFromField(
structBuilder(name, nullable),
m.keyType,
"key",
nullable = false,
writeInt96).build().getChildColumnOptions()(0),
writerOptionsFromField(
structBuilder(name, nullable),
m.valueType,
"value",
m.valueContainsNull,
writeInt96).build().getChildColumnOptions()(0)))
case _ =>
builder.withColumns(nullable, name)
}
builder.asInstanceOf[T]
}
/**
* Build writer options from schema for both ORC and Parquet writers.
*
* (There is an open issue "https://github.com/rapidsai/cudf/issues/7654" for Parquet writer,
* but it is circumvented by https://github.com/rapidsai/cudf/pull/9061, so the nullable can
* go back to the actual setting, instead of the hard-coded nullable=true before.)
*/
def writerOptionsFromSchema[T <: NestedBuilder[_, _], V <: ColumnWriterOptions](
builder: NestedBuilder[T, V],
schema: StructType,
writeInt96: Boolean = false): T = {
schema.foreach(field =>
writerOptionsFromField(builder, field.dataType, field.name, field.nullable, writeInt96)
)
builder.asInstanceOf[T]
}
}