Skip to content

Commit

Permalink
Merge pull request #1304 from tobegit3hub/bugfix/fix_register_empty_d…
Browse files Browse the repository at this point in the history
…f_for_openmldb_batch

fix: support create empty dataframe with schema for openmldb-batch
  • Loading branch information
tobegit3hub authored Feb 24, 2022
2 parents 9791c32 + ae7283e commit 5cc52af
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package com._4paradigm.openmldb.batch.api

import com._4paradigm.openmldb.batch.catalog.OpenmldbCatalogService
import com._4paradigm.openmldb.batch.utils.HybridseUtil
import com._4paradigm.openmldb.batch.{OpenmldbBatchConfig, SparkPlanner}
import org.apache.commons.io.IOUtils
import org.apache.spark.{SPARK_VERSION, SparkConf}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.slf4j.LoggerFactory
import scala.collection.mutable
import scala.collection.JavaConverters.asScalaBufferConverter

/**
* The class to provide SparkSession-like API.
Expand Down Expand Up @@ -57,14 +60,17 @@ class OpenmldbSession {
this.setDefaultSparkConfig()

if (this.config.openmldbZkCluster.nonEmpty && this.config.openmldbZkRootPath.nonEmpty) {
logger.info(s"Try to connect OpenMLDB with zk ${this.config.openmldbZkCluster} and root path " +
s"${this.config.openmldbZkRootPath}")
try {
openmldbCatalogService = new OpenmldbCatalogService(this.config.openmldbZkCluster,
this.config.openmldbZkRootPath, config.openmldbJsdkLibraryPath)
registerOpenmldbOfflineTable(openmldbCatalogService)
} catch {
case e: Exception => logger.warn("Fail to connect OpenMLDB cluster and register tables, " + e.getMessage)
}

} else {
logger.warn("openmldb.zk.cluster or openmldb.zk.root.path is not set and do not register OpenMLDB tables")
}
}

Expand Down Expand Up @@ -266,9 +272,19 @@ class OpenmldbSession {
registerTable(dbName, tableName, df)
} else {
// Register empty df for table
logger.info(s"Register empty dataframe fof $dbName.$tableName")
// TODO: Create empty df with schema
registerTable(dbName, tableName, sparkSession.emptyDataFrame)
val tableInfo = catalogService.getTableInfo(dbName, tableName)
val columnDescList = tableInfo.getColumnDescList()

val schema = new StructType(columnDescList.asScala.map(colDesc => {
StructField(colDesc.getName, HybridseUtil.protoTypeToSparkType(colDesc.getDataType),
!colDesc.getNotNull)
}).toArray)

logger.info(s"Register empty dataframe fof $dbName.$tableName with schema ${schema}")
// Create empty df with schema
val emptyDf = sparkSession.createDataFrame(sparkSession.emptyDataFrame.rdd, schema)

registerTable(dbName, tableName, emptyDf)
}
} catch {
case e: Exception => logger.warn(s"Fail to register table $dbName.$tableName, error: ${e.getMessage}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ import org.apache.spark.sql.types.{
ShortType, StringType, StructField, StructType, TimestampType
}
import org.apache.spark.sql.{DataFrame, Row}

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.mutable
import scala.reflect.{ClassTag, classTag}


object HybridseUtil {
Expand Down Expand Up @@ -102,6 +100,23 @@ object HybridseUtil {
}
}

def protoTypeToSparkType(dtype: com._4paradigm.openmldb.proto.Type.DataType): DataType = {
dtype match {
case com._4paradigm.openmldb.proto.Type.DataType.kSmallInt => ShortType
case com._4paradigm.openmldb.proto.Type.DataType.kInt => IntegerType
case com._4paradigm.openmldb.proto.Type.DataType.kBigInt => LongType
case com._4paradigm.openmldb.proto.Type.DataType.kFloat => FloatType
case com._4paradigm.openmldb.proto.Type.DataType.kDouble => DoubleType
case com._4paradigm.openmldb.proto.Type.DataType.kBool => BooleanType
case com._4paradigm.openmldb.proto.Type.DataType.kString => StringType
case com._4paradigm.openmldb.proto.Type.DataType.kVarchar => StringType
case com._4paradigm.openmldb.proto.Type.DataType.kDate => DateType
case com._4paradigm.openmldb.proto.Type.DataType.kTimestamp => TimestampType
case _ => throw new IllegalArgumentException(
s"HybridSE proto data type $dtype not supported")
}
}

def getHybridseType(dtype: DataType): Type = {
dtype match {
case ShortType => Type.kInt16
Expand Down

0 comments on commit 5cc52af

Please sign in to comment.