Skip to content

Commit

Permalink
[Client] Allow query to not specify every subtype of an union (#1099)
Browse files Browse the repository at this point in the history
* [Client] Allow query to not specify every subtype of an union

* Make it work with Scala 3
  • Loading branch information
ghostdogpr authored Oct 16, 2021
1 parent e8f5f26 commit fa3b7ba
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 58 deletions.
14 changes: 12 additions & 2 deletions client/src/main/scala/caliban/client/FieldBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,18 @@ object FieldBuilder {
case _ => Left(DecodingError(s"Field $value is not an object"))
}

override def toSelectionSet: List[Selection] =
override def toSelectionSet: List[Selection] = {
val filteredBuilderMap = builderMap.filter {
case (_, _: NullField.type) => false
case _ => true
}
Selection.Field(None, "__typename", Nil, Nil, Nil, 0) ::
builderMap.map { case (k, v) => Selection.InlineFragment(k, v.toSelectionSet) }.toList
filteredBuilderMap.map { case (k, v) => Selection.InlineFragment(k, v.toSelectionSet) }.toList
}
}

case object NullField extends FieldBuilder[Option[Nothing]] {
override def fromGraphQL(value: __Value): Either[DecodingError, Option[Nothing]] = Right(None)
override def toSelectionSet: List[Selection] = Nil
}
}
44 changes: 44 additions & 0 deletions client/src/test/scala/caliban/client/SelectionBuilderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ object SelectionBuilderSpec extends DefaultRunnableSpec {
)
)
},
test("union type optional") {
val query =
Queries.characters() {
Character.name ~
Character.nicknames ~
Character.roleOption(onCaptain = Some(Role.Captain.shipName))
}
val (s, _) = SelectionBuilder.toGraphQL(query.toSelectionSet, useVariables = false)
assert(s)(equalTo("characters{name nicknames role{__typename ... on Captain{shipName}}}"))
},
test("argument") {
val query =
Queries.characters(Some(Origin.MARS)) {
Expand All @@ -54,6 +64,40 @@ object SelectionBuilderSpec extends DefaultRunnableSpec {
val (s, _) = SelectionBuilder.toGraphQL(query.toSelectionSet, useVariables = false)
assert(s)(equalTo("""characters(origin:"MARS"){name}"""))
},
test("union type with optional parameters") {
case class CharacterView(name: String, nicknames: List[String], role: Option[Option[String]])
val query =
Queries.characters() {
(Character.name ~
Character.nicknames ~
Character.roleOption(onMechanic = Some(Role.Mechanic.shipName))).mapN(CharacterView(_, _, _))
}

val response =
__ObjectValue(
List(
"characters" -> __ListValue(
List(
__ObjectValue(
List(
"name" -> __StringValue("Amos Burton"),
"nicknames" -> __ListValue(List(__StringValue("Amos"))),
"role" -> __ObjectValue(
List(
"__typename" -> __StringValue("Mechanic"),
"shipName" -> __StringValue("Rocinante")
)
)
)
)
)
)
)
)
assert(query.fromGraphQL(response))(
isRight(equalTo(List(CharacterView("Amos Burton", List("Amos"), Some(Some("Rocinante"))))))
)
},
test("aliases") {
val query =
Queries
Expand Down
19 changes: 19 additions & 0 deletions client/src/test/scala/caliban/client/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ object TestData {
)
)
)
def roleOption[A](
onCaptain: Option[SelectionBuilder[Captain, A]] = None,
onEngineer: Option[SelectionBuilder[Engineer, A]] = None,
onMechanic: Option[SelectionBuilder[Mechanic, A]] = None,
onPilot: Option[SelectionBuilder[Pilot, A]] = None
): SelectionBuilder[Character, Option[Option[A]]] =
Field(
"role",
OptionOf(
ChoiceOf(
Map(
"Captain" -> onCaptain.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))),
"Engineer" -> onEngineer.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))),
"Mechanic" -> onMechanic.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))),
"Pilot" -> onPilot.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))
)
)
)
)
}

// Auto-generated query
Expand Down
114 changes: 58 additions & 56 deletions tools/src/main/scala/caliban/tools/ClientWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,15 @@ object ClientWriter {
): String =
s"type ${typedef.name} = _root_.caliban.client.Operations.RootQuery"

def writeRootQuery(
typedef: ObjectTypeDefinition
)(implicit
def writeRootQuery(typedef: ObjectTypeDefinition)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String =
s"""object ${typedef.name} {
| ${typedef.fields.map(writeField(_, "_root_.caliban.client.Operations.RootQuery")).mkString("\n ")}
| ${typedef.fields
.map(writeField(_, "_root_.caliban.client.Operations.RootQuery", optionalUnion = false))
.mkString("\n ")}
|}
|""".stripMargin

Expand All @@ -307,15 +307,15 @@ object ClientWriter {
): String =
s"type ${typedef.name} = _root_.caliban.client.Operations.RootMutation"

def writeRootMutation(
typedef: ObjectTypeDefinition
)(implicit
def writeRootMutation(typedef: ObjectTypeDefinition)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String =
s"""object ${typedef.name} {
| ${typedef.fields.map(writeField(_, "_root_.caliban.client.Operations.RootMutation")).mkString("\n ")}
| ${typedef.fields
.map(writeField(_, "_root_.caliban.client.Operations.RootMutation", optionalUnion = false))
.mkString("\n ")}
|}
|""".stripMargin

Expand All @@ -324,21 +324,19 @@ object ClientWriter {
): String =
s"type ${typedef.name} = _root_.caliban.client.Operations.RootSubscription"

def writeRootSubscription(
typedef: ObjectTypeDefinition
)(implicit
def writeRootSubscription(typedef: ObjectTypeDefinition)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String =
s"""object ${typedef.name} {
| ${typedef.fields.map(writeField(_, "_root_.caliban.client.Operations.RootSubscription")).mkString("\n ")}
| ${typedef.fields
.map(writeField(_, "_root_.caliban.client.Operations.RootSubscription", optionalUnion = false))
.mkString("\n ")}
|}
|""".stripMargin

def writeObjectType(
typedef: ObjectTypeDefinition
)(implicit
def writeObjectType(typedef: ObjectTypeDefinition)(implicit
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String = {
Expand All @@ -348,32 +346,34 @@ object ClientWriter {
s"type $objectName"
}

def writeObject(
typedef: ObjectTypeDefinition,
genView: Boolean
)(implicit
def writeObject(typedef: ObjectTypeDefinition, genView: Boolean)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String = {

val objectName: String = safeTypeName(typedef.name)
val fields = typedef.fields.map(collectFieldInfo(_, objectName))
val view =
if (genView)
"\n " + writeView(typedef.name, fields.map(_.typeInfo))
else ""
val objectName: String = safeTypeName(typedef.name)
val unionTypes = typesMap.typesMap.collect { case (key, _: UnionTypeDefinition) => key }
val optionalUnionTypeFields = typedef.fields.flatMap { field =>
val isOptionalUnionType = unionTypes.exists(_.compareToIgnoreCase(field.ofType.toString) == 0)
if (isOptionalUnionType) Some(collectFieldInfo(field, objectName, optionalUnion = true))
else None
}
val fields = typedef.fields.map(collectFieldInfo(_, objectName, optionalUnion = false))
val view = if (genView) "\n " + writeView(typedef.name, fields.map(_.typeInfo)) else ""

val allFields = fields ++ optionalUnionTypeFields

s"""object $objectName {$view
| ${fields.map(writeFieldInfo).mkString("\n ")}
| ${allFields.map(writeFieldInfo).mkString("\n ")}
|}
|""".stripMargin
}

def writeView(
objectName: String,
fields: List[FieldTypeInfo]
)(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = {
def writeView(objectName: String, fields: List[FieldTypeInfo])(implicit
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String = {
val viewName = s"${objectName}View"
val safeObjectName = safeTypeName(objectName)

Expand Down Expand Up @@ -582,7 +582,7 @@ object ClientWriter {
${encoderCases.mkString("\n")}
}

val values: Vector[${enumName}] = Vector(${typedef.enumValuesDefinition
val values: Vector[$enumName] = Vector(${typedef.enumValuesDefinition
.map(v => safeEnumValue(v.enumValue))
.mkString(", ")})
}
Expand Down Expand Up @@ -611,15 +611,12 @@ object ClientWriter {
private val tripleQuotes = "\"\"\""
private val doubleQuotes = "\""

def writeField(
field: FieldDefinition,
typeName: String
)(implicit
def writeField(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): String =
writeFieldInfo(collectFieldInfo(field, typeName))
writeFieldInfo(collectFieldInfo(field, typeName, optionalUnion))

def writeFieldInfo(fieldInfo: FieldInfo): String = {
val FieldInfo(
Expand All @@ -642,15 +639,11 @@ object ClientWriter {
s"""$description${deprecated}def $safeName$typeParam$args$innerSelection$implicits: SelectionBuilder[$typeName, $outputType] = _root_.caliban.client.SelectionBuilder.Field("$name", $builder$argBuilder)"""
}

def collectFieldInfo(
field: FieldDefinition,
typeName: String
)(implicit
def collectFieldInfo(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit
typesMap: TypesMap,
mappingClashedTypeNames: MappingClashedTypeNames,
scalarMappings: ScalarMappings
): FieldInfo = {
val name = safeName(field.name)
val description = field.description match {
case Some(d) if d.trim.nonEmpty => s"/**\n * ${d.trim}\n */\n"
case _ => ""
Expand Down Expand Up @@ -684,15 +677,11 @@ object ClientWriter {
memberTypes.flatMap(name => typesMap.get(safeTypeName(name)))
}
.getOrElse(Nil)
.collect { case o: ObjectTypeDefinition =>
o
}
.collect { case o: ObjectTypeDefinition => o }
.sortBy(_.name)
val interfaceTypes = typesMap
.get(fieldType)
.collect { case InterfaceTypeDefinition(_, name, _, _) =>
name
}
.collect { case InterfaceTypeDefinition(_, name, _, _) => name }
.map(interface =>
typesMap.values.collect {
case o @ ObjectTypeDefinition(_, _, implements, _, _) if implements.exists(_.name == interface) => o
Expand All @@ -711,15 +700,27 @@ object ClientWriter {
writeTypeBuilder(field.ofType, "Scalar()")
)
} else if (unionTypes.nonEmpty) {
(
s"[$typeLetter]",
s"(${unionTypes.map(t => s"""on${t.name}: SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]""").mkString(", ")})",
writeType(field.ofType).replace(fieldType, typeLetter),
writeTypeBuilder(
field.ofType,
s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> Obj(on${t.name})""").mkString(", ")}))"
if (optionalUnion) {
(
s"[$typeLetter]",
s"(${unionTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})",
s"Option[${writeType(field.ofType).replace(fieldType, typeLetter)}]",
writeTypeBuilder(
field.ofType,
s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> on${t.name}.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))""").mkString(", ")}))"
)
)
)
} else {
(
s"[$typeLetter]",
s"(${unionTypes.map(t => s"""on${t.name}: SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]""").mkString(", ")})",
writeType(field.ofType).replace(fieldType, typeLetter),
writeTypeBuilder(
field.ofType,
s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> Obj(on${t.name})""").mkString(", ")}))"
)
)
}
} else if (interfaceTypes.nonEmpty) {
(
s"[$typeLetter]",
Expand Down Expand Up @@ -757,12 +758,13 @@ object ClientWriter {
}.mkString(", ")})"
}

val name = if (optionalUnion && unionTypes.nonEmpty) safeName(field.name + "Option") else safeName(field.name)
val owner = if (typeParam.nonEmpty) Some(fieldType) else None
val fieldTypeInfo = FieldTypeInfo(
field.name,
name,
outputType,
interfaceTypes.map(_.name).toList,
interfaceTypes.map(_.name),
unionTypes.map(_.name),
field.args,
owner
Expand Down
14 changes: 14 additions & 0 deletions tools/src/test/scala/caliban/tools/ClientWriterSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,20 @@ object Client {
onPilot: SelectionBuilder[Pilot, A]
): SelectionBuilder[Character, Option[A]] = _root_.caliban.client.SelectionBuilder
.Field("role", OptionOf(ChoiceOf(Map("Captain" -> Obj(onCaptain), "Pilot" -> Obj(onPilot)))))
def roleOption[A](
onCaptain: Option[SelectionBuilder[Captain, A]] = None,
onPilot: Option[SelectionBuilder[Pilot, A]] = None
): SelectionBuilder[Character, Option[Option[A]]] = _root_.caliban.client.SelectionBuilder.Field(
"role",
OptionOf(
ChoiceOf(
Map(
"Captain" -> onCaptain.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))),
"Pilot" -> onPilot.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))
)
)
)
)
}
}
Expand Down
14 changes: 14 additions & 0 deletions tools/src/test/scala/caliban/tools/ClientWriterViewSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,20 @@ object Client {
onPilot: SelectionBuilder[Pilot, A]
): SelectionBuilder[Character, Option[A]] = _root_.caliban.client.SelectionBuilder
.Field("role", OptionOf(ChoiceOf(Map("Captain" -> Obj(onCaptain), "Pilot" -> Obj(onPilot)))))
def roleOption[A](
onCaptain: Option[SelectionBuilder[Captain, A]] = None,
onPilot: Option[SelectionBuilder[Pilot, A]] = None
): SelectionBuilder[Character, Option[Option[A]]] = _root_.caliban.client.SelectionBuilder.Field(
"role",
OptionOf(
ChoiceOf(
Map(
"Captain" -> onCaptain.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))),
"Pilot" -> onPilot.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))
)
)
)
)
}
type Captain
Expand Down

0 comments on commit fa3b7ba

Please sign in to comment.