Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Pickler for coproducts and enums #3222

Merged
merged 12 commits into from
Oct 6, 2023
24 changes: 9 additions & 15 deletions core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,35 @@ import scala.quoted.*

private[tapir] object EnumerationMacros:

transparent inline def isEnumeration[T]: Boolean = inline compiletime.erasedValue[T] match
case _: Null => false
case _: Nothing => false
case _: reflect.Enum => allChildrenObjectsOrEnumerationCases[T]
case _ => false

/** Checks whether type T has child types, and all children of type T are objects or enum cases or sealed parents of such. Useful for
* determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an
* enumeration in context of schemas and JSON codecs.
/** Checks whether type T is an "enumeration", which means either Scalas 3 enum, or a sealed hierarchy where all members are case objects
* or enum cases. Useful for deriving schemas and JSON codecs.
*/
inline def allChildrenObjectsOrEnumerationCases[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }
inline def isEnumeration[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }

def allChildrenObjectsOrEnumerationCasesImpl[T: Type](using q: Quotes): Expr[Boolean] =
Expr(enumerationTypeChildren[T](failOnError = false).nonEmpty)
val typeChildren = enumerationTypeChildren[T](failOnError = false)
Expr(typeChildren.nonEmpty && !typeChildren.exists(_.isEmpty))

/** Recursively scans a symbol and builds a list of all children and their children, as long as all of them are objects or enum cases or
* sealed parents of such. Useful for determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which
* case it's not really an enumeration (in context of schemas and JSON codecs).
*/
def enumerationTypeChildren[T: Type](failOnError: Boolean)(using q: Quotes): List[q.reflect.Symbol] =
def enumerationTypeChildren[T: Type](failOnError: Boolean)(using q: Quotes): List[Option[q.reflect.Symbol]] =
import quotes.reflect.*

val tpe = TypeRepr.of[T]
val symbol = tpe.typeSymbol

def flatChildren(s: Symbol): List[Symbol] = s.children.toList.flatMap { c =>
def flatChildren(s: Symbol): List[Option[Symbol]] = s.children.toList.flatMap { c =>
if (c.isClassDef) {
if (!(c.flags is Flags.Sealed))
if (failOnError)
report.errorAndAbort("All children must be objects or enum cases, or sealed parent of such.")
else
Nil
List(None)
else
flatChildren(c)
} else List(c)
} else List(Some(c))
}

flatChildren(symbol)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private[tapir] object ValidatorMacros {
report.errorAndAbort("Can only enumerate values of a sealed trait, class or enum.")
}

val instances = enumerationTypeChildren[T](failOnError = true).distinct
val instances = enumerationTypeChildren[T](failOnError = true).flatMap(_.toList).distinct
.sortBy(_.name)
.map(x =>
tpe.memberType(x).asType match {
Expand Down
67 changes: 57 additions & 10 deletions doc/endpoint/pickler.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,42 @@ However, this can negatively impact compilation performance, as the same pickler

## Configuring pickler derivation

It is possible to configure schema and codec derivation by providing an implicit `sttp.tapir.generic.Configuration`, just as for standalone [schema derivation](schemas.md). This configuration allows switching field naming policy to `snake_case`, `kebab_case`, or an arbitrary transformation function, as well as setting the field name for the coproduct (sealed hierarchy) type discriminator, which is discussed in details in further sections.
It is possible to configure schema and codec derivation by providing an implicit `sttp.tapir.pickler.PicklerConfiguration`. This configuration allows switching field naming policy to `snake_case`, `kebab_case`, or an arbitrary transformation function, as well as setting the field name/value for the coproduct (sealed hierarchy) type discriminator, which is discussed in details in further sections.

```scala
import sttp.tapir.generic.Configuration
import sttp.tapir.pickler.PicklerConfiguration

given customConfiguration: Configuration = Configuration.default.withSnakeCaseMemberNames
given customConfiguration: PicklerConfiguration =
PicklerConfiguration
.default
.withSnakeCaseMemberNames
```

## Enums / sealed traits / coproducts

Pickler derivation for coproduct types (enums / sealed hierarchies) works automatically, by adding an `$type` discriminator field with the full class name. This is the default behavior of uPickle, but it can be overridden either by changing the discriminator field name, or by using custom logic to get field value from base trait.
Pickler derivation for coproduct types (enums with parameters / sealed hierarchies) works automatically, by adding a `$type` discriminator field with the short class name.

```scala
import sttp.tapir.pickler.PicklerConfiguration

// encodes a case object as { "$type": "MyType" }
given PicklerConfiguration = PicklerConfiguration.default
```

This behavior can be overridden either by changing the discriminator field name, or by using custom logic to get field value from base trait.
Selaed hierarchies with all cases being objects are treated differently, considered as [enumerations](#enumerations).

A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used during automatic and semi-automatic derivation:

```scala
import sttp.tapir.generic.Configuration

given customConfiguration: Configuration =
Configuration.default.withDiscriminator("who_am_i")
import sttp.tapir.pickler.PicklerConfiguration

// encodes a case object as { "who_am_i": "full.pkg.path.MyType" }
given customConfiguration: PicklerConfiguration =
PicklerConfiguration
.default
.withDiscriminator("who_am_i")
.withFullDiscriminatorValues
```

The discriminator will be added as a field to all coproduct child codecs and schemas, if it’s not yet present. The schema of the added field will always be a Schema.string. Finally, the mapping between the discriminator field values and the child schemas will be generated using `Configuration.toDiscriminatorValue(childSchemaName)`.
Expand Down Expand Up @@ -136,14 +153,23 @@ Schemas generated by picklers can be customized using annotations, just like wit

## Enumerations

Scala 3 `enums`, where all cases are parameterless, are treated as an enumeration (not as a coproduct / sealed hierarchy). They are also automatically handled by `Pickler.derived[T]`: enum values are encoded as simple strings representing the type name. For example:
Tapir schemas and JSON codecs treats following cases as "enumerations":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

1. Scala 3 `enums`, where all cases are parameterless
2. Sealed hierarchies (coproducts), where all cases are case objects

Such types are handled by `Pickler.derived[T]`: possible values are encoded as simple strings representing the case objects. For example:

```scala
import sttp.tapir.json.pickler.*

enum ColorEnum:
case Green, Pink

// or:
// sealed trait ColorEnum
// case object Green extends ColorEnum
// case object Pink extends ColorEnum

case class ColorResponse(color: ColorEnum, description: String)

given Pickler[ColorEnum] = Pickler.derived
Expand All @@ -157,7 +183,27 @@ pResponse.toCodec.encode(
pResponse.schema
```

If you need to customize enum value encoding, use `Pickler.derivedEnumeration[T]`:
If sealed hierarchy or enum contain case classes with parameters, they are no longer an "enumeration", and will be treated as standard sealed hierarchies (coproducts):

```scala
import sttp.tapir.json.pickler.*

sealed trait ColorEnum
case object Green extends ColorEnum
case class Pink(intensity: Int) extends ColorEnum

case class ColorResponse(color1: ColorEnum, color2: ColorEnum)

given Pickler[ColorEnum] = Pickler.derived
val pResponse = Pickler.derived[ColorResponse]

// {"color1":{"$type":"Pink","intensity":85},"color2":{"$type":"Green"}}
pResponse.toCodec.encode(
ColorResponse(Pink(85), Green)
)
```

If you need to customize enumeration value encoding, use `Pickler.derivedEnumeration[T]`:

```scala
import sttp.tapir.json.pickler.*
Expand Down Expand Up @@ -191,4 +237,5 @@ you can proceed with `Pickler.derived[T]`.

* Tapir pickler serialises None values as `null`, instead of wrapping the value in an array
* Value classes (case classes extending AnyVal) will be serialised as simple values
* Discriminator field value is a short class name, instead of full package with class name

Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package sttp.tapir.json.pickler

import sttp.tapir.generic.Configuration
import _root_.upickle.implicits.{macros => upickleMacros}
import sttp.tapir.macros.CreateDerivedEnumerationSchema
import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator}
import upickle.core.{Annotator, Types}

import scala.deriving.Mirror
import scala.reflect.ClassTag

import compiletime.*

/** A builder allowing deriving Pickler for enums, used by [[Pickler.derivedEnumeration]]. Can be used to set non-standard encoding logic,
* schema type or default value for an enum.
*/
Expand All @@ -23,37 +26,77 @@ class CreateDerivedEnumerationPickler[T: ClassTag](
* The low-level representation of the enumeration. Defaults to a string.
*/
inline def apply(
encode: Option[T => Any] = Some(identity),
encode: T => Any = identity,
schemaType: SchemaType[T] = SchemaType.SString[T](),
default: Option[T] = None
)(using m: Mirror.Of[T]): Pickler[T] = {
)(using m: Mirror.SumOf[T]): Pickler[T] = {
val schema: Schema[T] = new CreateDerivedEnumerationSchema(validator, schemaAnnotations).apply(
encode,
Some(encode),
schemaType,
default
)
given Configuration = Configuration.default
given SubtypeDiscriminator[T] = EnumValueDiscriminator[T](
encode.map(_.andThen(_.toString)).getOrElse(_.toString),
validator
)
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = Pickler.summonChildPicklerInstances[T, m.MirroredElemTypes]
Pickler.picklerSum(schema, childPicklers)
lazy val childReadWriters = buildEnumerationReadWriters[T, m.MirroredElemTypes]
val tapirPickle = new TapirPickle[T] {
override lazy val reader: Reader[T] = {
val readersForPossibleValues: Seq[TaggedReader[T]] =
validator.possibleValues.zip(childReadWriters.map(_._1)).map { case (enumValue, reader) =>
TaggedReader.Leaf[T](encode(enumValue).toString, reader.asInstanceOf[LeafWrapper[_]].r.asInstanceOf[Reader[T]])
}
new TaggedReader.Node[T](readersForPossibleValues: _*)
}

override lazy val writer: Writer[T] =
new TaggedWriter.Node[T](childReadWriters.map(_._2.asInstanceOf[TaggedWriter[T]]): _*) {
override def findWriter(v: Any): (String, ObjectWriter[T]) =
val (t, writer) = super.findWriter(v)
// Here our custom encoding transforms the value of a singleton object
val overriddenTag = encode(v.asInstanceOf[T]).toString
(overriddenTag, writer)
}
}
new Pickler[T](tapirPickle, schema)
}

private inline def buildEnumerationReadWriters[T: ClassTag, Cases <: Tuple]: List[(Types#Reader[_], Types#Writer[_])] =
inline erasedValue[Cases] match {
case _: (enumerationCase *: enumerationCasesTail) =>
val processedHead = readWriterForEnumerationCase[enumerationCase]
val processedTail = buildEnumerationReadWriters[T, enumerationCasesTail]
(processedHead +: processedTail)
case _: EmptyTuple.type => Nil
}

/** Enumeration cases and case objects in an enumeration need special writers and readers, which are generated here, instead of being
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, much clearer now :)

* taken from child picklers. For example, for enum Color and case values Red and Blue, a Writer should just use the object Red or Blue
* and serialize it to "Red" or "Blue". If user needs to encode the singleton object using a custom function, this happens on a higher
* level - the top level of coproduct reader and writer.
*/
private inline def readWriterForEnumerationCase[C]: (Types#Reader[C], Types#Writer[C]) =
val pickle = new TapirPickle[C] {
// We probably don't need a separate TapirPickle for each C, this could be optimized.
// https://github.com/softwaremill/tapir/issues/3192
override lazy val writer = annotate[C](
SingletonWriter[C](null.asInstanceOf[C]),
upickleMacros.tagName[C],
Annotator.Checker.Val(upickleMacros.getSingleton[C])
)
override lazy val reader = annotate[C](SingletonReader[C](upickleMacros.getSingleton[C]), upickleMacros.tagName[C])
}
(pickle.reader, pickle.writer)

/** Creates the Pickler assuming the low-level representation is a `String`. The encoding function passes the object unchanged (which
* means `.toString` will be used to represent the enumeration in JSON and documentation). Typically you don't need to explicitly use
* `Pickler.derivedEnumeration[T].defaultStringBased`, as this is the default behavior of [[Pickler.derived]] for enums.
*/
inline def defaultStringBased(using Mirror.Of[T]) = apply()
inline def defaultStringBased(using Mirror.SumOf[T]) = apply()

/** Creates the Pickler assuming the low-level representation is a `String`. Provide your custom encoding function for representing an
* enum value as a String. It will be used to represent the enumeration in JSON and documentation. This approach is recommended if you
* need to encode enums using a common field in their base trait, or another specific logic for extracting string representation.
*/
inline def customStringBased(encode: T => String)(using Mirror.Of[T]): Pickler[T] =
inline def customStringBased(encode: T => String)(using Mirror.SumOf[T]): Pickler[T] =
apply(
Some(encode),
encode,
schemaType = SchemaType.SString[T](),
default = None
)
Loading
Loading