Skip to content

Commit

Permalink
[SPARK-49966][SQL][FOLLOWUP] Codegen Support for JsonToStructs(`fro…
Browse files Browse the repository at this point in the history
…m_json`) - remove Invoke

### What changes were proposed in this pull request?
The pr aims to restore the codegen implementation of `JsonToStructs`(`from_json`) in the way of `manually`, rather than in the way of `Invoke`.

### Why are the changes needed?
Based on cloud-fan's double-check, apache#48509 (comment)
I believe that restore to manual implementation will not result in regression.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Update existed UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48979 from panbingkun/SPARK-49966_REMOVE_INVOKE.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
panbingkun authored and cloud-fan committed Nov 27, 2024
1 parent 0138019 commit 6edcf43
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ case class JsonToStructsEvaluator(
}

final def evaluate(json: UTF8String): Any = {
if (json == null) return null
nullableSchema match {
case _: VariantType =>
VariantExpressionEvalUtils.parseJson(json,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,8 @@ case class JsonToStructs(
timeZoneId: Option[String] = None,
variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS))
extends UnaryExpression
with RuntimeReplaceable
with ExpectsInputTypes
with TimeZoneAwareExpression
with ExpectsInputTypes
with QueryErrorsBase {

// The JSON input data might be missing certain fields. We force the nullability
Expand All @@ -648,7 +647,9 @@ case class JsonToStructs(

override def nullable: Boolean = true

override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT, RUNTIME_REPLACEABLE)
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT)

override def nullIntolerant: Boolean = true

// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression, options: Map[String, String]) =
Expand Down Expand Up @@ -682,6 +683,32 @@ case class JsonToStructs(
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
private lazy val evaluator = new JsonToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)

override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String])

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val eval = child.genCode(ctx)
val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${eval.code}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value});
|boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin)
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil

override def sql: String = schema match {
Expand All @@ -691,21 +718,6 @@ case class JsonToStructs(

override def prettyName: String = "from_json"

@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
lazy val evaluator: JsonToStructsEvaluator = JsonToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[JsonToStructsEvaluator])),
"evaluate",
dataType,
Seq(child),
Seq(child.dataType)
)

override protected def withNewChildInternal(newChild: Expression): JsonToStructs =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("from_json escaping") {
val schema = StructType(StructField("\"quote", IntegerType) :: Nil)
GenerateUnsafeProjection.generate(
JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT).replacement :: Nil)
JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil)
}

test("from_json") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {

Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""", null).foreach(v => {
val row = create_row(v)
checkEvaluation(e1, replace(e2).eval(row), row)
checkEvaluation(e1, e2.eval(row), row)
})
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
val complexTypeFactory = JsonToStructs(attr.dataType,
ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone))
wrapperConvertException(data =>
complexTypeFactory.evaluator.evaluate(UTF8String.fromString(data)), any => any)
complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case dt =>
Expand Down

0 comments on commit 6edcf43

Please sign in to comment.