Skip to content

Commit

Permalink
[HUDI-4213] Infer keygen clazz for Spark SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
danny0405 committed Jun 9, 2022
1 parent 8ff17b0 commit 2389ccd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.hudi

import org.apache.hudi.DataSourceReadOptions.{QUERY_TYPE, QUERY_TYPE_READ_OPTIMIZED_OPT_VAL, QUERY_TYPE_SNAPSHOT_OPT_VAL}
import org.apache.hudi.HoodieConversionUtils.toScalaOption
import org.apache.hudi.common.config.{ConfigProperty, HoodieCommonConfig, HoodieConfig}
import org.apache.hudi.common.config.{ConfigProperty, HoodieCommonConfig, HoodieConfig, TypedProperties}
import org.apache.hudi.common.fs.ConsistencyGuardConfig
import org.apache.hudi.common.model.{HoodieTableType, WriteOperationType}
import org.apache.hudi.common.table.HoodieTableConfig
Expand Down Expand Up @@ -323,22 +323,12 @@ object DataSourceWriteOptions {
val HIVE_STYLE_PARTITIONING = KeyGeneratorOptions.HIVE_STYLE_PARTITIONING_ENABLE

/**
* Key generator class, that implements will extract the key out of incoming record
*
* Key generator class, that implements will extract the key out of incoming record.
*/
val keyGeneraterInferFunc = DataSourceOptionsHelper.scalaFunctionToJavaFunction((p: HoodieConfig) => {
if (!p.contains(PARTITIONPATH_FIELD)) {
Option.of(classOf[NonpartitionedKeyGenerator].getName)
} else {
val numOfPartFields = p.getString(PARTITIONPATH_FIELD).split(",").length
val numOfRecordKeyFields = p.getString(RECORDKEY_FIELD).split(",").length
if (numOfPartFields == 1 && numOfRecordKeyFields == 1) {
Option.of(classOf[SimpleKeyGenerator].getName)
} else {
Option.of(classOf[ComplexKeyGenerator].getName)
}
}
Option.of(DataSourceOptionsHelper.inferKeyGenClazz(p.getProps))
})

val KEYGENERATOR_CLASS_NAME: ConfigProperty[String] = ConfigProperty
.key("hoodie.datasource.write.keygenerator.class")
.defaultValue(classOf[SimpleKeyGenerator].getName)
Expand Down Expand Up @@ -804,6 +794,22 @@ object DataSourceOptionsHelper {
) ++ translateConfigurations(parameters)
}

def inferKeyGenClazz(props: TypedProperties): String = {
val partitionFields = props.getString(DataSourceWriteOptions.PARTITIONPATH_FIELD.key(), null)
if (partitionFields != null) {
val numPartFields = partitionFields.split(",").length
val recordsKeyFields = props.getString(DataSourceWriteOptions.RECORDKEY_FIELD.key(), DataSourceWriteOptions.RECORDKEY_FIELD.defaultValue())
val numRecordKeyFields = recordsKeyFields.split(",").length
if (numPartFields == 1 && numRecordKeyFields == 1) {
classOf[SimpleKeyGenerator].getName
} else {
classOf[ComplexKeyGenerator].getName
}
} else {
classOf[NonpartitionedKeyGenerator].getName
}
}

implicit def scalaFunctionToJavaFunction[From, To](function: (From) => To): JavaFunction[From, To] = {
new JavaFunction[From, To] {
override def apply (input: From): To = function (input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hudi.command

import org.apache.avro.generic.GenericRecord
import org.apache.hudi.{DataSourceOptionsHelper, DataSourceWriteOptions}
import org.apache.hudi.common.config.TypedProperties
import org.apache.hudi.common.util.PartitionPathEncodeUtils
import org.apache.hudi.config.HoodieWriteConfig
Expand Down Expand Up @@ -113,14 +114,14 @@ class SqlKeyGenerator(props: TypedProperties) extends ComplexKeyGenerator(props)
} else partitionPath
}

override def getPartitionPath(record: GenericRecord) = {
override def getPartitionPath(record: GenericRecord): String = {
val partitionPath = super.getPartitionPath(record)
convertPartitionPathToSqlType(partitionPath, false)
convertPartitionPathToSqlType(partitionPath, rowType = false)
}

override def getPartitionPath(row: Row): String = {
val partitionPath = super.getPartitionPath(row)
convertPartitionPathToSqlType(partitionPath, true)
convertPartitionPathToSqlType(partitionPath, rowType = true)
}
}

Expand All @@ -135,7 +136,7 @@ object SqlKeyGenerator {
if (beforeKeyGenClassName != null && beforeKeyGenClassName.nonEmpty) {
HoodieSparkKeyGeneratorFactory.convertToSparkKeyGenerator(beforeKeyGenClassName)
} else {
classOf[ComplexKeyGenerator].getCanonicalName
DataSourceOptionsHelper.inferKeyGenClazz(props)
}
}
}

0 comments on commit 2389ccd

Please sign in to comment.