Skip to content

Commit

Permalink
Revert "SPARK-49004"
Browse files Browse the repository at this point in the history
This reverts commit c8db813.
  • Loading branch information
hvanhovell committed Aug 1, 2024
1 parent a0f8de5 commit f15570d
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
Expand All @@ -64,8 +65,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.arrow.ArrowConverters
Expand Down Expand Up @@ -1614,28 +1614,12 @@ class SparkConnectPlanner(
fun: proto.Expression.UnresolvedFunction): Expression = {
if (fun.getIsUserDefinedFunction) {
UnresolvedFunction(
parser.parseMultipartIdentifier(fun.getFunctionName),
parser.parseFunctionIdentifier(fun.getFunctionName),
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = fun.getIsDistinct)
} else {
// In order to retain backwards compatibility we allow functions registered in the
// `system`.`internal` namespace to looked by their name (instead of their FQN).
val builtInName = FunctionIdentifier(fun.getFunctionName)
val functionRegistry = session.sessionState.functionRegistry
val internalName = builtInName.copy(
database = Option(CatalogManager.INTERNAL_NAMESPACE),
catalog = Option(CatalogManager.SYSTEM_CATALOG_NAME))
// We need to drop the global built-ins because we can't parse symbolic names
// (e.g. `+`, `-`, ...).
val names = if (functionRegistry.functionExists(builtInName)) {
builtInName.nameParts
} else if (functionRegistry.functionExists(internalName)) {
internalName.nameParts
} else {
parser.parseMultipartIdentifier(fun.getFunctionName)
}
UnresolvedFunction(
names,
FunctionIdentifier(fun.getFunctionName),
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = fun.getIsDistinct)
}
Expand Down Expand Up @@ -1848,6 +1832,18 @@ class SparkConnectPlanner(
private def transformUnregisteredFunction(
fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
fun.getFunctionName match {
case "product" if fun.getArgumentsCount == 1 =>
Some(
aggregate
.Product(transformExpression(fun.getArgumentsList.asScala.head))
.toAggregateExpression())

case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long]
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(
new BloomFilterAggregate(children(0), children(1), children(2))
.toAggregateExpression())

case "timestampdiff" if fun.getArgumentsCount == 3 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Expand All @@ -1868,6 +1864,21 @@ class SparkConnectPlanner(
throw InvalidPlanInput(s"numBuckets should be a literal integer, but got $other")
}

case "years" if fun.getArgumentsCount == 1 =>
Some(Years(transformExpression(fun.getArguments(0))))

case "months" if fun.getArgumentsCount == 1 =>
Some(Months(transformExpression(fun.getArguments(0))))

case "days" if fun.getArgumentsCount == 1 =>
Some(Days(transformExpression(fun.getArguments(0))))

case "hours" if fun.getArgumentsCount == 1 =>
Some(Hours(transformExpression(fun.getArguments(0))))

case "unwrap_udt" if fun.getArgumentsCount == 1 =>
Some(UnwrapUDT(transformExpression(fun.getArguments(0))))

case "from_json" if Seq(2, 3).contains(fun.getArgumentsCount) =>
// JsonToStructs constructor doesn't accept JSON-formatted schema.
extractDataTypeFromJSON(fun.getArguments(1)).map { dataType =>
Expand Down Expand Up @@ -1917,6 +1928,9 @@ class SparkConnectPlanner(
Some(CatalystDataToAvro(children.head, jsonFormatSchema))

// PS(Pandas API on Spark)-specific functions
case "distributed_sequence_id" if fun.getArgumentsCount == 0 =>
Some(DistributedSequenceID())

case "pandas_product" if fun.getArgumentsCount == 2 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val dropna = extractBoolean(children(1), "dropna")
Expand All @@ -1927,6 +1941,14 @@ class SparkConnectPlanner(
val ddof = extractInteger(children(1), "ddof")
Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false))

case "pandas_skew" if fun.getArgumentsCount == 1 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(aggregate.PandasSkewness(children(0)).toAggregateExpression(false))

case "pandas_kurt" if fun.getArgumentsCount == 1 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(aggregate.PandasKurtosis(children(0)).toAggregateExpression(false))

case "pandas_var" if fun.getArgumentsCount == 2 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val ddof = extractInteger(children(1), "ddof")
Expand All @@ -1946,7 +1968,11 @@ class SparkConnectPlanner(
val children = fun.getArgumentsList.asScala.map(transformExpression)
val alpha = extractDouble(children(1), "alpha")
val ignoreNA = extractBoolean(children(2), "ignoreNA")
Some(new EWM(children(0), alpha, ignoreNA))
Some(EWM(children(0), alpha, ignoreNA))

case "null_index" if fun.getArgumentsCount == 1 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(NullIndex(children(0)))

// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2320,26 +2320,42 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = {
v1SessionCatalog.lookupBuiltinOrTempFunction(FunctionIdentifier(name))
if (name.length == 1) {
v1SessionCatalog.lookupBuiltinOrTempFunction(name.head)
} else {
None
}
}

def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = {
v1SessionCatalog.lookupBuiltinOrTempTableFunction(FunctionIdentifier(name))
if (name.length == 1) {
v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head)
} else {
None
}
}

private def resolveBuiltinOrTempFunction(
name: Seq[String],
arguments: Seq[Expression],
u: Option[UnresolvedFunction]): Option[Expression] = {
v1SessionCatalog.resolveBuiltinOrTempFunction(FunctionIdentifier(name), arguments).map {
func => if (u.isDefined) validateFunction(func, arguments.length, u.get) else func
if (name.length == 1) {
v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func =>
if (u.isDefined) validateFunction(func, arguments.length, u.get) else func
}
} else {
None
}
}

private def resolveBuiltinOrTempTableFunction(
name: Seq[String],
arguments: Seq[Expression]): Option[LogicalPlan] = {
v1SessionCatalog.resolveBuiltinOrTempTableFunction(FunctionIdentifier(name), arguments)
if (name.length == 1) {
v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments)
} else {
None
}
}

private def resolveV1Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -203,7 +202,7 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging
// Resolution of the function name is always case insensitive, but the database name
// depends on the caller
private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
name.copy(funcName = name.funcName.toLowerCase(Locale.ROOT))
FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database)
}

override def registerFunction(
Expand Down Expand Up @@ -884,32 +883,6 @@ object FunctionRegistry {

val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet

/**
* Expressions registered in the system.internal.
*/
registerInternalExpression[Product]("product")
registerInternalExpression[BloomFilterAggregate]("bloom_filter_agg")
registerInternalExpression[Years]("years")
registerInternalExpression[Months]("months")
registerInternalExpression[Days]("days")
registerInternalExpression[Hours]("hours")
registerInternalExpression[UnwrapUDT]("unwrap_udt")
registerInternalExpression[DistributedSequenceID]("distributed_sequence_id")
registerInternalExpression[PandasSkewness]("pandas_skew")
registerInternalExpression[PandasKurtosis]("pandas_kurt")
registerInternalExpression[NullIndex]("null_index")

private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = {
val (info, builder) = FunctionRegistryBase.build(name, None)
builtin.internalRegisterFunction(
FunctionIdentifier(
name,
Option(CatalogManager.INTERNAL_NAMESPACE),
Option(CatalogManager.SYSTEM_CATALOG_NAME)),
info,
builder)
}

private def makeExprInfoForVirtualOperator(name: String, usage: String): ExpressionInfo = {
new ExpressionInfo(
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1647,81 +1647,65 @@ class SessionCatalog(
* Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function.
* This only supports scalar functions.
*/
def lookupBuiltinOrTempFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] = {
val operator = funcIdent match {
case FunctionIdentifier(name, None, None) =>
FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT))
case _ => None
}
operator.orElse {
def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = {
FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse {
synchronized(lookupTempFuncWithViewContext(
funcIdent,
FunctionRegistry.builtin.functionExists,
functionRegistry.lookupFunction))
name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction))
}
}

/**
* Look up the `ExpressionInfo` of the given function by name if it's a built-in or
* temp table function.
*/
def lookupBuiltinOrTempTableFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] =
synchronized {
lookupTempFuncWithViewContext(
funcIdent,
TableFunctionRegistry.builtin.functionExists,
tableFunctionRegistry.lookupFunction)
}
def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized {
lookupTempFuncWithViewContext(
name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction)
}

/**
* Look up a built-in or temp scalar function by name and resolves it to an Expression if such
* a function exists.
*/
def resolveBuiltinOrTempFunction(
funcIdent: FunctionIdentifier,
arguments: Seq[Expression]): Option[Expression] = {
def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = {
resolveBuiltinOrTempFunctionInternal(
funcIdent,
arguments,
FunctionRegistry.builtin.functionExists,
functionRegistry)
name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry)
}

/**
* Look up a built-in or temp table function by name and resolves it to a LogicalPlan if such
* a function exists.
*/
def resolveBuiltinOrTempTableFunction(
funcIdent: FunctionIdentifier,
arguments: Seq[Expression]): Option[LogicalPlan] = {
name: String, arguments: Seq[Expression]): Option[LogicalPlan] = {
resolveBuiltinOrTempFunctionInternal(
funcIdent, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry)
name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry)
}

private def resolveBuiltinOrTempFunctionInternal[T](
funcIdent: FunctionIdentifier,
name: String,
arguments: Seq[Expression],
isBuiltin: FunctionIdentifier => Boolean,
registry: FunctionRegistryBase[T]): Option[T] = synchronized {
val funcIdent = FunctionIdentifier(name)
if (!registry.functionExists(funcIdent)) {
None
} else {
lookupTempFuncWithViewContext(
funcIdent, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments)))
name, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments)))
}
}

private def lookupTempFuncWithViewContext[T](
funcIdent: FunctionIdentifier,
name: String,
isBuiltin: FunctionIdentifier => Boolean,
lookupFunc: FunctionIdentifier => Option[T]): Option[T] = {
val funcIdent = FunctionIdentifier(name)
if (isBuiltin(funcIdent)) {
lookupFunc(funcIdent)
} else if (funcIdent.catalog.isEmpty && funcIdent.database.isEmpty) {
val name = funcIdent.funcName
val context = AnalysisContext.get
val isResolvingView = context.catalogAndNamespace.nonEmpty
val referredTempFunctionNames = context.referredTempFunctionNames
} else {
val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty
val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames
if (isResolvingView) {
// When resolving a view, only return a temp function if it's referred by this view.
if (referredTempFunctionNames.contains(name)) {
Expand All @@ -1735,12 +1719,10 @@ class SessionCatalog(
// We are not resolving a view and the function is a temp one, add it to
// `AnalysisContext`, so during the view creation, we can save all referred temp
// functions to view metadata.
context.referredTempFunctionNames.add(name)
AnalysisContext.get.referredTempFunctionNames.add(name)
}
result
}
} else {
None
}
}

Expand Down Expand Up @@ -1827,21 +1809,33 @@ class SessionCatalog(
* Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists.
*/
def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
lookupBuiltinOrTempFunction(name)
.orElse(lookupBuiltinOrTempTableFunction(name))
.getOrElse(lookupPersistentFunction(name))
if (name.database.isEmpty) {
lookupBuiltinOrTempFunction(name.funcName)
.orElse(lookupBuiltinOrTempTableFunction(name.funcName))
.getOrElse(lookupPersistentFunction(name))
} else {
lookupPersistentFunction(name)
}
}

// The actual function lookup logic looks up temp/built-in function first, then persistent
// function from either v1 or v2 catalog. This method only look up v1 catalog.
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
resolveBuiltinOrTempFunction(name, children)
.getOrElse(resolvePersistentFunction(name, children))
if (name.database.isEmpty) {
resolveBuiltinOrTempFunction(name.funcName, children)
.getOrElse(resolvePersistentFunction(name, children))
} else {
resolvePersistentFunction(name, children)
}
}

def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = {
resolveBuiltinOrTempTableFunction(name, children)
.getOrElse(resolvePersistentTableFunction(name, children))
if (name.database.isEmpty) {
resolveBuiltinOrTempTableFunction(name.funcName, children)
.getOrElse(resolvePersistentTableFunction(name, children))
} else {
resolvePersistentTableFunction(name, children)
}
}

/**
Expand Down
Loading

0 comments on commit f15570d

Please sign in to comment.