Skip to content

Commit

Permalink
[SPARK-48821][SQL] Support Update in DataFrameWriterV2
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Aug 6, 2024
1 parent c3985ac commit 056492b
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,16 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.PartitionTransform$ExtractTransform")) ++
"org.apache.spark.sql.PartitionTransform$ExtractTransform"),

// Update Writer
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SparkSession.update"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition$")) ++
mergeIntoWriterExcludeRules

checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -832,6 +833,36 @@ class SparkSession private(
ret
}

/**
* Update rows in a table that match a condition.
*
* Scala Example:
* {{{
* spark.update("source")
* .set(
* Map("salary" -> lit(200))
* )
* .where($"salary" === 100)
* .execute()
*
* }}}
* @param tableName is either a qualified or unqualified name that designates a table or view.
* If a database is specified, it identifies the table/view from the database.
* Otherwise, it first attempts to find a temporary view with the given name
* and then match the table/view from the current database.
* Note that, the global temporary view database is also valid here.
* @since 4.0.0
*/
def update(tableName: String): UpdateWriter = {
val tableDF = table(tableName)
if (tableDF.isStreaming) {
throw new AnalysisException(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
new UpdateWriter(tableDF)
}

// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
Expand Down
102 changes: 102 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.functions.expr

/**
* `UpdateWriter` provides methods to define and execute an update action on a target table.
*
* @param tableDF DataFrame representing table to update.
*
* @since 4.0.0
*/
@Experimental
class UpdateWriter (tableDF: DataFrame) {

/**
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
*/
def set(assignments: Map[String, Column]): UpdateWithAssignment = {
new UpdateWithAssignment(tableDF, assignments)
}
}

/**
* A class for defining a condition on an update operation or directly executing it.
*
* @param tableDF DataFrame representing table to update.
* @param assignment A Map of column names to Column expressions representing the updates
* to be applied.
*
* @since 4.0.0
*/
@Experimental
class UpdateWithAssignment(tableDF: DataFrame, assignment: Map[String, Column]) {

/**
* Limits the update to rows matching the specified condition.
*
* @param condition the update condition
* @return
*/
def where(condition: Column): UpdateWithCondition = {
new UpdateWithCondition(tableDF, assignment, Some(condition))
}

/**
* Executes the update operation.
*/
def execute(): Unit = {
new UpdateWithCondition(tableDF, assignment, None)
}
}

/**
* A class for executing an update operation.
*
* @param tableDF DataFrame representing table to update.
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition
* @since 4.0.0
*/
@Experimental
class UpdateWithCondition(
tableDF: DataFrame,
assignments: Map[String, Column],
condition: Option[Column]) {

private val sparkSession = tableDF.sparkSession
private val logicalPlan = tableDF.queryExecution.logical

/**
* Executes the update operation.
*/
def execute(): Unit = {
val update = UpdateTable(
logicalPlan,
assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
condition.map(_.expr))
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._

class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {

import testImplicits._

test("Basic Update") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.update(tableNameAsString)
.set(Map("salary" -> lit(-1)))
.where($"pk" >= 2)
.execute()

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 300, "hr"),
Row(2, -1, "software"),
Row(3, -1, "hr")))
}

test("Update without where clause") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.update(tableNameAsString)
.set(Map("dep" -> lit("software")))
.execute()

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 300, "software"),
Row(2, 150, "software"),
Row(3, 120, "software")))
}
}

0 comments on commit 056492b

Please sign in to comment.