Skip to content

Commit

Permalink
Merge pull request #806 from softwaremill/coproducts-derivation
Browse files Browse the repository at this point in the history
#315: configurable coproducts schema derivation
  • Loading branch information
adamw authored Oct 22, 2020
2 parents eb7f287 + 3bd9a4d commit 4c48178
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sttp.tapir.apispec.{ExampleValue, ReferenceOr, Schema, SecurityRequiremen
import scala.collection.immutable.ListMap

case class OpenAPI(
openapi: String = "3.0.3",
openapi: String = "3.0.1",
info: Info,
tags: List[Tag],
servers: List[Server],
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ object Schema extends SchemaMagnoliaDerivation with LowPrioritySchema {

implicit def schemaForMap[V: Schema]: Schema[Map[String, V]] = macro SchemaMapMacro.schemaForMap[Map[String, V], V]

def oneOf[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V]
def oneOfUsingField[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V]
}

trait LowPrioritySchema {
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/sttp/tapir/SchemaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ object SchemaType {
}
case class SCoproduct(info: SObjectInfo, schemas: List[Schema[_]], discriminator: Option[Discriminator]) extends SObject {
override def show: String = "oneOf:" + schemas.mkString(",")

def addDiscriminatorField[D](
discriminatorName: FieldName,
discriminatorSchema: Schema[D] = Schema(SchemaType.SString),
discriminatorMappingOverride: Map[String, SRef] = Map.empty
): SCoproduct = {
SCoproduct(
info,
schemas.map {
case s @ Schema(st: SchemaType.SProduct, _, _, _, _) =>
s.copy(schemaType = st.copy(fields = st.fields.toSeq :+ (discriminatorName -> discriminatorSchema)))
case s => s
},
Some(Discriminator(discriminatorName.encodedName, discriminatorMappingOverride))
)
}
}
case class SOpenProduct(info: SObjectInfo, valueSchema: Schema[_]) extends SObject {
override def show: String = s"map"
Expand Down
16 changes: 5 additions & 11 deletions core/src/main/scala/sttp/tapir/generic/Configuration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@ package sttp.tapir.generic

import java.util.regex.Pattern

final case class Configuration(toEncodedName: String => String) {
def withSnakeCaseMemberNames: Configuration =
copy(
toEncodedName = Configuration.snakeCaseTransformation
)

def withKebabCaseMemberNames: Configuration =
copy(
toEncodedName = Configuration.kebabCaseTransformation
)
final case class Configuration(toEncodedName: String => String, discriminator: Option[String]) {
def withSnakeCaseMemberNames: Configuration = copy(toEncodedName = Configuration.snakeCaseTransformation)
def withKebabCaseMemberNames: Configuration = copy(toEncodedName = Configuration.kebabCaseTransformation)
def withDiscriminator(d: String): Configuration = copy(discriminator = Some(d))
}

object Configuration {
implicit val default: Configuration = Configuration(Predef.identity)
implicit val default: Configuration = Configuration(Predef.identity, None)

private val basePattern: Pattern = Pattern.compile("([A-Z]+)([A-Z][a-z])")
private val swapPattern: Pattern = Pattern.compile("([a-z\\d])([A-Z])")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,13 @@ trait SchemaMagnoliaDerivation {
private def isDeprecated(annotations: Seq[Any]): Boolean =
annotations.collectFirst { case _: deprecated => true } getOrElse false

def dispatch[T](ctx: SealedTrait[Schema, T]): Schema[T] = {
Schema(SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None))
def dispatch[T](ctx: SealedTrait[Schema, T])(implicit genericDerivationConfig: Configuration): Schema[T] = {
val baseCoproduct = SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None)
val coproduct = genericDerivationConfig.discriminator match {
case Some(d) => baseCoproduct.addDiscriminatorField(FieldName(d))
case None => baseCoproduct
}
Schema(coproduct)
}

implicit def schemaForCaseClass[T]: Derived[Schema[T]] = macro MagnoliaDerivedMacro.derivedGen[T]
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/sttp/tapir/SchemaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,31 @@ class SchemaTest extends AnyFlatSpec with Matchers {
openProductSchema
.modifyUnsafe[Nothing]()(_.description("test")) shouldBe openProductSchema.description("test")
}

it should "generate one-of schema using the given discriminator" in {
val coproduct = SCoproduct(
SObjectInfo("A"),
List(
Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger)))),
Schema(SProduct(SObjectInfo("G"), List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString)))),
Schema(SString)
),
None
)

coproduct.addDiscriminatorField(FieldName("who_am_i")) shouldBe SCoproduct(
SObjectInfo("A"),
List(
Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger), FieldName("who_am_i") -> Schema(SString)))),
Schema(
SProduct(
SObjectInfo("G"),
List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString), FieldName("who_am_i") -> Schema(SString))
)
),
Schema(SString)
),
Some(Discriminator("who_am_i", Map.empty))
)
}
}
27 changes: 27 additions & 0 deletions core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ class SchemaGenericTest extends AnyFlatSpec with Matchers {
implicitly[Schema[IList]] shouldBe expectedISchema
implicitly[Schema[JList]] shouldBe expectedJSchema
}

it should "generate one-of schema using the given discriminator" in {
implicit val customConf: Configuration = Configuration.default.withDiscriminator("who_am_i")

implicitly[Schema[Entity]].schemaType shouldBe SCoproduct(
SObjectInfo("sttp.tapir.generic.Entity"),
List(
Schema(
SProduct(
SObjectInfo("sttp.tapir.generic.Organization"),
List((FieldName("name"), Schema(SString)), (FieldName("who_am_i"), Schema(SString)))
)
),
Schema(
SProduct(
SObjectInfo("sttp.tapir.generic.Person"),
List((FieldName("first"), Schema(SString)), (FieldName("age"), Schema(SInteger)), (FieldName("who_am_i"), Schema(SString)))
)
)
),
Some(Discriminator("who_am_i", Map.empty))
)
}
}

case class StringValueClass(value: String) extends AnyVal
Expand Down Expand Up @@ -391,3 +414,7 @@ case class JOpt(data: Option[IOpt])

case class IList(i1: List[IList], i2: Int)
case class JList(data: List[IList])

sealed trait Entity
case class Person(first: String, age: Int) extends Entity
case class Organization(name: String) extends Entity
33 changes: 23 additions & 10 deletions doc/endpoint/customtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ Automatic codec derivation usually requires other implicits, such as:
* codecs for individual form fields
* schema of the custom type, through the `Schema[T]` implicit

Note the derivation of e.g. circe json encoders/decoders and tapir schema are separate processes, and must be
hence configured separately.

## Schema derivation

For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given
that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use
snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:
snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:

```scala mdoc:silent
import sttp.tapir.generic.Configuration
Expand All @@ -117,11 +120,25 @@ be derived automatically.

### Sealed traits / coproducts

Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined
by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a
discriminator object can be specified.
Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies
will be represented as a coproduct which contains a list of child schemas, without any discriminator field.

A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used
during automatic derivation:

```scala mdoc:silent:reset
import sttp.tapir.generic.Configuration

implicit val customConfiguration: Configuration =
Configuration.default.withDiscriminator("who_am_i")
```

Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling
the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method.

For example, given following coproduct:
Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the
schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate
mapping overrides):

```scala mdoc:silent:reset
sealed trait Entity {
Expand All @@ -133,17 +150,13 @@ case class Person(firstName:String, lastName:String) extends Entity {
case class Organization(name: String) extends Entity {
def kind: String = "org"
}
```

The schema may look like this:

```scala mdoc:silent
import sttp.tapir._

val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] =
Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
```

## Customising derived schemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
test("should match the expected yaml when using coproduct types with discriminator") {
val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)
implicit val sEntity: Schema[Entity] =
Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)

val expectedYaml = loadYaml("expected_coproduct_discriminator.yml")
val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, Entity, Any] = endpoint
Expand All @@ -264,7 +265,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
@silent("never used") // it is used
implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)
implicit val sEntity: Schema[Entity] =
Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)

val expectedYaml = loadYaml("expected_coproduct_discriminator_nested.yml")
val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, NestedEntity, Any] = endpoint
Expand Down
33 changes: 23 additions & 10 deletions generated-doc/out/endpoint/customtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ Automatic codec derivation usually requires other implicits, such as:
* codecs for individual form fields
* schema of the custom type, through the `Schema[T]` implicit

Note the derivation of e.g. circe json encoders/decoders and tapir schema are separate processes, and must be
hence configured separately.

## Schema derivation

For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given
that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use
snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:
snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:

```scala
import sttp.tapir.generic.Configuration
Expand All @@ -117,11 +120,25 @@ be derived automatically.

### Sealed traits / coproducts

Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined
by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a
discriminator object can be specified.
Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies
will be represented as a coproduct which contains a list of child schemas, without any discriminator field.

A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used
during automatic derivation:

```scala
import sttp.tapir.generic.Configuration

implicit val customConfiguration: Configuration =
Configuration.default.withDiscriminator("who_am_i")
```

Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling
the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method.

For example, given following coproduct:
Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the
schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate
mapping overrides):

```scala
sealed trait Entity {
Expand All @@ -133,17 +150,13 @@ case class Person(firstName:String, lastName:String) extends Entity {
case class Organization(name: String) extends Entity {
def kind: String = "org"
}
```

The schema may look like this:

```scala
import sttp.tapir._

val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] =
Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
```

## Customising derived schemas
Expand Down

0 comments on commit 4c48178

Please sign in to comment.