diff --git a/codegen-sbt/src/main/scala/caliban/codegen/CalibanSettings.scala b/codegen-sbt/src/main/scala/caliban/codegen/CalibanSettings.scala index 25c30c036..d732adb6c 100644 --- a/codegen-sbt/src/main/scala/caliban/codegen/CalibanSettings.scala +++ b/codegen-sbt/src/main/scala/caliban/codegen/CalibanSettings.scala @@ -1,6 +1,7 @@ package caliban.codegen import caliban.tools.CalibanCommonSettings +import caliban.tools.Codegen.GenType import java.io.File import java.net.URL @@ -19,6 +20,7 @@ final case class CalibanFileSettings(file: File, settings: CalibanCommonSettings def splitFiles(value: Boolean): CalibanFileSettings = this.copy(settings = this.settings.splitFiles(value)) def enableFmt(value: Boolean): CalibanFileSettings = this.copy(settings = this.settings.enableFmt(value)) def extensibleEnums(value: Boolean): CalibanFileSettings = this.copy(settings = this.settings.extensibleEnums(value)) + def genType(genType: GenType): CalibanFileSettings = this.copy(settings = this.settings.genType(genType)) } final case class CalibanUrlSettings(url: URL, settings: CalibanCommonSettings) extends CalibanSettings { @@ -34,4 +36,5 @@ final case class CalibanUrlSettings(url: URL, settings: CalibanCommonSettings) e def splitFiles(value: Boolean): CalibanUrlSettings = this.copy(settings = this.settings.splitFiles(value)) def enableFmt(value: Boolean): CalibanUrlSettings = this.copy(settings = this.settings.enableFmt(value)) def extensibleEnums(value: Boolean): CalibanUrlSettings = this.copy(settings = this.settings.extensibleEnums(value)) + def genType(genType: GenType): CalibanUrlSettings = this.copy(settings = this.settings.genType(genType)) } diff --git a/codegen-sbt/src/main/scala/caliban/codegen/CalibanSourceGenerator.scala b/codegen-sbt/src/main/scala/caliban/codegen/CalibanSourceGenerator.scala index 754a85f6e..e6b76b369 100644 --- a/codegen-sbt/src/main/scala/caliban/codegen/CalibanSourceGenerator.scala +++ b/codegen-sbt/src/main/scala/caliban/codegen/CalibanSourceGenerator.scala @@ -1,6 +1,5 @@ package caliban.codegen -import _root_.caliban.tools.Codegen.GenType import _root_.caliban.tools._ import sbt._ import sjsonnew.IsoLList.Aux @@ -30,7 +29,7 @@ object CalibanSourceGenerator { implicit val analysisIso: Aux[TrackedSettings, Seq[String] :*: LNil] = LList.iso[TrackedSettings, Seq[String] :*: LNil]( { case TrackedSettings(arguments) => ("args", arguments) :*: LNil }, - { case ((_, args) :*: LNil) => TrackedSettings(args) } + { case (_, args) :*: LNil => TrackedSettings(args) } ) } @@ -78,7 +77,7 @@ object CalibanSourceGenerator { generatedSource <- ZIO.succeed(transformFile(sourceRoot, sourceManaged, settings)(graphql)) _ <- Task(sbt.IO.createDirectory(generatedSource.toPath.getParent.toFile)).asSomeError opts <- ZIO.fromOption(Some(settings.toOptions(graphql.toString, generatedSource.toString))) - files <- Codegen.generate(opts, GenType.Client).asSomeError + files <- Codegen.generate(opts, settings.genType).asSomeError } yield files def generateUrlSource( @@ -92,7 +91,7 @@ object CalibanSourceGenerator { ) _ <- Task(sbt.IO.createDirectory(generatedSource.toPath.getParent.toFile)).asSomeError opts <- ZIO.fromOption(Some(settings.toOptions(graphql.toString, generatedSource.toString))) - files <- Codegen.generate(opts, GenType.Client).asSomeError + files <- Codegen.generate(opts, settings.genType).asSomeError } yield files Runtime.default diff --git a/tools/src/main/scala/caliban/tools/CalibanCommonSettings.scala b/tools/src/main/scala/caliban/tools/CalibanCommonSettings.scala index b1b30d961..94d949368 100644 --- a/tools/src/main/scala/caliban/tools/CalibanCommonSettings.scala +++ b/tools/src/main/scala/caliban/tools/CalibanCommonSettings.scala @@ -1,5 +1,7 @@ package caliban.tools +import caliban.tools.Codegen.GenType + final case class CalibanCommonSettings( clientName: Option[String], scalafmtPath: Option[String], @@ -10,7 +12,8 @@ final case class CalibanCommonSettings( imports: Seq[String], splitFiles: Option[Boolean], enableFmt: Option[Boolean], - extensibleEnums: Option[Boolean] + extensibleEnums: Option[Boolean], + genType: GenType ) { private[caliban] def toOptions(schemaPath: String, toPath: String): Options = @@ -42,7 +45,8 @@ final case class CalibanCommonSettings( imports = this.imports ++ r.imports, splitFiles = r.splitFiles.orElse(this.splitFiles), enableFmt = r.enableFmt.orElse(this.enableFmt), - extensibleEnums = r.extensibleEnums.orElse(this.extensibleEnums) + extensibleEnums = r.extensibleEnums.orElse(this.extensibleEnums), + genType = r.genType ) def clientName(value: String): CalibanCommonSettings = this.copy(clientName = Some(value)) @@ -56,6 +60,7 @@ final case class CalibanCommonSettings( def splitFiles(value: Boolean): CalibanCommonSettings = this.copy(splitFiles = Some(value)) def enableFmt(value: Boolean): CalibanCommonSettings = this.copy(enableFmt = Some(value)) def extensibleEnums(value: Boolean): CalibanCommonSettings = this.copy(extensibleEnums = Some(value)) + def genType(genType: GenType): CalibanCommonSettings = this.copy(genType = genType) } object CalibanCommonSettings { @@ -70,6 +75,7 @@ object CalibanCommonSettings { imports = Seq.empty, splitFiles = None, enableFmt = None, - extensibleEnums = None + extensibleEnums = None, + GenType.Client ) } diff --git a/tools/src/main/scala/caliban/tools/ClientWriter.scala b/tools/src/main/scala/caliban/tools/ClientWriter.scala index 94a9fe74e..a4509f48d 100644 --- a/tools/src/main/scala/caliban/tools/ClientWriter.scala +++ b/tools/src/main/scala/caliban/tools/ClientWriter.scala @@ -5,14 +5,12 @@ import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition._ import caliban.parsing.adt.Type.{ ListType, NamedType } import caliban.parsing.adt.{ Document, Type } -import caliban.tools.implicits.Implicits._ -import caliban.tools.implicits.{ MappingClashedTypeNames, ScalarMappings, TypesMap } import scala.annotation.tailrec object ClientWriter { - private val MaxTupleLength = 22 + val MaxTupleLength = 22 def write( schema: Document, @@ -21,27 +19,65 @@ object ClientWriter { genView: Boolean = false, additionalImports: Option[List[String]] = None, splitFiles: Boolean = false, - extensibleEnums: Boolean = false - )(implicit scalarMappings: ScalarMappings): List[(String, String)] = { + extensibleEnums: Boolean = false, + scalarMappings: Option[Map[String, String]] = None + ): List[(String, String)] = { require(packageName.isDefined || !splitFiles, "splitFiles option requires a package name") - val schemaDef = schema.schemaDefinition + def getMappingsClashedNames(typeNames: List[String], reservedNames: List[String] = Nil): Map[String, String] = + (reservedNames ::: typeNames) + .map(name => name.toLowerCase -> name) + .groupBy(_._1) + .collect { + case (_, _ :: typeNamesToRename) if typeNamesToRename.nonEmpty => + typeNamesToRename.zipWithIndex.map { case ((_, originalTypeName), index) => + val suffix = "_" + (index + 1) + originalTypeName -> s"$originalTypeName$suffix" + }.toMap + } + .reduceOption(_ ++ _) + .getOrElse(Map.empty) + + val mappingClashedTypeNames: Map[String, String] = getMappingsClashedNames( + schema.definitions.collect { + case ObjectTypeDefinition(_, name, _, _, _) => name + case InputObjectTypeDefinition(_, name, _, _) => name + case EnumTypeDefinition(_, name, _, _) => name + case UnionTypeDefinition(_, name, _, _) => name + case ScalarTypeDefinition(_, name, _) => name + case InterfaceTypeDefinition(_, name, _, _) => name + }, + if (splitFiles) List("package") else Nil + ) + + def safeUnapplyName(name: String): String = + if (reservedKeywords.contains(name) || name.endsWith("_") || isCapital(name)) s"${decapitalize(name)}$$" + else name + + def isCapital(name: String): Boolean = name.nonEmpty && name.charAt(0).isUpper + + def decapitalize(name: String): String = if (isCapital(name)) { + val chars = name.toCharArray + chars(0) = chars(0).toLower + new String(chars) + } else { + name + } - implicit val mappingClashedTypeNames: MappingClashedTypeNames = MappingClashedTypeNames( - getMappingsClashedNames( - schema.definitions.collect { - case ObjectTypeDefinition(_, name, _, _, _) => name - case InputObjectTypeDefinition(_, name, _, _) => name - case EnumTypeDefinition(_, name, _, _) => name - case UnionTypeDefinition(_, name, _, _) => name - case ScalarTypeDefinition(_, name, _) => name - case InterfaceTypeDefinition(_, name, _, _) => name - }, - if (splitFiles) List("package") else Nil + def safeName(name: String): String = + if (reservedKeywords.contains(name) || name.endsWith("_")) s"`$name`" + else if (caseClassReservedFields.contains(name)) s"$name$$" + else name + + def safeTypeName( + typeName: String + ): String = + mappingClashedTypeNames.getOrElse( + typeName, + scalarMappings.flatMap(m => m.get(typeName)).getOrElse(safeName(typeName)) ) - ) - implicit val typesMap: TypesMap = TypesMap(schema.definitions.collect { + val typesMap: Map[String, TypeDefinition] = schema.definitions.collect { case op @ ObjectTypeDefinition(_, name, _, _, _) => name -> op case op @ InputObjectTypeDefinition(_, name, _, _) => name -> op case op @ EnumTypeDefinition(_, name, _, _) => name -> op @@ -50,7 +86,517 @@ object ClientWriter { case op @ InterfaceTypeDefinition(_, name, _, _) => name -> op }.map { case (name, op) => safeTypeName(name) -> op - }.toMap) + }.toMap + + def writeFieldInfo(fieldInfo: FieldInfo): String = { + val FieldInfo( + name, + safeName, + description, + deprecated, + typeName, + typeParam, + args, + implicits, + innerSelection, + outputType, + builder, + argBuilder, + _ + ) = + fieldInfo + + s"""$description${deprecated}def $safeName$typeParam$args$innerSelection$implicits: SelectionBuilder[$typeName, $outputType] = _root_.caliban.client.SelectionBuilder.Field("$name", $builder$argBuilder)""" + } + + def collectFieldInfo(field: FieldDefinition, typeName: String, optionalUnion: Boolean): FieldInfo = { + val description = field.description match { + case Some(d) if d.trim.nonEmpty => s"/**\n * ${d.trim}\n */\n" + case _ => "" + } + val deprecated = field.directives.find(_.name == "deprecated") match { + case None => "" + case Some(directive) => + val body = + directive.arguments.collectFirst { case ("reason", StringValue(reason)) => + reason + }.getOrElse("") + + val quotes = + if (body.contains("\n")) tripleQuotes + else doubleQuotes + + "@deprecated(" + quotes + body + quotes + """, "")""" + "\n" + } + val fieldType = safeTypeName(getTypeName(field.ofType)) + val isScalar = typesMap + .get(fieldType) + .collect { + case _: ScalarTypeDefinition => true + case _: EnumTypeDefinition => true + case _ => false + } + .getOrElse(true) + val unionTypes = typesMap + .get(fieldType) + .collect { case UnionTypeDefinition(_, _, _, memberTypes) => + memberTypes.flatMap(name => typesMap.get(safeTypeName(name))) + } + .getOrElse(Nil) + .collect { case o: ObjectTypeDefinition => o } + .sortBy(_.name) + val interfaceTypes = typesMap + .get(fieldType) + .collect { case InterfaceTypeDefinition(_, name, _, _) => name } + .map(interface => + typesMap.values.collect { + case o @ ObjectTypeDefinition(_, _, implements, _, _) if implements.exists(_.name == interface) => o + } + ) + .getOrElse(Nil) + .toList + .sortBy(_.name) + val typeLetter = getTypeLetter(typesMap) + val (typeParam, innerSelection, outputType, builder) = + if (isScalar) { + ( + "", + "", + writeType(field.ofType), + writeTypeBuilder(field.ofType, "Scalar()") + ) + } else if (unionTypes.nonEmpty) { + if (optionalUnion) { + ( + s"[$typeLetter]", + s"(${unionTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", + s"Option[${writeType(field.ofType).replace(fieldType, typeLetter)}]", + writeTypeBuilder( + field.ofType, + s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> on${t.name}.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))""").mkString(", ")}))" + ) + ) + } else { + ( + s"[$typeLetter]", + s"(${unionTypes.map(t => s"""on${t.name}: SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]""").mkString(", ")})", + writeType(field.ofType).replace(fieldType, typeLetter), + writeTypeBuilder( + field.ofType, + s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> Obj(on${t.name})""").mkString(", ")}))" + ) + ) + } + } else if (interfaceTypes.nonEmpty) { + ( + s"[$typeLetter]", + s"(${interfaceTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", + writeType(field.ofType).replace(fieldType, typeLetter), + writeTypeBuilder( + field.ofType, + s"ChoiceOf(Map(${interfaceTypes.map(t => s""""${t.name}" -> on${t.name}""").mkString(", ")}).collect { case (k, Some(v)) => k -> Obj(v)})" + ) + ) + } else { + ( + s"[$typeLetter]", + s"(innerSelection: SelectionBuilder[$fieldType, $typeLetter])", + writeType(field.ofType).replace(fieldType, typeLetter), + writeTypeBuilder(field.ofType, "Obj(innerSelection)") + ) + } + val args = field.args match { + case Nil => "" + case list => s"(${writeArgumentFields(list)})" + } + val argBuilder = field.args match { + case Nil => "" + case list => + s", arguments = List(${list.zipWithIndex.map { case (arg, idx) => + s"""Argument("${arg.name}", ${safeName(arg.name)}, "${arg.ofType.toString}")(encoder$idx)""" + }.mkString(", ")})" + } + val implicits = field.args match { + case Nil => "" + case list => + s"(implicit ${list.zipWithIndex.map { case (arg, idx) => + s"""encoder$idx: ArgEncoder[${writeType(arg.ofType)}]""" + }.mkString(", ")})" + } + + val name = if (optionalUnion && unionTypes.nonEmpty) safeName(field.name + "Option") else safeName(field.name) + val owner = if (typeParam.nonEmpty) Some(fieldType) else None + val fieldTypeInfo = FieldTypeInfo( + field.name, + name, + outputType, + interfaceTypes.map(_.name), + unionTypes.map(_.name), + field.args, + owner + ) + FieldInfo( + field.name, + name, + description, + deprecated, + typeName, + typeParam, + args, + implicits, + innerSelection, + outputType, + builder, + argBuilder, + fieldTypeInfo + ) + } + + def mapTypeName( + s: String + ): String = s match { + case "Float" => "Double" + case "ID" => "String" + case other => safeTypeName(other) + } + + def writeField(field: FieldDefinition, typeName: String, optionalUnion: Boolean): String = + writeFieldInfo(collectFieldInfo(field, typeName, optionalUnion)) + + def reservedType(typeDefinition: ObjectTypeDefinition): Boolean = + typeDefinition.name == "Query" || typeDefinition.name == "Mutation" || typeDefinition.name == "Subscription" + + def writeRootQueryType( + typedef: ObjectTypeDefinition + ): String = + s"type ${typedef.name} = _root_.caliban.client.Operations.RootQuery" + + def writeRootQuery(typedef: ObjectTypeDefinition): String = + s"""object ${typedef.name} { + | ${typedef.fields + .map(writeField(_, "_root_.caliban.client.Operations.RootQuery", optionalUnion = false)) + .mkString("\n ")} + |} + |""".stripMargin + + def writeRootMutationType( + typedef: ObjectTypeDefinition + ): String = + s"type ${typedef.name} = _root_.caliban.client.Operations.RootMutation" + + def writeRootMutation(typedef: ObjectTypeDefinition): String = + s"""object ${typedef.name} { + | ${typedef.fields + .map(writeField(_, "_root_.caliban.client.Operations.RootMutation", optionalUnion = false)) + .mkString("\n ")} + |} + |""".stripMargin + + def writeRootSubscriptionType( + typedef: ObjectTypeDefinition + ): String = + s"type ${typedef.name} = _root_.caliban.client.Operations.RootSubscription" + + def writeRootSubscription(typedef: ObjectTypeDefinition): String = + s"""object ${typedef.name} { + | ${typedef.fields + .map(writeField(_, "_root_.caliban.client.Operations.RootSubscription", optionalUnion = false)) + .mkString("\n ")} + |} + |""".stripMargin + + def writeObjectType(typedef: ObjectTypeDefinition): String = { + + val objectName: String = safeTypeName(typedef.name) + + s"type $objectName" + } + + def writeObject(typedef: ObjectTypeDefinition, genView: Boolean): String = { + + val objectName: String = safeTypeName(typedef.name) + val unionTypes = typesMap.collect { case (key, _: UnionTypeDefinition) => key } + val optionalUnionTypeFields = typedef.fields.flatMap { field => + val isOptionalUnionType = unionTypes.exists(_.compareToIgnoreCase(field.ofType.toString) == 0) + if (isOptionalUnionType) Some(collectFieldInfo(field, objectName, optionalUnion = true)) + else None + } + val fields = typedef.fields.map(collectFieldInfo(_, objectName, optionalUnion = false)) + val view = if (genView) "\n " + writeView(typedef.name, fields.map(_.typeInfo)) else "" + + val allFields = fields ++ optionalUnionTypeFields + + s"""object $objectName {$view + | ${allFields.map(writeFieldInfo).mkString("\n ")} + |} + |""".stripMargin + } + + def writeView(objectName: String, fields: List[FieldTypeInfo]): String = { + val viewName = s"${objectName}View" + val safeObjectName = safeTypeName(objectName) + + def argumentName(fieldName: String, argName: String): String = + fieldName + argName.capitalize + + def withRoundBrackets(input: List[String]): String = + if (input.nonEmpty) input.mkString("(", ", ", ")") else "" + + val genericSelectionFields = + fields.collect { + case field if field.owner.nonEmpty => + field -> s"${field.rawName}Selection" + } + + val genericSelectionFieldTypes = + genericSelectionFields.map { case (field, name) => (field, name.capitalize) } + + val genericSelectionFieldsMap = genericSelectionFields.toMap + val genericSelectionFieldTypesMap = genericSelectionFieldTypes.toMap + + val viewFunctionArgumentsCount: Int = fields.map(_.arguments.length).sum + val needsCaseClassForArguments = viewFunctionArgumentsCount > MaxTupleLength + val viewFunctionArguments: List[String] = + fields.collect { + case field if field.arguments.nonEmpty => + writeArgumentFields( + field.arguments.map(a => a.copy(name = argumentName(field.name, a.name))) + ) + } + + val viewFunctionSelectionArgumentsCount: Int = genericSelectionFields.collect { + case (FieldTypeInfo(_, _, _, Nil, Nil, _, Some(_)), _) => 1 + case (FieldTypeInfo(_, _, _, _, unionTypes, _, Some(_)), _) if unionTypes.nonEmpty => unionTypes.length + case (FieldTypeInfo(_, _, _, interfaceTypes, _, _, Some(_)), _) if interfaceTypes.nonEmpty => + interfaceTypes.length + }.sum + val needsCaseClassForSelectionArguments = viewFunctionSelectionArgumentsCount > MaxTupleLength + val viewFunctionSelectionArguments: List[String] = + genericSelectionFields.collect { + case (field @ FieldTypeInfo(_, _, _, Nil, Nil, _, Some(owner)), fieldName) => + val tpe = genericSelectionFieldTypesMap(field) + List(s"$fieldName: SelectionBuilder[$owner, $tpe]") + + case (field @ FieldTypeInfo(_, _, _, _, unionTypes, _, Some(_)), fieldName) if unionTypes.nonEmpty => + val tpe = genericSelectionFieldTypesMap(field) + unionTypes.map(unionType => s"${fieldName}On$unionType: SelectionBuilder[$unionType, $tpe]") + + case (field @ FieldTypeInfo(_, _, _, interfaceTypes, _, _, Some(_)), fieldName) if interfaceTypes.nonEmpty => + val tpe = genericSelectionFieldTypesMap(field) + interfaceTypes.map(intType => s"${fieldName}On$intType: Option[SelectionBuilder[$intType, $tpe]] = None") + }.flatten + + val viewClassFields: List[String] = + fields.map { + case field @ FieldTypeInfo(_, _, outputType, _, _, _, Some(_)) => + val tpeName = genericSelectionFieldTypesMap(field) + val tpe = + if (outputType.contains("[A]")) outputType.replace("[A]", s"[$tpeName]") + else outputType.dropRight(1) + tpeName + s"${field.name}: $tpe" + + case field @ FieldTypeInfo(_, _, outputType, _, _, _, _) => + s"${field.name}: $outputType" + } + + val viewFunctionBody: String = + fields.map { case field @ FieldTypeInfo(_, _, _, interfaceTypes, unionTypes, _, _) => + val argsPart = withRoundBrackets( + field.arguments + .map(a => argumentName(field.name, a.name)) + .map(name => if (needsCaseClassForArguments) s"args.$name" else name) + ) + val selectionType = genericSelectionFieldsMap + .get(field) + .map(name => if (needsCaseClassForSelectionArguments) s"selectionArgs.$name" else name) + val selectionPart = { + val parts = + if (unionTypes.nonEmpty) unionTypes.map(unionType => s"${selectionType.head}On$unionType") + else if (interfaceTypes.nonEmpty) interfaceTypes.map(tpe => s"${selectionType.head}On$tpe") + else selectionType.toList + + withRoundBrackets(parts) + } + + s"${field.name}$argsPart$selectionPart" + }.mkString(" ~ ") + + val viewClassFieldParams: String = withRoundBrackets(viewClassFields) + + val viewFunction: String = + fields match { + case Nil => throw new Exception("Invalid GraphQL Schema: an object must have at least one field") + case head :: Nil => + s"$viewFunctionBody.map(${head.name} => $viewName(${head.name}))" + + case other => + val unapply = fields.tail.foldLeft(safeUnapplyName(fields.head.rawName)) { case (acc, field) => + "(" + acc + ", " + safeUnapplyName(field.rawName) + ")" + } + s"($viewFunctionBody).map { case $unapply => $viewName(${other.map(f => safeUnapplyName(f.rawName)).mkString(", ")}) }" + } + + val typeParams = + if (genericSelectionFieldTypes.nonEmpty) genericSelectionFieldTypes.map(_._2).mkString("[", ", ", "]") else "" + + val viewFunctionArgs = + if (needsCaseClassForArguments) s"(args: ${viewName}Args)" else withRoundBrackets(viewFunctionArguments) + val viewFunctionSelectionArgs = + if (needsCaseClassForSelectionArguments) s"(selectionArgs: ${viewName}SelectionArgs$typeParams)" + else withRoundBrackets(viewFunctionSelectionArguments) + + val caseClassForArguments = + if (needsCaseClassForArguments) s"final case class ${viewName}Args${withRoundBrackets(viewFunctionArguments)}" + else "" + val caseClassForSelectionArguments = + if (needsCaseClassForSelectionArguments) + s"final case class ${viewName}SelectionArgs$typeParams${withRoundBrackets(viewFunctionSelectionArguments)}" + else "" + + s""" + |final case class $viewName$typeParams$viewClassFieldParams + | + |$caseClassForArguments + |$caseClassForSelectionArguments + | + |type ViewSelection$typeParams = SelectionBuilder[$safeObjectName, $viewName$typeParams] + | + |def view$typeParams$viewFunctionArgs$viewFunctionSelectionArgs: ViewSelection$typeParams = $viewFunction + |""".stripMargin + } + + def writeInputObject( + typedef: InputObjectTypeDefinition + ): String = { + val inputObjectName = safeTypeName(typedef.name) + s"""final case class $inputObjectName(${writeArgumentFields(typedef.fields)}) + |object $inputObjectName { + | implicit val encoder: ArgEncoder[$inputObjectName] = new ArgEncoder[$inputObjectName] { + | override def encode(value: $inputObjectName): __Value = + | __ObjectValue(List(${typedef.fields + .map(f => + s""""${f.name}" -> ${writeInputValue( + f.ofType, + s"value.${safeName(f.name)}", + inputObjectName + )}""" + ) + .mkString(", ")})) + | } + |}""".stripMargin + } + + def writeInputValue( + t: Type, + fieldName: String, + typeName: String + ): String = t match { + case NamedType(name, true) => + if (name == typeName) s"encode($fieldName)" + else s"implicitly[ArgEncoder[${mapTypeName(name)}]].encode($fieldName)" + case NamedType(name, false) => + s"$fieldName.fold(__NullValue: __Value)(value => ${writeInputValue(NamedType(name, nonNull = true), "value", typeName)})" + case ListType(ofType, true) => + s"__ListValue($fieldName.map(value => ${writeInputValue(ofType, "value", typeName)}))" + case ListType(ofType, false) => + s"$fieldName.fold(__NullValue: __Value)(value => ${writeInputValue(ListType(ofType, nonNull = true), "value", typeName)})" + } + + def writeEnum( + typedef: EnumTypeDefinition, + extensibleEnums: Boolean + ): String = { + + val enumName = safeTypeName(typedef.name) + + val mappingClashedEnumValues = getMappingsClashedNames( + typedef.enumValuesDefinition.map(_.enumValue) + ) + + def safeEnumValue(enumValue: String): String = + safeName(mappingClashedEnumValues.getOrElse(enumValue, enumValue)) + + val enumCases = typedef.enumValuesDefinition + .map(v => + s"case object ${safeEnumValue(v.enumValue)} extends $enumName { val value: String = ${"\"" + safeEnumValue(v.enumValue) + "\""} }" + ) ++ + (if (extensibleEnums) Some(s"final case class __Unknown(value: String) extends $enumName") else None) + + val decoderCases = typedef.enumValuesDefinition + .map(v => s"""case __StringValue ("${v.enumValue}") => Right($enumName.${safeEnumValue(v.enumValue)})""") ++ + (if (extensibleEnums) Some(s"case __StringValue (other) => Right($enumName.__Unknown(other))") else None) + + val encoderCases = typedef.enumValuesDefinition + .map(v => s"""case ${typedef.name}.${safeEnumValue(v.enumValue)} => __EnumValue("${v.enumValue}")""") ++ + (if (extensibleEnums) Some(s"case ${typedef.name}.__Unknown (value) => __EnumValue(value)") else None) + + s"""sealed trait $enumName extends scala.Product with scala.Serializable { def value: String } + object $enumName { + ${enumCases.mkString("\n")} + + implicit val decoder: ScalarDecoder[$enumName] = { + ${decoderCases.mkString("\n")} + case other => Left(DecodingError(s"Can't build ${typedef.name} from input $$other")) + } + implicit val encoder: ArgEncoder[${typedef.name}] = { + ${encoderCases.mkString("\n")} + } + + val values: Vector[$enumName] = Vector(${typedef.enumValuesDefinition + .map(v => safeEnumValue(v.enumValue)) + .mkString(", ")}) + } + """ + } + + def writeScalar( + typedef: ScalarTypeDefinition + ): String = + s"""type ${safeTypeName(typedef.name)} = String + """ + + @tailrec + def getTypeLetter(typesMap: Map[String, TypeDefinition], letter: String = "A"): String = + if (!typesMap.contains(letter)) letter else getTypeLetter(typesMap, letter + "A") + + def writeArgumentFields( + args: List[InputValueDefinition] + ): String = + s"${args.map(arg => s"${safeName(arg.name)} : ${writeType(arg.ofType)}${writeDefaultArgument(arg)}").mkString(", ")}" + + def writeDefaultArgument(arg: InputValueDefinition): String = + arg.ofType match { + case t if t.nullable => " = None" + case ListType(_, _) => " = Nil" + case _ => "" + } + + def writeType( + t: Type + ): String = t match { + case NamedType(name, true) => mapTypeName(name) + case NamedType(name, false) => s"Option[${mapTypeName(name)}]" + case ListType(ofType, true) => s"List[${writeType(ofType)}]" + case ListType(ofType, false) => s"Option[List[${writeType(ofType)}]]" + } + + def writeTypeBuilder(t: Type, inner: String): String = t match { + case NamedType(_, true) => inner + case NamedType(_, false) => s"OptionOf($inner)" + case ListType(of, true) => s"ListOf(${writeTypeBuilder(of, inner)})" + case ListType(of, false) => s"OptionOf(ListOf(${writeTypeBuilder(of, inner)}))" + } + + @tailrec + def getTypeName(t: Type): String = t match { + case NamedType(name, _) => name + case ListType(ofType, _) => getTypeName(ofType) + } + + def isScalarSupported(scalar: String): Boolean = + supportedScalars.contains(scalar) || scalarMappings.exists(_.contains(scalar)) + + val schemaDef = schema.schemaDefinition val objectTypes = if (splitFiles) @@ -102,7 +648,7 @@ object ClientWriter { } val enums = schema.enumTypeDefinitions - .filter(e => !scalarMappings.scalarMap.exists(_.contains(e.name))) + .filter(e => !scalarMappings.exists(_.contains(e.name))) .map { typedef => val content = writeEnum(typedef, extensibleEnums = extensibleEnums) val fullContent = @@ -260,654 +806,6 @@ object ClientWriter { } } - private def getMappingsClashedNames(typeNames: List[String], reservedNames: List[String] = Nil): Map[String, String] = - (reservedNames ::: typeNames) - .map(name => name.toLowerCase -> name) - .groupBy(_._1) - .collect { - case (_, _ :: typeNamesToRename) if typeNamesToRename.nonEmpty => - typeNamesToRename.zipWithIndex.map { case ((_, originalTypeName), index) => - val suffix = "_" + (index + 1) - originalTypeName -> s"$originalTypeName$suffix" - }.toMap - } - .reduceOption(_ ++ _) - .getOrElse(Map.empty) - - private def safeTypeName( - typeName: String - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = - mappingClashedTypeNames.getOrElse( - typeName, - scalarMappings.scalarMap.flatMap(m => m.get(typeName)).getOrElse(safeName(typeName)) - ) - - def reservedType(typeDefinition: ObjectTypeDefinition): Boolean = - typeDefinition.name == "Query" || typeDefinition.name == "Mutation" || typeDefinition.name == "Subscription" - - def writeRootQueryType( - typedef: ObjectTypeDefinition - ): String = - s"type ${typedef.name} = _root_.caliban.client.Operations.RootQuery" - - def writeRootQuery(typedef: ObjectTypeDefinition)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = - s"""object ${typedef.name} { - | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootQuery", optionalUnion = false)) - .mkString("\n ")} - |} - |""".stripMargin - - def writeRootMutationType( - typedef: ObjectTypeDefinition - ): String = - s"type ${typedef.name} = _root_.caliban.client.Operations.RootMutation" - - def writeRootMutation(typedef: ObjectTypeDefinition)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = - s"""object ${typedef.name} { - | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootMutation", optionalUnion = false)) - .mkString("\n ")} - |} - |""".stripMargin - - def writeRootSubscriptionType( - typedef: ObjectTypeDefinition - ): String = - s"type ${typedef.name} = _root_.caliban.client.Operations.RootSubscription" - - def writeRootSubscription(typedef: ObjectTypeDefinition)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = - s"""object ${typedef.name} { - | ${typedef.fields - .map(writeField(_, "_root_.caliban.client.Operations.RootSubscription", optionalUnion = false)) - .mkString("\n ")} - |} - |""".stripMargin - - def writeObjectType(typedef: ObjectTypeDefinition)(implicit - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = { - - val objectName: String = safeTypeName(typedef.name) - - s"type $objectName" - } - - def writeObject(typedef: ObjectTypeDefinition, genView: Boolean)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = { - - val objectName: String = safeTypeName(typedef.name) - val unionTypes = typesMap.typesMap.collect { case (key, _: UnionTypeDefinition) => key } - val optionalUnionTypeFields = typedef.fields.flatMap { field => - val isOptionalUnionType = unionTypes.exists(_.compareToIgnoreCase(field.ofType.toString) == 0) - if (isOptionalUnionType) Some(collectFieldInfo(field, objectName, optionalUnion = true)) - else None - } - val fields = typedef.fields.map(collectFieldInfo(_, objectName, optionalUnion = false)) - val view = if (genView) "\n " + writeView(typedef.name, fields.map(_.typeInfo)) else "" - - val allFields = fields ++ optionalUnionTypeFields - - s"""object $objectName {$view - | ${allFields.map(writeFieldInfo).mkString("\n ")} - |} - |""".stripMargin - } - - def writeView(objectName: String, fields: List[FieldTypeInfo])(implicit - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = { - val viewName = s"${objectName}View" - val safeObjectName = safeTypeName(objectName) - - def argumentName(fieldName: String, argName: String): String = - fieldName + argName.capitalize - - def withRoundBrackets(input: List[String]): String = - if (input.nonEmpty) input.mkString("(", ", ", ")") else "" - - val genericSelectionFields = - fields.collect { - case field if field.owner.nonEmpty => - field -> s"${field.rawName}Selection" - } - - val genericSelectionFieldTypes = - genericSelectionFields.map { case (field, name) => (field, name.capitalize) } - - val genericSelectionFieldsMap = genericSelectionFields.toMap - val genericSelectionFieldTypesMap = genericSelectionFieldTypes.toMap - - val viewFunctionArgumentsCount: Int = fields.map(_.arguments.length).sum - val needsCaseClassForArguments = viewFunctionArgumentsCount > MaxTupleLength - val viewFunctionArguments: List[String] = - fields.collect { - case field if field.arguments.nonEmpty => - writeArgumentFields( - field.arguments.map(a => a.copy(name = argumentName(field.name, a.name))) - ) - } - - val viewFunctionSelectionArgumentsCount: Int = genericSelectionFields.collect { - case (FieldTypeInfo(_, _, _, Nil, Nil, _, Some(_)), _) => 1 - case (FieldTypeInfo(_, _, _, _, unionTypes, _, Some(_)), _) if unionTypes.nonEmpty => unionTypes.length - case (FieldTypeInfo(_, _, _, interfaceTypes, _, _, Some(_)), _) if interfaceTypes.nonEmpty => - interfaceTypes.length - }.sum - val needsCaseClassForSelectionArguments = viewFunctionSelectionArgumentsCount > MaxTupleLength - val viewFunctionSelectionArguments: List[String] = - genericSelectionFields.collect { - case (field @ FieldTypeInfo(_, _, _, Nil, Nil, _, Some(owner)), fieldName) => - val tpe = genericSelectionFieldTypesMap(field) - List(s"$fieldName: SelectionBuilder[$owner, $tpe]") - - case (field @ FieldTypeInfo(_, _, _, _, unionTypes, _, Some(_)), fieldName) if unionTypes.nonEmpty => - val tpe = genericSelectionFieldTypesMap(field) - unionTypes.map(unionType => s"${fieldName}On$unionType: SelectionBuilder[$unionType, $tpe]") - - case (field @ FieldTypeInfo(_, _, _, interfaceTypes, _, _, Some(_)), fieldName) if interfaceTypes.nonEmpty => - val tpe = genericSelectionFieldTypesMap(field) - interfaceTypes.map(intType => s"${fieldName}On$intType: Option[SelectionBuilder[$intType, $tpe]] = None") - }.flatten - - val viewClassFields: List[String] = - fields.map { - case field @ FieldTypeInfo(_, _, outputType, _, _, _, Some(_)) => - val tpeName = genericSelectionFieldTypesMap(field) - val tpe = - if (outputType.contains("[A]")) outputType.replace("[A]", s"[$tpeName]") - else outputType.dropRight(1) + tpeName - s"${field.name}: $tpe" - - case field @ FieldTypeInfo(_, _, outputType, _, _, _, _) => - s"${field.name}: $outputType" - } - - val viewFunctionBody: String = - fields.map { case field @ FieldTypeInfo(_, _, _, interfaceTypes, unionTypes, _, _) => - val argsPart = withRoundBrackets( - field.arguments - .map(a => argumentName(field.name, a.name)) - .map(name => if (needsCaseClassForArguments) s"args.$name" else name) - ) - val selectionType = genericSelectionFieldsMap - .get(field) - .map(name => if (needsCaseClassForSelectionArguments) s"selectionArgs.$name" else name) - val selectionPart = { - val parts = - if (unionTypes.nonEmpty) unionTypes.map(unionType => s"${selectionType.head}On$unionType") - else if (interfaceTypes.nonEmpty) interfaceTypes.map(tpe => s"${selectionType.head}On$tpe") - else selectionType.toList - - withRoundBrackets(parts) - } - - s"${field.name}$argsPart$selectionPart" - }.mkString(" ~ ") - - val viewClassFieldParams: String = withRoundBrackets(viewClassFields) - - val viewFunction: String = - fields match { - case Nil => throw new Exception("Invalid GraphQL Schema: an object must have at least one field") - case head :: Nil => - s"$viewFunctionBody.map(${head.name} => $viewName(${head.name}))" - - case other => - val unapply = fields.tail.foldLeft(safeUnapplyName(fields.head.rawName)) { case (acc, field) => - "(" + acc + ", " + safeUnapplyName(field.rawName) + ")" - } - s"($viewFunctionBody).map { case $unapply => $viewName(${other.map(f => safeUnapplyName(f.rawName)).mkString(", ")}) }" - } - - val typeParams = - if (genericSelectionFieldTypes.nonEmpty) genericSelectionFieldTypes.map(_._2).mkString("[", ", ", "]") else "" - - val viewFunctionArgs = - if (needsCaseClassForArguments) s"(args: ${viewName}Args)" else withRoundBrackets(viewFunctionArguments) - val viewFunctionSelectionArgs = - if (needsCaseClassForSelectionArguments) s"(selectionArgs: ${viewName}SelectionArgs$typeParams)" - else withRoundBrackets(viewFunctionSelectionArguments) - - val caseClassForArguments = - if (needsCaseClassForArguments) s"final case class ${viewName}Args${withRoundBrackets(viewFunctionArguments)}" - else "" - val caseClassForSelectionArguments = - if (needsCaseClassForSelectionArguments) - s"final case class ${viewName}SelectionArgs$typeParams${withRoundBrackets(viewFunctionSelectionArguments)}" - else "" - - s""" - |final case class $viewName$typeParams$viewClassFieldParams - | - |$caseClassForArguments - |$caseClassForSelectionArguments - | - |type ViewSelection$typeParams = SelectionBuilder[$safeObjectName, $viewName$typeParams] - | - |def view$typeParams$viewFunctionArgs$viewFunctionSelectionArgs: ViewSelection$typeParams = $viewFunction - |""".stripMargin - } - - def writeInputObject( - typedef: InputObjectTypeDefinition - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = { - val inputObjectName = safeTypeName(typedef.name) - s"""final case class $inputObjectName(${writeArgumentFields(typedef.fields)}) - |object $inputObjectName { - | implicit val encoder: ArgEncoder[$inputObjectName] = new ArgEncoder[$inputObjectName] { - | override def encode(value: $inputObjectName): __Value = - | __ObjectValue(List(${typedef.fields - .map(f => - s""""${f.name}" -> ${writeInputValue( - f.ofType, - s"value.${safeName(f.name)}", - inputObjectName - )}""" - ) - .mkString(", ")})) - | } - |}""".stripMargin - } - - def writeInputValue( - t: Type, - fieldName: String, - typeName: String - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = t match { - case NamedType(name, true) => - if (name == typeName) s"encode($fieldName)" - else s"implicitly[ArgEncoder[${mapTypeName(name)}]].encode($fieldName)" - case NamedType(name, false) => - s"$fieldName.fold(__NullValue: __Value)(value => ${writeInputValue(NamedType(name, nonNull = true), "value", typeName)})" - case ListType(ofType, true) => - s"__ListValue($fieldName.map(value => ${writeInputValue(ofType, "value", typeName)}))" - case ListType(ofType, false) => - s"$fieldName.fold(__NullValue: __Value)(value => ${writeInputValue(ListType(ofType, nonNull = true), "value", typeName)})" - } - - def writeEnum( - typedef: EnumTypeDefinition, - extensibleEnums: Boolean - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = { - - val enumName = safeTypeName(typedef.name) - - val mappingClashedEnumValues = getMappingsClashedNames( - typedef.enumValuesDefinition.map(_.enumValue) - ) - - def safeEnumValue(enumValue: String): String = - safeName(mappingClashedEnumValues.getOrElse(enumValue, enumValue)) - - val enumCases = typedef.enumValuesDefinition - .map(v => - s"case object ${safeEnumValue(v.enumValue)} extends $enumName { val value: String = ${"\"" + safeEnumValue(v.enumValue) + "\""} }" - ) ++ - (if (extensibleEnums) Some(s"final case class __Unknown(value: String) extends $enumName") else None) - - val decoderCases = typedef.enumValuesDefinition - .map(v => s"""case __StringValue ("${v.enumValue}") => Right($enumName.${safeEnumValue(v.enumValue)})""") ++ - (if (extensibleEnums) Some(s"case __StringValue (other) => Right($enumName.__Unknown(other))") else None) - - val encoderCases = typedef.enumValuesDefinition - .map(v => s"""case ${typedef.name}.${safeEnumValue(v.enumValue)} => __EnumValue("${v.enumValue}")""") ++ - (if (extensibleEnums) Some(s"case ${typedef.name}.__Unknown (value) => __EnumValue(value)") else None) - - s"""sealed trait $enumName extends scala.Product with scala.Serializable { def value: String } - object $enumName { - ${enumCases.mkString("\n")} - - implicit val decoder: ScalarDecoder[$enumName] = { - ${decoderCases.mkString("\n")} - case other => Left(DecodingError(s"Can't build ${typedef.name} from input $$other")) - } - implicit val encoder: ArgEncoder[${typedef.name}] = { - ${encoderCases.mkString("\n")} - } - - val values: Vector[$enumName] = Vector(${typedef.enumValuesDefinition - .map(v => safeEnumValue(v.enumValue)) - .mkString(", ")}) - } - """ - } - - def writeScalar( - typedef: ScalarTypeDefinition - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = - s"""type ${safeTypeName(typedef.name)} = String - """ - - def safeUnapplyName(name: String): String = - if (reservedKeywords.contains(name) || name.endsWith("_") || isCapital(name)) s"${decapitalize(name)}$$" - else name - - private def isCapital(name: String): Boolean = name.nonEmpty && name.charAt(0).isUpper - - private def decapitalize(name: String): String = if (isCapital(name)) { - val chars = name.toCharArray - chars(0) = chars(0).toLower - new String(chars) - } else { - name - } - - def safeName(name: String): String = - if (reservedKeywords.contains(name) || name.endsWith("_")) s"`$name`" - else if (caseClassReservedFields.contains(name)) s"$name$$" - else name - - @tailrec - def getTypeLetter(typesMap: Map[String, TypeDefinition], letter: String = "A"): String = - if (!typesMap.contains(letter)) letter else getTypeLetter(typesMap, letter + "A") - - private val tripleQuotes = "\"\"\"" - private val doubleQuotes = "\"" - - def writeField(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): String = - writeFieldInfo(collectFieldInfo(field, typeName, optionalUnion)) - - def writeFieldInfo(fieldInfo: FieldInfo): String = { - val FieldInfo( - name, - safeName, - description, - deprecated, - typeName, - typeParam, - args, - implicits, - innerSelection, - outputType, - builder, - argBuilder, - _ - ) = - fieldInfo - - s"""$description${deprecated}def $safeName$typeParam$args$innerSelection$implicits: SelectionBuilder[$typeName, $outputType] = _root_.caliban.client.SelectionBuilder.Field("$name", $builder$argBuilder)""" - } - - def collectFieldInfo(field: FieldDefinition, typeName: String, optionalUnion: Boolean)(implicit - typesMap: TypesMap, - mappingClashedTypeNames: MappingClashedTypeNames, - scalarMappings: ScalarMappings - ): FieldInfo = { - val description = field.description match { - case Some(d) if d.trim.nonEmpty => s"/**\n * ${d.trim}\n */\n" - case _ => "" - } - val deprecated = field.directives.find(_.name == "deprecated") match { - case None => "" - case Some(directive) => - val body = - directive.arguments.collectFirst { case ("reason", StringValue(reason)) => - reason - }.getOrElse("") - - val quotes = - if (body.contains("\n")) tripleQuotes - else doubleQuotes - - "@deprecated(" + quotes + body + quotes + """, "")""" + "\n" - } - val fieldType = safeTypeName(getTypeName(field.ofType)) - val isScalar = typesMap - .get(fieldType) - .collect { - case _: ScalarTypeDefinition => true - case _: EnumTypeDefinition => true - case _ => false - } - .getOrElse(true) - val unionTypes = typesMap - .get(fieldType) - .collect { case UnionTypeDefinition(_, _, _, memberTypes) => - memberTypes.flatMap(name => typesMap.get(safeTypeName(name))) - } - .getOrElse(Nil) - .collect { case o: ObjectTypeDefinition => o } - .sortBy(_.name) - val interfaceTypes = typesMap - .get(fieldType) - .collect { case InterfaceTypeDefinition(_, name, _, _) => name } - .map(interface => - typesMap.values.collect { - case o @ ObjectTypeDefinition(_, _, implements, _, _) if implements.exists(_.name == interface) => o - } - ) - .getOrElse(Nil) - .toList - .sortBy(_.name) - val typeLetter = getTypeLetter(typesMap) - val (typeParam, innerSelection, outputType, builder) = - if (isScalar) { - ( - "", - "", - writeType(field.ofType), - writeTypeBuilder(field.ofType, "Scalar()") - ) - } else if (unionTypes.nonEmpty) { - if (optionalUnion) { - ( - s"[$typeLetter]", - s"(${unionTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", - s"Option[${writeType(field.ofType).replace(fieldType, typeLetter)}]", - writeTypeBuilder( - field.ofType, - s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> on${t.name}.fold[FieldBuilder[Option[A]]](NullField)(a => OptionOf(Obj(a)))""").mkString(", ")}))" - ) - ) - } else { - ( - s"[$typeLetter]", - s"(${unionTypes.map(t => s"""on${t.name}: SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]""").mkString(", ")})", - writeType(field.ofType).replace(fieldType, typeLetter), - writeTypeBuilder( - field.ofType, - s"ChoiceOf(Map(${unionTypes.map(t => s""""${t.name}" -> Obj(on${t.name})""").mkString(", ")}))" - ) - ) - } - } else if (interfaceTypes.nonEmpty) { - ( - s"[$typeLetter]", - s"(${interfaceTypes.map(t => s"""on${t.name}: Option[SelectionBuilder[${safeTypeName(t.name)}, $typeLetter]] = None""").mkString(", ")})", - writeType(field.ofType).replace(fieldType, typeLetter), - writeTypeBuilder( - field.ofType, - s"ChoiceOf(Map(${interfaceTypes.map(t => s""""${t.name}" -> on${t.name}""").mkString(", ")}).collect { case (k, Some(v)) => k -> Obj(v)})" - ) - ) - } else { - ( - s"[$typeLetter]", - s"(innerSelection: SelectionBuilder[$fieldType, $typeLetter])", - writeType(field.ofType).replace(fieldType, typeLetter), - writeTypeBuilder(field.ofType, "Obj(innerSelection)") - ) - } - val args = field.args match { - case Nil => "" - case list => s"(${writeArgumentFields(list)})" - } - val argBuilder = field.args match { - case Nil => "" - case list => - s", arguments = List(${list.zipWithIndex.map { case (arg, idx) => - s"""Argument("${arg.name}", ${safeName(arg.name)}, "${arg.ofType.toString}")(encoder$idx)""" - }.mkString(", ")})" - } - val implicits = field.args match { - case Nil => "" - case list => - s"(implicit ${list.zipWithIndex.map { case (arg, idx) => - s"""encoder$idx: ArgEncoder[${writeType(arg.ofType)}]""" - }.mkString(", ")})" - } - - val name = if (optionalUnion && unionTypes.nonEmpty) safeName(field.name + "Option") else safeName(field.name) - val owner = if (typeParam.nonEmpty) Some(fieldType) else None - val fieldTypeInfo = FieldTypeInfo( - field.name, - name, - outputType, - interfaceTypes.map(_.name), - unionTypes.map(_.name), - field.args, - owner - ) - FieldInfo( - field.name, - name, - description, - deprecated, - typeName, - typeParam, - args, - implicits, - innerSelection, - outputType, - builder, - argBuilder, - fieldTypeInfo - ) - } - - def writeArgumentFields( - args: List[InputValueDefinition] - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = - s"${args.map(arg => s"${safeName(arg.name)} : ${writeType(arg.ofType)}${writeDefaultArgument(arg)}").mkString(", ")}" - - def writeDefaultArgument(arg: InputValueDefinition): String = - arg.ofType match { - case t if t.nullable => " = None" - case ListType(_, _) => " = Nil" - case _ => "" - } - - def writeDescription(description: Option[String]): String = - description.fold("")(d => s"""@GQLDescription("$d") - |""".stripMargin) - - def mapTypeName( - s: String - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = s match { - case "Float" => "Double" - case "ID" => "String" - case other => safeTypeName(other) - } - - def writeType( - t: Type - )(implicit mappingClashedTypeNames: MappingClashedTypeNames, scalarMappings: ScalarMappings): String = t match { - case NamedType(name, true) => mapTypeName(name) - case NamedType(name, false) => s"Option[${mapTypeName(name)}]" - case ListType(ofType, true) => s"List[${writeType(ofType)}]" - case ListType(ofType, false) => s"Option[List[${writeType(ofType)}]]" - } - - def writeTypeBuilder(t: Type, inner: String): String = t match { - case NamedType(_, true) => inner - case NamedType(_, false) => s"OptionOf($inner)" - case ListType(of, true) => s"ListOf(${writeTypeBuilder(of, inner)})" - case ListType(of, false) => s"OptionOf(ListOf(${writeTypeBuilder(of, inner)}))" - } - - @tailrec - def getTypeName(t: Type): String = t match { - case NamedType(name, _) => name - case ListType(ofType, _) => getTypeName(ofType) - } - - val supportedScalars = - Set("Int", "Float", "Double", "Long", "Unit", "String", "Boolean", "BigInt", "BigDecimal") - - def isScalarSupported(scalar: String)(implicit scalarMappings: ScalarMappings): Boolean = - supportedScalars.contains(scalar) || scalarMappings.map(_.contains(scalar)).getOrElse(false) - - val reservedKeywords = Set( - "abstract", - "as", - "case", - "catch", - "class", - "def", - "derives", - "do", - "else", - "enum", - "export", - "extends", - "extension", - "false", - "final", - "finally", - "for", - "forSome", - "given", - "if", - "implicit", - "import", - "infix", - "inline", - "lazy", - "match", - "new", - "null", - "object", - "opaque", - "open", - "override", - "package", - "private", - "protected", - "return", - "sealed", - "super", - "then", - "this", - "throw", - "trait", - "transparent", - "try", - "true", - "type", - "using", - "val", - "var", - "while", - "with", - "yield", - "_" - ) - - val caseClassReservedFields = - Set("wait", "notify", "toString", "notifyAll", "hashCode", "getClass", "finalize", "equals", "clone") - final case class FieldTypeInfo( rawName: String, name: String, @@ -933,5 +831,4 @@ object ClientWriter { argBuilder: String, typeInfo: FieldTypeInfo ) - } diff --git a/tools/src/main/scala/caliban/tools/Codegen.scala b/tools/src/main/scala/caliban/tools/Codegen.scala index 0f8648980..b425e50c3 100644 --- a/tools/src/main/scala/caliban/tools/Codegen.scala +++ b/tools/src/main/scala/caliban/tools/Codegen.scala @@ -1,8 +1,8 @@ package caliban.tools -import caliban.tools.implicits.ScalarMappings import zio.blocking.{ blocking, Blocking } import zio.{ RIO, Task, UIO, ZIO } + import java.io.{ File, PrintWriter } object Codegen { @@ -35,8 +35,13 @@ object Codegen { code = genType match { case GenType.Schema => List( - objectName -> SchemaWriter.write(schema, packageName, effect, arguments.imports, abstractEffectType)( - ScalarMappings(scalarMappings) + objectName -> SchemaWriter.write( + schema, + packageName, + effect, + arguments.imports, + scalarMappings, + abstractEffectType ) ) case GenType.Client => @@ -47,9 +52,8 @@ object Codegen { genView, arguments.imports, splitFiles, - extensibleEnums - )( - ScalarMappings(scalarMappings) + extensibleEnums, + scalarMappings ) } formatted <- if (enableFmt) Formatter.format(code, arguments.fmtPath) else Task.succeed(code) diff --git a/tools/src/main/scala/caliban/tools/SchemaWriter.scala b/tools/src/main/scala/caliban/tools/SchemaWriter.scala index a40a27a9e..0a7177d48 100644 --- a/tools/src/main/scala/caliban/tools/SchemaWriter.scala +++ b/tools/src/main/scala/caliban/tools/SchemaWriter.scala @@ -3,8 +3,6 @@ package caliban.tools import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition._ import caliban.parsing.adt.Type.{ ListType, NamedType } import caliban.parsing.adt.{ Document, Type } -import caliban.tools.implicits.Implicits._ -import caliban.tools.implicits.ScalarMappings object SchemaWriter { @@ -13,8 +11,206 @@ object SchemaWriter { packageName: Option[String] = None, effect: String = "zio.UIO", imports: Option[List[String]] = None, + scalarMappings: Option[Map[String, String]], isEffectTypeAbstract: Boolean = false - )(implicit scalarMappings: ScalarMappings): String = { + ): String = { + + val interfaceImplementationsMap = (for { + objectDef <- schema.objectTypeDefinitions + interfaceDef <- schema.interfaceTypeDefinitions + if objectDef.implements.exists(_.name == interfaceDef.name) + } yield interfaceDef -> objectDef).groupBy(_._1).map { case (definition, tuples) => + definition -> tuples.map(_._2) + } + + def safeName(name: String): String = + if (reservedKeywords.contains(name) || name.endsWith("_")) s"`$name`" + else if (caseClassReservedFields.contains(name)) s"$name$$" + else name + + def reservedType(typeDefinition: ObjectTypeDefinition): Boolean = + typeDefinition.name == "Query" || typeDefinition.name == "Mutation" || typeDefinition.name == "Subscription" + + def writeRootField(field: FieldDefinition, od: ObjectTypeDefinition): String = { + val argsTypeName = if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else "" + s"${safeName(field.name)} :$argsTypeName $effect[${writeType(field.ofType)}]" + } + + def writeRootQueryOrMutationDef(op: ObjectTypeDefinition): String = { + val typeParamOrEmpty = if (isEffectTypeAbstract) s"[$effect[_]]" else "" + s""" + |${writeDescription(op.description)}final case class ${op.name}$typeParamOrEmpty( + |${op.fields.map(c => writeRootField(c, op)).mkString(",\n")} + |)""".stripMargin + + } + def writeSubscriptionField(field: FieldDefinition, od: ObjectTypeDefinition): String = + "%s:%s ZStream[Any, Nothing, %s]".format( + safeName(field.name), + if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else "", + writeType(field.ofType) + ) + + def writeRootSubscriptionDef(op: ObjectTypeDefinition): String = + s""" + |${writeDescription(op.description)}final case class ${op.name}( + |${op.fields.map(c => writeSubscriptionField(c, op)).mkString(",\n")} + |)""".stripMargin + + def writeObject(typedef: ObjectTypeDefinition): String = + s"""${writeDescription(typedef.description)}final case class ${typedef.name}(${typedef.fields + .map(writeField(_, typedef)) + .mkString(", ")})""" + + def writeInputObject(typedef: InputObjectTypeDefinition): String = + s"""${writeDescription(typedef.description)}final case class ${typedef.name}(${typedef.fields + .map(writeInputValue) + .mkString(", ")})""" + + def writeEnum(typedef: EnumTypeDefinition): String = + s"""${writeDescription(typedef.description)}sealed trait ${typedef.name} extends scala.Product with scala.Serializable + + object ${typedef.name} { + ${typedef.enumValuesDefinition + .map(v => s"${writeDescription(v.description)}case object ${safeName(v.enumValue)} extends ${typedef.name}") + .mkString("\n")} + } + """ + + def writeUnions(unions: Map[UnionTypeDefinition, List[ObjectTypeDefinition]]): String = + if (unions.nonEmpty) { + val flattened = unions.toList.flatMap { case (unionType, objectTypes) => objectTypes.map(_ -> unionType) } + + val (unionsWithoutReusedMembers, reusedUnionMembers) = flattened + .foldLeft( + ( + Map.empty[UnionTypeDefinition, List[ObjectTypeDefinition]], + Map.empty[ObjectTypeDefinition, List[UnionTypeDefinition]] + ) + ) { + case ( + (unionsWithoutReusedMembers, reusedUnionMembers), + (objectType, unionType) + ) => + val isReused = reusedUnionMembers.contains(objectType) || + flattened.exists { case (_objectType, _unionType) => + _unionType.name != unionType.name && _objectType.name == objectType.name + } + + if (isReused) { + ( + unionsWithoutReusedMembers, + reusedUnionMembers.updated( + objectType, + reusedUnionMembers.getOrElse(objectType, List.empty) :+ unionType + ) + ) + } else { + ( + unionsWithoutReusedMembers.updated( + unionType, + unionsWithoutReusedMembers.getOrElse(unionType, List.empty) :+ objectType + ), + reusedUnionMembers + ) + } + } + + s"""${unions.keys.map(writeUnionSealedTrait).mkString("\n")} + + ${unionsWithoutReusedMembers.map { case (union, objects) => writeNotReusedMembers(union, objects) } + .mkString("\n")} + + ${reusedUnionMembers.map { case (objectType, unions) => writeReusedUnionMember(objectType, unions) } + .mkString("\n")} + """ + } else "" + + def writeUnionSealedTrait(union: UnionTypeDefinition): String = + s"""${writeDescription( + union.description + )}sealed trait ${union.name} extends scala.Product with scala.Serializable""" + + def writeReusedUnionMember(typedef: ObjectTypeDefinition, unions: List[UnionTypeDefinition]): String = + s"${writeObject(typedef)} extends ${unions.map(_.name).mkString(" with ")}" + + def writeNotReusedMembers(typedef: UnionTypeDefinition, objects: List[ObjectTypeDefinition]): String = + s"""object ${typedef.name} { + ${objects + .map(o => s"${writeObject(o)} extends ${typedef.name}") + .mkString("\n")} + } + """ + + def writeInterface(interface: InterfaceTypeDefinition, impls: List[ObjectTypeDefinition]): String = + s"""@GQLInterface + ${writeDescription(interface.description)}sealed trait ${interface.name} extends scala.Product with scala.Serializable { + ${interface.fields.map(field => s"def ${safeName(field.name)} : ${writeType(field.ofType)}").mkString("\n")} + } + + object ${interface.name} { + ${impls + .map(o => s"${writeObject(o)} extends ${interface.name}") + .mkString("\n")} + } + """ + + def writeField(field: FieldDefinition, of: ObjectTypeDefinition): String = + if (field.args.nonEmpty) { + s"${writeDescription(field.description)}${safeName(field.name)} : ${argsName(field, of)} => ${writeType(field.ofType)}" + } else { + s"""${writeDescription(field.description)}${safeName(field.name)} : ${writeType(field.ofType)}""" + } + + def writeInputValue(value: InputValueDefinition): String = + s"""${writeDescription(value.description)}${safeName(value.name)} : ${writeType(value.ofType)}""" + + def writeArguments(field: FieldDefinition, of: ObjectTypeDefinition): String = { + def fields(args: List[InputValueDefinition]): String = + s"${args.map(arg => s"${safeName(arg.name)} : ${writeType(arg.ofType)}").mkString(", ")}" + + if (field.args.nonEmpty) { + s"final case class ${argsName(field, of)}(${fields(field.args)})" + } else { + "" + } + } + + def argsName(field: FieldDefinition, od: ObjectTypeDefinition): String = + s"${od.name.capitalize}${field.name.capitalize}Args" + + def escapeDoubleQuotes(input: String): String = + input.replace("\"", "\\\"") + + def writeDescription(description: Option[String]): String = + description.fold("") { + case d if d.contains("\n") => + s"""@GQLDescription(\"\"\"${escapeDoubleQuotes(d)}\"\"\") + |""".stripMargin + case d => + s"""@GQLDescription("${escapeDoubleQuotes(d)}") + |""".stripMargin + } + + def writeType(t: Type): String = { + def write(name: String): String = scalarMappings + .flatMap(_.get(name)) + .getOrElse(checkIsInterfaceImpl(name)) + + def checkIsInterfaceImpl(name: String): String = interfaceImplementationsMap.find { case (_, impls) => + impls.exists(_.name == name) + }.map { case (interface, _) => + s"${interface.name}.$name" + }.getOrElse(name) + + t match { + case NamedType(name, true) => write(name) + case NamedType(name, false) => s"Option[${write(name)}]" + case ListType(ofType, true) => s"List[${writeType(ofType)}]" + case ListType(ofType, false) => s"Option[List[${writeType(ofType)}]]" + } + } + val schemaDef = schema.schemaDefinition val argsTypes = schema.objectTypeDefinitions @@ -27,13 +223,20 @@ object SchemaWriter { val unions = writeUnions(unionTypes) + val interfaceImplementations = interfaceImplementationsMap.values.flatten + + val interfacesStr = interfaceImplementationsMap.map { case (interface, impls) => + writeInterface(interface, impls) + }.mkString("\n") + val objects = schema.objectTypeDefinitions .filterNot(obj => reservedType(obj) || schemaDef.exists(_.query.getOrElse("Query") == obj.name) || schemaDef.exists(_.mutation.getOrElse("Mutation") == obj.name) || schemaDef.exists(_.subscription.getOrElse("Subscription") == obj.name) || - unionTypes.values.flatten.exists(_.name == obj.name) + unionTypes.values.flatten.exists(_.name == obj.name) || + interfaceImplementations.exists(_.name == obj.name) ) .map(writeObject) .mkString("\n") @@ -44,12 +247,12 @@ object SchemaWriter { val queries = schema .objectTypeDefinition(schemaDef.flatMap(_.query).getOrElse("Query")) - .map(t => writeRootQueryOrMutationDef(t, effect, isEffectTypeAbstract)) + .map(t => writeRootQueryOrMutationDef(t)) .getOrElse("") val mutations = schema .objectTypeDefinition(schemaDef.flatMap(_.mutation).getOrElse("Mutation")) - .map(t => writeRootQueryOrMutationDef(t, effect, isEffectTypeAbstract)) + .map(t => writeRootQueryOrMutationDef(t)) .getOrElse("") val subscriptions = schema @@ -60,7 +263,8 @@ object SchemaWriter { val additionalImportsString = imports.fold("")(_.map(i => s"import $i").mkString("\n")) val hasSubscriptions = subscriptions.nonEmpty - val hasTypes = argsTypes.length + objects.length + enums.length + unions.length + inputs.length > 0 + val hasTypes = argsTypes.length + objects.length + enums.length + unions.length + + inputs.length + interfacesStr.length > 0 val hasOperations = queries.length + mutations.length + subscriptions.length > 0 val typesAndOperations = s""" @@ -70,6 +274,7 @@ object SchemaWriter { objects + "\n" + inputs + "\n" + unions + "\n" + + interfacesStr + "\n" + enums + "\n" + "\n}\n" else ""} @@ -91,178 +296,4 @@ object SchemaWriter { $typesAndOperations """ } - - def safeName(name: String): String = ClientWriter.safeName(name) - - def reservedType(typeDefinition: ObjectTypeDefinition): Boolean = - typeDefinition.name == "Query" || typeDefinition.name == "Mutation" || typeDefinition.name == "Subscription" - - def writeRootField(field: FieldDefinition, od: ObjectTypeDefinition, effect: String)(implicit - scalarMappings: ScalarMappings - ): String = { - val argsTypeName = if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else "" - s"${safeName(field.name)} :$argsTypeName $effect[${writeType(field.ofType)}]" - } - - def writeRootQueryOrMutationDef(op: ObjectTypeDefinition, effect: String, isEffectTypeAbstract: Boolean)(implicit - scalarMappings: ScalarMappings - ): String = { - val typeParamOrEmpty = if (isEffectTypeAbstract) s"[$effect[_]]" else "" - s""" - |${writeDescription(op.description)}final case class ${op.name}$typeParamOrEmpty( - |${op.fields.map(c => writeRootField(c, op, effect)).mkString(",\n")} - |)""".stripMargin - - } - def writeSubscriptionField(field: FieldDefinition, od: ObjectTypeDefinition)(implicit - scalarMappings: ScalarMappings - ): String = - "%s:%s ZStream[Any, Nothing, %s]".format( - safeName(field.name), - if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else "", - writeType(field.ofType) - ) - - def writeRootSubscriptionDef(op: ObjectTypeDefinition)(implicit scalarMappings: ScalarMappings): String = - s""" - |${writeDescription(op.description)}final case class ${op.name}( - |${op.fields.map(c => writeSubscriptionField(c, op)).mkString(",\n")} - |)""".stripMargin - - def writeObject(typedef: ObjectTypeDefinition)(implicit scalarMappings: ScalarMappings): String = - s"""${writeDescription(typedef.description)}final case class ${typedef.name}(${typedef.fields - .map(writeField(_, typedef)) - .mkString(", ")})""" - - def writeInputObject(typedef: InputObjectTypeDefinition)(implicit scalarMappings: ScalarMappings): String = - s"""${writeDescription(typedef.description)}final case class ${typedef.name}(${typedef.fields - .map(writeInputValue) - .mkString(", ")})""" - - def writeEnum(typedef: EnumTypeDefinition): String = - s"""${writeDescription(typedef.description)}sealed trait ${typedef.name} extends scala.Product with scala.Serializable - - object ${typedef.name} { - ${typedef.enumValuesDefinition - .map(v => s"${writeDescription(v.description)}case object ${safeName(v.enumValue)} extends ${typedef.name}") - .mkString("\n")} - } - """ - - def writeUnions(unions: Map[UnionTypeDefinition, List[ObjectTypeDefinition]])(implicit - scalarMappings: ScalarMappings - ): String = - if (unions.nonEmpty) { - val flattened = unions.toList.flatMap { case (unionType, objectTypes) => objectTypes.map(_ -> unionType) } - - val (unionsWithoutReusedMembers, reusedUnionMembers) = flattened - .foldLeft( - ( - Map.empty[UnionTypeDefinition, List[ObjectTypeDefinition]], - Map.empty[ObjectTypeDefinition, List[UnionTypeDefinition]] - ) - ) { - case ( - (unionsWithoutReusedMembers, reusedUnionMembers), - (objectType, unionType) - ) => - val isReused = reusedUnionMembers.contains(objectType) || - flattened.exists { case (_objectType, _unionType) => - _unionType.name != unionType.name && _objectType.name == objectType.name - } - - if (isReused) { - ( - unionsWithoutReusedMembers, - reusedUnionMembers.updated( - objectType, - reusedUnionMembers.getOrElse(objectType, List.empty) :+ unionType - ) - ) - } else { - ( - unionsWithoutReusedMembers.updated( - unionType, - unionsWithoutReusedMembers.getOrElse(unionType, List.empty) :+ objectType - ), - reusedUnionMembers - ) - } - } - - s"""${unions.keys.map(writeUnionSealedTrait).mkString("\n")} - - ${unionsWithoutReusedMembers.map { case (union, objects) => writeNotReusedMembers(union, objects) } - .mkString("\n")} - - ${reusedUnionMembers.map { case (objectType, unions) => writeReusedUnionMember(objectType, unions) } - .mkString("\n")} - """ - } else "" - - def writeUnionSealedTrait(union: UnionTypeDefinition): String = - s"""${writeDescription( - union.description - )}sealed trait ${union.name} extends scala.Product with scala.Serializable""" - - def writeReusedUnionMember(typedef: ObjectTypeDefinition, unions: List[UnionTypeDefinition])(implicit - scalarMappings: ScalarMappings - ): String = - s"${writeObject(typedef)} extends ${unions.map(_.name).mkString(" with ")}" - - def writeNotReusedMembers(typedef: UnionTypeDefinition, objects: List[ObjectTypeDefinition])(implicit - scalarMappings: ScalarMappings - ): String = - s"""object ${typedef.name} { - ${objects - .map(o => s"${writeObject(o)} extends ${typedef.name}") - .mkString("\n")} - } - """ - - def writeField(field: FieldDefinition, of: ObjectTypeDefinition)(implicit scalarMappings: ScalarMappings): String = - if (field.args.nonEmpty) { - s"${writeDescription(field.description)}${safeName(field.name)} : ${argsName(field, of)} => ${writeType(field.ofType)}" - } else { - s"""${writeDescription(field.description)}${safeName(field.name)} : ${writeType(field.ofType)}""" - } - - def writeInputValue(value: InputValueDefinition)(implicit scalarMappings: ScalarMappings): String = - s"""${writeDescription(value.description)}${safeName(value.name)} : ${writeType(value.ofType)}""" - - def writeArguments(field: FieldDefinition, of: ObjectTypeDefinition)(implicit - scalarMappings: ScalarMappings - ): String = { - def fields(args: List[InputValueDefinition]): String = - s"${args.map(arg => s"${safeName(arg.name)} : ${writeType(arg.ofType)}").mkString(", ")}" - - if (field.args.nonEmpty) { - s"final case class ${argsName(field, of)}(${fields(field.args)})" - } else { - "" - } - } - - private def argsName(field: FieldDefinition, od: ObjectTypeDefinition): String = - s"${od.name.capitalize}${field.name.capitalize}Args" - - def escapeDoubleQuotes(input: String): String = - input.replace("\"", "\\\"") - - def writeDescription(description: Option[String]): String = - description.fold("") { - case d if d.contains("\n") => - s"""@GQLDescription(\"\"\"${escapeDoubleQuotes(d)}\"\"\") - |""".stripMargin - case d => - s"""@GQLDescription("${escapeDoubleQuotes(d)}") - |""".stripMargin - } - - def writeType(t: Type)(implicit scalarMappings: ScalarMappings): String = t match { - case NamedType(name, true) => scalarMappings.flatMap(m => m.get(name)).getOrElse(name) - case NamedType(name, false) => s"Option[${scalarMappings.flatMap(m => m.get(name)).getOrElse(name)}]" - case ListType(ofType, true) => s"List[${writeType(ofType)}]" - case ListType(ofType, false) => s"Option[List[${writeType(ofType)}]]" - } } diff --git a/tools/src/main/scala/caliban/tools/compiletime/Config.scala b/tools/src/main/scala/caliban/tools/compiletime/Config.scala index 91cb0758b..ee7c01e62 100644 --- a/tools/src/main/scala/caliban/tools/compiletime/Config.scala +++ b/tools/src/main/scala/caliban/tools/compiletime/Config.scala @@ -1,6 +1,7 @@ package caliban.tools.compiletime import caliban.tools.CalibanCommonSettings +import caliban.tools.Codegen.GenType trait Config { case class ClientGenerationSettings( @@ -25,7 +26,8 @@ trait Config { imports = imports, splitFiles = Some(splitFiles), enableFmt = Some(enableFmt), - extensibleEnums = Some(extensibleEnums) + extensibleEnums = Some(extensibleEnums), + GenType.Client ) private[caliban] def asScalaCode: String = { diff --git a/tools/src/main/scala/caliban/tools/implicits/Implicits.scala b/tools/src/main/scala/caliban/tools/implicits/Implicits.scala deleted file mode 100644 index 48012accd..000000000 --- a/tools/src/main/scala/caliban/tools/implicits/Implicits.scala +++ /dev/null @@ -1,14 +0,0 @@ -package caliban.tools.implicits - -import caliban.parsing.adt.Definition.TypeSystemDefinition - -import scala.language.implicitConversions - -object Implicits { - implicit def typesMapToMap(typesMap: TypesMap): Map[String, TypeSystemDefinition.TypeDefinition] = typesMap.typesMap - - implicit def scalarMappingsToMap(typeMappings: ScalarMappings): Option[Map[String, String]] = typeMappings.scalarMap - - implicit def mappingClashedTypeNamesToMap(mappingClashedTypeNames: MappingClashedTypeNames): Map[String, String] = - mappingClashedTypeNames.clashedTypesMap -} diff --git a/tools/src/main/scala/caliban/tools/implicits/MappingClashedTypeNames.scala b/tools/src/main/scala/caliban/tools/implicits/MappingClashedTypeNames.scala deleted file mode 100644 index 9413fbced..000000000 --- a/tools/src/main/scala/caliban/tools/implicits/MappingClashedTypeNames.scala +++ /dev/null @@ -1,3 +0,0 @@ -package caliban.tools.implicits - -case class MappingClashedTypeNames(clashedTypesMap: Map[String, String]) diff --git a/tools/src/main/scala/caliban/tools/implicits/ScalarMappings.scala b/tools/src/main/scala/caliban/tools/implicits/ScalarMappings.scala deleted file mode 100644 index 21b9de624..000000000 --- a/tools/src/main/scala/caliban/tools/implicits/ScalarMappings.scala +++ /dev/null @@ -1,3 +0,0 @@ -package caliban.tools.implicits - -case class ScalarMappings(scalarMap: Option[Map[String, String]]) diff --git a/tools/src/main/scala/caliban/tools/implicits/TypesMap.scala b/tools/src/main/scala/caliban/tools/implicits/TypesMap.scala deleted file mode 100644 index 8c1b7b64d..000000000 --- a/tools/src/main/scala/caliban/tools/implicits/TypesMap.scala +++ /dev/null @@ -1,5 +0,0 @@ -package caliban.tools.implicits - -import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition - -case class TypesMap(typesMap: Map[String, TypeDefinition]) diff --git a/tools/src/main/scala/caliban/tools/package.scala b/tools/src/main/scala/caliban/tools/package.scala new file mode 100644 index 000000000..e4c92eb79 --- /dev/null +++ b/tools/src/main/scala/caliban/tools/package.scala @@ -0,0 +1,67 @@ +package caliban + +package object tools { + val supportedScalars = Set("Int", "Float", "Double", "Long", "Unit", "String", "Boolean", "BigInt", "BigDecimal") + + val reservedKeywords = Set( + "abstract", + "as", + "case", + "catch", + "class", + "def", + "derives", + "do", + "else", + "enum", + "export", + "extends", + "extension", + "false", + "final", + "finally", + "for", + "forSome", + "given", + "if", + "implicit", + "import", + "infix", + "inline", + "lazy", + "match", + "new", + "null", + "object", + "opaque", + "open", + "override", + "package", + "private", + "protected", + "return", + "sealed", + "super", + "then", + "this", + "throw", + "trait", + "transparent", + "try", + "true", + "type", + "using", + "val", + "var", + "while", + "with", + "yield", + "_" + ) + + val caseClassReservedFields = + Set("wait", "notify", "toString", "notifyAll", "hashCode", "getClass", "finalize", "equals", "clone") + + val tripleQuotes = "\"\"\"" + val doubleQuotes = "\"" +} diff --git a/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala b/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala index 57dc40b42..6ff7b3c9a 100644 --- a/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala +++ b/tools/src/test/scala/caliban/tools/ClientWriterSpec.scala @@ -1,7 +1,6 @@ package caliban.tools import caliban.parsing.Parser -import caliban.tools.implicits.ScalarMappings import zio.RIO import zio.blocking.Blocking import zio.test.Assertion._ @@ -20,8 +19,11 @@ object ClientWriterSpec extends DefaultRunnableSpec { .flatMap(doc => Formatter.format( ClientWriter - .write(doc, additionalImports = Some(additionalImports), extensibleEnums = extensibleEnums)( - ScalarMappings(Some(scalarMappings)) + .write( + doc, + additionalImports = Some(additionalImports), + extensibleEnums = extensibleEnums, + scalarMappings = Some(scalarMappings) ) .head ._2, @@ -36,9 +38,7 @@ object ClientWriterSpec extends DefaultRunnableSpec { .parseQuery(schema) .flatMap(doc => Formatter.format( - ClientWriter.write(doc, packageName = Some("test"), splitFiles = true)( - ScalarMappings(Some(scalarMappings)) - ), + ClientWriter.write(doc, packageName = Some("test"), splitFiles = true, scalarMappings = Some(scalarMappings)), None ) ) diff --git a/tools/src/test/scala/caliban/tools/ClientWriterViewSpec.scala b/tools/src/test/scala/caliban/tools/ClientWriterViewSpec.scala index 9198166d7..18b0fd6c5 100644 --- a/tools/src/test/scala/caliban/tools/ClientWriterViewSpec.scala +++ b/tools/src/test/scala/caliban/tools/ClientWriterViewSpec.scala @@ -1,7 +1,6 @@ package caliban.tools import caliban.parsing.Parser -import caliban.tools.implicits.ScalarMappings import zio.RIO import zio.blocking.Blocking import zio.test.Assertion._ @@ -13,7 +12,7 @@ object ClientWriterViewSpec extends DefaultRunnableSpec { val gen: String => RIO[Blocking, String] = (schema: String) => Parser .parseQuery(schema) - .flatMap(doc => Formatter.format(ClientWriter.write(doc, genView = true)(ScalarMappings(None)).head._2, None)) + .flatMap(doc => Formatter.format(ClientWriter.write(doc, genView = true, scalarMappings = None).head._2, None)) override def spec: ZSpec[TestEnvironment, Any] = suite("ClientWriterViewSpec")( diff --git a/tools/src/test/scala/caliban/tools/SchemaWriterSpec.scala b/tools/src/test/scala/caliban/tools/SchemaWriterSpec.scala index d3c092b10..9e22bdd08 100644 --- a/tools/src/test/scala/caliban/tools/SchemaWriterSpec.scala +++ b/tools/src/test/scala/caliban/tools/SchemaWriterSpec.scala @@ -1,75 +1,57 @@ package caliban.tools import caliban.parsing.Parser -import caliban.tools.implicits.ScalarMappings +import zio.RIO import zio.blocking.Blocking import zio.test.Assertion.equalTo import zio.test._ import zio.test.environment.TestEnvironment -import zio.{ RIO, ZIO } object SchemaWriterSpec extends DefaultRunnableSpec { - implicit val scalarMappings: ScalarMappings = ScalarMappings(None) - def gen( schema: String, + packageName: Option[String] = None, + effect: String = "zio.UIO", + imports: List[String] = List.empty, scalarMappings: Map[String, String] = Map.empty, - customImports: List[String] = List.empty + isEffectTypeAbstract: Boolean = false ): RIO[Blocking, String] = Parser - .parseQuery(schema) + .parseQuery(schema.stripMargin) .flatMap(doc => Formatter - .format(SchemaWriter.write(doc, imports = Some(customImports))(ScalarMappings(Some(scalarMappings))), None) + .format( + SchemaWriter.write( + doc, + packageName, + effect, + Some(imports), + Some(scalarMappings), + isEffectTypeAbstract + ), + None + ) ) - override def spec: ZSpec[TestEnvironment, Any] = - suite("SchemaWriterSpec")( - testM("type with field parameter") { - val schema = - """ + val assertions = List( + ( + "type with field parameter", + gen(""" type Hero { name(pad: Int!): String! nick: String! bday: Int } - |""".stripMargin - - val typeCaseClass: ZIO[Blocking, Throwable, String] = - Parser - .parseQuery(schema) - .map(_.objectTypeDefinitions.map(SchemaWriter.writeObject).mkString("\n")) - .flatMap(Formatter.format(_, None).map(_.trim)) - - val typeCaseClassArgs: ZIO[Blocking, Throwable, String] = - Parser - .parseQuery(schema) - .map { doc => - (for { - typeDef <- doc.objectTypeDefinitions - typeDefField <- typeDef.fields - argClass = SchemaWriter.writeArguments(typeDefField, typeDef) if argClass.nonEmpty - } yield argClass).mkString("\n") - } - .flatMap(Formatter.format(_, None).map(_.trim)) - - val a = assertM(typeCaseClass)( - equalTo( - "final case class Hero(name: HeroNameArgs => String, nick: String, bday: Option[Int])" - ) - ) - - val b = assertM(typeCaseClassArgs)( - equalTo( - "final case class HeroNameArgs(pad: Int)" - ) - ) - - ZIO.mapN(a, b)(_ && _) - }, - testM("simple queries") { - val schema = - """ + |"""), + """ object Types { + | final case class HeroNameArgs(pad: Int) + | final case class Hero(name: HeroNameArgs => String, nick: String, bday: Option[Int]) + | + |}""" + ), + ( + "simple queries", + gen(""" type Query { user(id: Int): User userList: [User]! @@ -78,74 +60,74 @@ object SchemaWriterSpec extends DefaultRunnableSpec { id: Int name: String profilePic: String - }""" - - val result = Parser - .parseQuery(schema) - .map( - _.objectTypeDefinition("Query") - .map(SchemaWriter.writeRootQueryOrMutationDef(_, "zio.UIO", false)) - .mkString("\n") - ) - .flatMap(Formatter.format(_, None).map(_.trim)) - - assertM(result)( - equalTo( - """final case class Query( - user: QueryUserArgs => zio.UIO[Option[User]], - userList: zio.UIO[List[Option[User]]] -)""".stripMargin - ) - ) - }, - testM("simple mutation") { - val schema = - """ + }"""), + """import Types._ + | + |object Types { + | final case class QueryUserArgs(id: Option[Int]) + | final case class User(id: Option[Int], name: Option[String], profilePic: Option[String]) + | + |} + | + |object Operations { + | + | final case class Query( + | user: QueryUserArgs => zio.UIO[Option[User]], + | userList: zio.UIO[List[Option[User]]] + | ) + | + |}""" + ), + ( + "simple mutation", + gen(""" type Mutation { setMessage(message: String): String } - """ - val result = Parser - .parseQuery(schema) - .map( - _.objectTypeDefinition("Mutation") - .map(SchemaWriter.writeRootQueryOrMutationDef(_, "zio.UIO", false)) - .mkString("\n") - ) - .flatMap(Formatter.format(_, None).map(_.trim)) - - assertM(result)( - equalTo( - """final case class Mutation( - | setMessage: MutationSetMessageArgs => zio.UIO[Option[String]] - |)""".stripMargin - ) - ) - }, - testM("simple subscription") { - val schema = - """ + """), + """import Types._ + | + |object Types { + | final case class MutationSetMessageArgs(message: Option[String]) + | + |} + | + |object Operations { + | + | final case class Mutation( + | setMessage: MutationSetMessageArgs => zio.UIO[Option[String]] + | ) + | + |}""" + ), + ( + "simple subscription", + gen(""" type Subscription { UserWatch(id: Int!): String! } - """ - - val result = Parser - .parseQuery(schema) - .map(_.objectTypeDefinition("Subscription").map(SchemaWriter.writeRootSubscriptionDef).mkString("\n")) - - assertM(result)( - equalTo( - """ - |final case class Subscription( - |UserWatch: SubscriptionUserWatchArgs => ZStream[Any, Nothing, String] - |)""".stripMargin - ) - ) - }, - testM("simple queries with abstracted effect type") { - val schema = - """ + """), + """import Types._ + | + |import zio.stream.ZStream + | + |object Types { + | final case class SubscriptionUserWatchArgs(id: Int) + | + |} + | + |object Operations { + | + | final case class Subscription( + | UserWatch: SubscriptionUserWatchArgs => ZStream[Any, Nothing, String] + | ) + | + |}""" + ), + ( + "simple queries with abstracted effect type", + gen( + """ type Query { user(id: Int): User userList: [User]! @@ -154,436 +136,427 @@ object SchemaWriterSpec extends DefaultRunnableSpec { id: Int name: String profilePic: String - }""" - - val result = Parser - .parseQuery(schema) - .map( - _.objectTypeDefinition("Query").map(SchemaWriter.writeRootQueryOrMutationDef(_, "F", true)).mkString("\n") - ) - .flatMap(Formatter.format(_, None).map(_.trim)) - - assertM(result)( - equalTo( - """final case class Query[F[_]]( - user: QueryUserArgs => F[Option[User]], - userList: F[List[Option[User]]] -)""".stripMargin - ) - ) - }, - testM("simple mutation with abstracted effect type") { - val schema = - """ + }""", + effect = "F", + isEffectTypeAbstract = true + ), + """import Types._ + | + |object Types { + | final case class QueryUserArgs(id: Option[Int]) + | final case class User(id: Option[Int], name: Option[String], profilePic: Option[String]) + | + |} + | + |object Operations { + | + | final case class Query[F[_]]( + | user: QueryUserArgs => F[Option[User]], + | userList: F[List[Option[User]]] + | ) + | + |}""" + ), + ( + "simple mutation with abstracted effect type", + gen( + """ type Mutation { setMessage(message: String): String } - """ - val result = Parser - .parseQuery(schema) - .map( - _.objectTypeDefinition("Mutation") - .map(SchemaWriter.writeRootQueryOrMutationDef(_, "F", true)) - .mkString("\n") - ) - .flatMap(Formatter.format(_, None).map(_.trim)) - - assertM(result)( - equalTo( - """final case class Mutation[F[_]]( - | setMessage: MutationSetMessageArgs => F[Option[String]] - |)""".stripMargin - ) - ) - }, - testM("schema test") { - val schema = - """ - | type Subscription { - | postAdded: Post - | } - | type Query { - | posts: [Post] - | } - | type Mutation { - | addPost(author: String, comment: String): Post - | } - | type Post { - | author: String - | comment: String - | } - |""".stripMargin - - assertM(gen(schema))( - equalTo( - """import Types._ - | - |import zio.stream.ZStream - | - |object Types { - | final case class MutationAddPostArgs(author: Option[String], comment: Option[String]) - | final case class Post(author: Option[String], comment: Option[String]) - | - |} - | - |object Operations { - | - | final case class Query( - | posts: zio.UIO[Option[List[Option[Post]]]] - | ) - | - | final case class Mutation( - | addPost: MutationAddPostArgs => zio.UIO[Option[Post]] - | ) - | - | final case class Subscription( - | postAdded: ZStream[Any, Nothing, Option[Post]] - | ) - | - |} - |""".stripMargin - ) - ) - }, - testM("empty schema test") { - assertM(gen(""))(equalTo(System.lineSeparator)) - }, - testM("enum type") { - val schema = - """ + """, + effect = "F", + isEffectTypeAbstract = true + ), + """import Types._ + | + |object Types { + | final case class MutationSetMessageArgs(message: Option[String]) + | + |} + | + |object Operations { + | + | final case class Mutation[F[_]]( + | setMessage: MutationSetMessageArgs => F[Option[String]] + | ) + | + |}""" + ), + ( + "schema test", + gen(""" + | type Subscription { + | postAdded: Post + | } + | type Query { + | posts: [Post] + | } + | type Mutation { + | addPost(author: String, comment: String): Post + | } + | type Post { + | author: String + | comment: String + | } + |"""), + """import Types._ + | + |import zio.stream.ZStream + | + |object Types { + | final case class MutationAddPostArgs(author: Option[String], comment: Option[String]) + | final case class Post(author: Option[String], comment: Option[String]) + | + |} + | + |object Operations { + | + | final case class Query( + | posts: zio.UIO[Option[List[Option[Post]]]] + | ) + | + | final case class Mutation( + | addPost: MutationAddPostArgs => zio.UIO[Option[Post]] + | ) + | + | final case class Subscription( + | postAdded: ZStream[Any, Nothing, Option[Post]] + | ) + | + |}""" + ), + ("empty schema test", gen(""), System.lineSeparator), + ( + "enum type", + gen(""" enum Origin { EARTH MARS BELT } - """.stripMargin - - assertM(gen(schema))( - equalTo( - """object Types { - - sealed trait Origin extends scala.Product with scala.Serializable - - object Origin { - case object EARTH extends Origin - case object MARS extends Origin - case object BELT extends Origin - } - -} -""" - ) - ) - }, - testM("union type") { - val role = - s""" + """), + """object Types { + | + | sealed trait Origin extends scala.Product with scala.Serializable + | + | object Origin { + | case object EARTH extends Origin + | case object MARS extends Origin + | case object BELT extends Origin + | } + | + |}""" + ), + ( + "union type", + gen(s""" \"\"\" role Captain or Pilot \"\"\" - """ - val role2 = - s""" + union Role = Captain | Pilot \"\"\" role2 Captain or Pilot or Stewart \"\"\" - """ - val schema = - s""" - $role - union Role = Captain | Pilot - $role2 union Role2 = Captain | Pilot | Stewart - + type Captain { "ship" shipName: String! } - + type Pilot { shipName: String! } - + type Stewart { shipName: String! } - """.stripMargin - - assertM(gen(schema))( - equalTo { - val role = - s"""\"\"\"role -Captain or Pilot\"\"\"""" - val role2 = - s"""\"\"\"role2 -Captain or Pilot or Stewart\"\"\"""" - s"""import caliban.schema.Annotations._ - -object Types { - - @GQLDescription($role) - sealed trait Role extends scala.Product with scala.Serializable - @GQLDescription($role2) - sealed trait Role2 extends scala.Product with scala.Serializable - - object Role2 { - final case class Stewart(shipName: String) extends Role2 - } - - final case class Captain( - @GQLDescription("ship") - shipName: String - ) extends Role - with Role2 - final case class Pilot(shipName: String) extends Role with Role2 - -} -""" - } - ) - }, - testM("GQLDescription with escaped quotes") { - val schema = - s""" + """), + s"""import caliban.schema.Annotations._ + | + |object Types { + | + | @GQLDescription(\"\"\"role + |Captain or Pilot\"\"\") + | sealed trait Role extends scala.Product with scala.Serializable + | @GQLDescription(\"\"\"role2 + |Captain or Pilot or Stewart\"\"\") + | sealed trait Role2 extends scala.Product with scala.Serializable + | + | object Role2 { + | final case class Stewart(shipName: String) extends Role2 + | } + | + | final case class Captain( + | @GQLDescription("ship") + | shipName: String + | ) extends Role + | with Role2 + | final case class Pilot(shipName: String) extends Role with Role2 + | + |}""" + ), + ( + "GQLDescription with escaped quotes", + gen(s""" type Captain { "foo \\"quotes\\" bar" shipName: String! } - """.stripMargin - - assertM(gen(schema))( - equalTo { - s"""import caliban.schema.Annotations._ - -object Types { - - final case class Captain( - @GQLDescription("foo \\"quotes\\" bar") - shipName: String - ) - -} -""" - } - ) - }, - testM("schema") { - val schema = - """ + """), + """import caliban.schema.Annotations._ + | + |object Types { + | + | final case class Captain( + | @GQLDescription("foo \"quotes\" bar") + | shipName: String + | ) + | + |}""" + ), + ( + "schema", + gen(""" schema { query: Queries } - + type Queries { characters: Int! } - """.stripMargin - - assertM(gen(schema))( - equalTo( - """object Operations { - - final case class Queries( - characters: zio.UIO[Int] - ) - -} -""" - ) - ) - }, - testM("input type") { - val schema = - """ + """), + """object Operations { + | + | final case class Queries( + | characters: zio.UIO[Int] + | ) + | + |}""" + ), + ( + "input type", + gen(""" type Character { name: String! } - + input CharacterArgs { name: String! } - """.stripMargin - - assertM(gen(schema))( - equalTo( - """object Types { - - final case class Character(name: String) - final case class CharacterArgs(name: String) - -} -""" - ) - ) - }, - testM("scala reserved word used") { - val schema = - """ + """), + """object Types { + | + | final case class Character(name: String) + | final case class CharacterArgs(name: String) + | + |}""" + ), + ( + "scala reserved word used", + gen(""" type Character { private: String! object: String! type: String! } - """.stripMargin - - assertM(gen(schema))( - equalTo( - """object Types { - - final case class Character(`private`: String, `object`: String, `type`: String) - -} -""" - ) - ) - }, - testM("final case class reserved field name used") { - val schema = - """ + """), + """ object Types { + | + | final case class Character(`private`: String, `object`: String, `type`: String) + | + |}""" + ), + ( + "final case class reserved field name used", + gen(""" type Character { wait: String! } - """.stripMargin - - assertM(gen(schema))( - equalTo( - """object Types { - - final case class Character(wait$ : String) - -} -""" - ) - ) - }, - testM("args unique class names") { - val schema = - """ - |type Hero { - | callAllies(number: Int!): [Hero!]! - |} - | - |type Villain { - | callAllies(number: Int!, w: String!): [Villain!]! - |} - """.stripMargin + """), + """object Types { + | + | final case class Character(wait$ : String) + | + |}""" + ), + ( + "args unique class names", + gen(""" + |type Hero { + | callAllies(number: Int!): [Hero!]! + |} + | + |type Villain { + | callAllies(number: Int!, w: String!): [Villain!]! + |} + """), + """object Types { + | final case class HeroCallAlliesArgs(number: Int) + | final case class VillainCallAlliesArgs(number: Int, w: String) + | final case class Hero(callAllies: HeroCallAlliesArgs => List[Hero]) + | final case class Villain(callAllies: VillainCallAlliesArgs => List[Villain]) + | + |}""" + ), + ( + "args names root level", + gen(""" + |schema { + | query: Query + | subscription: Subscription + |} + | + |type Params { + | p: Int! + |} + | + |type Query { + | characters(p: Params!): Int! + |} + | + |type Subscription { + | characters(p: Params!): Int! + |} + """), + """import Types._ + | + |import zio.stream.ZStream + | + |object Types { + | final case class QueryCharactersArgs(p: Params) + | final case class SubscriptionCharactersArgs(p: Params) + | final case class Params(p: Int) + | + |} + | + |object Operations { + | + | final case class Query( + | characters: QueryCharactersArgs => zio.UIO[Int] + | ) + | + | final case class Subscription( + | characters: SubscriptionCharactersArgs => ZStream[Any, Nothing, Int] + | ) + | + |}""" + ), + ( + "add scalar mappings and additional imports", + gen( + """ + | scalar OffsetDateTime + | + | type Subscription { + | postAdded: Post + | } + | type Query { + | posts: [Post] + | } + | type Mutation { + | addPost(author: String, comment: String): Post + | } + | type Post { + | date: OffsetDateTime! + | author: String + | comment: String + | } + |""", + scalarMappings = Map("OffsetDateTime" -> "java.time.OffsetDateTime"), + imports = List("java.util.UUID", "a.b._") + ), + """import Types._ + | + |import zio.stream.ZStream + | + |import java.util.UUID + |import a.b._ + | + |object Types { + | final case class MutationAddPostArgs(author: Option[String], comment: Option[String]) + | final case class Post(date: java.time.OffsetDateTime, author: Option[String], comment: Option[String]) + | + |} + | + |object Operations { + | + | final case class Query( + | posts: zio.UIO[Option[List[Option[Post]]]] + | ) + | + | final case class Mutation( + | addPost: MutationAddPostArgs => zio.UIO[Option[Post]] + | ) + | + | final case class Subscription( + | postAdded: ZStream[Any, Nothing, Option[Post]] + | ) + | + |}""" + ), + ( + "interface type", + gen( + s""" + \"\"\" + person + Admin or Customer + \"\"\" + interface Person { + id: ID! + firstName: String! + lastName: String! + } - assertM(gen(schema))( - equalTo( - """object Types { - | final case class HeroCallAlliesArgs(number: Int) - | final case class VillainCallAlliesArgs(number: Int, w: String) - | final case class Hero(callAllies: HeroCallAlliesArgs => List[Hero]) - | final case class Villain(callAllies: VillainCallAlliesArgs => List[Villain]) - | - |} - |""".stripMargin - ) - ) - }, - testM("args names root level") { - val schema = - """ - |schema { - | query: Query - | subscription: Subscription - |} - | - |type Params { - | p: Int! - |} - | - |type Query { - | characters(p: Params!): Int! - |} - | - |type Subscription { - | characters(p: Params!): Int! - |} - """.stripMargin + type Admin implements Person { + id: ID! + "firstName" firstName: String! + lastName: String! + } - assertM(gen(schema))( - equalTo( - """import Types._ - | - |import zio.stream.ZStream - | - |object Types { - | final case class QueryCharactersArgs(p: Params) - | final case class SubscriptionCharactersArgs(p: Params) - | final case class Params(p: Int) - | - |} - | - |object Operations { - | - | final case class Query( - | characters: QueryCharactersArgs => zio.UIO[Int] - | ) - | - | final case class Subscription( - | characters: SubscriptionCharactersArgs => ZStream[Any, Nothing, Int] - | ) - | - |} - |""".stripMargin - ) - ) - }, - testM("add scalar mappings and additional imports") { - val schema = - """ - | scalar OffsetDateTime - | - | type Subscription { - | postAdded: Post - | } - | type Query { - | posts: [Post] - | } - | type Mutation { - | addPost(author: String, comment: String): Post - | } - | type Post { - | date: OffsetDateTime! - | author: String - | comment: String - | } - |""".stripMargin + type Customer implements Person { + id: ID! + firstName: String! + lastName: String! + email: String! + } + """, + scalarMappings = Map("ID" -> "java.util.UUID") + ), + s"""import caliban.schema.Annotations._ + | + |object Types { + | + | @GQLInterface + | @GQLDescription(\"\"\"person + |Admin or Customer\"\"\") + | sealed trait Person extends scala.Product with scala.Serializable { + | def id: java.util.UUID + | def firstName: String + | def lastName: String + | } + | + | object Person { + | final case class Admin( + | id: java.util.UUID, + | @GQLDescription("firstName") + | firstName: String, + | lastName: String + | ) extends Person + | final case class Customer(id: java.util.UUID, firstName: String, lastName: String, email: String) extends Person + | } + | + |}""" + ) + ) - assertM(gen(schema, Map("OffsetDateTime" -> "java.time.OffsetDateTime"), List("java.util.UUID", "a.b._")))( - equalTo( - """import Types._ - | - |import zio.stream.ZStream - | - |import java.util.UUID - |import a.b._ - | - |object Types { - | final case class MutationAddPostArgs(author: Option[String], comment: Option[String]) - | final case class Post(date: java.time.OffsetDateTime, author: Option[String], comment: Option[String]) - | - |} - | - |object Operations { - | - | final case class Query( - | posts: zio.UIO[Option[List[Option[Post]]]] - | ) - | - | final case class Mutation( - | addPost: MutationAddPostArgs => zio.UIO[Option[Post]] - | ) - | - | final case class Subscription( - | postAdded: ZStream[Any, Nothing, Option[Post]] - | ) - | - |} - |""".stripMargin - ) - ) - } - ) @@ TestAspect.sequential + override def spec: ZSpec[TestEnvironment, Any] = suite("SchemaWriterSpec")( + assertions.map { case (name, actual, expected) => + testM(name)( + assertM(actual.map(_.stripMargin.trim))(equalTo(expected.stripMargin.trim)) + ) + }: _* + ) @@ TestAspect.sequential } diff --git a/tools/src/test/scala/caliban/tools/compiletime/ConfigSpec.scala b/tools/src/test/scala/caliban/tools/compiletime/ConfigSpec.scala index feacec6c4..6a19e437b 100644 --- a/tools/src/test/scala/caliban/tools/compiletime/ConfigSpec.scala +++ b/tools/src/test/scala/caliban/tools/compiletime/ConfigSpec.scala @@ -1,6 +1,7 @@ package caliban.tools.compiletime import caliban.tools.CalibanCommonSettings +import caliban.tools.Codegen.GenType import caliban.tools.compiletime.Config._ import zio.test._ import zio.test.environment.TestEnvironment @@ -26,6 +27,7 @@ object ConfigSpec extends DefaultRunnableSpec { assertTrue( fullExample.toCalibanCommonSettings == CalibanCommonSettings( + genType = GenType.Client, clientName = Some("CalibanClient"), scalafmtPath = Some("a/b/c"), headers = List.empty,