Skip to content

Commit

Permalink
[SPARK-49137][SQL] When the Boolean condition in the if statement i…
Browse files Browse the repository at this point in the history
…s invalid, an exception should be thrown

### What changes were proposed in this pull request?
The pr aims to throw an exception to the end-user when the `Boolean condition` in the `if statement` is unexpected, instead of quietly returning `false`.

### Why are the changes needed?
Reduce unexpected behavior when end-users make errors in writing Boolean statements.

### Does this PR introduce _any_ user-facing change?
Yes, when the Boolean condition in the if statement is illegal, an exception should be thrown instead of returning false directly.

### How was this patch tested?
Add new UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47648 from panbingkun/SPARK-49137.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
panbingkun authored and dongjoon-hyun committed Aug 11, 2024
1 parent 11e0d2d commit 3ab97c1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,12 @@
],
"sqlState" : "22003"
},
"INVALID_BOOLEAN_STATEMENT" : {
"message" : [
"Boolean statement is expected in the condition, but <invalidStatement> was found."
],
"sqlState" : "22546"
},
"INVALID_BOUNDARY" : {
"message" : [
"The boundary <boundary> is invalid: <invalidValue>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.errors

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLStmt
import org.apache.spark.sql.exceptions.SqlScriptingException

/**
Expand Down Expand Up @@ -63,4 +64,14 @@ private[sql] object SqlScriptingErrors {
cause = null,
messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber))
}

def invalidBooleanStatement(
origin: Origin,
stmt: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "INVALID_BOOLEAN_STATEMENT",
cause = null,
messageParameters = Map("invalidStatement" -> toSQLStmt(stmt)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.types.BooleanType

/**
Expand Down Expand Up @@ -62,7 +63,10 @@ trait NonLeafStatementExec extends CompoundStatementExec {
* Evaluate the boolean condition represented by the statement.
* @param session SparkSession that SQL script is executed within.
* @param statement Statement representing the boolean condition to evaluate.
* @return Whether the condition evaluates to True.
* @return
* The value (`true` or `false`) of condition evaluation;
* or throw the error during the evaluation (eg: returning multiple rows of data
* or non-boolean statement).
*/
protected def evaluateBooleanCondition(
session: SparkSession,
Expand All @@ -78,11 +82,15 @@ trait NonLeafStatementExec extends CompoundStatementExec {
case Array(field) if field.dataType == BooleanType =>
df.limit(2).collect() match {
case Array(row) => row.getBoolean(0)
case _ => false
case _ =>
throw SparkException.internalError(
s"Boolean statement ${statement.getText} is invalid. It returns more than one row.")
}
case _ => false
case _ =>
throw SqlScriptingErrors.invalidBooleanStatement(statement.origin, statement.getText)
}
case _ => false
case _ =>
throw SparkException.internalError("Boolean condition must be SingleStatementExec")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.scripting

import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, Row}
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.test.SharedSparkSession

/**
Expand All @@ -29,11 +31,11 @@ import org.apache.spark.sql.test.SharedSparkSession
*/
class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
// Helpers
private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = {
private def runSqlScript(sqlText: String): Array[DataFrame] = {
val interpreter = SqlScriptingInterpreter()
val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText)
val executionPlan = interpreter.buildExecutionPlan(compoundBody, spark)
val result = executionPlan.flatMap {
executionPlan.flatMap {
case statement: SingleStatementExec =>
if (statement.isExecuted) {
None
Expand All @@ -42,7 +44,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
}
case _ => None
}.toArray
}

private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = {
val result = runSqlScript(sqlText)
assert(result.length == expected.length)
result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) }
}
Expand Down Expand Up @@ -362,4 +367,48 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
}

test("if's condition must be a boolean statement") {
withTable("t") {
val commands =
"""
|BEGIN
| IF 1 THEN
| SELECT 45;
| END IF;
|END
|""".stripMargin
checkError(
exception = intercept[SqlScriptingException] (
runSqlScript(commands)
),
errorClass = "INVALID_BOOLEAN_STATEMENT",
parameters = Map("invalidStatement" -> "1")
)
}
}

test("if's condition must return a single row data") {
withTable("t") {
val commands =
"""
|BEGIN
| CREATE TABLE t (a BOOLEAN) USING parquet;
| INSERT INTO t VALUES (true);
| INSERT INTO t VALUES (true);
| IF (select * from t) THEN
| SELECT 46;
| END IF;
|END
|""".stripMargin
checkError(
exception = intercept[SparkException] (
runSqlScript(commands)
),
errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
parameters = Map.empty,
context = ExpectedContext(fragment = "(select * from t)", start = 118, stop = 134)
)
}
}
}

0 comments on commit 3ab97c1

Please sign in to comment.