Skip to content

Commit

Permalink
Continued work on different issues and preparing ground for #70:
Browse files Browse the repository at this point in the history
 * Initial Null value support (#55)
 * Refactored marshalling-related classes in correspondent package
 * Unified input parsing between AST values and other types of input
 * Extract input value parsing and make it first-class feature. Closes #103
  • Loading branch information
OlegIlyenko committed Dec 1, 2015
1 parent 34eda99 commit 54959db
Show file tree
Hide file tree
Showing 45 changed files with 627 additions and 262 deletions.
6 changes: 5 additions & 1 deletion src/main/scala/sangria/ast/QueryAst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ case class StringValue(value: String, position: Option[Position] = None) extends
case class BooleanValue(value: Boolean, position: Option[Position] = None) extends ScalarValue
case class EnumValue(value: String, position: Option[Position] = None) extends Value
case class ListValue(values: List[Value], position: Option[Position] = None) extends Value
case class ObjectValue(fields: List[ObjectField], position: Option[Position] = None) extends Value
case class VariableValue(name: String, position: Option[Position] = None) extends Value
case class NullValue(position: Option[Position] = None) extends Value
case class ObjectValue(fields: List[ObjectField], position: Option[Position] = None) extends Value {
lazy val fieldsByName = fields groupBy (_.name) mapValues (_.head.value)
}

case class ObjectField(name: String, value: Value, position: Option[Position] = None) extends NameValue

Expand Down Expand Up @@ -178,6 +181,7 @@ object AstNode {
case n: BigDecimalValue n.copy(position = None).asInstanceOf[T]
case n: StringValue n.copy(position = None).asInstanceOf[T]
case n: BooleanValue n.copy(position = None).asInstanceOf[T]
case n: NullValue n.copy(position = None).asInstanceOf[T]
case n: EnumValue n.copy(position = None).asInstanceOf[T]
case n: ListValue
n.copy(
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sangria.execution

import sangria.ast
import sangria.integration.{ResultMarshaller, InputUnmarshaller}
import sangria.marshalling.{InputUnmarshaller, ResultMarshaller}
import sangria.parser.SourceMapper
import sangria.schema._
import sangria.validation.QueryValidator
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/sangria/execution/Resolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sangria.execution

import org.parboiled2.Position
import sangria.ast
import sangria.integration.ResultMarshaller
import sangria.marshalling.ResultMarshaller
import sangria.parser.SourceMapper
import sangria.schema._

Expand Down Expand Up @@ -648,6 +648,7 @@ object Resolver {
case ast.FloatValue(f, _) marshaller.floatNode(f)
case ast.BigDecimalValue(f, _) marshaller.bigDecimalNode(f)
case ast.BooleanValue(b, _) marshaller.booleanNode(b)
case ast.NullValue(_) marshaller.nullNode
case ast.EnumValue(enum, _) marshaller.stringNode(enum)
case ast.ListValue(values, _) marshaller.arrayNode(values.toVector map (marshalValue(_, marshaller)))
case ast.ObjectValue(values, _) marshaller.mapNode(values map (v v.name marshalValue(v.value, marshaller)))
Expand Down
5 changes: 1 addition & 4 deletions src/main/scala/sangria/execution/ResultResolver.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package sangria.execution

import org.parboiled2.Position
import sangria.ast.AstNode
import sangria.integration.ResultMarshaller
import sangria.marshalling.ResultMarshaller
import sangria.validation.{Violation, AstNodeLocation}

import scala.collection.immutable.VectorBuilder

class ResultResolver(val marshaller: ResultMarshaller, exceptionHandler: PartialFunction[(ResultMarshaller, Throwable), HandledException]) {
def marshalErrors(errors: ErrorRegistry) = {
val marshalled = errors.errorList.map {
Expand Down
167 changes: 101 additions & 66 deletions src/main/scala/sangria/execution/ValueCoercionHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,157 @@ package sangria.execution

import org.parboiled2.Position
import sangria.ast
import sangria.integration.{InputUnmarshaller, ToInput}
import sangria.marshalling.{InputUnmarshaller, ToInput}
import sangria.parser.SourceMapper
import sangria.renderer.{QueryRenderer, SchemaRenderer}
import sangria.schema._
import sangria.validation._

class ValueCoercionHelper[Ctx](sourceMapper: Option[SourceMapper] = None, deprecationTracker: DeprecationTracker = DeprecationTracker.empty, userContext: Option[Ctx] = None) {
def resolveListValue(ofType: InputType[_], fieldPath: List[String], value: Either[Vector[Violation], Option[Any]], pos: Option[Position] = None) = value match {
def resolveListValue(ofType: InputType[_], fieldPath: List[String], value: Either[Vector[Violation], Option[Any]], pos: List[Position] = Nil) = value match {
case r @ Right(None) if ofType.isInstanceOf[OptionInputType[_]] r
case Right(Some(v)) Right(v)
case Right(None) Left(Vector(NullValueForNotNullTypeViolation(fieldPath, SchemaRenderer.renderTypeName(ofType), sourceMapper, pos.toList)))
case Right(None) Left(Vector(NullValueForNotNullTypeViolation(fieldPath, SchemaRenderer.renderTypeName(ofType), sourceMapper, pos)))
case l @ Left(_) l
}

def resolveMapValue(ofType: InputType[_], fieldPath: List[String], default: Option[(_, ToInput[_, _])], fieldName: String, acc: Map[String, Either[Vector[Violation], Any]], value: Either[Vector[Violation], Option[Any]], pos: Option[Position] = None, allowErrorsOnDefault: Boolean = false) = {
def resolveMapValue(ofType: InputType[_], fieldPath: List[String], default: Option[(_, ToInput[_, _])], fieldName: String, acc: Map[String, Either[Vector[Violation], Any]], value: Either[Vector[Violation], Option[Any]], pos: List[Position] = Nil, allowErrorsOnDefault: Boolean = false) = {
def getDefault = {
val Some((defaultValue, toInput)) = default.asInstanceOf[Option[(Any, ToInput[Any, Any])]]
val (defaultInput, inputUnmarshaller) = toInput.toInput(defaultValue)

coerceInputValue(ofType, fieldPath, defaultInput)(inputUnmarshaller).right.map(_.get)
coerceInputValue(ofType, fieldPath, defaultInput, None)(inputUnmarshaller).right.map(_.get)
}

value match {
case Right(None) if default.isDefined acc.updated(fieldName, getDefault)
case r @ Right(None) if ofType.isInstanceOf[OptionInputType[_]] acc
case Right(Some(v)) acc.updated(fieldName, Right(v))
case Right(None) acc.updated(fieldName, Left(Vector(NullValueForNotNullTypeViolation(fieldPath, SchemaRenderer.renderTypeName(ofType), sourceMapper, pos.toList))))
case Right(None) acc.updated(fieldName, Left(Vector(NullValueForNotNullTypeViolation(fieldPath, SchemaRenderer.renderTypeName(ofType), sourceMapper, pos))))
case l @ Left(_) if allowErrorsOnDefault && default.isDefined acc.updated(fieldName, getDefault)
case l @ Left(_) acc.updated(fieldName, l)
}
}

def coerceInputValue[In](tpe: InputType[_], fieldPath: List[String], input: In, errorPrefix: String = "")(implicit iu: InputUnmarshaller[In]): Either[Vector[Violation], Option[Any]] = (tpe, input) match {
case (OptionInputType(ofType), value) coerceInputValue(ofType, fieldPath, value, errorPrefix)
def coerceInputValue[In](
tpe: InputType[_],
fieldPath: List[String],
input: In,
variables: Option[Map[String, Any]],
errorPrefix: String = "")(implicit iu: InputUnmarshaller[In]): Either[Vector[Violation], Option[Any]] = (tpe, input) match {
case (_, node) if iu.isVariableNode(node)
val varName = iu.getVariableName(node)

variables match {
case Some(vars)
Right(vars get varName)
case None
Left(Vector(VariableNotAllowedViolation(varName, sourceMapper, Nil)))
}

case (OptionInputType(ofType), value)
coerceInputValue(ofType, fieldPath, value, variables, errorPrefix)

case (ListInputType(ofType), values) if iu.isArrayNode(values)
val res = iu.getListValue(values).map {
case defined if iu.isDefined(defined) resolveListValue(ofType, fieldPath, coerceInputValue(ofType, fieldPath, defined, errorPrefix))
case _ resolveListValue(ofType, fieldPath, Right(None))
case defined if iu.isDefined(defined)
resolveListValue(ofType, fieldPath, coerceInputValue(ofType, fieldPath, defined, variables, errorPrefix), valuePosition(defined))
case v
resolveListValue(ofType, fieldPath, Right(None), valuePosition(v, values))
}

val (errors, successes) = res.partition(_.isLeft)

if (errors.nonEmpty) Left(errors.collect{case Left(errors) errors}.toVector.flatten)
if (errors.nonEmpty) Left(errors.collect{case Left(es) es}.toVector.flatten)
else Right(Some(successes.collect {case Right(v) v}))

case (ListInputType(ofType), value)
resolveListValue(ofType, fieldPath, coerceInputValue(ofType, fieldPath, value, errorPrefix)) match {
case Right(v) Right(Some(Seq(v)))
case l @ Left(violations) Left(violations)
val res = value match {
case defined if iu.isDefined(defined)
resolveListValue(ofType, fieldPath, coerceInputValue(ofType, fieldPath, defined, variables, errorPrefix), valuePosition(defined))
case v
resolveListValue(ofType, fieldPath, Right(None), valuePosition(v, value))
}

res match {
case Right(v) Right(Some(Vector(v)))
case Left(violations) Left(violations)
}

case (objTpe: InputObjectType[_], valueMap) if iu.isMapNode(valueMap)
val res = objTpe.fields.foldLeft(Map.empty[String, Either[Vector[Violation], Any]]) {
case (acc, field) iu.getMapValue(valueMap, field.name) match {
case Some(defined) if iu.isDefined(defined)
resolveMapValue(field.fieldType, fieldPath :+ field.name, field.defaultValue, field.name, acc,
coerceInputValue(field.fieldType, fieldPath :+ field.name, defined, errorPrefix))
case _ resolveMapValue(field.fieldType, fieldPath :+ field.name, field.defaultValue, field.name, acc, Right(None))
coerceInputValue(field.fieldType, fieldPath :+ field.name, defined, variables, errorPrefix), valuePosition(defined))
case v
resolveMapValue(field.fieldType, fieldPath :+ field.name, field.defaultValue, field.name, acc, Right(None), valuePosition(v, valueMap))
}
}

val errors = res.collect{case (_, Left(errors)) errors}.toVector.flatten

if (errors.nonEmpty) Left(errors)
else Right(Some(res mapValues (_.right.get)))

case (objTpe: InputObjectType[_], value)
Left(Vector(InputObjectTypeMismatchViolation(fieldPath, SchemaRenderer.renderTypeName(objTpe), iu.render(value), sourceMapper, valuePosition(value))))

case (scalar: ScalarType[_], value) if iu.isScalarNode(value)
scalar.coerceUserInput(iu.getScalarValue(value))
.fold(violation Left(Vector(FieldCoercionViolation(fieldPath, violation, None, Nil, errorPrefix))), v Right(Some(v)))
case (enum: EnumType[_], value) if iu.isScalarNode(value)
enum.coerceUserInput(iu.getScalarValue(value))
.fold(violation Left(Vector(FieldCoercionViolation(fieldPath, violation, None, Nil, errorPrefix))), {
val coerced = iu.getScalarValue(value) match {
case node: ast.Value scalar.coerceInput(node)
case other scalar.coerceUserInput(other)
}

coerced.fold(
violation Left(Vector(FieldCoercionViolation(fieldPath, violation, sourceMapper, valuePosition(value), errorPrefix))),
v Right(Some(v)))

case (enum: ScalarType[_], value)
Left(Vector(FieldCoercionViolation(fieldPath, GenericInvalidValueViolation(sourceMapper, valuePosition(value)), sourceMapper, valuePosition(value), errorPrefix)))

case (enum: EnumType[_], value) if iu.isEnumNode(value)
val coerced = iu.getScalarValue(value) match {
case node: ast.Value enum.coerceInput(node)
case other enum.coerceUserInput(other)
}

coerced.fold(violation Left(Vector(FieldCoercionViolation(fieldPath, violation, sourceMapper, valuePosition(value), errorPrefix))), {
case (v, deprecated)
if (deprecated && userContext.isDefined) deprecationTracker.deprecatedEnumValueUsed(enum, v, userContext.get)

Right(Some(v))
})

case (enum: EnumType[_], value)
Left(Vector(FieldCoercionViolation(fieldPath, EnumCoercionViolation, sourceMapper, valuePosition(value), errorPrefix)))
}

def valuePosition[T](value: T*): List[Position] = {
val values = value.view.collect {
case node: ast.AstNode if node.position.isDefined node.position.toList
}

values.headOption.fold(Nil: List[Position])(identity)
}

def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) isValidValue(ofType, Some(value))
case (OptionInputType(_), _) Vector.empty
case (_, None) Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isArrayNode(values)
um.getListValue(values).toVector.flatMap(v isValidValue(ofType, v match {
case opt: Option[In @unchecked] opt
case other Option(other)
}) map (ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value)
isValidValue(ofType, value match {
case opt: Option[In @unchecked] opt
case other Option(other)
}) map (ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap)
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f)
Expand All @@ -103,63 +164,37 @@ class ValueCoercionHelper[Ctx](sourceMapper: Option[SourceMapper] = None, deprec
objTpe.fields.toVector.flatMap(f
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name)) map (MapValueViolation(f.name, _, sourceMapper, Nil)))
}

case (objTpe: InputObjectType[_], _)
Vector(InputObjectIsOfWrongTypeMissingViolation(SchemaRenderer.renderTypeName(objTpe, true), sourceMapper, Nil))

case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value)
scalar.coerceUserInput(um.getScalarValue(value)) match {
case Left(violation) Vector(violation)
case _ Vector.empty
val coerced = um.getScalarValue(value) match {
case node: ast.Value scalar.coerceInput(node)
case other scalar.coerceUserInput(other)
}
case (enum: EnumType[_], Some(value)) if um.isScalarNode(value)
enum.coerceUserInput(um.getScalarValue(value)) match {

coerced match {
case Left(violation) Vector(violation)
case _ Vector.empty
}
case _ Vector(GenericInvalidValueViolation(sourceMapper, Nil))
}

def coerceAstValue(tpe: InputType[_], fieldPath: List[String], input: ast.Value, variables: Map[String, Any]): Either[Vector[Violation], Option[Any]] = (tpe, input) match {
case (_, ast.VariableValue(name, _)) Right(variables get name)
case (OptionInputType(ofType), value) coerceAstValue(ofType, fieldPath, value, variables)
case (ListInputType(ofType), ast.ListValue(values, _))
val res = values.map {v resolveListValue(ofType, fieldPath, coerceAstValue(ofType, fieldPath, v, variables), v.position)}
val (errors, successes) = res.partition(_.isLeft)

if (errors.nonEmpty) Left(errors.collect{case Left(errors) errors}.toVector.flatten)
else Right(Some(successes.collect {case Right(v) v}))
case (ListInputType(ofType), value)
resolveListValue(ofType, fieldPath, coerceAstValue(ofType, fieldPath, value, variables), value.position) match {
case Right(v) Right(Some(Seq(v)))
case l @ Left(violations) Left(violations)
}
case (objTpe: InputObjectType[_], ast.ObjectValue(fieldList, objPos))
val astFields = fieldList groupBy (_.name) mapValues (_.head)
val res = objTpe.fields.foldLeft(Map.empty[String, Either[Vector[Violation], Any]]) {
case (acc, field) astFields get field.name match {
case Some(defined)
resolveMapValue(field.fieldType, fieldPath, field.defaultValue, field.name, acc,
coerceAstValue(field.fieldType, fieldPath :+ field.name, defined.value, variables), defined.value.position)
case _ resolveMapValue(field.fieldType, fieldPath, field.defaultValue, field.name, acc, Right(None), objPos)
}
case (enum: EnumType[_], Some(value)) if um.isEnumNode(value)
val coerced = um.getScalarValue(value) match {
case node: ast.Value enum.coerceInput(node)
case other enum.coerceUserInput(other)
}

val errors = res.collect{case (_, Left(errors)) errors}.toVector.flatten
coerced match {
case Left(violation) Vector(violation)
case _ Vector.empty
}

if (errors.nonEmpty) Left(errors)
else Right(Some(res mapValues (_.right.get)))
case (objTpe: InputObjectType[_], value)
Left(Vector(InputObjectTypeMismatchViolation(fieldPath, SchemaRenderer.renderTypeName(objTpe), QueryRenderer.render(value), sourceMapper, value.position.toList)))
case (scalar: ScalarType[_], value)
scalar.coerceInput(value)
.fold(violation Left(Vector(FieldCoercionViolation(fieldPath, violation, sourceMapper, value.position.toList, ""))), v Right(Some(v)))
case (enum: EnumType[_], value)
enum.coerceInput(value)
.fold(violation Left(Vector(FieldCoercionViolation(fieldPath, violation, sourceMapper, value.position.toList, ""))), {
case (v, deprecated)
if (deprecated && userContext.isDefined) deprecationTracker.deprecatedEnumValueUsed(enum, v, userContext.get)
case (enum: EnumType[_], Some(value))
Vector(EnumCoercionViolation)

Right(Some(v))
})
case _
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
}
}

Expand Down
15 changes: 10 additions & 5 deletions src/main/scala/sangria/execution/ValueCollector.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sangria.execution

import sangria.ast
import sangria.integration.InputUnmarshaller
import sangria.marshalling.InputUnmarshaller
import sangria.parser.SourceMapper
import sangria.renderer.QueryRenderer
import sangria.schema._
Expand Down Expand Up @@ -46,9 +46,12 @@ class ValueCollector[Ctx, Input](schema: Schema[_, _], inputVars: Input, sourceM
if (violations.isEmpty) {
val fieldPath = s"$$${definition.name}" :: Nil

if (input.isEmpty || !um.isDefined(input.get))
definition.defaultValue map (coerceAstValue(tpe, fieldPath, _, Map.empty)) getOrElse Right(None)
else coerceInputValue(tpe, fieldPath, input.get)
if (input.isEmpty || !um.isDefined(input.get)) {
import sangria.marshalling.queryAst.queryAstInputUnmarshaller

definition.defaultValue map (coerceInputValue(tpe, fieldPath, _, None)) getOrElse Right(None)
} else
coerceInputValue(tpe, fieldPath, input.get, None)
} else Left(violations.map(violation
VarTypeMismatchViolation(definition.name, QueryRenderer.render(definition.tpe), input map um.render, violation: Violation, sourceMapper, definition.position.toList)))
}
Expand All @@ -72,8 +75,10 @@ class ValueCollector[Ctx, Input](schema: Schema[_, _], inputVars: Input, sourceM
val argPath = argDef.name :: Nil
val astValue = astArgMap get argDef.name map (_.value)

import sangria.marshalling.queryAst.queryAstInputUnmarshaller

resolveMapValue(argDef.argumentType, argPath, argDef.defaultValue, argDef.name, acc,
astValue map (coerceAstValue(argDef.argumentType, argPath, _, variables)) getOrElse Right(None), allowErrorsOnDefault = true)
astValue map (coerceInputValue(argDef.argumentType, argPath, _, Some(variables))) getOrElse Right(None), allowErrorsOnDefault = true)
}

val errors = res.collect{case (_, Left(errors)) errors}.toVector.flatten
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/sangria/execution/middleware.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package sangria.execution

import sangria.marshalling.InputUnmarshaller

import language.{implicitConversions, existentials}

import sangria.ast
import sangria.integration.InputUnmarshaller
import sangria.schema.{Action, Context}


Expand Down
Loading

0 comments on commit 54959db

Please sign in to comment.