Skip to content

Commit

Permalink
Merge pull request #3208 from softwaremill/is-scala-enum
Browse files Browse the repository at this point in the history
Refactor reused enum-based macros
  • Loading branch information
adamw authored Sep 28, 2023
2 parents ad396fc + 8d6575d commit 7174c40
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 43 deletions.
44 changes: 44 additions & 0 deletions core/src/main/scala-3/sttp/tapir/internal/EnumerationMacros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package sttp.tapir.internal

import scala.quoted.*

private[tapir] object EnumerationMacros:

transparent inline def isEnumeration[T]: Boolean = inline compiletime.erasedValue[T] match
case _: Null => false
case _: Nothing => false
case _: reflect.Enum => allChildrenObjectsOrEnumerationCases[T]
case _ => false

/** Checks whether type T has child types, and all children of type T are objects or enum cases or sealed parents of such. Useful for
* determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an
* enumeration in context of schemas and JSON codecs.
*/
inline def allChildrenObjectsOrEnumerationCases[T]: Boolean = ${ allChildrenObjectsOrEnumerationCasesImpl[T] }

def allChildrenObjectsOrEnumerationCasesImpl[T: Type](using q: Quotes): Expr[Boolean] =
Expr(enumerationTypeChildren[T](failOnError = false).nonEmpty)

/** Recursively scans a symbol and builds a list of all children and their children, as long as all of them are objects or enum cases or
* sealed parents of such. Useful for determining whether an enum is indeed an enum, or will be desugared to a sealed hierarchy, in which
* case it's not really an enumeration (in context of schemas and JSON codecs).
*/
def enumerationTypeChildren[T: Type](failOnError: Boolean)(using q: Quotes): List[q.reflect.Symbol] =
import quotes.reflect.*

val tpe = TypeRepr.of[T]
val symbol = tpe.typeSymbol

def flatChildren(s: Symbol): List[Symbol] = s.children.toList.flatMap { c =>
if (c.isClassDef) {
if (!(c.flags is Flags.Sealed))
if (failOnError)
report.errorAndAbort("All children must be objects or enum cases, or sealed parent of such.")
else
Nil
else
flatChildren(c)
} else List(c)
}

flatChildren(symbol)
16 changes: 3 additions & 13 deletions core/src/main/scala-3/sttp/tapir/macros/ValidatorMacros.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package sttp.tapir.macros

import sttp.tapir.internal.EnumerationMacros.*
import sttp.tapir.Validator
import sttp.tapir.{Schema, SchemaType}

import scala.compiletime
import sttp.tapir.Schema

import scala.quoted.*

Expand All @@ -27,16 +26,7 @@ private[tapir] object ValidatorMacros {
report.errorAndAbort("Can only enumerate values of a sealed trait, class or enum.")
}

def flatChildren(s: Symbol): List[Symbol] = s.children.toList.flatMap { c =>
if (c.isClassDef) {
if (!(c.flags is Flags.Sealed))
report.errorAndAbort("All children must be objects or enum cases, or sealed parent of such.")
else
flatChildren(c)
} else List(c)
}

val instances = flatChildren(symbol).distinct
val instances = enumerationTypeChildren[T](failOnError = true).distinct
.sortBy(_.name)
.map(x =>
tpe.memberType(x).asType match {
Expand Down
17 changes: 13 additions & 4 deletions doc/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

All suggestions welcome :)!

If you'd like to contribute, see the list of [issues](https://github.com/softwaremill/tapir/issues) and pick one!
If you'd like to contribute, see the list of [issues](https://github.com/softwaremill/tapir/issues) and pick one!
Or report your own. If you have an idea you'd like to discuss, that's always a good option.

If you are having doubts on the *why* or *how* something works, don't hesitate to ask a question on
[discourse](https://softwaremill.community/c/tapir) or via github. This probably means that the documentation, scaladocs or
If you are having doubts on the _why_ or _how_ something works, don't hesitate to ask a question on
[discourse](https://softwaremill.community/c/tapir) or via github. This probably means that the documentation, scaladocs or
code is unclear and can be improved for the benefit of all.

## Conventions

### Enumerations

Scala 3 introduces `enum`, which can be used to represent sealed hierarchies with simpler syntax, or actual "true" enumerations,
that is parameterless enums or sealed traits with only case objects as children. Tapir needs to treat the latter differently,
in order to allow using OpenAPI `enum` elements and derive JSON codecs which represent them as simple values (without discriminator).
Let's use the name `enumeration` in Tapir codebase to represent these "true" enumerations and avoid ambiguity.

## Acknowledgments

Tuple-concatenating code is copied from [akka-http](https://github.com/akka/akka-http/blob/master/akka-http/src/main/scala/akka/http/scaladsl/server/util/TupleOps.scala)
Expand All @@ -17,4 +26,4 @@ Parts of generic derivation configuration is copied from [circe](https://github.

Implementation of mirror for union and intersection types are originally implemented by [Iltotore](https://github.com/Iltotore) in [this gist](https://gist.github.com/Iltotore/eece20188d383f7aee16a0b89eeb887f)

Tapir logo & stickers have been drawn by [impurepics](https://twitter.com/impurepics).
Tapir logo & stickers have been drawn by [impurepics](https://twitter.com/impurepics).
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sttp.tapir.json.pickler

import sttp.tapir.internal.EnumerationMacros.*
import sttp.tapir.Codec.JsonCodec
import sttp.tapir.DecodeResult.Error.JsonDecodeException
import sttp.tapir.DecodeResult.{Error, Value}
Expand Down Expand Up @@ -68,7 +69,7 @@ object Pickler:
s"Unexpected product type (case class) ${implicitly[ClassTag[T]].runtimeClass.getSimpleName()}, this method should only be used with sum types (like sealed hierarchy)"
)
case _: Mirror.SumOf[T] =>
inline if (isScalaEnum[T])
inline if (isEnumeration[T])
error("oneOfUsingField cannot be used with enums. Try Pickler.derivedEnumeration instead.")
else {
given schemaV: Schema[V] = discriminatorPickler.schema
Expand Down Expand Up @@ -277,7 +278,7 @@ object Pickler:
case p: Mirror.ProductOf[T] => picklerProduct(p, childPicklers)
case _: Mirror.SumOf[T] =>
val schema: Schema[T] =
inline if (isScalaEnum[T])
inline if (isEnumeration[T])
Schema.derivedEnumeration[T].defaultStringBased
else
Schema.derived[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sttp.tapir.json.pickler
import _root_.upickle.core.Annotator.Checker
import _root_.upickle.core.{ObjVisitor, Visitor, _}
import _root_.upickle.implicits.{WritersVersionSpecific, macros => upickleMacros}
import sttp.tapir.internal.EnumerationMacros.*
import sttp.tapir.Schema
import sttp.tapir.SchemaType.SProduct
import sttp.tapir.generic.Configuration
Expand Down Expand Up @@ -64,7 +65,7 @@ private[pickler] trait Writers extends WritersVersionSpecific with UpickleHelper
)
}

inline if upickleMacros.isMemberOfSealedHierarchy[T] && !macros.isScalaEnum[T] then
inline if upickleMacros.isMemberOfSealedHierarchy[T] && !isEnumeration[T] then
annotate[T](
writer,
upickleMacros.tagName[T],
Expand All @@ -76,8 +77,8 @@ private[pickler] trait Writers extends WritersVersionSpecific with UpickleHelper
annotate[T](SingletonWriter[T](null.asInstanceOf[T]), upickleMacros.tagName[T], Annotator.Checker.Val(upickleMacros.getSingleton[T]))
else writer

inline def macroSumW[T: ClassTag](inline childWriters: => List[Any], subtypeDiscriminator: SubtypeDiscriminator[T])(
using Configuration
inline def macroSumW[T: ClassTag](inline childWriters: => List[Any], subtypeDiscriminator: SubtypeDiscriminator[T])(using
Configuration
) =
implicit val currentlyDeriving: _root_.upickle.core.CurrentlyDeriving[T] = new _root_.upickle.core.CurrentlyDeriving()
val writers: List[TaggedWriter[_ <: T]] = childWriters
Expand Down
22 changes: 1 addition & 21 deletions json/pickler/src/main/scala/sttp/tapir/json/pickler/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package sttp.tapir.json.pickler

import _root_.upickle.implicits.*
import _root_.upickle.implicits.{macros => uMacros}
import sttp.tapir.internal.EnumerationMacros.*
import sttp.tapir.SchemaType
import sttp.tapir.SchemaType.SProduct

import scala.quoted.*

import compiletime.*

/** Macros, mostly copied from uPickle, and modified to allow our customizations like passing writers/readers as parameters, adjusting
Expand Down Expand Up @@ -92,23 +92,3 @@ private[pickler] object macros:

Expr.block(statements, '{})
}

transparent inline def isScalaEnum[X]: Boolean = inline compiletime.erasedValue[X] match
case _: Null => false
case _: Nothing => false
case _: reflect.Enum => allChildrenObjectsOrEnumCases[X]
case _ => false

/** Checks whether all children of type T are objects or enum cases or sealed parents of such. Useful for determining whether an enum is
* indeed an enum, or will be desugared to a sealed hierarchy, in which case it's not really an enumeration in context of schemas and
* JSON codecs.
*/
inline def allChildrenObjectsOrEnumCases[T]: Boolean = ${ allChildrenObjectsOrEnumCasesImpl[T] }

import scala.quoted._

def allChildrenObjectsOrEnumCasesImpl[T: Type](using q: Quotes): Expr[Boolean] =
import quotes.reflect.*
val tpe = TypeRepr.of[T]
val symbol = tpe.typeSymbol
Expr(symbol.children.nonEmpty && !symbol.children.exists(c => c.isClassDef && !(c.flags is Flags.Sealed)))

0 comments on commit 7174c40

Please sign in to comment.