Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 8, 2015
1 parent 7ea5b31 commit 715c589
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => UnresolvedGetField(base, ordinal) }
{ case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => UnresolvedGetField(base, Literal(fieldName)) }
{ case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ class Analyzer(
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldExpr) if child.resolved =>
GetField(child, fieldExpr, resolver)
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,16 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
}

/**
* Get field of an expression
* Extracts a value or values from an Expression
*
* @param child The expression to get field of, can be Map, Array, Struct or array of Struct.
* @param fieldExpr The expression to describe the field,
* can be key of Map, index of Array, field name of Struct.
* @param child The expression to extract value from,
* can be Map, Array, Struct or array of Structs.
* @param extraction The expression to describe the extraction,
* can be key of Map, index of Array, field name of Struct.
*/
case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends UnaryExpression {
case class UnresolvedExtractValue(child: Expression, extraction: Expression)
extends UnaryExpression {

override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
Expand All @@ -200,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString: String = s"$child[$fieldExpr]"
override def toString: String = s"$child[$extraction]"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -100,9 +100,9 @@ package object dsl {
def isNull: Predicate = IsNull(expr)
def isNotNull: Predicate = IsNotNull(expr)

def getItem(ordinal: Expression): UnresolvedGetField = UnresolvedGetField(expr, ordinal)
def getField(fieldName: String): UnresolvedGetField =
UnresolvedGetField(expr, Literal(fieldName))
def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal)
def getField(fieldName: String): UnresolvedExtractValue =
UnresolvedExtractValue(expr, Literal(fieldName))

def cast(to: DataType): Expression = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,50 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._

object GetField {
object ExtractValue {
/**
* Returns the resolved `GetField`. It will return one kind of concrete `GetField`,
* depend on the type of `child` and `fieldExpr`.
* Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
* depend on the type of `child` and `extraction`.
*
* `child` | `fieldExpr` | concrete `GetField`
* -------------------------------------------------------------
* Struct | Literal String | SimpleStructGetField
* Array[Struct] | Literal String | ArrayStructGetField
* Array | Integral type | ArrayOrdinalGetField
* Map | Any type | MapOrdinalGetField
* `child` | `extraction` | concrete `ExtractValue`
* ----------------------------------------------------------------
* Struct | Literal String | GetStructField
* Array[Struct] | Literal String | GetArrayStructFields
* Array | Integral type | GetArrayItem
* Map | Any type | GetMapValue
*/
def apply(
child: Expression,
fieldExpr: Expression,
resolver: Resolver): GetField = {
extraction: Expression,
resolver: Resolver): ExtractValue = {

(child.dataType, fieldExpr) match {
(child.dataType, extraction) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
SimpleStructGetField(child, fields(ordinal), ordinal)
GetStructField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
ArrayStructGetField(child, fields(ordinal), ordinal, containsNull)
case (_: ArrayType, _) if fieldExpr.dataType.isInstanceOf[IntegralType] =>
ArrayOrdinalGetField(child, fieldExpr)
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
case (_: MapType, _) =>
MapOrdinalGetField(child, fieldExpr)
GetMapValue(child, extraction)
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
s"Field name should be String Literal, but it's $fieldExpr"
s"Field name should be String Literal, but it's $extraction"
case _: ArrayType =>
s"Array index should be integral type, but it's ${fieldExpr.dataType}"
s"Array index should be integral type, but it's ${extraction.dataType}"
case other =>
s"Can't get field on $child"
s"Can't extract value from $child"
}
throw new AnalysisException(errorMsg)
}
}

def unapply(g: GetField): Option[(Expression, Expression)] = {
def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
g match {
case o: OrdinalGetField => Some((o.child, o.ordinal))
case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
case _ => Some((g.child, null))
}
}
Expand All @@ -90,7 +90,7 @@ object GetField {
}
}

trait GetField extends UnaryExpression {
trait ExtractValue extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
Expand All @@ -99,8 +99,8 @@ trait GetField extends UnaryExpression {
/**
* Returns the value of fields in the Struct `child`.
*/
case class SimpleStructGetField(child: Expression, field: StructField, ordinal: Int)
extends GetField {
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends ExtractValue {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
Expand All @@ -116,11 +116,11 @@ case class SimpleStructGetField(child: Expression, field: StructField, ordinal:
/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class ArrayStructGetField(
case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
containsNull: Boolean) extends GetField {
containsNull: Boolean) extends ExtractValue {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable
Expand All @@ -137,7 +137,7 @@ case class ArrayStructGetField(
}
}

abstract class OrdinalGetField extends GetField {
abstract class ExtractValueWithOrdinal extends ExtractValue {
self: Product =>

def ordinal: Expression
Expand Down Expand Up @@ -168,8 +168,8 @@ abstract class OrdinalGetField extends GetField {
/**
* Returns the field at `ordinal` in the Array `child`
*/
case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
extends OrdinalGetField {
case class GetArrayItem(child: Expression, ordinal: Expression)
extends ExtractValueWithOrdinal {

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

Expand All @@ -192,8 +192,8 @@ case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
/**
* Returns the value of key `ordinal` in Map `child`
*/
case class MapOrdinalGetField(child: Expression, ordinal: Expression)
extends OrdinalGetField {
case class GetMapValue(child: Expression, ordinal: Expression)
extends ExtractValueWithOrdinal {

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetField(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetField(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object PartialAggregation {
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
namedGroupingExpressions
.get(e.transform { case Alias(g: GetField, _) => g })
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// Then this will add GetField("c", GetField("b", a)), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
GetField(expr, Literal(fieldName), resolver))
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class DataType {
/**
* Enables matching against DataType for expressions:
* {{{
* case Cast(child @ DataType(), StringType) =>
* case Cast(child @ BinaryType(), StringType) =>
* ...
* }}}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.scalatest.FunSuite
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -891,57 +891,55 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)

checkEvaluation(MapOrdinalGetField(BoundReference(3, typeMap, true),
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
Literal("aa")), "bb", row)
checkEvaluation(MapOrdinalGetField(Literal.create(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
checkEvaluation(
MapOrdinalGetField(Literal.create(null, typeMap),
Literal.create(null, StringType)), null, row)
checkEvaluation(MapOrdinalGetField(BoundReference(3, typeMap, true),
GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
Literal.create(null, StringType)), null, row)

checkEvaluation(ArrayOrdinalGetField(BoundReference(4, typeArray, true),
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
Literal(1)), "bb", row)
checkEvaluation(ArrayOrdinalGetField(Literal.create(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
checkEvaluation(
ArrayOrdinalGetField(Literal.create(null, typeArray),
Literal.create(null, IntegerType)), null, row)
checkEvaluation(ArrayOrdinalGetField(BoundReference(4, typeArray, true),
GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
Literal.create(null, IntegerType)), null, row)

def quickBuildGetField(expr: Expression, fieldName: String): GetField = {
def getStructField(expr: Expression, fieldName: String): ExtractValue = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
SimpleStructGetField(expr, field, fields.indexOf(field))
GetStructField(expr, field, fields.indexOf(field))
}
}

def resolveGetField(u: UnresolvedGetField): GetField = {
GetField(u.child, u.fieldExpr, _ == _)
def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
ExtractValue(u.child, u.extraction, _ == _)
}

checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row)
checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)

val typeS_notNullable = StructType(
StructField("a", StringType, nullable = false)
:: StructField("b", StringType, nullable = false) :: Nil
)

assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
=== false)

assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true)
assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true)
assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)

checkEvaluation(resolveGetField('c.map(typeMap).at(3).getItem("aa")), "bb", row)
checkEvaluation(resolveGetField('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
checkEvaluation(resolveGetField('c.struct(typeS).at(2).getField("a")), "aa", row)
checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("error message of GetField") {
test("error message of ExtractValue") {
val structType = StructType(StructField("a", StringType, true) :: Nil)
val arrayStructType = ArrayType(structType)
val arrayType = ArrayType(StringType)
Expand All @@ -952,7 +950,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
fieldDataType: DataType,
errorMesage: String): Unit = {
val e = intercept[org.apache.spark.sql.AnalysisException] {
GetField(
ExtractValue(
Literal.create(null, childDataType),
Literal.create(null, fieldDataType),
_ == _)
Expand All @@ -963,7 +961,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
checkErrorMessage(arrayType, StringType, "Array index should be integral type")
checkErrorMessage(otherType, StringType, "Can't get field on")
checkErrorMessage(otherType, StringType, "Can't extract value from")
}

test("arithmetic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -180,10 +180,10 @@ class ConstantFoldingSuite extends PlanTest {
IsNull(Literal(null)) as 'c1,
IsNotNull(Literal(null)) as 'c2,

UnresolvedGetField(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
UnresolvedGetField(
UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
UnresolvedExtractValue(
Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4,
UnresolvedGetField(
UnresolvedExtractValue(
Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

Expand Down
Loading

0 comments on commit 715c589

Please sign in to comment.