From 64fbb9ee9402e579a25ea589c9dab9ebb3627db6 Mon Sep 17 00:00:00 2001 From: Suraj Aralihalli Date: Fri, 9 Feb 2024 16:06:12 -0800 Subject: [PATCH 1/2] use getJson Options Signed-off-by: Suraj Aralihalli --- .../scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala | 7 ++++--- .../main/scala/com/nvidia/spark/rapids/GpuJsonTuple.scala | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala index a113555d356..e15d8f90d74 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.ColumnVector +import ai.rapids.cudf.{ColumnVector,GetJsonObjectOptions} import com.nvidia.spark.rapids.Arm.withResource import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} @@ -32,8 +32,9 @@ case class GpuGetJsonObject(json: Expression, path: Expression) override def nullable: Boolean = true override def prettyName: String = "get_json_object" - override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - lhs.getBase().getJSONObject(rhs.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + lhs.getBase().getJSONObject(rhs.getBase, + GetJsonObjectOptions.builder().allowSingleQuotes(true).build()); } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuJsonTuple.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuJsonTuple.scala index 0b6c839ca2b..ae539820331 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuJsonTuple.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuJsonTuple.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.Scalar +import ai.rapids.cudf.{GetJsonObjectOptions,Scalar} import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry} @@ -69,7 +69,8 @@ case class GpuJsonTuple(children: Seq[Expression]) extends GpuGenerator } withResource(fieldScalars) { fieldScalars => - withResource(fieldScalars.safeMap(field => json.getJSONObject(field))) { resultCols => + withResource(fieldScalars.safeMap(field => json.getJSONObject(field, + GetJsonObjectOptions.builder().allowSingleQuotes(true).build()))) { resultCols => val generatorCols = resultCols.safeMap(_.incRefCount).zip(schema).safeMap { case (col, dataType) => GpuColumnVector.from(col, dataType) } From f0d392f234f6cda6e919642467fcbc862c8037e9 Mon Sep 17 00:00:00 2001 From: Suraj Aralihalli Date: Mon, 12 Feb 2024 13:40:49 -0800 Subject: [PATCH 2/2] add single quote get_json_object tests Signed-off-by: Suraj Aralihalli --- integration_tests/src/main/python/get_json_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/integration_tests/src/main/python/get_json_test.py b/integration_tests/src/main/python/get_json_test.py index 970617709a8..d7473119b13 100644 --- a/integration_tests/src/main/python/get_json_test.py +++ b/integration_tests/src/main/python/get_json_test.py @@ -50,6 +50,19 @@ def test_get_json_object_quoted_index(): f.get_json_object('jsonStr',r'''$['b']''').alias('sub_b')), conf={'spark.rapids.sql.expression.GetJsonObject': 'true'}) +def test_get_json_object_single_quotes(): + schema = StructType([StructField("jsonStr", StringType())]) + data = [[r'''{'a':'A'}'''], + [r'''{'b':'"B'}'''], + [r'''{"c":"'C"}''']] + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.createDataFrame(data,schema=schema).select( + f.get_json_object('jsonStr',r'''$['a']''').alias('sub_a'), + f.get_json_object('jsonStr',r'''$['b']''').alias('sub_b'), + f.get_json_object('jsonStr',r'''$['c']''').alias('sub_c')), + conf={'spark.rapids.sql.expression.GetJsonObject': 'true'}) + @pytest.mark.parametrize('query',["$.store.bicycle", "$['store'].bicycle", "$.store['bicycle']",