From 4a522b767a7282b1aee04ad2c2ab0b700311f096 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 3 Oct 2023 13:38:33 +0200 Subject: [PATCH] Apply changes from #3222 --- .../tapir/internal/EnumerationMacros.scala | 18 ++------ .../CreateDerivedEnumerationPickler.scala | 46 ++++++++++++++++--- .../sttp/tapir/json/pickler/Pickler.scala | 37 ++++++--------- .../sttp/tapir/json/pickler/Readers.scala | 6 +-- 4 files changed, 63 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala b/core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala index 6efa99bcc5..63dd51eea9 100644 --- a/core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala +++ b/core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala @@ -4,22 +4,14 @@ import scala.quoted.* private[tapir] object EnumerationMacros: - transparent inline def isEnumeration[T]: Boolean = - // allChildrenObjectsOrEnumerationCases[T] - inline compiletime.erasedValue[T] match - case _: Null => false - case _: Nothing => false - case _: reflect.Enum => allChildrenObjectsOrEnumerationCases[T] - case _ => false - - /** Checks whether type T has child types, and all children of type T are objects or enum cases or sealed parents of such. Useful for - * determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an - * enumeration in context of schemas and JSON codecs. + /** Checks whether type T is an "enumeration", which means either Scalas 3 enum, or a sealed hierarchy where all members are case objects + * or enum cases. Useful for deriving schemas and JSON codecs. */ - inline def allChildrenObjectsOrEnumerationCases[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] } + inline def isEnumeration[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] } def allChildrenObjectsOrEnumerationCasesImpl[T: Type](using q: Quotes): Expr[Boolean] = - Expr(enumerationTypeChildren[T](failOnError = false).forall(_.isDefined)) + val typeChildren = enumerationTypeChildren[T](failOnError = false) + Expr(typeChildren.nonEmpty && !typeChildren.exists(_.isEmpty)) /** Recursively scans a symbol and builds a list of all children and their children, as long as all of them are objects or enum cases or * sealed parents of such. Useful for determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which diff --git a/json/pickler/src/main/scala/sttp/tapir/json/pickler/CreateDerivedEnumerationPickler.scala b/json/pickler/src/main/scala/sttp/tapir/json/pickler/CreateDerivedEnumerationPickler.scala index cbb75c4bb3..e4d5896e0a 100644 --- a/json/pickler/src/main/scala/sttp/tapir/json/pickler/CreateDerivedEnumerationPickler.scala +++ b/json/pickler/src/main/scala/sttp/tapir/json/pickler/CreateDerivedEnumerationPickler.scala @@ -1,12 +1,16 @@ package sttp.tapir.json.pickler +import _root_.upickle.implicits.{macros => upickleMacros} import sttp.tapir.generic.Configuration import sttp.tapir.macros.CreateDerivedEnumerationSchema import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator} +import upickle.core.{Annotator, Types} import scala.deriving.Mirror import scala.reflect.ClassTag +import compiletime.* + /** A builder allowing deriving Pickler for enums, used by [[Pickler.derivedEnumeration]]. Can be used to set non-standard encoding logic, * schema type or default value for an enum. */ @@ -26,32 +30,62 @@ class CreateDerivedEnumerationPickler[T: ClassTag]( encode: Option[T => Any] = Some(identity), schemaType: SchemaType[T] = SchemaType.SString[T](), default: Option[T] = None - )(using m: Mirror.Of[T]): Pickler[T] = { + )(using m: Mirror.SumOf[T]): Pickler[T] = { val schema: Schema[T] = new CreateDerivedEnumerationSchema(validator, schemaAnnotations).apply( encode, schemaType, default ) given Configuration = Configuration.default - given SubtypeDiscriminator[T] = EnumValueDiscriminator[T]( + given subtypeDiscriminator: SubtypeDiscriminator[T] = EnumValueDiscriminator[T]( encode.map(_.andThen(_.toString)).getOrElse(_.toString), validator ) - lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = Pickler.summonChildPicklerInstances[T, m.MirroredElemTypes] - Pickler.picklerSum(schema, childPicklers) + + lazy val childReadWriters = buildEnumerationReadWriters[T, m.MirroredElemTypes] + val tapirPickle = new TapirPickle[T] { + override lazy val reader: Reader[T] = + macroSumR[T](childReadWriters.map(_._1), subtypeDiscriminator) + + override lazy val writer: Writer[T] = + macroSumW[T](childReadWriters.map(_._2), subtypeDiscriminator) + } + new Pickler[T](tapirPickle, schema) } + private inline def buildEnumerationReadWriters[T: ClassTag, Cases <: Tuple]: List[(Types#Reader[_], Types#Writer[_])] = + inline erasedValue[Cases] match { + case _: (enumerationCase *: enumerationCasesTail) => + val processedHead = readWriterForEnumerationCase[enumerationCase] + val processedTail = buildEnumerationReadWriters[T, enumerationCasesTail] + (processedHead +: processedTail) + case _: EmptyTuple.type => Nil + } + + private inline def readWriterForEnumerationCase[C]: (Types#Reader[C], Types#Writer[C]) = + val pickle = new TapirPickle[C] { + // We probably don't need a separate TapirPickle for each C, this could be optimized. + // https://github.com/softwaremill/tapir/issues/3192 + override lazy val writer = annotate[C]( + SingletonWriter[C](null.asInstanceOf[C]), + upickleMacros.tagName[C], + Annotator.Checker.Val(upickleMacros.getSingleton[C]) + ) + override lazy val reader = annotate[C](SingletonReader[C](upickleMacros.getSingleton[C]), upickleMacros.tagName[C]) + } + (pickle.reader, pickle.writer) + /** Creates the Pickler assuming the low-level representation is a `String`. The encoding function passes the object unchanged (which * means `.toString` will be used to represent the enumeration in JSON and documentation). Typically you don't need to explicitly use * `Pickler.derivedEnumeration[T].defaultStringBased`, as this is the default behavior of [[Pickler.derived]] for enums. */ - inline def defaultStringBased(using Mirror.Of[T]) = apply() + inline def defaultStringBased(using Mirror.SumOf[T]) = apply() /** Creates the Pickler assuming the low-level representation is a `String`. Provide your custom encoding function for representing an * enum value as a String. It will be used to represent the enumeration in JSON and documentation. This approach is recommended if you * need to encode enums using a common field in their base trait, or another specific logic for extracting string representation. */ - inline def customStringBased(encode: T => String)(using Mirror.Of[T]): Pickler[T] = + inline def customStringBased(encode: T => String)(using Mirror.SumOf[T]): Pickler[T] = apply( Some(encode), schemaType = SchemaType.SString[T](), diff --git a/json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala b/json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala index 822715587e..0f8bfbf7ae 100644 --- a/json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala +++ b/json/pickler/src/main/scala/sttp/tapir/json/pickler/Pickler.scala @@ -202,8 +202,6 @@ object Pickler: inline given picklerForAnyVal[T <: AnyVal]: Pickler[T] = ${ picklerForAnyValImpl[T] } - // - private inline def errorForType[T](inline template: String): Null = ${ errorForTypeImpl[T]('template) } private def errorForTypeImpl[T: Type](template: Expr[String])(using Quotes): Expr[Null] = { @@ -277,20 +275,20 @@ object Pickler: ) private[pickler] inline def buildNewPickler[T: ClassTag]()(using m: Mirror.Of[T], c: Configuration): Pickler[T] = - // The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst - lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes] - inline m match { - case p: Mirror.ProductOf[T] => picklerProduct(p, childPicklers) - case _: Mirror.SumOf[T] => - val schema: Schema[T] = - inline if (isEnumeration[T]) Schema.derivedEnumeration[T].defaultStringBased - else - lazy val derivedChildSchemas: Tuple.Map[m.MirroredElemTypes, Schema] = - childPicklers.map([t] => (p: t) => p.asInstanceOf[Pickler[t]].schema).asInstanceOf[Tuple.Map[m.MirroredElemTypes, Schema]] - coproductSchema(derivedChildSchemas) - given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]() - picklerSum(schema, childPicklers) - } + inline m match + case p: Mirror.ProductOf[T] => + // The lazy modifier is necessary for preventing infinite recursion in the derived instance for recursive types such as Lst + lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes] + picklerProduct(p, childPicklers) + case sum: Mirror.SumOf[T] => + inline if (isEnumeration[T]) + new CreateDerivedEnumerationPickler(Validator.derivedEnumeration[T], SchemaAnnotations.derived[T]).defaultStringBased(using sum) + else + val schema = Schema.derived[T] + lazy val childPicklers: Tuple.Map[m.MirroredElemTypes, Pickler] = summonChildPicklerInstances[T, m.MirroredElemTypes] + given SubtypeDiscriminator[T] = DefaultSubtypeDiscriminator[T]() + picklerSum(schema, childPicklers) + private[pickler] inline def summonChildPicklerInstances[T: ClassTag, Fields <: Tuple](using m: Mirror.Of[T], @@ -360,11 +358,6 @@ object Pickler: ): Schema[T] = SchemaDerivation.productSchema(genericDerivationConfig, childSchemas) - private inline def coproductSchema[T, TFields <: Tuple](childSchemas: Tuple.Map[TFields, Schema])(using - genericDerivationConfig: Configuration - ): Schema[T] = - SchemaDerivation.coproductSchema(genericDerivationConfig, childSchemas) - private[tapir] inline def picklerSum[T: ClassTag, CP <: Tuple](schema: Schema[T], childPicklers: => CP)(using m: Mirror.Of[T], config: Configuration, @@ -378,7 +371,7 @@ object Pickler: subtypeDiscriminator ) override lazy val reader: Reader[T] = - macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader), subtypeDiscriminator) + macroSumR[T](childPicklers.map([a] => (obj: a) => obj.asInstanceOf[Pickler[a]].innerUpickle.reader).productIterator.toList, subtypeDiscriminator) } new Pickler[T](tapirPickle, schema) diff --git a/json/pickler/src/main/scala/sttp/tapir/json/pickler/Readers.scala b/json/pickler/src/main/scala/sttp/tapir/json/pickler/Readers.scala index 9ba80ef6bd..1dd2816937 100644 --- a/json/pickler/src/main/scala/sttp/tapir/json/pickler/Readers.scala +++ b/json/pickler/src/main/scala/sttp/tapir/json/pickler/Readers.scala @@ -50,7 +50,7 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper else if upickleMacros.isMemberOfSealedHierarchy[T] then annotate[T](reader, upickleMacros.tagName[T]) else reader - inline def macroSumR[T](derivedChildReaders: Tuple, subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] = + inline def macroSumR[T](derivedChildReaders: List[Any], subtypeDiscriminator: SubtypeDiscriminator[T]): Reader[T] = implicit val currentlyDeriving: _root_.upickle.core.CurrentlyDeriving[T] = new _root_.upickle.core.CurrentlyDeriving() subtypeDiscriminator match { case discriminator: CustomSubtypeDiscriminator[T] => @@ -70,13 +70,13 @@ private[pickler] trait Readers extends ReadersVersionSpecific with UpickleHelper new TaggedReader.Node[T](readersFromMapping.asInstanceOf[Seq[TaggedReader[T]]]: _*) case discriminator: EnumValueDiscriminator[T] => val readersForPossibleValues: Seq[TaggedReader[T]] = - discriminator.validator.possibleValues.zip(derivedChildReaders.toList).map { case (enumValue, reader) => + discriminator.validator.possibleValues.zip(derivedChildReaders).map { case (enumValue, reader) => TaggedReader.Leaf[T](discriminator.encode(enumValue), reader.asInstanceOf[LeafWrapper[_]].r.asInstanceOf[Reader[T]]) } new TaggedReader.Node[T](readersForPossibleValues: _*) case _: DefaultSubtypeDiscriminator[T] => - val readers = derivedChildReaders.toList.asInstanceOf[List[TaggedReader[T]]] + val readers = derivedChildReaders.asInstanceOf[List[TaggedReader[T]]] Reader.merge(readers: _*) } }