Skip to content

Commit

Permalink
Circe - fully qualify class names for components and properties to av…
Browse files Browse the repository at this point in the history
…oid naming conflicts

Previously, when a property was named the same as its component, references to the component class were misinterpreted as references to the property class, causing errors. Fully qualifying the references to each should prevent this from happening.

A regression test covering a minimal failing case is included.

Fixes issue guardrail-dev#2050
  • Loading branch information
Jonnty committed Nov 11, 2024
1 parent cd6955b commit 77ce8dd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 26 deletions.
32 changes: 32 additions & 0 deletions modules/sample/src/main/resources/issues/issue2050.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
openapi: 3.0.1
info:
title: Minimal Error Case
description: Testing that internal naming conflicts do not occur when a component and its own property have the same name
version: "1.0"
servers:
- url: "http://localhost:1234
paths:
/test:
get:
operationId: Test
responses:
'200':
description: A test
content:
application/json:
schema:
$ref: '#/components/schemas/Test'
components:
schemas:
Test:
title: Test
required:
- test
type: object
properties:
test:
title: test
enum:
- optionA
- optionB
type: string
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}

private[this] def renderIntermediate(
clsName: NonEmptyList[String],
model: Tracker[Schema[_]],
dtoName: String,
concreteTypes: List[PropMeta[ScalaLanguage]],
Expand All @@ -510,12 +511,12 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
): Target[(PropMeta[ScalaLanguage], (Option[Defn.Val], Option[Defn.Val], Defn.Class))] =
for {
prefixes <- Cl.vendorPrefixes()
customTpe = CustomTypeName(model, prefixes).getOrElse(dtoName)
tpe <- Sc.pureTypeName(customTpe)
customTpe = NonEmptyList.of(CustomTypeName(model, prefixes).getOrElse(dtoName))
tpe <- Sc.pureTypeName(customTpe.last)
props <- extractProperties(model)
requiredFields = getRequiredFieldsRec(model)
(params, nestedDefinitions) <- prepareProperties(
NonEmptyList.of(customTpe),
customTpe,
propertyToTypeLookup = Map.empty,
props,
requiredFields,
Expand All @@ -526,10 +527,10 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
defaultPropertyRequirement,
components
)
encoder <- encodeModel(customTpe, dtoPackage, params, parents = List.empty)
decoder <- decodeModel(customTpe, dtoPackage, supportPackage, params, parents = List.empty)
defn <- renderDTOClass(customTpe, supportPackage, params, parents = List.empty)
} yield (PropMeta[ScalaLanguage](customTpe, tpe), (encoder, decoder, defn))
encoder <- encodeModel(clsName ::: customTpe, dtoPackage, params, parents = List.empty)
decoder <- decodeModel(clsName ::: customTpe, dtoPackage, supportPackage, params, parents = List.empty)
defn <- renderDTOClass(customTpe.last, supportPackage, params, parents = List.empty)
} yield (PropMeta[ScalaLanguage](customTpe.last, tpe), (encoder, decoder, defn))

private[this] def fromModel(
clsName: NonEmptyList[String],
Expand Down Expand Up @@ -565,8 +566,8 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
defaultPropertyRequirement,
components
)
encoder <- encodeModel(clsName.last, dtoPackage, params, parents)
decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents)
encoder <- encodeModel(clsName, dtoPackage, params, parents)
decoder <- decodeModel(clsName, dtoPackage, supportPackage, params, parents)
tpe <- parseTypeName(clsName.last)
fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x)))
nestedClasses <- nestedDefinitions.flatTraverse {
Expand Down Expand Up @@ -652,6 +653,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
case other => Target.raiseUserError(s"Unexpected type ${other}")
}
(pm, defns) <- renderIntermediate(
clsName,
model,
dtoName,
concreteTypes,
Expand Down Expand Up @@ -867,7 +869,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[ScalaLanguage]], Option[NestedProtocolElems[ScalaLanguage]])] {
case (name, schema) =>
for {
typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName))
typeName <- formatTypeName(name).map(formattedName => getClsName(name).prependList(dtoPackage).append(formattedName))
tpe <- selectType(typeName)
maybeNestedDefinition <- processProperty(name, schema)
resolvedType <- ModelResolver.propMetaWithName[ScalaLanguage, Target](() => Target.pure(tpe), schema, components)
Expand Down Expand Up @@ -1407,19 +1409,21 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
clsName: NonEmptyList[String],
dtoPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) =
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]) = {
import Lt._
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
qualifiedClsType <- selectType(clsName.prependList(dtoPackage))
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)

encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""
Expand Down Expand Up @@ -1448,14 +1452,14 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}
}

val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName.last))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}

q"""
${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
${circeVersion.encoderObjectCompanion}.instance[${qualifiedClsType}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
"""
}
(readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
Expand All @@ -1466,24 +1470,28 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
}

} yield Option(q"""
implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = {
implicit val ${suffixClsName("encode", clsName.last)}: ${circeVersion.encoderObject}[${qualifiedClsType}] = {
..${readOnlyDefn};
${readOnlyFilter(encVal)}
}
""")
}

private def decodeModel(
clsName: String,
clsName: NonEmptyList[String],
dtoPackage: List[String],
supportPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
import Lt._
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
qualifiedClsType <- selectType(clsName.prependList(dtoPackage))
qualifiedClsTerm <- selectTerm(clsName.prependList(dtoPackage))
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
Expand All @@ -1492,9 +1500,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
if (paramCount == 0) {
Target.pure(
q"""
new _root_.io.circe.Decoder[${Type.Name(clsName)}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] =
_root_.scala.Right(${Term.Name(clsName)}())
new _root_.io.circe.Decoder[${qualifiedClsType}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${qualifiedClsType}] =
_root_.scala.Right(${qualifiedClsTerm}())
}
"""
)
Expand Down Expand Up @@ -1569,17 +1577,17 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
.map { pairs =>
val (terms, enumerators) = pairs.unzip
q"""
new _root_.io.circe.Decoder[${Type.Name(clsName)}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] =
new _root_.io.circe.Decoder[${qualifiedClsType}] {
final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${qualifiedClsType}] =
for {
..${enumerators}
} yield ${Term.Name(clsName)}(..${terms})
} yield ${qualifiedClsTerm}(..${terms})
}
"""
}
}
} yield Option(q"""
implicit val ${suffixClsName("decode", clsName)}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = $decVal
implicit val ${suffixClsName("decode", clsName.last)}: _root_.io.circe.Decoder[${qualifiedClsType}] = $decVal
""")
}

Expand Down
1 change: 1 addition & 0 deletions project/src/main/scala/RegressionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ object RegressionTests {
ExampleCase(sampleResource("issues/issue1260.yaml"), "issues.issue1260"),
ExampleCase(sampleResource("issues/issue1218.yaml"), "issues.issue1218").frameworks("scala" -> Set("http4s", "http4s-v0.22")),
ExampleCase(sampleResource("issues/issue1594.yaml"), "issues.issue1594"),
ExampleCase(sampleResource("issues/issue2050.yaml"), "issues.issue2050"),
ExampleCase(sampleResource("multipart-form-data.yaml"), "multipartFormData"),
ExampleCase(sampleResource("petstore.json"), "examples").args("--import", "examples.support.PositiveLong"),
// ExampleCase(sampleResource("petstore-openapi-3.0.2.yaml"), "examples.petstore.openapi302").args("--import", "examples.support.PositiveLong"),
Expand Down

0 comments on commit 77ce8dd

Please sign in to comment.