Skip to content

Commit

Permalink
Circe - fully qualify class names for components and properties
Browse files Browse the repository at this point in the history
Previously, when a property was named the same as its component, references to the component class and property class conflicted, causing errors. Fully qualifying the references to each should prevent this from happening.

A regression test covering a minimal failing case is included.

Fixes issue guardrail-dev#2050
  • Loading branch information
Jonnty committed Nov 11, 2024
1 parent cd6955b commit d99909e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 26 deletions.
32 changes: 32 additions & 0 deletions modules/sample/src/main/resources/issues/issue2050.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
openapi: 3.0.1
info:
title: Minimal Error Case
description: Testing that internal naming conflicts do not occur when a component and its own property have the same name
version: "1.0"
servers:
- url: "http://localhost:1234"
paths:
/test:
get:
operationId: Test
responses:
'200':
description: A test
content:
application/json:
schema:
$ref: '#/components/schemas/Test'
components:
schemas:
Test:
title: Test
required:
- test
type: object
properties:
test:
title: test
enum:
- optionA
- optionB
type: string
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}

private[this] def renderIntermediate(
clsName: NonEmptyList[String],
model: Tracker[Schema[_]],
dtoName: String,
concreteTypes: List[PropMeta[ScalaLanguage]],
Expand All @@ -510,12 +511,12 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
): Target[(PropMeta[ScalaLanguage], (Option[Defn.Val], Option[Defn.Val], Defn.Class))] =
for {
prefixes <- Cl.vendorPrefixes()
customTpe = CustomTypeName(model, prefixes).getOrElse(dtoName)
tpe <- Sc.pureTypeName(customTpe)
customTpe = NonEmptyList.of(CustomTypeName(model, prefixes).getOrElse(dtoName))
tpe <- Sc.pureTypeName(customTpe.last)
props <- extractProperties(model)
requiredFields = getRequiredFieldsRec(model)
(params, nestedDefinitions) <- prepareProperties(
NonEmptyList.of(customTpe),
customTpe,
propertyToTypeLookup = Map.empty,
props,
requiredFields,
Expand All @@ -526,10 +527,10 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
defaultPropertyRequirement,
components
)
encoder <- encodeModel(customTpe, dtoPackage, params, parents = List.empty)
decoder <- decodeModel(customTpe, dtoPackage, supportPackage, params, parents = List.empty)
defn <- renderDTOClass(customTpe, supportPackage, params, parents = List.empty)
} yield (PropMeta[ScalaLanguage](customTpe, tpe), (encoder, decoder, defn))
encoder <- encodeModel(clsName ::: customTpe, dtoPackage, params, parents = List.empty)
decoder <- decodeModel(clsName ::: customTpe, dtoPackage, supportPackage, params, parents = List.empty)
defn <- renderDTOClass(customTpe.last, supportPackage, params, parents = List.empty)
} yield (PropMeta[ScalaLanguage](customTpe.last, tpe), (encoder, decoder, defn))

private[this] def fromModel(
clsName: NonEmptyList[String],
Expand Down Expand Up @@ -565,8 +566,8 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
defaultPropertyRequirement,
components
)
encoder <- encodeModel(clsName.last, dtoPackage, params, parents)
decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents)
encoder <- encodeModel(clsName, dtoPackage, params, parents)
decoder <- decodeModel(clsName, dtoPackage, supportPackage, params, parents)
tpe <- parseTypeName(clsName.last)
fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x)))
nestedClasses <- nestedDefinitions.flatTraverse {
Expand Down Expand Up @@ -652,6 +653,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
case other => Target.raiseUserError(s"Unexpected type ${other}")
}
(pm, defns) <- renderIntermediate(
clsName,
model,
dtoName,
concreteTypes,
Expand Down Expand Up @@ -867,7 +869,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[ScalaLanguage]], Option[NestedProtocolElems[ScalaLanguage]])] {
case (name, schema) =>
for {
typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName))
typeName <- formatTypeName(name).map(formattedName => getClsName(name).prependList(dtoPackage).append(formattedName))
tpe <- selectType(typeName)
maybeNestedDefinition <- processProperty(name, schema)
resolvedType <- ModelResolver.propMetaWithName[ScalaLanguage, Target](() => Target.pure(tpe), schema, components)
Expand Down Expand Up @@ -1407,19 +1409,21 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
clsName: NonEmptyList[String],
dtoPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) =
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]) = {
import Lt._
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
qualifiedClsType <- selectType(clsName.prependList(dtoPackage))
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)

encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""
Expand Down Expand Up @@ -1448,14 +1452,14 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}
}

val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName.last))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}

q"""
${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
${circeVersion.encoderObjectCompanion}.instance[${qualifiedClsType}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
"""
}
(readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
Expand All @@ -1466,24 +1470,28 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}

} yield Option(q"""
implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = {
implicit val ${suffixClsName("encode", clsName.last)}: ${circeVersion.encoderObject}[${qualifiedClsType}] = {
..${readOnlyDefn};
${readOnlyFilter(encVal)}
}
""")
}

private def decodeModel(
clsName: String,
clsName: NonEmptyList[String],
dtoPackage: List[String],
supportPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
import Lt._
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
qualifiedClsType <- selectType(clsName.prependList(dtoPackage))
qualifiedClsTerm <- selectTerm(clsName.prependList(dtoPackage))
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
Expand All @@ -1492,9 +1500,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
if (paramCount == 0) {
Target.pure(
q"""
new _root_.io.circe.Decoder[${Type.Name(clsName)}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] =
_root_.scala.Right(${Term.Name(clsName)}())
new _root_.io.circe.Decoder[${qualifiedClsType}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${qualifiedClsType}] =
_root_.scala.Right(${qualifiedClsTerm}())
}
"""
)
Expand Down Expand Up @@ -1569,17 +1577,17 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
.map { pairs =>
val (terms, enumerators) = pairs.unzip
q"""
new _root_.io.circe.Decoder[${Type.Name(clsName)}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] =
new _root_.io.circe.Decoder[${qualifiedClsType}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${qualifiedClsType}] =
for {
..${enumerators}
} yield ${Term.Name(clsName)}(..${terms})
} yield ${qualifiedClsTerm}(..${terms})
}
"""
}
}
} yield Option(q"""
implicit val ${suffixClsName("decode", clsName)}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = $decVal
implicit val ${suffixClsName("decode", clsName.last)}: _root_.io.circe.Decoder[${qualifiedClsType}] = $decVal
""")
}

Expand Down
1 change: 1 addition & 0 deletions project/src/main/scala/RegressionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ object RegressionTests {
ExampleCase(sampleResource("issues/issue1260.yaml"), "issues.issue1260"),
ExampleCase(sampleResource("issues/issue1218.yaml"), "issues.issue1218").frameworks("scala" -> Set("http4s", "http4s-v0.22")),
ExampleCase(sampleResource("issues/issue1594.yaml"), "issues.issue1594"),
ExampleCase(sampleResource("issues/issue2050.yaml"), "issues.issue2050"),
ExampleCase(sampleResource("multipart-form-data.yaml"), "multipartFormData"),
ExampleCase(sampleResource("petstore.json"), "examples").args("--import", "examples.support.PositiveLong"),
// ExampleCase(sampleResource("petstore-openapi-3.0.2.yaml"), "examples.petstore.openapi302").args("--import", "examples.support.PositiveLong"),
Expand Down

0 comments on commit d99909e

Please sign in to comment.