Skip to content

Commit

Permalink
feat: add string-based const enum Codec
Browse files Browse the repository at this point in the history
process feedback
simplify macro
  • Loading branch information
ThijsBroersen committed Jun 14, 2024
1 parent 2c29764 commit 816841f
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 46 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package sttp.tapir
package macros

trait SchemaCompanionMacrosExtensions
9 changes: 9 additions & 0 deletions core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package sttp.tapir.macros
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.{Codec, SchemaAnnotations, Validator}
import sttp.tapir.internal.CodecValueClassMacro
import sttp.tapir.Mapping
import sttp.tapir.DecodeResult
import sttp.tapir.DecodeResult.Value
import sttp.tapir.Schema

trait CodecMacros {

Expand Down Expand Up @@ -36,6 +40,11 @@ trait CodecMacros {
inline def derivedEnumerationValueCustomise[L, T <: scala.Enumeration#Value]: CreateDerivedEnumerationCodec[L, T] =
new CreateDerivedEnumerationCodec(derivedEnumerationValueValidator[T], SchemaAnnotations.derived[T])

inline given derivedStringBasedUnionEnumeration[T](using IsUnionOf[String, T]): Codec[String, T, TextPlain] =
lazy val values = UnionDerivation.constValueUnionTuple[String, T]
lazy val validator = Validator.enumeration(values.toList.asInstanceOf[List[T]])
Codec.string.validate(validator.asInstanceOf[Validator[String]]).map(_.asInstanceOf[T])(_.asInstanceOf[String])

/** A default codec for enumerations, which returns a string-based enumeration codec, using the enum's `.toString` to encode values, and
* performing a case-insensitive search through the possible values, converted to strings using `.toString`.
*
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package sttp.tapir
package macros

import sttp.tapir.Schema.SName

object SchemaCompanionMacrosExtensions extends SchemaCompanionMacrosExtensions

trait SchemaCompanionMacrosExtensions:
inline given derivedStringBasedUnionEnumeration[S](using IsUnionOf[String, S]): Schema[S] =
lazy val values = UnionDerivation.constValueUnionTuple[String, S]
lazy val validator = Validator.enumeration(values.toList.asInstanceOf[List[S]])
Schema
.string[S]
.name(SName(values.toList.mkString("_or_")))
.validate(validator)
10 changes: 6 additions & 4 deletions core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import scala.deriving.*
import scala.quoted.*

@scala.annotation.implicitNotFound("${A} is not a union of ${T}")
sealed trait IsUnionOf[T, A]
private[tapir] sealed trait IsUnionOf[T, A]

object IsUnionOf:
private[tapir] object IsUnionOf:

private val singleton: IsUnionOf[Any, Any] = new IsUnionOf[Any, Any] {}

Expand All @@ -31,9 +31,11 @@ object IsUnionOf:
case o: OrType =>
validateTypes(o)
('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]]
case other => report.errorAndAbort(s"${tpe.show} is not a Union")
case o =>
if o <:< bound then ('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]]
else report.errorAndAbort(s"${tpe.show} is not a Union")

object UnionDerivation:
private[tapir] object UnionDerivation:
transparent inline def constValueUnionTuple[T, A](using IsUnionOf[T, A]): Tuple = ${ constValueUnionTupleImpl[T, A] }

private def constValueUnionTupleImpl[T: Type, A: Type](using Quotes): Expr[Tuple] =
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sttp.tapir.Schema.{SName, Title}
import sttp.tapir.SchemaType._
import sttp.tapir.generic.{Configuration, Derived}
import sttp.tapir.internal.{ValidatorSyntax, isBasicValue}
import sttp.tapir.macros.{LowPrioSchemaMacrosVersionSpecific, SchemaCompanionMacros, SchemaMacros}
import sttp.tapir.macros.{SchemaCompanionMacrosExtensions, SchemaCompanionMacros, SchemaMacros}
import sttp.tapir.model.Delimited

import java.io.InputStream
Expand Down Expand Up @@ -422,6 +422,6 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros {
def anyObject[T]: Schema[T] = Schema(SProduct(Nil), None)
}

trait LowPrioritySchema extends LowPrioSchemaMacrosVersionSpecific {
trait LowPrioritySchema extends SchemaCompanionMacrosExtensions {
implicit def derivedSchema[T](implicit derived: Derived[Schema[T]]): Schema[T] = derived.value
}
25 changes: 25 additions & 0 deletions core/src/test/scala-3/sttp/tapir/CodecScala3Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package sttp.tapir

import org.scalatest.{Assertion, Inside}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.scalacheck.Checkers
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.DecodeResult.Value

import sttp.tapir.DecodeResult.InvalidValue

class CodecScala3Test extends AnyFlatSpec with Matchers with Checkers with Inside {

it should "derive a codec for a string-based union type" in {
// given
val codec = summon[Codec[String, "Apple" | "Banana", TextPlain]]

// then
codec.encode("Apple") shouldBe "Apple"
codec.encode("Banana") shouldBe "Banana"
codec.decode("Apple") shouldBe Value("Apple")
codec.decode("Banana") shouldBe Value("Banana")
codec.decode("Orange") should matchPattern { case DecodeResult.InvalidValue(List(ValidationError(_, "Orange", _, _))) => }
}
}
16 changes: 15 additions & 1 deletion core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers:

it should "derive schema for a const as a string-based union type" in {
// when
val s: Schema["a"] = Schema.constStringToEnum
val s: Schema["a"] = Schema.derivedStringBasedUnionEnumeration

// then
s.name.map(_.show) shouldBe Some("a")
Expand All @@ -110,6 +110,20 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers:
s.validator should matchPattern { case Validator.Enumeration(List("a"), _, _) => }
}

it should "derive a schema for a union of unions when all are string-based constants" in {
// when
type AorB = "a" | "b"
type C = "c"
type AorBorC = AorB | C
val s: Schema[AorBorC] = Schema.derivedStringBasedUnionEnumeration[AorBorC]

// then
s.name.map(_.show) shouldBe Some("a_or_b_or_c")

s.schemaType should matchPattern { case SchemaType.SString() => }
s.validator should matchPattern { case Validator.Enumeration(List("a", "b", "c"), _, _) => }
}

object SchemaMacroScala3Test:
enum Fruit:
case Apple, Banana
Expand Down
17 changes: 17 additions & 0 deletions doc/endpoint/enumerations.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ enum ColorEnum {
given Schema[ColorEnum] = Schema.derivedEnumeration.defaultStringBased
```

### Scala 3 string-based constant union types to enum

If a union type is a string-based constant union type, it can be auto-derived as field type or manually derived by using the `Schema.derivedStringBasedUnionEnumeration[T]` method.

Constant strings can be derived by using the `Schema.constStringToEnum[T]` method.

Examples:
```scala
val aOrB: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration
```
```scala
val a: Schema["a"] = Schema.constStringToEnum
```
```scala
case class Foo(aOrB: "a" | "b", optA: Option["a"]) derives Schema
```

### Creating an enum schema by hand

Creating an enumeration [schema](schema.md) by hand is exactly the same as for any other type. The only difference
Expand Down
16 changes: 2 additions & 14 deletions doc/endpoint/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,9 @@ If any of the components of the union type is a generic type, any of its validat
the union type, as it's not possible to generate a runtime check for the generic type.

### Derivation for string-based constant union types
e.g. `type AorB = "a" | "b"`

If a union type is a string-based constant union type, it can be auto-derived as field type or manually derived by using the `Schema.derivedStringBasedUnionEnumeration[T]` method.

Constant strings can be derived by using the `Schema.constStringToEnum[T]` method.

Examples:
```scala
val aOrB: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration
```
```scala
val a: Schema["a"] = Schema.constStringToEnum
```
```scala
case class Foo(aOrB: "a" | "b", optA: Option["a"]) derives Schema
```
See [enumerations](enumerations.md#scala-3-string-based-constant-union-types-to-enum) on how to use string-based unions of constant types as enums.

## Configuring derivation

Expand Down

0 comments on commit 816841f

Please sign in to comment.