diff --git a/core/src/main/scala-2/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala b/core/src/main/scala-2/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala deleted file mode 100644 index 7c77e3bbc2..0000000000 --- a/core/src/main/scala-2/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala +++ /dev/null @@ -1,4 +0,0 @@ -package sttp.tapir -package macros - -trait LowPrioSchemaMacrosVersionSpecific diff --git a/core/src/main/scala-2/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala b/core/src/main/scala-2/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala new file mode 100644 index 0000000000..a302d1a4be --- /dev/null +++ b/core/src/main/scala-2/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala @@ -0,0 +1,4 @@ +package sttp.tapir +package macros + +trait SchemaCompanionMacrosExtensions diff --git a/core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala b/core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala index c63d9dac46..45b50d5c7c 100644 --- a/core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala +++ b/core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala @@ -3,6 +3,10 @@ package sttp.tapir.macros import sttp.tapir.CodecFormat.TextPlain import sttp.tapir.{Codec, SchemaAnnotations, Validator} import sttp.tapir.internal.CodecValueClassMacro +import sttp.tapir.Mapping +import sttp.tapir.DecodeResult +import sttp.tapir.DecodeResult.Value +import sttp.tapir.Schema trait CodecMacros { @@ -36,6 +40,11 @@ trait CodecMacros { inline def derivedEnumerationValueCustomise[L, T <: scala.Enumeration#Value]: CreateDerivedEnumerationCodec[L, T] = new CreateDerivedEnumerationCodec(derivedEnumerationValueValidator[T], SchemaAnnotations.derived[T]) + inline given derivedStringBasedUnionEnumeration[T](using IsUnionOf[String, T]): Codec[String, T, TextPlain] = + lazy val values = UnionDerivation.constValueUnionTuple[String, T] + lazy val validator = Validator.enumeration(values.toList.asInstanceOf[List[T]]) + Codec.string.validate(validator.asInstanceOf[Validator[String]]).map(_.asInstanceOf[T])(_.asInstanceOf[String]) + /** A default codec for enumerations, which returns a string-based enumeration codec, using the enum's `.toString` to encode values, and * performing a case-insensitive search through the possible values, converted to strings using `.toString`. * diff --git a/core/src/main/scala-3/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala b/core/src/main/scala-3/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala deleted file mode 100644 index c0634480fd..0000000000 --- a/core/src/main/scala-3/sttp/tapir/macros/LowPrioSchemaMacrosVersionSpecific.scala +++ /dev/null @@ -1,21 +0,0 @@ -package sttp.tapir -package macros - -import sttp.tapir.Schema.SName - -import scala.compiletime.* -import scala.compiletime.ops.any.IsConst - -trait LowPrioSchemaMacrosVersionSpecific: - inline given derivedStringBasedUnionEnumeration[S](using IsUnionOf[String, S]): Schema[S] = - val values = UnionDerivation.constValueUnionTuple[String, S] - Schema - .string[S] - .name(SName(values.toList.mkString("_or_"))) - .validate(Validator.enumeration(values.toList.asInstanceOf[List[S]])) - - inline given constStringToEnum[S <: String](using IsConst[S] =:= true): Schema[S] = - Schema - .string[S] - .name(SName(constValue[S])) - .validate(Validator.enumeration(List(constValue[S]).asInstanceOf[List[S]])) diff --git a/core/src/main/scala-3/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala b/core/src/main/scala-3/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala new file mode 100644 index 0000000000..2e71be73e6 --- /dev/null +++ b/core/src/main/scala-3/sttp/tapir/macros/SchemaCompanionMacrosExtensions.scala @@ -0,0 +1,15 @@ +package sttp.tapir +package macros + +import sttp.tapir.Schema.SName + +object SchemaCompanionMacrosExtensions extends SchemaCompanionMacrosExtensions + +trait SchemaCompanionMacrosExtensions: + inline given derivedStringBasedUnionEnumeration[S](using IsUnionOf[String, S]): Schema[S] = + lazy val values = UnionDerivation.constValueUnionTuple[String, S] + lazy val validator = Validator.enumeration(values.toList.asInstanceOf[List[S]]) + Schema + .string[S] + .name(SName(values.toList.mkString("_or_"))) + .validate(validator) diff --git a/core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala b/core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala index 2bd64a3129..47e9c46124 100644 --- a/core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala +++ b/core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala @@ -5,9 +5,9 @@ import scala.deriving.* import scala.quoted.* @scala.annotation.implicitNotFound("${A} is not a union of ${T}") -sealed trait IsUnionOf[T, A] +private[tapir] sealed trait IsUnionOf[T, A] -object IsUnionOf: +private[tapir] object IsUnionOf: private val singleton: IsUnionOf[Any, Any] = new IsUnionOf[Any, Any] {} @@ -31,9 +31,11 @@ object IsUnionOf: case o: OrType => validateTypes(o) ('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]] - case other => report.errorAndAbort(s"${tpe.show} is not a Union") + case o => + if o <:< bound then ('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]] + else report.errorAndAbort(s"${tpe.show} is not a Union") -object UnionDerivation: +private[tapir] object UnionDerivation: transparent inline def constValueUnionTuple[T, A](using IsUnionOf[T, A]): Tuple = ${ constValueUnionTupleImpl[T, A] } private def constValueUnionTupleImpl[T: Type, A: Type](using Quotes): Expr[Tuple] = diff --git a/core/src/main/scala/sttp/tapir/Schema.scala b/core/src/main/scala/sttp/tapir/Schema.scala index 9844d360a1..1826df96c8 100644 --- a/core/src/main/scala/sttp/tapir/Schema.scala +++ b/core/src/main/scala/sttp/tapir/Schema.scala @@ -5,7 +5,7 @@ import sttp.tapir.Schema.{SName, Title} import sttp.tapir.SchemaType._ import sttp.tapir.generic.{Configuration, Derived} import sttp.tapir.internal.{ValidatorSyntax, isBasicValue} -import sttp.tapir.macros.{LowPrioSchemaMacrosVersionSpecific, SchemaCompanionMacros, SchemaMacros} +import sttp.tapir.macros.{SchemaCompanionMacrosExtensions, SchemaCompanionMacros, SchemaMacros} import sttp.tapir.model.Delimited import java.io.InputStream @@ -422,6 +422,6 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros { def anyObject[T]: Schema[T] = Schema(SProduct(Nil), None) } -trait LowPrioritySchema extends LowPrioSchemaMacrosVersionSpecific { +trait LowPrioritySchema extends SchemaCompanionMacrosExtensions { implicit def derivedSchema[T](implicit derived: Derived[Schema[T]]): Schema[T] = derived.value } diff --git a/core/src/test/scala-3/sttp/tapir/CodecScala3Test.scala b/core/src/test/scala-3/sttp/tapir/CodecScala3Test.scala new file mode 100644 index 0000000000..75ba64a0c2 --- /dev/null +++ b/core/src/test/scala-3/sttp/tapir/CodecScala3Test.scala @@ -0,0 +1,25 @@ +package sttp.tapir + +import org.scalatest.{Assertion, Inside} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.scalacheck.Checkers +import sttp.tapir.CodecFormat.TextPlain +import sttp.tapir.DecodeResult.Value + +import sttp.tapir.DecodeResult.InvalidValue + +class CodecScala3Test extends AnyFlatSpec with Matchers with Checkers with Inside { + + it should "derive a codec for a string-based union type" in { + // given + val codec = summon[Codec[String, "Apple" | "Banana", TextPlain]] + + // then + codec.encode("Apple") shouldBe "Apple" + codec.encode("Banana") shouldBe "Banana" + codec.decode("Apple") shouldBe Value("Apple") + codec.decode("Banana") shouldBe Value("Banana") + codec.decode("Orange") should matchPattern { case DecodeResult.InvalidValue(List(ValidationError(_, "Orange", _, _))) => } + } +} diff --git a/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala b/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala index 66279c178f..968104e971 100644 --- a/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala +++ b/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala @@ -101,7 +101,7 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers: it should "derive schema for a const as a string-based union type" in { // when - val s: Schema["a"] = Schema.constStringToEnum + val s: Schema["a"] = Schema.derivedStringBasedUnionEnumeration // then s.name.map(_.show) shouldBe Some("a") @@ -110,6 +110,20 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers: s.validator should matchPattern { case Validator.Enumeration(List("a"), _, _) => } } + it should "derive a schema for a union of unions when all are string-based constants" in { + // when + type AorB = "a" | "b" + type C = "c" + type AorBorC = AorB | C + val s: Schema[AorBorC] = Schema.derivedStringBasedUnionEnumeration[AorBorC] + + // then + s.name.map(_.show) shouldBe Some("a_or_b_or_c") + + s.schemaType should matchPattern { case SchemaType.SString() => } + s.validator should matchPattern { case Validator.Enumeration(List("a", "b", "c"), _, _) => } + } + object SchemaMacroScala3Test: enum Fruit: case Apple, Banana diff --git a/doc/endpoint/enumerations.md b/doc/endpoint/enumerations.md index 8099ae921a..06bf3929a4 100644 --- a/doc/endpoint/enumerations.md +++ b/doc/endpoint/enumerations.md @@ -263,6 +263,23 @@ enum ColorEnum { given Schema[ColorEnum] = Schema.derivedEnumeration.defaultStringBased ``` +### Scala 3 string-based constant union types to enum + +If a union type is a string-based constant union type, it can be auto-derived as field type or manually derived by using the `Schema.derivedStringBasedUnionEnumeration[T]` method. + +Constant strings can be derived by using the `Schema.constStringToEnum[T]` method. + +Examples: +```scala +val aOrB: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration +``` +```scala +val a: Schema["a"] = Schema.constStringToEnum +``` +```scala +case class Foo(aOrB: "a" | "b", optA: Option["a"]) derives Schema +``` + ### Creating an enum schema by hand Creating an enumeration [schema](schema.md) by hand is exactly the same as for any other type. The only difference diff --git a/doc/endpoint/schemas.md b/doc/endpoint/schemas.md index 53bd6a85cd..f31804d563 100644 --- a/doc/endpoint/schemas.md +++ b/doc/endpoint/schemas.md @@ -118,21 +118,9 @@ If any of the components of the union type is a generic type, any of its validat the union type, as it's not possible to generate a runtime check for the generic type. ### Derivation for string-based constant union types +e.g. `type AorB = "a" | "b"` -If a union type is a string-based constant union type, it can be auto-derived as field type or manually derived by using the `Schema.derivedStringBasedUnionEnumeration[T]` method. - -Constant strings can be derived by using the `Schema.constStringToEnum[T]` method. - -Examples: -```scala -val aOrB: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration -``` -```scala -val a: Schema["a"] = Schema.constStringToEnum -``` -```scala -case class Foo(aOrB: "a" | "b", optA: Option["a"]) derives Schema -``` +See [enumerations](enumerations.md#scala-3-string-based-constant-union-types-to-enum) on how to use string-based unions of constant types as enums. ## Configuring derivation