Skip to content

Commit

Permalink
[SPARK-37627][SQL][FOLLOWUP] Separate SortedBucketTransform from Buck…
Browse files Browse the repository at this point in the history
…etTransform

### What changes were proposed in this pull request?

1. Currently only a single bucket column is supported in `BucketTransform`, fix the code to make multiple bucket columns work.
2. Separate `SortedBucketTransform` from `BucketTransform`, and make the `arguments` in `SortedBucketTransform` in the format of `columns numBuckets sortedColumns` so we have a way to find out the `columns` and `sortedColumns`.
3. add more test coverage.

### Why are the changes needed?

Fix bugs in `BucketTransform` and `SortedBucketTransform`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New tests

Closes #34914 from huaxingao/sorted_followup.

Lead-authored-by: Huaxin Gao <[email protected]>
Co-authored-by: huaxingao <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Jan 14, 2022
1 parent 31d8489 commit 2ed827a
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.connector.expressions.{IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
Expand All @@ -37,7 +37,7 @@ private[sql] object CatalogV2Implicits {
}

implicit class BucketSpecHelper(spec: BucketSpec) {
def asTransform: BucketTransform = {
def asTransform: Transform = {
val references = spec.bucketColumnNames.map(col => reference(Seq(col)))
if (spec.sortColumnNames.nonEmpty) {
val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
Expand Down Expand Up @@ -48,8 +49,8 @@ private[sql] object LogicalExpressions {
def bucket(
numBuckets: Int,
references: Array[NamedReference],
sortedCols: Array[NamedReference]): BucketTransform =
BucketTransform(literal(numBuckets, IntegerType), references, sortedCols)
sortedCols: Array[NamedReference]): SortedBucketTransform =
SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols)

def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference)

Expand Down Expand Up @@ -101,8 +102,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R

private[sql] final case class BucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {
columns: Seq[NamedReference]) extends RewritableTransform {

override val name: String = "bucket"

Expand All @@ -112,46 +112,62 @@ private[sql] final case class BucketTransform(

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def toString: String =
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
}
override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"

override def toString: String = describe

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
}

private[sql] object BucketTransform {
def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] =
expr match {
case transform: Transform =>
def unapply(transform: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] =
transform match {
case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) =>
Some((n, FieldReference(parts), FieldReference(sortCols)))
case NamedTransform("sorted_bucket", arguments) =>
var posOfLit: Int = -1
var numOfBucket: Int = -1
arguments.zipWithIndex.foreach {
case (Lit(value: Int, IntegerType), i) =>
numOfBucket = value
posOfLit = i
case _ =>
None
}
Some(numOfBucket, arguments.take(posOfLit).map(_.asInstanceOf[NamedReference]),
arguments.drop(posOfLit + 1).map(_.asInstanceOf[NamedReference]))
case NamedTransform("bucket", arguments) =>
var numOfBucket: Int = -1
arguments(0) match {
case Lit(value: Int, IntegerType) =>
numOfBucket = value
case _ => throw new SparkException("The first element in BucketTransform arguments " +
"should be an Integer Literal.")
}
Some(numOfBucket, arguments.drop(1).map(_.asInstanceOf[NamedReference]),
Seq.empty[FieldReference])
case _ =>
None
}
}

def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] =
transform match {
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]),
Ref(sortCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(sortCols)))
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(Seq.empty[String])))
case _ =>
None
private[sql] final case class SortedBucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {

override val name: String = "sorted_bucket"

override def references: Array[NamedReference] = {
arguments.collect { case named: NamedReference => named }
}

override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns

override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences.take(columns.length),
sortedColumns = newReferences.drop(columns.length))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class InMemoryTable(
case _: DaysTransform =>
case _: HoursTransform =>
case _: BucketTransform =>
case _: SortedBucketTransform =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}
Expand Down Expand Up @@ -161,10 +162,15 @@ class InMemoryTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
case BucketTransform(numBuckets, ref, _) =>
val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row)
val valueHashCode = if (value == null) 0 else value.hashCode
((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets
case BucketTransform(numBuckets, cols, _) =>
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.types.DataType

class TransformExtractorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -139,9 +140,9 @@ class TransformExtractorSuite extends SparkFunSuite {
}

bucketTransform match {
case BucketTransform(numBuckets, FieldReference(seq), _) =>
case BucketTransform(numBuckets, cols, _) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
assert(cols(0).fieldNames === Seq("a", "b"))
case _ =>
fail("Did not match BucketTransform extractor")
}
Expand All @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite {
// expected
}
}

test("Sorted Bucket extractor") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val sortedBucketTransform = new Transform {
override def name: String = "sorted_bucket"
override def references: Array[NamedReference] = col ++ sortedCol
override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol
override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " +
s"${sortedCol(0).describe} ${sortedCol(1).describe})"
}

sortedBucketTransform match {
case BucketTransform(numBuckets, cols, sortCols) =>
assert(numBuckets === 16)
assert(cols.flatMap(c => c.fieldNames()) === Seq("a", "b"))
assert(sortCols.flatMap(c => c.fieldNames()) === Seq("c", "d"))
case _ =>
fail("Did not match BucketTransform extractor")
}
}

test("test bucket") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val bucketTransform = bucket(16, col)
val reference1 = bucketTransform.references
assert(reference1.length == 2)
assert(reference1(0).fieldNames() === Seq("a"))
assert(reference1(1).fieldNames() === Seq("b"))
val arguments1 = bucketTransform.arguments
assert(arguments1.length == 3)
assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
val copied1 = bucketTransform.withReferences(reference1)
assert(copied1.equals(bucketTransform))

val sortedBucketTransform = bucket(16, col, sortedCol)
val reference2 = sortedBucketTransform.references
assert(reference2.length == 4)
assert(reference2(0).fieldNames() === Seq("a"))
assert(reference2(1).fieldNames() === Seq("b"))
assert(reference2(2).fieldNames() === Seq("c"))
assert(reference2(3).fieldNames() === Seq("d"))
val arguments2 = sortedBucketTransform.arguments
assert(arguments2.length == 5)
assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c"))
assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d"))
val copied2 = sortedBucketTransform.withReferences(reference2)
assert(copied2.equals(sortedBucketTransform))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,13 @@ private[sql] object V2SessionCatalog {
case IdentityTransform(FieldReference(Seq(col))) =>
identityCols += col

case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) =>
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil))
case BucketTransform(numBuckets, col, sortCol) =>
if (sortCol.isEmpty) {
bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil))
} else {
bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")),
sortCol.map(_.fieldNames.mkString("."))))
}

case transform =>
throw QueryExecutionErrors.unsupportedPartitionTransformError(transform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,12 @@ class DataSourceV2SQLSuite
test("SPARK-36850: CreateTableAsSelect partitions can be specified using " +
"PARTITIONED BY and/or CLUSTERED BY") {
val identifier = "testcat.table_name"
val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
(3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
df.createOrReplaceTempView("source_table")
withTable(identifier) {
spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " +
s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source")
s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
Expand All @@ -421,18 +424,22 @@ class DataSourceV2SQLSuite
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, data)")
assert(part2 === "bucket(4, data1, data2, data3, data4)")
}
}

test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " +
"PARTITIONED BY and/or CLUSTERED BY") {
val identifier = "testcat.table_name"
val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
(3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
df.createOrReplaceTempView("source_table")
withTable(identifier) {
spark.sql(s"CREATE TABLE $identifier USING foo " +
"AS SELECT id FROM source")
spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " +
s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source")
s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " +
s"AS SELECT * FROM source_table")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
Expand All @@ -441,7 +448,7 @@ class DataSourceV2SQLSuite
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, data)")
assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)")
}
}

Expand Down Expand Up @@ -1479,18 +1486,21 @@ class DataSourceV2SQLSuite
test("create table using - with sorted bucket") {
val identifier = "testcat.table_name"
withTable(identifier) {
sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" +
s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS")
val table = getTableMetadata(identifier)
sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" +
s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
.select("data_type").head.getString(0)
assert(part1 === "c")
assert(part1 === "a")
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, b, a)")
assert(part2 === "b")
val part3 = describe
.filter("col_name = 'Part 2'")
.select("data_type").head.getString(0)
assert(part3 === "sorted_bucket(c, d, 4, e, f)")
}
}

Expand Down

0 comments on commit 2ed827a

Please sign in to comment.