diff --git a/tools/src/main/scala/caliban/tools/ClientWriter.scala b/tools/src/main/scala/caliban/tools/ClientWriter.scala index 9dbbdd1159..f7c8887760 100644 --- a/tools/src/main/scala/caliban/tools/ClientWriter.scala +++ b/tools/src/main/scala/caliban/tools/ClientWriter.scala @@ -297,7 +297,9 @@ object ClientWriter { ): String = s"""object ${typedef.name} { | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootQuery", optionalUnion = false)) + .map( + writeField(_, "_root_.caliban.client.Operations.RootQuery", optionalUnion = false, optionalInterface = false) + ) .mkString("\n ")} |} |""".stripMargin @@ -314,7 +316,9 @@ object ClientWriter { ): String = s"""object ${typedef.name} { | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootMutation", optionalUnion = false)) + .map( + writeField(_, "_root_.caliban.client.Operations.RootMutation", optionalUnion = false, optionalInterface = false) + ) .mkString("\n ")} |} |""".stripMargin @@ -331,7 +335,14 @@ object ClientWriter { ): String = s"""object ${typedef.name} { | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootSubscription", optionalUnion = false)) + .map( + writeField( + _, + "_root_.caliban.client.Operations.RootSubscription", + optionalUnion = false, + optionalInterface = false + ) + ) .mkString("\n ")} |} |""".stripMargin @@ -352,17 +363,28 @@ object ClientWriter { scalarMappings: ScalarMappings ): String = { - val objectName: String = safeTypeName(typedef.name) + val objectName: String = safeTypeName(typedef.name) + val unionTypes = typesMap.typesMap.collect { case (key, _: UnionTypeDefinition) => key } val optionalUnionTypeFields = typedef.fields.flatMap { field => val isOptionalUnionType = unionTypes.exists(_.compareToIgnoreCase(field.ofType.toString) == 0) - if (isOptionalUnionType) Some(collectFieldInfo(field, objectName, optionalUnion = true)) + if (isOptionalUnionType) + Some(collectFieldInfo(field, objectName, optionalUnion = true, optionalInterface = false)) else None } - val fields = typedef.fields.map(collectFieldInfo(_, objectName, optionalUnion = false)) - val view = if (genView) "\n " + writeView(typedef.name, fields.map(_.typeInfo)) else "" - val allFields = fields ++ optionalUnionTypeFields + val interfaceTypes = typesMap.typesMap.collect { case (key, _: InterfaceTypeDefinition) => key } + val optionalInterfaceTypeFields = typedef.fields.flatMap { field => + val isOptionalInterfaceType = interfaceTypes.exists(_.compareToIgnoreCase(field.ofType.toString) == 0) + if (isOptionalInterfaceType) + Some(collectFieldInfo(field, objectName, optionalUnion = false, optionalInterface = true)) + else None + } + + val fields = typedef.fields.map(collectFieldInfo(_, objectName, optionalUnion = false, optionalInterface = false)) + val view = if (genView) "\n " + writeView(typedef.name, fields.map(_.typeInfo)) else "" + + val allFields = fields ++ optionalUnionTypeFields ++ optionalInterfaceTypeFields s"""object $objectName {$view | ${allFields.map(writeFieldInfo).mkString("\n ")} @@ -612,12 +634,12 @@ object ClientWriter { private val tripleQuotes = "\"\"\"" private val doubleQuotes = "\"" - def writeField(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit + def writeField(field: FieldDefinition, typeName: String, optionalUnion: Boolean, optionalInterface: Boolean)(implicit typesMap: TypesMap, mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings ): String = - writeFieldInfo(collectFieldInfo(field, typeName, optionalUnion)) + writeFieldInfo(collectFieldInfo(field, typeName, optionalUnion, optionalInterface)) def writeFieldInfo(fieldInfo: FieldInfo): String = { val FieldInfo( @@ -640,7 +662,8 @@ object ClientWriter { s"""$description${deprecated}def $safeName$typeParam$args$innerSelection$implicits: SelectionBuilder[$typeName, $outputType] = _root_.caliban.client.SelectionBuilder.Field("$name", $builder$argBuilder)""" } - def collectFieldInfo(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit + def collectFieldInfo(field: FieldDefinition, typeName: String, optionalUnion: Boolean, optionalInterface: Boolean)( + implicit typesMap: TypesMap, mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings @@ -723,15 +746,27 @@ object ClientWriter { ) } } else if (interfaceTypes.nonEmpty) { - ( - s"[$typeLetter]", - s"(${interfaceTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", - writeType(field.ofType).replace(fieldType, typeLetter), - writeTypeBuilder( - field.ofType, - s"ChoiceOf(Map(${interfaceTypes.map(t => s""""${t.name}" -> on${t.name}""").mkString(", ")}).collect { case (k, Some(v)) => k -> Obj(v)})" + if (optionalInterface) { + ( + s"[$typeLetter]", + s"(${interfaceTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", + s"Option[${writeType(field.ofType).replace(fieldType, typeLetter)}]", + writeTypeBuilder( + field.ofType, + s"ChoiceOf(Map(${interfaceTypes.map(t => s""""${t.name}" -> on${t.name}.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))""").mkString(", ")}))" + ) ) - ) + } else { + ( + s"[$typeLetter]", + s"(${interfaceTypes.map(t => s"""on${t.name}: SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]""").mkString(", ")})", + writeType(field.ofType).replace(fieldType, typeLetter), + writeTypeBuilder( + field.ofType, + s"ChoiceOf(Map(${interfaceTypes.map(t => s""""${t.name}" -> Obj(on${t.name})""").mkString(", ")}))" + ) + ) + } } else { ( s"[$typeLetter]", @@ -759,7 +794,10 @@ object ClientWriter { }.mkString(", ")})" } - val name = if (optionalUnion && unionTypes.nonEmpty) safeName(field.name + "Option") else safeName(field.name) + val name = + if ((optionalUnion && unionTypes.nonEmpty) || (optionalInterface && interfaceTypes.nonEmpty)) + safeName(field.name + "Option") + else safeName(field.name) val owner = if (typeParam.nonEmpty) Some(fieldType) else None val fieldTypeInfo = FieldTypeInfo( field.name, diff --git a/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala b/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala index 57dc40b423..615386612f 100644 --- a/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala +++ b/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala @@ -843,6 +843,51 @@ object Client { ) ) ) + }, + testM("interface") { + val schema = + """ + interface Order { + name: String! + } + type Ascending implements Order { + name: String! + } + type Sort { + order: Order + } + """.stripMargin + + assertM(gen(schema, Map.empty, List.empty)) { + equalTo( + """import caliban.client.FieldBuilder._ +import caliban.client._ + +object Client { + + type Ascending + object Ascending { + def name: SelectionBuilder[Ascending, String] = _root_.caliban.client.SelectionBuilder.Field("name", Scalar()) + } + + type Sort + object Sort { + def order[A](onAscending: SelectionBuilder[Ascending, A]): SelectionBuilder[Sort, Option[A]] = + _root_.caliban.client.SelectionBuilder.Field("order", OptionOf(ChoiceOf(Map("Ascending" -> Obj(onAscending))))) + def orderOption[A]( + onAscending: Option[SelectionBuilder[Ascending, A]] = None + ): SelectionBuilder[Sort, Option[Option[A]]] = _root_.caliban.client.SelectionBuilder.Field( + "order", + OptionOf( + ChoiceOf(Map("Ascending" -> onAscending.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a))))) + ) + ) + } + +} +""" + ) + } } ) @@ TestAspect.sequential }