Skip to content

Commit

Permalink
only generate enum query param codecs for enums that're actually used…
Browse files Browse the repository at this point in the history
… as query params
  • Loading branch information
hughsimpson committed Mar 14, 2024
1 parent 7b5771d commit d1bbdd7
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object BasicGenerator {
if (!targetScala3 && doc.components.toSeq.flatMap(_.schemas).exists(_._2.isInstanceOf[OpenapiSchemaEnum])) "\n import enumeratum._"
else ""

val endpointsByTag = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val EndpointDefs(endpointsByTag, queryParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
Expand All @@ -57,7 +57,7 @@ object BasicGenerator {
|
|${indent(2)(imports)}$enumImport
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3).getOrElse(""))}
|${indent(2)(classGenerator.classDefs(doc, targetScala3, queryParamRefs).getOrElse(""))}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{

class ClassDefinitionGenerator {

def classDefs(doc: OpenapiDocument, targetScala3: Boolean = false): Option[String] = {
val generatesEnums = doc.components.exists(_.schemas.exists(_._2.isInstanceOf[OpenapiSchemaEnum]))
def classDefs(doc: OpenapiDocument, targetScala3: Boolean = false, queryParamRefs: Set[String] = Set.empty): Option[String] = {
val generatesQueryParamEnums =
doc.components.toSeq
.flatMap(_.schemas.collect { case (name, _: OpenapiSchemaEnum) => name })
.exists(queryParamRefs.contains)
val enumQuerySerdeHelper =
if (!generatesEnums) ""
if (!generatesQueryParamEnums) ""
else if (targetScala3) "" // TODO
else
""" def makeQueryCodecForEnum[T <: EnumEntry](T: Enum[T] with CirceEnum[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
Expand All @@ -33,7 +36,7 @@ class ClassDefinitionGenerator {
case (name, obj: OpenapiSchemaObject) =>
generateClass(name, obj)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3)
generateEnum(name, obj, targetScala3, queryParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
})
Expand All @@ -50,18 +53,27 @@ class ClassDefinitionGenerator {
}

// Uses enumeratum for scala 2, but generates scala 3 enums instead where it can
private[codegen] def generateEnum(name: String, obj: OpenapiSchemaEnum, targetScala3: Boolean): Seq[String] = if (targetScala3) {
private[codegen] def generateEnum(
name: String,
obj: OpenapiSchemaEnum,
targetScala3: Boolean,
queryParamRefs: Set[String]
): Seq[String] = if (targetScala3) {
s"""enum $name derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
val maybeQueryCodecDefn =
if (queryParamRefs contains name)
s"""
| implicit val ${name.head.toLower +: name.tail}Codec: sttp.tapir.Codec[List[String], ${name}, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum(${name})""".stripMargin
else ""
s"""|sealed trait $name extends EnumEntry
|object $name extends Enum[$name] with CirceEnum[$name] {
| val values = findValues
|${indent(2)(members.mkString("\n"))}
| implicit val ${name.head.toLower +: name.tail}Codec: sttp.tapir.Codec[List[String], ${name}, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum(${name})
|${indent(2)(members.mkString("\n"))}$maybeQueryCodecDefn
|}""".stripMargin :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, Openapi
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaBinary,
OpenapiSchemaEnum,
OpenapiSchemaRef,
OpenapiSchemaSimpleType
}
Expand All @@ -14,16 +15,25 @@ case class Location(path: String, method: String) {
override def toString: String = s"${method.toUpperCase} ${path}"
}

case class GeneratedEndpoints(namesAndBodies: Seq[(Option[String], Seq[(String, String)])], queryParamRefs: Set[String]) {
def merge(that: GeneratedEndpoints): GeneratedEndpoints =
GeneratedEndpoints(
(namesAndBodies ++ that.namesAndBodies).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _)).toSeq,
queryParamRefs ++ that.queryParamRefs
)
}
case class EndpointDefs(endpointDecls: Map[Option[String], String], queryParamRefs: Set[String])

class EndpointGenerator {
private def bail(msg: String)(implicit location: Location): Nothing = throw new NotImplementedError(s"$msg at $location")

private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): Map[Option[String], String] = {
def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): EndpointDefs = {
val components = Option(doc.components).flatten
val geMap =
doc.paths.flatMap(generatedEndpoints(components, useHeadTagForObjectNames)).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _))
geMap.mapValues { ge =>
val GeneratedEndpoints(geMap, queryParamRefs) =
doc.paths.map(generatedEndpoints(components, useHeadTagForObjectNames)).foldLeft(GeneratedEndpoints(Nil, Set.empty))(_ merge _)
val endpointDecls = geMap.map { case (k, ge) =>
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
Expand All @@ -33,20 +43,21 @@ class EndpointGenerator {
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
k -> s"""|$definitions
|
|$allEP
|""".stripMargin
}.toMap
EndpointDefs(endpointDecls, queryParamRefs)
}

private[codegen] def generatedEndpoints(components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean)(
p: OpenapiPath
): Seq[(Option[String], Seq[(String, String)])] = {
): GeneratedEndpoints = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

p.methods
val (fileNamesAndParams, unflattenedQueryParamRefs) = p.methods
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)
Expand All @@ -68,11 +79,18 @@ class EndpointGenerator {
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
(maybeTargetFileName, (name, definition))
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
.collect { case OpenapiSchemaRef(ref) if ref.startsWith("#/components/schemas/") => ref.stripPrefix("#/components/schemas/") }
.toSet
(maybeTargetFileName, (name, definition)) -> queryParamRefs
}
.unzip
val namesAndParamsByFile = fileNamesAndParams
.groupBy(_._1)
.toSeq
.map { case (maybeTargetFileName, defns) => maybeTargetFileName -> defns.map(_._2) }
GeneratedEndpoints(namesAndParamsByFile, unflattenedQueryParamRefs.foldLeft(Set.empty[String])(_ ++ _))
}

private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
TestHelpers.enumQueryParamDocs,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false
) shouldCompile ()
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints") shouldCompile ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {

val res: String = parserRes match {
case Left(value) => throw new Exception(value)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
}

val compileUnit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
val generatedCode =
BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
generatedCode should include("val getTestAsdId =")
generatedCode shouldCompile ()
}
Expand Down Expand Up @@ -131,7 +132,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
BasicGenerator.imports ++
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) shouldCompile ()
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None) shouldCompile ()
}

it should "handle status codes" in {
Expand Down Expand Up @@ -174,7 +175,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
val generatedCode =
BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
) // status code with body
Expand Down

0 comments on commit d1bbdd7

Please sign in to comment.