Skip to content

Commit

Permalink
#36 - replace sme with spark impl - cleanup with #37 safety check
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Jun 30, 2023
1 parent 8d39954 commit c2d8e1a
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 9 deletions.
202 changes: 195 additions & 7 deletions src/main/scala/com/sparkutils/quality/impl/util/StructFunctions.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package com.sparkutils.quality.impl.util

import com.sparkutils.quality.QualityException
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, GenericInternalRow, NamedExpression}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.QualitySparkUtils.{toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExpressionDescription, ExtractValue, GenericInternalRow, GetStructField, If, IsNull, LeafExpression, Literal, UnaryExpression, Unevaluable}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.qualityFunctions.utils.named
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

import scala.collection.mutable.ArrayBuffer

trait StructFunctionsImport {

Expand All @@ -16,10 +19,15 @@ trait StructFunctionsImport {
* @param update
* @param transformations
* @return a new copy of update with the changes applied
*/
*/ // TODO figure out optimisation to join fields
def update_field(update: Column, transformations: (String, Column)*): Column =
new Column( AddFields(update.expr +: transformations.flatMap{ transformation
/* new Column( AddFields(update.expr +: transformations.flatMap{ transformation
=> Seq(lit(transformation._1).expr, transformation._2.expr)} ) )
*/
transformations.foldRight(update){
case ((path, col), origin) =>
new Column( UpdateFields.apply(origin.expr, path, col.expr) )
}

}

Expand Down Expand Up @@ -216,3 +224,183 @@ case class AddFields(children: Seq[Expression]) extends Expression {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren)
}

// Below is lifted from 3.4.1 complexTypeCreator

/**
* Represents an operation to be applied to the fields of a struct.
*/
trait StructFieldsOperation extends Expression with Unevaluable {

override lazy val foldable = false

val resolver: Resolver = SQLConf.get.resolver

override def dataType: DataType = throw new IllegalStateException(
"StructFieldsOperation.dataType should not be called.")

override def nullable: Boolean = throw new IllegalStateException(
"StructFieldsOperation.nullable should not be called.")

/**
* Returns an updated list of StructFields and Expressions that will ultimately be used
* as the fields argument for [[StructType]] and as the children argument for
* [[CreateNamedStruct]] respectively inside of [[UpdateFields]].
*/
def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)]
}

/**
* Add or replace a field by name.
*
* We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its
* children, and thereby enable the analyzer to resolve and transform valExpr as necessary.
*/
case class WithField(name: String, child: Expression)
extends UnaryExpression with StructFieldsOperation {

override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = {
val newFieldExpr = (StructField(name, child.dataType, child.nullable), child)
val result = ArrayBuffer.empty[(StructField, Expression)]
var hasMatch = false
for (existingFieldExpr @ (existingField, _) <- values) {
if (resolver(existingField.name, name)) {
hasMatch = true
result += newFieldExpr
} else {
result += existingFieldExpr
}
}
if (!hasMatch) result += newFieldExpr
result.toSeq
}

override def prettyName: String = "WithField"

protected def withNewChildInternal(newChild: Expression): WithField =
copy(child = newChild)
}

/**
* Drop a field by name.
*/
case class DropField(name: String) extends LeafExpression with StructFieldsOperation {
override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] =
values.filterNot { case (field, _) => resolver(field.name, name) }
}

/**
* Updates fields in a struct.
*/
case class UpdateFields(children: Seq[Expression])
extends Expression with CodegenFallback {

val structExpr = children.head
val fieldOps: Seq[StructFieldsOperation] = children.drop(1).map(_.asInstanceOf[StructFieldsOperation])

override def checkInputDataTypes(): TypeCheckResult = {
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure( message =
s"UNEXPECTED_INPUT_TYPE, requiredType StructType, inputSql ${ toSQLExpr(structExpr) }, inputType ${ toSQLType(structExpr.dataType) }"
)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure( message =
s"errorSubClass = CANNOT_DROP_ALL_FIELDS"
)
} else {
TypeCheckResult.TypeCheckSuccess
}
}
/*
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
} */

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def dataType: StructType = StructType(newFields)

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "update_fields"

private lazy val newFieldExprs: Seq[(StructField, Expression)] = {
def getFieldExpr(i: Int): Expression = structExpr match {
case c: CreateNamedStruct => c.valExprs(i)
case _ => GetStructField(structExpr, i)
}
val fieldsWithIndex = structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex
val existingFieldExprs: Seq[(StructField, Expression)] =
fieldsWithIndex.map { case (field, i) => (field, getFieldExpr(i)) }
fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs))
}

private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1)

lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2)

lazy val evalExpr: Expression = {
val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap {
case (field, expr) => Seq(Literal(field.name), expr)
})

if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr)
} else {
createNamedStructExpr
}
}

override def eval(input: InternalRow): Any = evalExpr.eval(input)

}

object UpdateFields {
private def nameParts(fieldName: String): Seq[String] = {
require(fieldName != null, "fieldName cannot be null")

if (fieldName.isEmpty) {
fieldName :: Nil
} else {
UnresolvedAttribute.parseAttributeName(fieldName)
//CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
}

/**
* Adds/replaces field of `StructType` into `col` expression by name.
*/
def apply(col: Expression, fieldName: String, expr: Expression): UpdateFields =
updateFieldsHelper(col, nameParts(fieldName), name => WithField(name, expr))

/**
* Drops fields of `StructType` in `col` expression by name.
*/
def apply(col: Expression, fieldName: String): UpdateFields =
updateFieldsHelper(col, nameParts(fieldName), name => DropField(name))

private def updateFieldsHelper(
structExpr: Expression,
namePartsRemaining: Seq[String],
valueFunc: String => StructFieldsOperation) : UpdateFields = {
val fieldName = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
UpdateFields(Seq(structExpr, valueFunc(fieldName)))
} else {
val newStruct = if (structExpr.resolved) {
val resolver = SQLConf.get.resolver
ExtractValue(structExpr, Literal(fieldName), resolver)
} else {
UnresolvedExtractValue(structExpr, Literal(fieldName))
}

val newValue = updateFieldsHelper(
structExpr = newStruct,
namePartsRemaining = namePartsRemaining.tail,
valueFunc = valueFunc)
UpdateFields(Seq(structExpr, WithField(fieldName, newValue) ))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.sparkutils.qualityTests
import com.sparkutils.quality
import com.sparkutils.quality._
import functions._
import com.sparkutils.quality.impl.util.{Arrays, PrintCode}
import com.sparkutils.quality.impl.util.{Arrays, PrintCode, UpdateFields}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.IntegerType
Expand Down Expand Up @@ -664,8 +664,8 @@ class BaseFunctionalityTest extends FunSuite with RowTools with TestUtils {
assertsbc(og, 4, "wot")

// val updated = og.select(update_field(col("s"), ("b", update_field(col("s.b"), ("c", lit(40))))) as "s")
//val updated = og.select(add_struct_field("s", "b.c", lit(40)) as "s")
val updated = og.select(update_field(col("s"), ("b.c", lit(40)), ("b.d.e", lit("mate"))) as "s")
//val updated = og.select(new Column(UpdateFields.apply(UpdateFields.apply(col("s").expr, "b.c", lit(40).expr), "b.d.e", lit("mate").expr)) as "s")
assertsbc(updated, 40, "mate")
}
}
Expand Down

0 comments on commit c2d8e1a

Please sign in to comment.