diff --git a/core/src/main/scala/sttp/tapir/Schema.scala b/core/src/main/scala/sttp/tapir/Schema.scala index 6c284f45e7..0b695f6708 100644 --- a/core/src/main/scala/sttp/tapir/Schema.scala +++ b/core/src/main/scala/sttp/tapir/Schema.scala @@ -124,7 +124,7 @@ object Schema extends SchemaMagnoliaDerivation with LowPrioritySchema { implicit def schemaForMap[V: Schema]: Schema[Map[String, V]] = macro SchemaMapMacro.schemaForMap[Map[String, V], V] - def oneOf[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V] + def oneOfUsingField[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V] } trait LowPrioritySchema { diff --git a/core/src/main/scala/sttp/tapir/SchemaType.scala b/core/src/main/scala/sttp/tapir/SchemaType.scala index d9e17700ed..99115f265f 100644 --- a/core/src/main/scala/sttp/tapir/SchemaType.scala +++ b/core/src/main/scala/sttp/tapir/SchemaType.scala @@ -45,6 +45,22 @@ object SchemaType { } case class SCoproduct(info: SObjectInfo, schemas: List[Schema[_]], discriminator: Option[Discriminator]) extends SObject { override def show: String = "oneOf:" + schemas.mkString(",") + + def addDiscriminatorField[D]( + discriminatorName: FieldName, + discriminatorSchema: Schema[D] = Schema(SchemaType.SString), + discriminatorMappingOverride: Map[String, SRef] = Map.empty + ): SCoproduct = { + SCoproduct( + info, + schemas.map { + case s @ Schema(st: SchemaType.SProduct, _, _, _, _) => + s.copy(schemaType = st.copy(fields = st.fields.toSeq :+ (discriminatorName -> discriminatorSchema))) + case s => s + }, + Some(Discriminator(discriminatorName.encodedName, discriminatorMappingOverride)) + ) + } } case class SOpenProduct(info: SObjectInfo, valueSchema: Schema[_]) extends SObject { override def show: String = s"map" diff --git a/core/src/main/scala/sttp/tapir/generic/Configuration.scala b/core/src/main/scala/sttp/tapir/generic/Configuration.scala index 2caf37e934..f9c637b024 100644 --- a/core/src/main/scala/sttp/tapir/generic/Configuration.scala +++ b/core/src/main/scala/sttp/tapir/generic/Configuration.scala @@ -2,20 +2,14 @@ package sttp.tapir.generic import java.util.regex.Pattern -final case class Configuration(toEncodedName: String => String) { - def withSnakeCaseMemberNames: Configuration = - copy( - toEncodedName = Configuration.snakeCaseTransformation - ) - - def withKebabCaseMemberNames: Configuration = - copy( - toEncodedName = Configuration.kebabCaseTransformation - ) +final case class Configuration(toEncodedName: String => String, discriminator: Option[String]) { + def withSnakeCaseMemberNames: Configuration = copy(toEncodedName = Configuration.snakeCaseTransformation) + def withKebabCaseMemberNames: Configuration = copy(toEncodedName = Configuration.kebabCaseTransformation) + def withDiscriminator(d: String): Configuration = copy(discriminator = Some(d)) } object Configuration { - implicit val default: Configuration = Configuration(Predef.identity) + implicit val default: Configuration = Configuration(Predef.identity, None) private val basePattern: Pattern = Pattern.compile("([A-Z]+)([A-Z][a-z])") private val swapPattern: Pattern = Pattern.compile("([a-z\\d])([A-Z])") diff --git a/core/src/main/scala/sttp/tapir/generic/internal/SchemaMagnoliaDerivation.scala b/core/src/main/scala/sttp/tapir/generic/internal/SchemaMagnoliaDerivation.scala index 952e4bcc0b..734773b12d 100644 --- a/core/src/main/scala/sttp/tapir/generic/internal/SchemaMagnoliaDerivation.scala +++ b/core/src/main/scala/sttp/tapir/generic/internal/SchemaMagnoliaDerivation.scala @@ -87,8 +87,13 @@ trait SchemaMagnoliaDerivation { private def isDeprecated(annotations: Seq[Any]): Boolean = annotations.collectFirst { case _: deprecated => true } getOrElse false - def dispatch[T](ctx: SealedTrait[Schema, T]): Schema[T] = { - Schema(SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None)) + def dispatch[T](ctx: SealedTrait[Schema, T])(implicit genericDerivationConfig: Configuration): Schema[T] = { + val baseCoproduct = SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None) + val coproduct = genericDerivationConfig.discriminator match { + case Some(d) => baseCoproduct.addDiscriminatorField(FieldName(d)) + case None => baseCoproduct + } + Schema(coproduct) } implicit def schemaForCaseClass[T]: Derived[Schema[T]] = macro MagnoliaDerivedMacro.derivedGen[T] diff --git a/core/src/test/scala/sttp/tapir/SchemaMacroTest.scala b/core/src/test/scala/sttp/tapir/SchemaModifyTest.scala similarity index 100% rename from core/src/test/scala/sttp/tapir/SchemaMacroTest.scala rename to core/src/test/scala/sttp/tapir/SchemaModifyTest.scala diff --git a/core/src/test/scala/sttp/tapir/SchemaTest.scala b/core/src/test/scala/sttp/tapir/SchemaTest.scala index a8f3cbe277..2f960461c0 100644 --- a/core/src/test/scala/sttp/tapir/SchemaTest.scala +++ b/core/src/test/scala/sttp/tapir/SchemaTest.scala @@ -77,4 +77,31 @@ class SchemaTest extends AnyFlatSpec with Matchers { openProductSchema .modifyUnsafe[Nothing]()(_.description("test")) shouldBe openProductSchema.description("test") } + + it should "generate one-of schema using the given discriminator" in { + val coproduct = SCoproduct( + SObjectInfo("A"), + List( + Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger)))), + Schema(SProduct(SObjectInfo("G"), List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString)))), + Schema(SString) + ), + None + ) + + coproduct.addDiscriminatorField(FieldName("who_am_i")) shouldBe SCoproduct( + SObjectInfo("A"), + List( + Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger), FieldName("who_am_i") -> Schema(SString)))), + Schema( + SProduct( + SObjectInfo("G"), + List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString), FieldName("who_am_i") -> Schema(SString)) + ) + ), + Schema(SString) + ), + Some(Discriminator("who_am_i", Map.empty)) + ) + } } diff --git a/core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala b/core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala index 1ed8b9ab94..1c65a6caae 100644 --- a/core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala +++ b/core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala @@ -337,6 +337,29 @@ class SchemaGenericTest extends AnyFlatSpec with Matchers { implicitly[Schema[IList]] shouldBe expectedISchema implicitly[Schema[JList]] shouldBe expectedJSchema } + + it should "generate one-of schema using the given discriminator" in { + implicit val customConf: Configuration = Configuration.default.withDiscriminator("who_am_i") + + implicitly[Schema[Entity]].schemaType shouldBe SCoproduct( + SObjectInfo("sttp.tapir.generic.Entity"), + List( + Schema( + SProduct( + SObjectInfo("sttp.tapir.generic.Organization"), + List((FieldName("name"), Schema(SString)), (FieldName("who_am_i"), Schema(SString))) + ) + ), + Schema( + SProduct( + SObjectInfo("sttp.tapir.generic.Person"), + List((FieldName("first"), Schema(SString)), (FieldName("age"), Schema(SInteger)), (FieldName("who_am_i"), Schema(SString))) + ) + ) + ), + Some(Discriminator("who_am_i", Map.empty)) + ) + } } case class StringValueClass(value: String) extends AnyVal @@ -391,3 +414,7 @@ case class JOpt(data: Option[IOpt]) case class IList(i1: List[IList], i2: Int) case class JList(data: List[IList]) + +sealed trait Entity +case class Person(first: String, age: Int) extends Entity +case class Organization(name: String) extends Entity diff --git a/doc/endpoint/customtypes.md b/doc/endpoint/customtypes.md index b7ca7870c9..39fc721090 100644 --- a/doc/endpoint/customtypes.md +++ b/doc/endpoint/customtypes.md @@ -92,7 +92,7 @@ Automatic codec derivation usually requires other implicits, such as: For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use -snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value: +snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value: ```scala mdoc:silent import sttp.tapir.generic.Configuration @@ -117,11 +117,25 @@ be derived automatically. ### Sealed traits / coproducts -Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined -by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a -discriminator object can be specified. +Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies +will be represented as a coproduct which contains a list of child schemas, without any discriminator field. -For example, given following coproduct: +A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used +during automatic derivation: + +```scala mdoc:silent:reset +import sttp.tapir.generic.Configuration + +implicit val customConfiguration: Configuration = + Configuration.default.withDiscriminator("who_am_i") +``` + +Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling +the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method. + +Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the +schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate +mapping overrides): ```scala mdoc:silent:reset sealed trait Entity { @@ -133,17 +147,13 @@ case class Person(firstName:String, lastName:String) extends Entity { case class Organization(name: String) extends Entity { def kind: String = "org" } -``` -The schema may look like this: - -```scala mdoc:silent import sttp.tapir._ val sPerson = implicitly[Schema[Person]] val sOrganization = implicitly[Schema[Organization]] implicit val sEntity: Schema[Entity] = - Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization) + Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization) ``` ## Customising derived schemas diff --git a/docs/openapi-docs/src/test/scala/sttp/tapir/docs/openapi/VerifyYamlTest.scala b/docs/openapi-docs/src/test/scala/sttp/tapir/docs/openapi/VerifyYamlTest.scala index 566390ebc7..1cc0f300b1 100644 --- a/docs/openapi-docs/src/test/scala/sttp/tapir/docs/openapi/VerifyYamlTest.scala +++ b/docs/openapi-docs/src/test/scala/sttp/tapir/docs/openapi/VerifyYamlTest.scala @@ -237,7 +237,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers { test("should match the expected yaml when using coproduct types with discriminator") { val sPerson = implicitly[Schema[Person]] val sOrganization = implicitly[Schema[Organization]] - implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization) + implicit val sEntity: Schema[Entity] = + Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization) val expectedYaml = loadYaml("expected_coproduct_discriminator.yml") val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, Entity, Any] = endpoint @@ -264,7 +265,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers { val sPerson = implicitly[Schema[Person]] val sOrganization = implicitly[Schema[Organization]] @silent("never used") // it is used - implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization) + implicit val sEntity: Schema[Entity] = + Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization) val expectedYaml = loadYaml("expected_coproduct_discriminator_nested.yml") val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, NestedEntity, Any] = endpoint diff --git a/generated-doc/out/endpoint/customtypes.md b/generated-doc/out/endpoint/customtypes.md index c37004cc3d..21666c872c 100644 --- a/generated-doc/out/endpoint/customtypes.md +++ b/generated-doc/out/endpoint/customtypes.md @@ -92,7 +92,7 @@ Automatic codec derivation usually requires other implicits, such as: For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use -snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value: +snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value: ```scala import sttp.tapir.generic.Configuration @@ -117,11 +117,25 @@ be derived automatically. ### Sealed traits / coproducts -Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined -by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a -discriminator object can be specified. +Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies +will be represented as a coproduct which contains a list of child schemas, without any discriminator field. -For example, given following coproduct: +A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used +during automatic derivation: + +```scala +import sttp.tapir.generic.Configuration + +implicit val customConfiguration: Configuration = + Configuration.default.withDiscriminator("who_am_i") +``` + +Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling +the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method. + +Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the +schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate +mapping overrides): ```scala sealed trait Entity { @@ -133,17 +147,13 @@ case class Person(firstName:String, lastName:String) extends Entity { case class Organization(name: String) extends Entity { def kind: String = "org" } -``` -The schema may look like this: - -```scala import sttp.tapir._ val sPerson = implicitly[Schema[Person]] val sOrganization = implicitly[Schema[Organization]] implicit val sEntity: Schema[Entity] = - Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization) + Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization) ``` ## Customising derived schemas