Skip to content

Commit

Permalink
Apply changes from #3222
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski committed Oct 3, 2023
1 parent 1228f5e commit 4a522b7
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 44 deletions.
18 changes: 5 additions & 13 deletions core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,14 @@ import scala.quoted.*

private[tapir] object EnumerationMacros:

transparent inline def isEnumeration[T]: Boolean =
// allChildrenObjectsOrEnumerationCases[T]
inline compiletime.erasedValue[T] match
case _: Null => false
case _: Nothing => false
case _: reflect.Enum => allChildrenObjectsOrEnumerationCases[T]
case _ => false

/** Checks whether type T has child types, and all children of type T are objects or enum cases or sealed parents of such. Useful for
* determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an
* enumeration in context of schemas and JSON codecs.
/** Checks whether type T is an "enumeration", which means either Scalas 3 enum, or a sealed hierarchy where all members are case objects
* or enum cases. Useful for deriving schemas and JSON codecs.
*/
inline def allChildrenObjectsOrEnumerationCases[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }
inline def isEnumeration[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }

def allChildrenObjectsOrEnumerationCasesImpl[T: Type](using q: Quotes): Expr[Boolean] =
Expr(enumerationTypeChildren[T](failOnError = false).forall(_.isDefined))
val typeChildren = enumerationTypeChildren[T](failOnError = false)
Expr(typeChildren.nonEmpty && !typeChildren.exists(_.isEmpty))

/** Recursively scans a symbol and builds a list of all children and their children, as long as all of them are objects or enum cases or
* sealed parents of such. Useful for determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package sttp.tapir.json.pickler

import _root_.upickle.implicits.{macros => upickleMacros}
import sttp.tapir.generic.Configuration
import sttp.tapir.macros.CreateDerivedEnumerationSchema
import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator}
import upickle.core.{Annotator, Types}

import scala.deriving.Mirror
import scala.reflect.ClassTag

import compiletime.*

/** A builder allowing deriving Pickler for enums, used by [[Pickler.derivedEnumeration]]. Can be used to set non-standard encoding logic,
* schema type or default value for an enum.
*/
Expand All @@ -26,32 +30,62 @@ class CreateDerivedEnumerationPickler[T: ClassTag](
encode: Option[T => Any] = Some(identity),
schemaType: SchemaType[T] = SchemaType.SString[T](),
default: Option[T] = None
)(using m: Mirror.Of[T]): Pickler[T] = {
)(using m: Mirror.SumOf[T]): Pickler[T] = {
val schema: Schema[T] = new CreateDerivedEnumerationSchema(validator, schemaAnnotations).apply(
encode,
schemaType,
default
)
given Configuration = Configuration.default
given SubtypeDiscriminator[T] = EnumValueDiscriminator[T](
given subtypeDiscriminator: SubtypeDiscriminator[T] = EnumValueDiscriminator[T](
encode.map(_.andThen(_.toString)).getOrElse(_.toString),
validator
)
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = Pickler.summonChildPicklerInstances[T, m.MirroredElemTypes]
Pickler.picklerSum(schema, childPicklers)

lazy val childReadWriters = buildEnumerationReadWriters[T, m.MirroredElemTypes]
val tapirPickle = new TapirPickle[T] {
override lazy val reader: Reader[T] =
macroSumR[T](childReadWriters.map(_._1), subtypeDiscriminator)

override lazy val writer: Writer[T] =
macroSumW[T](childReadWriters.map(_._2), subtypeDiscriminator)
}
new Pickler[T](tapirPickle, schema)
}

private inline def buildEnumerationReadWriters[T: ClassTag, Cases <: Tuple]: List[(Types#Reader[_], Types#Writer[_])] =
inline erasedValue[Cases] match {
case _: (enumerationCase *: enumerationCasesTail) =>
val processedHead = readWriterForEnumerationCase[enumerationCase]
val processedTail = buildEnumerationReadWriters[T, enumerationCasesTail]
(processedHead +: processedTail)
case _: EmptyTuple.type => Nil
}

private inline def readWriterForEnumerationCase[C]: (Types#Reader[C], Types#Writer[C]) =
val pickle = new TapirPickle[C] {
// We probably don't need a separate TapirPickle for each C, this could be optimized.
// https://github.com/softwaremill/tapir/issues/3192
override lazy val writer = annotate[C](
SingletonWriter[C](null.asInstanceOf[C]),
upickleMacros.tagName[C],
Annotator.Checker.Val(upickleMacros.getSingleton[C])
)
override lazy val reader = annotate[C](SingletonReader[C](upickleMacros.getSingleton[C]), upickleMacros.tagName[C])
}
(pickle.reader, pickle.writer)

/** Creates the Pickler assuming the low-level representation is a `String`. The encoding function passes the object unchanged (which
* means `.toString` will be used to represent the enumeration in JSON and documentation). Typically you don't need to explicitly use
* `Pickler.derivedEnumeration[T].defaultStringBased`, as this is the default behavior of [[Pickler.derived]] for enums.
*/
inline def defaultStringBased(using Mirror.Of[T]) = apply()
inline def defaultStringBased(using Mirror.SumOf[T]) = apply()

/** Creates the Pickler assuming the low-level representation is a `String`. Provide your custom encoding function for representing an
* enum value as a String. It will be used to represent the enumeration in JSON and documentation. This approach is recommended if you
* need to encode enums using a common field in their base trait, or another specific logic for extracting string representation.
*/
inline def customStringBased(encode: T => String)(using Mirror.Of[T]): Pickler[T] =
inline def customStringBased(encode: T => String)(using Mirror.SumOf[T]): Pickler[T] =
apply(
Some(encode),
schemaType = SchemaType.SString[T](),
Expand Down
37 changes: 15 additions & 22 deletions json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ object Pickler:

inline given picklerForAnyVal[T <: AnyVal]: Pickler[T] = ${ picklerForAnyValImpl[T] }

//

private inline def errorForType[T](inline template: String): Null = ${ errorForTypeImpl[T]('template) }

private def errorForTypeImpl[T: Type](template: Expr[String])(using Quotes): Expr[Null] = {
Expand Down Expand Up @@ -277,20 +275,20 @@ object Pickler:
)

private[pickler] inline def buildNewPickler[T: ClassTag]()(using m: Mirror.Of[T], c: Configuration): Pickler[T] =
// The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
inline m match {
case p: Mirror.ProductOf[T] => picklerProduct(p, childPicklers)
case _: Mirror.SumOf[T] =>
val schema: Schema[T] =
inline if (isEnumeration[T]) Schema.derivedEnumeration[T].defaultStringBased
else
lazy val derivedChildSchemas: Tuple.Map[m.MirroredElemTypes, Schema] =
childPicklers.map([t] => (p: t) => p.asInstanceOf[Pickler[t]].schema).asInstanceOf[Tuple.Map[m.MirroredElemTypes, Schema]]
coproductSchema(derivedChildSchemas)
given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]()
picklerSum(schema, childPicklers)
}
inline m match
case p: Mirror.ProductOf[T] =>
// The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
picklerProduct(p, childPicklers)
case sum: Mirror.SumOf[T] =>
inline if (isEnumeration[T])
new CreateDerivedEnumerationPickler(Validator.derivedEnumeration[T], SchemaAnnotations.derived[T]).defaultStringBased(using sum)
else
val schema = Schema.derived[T]
lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes]
given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]()
picklerSum(schema, childPicklers)


private[pickler] inline def summonChildPicklerInstances[T: ClassTag, Fields <: Tuple](using
m: Mirror.Of[T],
Expand Down Expand Up @@ -360,11 +358,6 @@ object Pickler:
): Schema[T] =
SchemaDerivation.productSchema(genericDerivationConfig, childSchemas)

private inline def coproductSchema[T, TFields <: Tuple](childSchemas: Tuple.Map[TFields, Schema])(using
genericDerivationConfig: Configuration
): Schema[T] =
SchemaDerivation.coproductSchema(genericDerivationConfig, childSchemas)

private[tapir] inline def picklerSum[T: ClassTag, CP <: Tuple](schema: Schema[T], childPicklers: => CP)(using
m: Mirror.Of[T],
config: Configuration,
Expand All @@ -378,7 +371,7 @@ object Pickler:
subtypeDiscriminator
)
override lazy val reader: Reader[T] =
macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader), subtypeDiscriminator)
macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader).productIterator.toList, subtypeDiscriminator)

}
new Pickler[T](tapirPickle, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper
else if upickleMacros.isMemberOfSealedHierarchy[T] then annotate[T](reader, upickleMacros.tagName[T])
else reader

inline def macroSumR[T](derivedChildReaders: Tuple, subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] =
inline def macroSumR[T](derivedChildReaders: List[Any], subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] =
implicit val currentlyDeriving: _root_.upickle.core.CurrentlyDeriving[T] = new _root_.upickle.core.CurrentlyDeriving()
subtypeDiscriminator match {
case discriminator: CustomSubtypeDiscriminator[T] =>
Expand All @@ -70,13 +70,13 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper
new TaggedReader.Node[T](readersFromMapping.asInstanceOf[Seq[TaggedReader[T]]]: _*)
case discriminator: EnumValueDiscriminator[T] =>
val readersForPossibleValues: Seq[TaggedReader[T]] =
discriminator.validator.possibleValues.zip(derivedChildReaders.toList).map { case (enumValue, reader) =>
discriminator.validator.possibleValues.zip(derivedChildReaders).map { case (enumValue, reader) =>
TaggedReader.Leaf[T](discriminator.encode(enumValue), reader.asInstanceOf[LeafWrapper[_]].r.asInstanceOf[Reader[T]])
}
new TaggedReader.Node[T](readersForPossibleValues: _*)

case _: DefaultSubtypeDiscriminator[T] =>
val readers = derivedChildReaders.toList.asInstanceOf[List[TaggedReader[T]]]
val readers = derivedChildReaders.asInstanceOf[List[TaggedReader[T]]]
Reader.merge(readers: _*)
}
}

0 comments on commit 4a522b7

Please sign in to comment.