Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Pickler for coproducts and enums #3222

Merged
merged 12 commits into from
Oct 6, 2023
24 changes: 9 additions & 15 deletions core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,35 @@ import scala.quoted.*

private[tapir] object EnumerationMacros:

transparent inline def isEnumeration[T]: Boolean = inline compiletime.erasedValue[T] match
case _: Null => false
case _: Nothing => false
case _: reflect.Enum => allChildrenObjectsOrEnumerationCases[T]
case _ => false

/** Checks whether type T has child types, and all children of type T are objects or enum cases or sealed parents of such. Useful for
* determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an
* enumeration in context of schemas and JSON codecs.
/** Checks whether type T is an "enumeration", which means either Scalas 3 enum, or a sealed hierarchy where all members are case objects
* or enum cases. Useful for deriving schemas and JSON codecs.
*/
inline def allChildrenObjectsOrEnumerationCases[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }
inline def isEnumeration[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }

def allChildrenObjectsOrEnumerationCasesImpl[T: Type](using q: Quotes): Expr[Boolean] =
Expr(enumerationTypeChildren[T](failOnError = false).nonEmpty)
val typeChildren = enumerationTypeChildren[T](failOnError = false)
Expr(typeChildren.nonEmpty && !typeChildren.exists(_.isEmpty))

/** Recursively scans a symbol and builds a list of all children and their children, as long as all of them are objects or enum cases or
* sealed parents of such. Useful for determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which
* case it's not really an enumeration (in context of schemas and JSON codecs).
*/
def enumerationTypeChildren[T: Type](failOnError: Boolean)(using q: Quotes): List[q.reflect.Symbol] =
def enumerationTypeChildren[T: Type](failOnError: Boolean)(using q: Quotes): List[Option[q.reflect.Symbol]] =
import quotes.reflect.*

val tpe = TypeRepr.of[T]
val symbol = tpe.typeSymbol

def flatChildren(s: Symbol): List[Symbol] = s.children.toList.flatMap { c =>
def flatChildren(s: Symbol): List[Option[Symbol]] = s.children.toList.flatMap { c =>
if (c.isClassDef) {
if (!(c.flags is Flags.Sealed))
if (failOnError)
report.errorAndAbort("All children must be objects or enum cases, or sealed parent of such.")
else
Nil
List(None)
else
flatChildren(c)
} else List(c)
} else List(Some(c))
}

flatChildren(symbol)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private[tapir] object ValidatorMacros {
report.errorAndAbort("Can only enumerate values of a sealed trait, class or enum.")
}

val instances = enumerationTypeChildren[T](failOnError = true).distinct
val instances = enumerationTypeChildren[T](failOnError = true).flatMap(_.toList).distinct
.sortBy(_.name)
.map(x =>
tpe.memberType(x).asType match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package sttp.tapir.json.pickler

import _root_.upickle.implicits.{macros => upickleMacros}
import sttp.tapir.generic.Configuration
import sttp.tapir.macros.CreateDerivedEnumerationSchema
import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator}
import upickle.core.{Annotator, Types}

import scala.deriving.Mirror
import scala.reflect.ClassTag

import compiletime.*

/** A builder allowing deriving Pickler for enums, used by [[Pickler.derivedEnumeration]]. Can be used to set non-standard encoding logic,
* schema type or default value for an enum.
*/
Expand All @@ -26,32 +30,62 @@ class CreateDerivedEnumerationPickler[T: ClassTag](
encode: Option[T => Any] = Some(identity),
schemaType: SchemaType[T] = SchemaType.SString[T](),
default: Option[T] = None
)(using m: Mirror.Of[T]): Pickler[T] = {
)(using m: Mirror.SumOf[T]): Pickler[T] = {
val schema: Schema[T] = new CreateDerivedEnumerationSchema(validator, schemaAnnotations).apply(
encode,
schemaType,
default
)
given Configuration = Configuration.default
given SubtypeDiscriminator[T] = EnumValueDiscriminator[T](
given subtypeDiscriminator: SubtypeDiscriminator[T] = EnumValueDiscriminator[T](
encode.map(_.andThen(_.toString)).getOrElse(_.toString),
validator
)
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = Pickler.summonChildPicklerInstances[T, m.MirroredElemTypes]
Pickler.picklerSum(schema, childPicklers)

lazy val childReadWriters = buildEnumerationReadWriters[T, m.MirroredElemTypes]
val tapirPickle = new TapirPickle[T] {
override lazy val reader: Reader[T] =
macroSumR[T](childReadWriters.map(_._1), subtypeDiscriminator)

override lazy val writer: Writer[T] =
macroSumW[T](childReadWriters.map(_._2), subtypeDiscriminator)
}
new Pickler[T](tapirPickle, schema)
}

private inline def buildEnumerationReadWriters[T: ClassTag, Cases <: Tuple]: List[(Types#Reader[_], Types#Writer[_])] =
inline erasedValue[Cases] match {
case _: (enumerationCase *: enumerationCasesTail) =>
val processedHead = readWriterForEnumerationCase[enumerationCase]
val processedTail = buildEnumerationReadWriters[T, enumerationCasesTail]
(processedHead +: processedTail)
case _: EmptyTuple.type => Nil
}

private inline def readWriterForEnumerationCase[C]: (Types#Reader[C], Types#Writer[C]) =
val pickle = new TapirPickle[C] {
// We probably don't need a separate TapirPickle for each C, this could be optimized.
// https://github.com/softwaremill/tapir/issues/3192
override lazy val writer = annotate[C](
SingletonWriter[C](null.asInstanceOf[C]),
upickleMacros.tagName[C],
Annotator.Checker.Val(upickleMacros.getSingleton[C])
)
override lazy val reader = annotate[C](SingletonReader[C](upickleMacros.getSingleton[C]), upickleMacros.tagName[C])
}
(pickle.reader, pickle.writer)

/** Creates the Pickler assuming the low-level representation is a `String`. The encoding function passes the object unchanged (which
* means `.toString` will be used to represent the enumeration in JSON and documentation). Typically you don't need to explicitly use
* `Pickler.derivedEnumeration[T].defaultStringBased`, as this is the default behavior of [[Pickler.derived]] for enums.
*/
inline def defaultStringBased(using Mirror.Of[T]) = apply()
inline def defaultStringBased(using Mirror.SumOf[T]) = apply()

/** Creates the Pickler assuming the low-level representation is a `String`. Provide your custom encoding function for representing an
* enum value as a String. It will be used to represent the enumeration in JSON and documentation. This approach is recommended if you
* need to encode enums using a common field in their base trait, or another specific logic for extracting string representation.
*/
inline def customStringBased(encode: T => String)(using Mirror.Of[T]): Pickler[T] =
inline def customStringBased(encode: T => String)(using Mirror.SumOf[T]): Pickler[T] =
apply(
Some(encode),
schemaType = SchemaType.SString[T](),
Expand Down
29 changes: 15 additions & 14 deletions json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,20 @@ object Pickler:
)

private[pickler] inline def buildNewPickler[T: ClassTag]()(using m: Mirror.Of[T], c: Configuration): Pickler[T] =
// The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
inline m match {
case p: Mirror.ProductOf[T] => picklerProduct(p, childPicklers)
case _: Mirror.SumOf[T] =>
val schema: Schema[T] =
inline if (isEnumeration[T])
Schema.derivedEnumeration[T].defaultStringBased
else
Schema.derived[T]
given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]()
picklerSum(schema, childPicklers)
}
inline m match
case p: Mirror.ProductOf[T] =>
// The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
picklerProduct(p, childPicklers)
case sum: Mirror.SumOf[T] =>
inline if (isEnumeration[T])
new CreateDerivedEnumerationPickler(Validator.derivedEnumeration[T], SchemaAnnotations.derived[T]).defaultStringBased(using sum)
else
val schema = Schema.derived[T]
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]()
picklerSum(schema, childPicklers)


private[pickler] inline def summonChildPicklerInstances[T: ClassTag, Fields <: Tuple](using
m: Mirror.Of[T],
Expand Down Expand Up @@ -372,7 +373,7 @@ object Pickler:
subtypeDiscriminator
)
override lazy val reader: Reader[T] =
macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader), subtypeDiscriminator)
macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader).productIterator.toList, subtypeDiscriminator)

}
new Pickler[T](tapirPickle, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper
else if upickleMacros.isMemberOfSealedHierarchy[T] then annotate[T](reader, upickleMacros.tagName[T])
else reader

inline def macroSumR[T](derivedChildReaders: Tuple, subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] =
inline def macroSumR[T](derivedChildReaders: List[Any], subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] =
implicit val currentlyDeriving: _root_.upickle.core.CurrentlyDeriving[T] = new _root_.upickle.core.CurrentlyDeriving()
subtypeDiscriminator match {
case discriminator: CustomSubtypeDiscriminator[T] =>
Expand All @@ -70,13 +70,13 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper
new TaggedReader.Node[T](readersFromMapping.asInstanceOf[Seq[TaggedReader[T]]]: _*)
case discriminator: EnumValueDiscriminator[T] =>
val readersForPossibleValues: Seq[TaggedReader[T]] =
discriminator.validator.possibleValues.zip(derivedChildReaders.toList).map { case (enumValue, reader) =>
discriminator.validator.possibleValues.zip(derivedChildReaders).map { case (enumValue, reader) =>
TaggedReader.Leaf[T](discriminator.encode(enumValue), reader.asInstanceOf[LeafWrapper[_]].r.asInstanceOf[Reader[T]])
}
new TaggedReader.Node[T](readersForPossibleValues: _*)

case _: DefaultSubtypeDiscriminator[T] =>
val readers = derivedChildReaders.toList.asInstanceOf[List[TaggedReader[T]]]
val readers = derivedChildReaders.asInstanceOf[List[TaggedReader[T]]]
Reader.merge(readers: _*)
}
}
11 changes: 11 additions & 0 deletions json/pickler/src/test/scala/sttp/tapir/json/pickler/Fixtures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ object Fixtures:
}

case class StatusResponse(status: Status)

case class SealedVariantContainer(v: SealedVariant)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also introduce a test case with enums (parameterless and not)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unless we already don't have one :) )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have such tests

  1. Parameterless enums
  2. With parameters
    it should "support sealed hierarchies looking like enums" in {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, sorry, I didn't unwrap enough context :) But maybe if they will be in separate files it will be easier to see what's already there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(btw. open api docs tests are grouped in a similar way - each test has its own test classes in the companion object etc.)


sealed trait SealedVariant
case object VariantA extends SealedVariant
case object VariantB extends SealedVariant
case object VariantC extends SealedVariant

sealed trait NotAllSealedVariant
case object NotAllSealedVariantA extends NotAllSealedVariant
case class NotAllSealedVariantB(innerField: Int) extends NotAllSealedVariant
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sttp.tapir.json.pickler

import _root_.upickle.{default => udefault}
import magnolia1.SealedTrait
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import sttp.tapir.DecodeResult.Value
Expand Down Expand Up @@ -360,28 +361,6 @@ class PicklerTest extends AnyFlatSpec with Matchers {
decoded shouldBe Value(inputObject)
}

it should "work2" in {
sealed trait Entity {
def kind: String
}
case class Person(firstName: String, lastName: String) extends Entity {
def kind: String = "person"
}
case class Organization(name: String) extends Entity {
def kind: String = "org"
}

import sttp.tapir.*
import sttp.tapir.json.*

val pPerson = Pickler.derived[Person]
val pOrganization = Pickler.derived[Organization]
given pEntity: Pickler[Entity] =
Pickler.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> pPerson, "org" -> pOrganization)

// { "$type": "person", "firstName": "Jessica", "lastName": "West" }
pEntity.toCodec.encode(Person("Jessica", "West"))
}
it should "Set discriminator value using oneOfUsingField" in {
// given
val picklerOk = Pickler.derived[StatusOk]
Expand Down Expand Up @@ -523,6 +502,21 @@ class PicklerTest extends AnyFlatSpec with Matchers {
codec.decode(encoded) shouldBe Value(inputObj)
}

it should "handle sealed hierarchies consisting of objects only" in {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should somehow separate tests per-feature, here: for enumerations (direct, wrapped, customised)

// given
import generic.auto.* // for Pickler auto-derivation
val inputObj = SealedVariantContainer(VariantA)

// when
val pickler = Pickler.derived[SealedVariantContainer]
val codec = pickler.toCodec
val encoded = codec.encode(inputObj)

// then
encoded shouldBe """{"v":"VariantA"}"""

}

it should "handle value classes" in {
// when
val pickler = Pickler.derived[ClassWithValues]
Expand Down