Skip to content

Commit

Permalink
codegen: Fix issues with jsoniter in scala3 (#3963)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Aug 5, 2024
1 parent 81c9a76 commit 9916314
Show file tree
Hide file tree
Showing 11 changed files with 461 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ClassDefinitionGenerator {
jsonSerdeLib,
jsonParamRefs,
allTransitiveJsonParamRefs,
fullModelPath,
validateNonDiscriminatedOneOfs,
adtInheritanceMap.mapValues(_.map(_._1)),
targetScala3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ object EnumGenerator {
case _ if !jsonParamRefs.contains(name) => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Circe if !queryParamRefs.contains(name) => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec"
case JsonSerdeLib.Circe => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter | JsonSerdeLib.Zio if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Jsoniter | JsonSerdeLib.Zio => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter if !queryParamRefs.contains(name) => ""
case JsonSerdeLib.Jsoniter => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Zio if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Zio => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
}
s"""$maybeCompanion
|enum $name$maybeCodecExtensions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ object JsonSerdeGenerator {
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
jsonParamRefs: Set[String],
allTransitiveJsonParamRefs: Set[String],
fullModelPath: String,
validateNonDiscriminatedOneOfs: Boolean,
adtInheritanceMap: Map[String, Seq[String]],
targetScala3: Boolean
Expand All @@ -41,7 +40,6 @@ object JsonSerdeGenerator {
jsonParamRefs,
allTransitiveJsonParamRefs,
adtInheritanceMap,
if (fullModelPath.isEmpty) None else Some(fullModelPath),
validateNonDiscriminatedOneOfs
)
case JsonSerdeLib.Zio => genZioSerdes(doc, allSchemas, allTransitiveJsonParamRefs, validateNonDiscriminatedOneOfs, targetScala3)
Expand Down Expand Up @@ -233,7 +231,6 @@ object JsonSerdeGenerator {
jsonParamRefs: Set[String],
allTransitiveJsonParamRefs: Set[String],
adtInheritanceMap: Map[String, Seq[String]],
fullModelPath: Option[String],
validateNonDiscriminatedOneOfs: Boolean
): Option[String] = {
// For jsoniter-scala, we define explicit serdes for any 'primitive' params (e.g. List[java.util.UUID]) that we reference.
Expand Down Expand Up @@ -271,7 +268,7 @@ object JsonSerdeGenerator {
Some(genJsoniterEnumSerde(name))
// For ADTs, generate the serde if it's referenced in any json model
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
Some(generateJsoniterAdtSerde(allSchemas, name, schema, fullModelPath, validateNonDiscriminatedOneOfs))
Some(generateJsoniterAdtSerde(allSchemas, name, schema, validateNonDiscriminatedOneOfs))
case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None
case (n, x) => throw new NotImplementedError(s"Only objects, enums, maps and oneOf supported! (for $n found ${x})")
})
Expand Down Expand Up @@ -304,10 +301,8 @@ object JsonSerdeGenerator {
allSchemas: Map[String, OpenapiSchemaType],
name: String,
schema: OpenapiSchemaOneOf,
maybeFullModelPath: Option[String],
validateNonDiscriminatedOneOfs: Boolean
): String = {
val fullPathPrefix = maybeFullModelPath.map(_ + ".").getOrElse("")
val uncapitalisedName = BasicGenerator.uncapitalise(name)
schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
Expand All @@ -323,11 +318,11 @@ object JsonSerdeGenerator {
val body = if (schemaToJsonMapping.exists { case (className, jsonValue) => className != jsonValue }) {
val discriminatorMap = indent(2)(
schemaToJsonMapping
.map { case (k, v) => s"""case "$fullPathPrefix$k" => "$v"""" }
.map { case (k, v) => s"""case "$k" => "$v"""" }
.mkString("\n", "\n", "\n")
)
val config =
s"""$jsoniterBaseConfig.withRequireDiscriminatorFirst(false).withDiscriminatorFieldName(Some("${discriminator.propertyName}")).withAdtLeafClassNameMapper{$discriminatorMap}"""
s"""$jsoniterBaseConfig.withRequireDiscriminatorFirst(false).withDiscriminatorFieldName(Some("${discriminator.propertyName}")).withAdtLeafClassNameMapper(x => com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker.simpleClassName(x) match {$discriminatorMap})"""
val serde =
s"implicit lazy val ${uncapitalisedName}Codec: $jsoniterPkgCore.JsonValueCodec[$name] = $jsoniterPkgMacros.JsonCodecMaker.make($config)"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package sttp.tapir.generated

object TapirGeneratedEndpoints {

import sttp.tapir._
import sttp.tapir.model._
import sttp.tapir.generic.auto._
import sttp.tapir.json.jsoniter._
import com.github.plokhotnyuk.jsoniter_scala.macros._
import com.github.plokhotnyuk.jsoniter_scala.core._

import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._
import TapirGeneratedEndpointsSchemas._


case class CommaSeparatedValues[T](values: List[T])
case class ExplodedValues[T](values: List[T])
trait ExtraParamSupport[T] {
def decode(s: String): sttp.tapir.DecodeResult[T]
def encode(t: T): String
}
implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
}
implicit def makeUnexplodedQuerySeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], CommaSeparatedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(values => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(s => CommaSeparatedValues(s.toList)))(_.values.map(support.encode).mkString(","))
}
implicit def makeUnexplodedQueryOptSeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], Option[CommaSeparatedValues[T]], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode{
case None => DecodeResult.Value(None)
case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
}(_.map(_.values.map(support.encode).mkString(",")))
}
implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}


sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD1 (
s: String,
i: Option[Int] = None,
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD1 (
s: String,
i: Option[Int] = None,
d: Option[Double] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD3 (
s: String,
i: Option[Int] = None,
e: Option[AnEnum] = None,
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithoutD2 (
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD2 (
s: String,
a: Option[Seq[String]] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping

enum AnEnum {
case Foo, Bar, Baz
}



lazy val putAdtTest =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))


lazy val generatedEndpoints = List(putAdtTest, postAdtTest)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
lazy val root = (project in file("."))
.enablePlugins(OpenapiCodegenPlugin)
.settings(
scalaVersion := "3.3.3",
version := "0.1",
openapiJsonSerdeLib := "jsoniter"
)

libraryDependencies ++= Seq(
"com.softwaremill.sttp.tapir" %% "tapir-jsoniter-scala" % "1.10.0",
"com.softwaremill.sttp.tapir" %% "tapir-openapi-docs" % "1.10.0",
"com.softwaremill.sttp.apispec" %% "openapi-circe-yaml" % "0.8.0",
"com.beachape" %% "enumeratum" % "1.7.4",
"com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala-core" % "2.30.7",
"com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala-macros" % "2.30.7" % "compile-internal",
"org.scalatest" %% "scalatest" % "3.2.19" % Test,
"com.softwaremill.sttp.tapir" %% "tapir-sttp-stub-server" % "1.10.0" % Test
)

import sttp.tapir.sbt.OpenapiCodegenPlugin.autoImport.{openapiJsonSerdeLib, openapiUseHeadTagForObjectName}

import scala.io.Source

TaskKey[Unit]("check") := {
val generatedCode =
Source.fromFile("target/scala-3.3.3/src_managed/main/sbt-openapi-codegen/TapirGeneratedEndpoints.scala").getLines.mkString("\n")
val expected = Source.fromFile("Expected.scala.txt").getLines.mkString("\n")
val generatedTrimmed =
generatedCode.linesIterator.zipWithIndex.filterNot(_._1.forall(_.isWhitespace)).map { case (a, i) => a.trim -> i }.toSeq
val expectedTrimmed = expected.linesIterator.filterNot(_.forall(_.isWhitespace)).map(_.trim).toSeq
if (generatedTrimmed.size != expectedTrimmed.size)
sys.error(s"expected ${expectedTrimmed.size} non-empty lines, found ${generatedTrimmed.size}")
generatedTrimmed.zip(expectedTrimmed).foreach { case ((a, i), b) =>
if (a != b) sys.error(s"Generated code did not match (expected '$b' on line $i, found '$a')")
}
println("Skipping swagger roundtrip for petstore")
()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.10.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
val pluginVersion = System.getProperty("plugin.version")
if (pluginVersion == null)
throw new RuntimeException("""|
|
|The system property 'plugin.version' is not defined.
|Specify this property using the scriptedLaunchOpts -D.
|
|""".stripMargin)
else addSbtPlugin("com.softwaremill.sttp.tapir" % "sbt-openapi-codegen" % pluginVersion)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
object Main extends App {
import sttp.apispec.openapi.circe.yaml._
import sttp.tapir.generated._
import sttp.tapir.docs.openapi._

val docs = OpenAPIDocsInterpreter().toOpenAPI(TapirGeneratedEndpoints.generatedEndpoints, "My Bookshop", "1.0")

import java.nio.file.{Paths, Files}
import java.nio.charset.StandardCharsets

Files.write(Paths.get("target/swagger.yaml"), docs.toYaml.getBytes(StandardCharsets.UTF_8))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import com.github.plokhotnyuk.jsoniter_scala.core.writeToString
import io.circe.parser.parse
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
import sttp.client3.UriContext
import sttp.client3.testing.SttpBackendStub
import sttp.tapir.generated.{TapirGeneratedEndpoints, TapirGeneratedEndpointsJsonSerdes}
import sttp.tapir.generated.TapirGeneratedEndpoints.*
import sttp.tapir.generated.TapirGeneratedEndpointsSchemas.*
import TapirGeneratedEndpointsJsonSerdes._
import sttp.tapir.server.stub.TapirStubInterpreter

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global

class JsonRoundtrip extends AnyFreeSpec with Matchers {
"oneOf without discriminator can be round-tripped by generated serdes" in {
val route = TapirGeneratedEndpoints.putAdtTest.serverLogic[Future]({
case foo: SubtypeWithoutD1 =>
Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD1(foo.s + "+SubtypeWithoutD1", foo.i, foo.a))
case foo: SubtypeWithoutD2 => Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD2(foo.a :+ "+SubtypeWithoutD2"))
case foo: SubtypeWithoutD3 =>
Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD3(foo.s + "+SubtypeWithoutD3", foo.i, foo.e))
})

val stub = TapirStubInterpreter(SttpBackendStub.asynchronousFuture)
.whenServerEndpoint(route)
.thenRunLogic()
.backend()

def normalise(json: String): String = parse(json).toTry.get.noSpacesSortKeys
locally {
val reqBody = SubtypeWithoutD1("a string", Some(123), Seq("string 1", "string 2"))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD1("a string+SubtypeWithoutD1", Some(123), Seq("string 1", "string 2"))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"s":"a string","i":123,"a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"s":"a string+SubtypeWithoutD1","i":123,"a":["string 1","string 2"]}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

locally {
val reqBody = SubtypeWithoutD2(Seq("string 1", "string 2"))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD2(Seq("string 1", "string 2", "+SubtypeWithoutD2"))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"a":["string 1","string 2","+SubtypeWithoutD2"]}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.body shouldEqual Right(respJsonBody)
resp.code.code === 200
},
1.second
)
}

locally {
val reqBody = SubtypeWithoutD3("a string", Some(123), Some(AnEnum.Foo))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD3("a string+SubtypeWithoutD3", Some(123), Some(AnEnum.Foo))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"s":"a string","i":123,"e":"Foo"}"""
respJsonBody shouldEqual """{"s":"a string+SubtypeWithoutD3","i":123,"e":"Foo"}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.body shouldEqual Right(respJsonBody)
resp.code.code === 200
},
1.second
)
}
}
"oneOf with discriminator can be round-tripped by generated serdes" in {
val route = TapirGeneratedEndpoints.postAdtTest.serverLogic[Future]({
case foo: SubtypeWithD1 => Future successful Right[Unit, ADTWithDiscriminator](SubtypeWithD1(foo.s + "+SubtypeWithD1", foo.i, foo.d))
case foo: SubtypeWithD2 => Future successful Right[Unit, ADTWithDiscriminator](SubtypeWithD2(foo.s + "+SubtypeWithD2", foo.a))
})

val stub = TapirStubInterpreter(SttpBackendStub.asynchronousFuture)
.whenServerEndpoint(route)
.thenRunLogic()
.backend()

def normalise(json: String): String = parse(json).toTry.get.noSpacesSortKeys

locally {
val reqBody: ADTWithDiscriminatorNoMapping = SubtypeWithD1("a string", Some(123), Some(23.4))
val reqJsonBody = writeToString(reqBody)
val respBody: ADTWithDiscriminator = SubtypeWithD1("a string+SubtypeWithD1", Some(123), Some(23.4))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"type":"SubtypeWithD1","s":"a string","i":123,"d":23.4}"""
respJsonBody shouldEqual """{"type":"SubA","s":"a string+SubtypeWithD1","i":123,"d":23.4}"""
Await.result(
sttp.client3.basicRequest
.post(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

locally {
val reqBody: ADTWithDiscriminatorNoMapping = SubtypeWithD2("a string", Some(Seq("string 1", "string 2")))
val reqJsonBody = writeToString(reqBody)
val respBody: ADTWithDiscriminator = SubtypeWithD2("a string+SubtypeWithD2", Some(Seq("string 1", "string 2")))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"type":"SubtypeWithD2","s":"a string","a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"type":"SubB","s":"a string+SubtypeWithD2","a":["string 1","string 2"]}"""
Await.result(
sttp.client3.basicRequest
.post(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

}
}
Loading

0 comments on commit 9916314

Please sign in to comment.