From 8f2f4820ad1e86451489d0f00781c049a3380e3e Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 3 Feb 2022 17:01:50 +1100 Subject: [PATCH 01/13] Enable matching by field ids --- .../sql/errors/QueryExecutionErrors.scala | 9 + .../apache/spark/sql/internal/SQLConf.scala | 19 + .../parquet/ParquetFileFormat.scala | 8 + .../parquet/ParquetReadSupport.scala | 261 ++++++--- .../parquet/ParquetSchemaConverter.scala | 16 +- .../datasources/parquet/ParquetUtils.scala | 38 +- .../parquet/ParquetFieldIdIOSuite.scala | 166 ++++++ .../parquet/ParquetFieldIdSchemaSuite.scala | 501 ++++++++++++++++++ .../parquet/ParquetSchemaSuite.scala | 10 +- 9 files changed, 956 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 76eb4311e41bf..d3d4ec3dab1bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -805,6 +805,15 @@ object QueryExecutionErrors { """.stripMargin.replaceAll("\n", " ")) } + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): Throwable = { + new RuntimeException( + s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { new SparkException(s"Failed to merge incompatible schemas $left and $right", e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 42979a68d8578..14d6a3e2e5f68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,6 +934,23 @@ object SQLConf { .intConf .createWithDefault(4096) + val PARQUET_FIELD_ID_ENABLED = + buildConf("spark.sql.parquet.fieldId.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers" + + " will use field IDs (if present) in the requested Spark schema to look up Parquet" + + " fields instead of using column names; Parquet writers will also populate the field Id" + + " metadata (if present) in the Spark schema to the Parquet schema.") + .booleanConf + .createWithDefault(true) + + val IGNORE_MISSING_PARQUET_FIELD_ID = + buildConf("spark.sql.parquet.fieldId.ignoreMissing") + .doc("When the Parquet file does't have any field IDs but the" + + " Spark read schema is using field IDs to read, we will return silently return nulls" + + "when this flag is enabled, or error otherwise.") + .booleanConf + .createWithDefault(false) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -4251,6 +4268,8 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + def parquetFieldIdEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_ENABLED) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index aa6f9ee91656d..393a1e21999b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.io.IOException import java.net.URI import scala.collection.JavaConverters._ @@ -354,6 +355,13 @@ class ParquetFileFormat } } else { logDebug(s"Falling back to parquet-mr") + + if (SQLConf.get.parquetFieldIdEnabled && + ParquetUtils.hasFieldIds(requiredSchema)) { + throw new IOException("Parquet-mr reader does not support schema with field IDs." + + s" Please choose a different Parquet reader. Read schema: ${requiredSchema.json}") + } + // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( convertTz, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index bdab0f7892f00..9db15b4076c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId -import java.util.{Locale, Map => JMap} +import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -85,13 +85,70 @@ class ParquetReadSupport( StructType.fromString(schemaString) } + val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( + context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new ParquetRecordMaterializer( + parquetRequestedSchema, + ParquetReadSupport.expandUDT(catalystRequestedSchema), + new ParquetToSparkSchemaConverter(conf), + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + } +} + +object ParquetReadSupport extends Logging { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + def generateFakeColumnName: String = s"_fake_name_${UUID.randomUUID()}" + + def getRequestedSchema( + parquetFileSchema: MessageType, + catalystRequestedSchema: StructType, + conf: Configuration, + enableVectorizedReader: Boolean): MessageType = { val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val parquetFileSchema = context.getFileSchema + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) + val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, + SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) + + if (!ignoreMissingIds && + !containsFieldIds(parquetFileSchema) && + ParquetUtils.hasFieldIds(catalystRequestedSchema)) { + throw new RuntimeException( + s""" + |Spark read schema expects field Ids, but Parquet file schema doesn't contain field Ids. + | + |Spark read schema: + |${catalystRequestedSchema.prettyJson} + | + |Parquet file schema: + |${parquetFileSchema.toString} + |""".stripMargin) + } + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, caseSensitive) + catalystRequestedSchema, useFieldId, caseSensitive) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -109,6 +166,7 @@ class ParquetReadSupport( // in parquetRequestedSchema which are not present in the file. parquetClippedSchema } + logDebug( s"""Going to read the following fields from the Parquet file with the following schema: |Parquet file schema: @@ -120,34 +178,9 @@ class ParquetReadSupport( |Catalyst requested schema: |${catalystRequestedSchema.treeString} """.stripMargin) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) - } - /** - * Called on executor side after [[init()]], before instantiating actual Parquet record readers. - * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet - * records to Catalyst [[InternalRow]]s. - */ - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - val parquetRequestedSchema = readContext.getRequestedSchema - new ParquetRecordMaterializer( - parquetRequestedSchema, - ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf), - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) + parquetRequestedSchema } -} - -object ParquetReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist @@ -156,9 +189,10 @@ object ParquetReadSupport { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, + useFieldId: Boolean, caseSensitive: Boolean = true): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, useFieldId, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -170,20 +204,24 @@ object ParquetReadSupport { } private def clipParquetType( - parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { + parquetType: Type, + catalystType: DataType, + useFieldId: Boolean, + caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, useFieldId, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + clipParquetMapType( + parquetType.asGroupType(), t.keyType, t.valueType, useFieldId, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, useFieldId, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -210,7 +248,10 @@ object ParquetReadSupport { * [[StructType]]. */ private def clipParquetListType( - parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { + parquetList: GroupType, + elementType: DataType, + useFieldId: Boolean, + caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -218,7 +259,7 @@ object ParquetReadSupport { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, caseSensitive) + clipParquetType(parquetList, elementType, useFieldId, caseSensitive) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -242,7 +283,7 @@ object ParquetReadSupport { // "_tuple" appended then the repeated type is the element type and elements are required. // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the // only field. - if ( + val newParquetList = if ( repeatedGroup.getFieldCount > 1 || repeatedGroup.getName == "array" || repeatedGroup.getName == parquetList.getName + "_tuple" @@ -250,21 +291,35 @@ object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, useFieldId, caseSensitive)) .named(parquetList.getName) } else { + val newRepeatedGroup = Types + .repeatedGroup() + .addField( + clipParquetType( + repeatedGroup.getType(0), elementType, useFieldId, caseSensitive)) + .named(repeatedGroup.getName) + + val newElementType = if (useFieldId && repeatedGroup.getId() != null) { + newRepeatedGroup.withId(repeatedGroup.getId().intValue()) + } else { + newRepeatedGroup + } + // Otherwise, the repeated field's type is the element type with the repeated field's // repetition. Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField( - Types - .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) - .named(repeatedGroup.getName)) + .addField(newElementType) .named(parquetList.getName) } + if (useFieldId && parquetList.getId() != null) { + newParquetList.withId(parquetList.getId().intValue()) + } else { + newParquetList + } } } @@ -277,6 +332,7 @@ object ParquetReadSupport { parquetMap: GroupType, keyType: DataType, valueType: DataType, + useFieldId: Boolean, caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -285,19 +341,31 @@ object ParquetReadSupport { val parquetKeyType = repeatedGroup.getType(0) val parquetValueType = repeatedGroup.getType(1) - val clippedRepeatedGroup = - Types + val clippedRepeatedGroup = { + val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, useFieldId, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, useFieldId, caseSensitive)) .named(repeatedGroup.getName) + if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + } - Types + val newMap = Types .buildGroup(parquetMap.getRepetition) .as(parquetMap.getLogicalTypeAnnotation) .addField(clippedRepeatedGroup) .named(parquetMap.getName) + + if (useFieldId && parquetMap.getId() != null) { + newMap.withId(parquetMap.getId().intValue()) + } else { + newMap + } } /** @@ -309,13 +377,22 @@ object ParquetReadSupport { * pruning. */ private def clipParquetGroup( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) - Types + parquetRecord: GroupType, + structType: StructType, + useFieldId: Boolean, + caseSensitive: Boolean): GroupType = { + val clippedParquetFields = + clipParquetGroupFields(parquetRecord, structType, useFieldId, caseSensitive) + val newRecord = Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) .addFields(clippedParquetFields: _*) .named(parquetRecord.getName) + if (useFieldId && parquetRecord.getId() != null) { + newRecord.withId(parquetRecord.getId().intValue()) + } else { + newRecord + } } /** @@ -324,23 +401,29 @@ object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { - val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - if (caseSensitive) { - val caseSensitiveParquetFieldMap = + parquetRecord: GroupType, + structType: StructType, + useFieldId: Boolean, + caseSensitive: Boolean): Seq[Type] = { + val toParquet = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, useFieldId = useFieldId) + lazy val caseSensitiveParquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - structType.map { f => - caseSensitiveParquetFieldMap + lazy val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + lazy val idToParquetFieldMap = + parquetRecord.getFields.asScala.filter(_.getId() != null).groupBy(f => f.getId.intValue()) + + def matchCaseSensitiveField(f: StructField): Type = { + caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, caseSensitive)) + .map(clipParquetType(_, f.dataType, useFieldId, caseSensitive)) .getOrElse(toParquet.convertField(f)) - } - } else { + } + + def matchCaseInsensitiveField(f: StructField): Type = { // Do case-insensitive resolution only if in case-insensitive mode - val caseInsensitiveParquetFieldMap = - parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) - structType.map { f => - caseInsensitiveParquetFieldMap + caseInsensitiveParquetFieldMap .get(f.name.toLowerCase(Locale.ROOT)) .map { parquetTypes => if (parquetTypes.size > 1) { @@ -349,11 +432,50 @@ object ParquetReadSupport { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) } }.getOrElse(toParquet.convertField(f)) + } + + def matchIdField(f: StructField): Type = { + val fieldId = ParquetUtils.getFieldId(f) + idToParquetFieldMap + .get(fieldId) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( + fieldId, parquetTypesString) + } else { + clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) + } + }.getOrElse { + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) } } + + if (useFieldId && ParquetUtils.hasFieldIds(structType)) { + structType.map { f => + if (ParquetUtils.hasFieldId(f)) { + // try to match id if there's any + matchIdField(f) + } else { + // fall back to name matching + if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) + } + } + } + } else if (caseSensitive) { + structType.map(matchCaseSensitiveField) + } else { + structType.map(matchCaseInsensitiveField) + } } /** @@ -410,4 +532,13 @@ object ParquetReadSupport { expand(schema).asInstanceOf[StructType] } + + /** + * Whether the parquet schema contains any field IDs. + */ + def containsFieldIds(schema: Type): Boolean = schema match { + case p: PrimitiveType => p.getId() != null + // We don't require all fields to have IDs, so we use `exists` here. + case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cb5d646f85e9e..49e5d4d1b26e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -438,16 +438,19 @@ class ParquetToSparkSchemaConverter( class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = - SQLConf.ParquetOutputTimestampType.INT96) { + SQLConf.ParquetOutputTimestampType.INT96, + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - outputTimestampType = conf.parquetOutputTimestampType) + outputTimestampType = conf.parquetOutputTimestampType, + useFieldId = conf.parquetFieldIdEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( - conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), + useFieldId = SQLConf.get.parquetFieldIdEnabled) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -463,7 +466,12 @@ class SparkToParquetSchemaConverter( * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. */ def convertField(field: StructField): Type = { - convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + val converted = convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + if (useFieldId && ParquetUtils.hasFieldId(field)) { + converted.withId(ParquetUtils.getFieldId(field)) + } else { + converted + } } private def convertField(field: StructField, repetition: Type.Repetition): Type = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 87a0d9c860f31..1130a45137570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} object ParquetUtils { def inferSchema( @@ -144,6 +144,42 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field + * ID metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require(hasFieldId(field), + "The key `parquet.field.id` doesn't exist in the metadata of " + field) + field.metadata.getLong(FIELD_ID_METADATA_KEY).toInt + } + /** * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala new file mode 100644 index 0000000000000..ef3b3f8ec143a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructType} + +class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkSession { + + private def withId(id: Int): Metadata = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + /** + * Field id is supported in OSS vectorized reader at the moment. + * parquet-mr support is coming soon. + */ + private def withAllSupportedReaders(code: => Unit): Unit = { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + } + + test("general test") { + withTempDir { dir => + val readSchema = + new StructType().add( + "a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("random", IntegerType, true, withId(1)) + .add("name", StringType, true, withId(0)) + + val readData = Seq(Row("text", 100), Row("more", 200)) + val writeData = Seq(Row(100, "text"), Row(200, "more")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("b < 50"), Seq.empty) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("a >= 'oh'"), Row("text", 100) :: Nil) + } + + // blocked for Parquet-mr reader + val e = intercept[SparkException] { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + } + } + val cause = e.getCause + assert(cause.isInstanceOf[java.io.IOException] && + cause.getMessage.contains("Parquet-mr reader does not support schema with field IDs.")) + } + } + + test("absence of field ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", StringType, true, withId(2)) + .add("c", IntegerType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(3)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // 3 different cases for the 3 columns to read: + // - a: ID 1 is not found, but there is column with name `a`, still return null + // - b: ID 2 is not found, return null + // - c: ID 3 is found, read it + Row(null, null, 100) :: Row(null, null, 200) :: Nil) + } + } + } + + test("multiple id matches") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(1)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Found duplicate field(s)")) + } + } + } + + test("read parquet file without ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true) + .add("rand1", StringType, true) + .add("rand2", StringType, true) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllSupportedReaders { + Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => + val cause = intercept[SparkException] { + spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain field Ids")) + val expectedValues = (1 to schema.length).map(_ => null) + withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { + checkAnswer( + spark.read.schema(schema).parquet(dir.getCanonicalPath), + Row(expectedValues: _*) :: Row(expectedValues: _*) :: Nil) + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala new file mode 100644 index 0000000000000..2f964d1a483f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { + + private val FAKE_COLUMN_NAME = "_fake_name_" + private val UUID_REGEX = + "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + + private def withId(id: Int) = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String, + caseSensitive: Boolean = true, + useFieldId: Boolean = true): Unit = { + test(s"Clipping with field id - $testName") { + val fileSchema = MessageTypeParser.parseMessageType(parquetSchema) + val actual = ParquetReadSupport.clipParquetSchema( + fileSchema, + catalystSchema, + useFieldId = useFieldId, + caseSensitive = caseSensitive) + + // each fake name should be uniquely generated + val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) + assert( + fakeColumnNames.distinct == fakeColumnNames, "Should generate unique fake column names") + + // replace the random part of all fake names with a fixed id generator + val ids1 = (1 to 100).iterator + val actualNormalized = MessageTypeParser.parseMessageType( + UUID_REGEX.replaceAllIn(actual.toString, _ => ids1.next().toString) + ) + val ids2 = (1 to 100).iterator + val expectedNormalized = MessageTypeParser.parseMessageType( + FAKE_COLUMN_NAME.r.replaceAllIn(expectedSchema, _ => s"$FAKE_COLUMN_NAME${ids2.next()}") + ) + + try { + expectedNormalized.checkContains(actualNormalized) + actualNormalized.checkContains(expectedNormalized) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expectedSchema + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + checkEqual(actualNormalized, expectedNormalized) + // might be redundant but just to have some free tests for the utils + assert(ParquetReadSupport.containsFieldIds(fileSchema)) + assert(ParquetUtils.hasFieldIds(catalystSchema)) + } + } + + private def testSqlToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String): Unit = { + val converter = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, + outputTimestampType = SQLConf.ParquetOutputTimestampType.INT96, + useFieldId = true) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + checkEqual(actual, expected) + } + } + + private def checkEqual(actual: MessageType, expected: MessageType): Unit = { + actual.checkContains(expected) + expected.checkContains(actual) + assert(actual.toString == expected.toString, + s""" + |Schema mismatch. + |Expected schema: + |${expected.toString} + |Actual schema: + |${actual.toString} + """.stripMargin + ) + } + + test("check hasFieldIds for schema") { + val simpleSchemaMissingId = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true) + + assert(ParquetUtils.hasFieldIds(simpleSchemaMissingId)) + + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(8)) + + assert(ParquetUtils.hasFieldIds(f01ElementType)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + assert(ParquetUtils.hasFieldIds(f0Type)) + + assert(ParquetUtils.hasFieldIds( + new StructType().add("f0", f0Type, nullable = false, withId(1)))) + + assert(!ParquetUtils.hasFieldIds(new StructType().add("f0", IntegerType, nullable = true))) + assert(!ParquetUtils.hasFieldIds(new StructType())); + } + + test("check containsFieldIds for parquet schema") { + + // empty Parquet schema fails too + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 = 1 { + | optional int32 f00; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00 = 1; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list = 1 { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + } + + test("ID in Parquet Types is read as null when not set") { + val parquetSchemaString = + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin + + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaString) + val f0 = parquetSchema.getFields().get(0) + assert(f0.getId() == null) + assert(f0.asGroupType().getFields.get(0).getId == null) + } + + testSqlToParquet( + "standard array", + sqlSchema = { + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("f0", f0Type, nullable = false, withId(1)) + }, + parquetSchema = + """message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f010 = 7; + | optional int64 f012 = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 f01 = 3; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add( + "g00", IntegerType, nullable = true, withId(2)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(4)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 $FAKE_COLUMN_NAME = 4; + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional int32 f010 = 7; + | optional double f011 = 8; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("g011", DoubleType, nullable = true, withId(8)) + .add("g012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("g00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("g01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("g0", f0Type, nullable = false, withId(1)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f011 = 8; + | optional int64 $FAKE_COLUMN_NAME = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int32 value_f0 = 4; + | required int64 value_f1 = 6; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_g1", LongType, nullable = false, withId(6)) + .add("value_g2", DoubleType, nullable = false, withId(7)) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("g0", f0Type, nullable = false, withId(3)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int64 value_f1 = 6; + | required double $FAKE_COLUMN_NAME = 7; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "won't match field id if structure is different", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + // parquet has id 3, but won't use because structure is different + .add("g01", IntegerType, nullable = true, withId(3)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + }, + + // note that f1 is not picked up, even though it's Id is 3 + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 $FAKE_COLUMN_NAME = 3; + | } + |} + """.stripMargin) + + testSchemaClipping( + "Complex type with multiple mismatches should work", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(999)) + .add("g1", IntegerType, nullable = true, withId(3)) + .add("g2", IntegerType, nullable = true, withId(888)) + }, + + expectedSchema = + s"""message spark_schema { + | required group $FAKE_COLUMN_NAME = 999 { + | optional int32 g00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 $FAKE_COLUMN_NAME = 888; + |} + """.stripMargin) + + testSchemaClipping( + "Should allow fall-back to name matching if id not found", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + // nested f00 without id should also work + .add("f00", IntegerType, nullable = true) + + val f4Type = new StructType() + .add("g40", IntegerType, nullable = true, withId(6)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(3)) + // f2 without id should be matched using name matching + .add("f2", IntegerType, nullable = true) + // name is not matched + .add("g2", IntegerType, nullable = true) + // f4 without id will do name matching, but g40 will be matched using id + .add("f4", f4Type, nullable = true) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | optional int32 g2; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 272f12e138b68..6d1d160bd264e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -2257,7 +2257,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + useFieldId = false, + caseSensitive) try { expectedSchema.checkContains(actual) @@ -2821,7 +2824,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } assertThrows[RuntimeException] { ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + useFieldId = false, + caseSensitive = false) } } } From 2dd384b29a1c342cfb15848b432fb43a5f459dc3 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 3 Feb 2022 17:01:50 +1100 Subject: [PATCH 02/13] Enable matching by field ids --- .../sql/errors/QueryExecutionErrors.scala | 9 + .../apache/spark/sql/internal/SQLConf.scala | 19 + .../parquet/ParquetFileFormat.scala | 8 + .../parquet/ParquetReadSupport.scala | 261 ++++++--- .../parquet/ParquetSchemaConverter.scala | 16 +- .../datasources/parquet/ParquetUtils.scala | 38 +- .../parquet/ParquetFieldIdIOSuite.scala | 166 ++++++ .../parquet/ParquetFieldIdSchemaSuite.scala | 501 ++++++++++++++++++ .../parquet/ParquetSchemaSuite.scala | 10 +- 9 files changed, 956 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 76eb4311e41bf..d3d4ec3dab1bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -805,6 +805,15 @@ object QueryExecutionErrors { """.stripMargin.replaceAll("\n", " ")) } + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): Throwable = { + new RuntimeException( + s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { new SparkException(s"Failed to merge incompatible schemas $left and $right", e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 42979a68d8578..14d6a3e2e5f68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,6 +934,23 @@ object SQLConf { .intConf .createWithDefault(4096) + val PARQUET_FIELD_ID_ENABLED = + buildConf("spark.sql.parquet.fieldId.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers" + + " will use field IDs (if present) in the requested Spark schema to look up Parquet" + + " fields instead of using column names; Parquet writers will also populate the field Id" + + " metadata (if present) in the Spark schema to the Parquet schema.") + .booleanConf + .createWithDefault(true) + + val IGNORE_MISSING_PARQUET_FIELD_ID = + buildConf("spark.sql.parquet.fieldId.ignoreMissing") + .doc("When the Parquet file does't have any field IDs but the" + + " Spark read schema is using field IDs to read, we will return silently return nulls" + + "when this flag is enabled, or error otherwise.") + .booleanConf + .createWithDefault(false) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -4251,6 +4268,8 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + def parquetFieldIdEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_ENABLED) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index aa6f9ee91656d..393a1e21999b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.io.IOException import java.net.URI import scala.collection.JavaConverters._ @@ -354,6 +355,13 @@ class ParquetFileFormat } } else { logDebug(s"Falling back to parquet-mr") + + if (SQLConf.get.parquetFieldIdEnabled && + ParquetUtils.hasFieldIds(requiredSchema)) { + throw new IOException("Parquet-mr reader does not support schema with field IDs." + + s" Please choose a different Parquet reader. Read schema: ${requiredSchema.json}") + } + // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( convertTz, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index bdab0f7892f00..9db15b4076c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId -import java.util.{Locale, Map => JMap} +import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -85,13 +85,70 @@ class ParquetReadSupport( StructType.fromString(schemaString) } + val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( + context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new ParquetRecordMaterializer( + parquetRequestedSchema, + ParquetReadSupport.expandUDT(catalystRequestedSchema), + new ParquetToSparkSchemaConverter(conf), + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + } +} + +object ParquetReadSupport extends Logging { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + def generateFakeColumnName: String = s"_fake_name_${UUID.randomUUID()}" + + def getRequestedSchema( + parquetFileSchema: MessageType, + catalystRequestedSchema: StructType, + conf: Configuration, + enableVectorizedReader: Boolean): MessageType = { val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val parquetFileSchema = context.getFileSchema + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) + val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, + SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) + + if (!ignoreMissingIds && + !containsFieldIds(parquetFileSchema) && + ParquetUtils.hasFieldIds(catalystRequestedSchema)) { + throw new RuntimeException( + s""" + |Spark read schema expects field Ids, but Parquet file schema doesn't contain field Ids. + | + |Spark read schema: + |${catalystRequestedSchema.prettyJson} + | + |Parquet file schema: + |${parquetFileSchema.toString} + |""".stripMargin) + } + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, caseSensitive) + catalystRequestedSchema, useFieldId, caseSensitive) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -109,6 +166,7 @@ class ParquetReadSupport( // in parquetRequestedSchema which are not present in the file. parquetClippedSchema } + logDebug( s"""Going to read the following fields from the Parquet file with the following schema: |Parquet file schema: @@ -120,34 +178,9 @@ class ParquetReadSupport( |Catalyst requested schema: |${catalystRequestedSchema.treeString} """.stripMargin) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) - } - /** - * Called on executor side after [[init()]], before instantiating actual Parquet record readers. - * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet - * records to Catalyst [[InternalRow]]s. - */ - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - val parquetRequestedSchema = readContext.getRequestedSchema - new ParquetRecordMaterializer( - parquetRequestedSchema, - ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf), - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) + parquetRequestedSchema } -} - -object ParquetReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist @@ -156,9 +189,10 @@ object ParquetReadSupport { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, + useFieldId: Boolean, caseSensitive: Boolean = true): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, useFieldId, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -170,20 +204,24 @@ object ParquetReadSupport { } private def clipParquetType( - parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { + parquetType: Type, + catalystType: DataType, + useFieldId: Boolean, + caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, useFieldId, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + clipParquetMapType( + parquetType.asGroupType(), t.keyType, t.valueType, useFieldId, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, useFieldId, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -210,7 +248,10 @@ object ParquetReadSupport { * [[StructType]]. */ private def clipParquetListType( - parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { + parquetList: GroupType, + elementType: DataType, + useFieldId: Boolean, + caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -218,7 +259,7 @@ object ParquetReadSupport { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, caseSensitive) + clipParquetType(parquetList, elementType, useFieldId, caseSensitive) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -242,7 +283,7 @@ object ParquetReadSupport { // "_tuple" appended then the repeated type is the element type and elements are required. // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the // only field. - if ( + val newParquetList = if ( repeatedGroup.getFieldCount > 1 || repeatedGroup.getName == "array" || repeatedGroup.getName == parquetList.getName + "_tuple" @@ -250,21 +291,35 @@ object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, useFieldId, caseSensitive)) .named(parquetList.getName) } else { + val newRepeatedGroup = Types + .repeatedGroup() + .addField( + clipParquetType( + repeatedGroup.getType(0), elementType, useFieldId, caseSensitive)) + .named(repeatedGroup.getName) + + val newElementType = if (useFieldId && repeatedGroup.getId() != null) { + newRepeatedGroup.withId(repeatedGroup.getId().intValue()) + } else { + newRepeatedGroup + } + // Otherwise, the repeated field's type is the element type with the repeated field's // repetition. Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField( - Types - .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) - .named(repeatedGroup.getName)) + .addField(newElementType) .named(parquetList.getName) } + if (useFieldId && parquetList.getId() != null) { + newParquetList.withId(parquetList.getId().intValue()) + } else { + newParquetList + } } } @@ -277,6 +332,7 @@ object ParquetReadSupport { parquetMap: GroupType, keyType: DataType, valueType: DataType, + useFieldId: Boolean, caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -285,19 +341,31 @@ object ParquetReadSupport { val parquetKeyType = repeatedGroup.getType(0) val parquetValueType = repeatedGroup.getType(1) - val clippedRepeatedGroup = - Types + val clippedRepeatedGroup = { + val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, useFieldId, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, useFieldId, caseSensitive)) .named(repeatedGroup.getName) + if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + } - Types + val newMap = Types .buildGroup(parquetMap.getRepetition) .as(parquetMap.getLogicalTypeAnnotation) .addField(clippedRepeatedGroup) .named(parquetMap.getName) + + if (useFieldId && parquetMap.getId() != null) { + newMap.withId(parquetMap.getId().intValue()) + } else { + newMap + } } /** @@ -309,13 +377,22 @@ object ParquetReadSupport { * pruning. */ private def clipParquetGroup( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) - Types + parquetRecord: GroupType, + structType: StructType, + useFieldId: Boolean, + caseSensitive: Boolean): GroupType = { + val clippedParquetFields = + clipParquetGroupFields(parquetRecord, structType, useFieldId, caseSensitive) + val newRecord = Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) .addFields(clippedParquetFields: _*) .named(parquetRecord.getName) + if (useFieldId && parquetRecord.getId() != null) { + newRecord.withId(parquetRecord.getId().intValue()) + } else { + newRecord + } } /** @@ -324,23 +401,29 @@ object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { - val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - if (caseSensitive) { - val caseSensitiveParquetFieldMap = + parquetRecord: GroupType, + structType: StructType, + useFieldId: Boolean, + caseSensitive: Boolean): Seq[Type] = { + val toParquet = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, useFieldId = useFieldId) + lazy val caseSensitiveParquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - structType.map { f => - caseSensitiveParquetFieldMap + lazy val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + lazy val idToParquetFieldMap = + parquetRecord.getFields.asScala.filter(_.getId() != null).groupBy(f => f.getId.intValue()) + + def matchCaseSensitiveField(f: StructField): Type = { + caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, caseSensitive)) + .map(clipParquetType(_, f.dataType, useFieldId, caseSensitive)) .getOrElse(toParquet.convertField(f)) - } - } else { + } + + def matchCaseInsensitiveField(f: StructField): Type = { // Do case-insensitive resolution only if in case-insensitive mode - val caseInsensitiveParquetFieldMap = - parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) - structType.map { f => - caseInsensitiveParquetFieldMap + caseInsensitiveParquetFieldMap .get(f.name.toLowerCase(Locale.ROOT)) .map { parquetTypes => if (parquetTypes.size > 1) { @@ -349,11 +432,50 @@ object ParquetReadSupport { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) } }.getOrElse(toParquet.convertField(f)) + } + + def matchIdField(f: StructField): Type = { + val fieldId = ParquetUtils.getFieldId(f) + idToParquetFieldMap + .get(fieldId) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( + fieldId, parquetTypesString) + } else { + clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) + } + }.getOrElse { + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) } } + + if (useFieldId && ParquetUtils.hasFieldIds(structType)) { + structType.map { f => + if (ParquetUtils.hasFieldId(f)) { + // try to match id if there's any + matchIdField(f) + } else { + // fall back to name matching + if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) + } + } + } + } else if (caseSensitive) { + structType.map(matchCaseSensitiveField) + } else { + structType.map(matchCaseInsensitiveField) + } } /** @@ -410,4 +532,13 @@ object ParquetReadSupport { expand(schema).asInstanceOf[StructType] } + + /** + * Whether the parquet schema contains any field IDs. + */ + def containsFieldIds(schema: Type): Boolean = schema match { + case p: PrimitiveType => p.getId() != null + // We don't require all fields to have IDs, so we use `exists` here. + case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cb5d646f85e9e..49e5d4d1b26e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -438,16 +438,19 @@ class ParquetToSparkSchemaConverter( class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = - SQLConf.ParquetOutputTimestampType.INT96) { + SQLConf.ParquetOutputTimestampType.INT96, + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - outputTimestampType = conf.parquetOutputTimestampType) + outputTimestampType = conf.parquetOutputTimestampType, + useFieldId = conf.parquetFieldIdEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( - conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), + useFieldId = SQLConf.get.parquetFieldIdEnabled) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -463,7 +466,12 @@ class SparkToParquetSchemaConverter( * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. */ def convertField(field: StructField): Type = { - convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + val converted = convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + if (useFieldId && ParquetUtils.hasFieldId(field)) { + converted.withId(ParquetUtils.getFieldId(field)) + } else { + converted + } } private def convertField(field: StructField, repetition: Type.Repetition): Type = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 87a0d9c860f31..1130a45137570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} object ParquetUtils { def inferSchema( @@ -144,6 +144,42 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field + * ID metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require(hasFieldId(field), + "The key `parquet.field.id` doesn't exist in the metadata of " + field) + field.metadata.getLong(FIELD_ID_METADATA_KEY).toInt + } + /** * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala new file mode 100644 index 0000000000000..ef3b3f8ec143a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructType} + +class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkSession { + + private def withId(id: Int): Metadata = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + /** + * Field id is supported in OSS vectorized reader at the moment. + * parquet-mr support is coming soon. + */ + private def withAllSupportedReaders(code: => Unit): Unit = { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + } + + test("general test") { + withTempDir { dir => + val readSchema = + new StructType().add( + "a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("random", IntegerType, true, withId(1)) + .add("name", StringType, true, withId(0)) + + val readData = Seq(Row("text", 100), Row("more", 200)) + val writeData = Seq(Row(100, "text"), Row(200, "more")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("b < 50"), Seq.empty) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("a >= 'oh'"), Row("text", 100) :: Nil) + } + + // blocked for Parquet-mr reader + val e = intercept[SparkException] { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + } + } + val cause = e.getCause + assert(cause.isInstanceOf[java.io.IOException] && + cause.getMessage.contains("Parquet-mr reader does not support schema with field IDs.")) + } + } + + test("absence of field ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", StringType, true, withId(2)) + .add("c", IntegerType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(3)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // 3 different cases for the 3 columns to read: + // - a: ID 1 is not found, but there is column with name `a`, still return null + // - b: ID 2 is not found, return null + // - c: ID 3 is found, read it + Row(null, null, 100) :: Row(null, null, 200) :: Nil) + } + } + } + + test("multiple id matches") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(1)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllSupportedReaders { + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Found duplicate field(s)")) + } + } + } + + test("read parquet file without ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true) + .add("rand1", StringType, true) + .add("rand2", StringType, true) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllSupportedReaders { + Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => + val cause = intercept[SparkException] { + spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain field Ids")) + val expectedValues = (1 to schema.length).map(_ => null) + withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { + checkAnswer( + spark.read.schema(schema).parquet(dir.getCanonicalPath), + Row(expectedValues: _*) :: Row(expectedValues: _*) :: Nil) + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala new file mode 100644 index 0000000000000..2f964d1a483f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { + + private val FAKE_COLUMN_NAME = "_fake_name_" + private val UUID_REGEX = + "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + + private def withId(id: Int) = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String, + caseSensitive: Boolean = true, + useFieldId: Boolean = true): Unit = { + test(s"Clipping with field id - $testName") { + val fileSchema = MessageTypeParser.parseMessageType(parquetSchema) + val actual = ParquetReadSupport.clipParquetSchema( + fileSchema, + catalystSchema, + useFieldId = useFieldId, + caseSensitive = caseSensitive) + + // each fake name should be uniquely generated + val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) + assert( + fakeColumnNames.distinct == fakeColumnNames, "Should generate unique fake column names") + + // replace the random part of all fake names with a fixed id generator + val ids1 = (1 to 100).iterator + val actualNormalized = MessageTypeParser.parseMessageType( + UUID_REGEX.replaceAllIn(actual.toString, _ => ids1.next().toString) + ) + val ids2 = (1 to 100).iterator + val expectedNormalized = MessageTypeParser.parseMessageType( + FAKE_COLUMN_NAME.r.replaceAllIn(expectedSchema, _ => s"$FAKE_COLUMN_NAME${ids2.next()}") + ) + + try { + expectedNormalized.checkContains(actualNormalized) + actualNormalized.checkContains(expectedNormalized) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expectedSchema + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + checkEqual(actualNormalized, expectedNormalized) + // might be redundant but just to have some free tests for the utils + assert(ParquetReadSupport.containsFieldIds(fileSchema)) + assert(ParquetUtils.hasFieldIds(catalystSchema)) + } + } + + private def testSqlToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String): Unit = { + val converter = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, + outputTimestampType = SQLConf.ParquetOutputTimestampType.INT96, + useFieldId = true) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + checkEqual(actual, expected) + } + } + + private def checkEqual(actual: MessageType, expected: MessageType): Unit = { + actual.checkContains(expected) + expected.checkContains(actual) + assert(actual.toString == expected.toString, + s""" + |Schema mismatch. + |Expected schema: + |${expected.toString} + |Actual schema: + |${actual.toString} + """.stripMargin + ) + } + + test("check hasFieldIds for schema") { + val simpleSchemaMissingId = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true) + + assert(ParquetUtils.hasFieldIds(simpleSchemaMissingId)) + + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(8)) + + assert(ParquetUtils.hasFieldIds(f01ElementType)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + assert(ParquetUtils.hasFieldIds(f0Type)) + + assert(ParquetUtils.hasFieldIds( + new StructType().add("f0", f0Type, nullable = false, withId(1)))) + + assert(!ParquetUtils.hasFieldIds(new StructType().add("f0", IntegerType, nullable = true))) + assert(!ParquetUtils.hasFieldIds(new StructType())); + } + + test("check containsFieldIds for parquet schema") { + + // empty Parquet schema fails too + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 = 1 { + | optional int32 f00; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00 = 1; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list = 1 { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + } + + test("ID in Parquet Types is read as null when not set") { + val parquetSchemaString = + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin + + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaString) + val f0 = parquetSchema.getFields().get(0) + assert(f0.getId() == null) + assert(f0.asGroupType().getFields.get(0).getId == null) + } + + testSqlToParquet( + "standard array", + sqlSchema = { + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("f0", f0Type, nullable = false, withId(1)) + }, + parquetSchema = + """message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f010 = 7; + | optional int64 f012 = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 f01 = 3; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add( + "g00", IntegerType, nullable = true, withId(2)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(4)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 $FAKE_COLUMN_NAME = 4; + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional int32 f010 = 7; + | optional double f011 = 8; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("g011", DoubleType, nullable = true, withId(8)) + .add("g012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("g00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("g01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("g0", f0Type, nullable = false, withId(1)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f011 = 8; + | optional int64 $FAKE_COLUMN_NAME = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int32 value_f0 = 4; + | required int64 value_f1 = 6; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_g1", LongType, nullable = false, withId(6)) + .add("value_g2", DoubleType, nullable = false, withId(7)) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("g0", f0Type, nullable = false, withId(3)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int64 value_f1 = 6; + | required double $FAKE_COLUMN_NAME = 7; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "won't match field id if structure is different", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + // parquet has id 3, but won't use because structure is different + .add("g01", IntegerType, nullable = true, withId(3)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + }, + + // note that f1 is not picked up, even though it's Id is 3 + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 $FAKE_COLUMN_NAME = 3; + | } + |} + """.stripMargin) + + testSchemaClipping( + "Complex type with multiple mismatches should work", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(999)) + .add("g1", IntegerType, nullable = true, withId(3)) + .add("g2", IntegerType, nullable = true, withId(888)) + }, + + expectedSchema = + s"""message spark_schema { + | required group $FAKE_COLUMN_NAME = 999 { + | optional int32 g00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 $FAKE_COLUMN_NAME = 888; + |} + """.stripMargin) + + testSchemaClipping( + "Should allow fall-back to name matching if id not found", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + // nested f00 without id should also work + .add("f00", IntegerType, nullable = true) + + val f4Type = new StructType() + .add("g40", IntegerType, nullable = true, withId(6)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(3)) + // f2 without id should be matched using name matching + .add("f2", IntegerType, nullable = true) + // name is not matched + .add("g2", IntegerType, nullable = true) + // f4 without id will do name matching, but g40 will be matched using id + .add("f4", f4Type, nullable = true) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | optional int32 g2; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 272f12e138b68..6d1d160bd264e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -2257,7 +2257,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + useFieldId = false, + caseSensitive) try { expectedSchema.checkContains(actual) @@ -2821,7 +2824,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } assertThrows[RuntimeException] { ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + useFieldId = false, + caseSensitive = false) } } } From 6d7e769a720bcafb632195c5c004b41461d8bfd3 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 3 Feb 2022 17:21:27 +1100 Subject: [PATCH 03/13] retrigger test From 83a3184c5bf79b214c24e13abd43d4f6a1dba7e7 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Fri, 4 Feb 2022 14:25:14 +1100 Subject: [PATCH 04/13] Address some comments --- .../datasources/parquet/ParquetReadSupport.scala | 8 ++++---- .../execution/datasources/parquet/ParquetUtils.scala | 2 +- .../datasources/parquet/ParquetFieldIdIOSuite.scala | 11 ++++++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 9db15b4076c36..6b566d8121983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -451,10 +451,10 @@ object ParquetReadSupport extends Logging { clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) } }.getOrElse { - // When there is no ID match, we use a fake name to avoid a name match by accident - // We need this name to be unique as well, otherwise there will be type conflicts - toParquet.convertField(f.copy(name = generateFakeColumnName)) - } + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) + } } if (useFieldId && ParquetUtils.hasFieldIds(structType)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 1130a45137570..8033c2cfea579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -176,7 +176,7 @@ object ParquetUtils { def getFieldId(field: StructField): Int = { require(hasFieldId(field), - "The key `parquet.field.id` doesn't exist in the metadata of " + field) + s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) field.metadata.getLong(FIELD_ID_METADATA_KEY).toInt } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index ef3b3f8ec143a..d2ec3fda1e68d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -31,14 +31,14 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() /** - * Field id is supported in OSS vectorized reader at the moment. + * Field id is supported in vectorized reader at the moment. * parquet-mr support is coming soon. */ private def withAllSupportedReaders(code: => Unit): Unit = { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) } - test("general test") { + test("Parquet reads infer fields using field ids correctly") { withTempDir { dir => val readSchema = new StructType().add( @@ -56,11 +56,16 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS .write.mode("overwrite").parquet(dir.getCanonicalPath) withAllSupportedReaders { + // read with schema checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) .where("b < 50"), Seq.empty) checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) .where("a >= 'oh'"), Row("text", 100) :: Nil) + // schema inference should pull into the schema with ids + val reader = spark.read.parquet(dir.getCanonicalPath) + assert(reader.schema == writeSchema) + checkAnswer(reader.where("name >= 'oh'"), Row(100, "text") :: Nil) } // blocked for Parquet-mr reader From bf7ddba2ae6e98f03213d5f7c65bbe8da47acaa1 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Sat, 5 Feb 2022 11:48:33 +1100 Subject: [PATCH 05/13] typos & conf --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 14d6a3e2e5f68..7a2c1a4c6f9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -945,8 +945,8 @@ object SQLConf { val IGNORE_MISSING_PARQUET_FIELD_ID = buildConf("spark.sql.parquet.fieldId.ignoreMissing") - .doc("When the Parquet file does't have any field IDs but the" + - " Spark read schema is using field IDs to read, we will return silently return nulls" + + .doc("When the Parquet file doesn't have any field IDs but the" + + " Spark read schema is using field IDs to read, we will silently return nulls" + "when this flag is enabled, or error otherwise.") .booleanConf .createWithDefault(false) @@ -4270,6 +4270,8 @@ class SQLConf extends Serializable with Logging { def parquetFieldIdEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_ENABLED) + def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) /** ********************** SQLConf functionality methods ************ */ From eb21dc53934dc0b9241abe19be48c83a3e7a30f3 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Tue, 8 Feb 2022 16:18:40 +1100 Subject: [PATCH 06/13] Address comments --- .../apache/spark/sql/internal/SQLConf.scala | 19 +++++-- .../parquet/ParquetFileFormat.scala | 3 +- .../parquet/ParquetReadSupport.scala | 56 +++++++++---------- .../parquet/ParquetSchemaConverter.scala | 6 +- .../parquet/ParquetFieldIdIOSuite.scala | 20 ++++++- .../parquet/ParquetFieldIdSchemaSuite.scala | 4 +- .../parquet/ParquetSchemaSuite.scala | 8 +-- .../spark/sql/test/TestSQLContext.scala | 5 +- 8 files changed, 74 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7a2c1a4c6f9d5..a79024063fc89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,12 +934,19 @@ object SQLConf { .intConf .createWithDefault(4096) - val PARQUET_FIELD_ID_ENABLED = - buildConf("spark.sql.parquet.fieldId.enabled") + val PARQUET_FIELD_ID_WRITE_ENABLED = + buildConf("spark.sql.parquet.fieldId.write.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled," + + " Parquet writers will populate the field Id" + + " metadata (if present) in the Spark schema to the Parquet schema.") + .booleanConf + .createWithDefault(true) + + val PARQUET_FIELD_ID_READ_ENABLED = + buildConf("spark.sql.parquet.fieldId.read.enabled") .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers" + " will use field IDs (if present) in the requested Spark schema to look up Parquet" + - " fields instead of using column names; Parquet writers will also populate the field Id" + - " metadata (if present) in the Spark schema to the Parquet schema.") + " fields instead of using column names") .booleanConf .createWithDefault(true) @@ -4268,7 +4275,9 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) - def parquetFieldIdEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_ENABLED) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) + + def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED) def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 393a1e21999b7..a1bca4e0335c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -356,8 +356,7 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") - if (SQLConf.get.parquetFieldIdEnabled && - ParquetUtils.hasFieldIds(requiredSchema)) { + if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(requiredSchema)) { throw new IOException("Parquet-mr reader does not support schema with field IDs." + s" Please choose a different Parquet reader. Read schema: ${requiredSchema.json}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 6b566d8121983..2e1e4fccfc65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -127,8 +127,8 @@ object ParquetReadSupport extends Logging { SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_ENABLED.key, - SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) @@ -148,7 +148,7 @@ object ParquetReadSupport extends Logging { } val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, useFieldId, caseSensitive) + catalystRequestedSchema, caseSensitive, useFieldId) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -189,10 +189,10 @@ object ParquetReadSupport extends Logging { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, - useFieldId: Boolean, - caseSensitive: Boolean = true): MessageType = { + caseSensitive: Boolean, + useFieldId: Boolean): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, useFieldId, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, caseSensitive, useFieldId) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -206,22 +206,22 @@ object ParquetReadSupport extends Logging { private def clipParquetType( parquetType: Type, catalystType: DataType, - useFieldId: Boolean, - caseSensitive: Boolean): Type = { + caseSensitive: Boolean, + useFieldId: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, useFieldId, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type clipParquetMapType( - parquetType.asGroupType(), t.keyType, t.valueType, useFieldId, caseSensitive) + parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, useFieldId, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -250,8 +250,8 @@ object ParquetReadSupport extends Logging { private def clipParquetListType( parquetList: GroupType, elementType: DataType, - useFieldId: Boolean, - caseSensitive: Boolean): Type = { + caseSensitive: Boolean, + useFieldId: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -259,7 +259,7 @@ object ParquetReadSupport extends Logging { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, useFieldId, caseSensitive) + clipParquetType(parquetList, elementType, caseSensitive, useFieldId) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -291,14 +291,14 @@ object ParquetReadSupport extends Logging { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, useFieldId, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive, useFieldId)) .named(parquetList.getName) } else { val newRepeatedGroup = Types .repeatedGroup() .addField( clipParquetType( - repeatedGroup.getType(0), elementType, useFieldId, caseSensitive)) + repeatedGroup.getType(0), elementType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) val newElementType = if (useFieldId && repeatedGroup.getId() != null) { @@ -332,8 +332,8 @@ object ParquetReadSupport extends Logging { parquetMap: GroupType, keyType: DataType, valueType: DataType, - useFieldId: Boolean, - caseSensitive: Boolean): GroupType = { + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -345,8 +345,8 @@ object ParquetReadSupport extends Logging { val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, useFieldId, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, useFieldId, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) if (useFieldId && repeatedGroup.getId != null) { newRepeatedGroup.withId(repeatedGroup.getId.intValue()) @@ -379,10 +379,10 @@ object ParquetReadSupport extends Logging { private def clipParquetGroup( parquetRecord: GroupType, structType: StructType, - useFieldId: Boolean, - caseSensitive: Boolean): GroupType = { + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { val clippedParquetFields = - clipParquetGroupFields(parquetRecord, structType, useFieldId, caseSensitive) + clipParquetGroupFields(parquetRecord, structType, caseSensitive, useFieldId) val newRecord = Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) @@ -403,8 +403,8 @@ object ParquetReadSupport extends Logging { private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType, - useFieldId: Boolean, - caseSensitive: Boolean): Seq[Type] = { + caseSensitive: Boolean, + useFieldId: Boolean): Seq[Type] = { val toParquet = new SparkToParquetSchemaConverter( writeLegacyParquetFormat = false, useFieldId = useFieldId) lazy val caseSensitiveParquetFieldMap = @@ -417,7 +417,7 @@ object ParquetReadSupport extends Logging { def matchCaseSensitiveField(f: StructField): Type = { caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, useFieldId, caseSensitive)) + .map(clipParquetType(_, f.dataType, caseSensitive, useFieldId)) .getOrElse(toParquet.convertField(f)) } @@ -432,7 +432,7 @@ object ParquetReadSupport extends Logging { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) } }.getOrElse(toParquet.convertField(f)) } @@ -448,7 +448,7 @@ object ParquetReadSupport extends Logging { throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( fieldId, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, useFieldId, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) } }.getOrElse { // When there is no ID match, we use a fake name to avoid a name match by accident diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 49e5d4d1b26e2..4010585e04f7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -439,18 +439,18 @@ class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = SQLConf.ParquetOutputTimestampType.INT96, - useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_ENABLED.defaultValue.get) { + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, outputTimestampType = conf.parquetOutputTimestampType, - useFieldId = conf.parquetFieldIdEnabled) + useFieldId = conf.parquetFieldIdWriteEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), - useFieldId = SQLConf.get.parquetFieldIdEnabled) + useFieldId = SQLConf.get.parquetFieldIdWriteEnabled) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index d2ec3fda1e68d..73ef56928fe39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -41,8 +41,18 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS test("Parquet reads infer fields using field ids correctly") { withTempDir { dir => val readSchema = - new StructType().add( - "a", StringType, true, withId(0)) + new StructType() + .add("a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixed = + new StructType() + .add("name", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixedHalfMatched = + new StructType() + .add("unmatched", StringType, true) .add("b", IntegerType, true, withId(1)) val writeSchema = @@ -51,6 +61,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS .add("name", StringType, true, withId(0)) val readData = Seq(Row("text", 100), Row("more", 200)) + val readDataHalfMatched = Seq(Row(null, 100), Row(null, 200)) val writeData = Seq(Row(100, "text"), Row(200, "more")) spark.createDataFrame(writeData.asJava, writeSchema) .write.mode("overwrite").parquet(dir.getCanonicalPath) @@ -62,6 +73,11 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS .where("b < 50"), Seq.empty) checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) .where("a >= 'oh'"), Row("text", 100) :: Nil) + // read with mixed field-id/name schema + checkAnswer(spark.read.schema(readSchemaMixed).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchemaMixedHalfMatched) + .parquet(dir.getCanonicalPath), readDataHalfMatched) + // schema inference should pull into the schema with ids val reader = spark.read.parquet(dir.getCanonicalPath) assert(reader.schema == writeSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala index 2f964d1a483f5..15e0e11364a9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -45,8 +45,8 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { val actual = ParquetReadSupport.clipParquetSchema( fileSchema, catalystSchema, - useFieldId = useFieldId, - caseSensitive = caseSensitive) + caseSensitive = caseSensitive, + useFieldId = useFieldId) // each fake name should be uniquely generated val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 6d1d160bd264e..2feea41d15656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -2259,8 +2259,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val actual = ParquetReadSupport.clipParquetSchema( MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, - useFieldId = false, - caseSensitive) + caseSensitive, + useFieldId = false) try { expectedSchema.checkContains(actual) @@ -2826,8 +2826,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { ParquetReadSupport.clipParquetSchema( MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, - useFieldId = false, - caseSensitive = false) + caseSensitive = false, + useFieldId = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 47a6f3617da63..2f65cb7d16c2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,7 +61,10 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5") + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // Enable parquet read field id for tests to ensure correctness + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true" + ) } private[sql] class TestSQLSessionStateBuilder( From 0495e0b578201fd412762e2e1d7ba020cefcda77 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Tue, 8 Feb 2022 16:30:54 +1100 Subject: [PATCH 07/13] Enable parquet-mr code path --- .../parquet/ParquetFileFormat.scala | 6 --- .../parquet/ParquetRowConverter.scala | 38 ++++++++++++++++--- .../parquet/ParquetFieldIdIOSuite.scala | 25 ++---------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index a1bca4e0335c2..31ebee8e15e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.io.IOException import java.net.URI import scala.collection.JavaConverters._ @@ -356,11 +355,6 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") - if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(requiredSchema)) { - throw new IOException("Parquet-mr reader does not support schema with field IDs." + - s" Please choose a different Parquet reader. Read schema: ${requiredSchema.json}") - } - // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( convertTz, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b12898360dcf4..2012370ab729e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -203,16 +203,42 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { - catalystType.fieldNames.zipWithIndex.toMap + def nameToIndex = + catalystType.fields.zipWithIndex.map { case (f, idx) => + (f.name, idx) + }.toMap + + val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { + nameToIndex } else { - CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + CaseInsensitiveMap(nameToIndex) } + + // (SPARK-38094) parquet field ids, if exist, should be prioritized for matching + val catalystFieldIdxByFieldId = + if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(catalystType)) { + catalystType.fields + .zipWithIndex + .filter { case (f, _) => ParquetUtils.hasFieldId(f) } + .map { case (f, idx) => (ParquetUtils.getFieldId(f), idx) } + .toMap + } else { + Map.empty[Int, Int] + } + parquetType.getFields.asScala.map { parquetField => - val fieldIndex = catalystFieldNameToIndex(parquetField.getName) - val catalystField = catalystType(fieldIndex) + val catalystFieldIndex = Option(parquetField.getId).map { fieldId => + // field has id, try to match by id first before falling back to match by name + catalystFieldIdxByFieldId + .getOrElse(fieldId.intValue(), catalystFieldIdxByName(parquetField.getName)) + }.getOrElse { + // field doesn't have id, just match by name + catalystFieldIdxByName(parquetField.getName) + } + val catalystField = catalystType(catalystFieldIndex) // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` - newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + newConverter(parquetField, + catalystField.dataType, new RowUpdater(currentRow, catalystFieldIndex)) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index 73ef56928fe39..d13242bb83ade 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -30,13 +30,6 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS private def withId(id: Int): Metadata = new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() - /** - * Field id is supported in vectorized reader at the moment. - * parquet-mr support is coming soon. - */ - private def withAllSupportedReaders(code: => Unit): Unit = { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) - } test("Parquet reads infer fields using field ids correctly") { withTempDir { dir => @@ -66,7 +59,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS spark.createDataFrame(writeData.asJava, writeSchema) .write.mode("overwrite").parquet(dir.getCanonicalPath) - withAllSupportedReaders { + withAllParquetReaders { // read with schema checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) @@ -83,16 +76,6 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS assert(reader.schema == writeSchema) checkAnswer(reader.where("name >= 'oh'"), Row(100, "text") :: Nil) } - - // blocked for Parquet-mr reader - val e = intercept[SparkException] { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) - } - } - val cause = e.getCause - assert(cause.isInstanceOf[java.io.IOException] && - cause.getMessage.contains("Parquet-mr reader does not support schema with field IDs.")) } } @@ -114,7 +97,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS spark.createDataFrame(writeData.asJava, writeSchema) .write.mode("overwrite").parquet(dir.getCanonicalPath) - withAllSupportedReaders { + withAllParquetReaders { checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), // 3 different cases for the 3 columns to read: // - a: ID 1 is not found, but there is column with name `a`, still return null @@ -142,7 +125,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS spark.createDataFrame(writeData.asJava, writeSchema) .write.mode("overwrite").parquet(dir.getCanonicalPath) - withAllSupportedReaders { + withAllParquetReaders { val cause = intercept[SparkException] { spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() }.getCause @@ -167,7 +150,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) spark.createDataFrame(writeData.asJava, writeSchema) .write.mode("overwrite").parquet(dir.getCanonicalPath) - withAllSupportedReaders { + withAllParquetReaders { Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => val cause = intercept[SparkException] { spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() From b7f76f7e07bd3bddb8a9e2231bc478b85e906142 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Tue, 8 Feb 2022 17:32:16 +1100 Subject: [PATCH 08/13] Address comments --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../parquet/ParquetFieldIdIOSuite.scala | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a79024063fc89..406a8f8073ae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -948,7 +948,7 @@ object SQLConf { " will use field IDs (if present) in the requested Spark schema to look up Parquet" + " fields instead of using column names") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val IGNORE_MISSING_PARQUET_FIELD_ID = buildConf("spark.sql.parquet.fieldId.ignoreMissing") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index d13242bb83ade..f944878cd35c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -167,4 +167,48 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS } } } + + test("global read/write flag should work correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("some", IntegerType, true, withId(1)) + .add("other", StringType, true, withId(2)) + .add("name", StringType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(3)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + val expectedResult = Seq(Row(null, null, null), Row(null, null, null)) + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "false", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // no field id found exception + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain field Ids")) + } + } + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "true", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "false") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // ids are there, but we don't use id for matching, so no results would be returned + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), expectedResult) + } + } + } + } } From 411158205a59e4f89678c19e7db14f6a1e4b86f4 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Wed, 9 Feb 2022 12:03:10 +1100 Subject: [PATCH 09/13] address comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 ++++- .../execution/datasources/parquet/ParquetFileFormat.scala | 4 ++++ .../execution/datasources/parquet/ParquetReadSupport.scala | 3 ++- .../execution/datasources/parquet/ParquetRowConverter.scala | 5 +---- .../datasources/parquet/ParquetSchemaConverter.scala | 2 +- .../sql/execution/datasources/v2/parquet/ParquetWrite.scala | 3 +++ 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 406a8f8073ae3..cdb7e2fdd20da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -939,6 +939,7 @@ object SQLConf { .doc("Field ID is a native field of the Parquet schema spec. When enabled," + " Parquet writers will populate the field Id" + " metadata (if present) in the Spark schema to the Parquet schema.") + .version("3.3.0") .booleanConf .createWithDefault(true) @@ -947,6 +948,7 @@ object SQLConf { .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers" + " will use field IDs (if present) in the requested Spark schema to look up Parquet" + " fields instead of using column names") + .version("3.3.0") .booleanConf .createWithDefault(false) @@ -954,7 +956,8 @@ object SQLConf { buildConf("spark.sql.parquet.fieldId.ignoreMissing") .doc("When the Parquet file doesn't have any field IDs but the" + " Spark read schema is using field IDs to read, we will silently return nulls" + - "when this flag is enabled, or error otherwise.") + " when this flag is enabled, or error otherwise.") + .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 31ebee8e15e18..cd81f9242e736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -119,6 +119,10 @@ class ParquetFileFormat SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 2e1e4fccfc65c..46c8c61082996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId +import java.util import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -87,7 +88,7 @@ class ParquetReadSupport( val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + new ReadContext(parquetRequestedSchema, new util.HashMap[String, String]()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 2012370ab729e..76318ed53bfdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -203,10 +203,7 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - def nameToIndex = - catalystType.fields.zipWithIndex.map { case (f, idx) => - (f.name, idx) - }.toMap + def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { nameToIndex diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 4010585e04f7a..e50743939e714 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -450,7 +450,7 @@ class SparkToParquetSchemaConverter( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), - useFieldId = SQLConf.get.parquetFieldIdWriteEnabled) + useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index 0316d91f40732..78c75a14a0154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -81,6 +81,9 @@ case class ParquetWrite( conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sqlConf.parquetOutputTimestampType.toString) + conf + .set(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, sqlConf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) From 723dff7f9d00b4bdc9e0b6d0ea7ecac2820ffdbe Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Wed, 9 Feb 2022 17:26:21 +1100 Subject: [PATCH 10/13] Address more comments --- .../apache/spark/sql/internal/SQLConf.scala | 18 ++++++------- .../parquet/ParquetReadSupport.scala | 17 ++++++++++-- .../datasources/parquet/ParquetUtils.scala | 2 +- .../datasources/v2/parquet/ParquetWrite.scala | 5 ++-- .../parquet/ParquetFieldIdIOSuite.scala | 4 +-- .../parquet/ParquetFieldIdSchemaSuite.scala | 27 +++++++++++++++++++ 6 files changed, 57 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cdb7e2fdd20da..e52766225b645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -936,27 +936,27 @@ object SQLConf { val PARQUET_FIELD_ID_WRITE_ENABLED = buildConf("spark.sql.parquet.fieldId.write.enabled") - .doc("Field ID is a native field of the Parquet schema spec. When enabled," + - " Parquet writers will populate the field Id" + - " metadata (if present) in the Spark schema to the Parquet schema.") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, " + + "Parquet writers will populate the field Id " + + "metadata (if present) in the Spark schema to the Parquet schema.") .version("3.3.0") .booleanConf .createWithDefault(true) val PARQUET_FIELD_ID_READ_ENABLED = buildConf("spark.sql.parquet.fieldId.read.enabled") - .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers" + - " will use field IDs (if present) in the requested Spark schema to look up Parquet" + - " fields instead of using column names") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " + + "will use field IDs (if present) in the requested Spark schema to look up Parquet " + + "fields instead of using column names") .version("3.3.0") .booleanConf .createWithDefault(false) val IGNORE_MISSING_PARQUET_FIELD_ID = buildConf("spark.sql.parquet.fieldId.ignoreMissing") - .doc("When the Parquet file doesn't have any field IDs but the" + - " Spark read schema is using field IDs to read, we will silently return nulls" + - " when this flag is enabled, or error otherwise.") + .doc("When the Parquet file doesn't have any field IDs but the " + + "Spark read schema is using field IDs to read, we will silently return nulls " + + "when this flag is enabled, or error otherwise.") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 46c8c61082996..f6bf0f2c3cc94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -137,9 +137,11 @@ object ParquetReadSupport extends Logging { !containsFieldIds(parquetFileSchema) && ParquetUtils.hasFieldIds(catalystRequestedSchema)) { throw new RuntimeException( + "Spark read schema expects field Ids, " + + "but Parquet file schema doesn't contain any field Ids.\n" + + "Please remove the field ids from Spark schema or ignore missing ids by " + + "setting `spark.sql.parquet.fieldId.ignoreMissing = true`\n" + s""" - |Spark read schema expects field Ids, but Parquet file schema doesn't contain field Ids. - | |Spark read schema: |${catalystRequestedSchema.prettyJson} | @@ -183,6 +185,17 @@ object ParquetReadSupport extends Logging { parquetRequestedSchema } + /** + * Overloaded method for backward compatibility with + * `caseSensitive` default to `true` and `useFieldId` default to `false` + */ + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + clipParquetSchema(parquetSchema, catalystSchema, caseSensitive, useFieldId = false) + } + /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist * in `catalystSchema`, and adding those only exist in `catalystSchema`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 8033c2cfea579..68c1c99eecf32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -177,7 +177,7 @@ object ParquetUtils { def getFieldId(field: StructField): Int = { require(hasFieldId(field), s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) - field.metadata.getLong(FIELD_ID_METADATA_KEY).toInt + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index 78c75a14a0154..d84acedb962e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -81,8 +81,9 @@ case class ParquetWrite( conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sqlConf.parquetOutputTimestampType.toString) - conf - .set(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, sqlConf.parquetFieldIdWriteEnabled.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sqlConf.parquetFieldIdWriteEnabled.toString) // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index f944878cd35c7..cfdae6e4eba23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -156,7 +156,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() }.getCause assert(cause.isInstanceOf[RuntimeException] && - cause.getMessage.contains("Parquet file schema doesn't contain field Ids")) + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) val expectedValues = (1 to schema.length).map(_ => null) withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { checkAnswer( @@ -196,7 +196,7 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() }.getCause assert(cause.isInstanceOf[RuntimeException] && - cause.getMessage.contains("Parquet file schema doesn't contain field Ids")) + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala index 15e0e11364a9a..06eb627f950de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -138,6 +138,33 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { assert(!ParquetUtils.hasFieldIds(new StructType())); } + test("check getFieldId for schema") { + val schema = new StructType() + .add("overflowId", DoubleType, nullable = true, + new MetadataBuilder() + .putLong(ParquetUtils.FIELD_ID_METADATA_KEY, 12345678987654321L).build()) + .add("stringId", StringType, nullable = true, + new MetadataBuilder() + .putString(ParquetUtils.FIELD_ID_METADATA_KEY, "lol").build()) + .add("negativeId", LongType, nullable = true, withId(-20)) + .add("noId", LongType, nullable = true) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("noId")).get._2) + }.getMessage.contains("doesn't exist")) + + intercept[ArithmeticException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("overflowId")).get._2) + } + + intercept[ClassCastException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("stringId")).get._2) + } + + // negative id allowed + assert(ParquetUtils.getFieldId(schema.findNestedField(Seq("negativeId")).get._2) == -20) + } + test("check containsFieldIds for parquet schema") { // empty Parquet schema fails too From 53f44d19ee0772e78bddb6c2494168089a6da9d7 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 10 Feb 2022 18:24:01 +1100 Subject: [PATCH 11/13] comments & refactoring --- .../parquet/ParquetFileFormat.scala | 1 - .../parquet/ParquetReadSupport.scala | 64 +++++++------------ .../parquet/ParquetRowConverter.scala | 5 +- .../parquet/ParquetSchemaConverter.scala | 2 + .../datasources/parquet/ParquetUtils.scala | 8 ++- .../parquet/ParquetFieldIdIOSuite.scala | 7 ++ .../parquet/ParquetFieldIdSchemaSuite.scala | 14 ++-- .../datasources/parquet/ParquetTest.scala | 38 ++++++++++- .../spark/sql/test/TestSQLContext.scala | 5 +- 9 files changed, 87 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index cd81f9242e736..18876dedb951e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -358,7 +358,6 @@ class ParquetFileFormat } } else { logDebug(s"Falling back to parquet-mr") - // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( convertTz, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f6bf0f2c3cc94..97e691ff7c66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -149,7 +149,6 @@ object ParquetReadSupport extends Logging { |${parquetFileSchema.toString} |""".stripMargin) } - val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, catalystRequestedSchema, caseSensitive, useFieldId) @@ -222,7 +221,7 @@ object ParquetReadSupport extends Logging { catalystType: DataType, caseSensitive: Boolean, useFieldId: Boolean): Type = { - catalystType match { + val newParquetType = catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId) @@ -242,6 +241,12 @@ object ParquetReadSupport extends Logging { // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } + + if (useFieldId && parquetType.getId != null) { + newParquetType.withId(parquetType.getId.intValue()) + } else { + newParquetType + } } /** @@ -297,7 +302,7 @@ object ParquetReadSupport extends Logging { // "_tuple" appended then the repeated type is the element type and elements are required. // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the // only field. - val newParquetList = if ( + if ( repeatedGroup.getFieldCount > 1 || repeatedGroup.getName == "array" || repeatedGroup.getName == parquetList.getName + "_tuple" @@ -315,8 +320,8 @@ object ParquetReadSupport extends Logging { repeatedGroup.getType(0), elementType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) - val newElementType = if (useFieldId && repeatedGroup.getId() != null) { - newRepeatedGroup.withId(repeatedGroup.getId().intValue()) + val newElementType = if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) } else { newRepeatedGroup } @@ -329,11 +334,6 @@ object ParquetReadSupport extends Logging { .addField(newElementType) .named(parquetList.getName) } - if (useFieldId && parquetList.getId() != null) { - newParquetList.withId(parquetList.getId().intValue()) - } else { - newParquetList - } } } @@ -369,17 +369,11 @@ object ParquetReadSupport extends Logging { } } - val newMap = Types + Types .buildGroup(parquetMap.getRepetition) .as(parquetMap.getLogicalTypeAnnotation) .addField(clippedRepeatedGroup) .named(parquetMap.getName) - - if (useFieldId && parquetMap.getId() != null) { - newMap.withId(parquetMap.getId().intValue()) - } else { - newMap - } } /** @@ -397,16 +391,11 @@ object ParquetReadSupport extends Logging { useFieldId: Boolean): GroupType = { val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive, useFieldId) - val newRecord = Types + Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) .addFields(clippedParquetFields: _*) .named(parquetRecord.getName) - if (useFieldId && parquetRecord.getId() != null) { - newRecord.withId(parquetRecord.getId().intValue()) - } else { - newRecord - } } /** @@ -426,7 +415,7 @@ object ParquetReadSupport extends Logging { lazy val caseInsensitiveParquetFieldMap = parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) lazy val idToParquetFieldMap = - parquetRecord.getFields.asScala.filter(_.getId() != null).groupBy(f => f.getId.intValue()) + parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f => f.getId.intValue()) def matchCaseSensitiveField(f: StructField): Type = { caseSensitiveParquetFieldMap @@ -471,24 +460,15 @@ object ParquetReadSupport extends Logging { } } - if (useFieldId && ParquetUtils.hasFieldIds(structType)) { - structType.map { f => - if (ParquetUtils.hasFieldId(f)) { - // try to match id if there's any - matchIdField(f) - } else { - // fall back to name matching - if (caseSensitive) { - matchCaseSensitiveField(f) - } else { - matchCaseInsensitiveField(f) - } - } + val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType) + structType.map { f => + if (shouldMatchById && ParquetUtils.hasFieldId(f)) { + matchIdField(f) + } else if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) } - } else if (caseSensitive) { - structType.map(matchCaseSensitiveField) - } else { - structType.map(matchCaseInsensitiveField) } } @@ -551,7 +531,7 @@ object ParquetReadSupport extends Logging { * Whether the parquet schema contains any field IDs. */ def containsFieldIds(schema: Type): Boolean = schema match { - case p: PrimitiveType => p.getId() != null + case p: PrimitiveType => p.getId != null // We don't require all fields to have IDs, so we use `exists` here. case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 76318ed53bfdb..63ad5ed6db82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -224,10 +224,9 @@ private[parquet] class ParquetRowConverter( } parquetType.getFields.asScala.map { parquetField => - val catalystFieldIndex = Option(parquetField.getId).map { fieldId => + val catalystFieldIndex = Option(parquetField.getId).flatMap { fieldId => // field has id, try to match by id first before falling back to match by name - catalystFieldIdxByFieldId - .getOrElse(fieldId.intValue(), catalystFieldIdxByName(parquetField.getName)) + catalystFieldIdxByFieldId.get(fieldId.intValue()) }.getOrElse { // field doesn't have id, just match by name catalystFieldIdxByName(parquetField.getName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index e50743939e714..34a4eb8c002d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -434,6 +434,8 @@ class ParquetToSparkSchemaConverter( * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. * @param outputTimestampType which parquet timestamp type to use when writing. + * @param useFieldId whether we should include write field id to Parquet schema. Set this to false + * via `spark.sql.parquet.fieldId.write.enabled = false` to disable writing field ids. */ class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 68c1c99eecf32..812fa2224d284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -177,7 +177,13 @@ object ParquetUtils { def getFieldId(field: StructField): Int = { require(hasFieldId(field), s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) - Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + try { + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + } catch { + case _: ArithmeticException | _: ClassCastException => + throw new IllegalArgumentException( + s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index cfdae6e4eba23..99c7ac110bfc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -30,6 +30,13 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS private def withId(id: Int): Metadata = new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + protected def test(testName: String)(testFun: => Any): Unit = { + super.test(testName, ParquetUseDefaultFieldIdConfigs()) { + withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + testFun + } + } + } test("Parquet reads infer fields using field ids correctly") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala index 06eb627f950de..08a02a3aeaea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -30,6 +30,12 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { private val UUID_REGEX = "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + protected def test(testName: String)(testFun: => Any): Unit = { + withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + super.test(testName, ParquetUseDefaultFieldIdConfigs())(testFun) + } + } + private def withId(id: Int) = new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() @@ -153,13 +159,13 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { ParquetUtils.getFieldId(schema.findNestedField(Seq("noId")).get._2) }.getMessage.contains("doesn't exist")) - intercept[ArithmeticException] { + assert(intercept[IllegalArgumentException] { ParquetUtils.getFieldId(schema.findNestedField(Seq("overflowId")).get._2) - } + }.getMessage.contains("must be a 32-bit integer")) - intercept[ClassCastException] { + assert(intercept[IllegalArgumentException] { ParquetUtils.getFieldId(schema.findNestedField(Seq("stringId")).get._2) - } + }.getMessage.contains("must be a 32-bit integer")) // negative id allowed assert(ParquetUtils.getFieldId(schema.findNestedField(Seq("negativeId")).get._2) == -20) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 7a7957c67dce1..12496b19da7f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -32,6 +32,8 @@ import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter, import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.schema.MessageType +import org.scalactic.source.Position +import org.scalatest.Tag import org.apache.spark.TestUtils import org.apache.spark.sql.DataFrame @@ -52,6 +54,30 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { override protected val vectorizedReaderEnabledKey: String = SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key + case class ParquetUseDefaultFieldIdConfigs() + extends Tag("Use SQLConf default options for setting Parquet field id configs") + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + if (testTags.exists(_.isInstanceOf[ParquetUseDefaultFieldIdConfigs])) { + super.test(testName, testTags: _*)(testFun) + } else { + // grid test with different combination of parquet field id options + super.test(testName, testTags: _*) { + Seq(true, false).foreach { enableRead => + Seq(true, false).foreach { enableWrite => + withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> enableRead.toString, + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> enableWrite.toString) { + withClue(s"Field ID read enabled: $enableRead, write enabled: $enableWrite") { + testFun + } + } + } + } + } + } + } + /** * Reads the parquet file at `path` */ @@ -165,9 +191,17 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { def withAllParquetReaders(code: => Unit): Unit = { // test the row-based reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withClue("Parquet-mr reader") { + code + } + } // test the vectorized reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withClue("Vectorized reader") { + code + } + } } def withAllParquetWriters(code: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 2f65cb7d16c2a..47a6f3617da63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,10 +61,7 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - // Enable parquet read field id for tests to ensure correctness - SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true" - ) + SQLConf.SHUFFLE_PARTITIONS.key -> "5") } private[sql] class TestSQLSessionStateBuilder( From 308212ca201f893963a390753744322b98786068 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 10 Feb 2022 19:14:45 +1100 Subject: [PATCH 12/13] Conf name update --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e52766225b645..9e8d01a78c9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -953,7 +953,7 @@ object SQLConf { .createWithDefault(false) val IGNORE_MISSING_PARQUET_FIELD_ID = - buildConf("spark.sql.parquet.fieldId.ignoreMissing") + buildConf("spark.sql.parquet.fieldId.read.ignoreMissing") .doc("When the Parquet file doesn't have any field IDs but the " + "Spark read schema is using field IDs to read, we will silently return nulls " + "when this flag is enabled, or error otherwise.") From 9c0b23952de46f581f01cd12a73109e327beaa3d Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Thu, 10 Feb 2022 23:07:27 +1100 Subject: [PATCH 13/13] Remove the grid testing --- .../parquet/ParquetFieldIdIOSuite.scala | 8 ------ .../parquet/ParquetFieldIdSchemaSuite.scala | 6 ----- .../datasources/parquet/ParquetTest.scala | 26 ------------------- .../spark/sql/test/TestSQLContext.scala | 8 +++++- 4 files changed, 7 insertions(+), 41 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala index 99c7ac110bfc1..ff0bb2f92d208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -30,14 +30,6 @@ class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkS private def withId(id: Int): Metadata = new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() - protected def test(testName: String)(testFun: => Any): Unit = { - super.test(testName, ParquetUseDefaultFieldIdConfigs()) { - withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { - testFun - } - } - } - test("Parquet reads infer fields using field ids correctly") { withTempDir { dir => val readSchema = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala index 08a02a3aeaea5..b3babdd3a0cff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -30,12 +30,6 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { private val UUID_REGEX = "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r - protected def test(testName: String)(testFun: => Any): Unit = { - withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { - super.test(testName, ParquetUseDefaultFieldIdConfigs())(testFun) - } - } - private def withId(id: Int) = new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 12496b19da7f3..18690844d484c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -32,8 +32,6 @@ import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter, import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.schema.MessageType -import org.scalactic.source.Position -import org.scalatest.Tag import org.apache.spark.TestUtils import org.apache.spark.sql.DataFrame @@ -54,30 +52,6 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { override protected val vectorizedReaderEnabledKey: String = SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key - case class ParquetUseDefaultFieldIdConfigs() - extends Tag("Use SQLConf default options for setting Parquet field id configs") - - override protected def test(testName: String, testTags: Tag*)(testFun: => Any) - (implicit pos: Position): Unit = { - if (testTags.exists(_.isInstanceOf[ParquetUseDefaultFieldIdConfigs])) { - super.test(testName, testTags: _*)(testFun) - } else { - // grid test with different combination of parquet field id options - super.test(testName, testTags: _*) { - Seq(true, false).foreach { enableRead => - Seq(true, false).foreach { enableWrite => - withSQLConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> enableRead.toString, - SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> enableWrite.toString) { - withClue(s"Field ID read enabled: $enableRead, write enabled: $enableWrite") { - testFun - } - } - } - } - } - } - } - /** * Reads the parquet file at `path` */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 47a6f3617da63..fb3d38f3b7b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,7 +61,13 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5") + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // Enable parquet read field id for tests to ensure correctness + // By default, if Spark schema doesn't contain the `parquet.field.id` metadata, + // the underlying matching mechanism should behave exactly like name matching + // which is the existing behavior. Therefore, turning this on ensures that we didn't + // introduce any regression for such mixed matching mode. + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") } private[sql] class TestSQLSessionStateBuilder(