Skip to content

Commit

Permalink
#315: add discriminator to generic derivation configuration, as in circe
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Oct 22, 2020
1 parent b3a12b7 commit d5e555f
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 36 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ object Schema extends SchemaMagnoliaDerivation with LowPrioritySchema {

implicit def schemaForMap[V: Schema]: Schema[Map[String, V]] = macro SchemaMapMacro.schemaForMap[Map[String, V], V]

def oneOf[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V]
def oneOfUsingField[E, V](extractor: E => V, asString: V => String)(mapping: (V, Schema[_])*): Schema[E] = macro oneOfMacro[E, V]
}

trait LowPrioritySchema {
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/sttp/tapir/SchemaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ object SchemaType {
}
case class SCoproduct(info: SObjectInfo, schemas: List[Schema[_]], discriminator: Option[Discriminator]) extends SObject {
override def show: String = "oneOf:" + schemas.mkString(",")

def addDiscriminatorField[D](
discriminatorName: FieldName,
discriminatorSchema: Schema[D] = Schema(SchemaType.SString),
discriminatorMappingOverride: Map[String, SRef] = Map.empty
): SCoproduct = {
SCoproduct(
info,
schemas.map {
case s @ Schema(st: SchemaType.SProduct, _, _, _, _) =>
s.copy(schemaType = st.copy(fields = st.fields.toSeq :+ (discriminatorName -> discriminatorSchema)))
case s => s
},
Some(Discriminator(discriminatorName.encodedName, discriminatorMappingOverride))
)
}
}
case class SOpenProduct(info: SObjectInfo, valueSchema: Schema[_]) extends SObject {
override def show: String = s"map"
Expand Down
16 changes: 5 additions & 11 deletions core/src/main/scala/sttp/tapir/generic/Configuration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@ package sttp.tapir.generic

import java.util.regex.Pattern

final case class Configuration(toEncodedName: String => String) {
def withSnakeCaseMemberNames: Configuration =
copy(
toEncodedName = Configuration.snakeCaseTransformation
)

def withKebabCaseMemberNames: Configuration =
copy(
toEncodedName = Configuration.kebabCaseTransformation
)
final case class Configuration(toEncodedName: String => String, discriminator: Option[String]) {
def withSnakeCaseMemberNames: Configuration = copy(toEncodedName = Configuration.snakeCaseTransformation)
def withKebabCaseMemberNames: Configuration = copy(toEncodedName = Configuration.kebabCaseTransformation)
def withDiscriminator(d: String): Configuration = copy(discriminator = Some(d))
}

object Configuration {
implicit val default: Configuration = Configuration(Predef.identity)
implicit val default: Configuration = Configuration(Predef.identity, None)

private val basePattern: Pattern = Pattern.compile("([A-Z]+)([A-Z][a-z])")
private val swapPattern: Pattern = Pattern.compile("([a-z\\d])([A-Z])")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,13 @@ trait SchemaMagnoliaDerivation {
private def isDeprecated(annotations: Seq[Any]): Boolean =
annotations.collectFirst { case _: deprecated => true } getOrElse false

def dispatch[T](ctx: SealedTrait[Schema, T]): Schema[T] = {
Schema(SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None))
def dispatch[T](ctx: SealedTrait[Schema, T])(implicit genericDerivationConfig: Configuration): Schema[T] = {
val baseCoproduct = SCoproduct(typeNameToObjectInfo(ctx.typeName, ctx.annotations), ctx.subtypes.map(_.typeclass).toList, None)
val coproduct = genericDerivationConfig.discriminator match {
case Some(d) => baseCoproduct.addDiscriminatorField(FieldName(d))
case None => baseCoproduct
}
Schema(coproduct)
}

implicit def schemaForCaseClass[T]: Derived[Schema[T]] = macro MagnoliaDerivedMacro.derivedGen[T]
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/sttp/tapir/SchemaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,31 @@ class SchemaTest extends AnyFlatSpec with Matchers {
openProductSchema
.modifyUnsafe[Nothing]()(_.description("test")) shouldBe openProductSchema.description("test")
}

it should "generate one-of schema using the given discriminator" in {
val coproduct = SCoproduct(
SObjectInfo("A"),
List(
Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger)))),
Schema(SProduct(SObjectInfo("G"), List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString)))),
Schema(SString)
),
None
)

coproduct.addDiscriminatorField(FieldName("who_am_i")) shouldBe SCoproduct(
SObjectInfo("A"),
List(
Schema(SProduct(SObjectInfo("H"), List(FieldName("f1") -> Schema(SInteger), FieldName("who_am_i") -> Schema(SString)))),
Schema(
SProduct(
SObjectInfo("G"),
List(FieldName("f1") -> Schema(SString), FieldName("f2") -> Schema(SString), FieldName("who_am_i") -> Schema(SString))
)
),
Schema(SString)
),
Some(Discriminator("who_am_i", Map.empty))
)
}
}
27 changes: 27 additions & 0 deletions core/src/test/scala/sttp/tapir/generic/SchemaGenericTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ class SchemaGenericTest extends AnyFlatSpec with Matchers {
implicitly[Schema[IList]] shouldBe expectedISchema
implicitly[Schema[JList]] shouldBe expectedJSchema
}

it should "generate one-of schema using the given discriminator" in {
implicit val customConf: Configuration = Configuration.default.withDiscriminator("who_am_i")

implicitly[Schema[Entity]].schemaType shouldBe SCoproduct(
SObjectInfo("sttp.tapir.generic.Entity"),
List(
Schema(
SProduct(
SObjectInfo("sttp.tapir.generic.Organization"),
List((FieldName("name"), Schema(SString)), (FieldName("who_am_i"), Schema(SString)))
)
),
Schema(
SProduct(
SObjectInfo("sttp.tapir.generic.Person"),
List((FieldName("first"), Schema(SString)), (FieldName("age"), Schema(SInteger)), (FieldName("who_am_i"), Schema(SString)))
)
)
),
Some(Discriminator("who_am_i", Map.empty))
)
}
}

case class StringValueClass(value: String) extends AnyVal
Expand Down Expand Up @@ -391,3 +414,7 @@ case class JOpt(data: Option[IOpt])

case class IList(i1: List[IList], i2: Int)
case class JList(data: List[IList])

sealed trait Entity
case class Person(first: String, age: Int) extends Entity
case class Organization(name: String) extends Entity
30 changes: 20 additions & 10 deletions doc/endpoint/customtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Automatic codec derivation usually requires other implicits, such as:

For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given
that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use
snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:
snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:

```scala mdoc:silent
import sttp.tapir.generic.Configuration
Expand All @@ -117,11 +117,25 @@ be derived automatically.

### Sealed traits / coproducts

Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined
by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a
discriminator object can be specified.
Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies
will be represented as a coproduct which contains a list of child schemas, without any discriminator field.

For example, given following coproduct:
A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used
during automatic derivation:

```scala mdoc:silent:reset
import sttp.tapir.generic.Configuration

implicit val customConfiguration: Configuration =
Configuration.default.withDiscriminator("who_am_i")
```

Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling
the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method.

Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the
schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate
mapping overrides):

```scala mdoc:silent:reset
sealed trait Entity {
Expand All @@ -133,17 +147,13 @@ case class Person(firstName:String, lastName:String) extends Entity {
case class Organization(name: String) extends Entity {
def kind: String = "org"
}
```

The schema may look like this:

```scala mdoc:silent
import sttp.tapir._

val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] =
Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
```

## Customising derived schemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
test("should match the expected yaml when using coproduct types with discriminator") {
val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)
implicit val sEntity: Schema[Entity] =
Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)

val expectedYaml = loadYaml("expected_coproduct_discriminator.yml")
val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, Entity, Any] = endpoint
Expand All @@ -264,7 +265,8 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
@silent("never used") // it is used
implicit val sEntity: Schema[Entity] = Schema.oneOf[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)
implicit val sEntity: Schema[Entity] =
Schema.oneOfUsingField[Entity, String](_.name, _.toString)("john" -> sPerson, "sml" -> sOrganization)

val expectedYaml = loadYaml("expected_coproduct_discriminator_nested.yml")
val endpoint_wit_sealed_trait: Endpoint[Unit, Unit, NestedEntity, Any] = endpoint
Expand Down
30 changes: 20 additions & 10 deletions generated-doc/out/endpoint/customtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Automatic codec derivation usually requires other implicits, such as:

For case classes types, `Schema[_]` values are derived automatically using [Magnolia](https://propensive.com/opensource/magnolia/), given
that schemas are defined for all the case class's fields. It is possible to configure the automatic derivation to use
snake-case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:
snake_case, kebab-case or a custom field naming policy, by providing an implicit `sttp.tapir.generic.Configuration` value:

```scala
import sttp.tapir.generic.Configuration
Expand All @@ -117,11 +117,25 @@ be derived automatically.

### Sealed traits / coproducts

Tapir supports schema generation for coproduct types (sealed trait hierarchies) out of the box, but they need to be defined
by hand as `implicit` values. To properly reflect the schema in [OpenAPI](../openapi.md) documentation, a
discriminator object can be specified.
Schema derivation for coproduct types (sealed trait hierarchies) is supported as well. By default, such hierarchies
will be represented as a coproduct which contains a list of child schemas, without any discriminator field.

For example, given following coproduct:
A discriminator field can be specified for coproducts by providing it in the configuration; this will be only used
during automatic derivation:

```scala
import sttp.tapir.generic.Configuration

implicit val customConfiguration: Configuration =
Configuration.default.withDiscriminator("who_am_i")
```

Alternatively, derived schemas can be customised (see below), and a discriminator can be added by calling
the `SchemaType.SCoproduct.addDiscriminatorField(name, schema, mapingOverride)` method.

Finally, if the discriminator is a field that's defined on the base trait (and hence in each implementation), the
schemas can be specified using `Schema.oneOfUsingField`, for example (this will also generate the appropriate
mapping overrides):

```scala
sealed trait Entity {
Expand All @@ -133,17 +147,13 @@ case class Person(firstName:String, lastName:String) extends Entity {
case class Organization(name: String) extends Entity {
def kind: String = "org"
}
```

The schema may look like this:

```scala
import sttp.tapir._

val sPerson = implicitly[Schema[Person]]
val sOrganization = implicitly[Schema[Organization]]
implicit val sEntity: Schema[Entity] =
Schema.oneOf[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
Schema.oneOfUsingField[Entity, String](_.kind, _.toString)("person" -> sPerson, "org" -> sOrganization)
```

## Customising derived schemas
Expand Down

0 comments on commit d5e555f

Please sign in to comment.