diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 6e6992d68..c0a1adb80 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -12,6 +12,8 @@ import caliban.wrappers.Wrapper.FieldWrapper import zio._ import zio.query.ZQuery +import scala.collection.mutable.ArrayBuffer + object Executor { /** @@ -51,7 +53,7 @@ object Executor { value match { case EnumValue(v) => // special case of an hybrid union containing case objects, those should return an object instead of a string - val obj = filterFields(currentField, v).collectFirst { + val obj = mergeFields(currentField, v).collectFirst { case f: Field if f.name == "__typename" => ObjectValue(List(f.alias.getOrElse(f.name) -> StringValue(v))) case f: Field if f.name == "_" => @@ -70,7 +72,7 @@ object Executor { Types.listOf(currentField.fieldType).fold(false)(_.isNullable) ) case ObjectStep(objectName, fields) => - val filteredFields = filterFields(currentField, objectName) + val filteredFields = mergeFields(currentField, objectName) val items = filteredFields.map { case f @ Field(name @ "__typename", _, _, alias, _, _, _, _, directives) => (alias.getOrElse(name), PureStep(StringValue(objectName)), fieldInfo(f, path, directives)) @@ -172,8 +174,31 @@ object Executor { private[caliban] def fail(error: CalibanError): UIO[GraphQLResponse[CalibanError]] = IO.succeed(GraphQLResponse(NullValue, List(error))) - private[caliban] def filterFields(field: Field, typeName: String): List[Field] = - field.fields.filter(_.condition.forall(_.contains(typeName))) + private[caliban] def mergeFields(field: Field, typeName: String): List[Field] = { + // ugly mutable code but it's worth it for the speed ;) + val array = ArrayBuffer.empty[Field] + val map = collection.mutable.Map.empty[String, Int] + var index = 0 + + field.fields.foreach { field => + if (field.condition.forall(_.contains(typeName))) { + val name = field.alias.getOrElse(field.name) + map.get(name) match { + case None => + // first time we see this field, add it to the array + array += field + map.update(name, index) + index = index + 1 + case Some(index) => + // field already existed, merge it + val f = array(index) + array(index) = f.copy(fields = f.fields ::: field.fields) + } + } + } + + array.toList + } private def fieldInfo(field: Field, path: List[Either[String, Int]], fieldDirectives: List[Directive]): FieldInfo = FieldInfo(field.alias.getOrElse(field.name), field, path, fieldDirectives) diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index 73b445083..664a7bfd6 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -37,16 +37,17 @@ object Field { ): Field = { def loop(selectionSet: List[Selection], fieldType: __Type): Field = { val fieldList = ArrayBuffer.empty[Field] - val map = collection.mutable.Map.empty[String, Int] + val map = collection.mutable.Map.empty[(String, String), Int] var fieldIndex = 0 - def addField(f: Field): Unit = { + def addField(f: Field, condition: Option[String]): Unit = { val name = f.alias.getOrElse(f.name) - map.get(name) match { + val key = (name, condition.getOrElse("")) + map.get(key) match { case None => // first time we see this field, add it to the array fieldList += f - map.update(name, fieldIndex) + map.update(key, fieldIndex) fieldIndex = fieldIndex + 1 case Some(index) => // field already existed, merge it @@ -88,7 +89,8 @@ object Field { resolveVariables(arguments, variableDefinitions, variableValues), () => sourceMapper.getLocation(index), directives ++ schemaDirectives - ) + ), + None ) case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) => fragments @@ -101,7 +103,7 @@ object Field { if (field.condition.isDefined) field else field.copy(condition = subtypeNames(f.typeCondition.name, rootType)) ) - .foreach(addField) + .foreach(addField(_, Some(f.typeCondition.name))) } case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) => val t = innerType.possibleTypes @@ -116,7 +118,7 @@ object Field { if (field.condition.isDefined) field else field.copy(condition = subtypeNames(typeName.name, rootType)) ) - .foreach(addField) + .foreach(addField(_, Some(typeName.name))) } case _ => } diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index 6f531e865..04b0c8e92 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -20,6 +20,35 @@ import zio.test.environment.TestEnvironment object ExecutionSpec extends DefaultRunnableSpec { + @GQLInterface + sealed trait Base { + def id: String + def name: String + } + object Base { + @GQLName("BaseOne") + case class One( + id: String, + name: String, + inner: List[One.Inner] + ) extends Base + object One { + @GQLName("BaseOneInner") + case class Inner(a: String) + } + + @GQLName("BaseTwoOne") + case class Two( + id: String, + name: String, + inner: List[Two.Inner] + ) extends Base + object Two { + @GQLName("BaseTwoInner") + case class Inner(b: Int) + } + } + override def spec: ZSpec[TestEnvironment, Any] = suite("ExecutionSpec")( testM("skip directive") { @@ -1128,6 +1157,47 @@ object ExecutionSpec extends DefaultRunnableSpec { """{"test":2}""" ) ) + }, + testM("conflicting fragments selection merging") { + + val base1 = Base.One( + id = "1", + name = "base 1", + inner = List(Base.One.Inner(a = "a")) + ) + val base2 = Base.Two( + id = "2", + name = "base 2", + inner = List(Base.Two.Inner(b = 2)) + ) + case class Test(bases: List[Base]) + + implicit val baseSchema: Schema[Any, Base] = Schema.gen + + val api = graphQL(RootResolver(Test(List(base1, base2)))) + val query = """ + query { + bases { + id + ... on BaseOne { + id + name + inner { a } + } + ... on BaseTwoOne { + id + name + inner { b } + } + } + } + """ + + api.interpreter.flatMap(_.execute(query)).map { response => + assertTrue( + response.data.toString == """{"bases":[{"id":"1","name":"base 1","inner":[{"a":"a"}]},{"id":"2","name":"base 2","inner":[{"b":2}]}]}""" + ) + } } ) }