Skip to content

Commit

Permalink
[SPARK-31563][SQL] Fix failure of InSet.sql for collections of Cataly…
Browse files Browse the repository at this point in the history
…st's internal types

### What changes were proposed in this pull request?
In the PR, I propose to fix the `InSet.sql` method for the cases when input collection contains values of internal Catalyst's types, for instance `UTF8String`. Elements of the input set `hset` are converted to Scala types, and wrapped by `Literal` to properly form SQL view of the input collection.

### Why are the changes needed?
The changes fixed the bug in `InSet.sql` that makes wrong assumption about types of collection elements. See more details in SPARK-31563.

### Does this PR introduce any user-facing change?
Highly likely, not.

### How was this patch tested?
Added a test to `ColumnExpressionSuite`

Closes #28343 from MaxGekk/fix-InSet-sql.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
MaxGekk authored and dongjoon-hyun committed Apr 25, 2020
1 parent ab8cada commit 7d8216a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.TreeSet

import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
Expand Down Expand Up @@ -519,7 +520,9 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with

override def sql: String = {
val valueSQL = child.sql
val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")
val listSQL = hset.toSeq
.map(elem => Literal(convertToScala(elem, child.dataType)).sql)
.mkString(", ")
s"($valueSQL IN ($listSQL))"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.expressions.{In, InSet, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{In, InSet, Literal, NamedExpression}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -869,4 +870,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
df.select(typedLit(("a", 2, 1.0))),
Row(Row("a", 2, 1.0)) :: Nil)
}

test("SPARK-31563: sql of InSet for UTF8String collection") {
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
assert(inSet.sql === "('a' IN ('a', 'b'))")
}
}

0 comments on commit 7d8216a

Please sign in to comment.