diff --git a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6565c4abf477c..7bfacf7cf0647 100644 --- a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -45,12 +45,13 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} +import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical @@ -64,8 +65,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.MetricGenerator -import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -1614,28 +1614,12 @@ class SparkConnectPlanner( fun: proto.Expression.UnresolvedFunction): Expression = { if (fun.getIsUserDefinedFunction) { UnresolvedFunction( - parser.parseMultipartIdentifier(fun.getFunctionName), + parser.parseFunctionIdentifier(fun.getFunctionName), fun.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = fun.getIsDistinct) } else { - // In order to retain backwards compatibility we allow functions registered in the - // `system`.`internal` namespace to looked by their name (instead of their FQN). - val builtInName = FunctionIdentifier(fun.getFunctionName) - val functionRegistry = session.sessionState.functionRegistry - val internalName = builtInName.copy( - database = Option(CatalogManager.INTERNAL_NAMESPACE), - catalog = Option(CatalogManager.SYSTEM_CATALOG_NAME)) - // We need to drop the global built-ins because we can't parse symbolic names - // (e.g. `+`, `-`, ...). - val names = if (functionRegistry.functionExists(builtInName)) { - builtInName.nameParts - } else if (functionRegistry.functionExists(internalName)) { - internalName.nameParts - } else { - parser.parseMultipartIdentifier(fun.getFunctionName) - } UnresolvedFunction( - names, + FunctionIdentifier(fun.getFunctionName), fun.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = fun.getIsDistinct) } @@ -1848,6 +1832,18 @@ class SparkConnectPlanner( private def transformUnregisteredFunction( fun: proto.Expression.UnresolvedFunction): Option[Expression] = { fun.getFunctionName match { + case "product" if fun.getArgumentsCount == 1 => + Some( + aggregate + .Product(transformExpression(fun.getArgumentsList.asScala.head)) + .toAggregateExpression()) + + case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + // [col, expectedNumItems: Long, numBits: Long] + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some( + new BloomFilterAggregate(children(0), children(1), children(2)) + .toAggregateExpression()) case "timestampdiff" if fun.getArgumentsCount == 3 => val children = fun.getArgumentsList.asScala.map(transformExpression) @@ -1868,6 +1864,21 @@ class SparkConnectPlanner( throw InvalidPlanInput(s"numBuckets should be a literal integer, but got $other") } + case "years" if fun.getArgumentsCount == 1 => + Some(Years(transformExpression(fun.getArguments(0)))) + + case "months" if fun.getArgumentsCount == 1 => + Some(Months(transformExpression(fun.getArguments(0)))) + + case "days" if fun.getArgumentsCount == 1 => + Some(Days(transformExpression(fun.getArguments(0)))) + + case "hours" if fun.getArgumentsCount == 1 => + Some(Hours(transformExpression(fun.getArguments(0)))) + + case "unwrap_udt" if fun.getArgumentsCount == 1 => + Some(UnwrapUDT(transformExpression(fun.getArguments(0)))) + case "from_json" if Seq(2, 3).contains(fun.getArgumentsCount) => // JsonToStructs constructor doesn't accept JSON-formatted schema. extractDataTypeFromJSON(fun.getArguments(1)).map { dataType => @@ -1917,6 +1928,9 @@ class SparkConnectPlanner( Some(CatalystDataToAvro(children.head, jsonFormatSchema)) // PS(Pandas API on Spark)-specific functions + case "distributed_sequence_id" if fun.getArgumentsCount == 0 => + Some(DistributedSequenceID()) + case "pandas_product" if fun.getArgumentsCount == 2 => val children = fun.getArgumentsList.asScala.map(transformExpression) val dropna = extractBoolean(children(1), "dropna") @@ -1927,6 +1941,14 @@ class SparkConnectPlanner( val ddof = extractInteger(children(1), "ddof") Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false)) + case "pandas_skew" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(aggregate.PandasSkewness(children(0)).toAggregateExpression(false)) + + case "pandas_kurt" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(aggregate.PandasKurtosis(children(0)).toAggregateExpression(false)) + case "pandas_var" if fun.getArgumentsCount == 2 => val children = fun.getArgumentsList.asScala.map(transformExpression) val ddof = extractInteger(children(1), "ddof") @@ -1946,7 +1968,11 @@ class SparkConnectPlanner( val children = fun.getArgumentsList.asScala.map(transformExpression) val alpha = extractDouble(children(1), "alpha") val ignoreNA = extractBoolean(children(2), "ignoreNA") - Some(new EWM(children(0), alpha, ignoreNA)) + Some(EWM(children(0), alpha, ignoreNA)) + + case "null_index" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(NullIndex(children(0))) // ML-specific functions case "vector_to_array" if fun.getArgumentsCount == 2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ac0b1ea601e04..1b194da5ab0a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2320,26 +2320,42 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = { - v1SessionCatalog.lookupBuiltinOrTempFunction(FunctionIdentifier(name)) + if (name.length == 1) { + v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) + } else { + None + } } def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { - v1SessionCatalog.lookupBuiltinOrTempTableFunction(FunctionIdentifier(name)) + if (name.length == 1) { + v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) + } else { + None + } } private def resolveBuiltinOrTempFunction( name: Seq[String], arguments: Seq[Expression], u: Option[UnresolvedFunction]): Option[Expression] = { - v1SessionCatalog.resolveBuiltinOrTempFunction(FunctionIdentifier(name), arguments).map { - func => if (u.isDefined) validateFunction(func, arguments.length, u.get) else func + if (name.length == 1) { + v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func => + if (u.isDefined) validateFunction(func, arguments.length, u.get) else func + } + } else { + None } } private def resolveBuiltinOrTempTableFunction( name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = { - v1SessionCatalog.resolveBuiltinOrTempTableFunction(FunctionIdentifier(name), arguments) + if (name.length == 1) { + v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) + } else { + None + } } private def resolveV1Function( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 39df8ba4ee461..48123254a8fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -203,7 +202,7 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging // Resolution of the function name is always case insensitive, but the database name // depends on the caller private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { - name.copy(funcName = name.funcName.toLowerCase(Locale.ROOT)) + FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database) } override def registerFunction( @@ -884,32 +883,6 @@ object FunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet - /** - * Expressions registered in the system.internal. - */ - registerInternalExpression[Product]("product") - registerInternalExpression[BloomFilterAggregate]("bloom_filter_agg") - registerInternalExpression[Years]("years") - registerInternalExpression[Months]("months") - registerInternalExpression[Days]("days") - registerInternalExpression[Hours]("hours") - registerInternalExpression[UnwrapUDT]("unwrap_udt") - registerInternalExpression[DistributedSequenceID]("distributed_sequence_id") - registerInternalExpression[PandasSkewness]("pandas_skew") - registerInternalExpression[PandasKurtosis]("pandas_kurt") - registerInternalExpression[NullIndex]("null_index") - - private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = { - val (info, builder) = FunctionRegistryBase.build(name, None) - builtin.internalRegisterFunction( - FunctionIdentifier( - name, - Option(CatalogManager.INTERNAL_NAMESPACE), - Option(CatalogManager.SYSTEM_CATALOG_NAME)), - info, - builder) - } - private def makeExprInfoForVirtualOperator(name: String, usage: String): ExpressionInfo = { new ExpressionInfo( null, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 91300ee6a7eb1..701c68684c346 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1647,17 +1647,10 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. * This only supports scalar functions. */ - def lookupBuiltinOrTempFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] = { - val operator = funcIdent match { - case FunctionIdentifier(name, None, None) => - FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)) - case _ => None - } - operator.orElse { + def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { + FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { synchronized(lookupTempFuncWithViewContext( - funcIdent, - FunctionRegistry.builtin.functionExists, - functionRegistry.lookupFunction)) + name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) } } @@ -1665,26 +1658,18 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or * temp table function. */ - def lookupBuiltinOrTempTableFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] = - synchronized { - lookupTempFuncWithViewContext( - funcIdent, - TableFunctionRegistry.builtin.functionExists, - tableFunctionRegistry.lookupFunction) - } + def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { + lookupTempFuncWithViewContext( + name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) + } /** * Look up a built-in or temp scalar function by name and resolves it to an Expression if such * a function exists. */ - def resolveBuiltinOrTempFunction( - funcIdent: FunctionIdentifier, - arguments: Seq[Expression]): Option[Expression] = { + def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { resolveBuiltinOrTempFunctionInternal( - funcIdent, - arguments, - FunctionRegistry.builtin.functionExists, - functionRegistry) + name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry) } /** @@ -1692,36 +1677,35 @@ class SessionCatalog( * a function exists. */ def resolveBuiltinOrTempTableFunction( - funcIdent: FunctionIdentifier, - arguments: Seq[Expression]): Option[LogicalPlan] = { + name: String, arguments: Seq[Expression]): Option[LogicalPlan] = { resolveBuiltinOrTempFunctionInternal( - funcIdent, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) + name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) } private def resolveBuiltinOrTempFunctionInternal[T]( - funcIdent: FunctionIdentifier, + name: String, arguments: Seq[Expression], isBuiltin: FunctionIdentifier => Boolean, registry: FunctionRegistryBase[T]): Option[T] = synchronized { + val funcIdent = FunctionIdentifier(name) if (!registry.functionExists(funcIdent)) { None } else { lookupTempFuncWithViewContext( - funcIdent, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) + name, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) } } private def lookupTempFuncWithViewContext[T]( - funcIdent: FunctionIdentifier, + name: String, isBuiltin: FunctionIdentifier => Boolean, lookupFunc: FunctionIdentifier => Option[T]): Option[T] = { + val funcIdent = FunctionIdentifier(name) if (isBuiltin(funcIdent)) { lookupFunc(funcIdent) - } else if (funcIdent.catalog.isEmpty && funcIdent.database.isEmpty) { - val name = funcIdent.funcName - val context = AnalysisContext.get - val isResolvingView = context.catalogAndNamespace.nonEmpty - val referredTempFunctionNames = context.referredTempFunctionNames + } else { + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames if (isResolvingView) { // When resolving a view, only return a temp function if it's referred by this view. if (referredTempFunctionNames.contains(name)) { @@ -1735,12 +1719,10 @@ class SessionCatalog( // We are not resolving a view and the function is a temp one, add it to // `AnalysisContext`, so during the view creation, we can save all referred temp // functions to view metadata. - context.referredTempFunctionNames.add(name) + AnalysisContext.get.referredTempFunctionNames.add(name) } result } - } else { - None } } @@ -1827,21 +1809,33 @@ class SessionCatalog( * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. */ def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { - lookupBuiltinOrTempFunction(name) - .orElse(lookupBuiltinOrTempTableFunction(name)) - .getOrElse(lookupPersistentFunction(name)) + if (name.database.isEmpty) { + lookupBuiltinOrTempFunction(name.funcName) + .orElse(lookupBuiltinOrTempTableFunction(name.funcName)) + .getOrElse(lookupPersistentFunction(name)) + } else { + lookupPersistentFunction(name) + } } // The actual function lookup logic looks up temp/built-in function first, then persistent // function from either v1 or v2 catalog. This method only look up v1 catalog. def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - resolveBuiltinOrTempFunction(name, children) - .getOrElse(resolvePersistentFunction(name, children)) + if (name.database.isEmpty) { + resolveBuiltinOrTempFunction(name.funcName, children) + .getOrElse(resolvePersistentFunction(name, children)) + } else { + resolvePersistentFunction(name, children) + } } def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = { - resolveBuiltinOrTempTableFunction(name, children) - .getOrElse(resolvePersistentTableFunction(name, children)) + if (name.database.isEmpty) { + resolveBuiltinOrTempTableFunction(name.funcName, children) + .getOrElse(resolvePersistentTableFunction(name, children)) + } else { + resolvePersistentTableFunction(name, children) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 66e78f12d5c37..2f818fecad93a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.errors.QueryCompilationErrors - /** * An identifier that optionally specifies a database. * @@ -138,15 +136,6 @@ case class FunctionIdentifier(funcName: String, database: Option[String], catalo object FunctionIdentifier { def apply(funcName: String): FunctionIdentifier = new FunctionIdentifier(funcName) - def apply(funcName: String, database: Option[String]): FunctionIdentifier = new FunctionIdentifier(funcName, database) - - def apply(names: Seq[String]): FunctionIdentifier = names match { - case Seq() => throw QueryCompilationErrors.emptyMultipartIdentifierError() - case Seq(name) => new FunctionIdentifier(name) - case Seq(database, name) => FunctionIdentifier(name, Option(database)) - case Seq(catalog, database, name) => FunctionIdentifier(name, Option(database), Option(catalog)) - case _ => throw QueryCompilationErrors.identifierTooManyNamePartsError(names) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index d29a4e4a36483..16c387a82373b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -155,5 +155,4 @@ private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" val SYSTEM_CATALOG_NAME = "system" val SESSION_NAMESPACE = "session" - val INTERNAL_NAMESPACE = "internal" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3f115fcb0a3cd..3108f1886c299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} -import org.apache.spark.sql.connector.catalog.CatalogManager.{INTERNAL_NAMESPACE, SYSTEM_CATALOG_NAME} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -62,25 +61,21 @@ private[sql] object Column { } private[sql] def fn(name: String, inputs: Column*): Column = { - fn(name, isDistinct = false, inputs: _*) + fn(name, isDistinct = false, ignoreNulls = false, inputs: _*) } private[sql] def fn(name: String, isDistinct: Boolean, inputs: Column*): Column = { - fn(name :: Nil, isDistinct = isDistinct, inputs: _*) + fn(name, isDistinct = isDistinct, ignoreNulls = false, inputs: _*) } - private[sql] def internalFn(name: String, inputs: Column*): Column = { - fn( - SYSTEM_CATALOG_NAME :: INTERNAL_NAMESPACE :: name :: Nil, - isDistinct = false, - inputs: _*) - } - - private def fn( - names: Seq[String], + private[sql] def fn( + name: String, isDistinct: Boolean, + ignoreNulls: Boolean, inputs: Column*): Column = withOrigin { - Column(UnresolvedFunction(names, inputs.map(_.expr), isDistinct)) + Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 705891a92eb70..0e62e05900a54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} @@ -8464,7 +8465,9 @@ object functions { * @group udf_funcs * @since 3.4.0 */ - def unwrap_udt(column: Column): Column = Column.internalFn("unwrap_udt", column) + def unwrap_udt(column: Column): Column = withExpr { + UnwrapUDT(column.expr) + } // scalastyle:off // TODO(SPARK-45970): Use @static annotation so Java can access to those @@ -8478,7 +8481,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def years(e: Column): Column = Column.internalFn("years", e) + def years(e: Column): Column = withExpr { Years(e.expr) } /** * (Scala-specific) A transform for timestamps and dates to partition data into months. @@ -8486,7 +8489,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def months(e: Column): Column = Column.internalFn("months", e) + def months(e: Column): Column = withExpr { Months(e.expr) } /** * (Scala-specific) A transform for timestamps and dates to partition data into days. @@ -8494,7 +8497,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def days(e: Column): Column = Column.internalFn("days", e) + def days(e: Column): Column = withExpr { Days(e.expr) } /** * (Scala-specific) A transform for timestamps to partition data into hours. @@ -8502,7 +8505,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def hours(e: Column): Column = Column.internalFn("hours", e) + def hours(e: Column): Column = withExpr { Hours(e.expr) } /** * (Scala-specific) A transform for any type that partitions by a hash of the input column.