Skip to content

Commit

Permalink
Merge pull request #1818 from softwaremill/annotations-enumeratum
Browse files Browse the repository at this point in the history
macro for collecting schema annotations
  • Loading branch information
adamw authored Feb 2, 2022
2 parents 0709088 + 30c3d1f commit 90fe9ff
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 24 deletions.
34 changes: 34 additions & 0 deletions core/src/main/scala-2/sttp/tapir/internal/SchemaAnnotations.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package sttp.tapir.internal

import sttp.tapir.Schema.SName
import sttp.tapir.{Schema, Validator}

final case class SchemaAnnotations[T](
description: Option[String],
encodedExample: Option[Any],
default: Option[T],
format: Option[String],
deprecated: Option[Boolean],
encodedName: Option[String],
validate: Option[Validator[T]]
) {
private case class SchemaEnrich(current: Schema[T]) {
def optionally(f: Schema[T] => Option[Schema[T]]): SchemaEnrich = f(current).map(SchemaEnrich.apply).getOrElse(this)
}

def enrich(s: Schema[T]): Schema[T] = {
SchemaEnrich(s)
.optionally(s => description.map(s.description(_)))
.optionally(s => encodedExample.map(s.encodedExample(_)))
.optionally(s => default.map(s.default(_)))
.optionally(s => format.map(s.format(_)))
.optionally(s => deprecated.map(s.deprecated(_)))
.optionally(s => encodedName.map(en => s.name(SName(en))))
.optionally(s => validate.map(s.validate))
.current
}
}

object SchemaAnnotations {
implicit def schemaAnnotations[T]: SchemaAnnotations[T] = macro SchemaAnnotationsMacro.derived[T]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package sttp.tapir.internal

import scala.reflect.macros.blackbox

object SchemaAnnotationsMacro {
def derived[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[SchemaAnnotations[T]] = {
import c.universe._

val DescriptionAnn = typeOf[sttp.tapir.Schema.annotations.description]
val EncodedExampleAnn = typeOf[sttp.tapir.Schema.annotations.encodedExample]
val DefaultAnn = typeOf[sttp.tapir.Schema.annotations.default[_]]
val FormatAnn = typeOf[sttp.tapir.Schema.annotations.format]
val DeprecatedAnn = typeOf[sttp.tapir.Schema.annotations.deprecated]
val EncodedNameAnn = typeOf[sttp.tapir.Schema.annotations.encodedName]
val ValidateAnn = typeOf[sttp.tapir.Schema.annotations.validate[_]]

val annotations = weakTypeOf[T].typeSymbol.annotations

val firstArg: Annotation => Tree = a => a.tree.children.tail.head

val description = annotations.collectFirst { case ann if ann.tree.tpe <:< DescriptionAnn => firstArg(ann) }
val encodedExample = annotations.collectFirst { case ann if ann.tree.tpe <:< EncodedExampleAnn => firstArg(ann) }
val default = annotations.collectFirst { case ann if ann.tree.tpe <:< DefaultAnn => firstArg(ann) }
val format = annotations.collectFirst { case ann if ann.tree.tpe <:< FormatAnn => firstArg(ann) }
val deprecated = annotations.collectFirst { case ann if ann.tree.tpe <:< DeprecatedAnn => q"""true""" }
val encodedName = annotations.collectFirst { case ann if ann.tree.tpe <:< EncodedNameAnn => firstArg(ann) }
val validator = annotations.collectFirst { case ann if ann.tree.tpe <:< ValidateAnn => firstArg(ann) }

c.Expr[SchemaAnnotations[T]](
q"""_root_.sttp.tapir.internal.SchemaAnnotations.apply($description, $encodedExample, $default, $format, $deprecated, $encodedName, $validator)"""
)
}
}
19 changes: 18 additions & 1 deletion core/src/test/scala-2/sttp/tapir/SchemaMacroTest2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import sttp.tapir.Schema.SName
import sttp.tapir.SchemaMacroTestData2.ValueClasses.DoubleValue
import sttp.tapir.SchemaMacroTestData2.{Type, ValueClasses}
import sttp.tapir.SchemaMacroTestData2.{MyString, Type, ValueClasses}
import sttp.tapir.SchemaType.{SArray, SProduct, SString}
import sttp.tapir.TestUtil.field
import sttp.tapir.generic.auto._
import sttp.tapir.internal.SchemaAnnotations

// tests which pass only on Scala2
class SchemaMacroTest2 extends AnyFlatSpec with Matchers {
Expand Down Expand Up @@ -49,4 +50,20 @@ class SchemaMacroTest2 extends AnyFlatSpec with Matchers {
val ex = the[IllegalArgumentException] thrownBy schemaForCaseClass[Type.MapType]
ex.getMessage.contains("requirement failed: Cannot derive schema for generic value class") shouldBe true
}

it should "derive schema annotations and enrich schema" in {
val baseSchema = Schema.string[MyString]

val enriched = implicitly[SchemaAnnotations[MyString]].enrich(baseSchema)

enriched shouldBe Schema
.string[MyString]
.description("my-string")
.encodedExample("encoded-example")
.default(MyString("default"))
.format("utf8")
.deprecated(true)
.name(SName("encoded-name"))
.validate(Validator.pass[MyString])
}
}
11 changes: 11 additions & 0 deletions core/src/test/scala-2/sttp/tapir/SchemaMacroTestData2.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package sttp.tapir

import sttp.tapir.Schema.annotations._

object SchemaMacroTestData2 {
object ValueClasses {
case class UserName(name: String) extends AnyVal
Expand All @@ -15,4 +17,13 @@ object SchemaMacroTestData2 {
final case class Num[N <: AnyVal: Numeric](n: N) extends Type
final case class MapType(obj: Map[String, Type]) extends Type
}

@description("my-string")
@encodedExample("encoded-example")
@default[MyString](MyString("default"))
@format("utf8")
@Schema.annotations.deprecated
@encodedName("encoded-name")
@validate[MyString](Validator.pass[MyString])
case class MyString(value: String)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
openapi: 3.0.3
info:
title: Numbers
version: '1.0'
paths:
/numbers:
get:
operationId: getNumbers
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/NumberWithMsg'
required: true
responses:
'200':
description: ''
'400':
description: 'Invalid value for: body'
content:
text/plain:
schema:
type: string
components:
schemas:
MyNumber:
type: integer
description: |-
* 1 - One
* 2 - Two
* 3 - Three
enum:
- 1
- 2
- 3
NumberWithMsg:
required:
- number
- msg
type: object
properties:
number:
$ref: '#/components/schemas/MyNumber'
msg:
type: string
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package sttp.tapir.docs.openapi
import io.circe.generic.auto._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import sttp.tapir.Schema.annotations.description
import sttp.tapir._
import sttp.tapir.docs.openapi.VerifyYamlEnumeratumTest.Enumeratum
import sttp.tapir.generic.auto._
import sttp.tapir.json.circe.jsonBody
import sttp.tapir.openapi.Info
import sttp.tapir.openapi.circe.yaml._
import VerifyYamlEnumeratumTest._

class VerifyYamlEnumeratumTest extends AnyFunSuite with Matchers {
test("use enumeratum validator for array elements") {
Expand All @@ -24,11 +25,22 @@ class VerifyYamlEnumeratumTest extends AnyFunSuite with Matchers {

actualYamlNoIndent shouldBe expectedYaml
}

test("add metadata from annotations on enumeratum") {
import sttp.tapir.codec.enumeratum._
val expectedYaml = load("validator/expected_valid_enumeratum_with_metadata.yml")

val actualYaml =
OpenAPIDocsInterpreter().toOpenAPI(endpoint.in("numbers").in(jsonBody[Enumeratum.NumberWithMsg]), Info("Numbers", "1.0")).toYaml

noIndentation(actualYaml) shouldBe expectedYaml
}
}

object VerifyYamlEnumeratumTest {
object Enumeratum {
import enumeratum.{Enum, EnumEntry}
import enumeratum.values.{IntEnum, IntEnumEntry}

case class FruitWithEnum(fruit: String, amount: Int, fruitType: List[FruitType])

Expand All @@ -39,5 +51,17 @@ object VerifyYamlEnumeratumTest {
case object PEAR extends FruitType
override def values: scala.collection.immutable.IndexedSeq[FruitType] = findValues
}

@description("* 1 - One\n* 2 - Two\n* 3 - Three")
sealed abstract class MyNumber(val value: Int) extends IntEnumEntry

object MyNumber extends IntEnum[MyNumber] {
case object One extends MyNumber(1)
case object Two extends MyNumber(2)
case object Three extends MyNumber(3)
override def values = findValues
}

case class NumberWithMsg(number: MyNumber, msg: String)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import enumeratum._
import enumeratum.values._
import sttp.tapir.Schema.SName
import sttp.tapir._
import sttp.tapir.internal.SchemaAnnotations

trait TapirCodecEnumeratum {
// Regular enums

def validatorEnumEntry[E <: EnumEntry](implicit enum: Enum[E]): Validator[E] =
Validator.enumeration(enum.values.toList, v => Some(v.entryName), Some(SName(fullName(`enum`))))

implicit def schemaForEnumEntry[E <: EnumEntry](implicit enum: Enum[E]): Schema[E] =
Schema[E](SchemaType.SString()).validate(validatorEnumEntry)
implicit def schemaForEnumEntry[E <: EnumEntry](implicit annotations: SchemaAnnotations[E], enum: Enum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SString()).validate(validatorEnumEntry))

def plainCodecEnumEntryUsing[E <: EnumEntry](f: String => Option[E])(implicit enum: Enum[E]): Codec[String, E, CodecFormat.TextPlain] =
Codec.string
Expand All @@ -36,23 +37,23 @@ trait TapirCodecEnumeratum {
def validatorValueEnumEntry[T, E <: ValueEnumEntry[T]](implicit enum: ValueEnum[T, E]): Validator[E] =
Validator.enumeration(enum.values.toList, v => Some(v.value), Some(SName(fullName(`enum`))))

implicit def schemaForIntEnumEntry[E <: IntEnumEntry](implicit enum: IntEnum[E]): Schema[E] =
Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Int, E])
implicit def schemaForIntEnumEntry[E <: IntEnumEntry](implicit annotations: SchemaAnnotations[E], enum: IntEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Int, E]))

implicit def schemaForLongEnumEntry[E <: LongEnumEntry](implicit enum: LongEnum[E]): Schema[E] =
Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Long, E])
implicit def schemaForLongEnumEntry[E <: LongEnumEntry](implicit annotations: SchemaAnnotations[E], enum: LongEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Long, E]))

implicit def schemaForShortEnumEntry[E <: ShortEnumEntry](implicit enum: ShortEnum[E]): Schema[E] =
Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Short, E])
implicit def schemaForShortEnumEntry[E <: ShortEnumEntry](implicit annotations: SchemaAnnotations[E], enum: ShortEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Short, E]))

implicit def schemaForStringEnumEntry[E <: StringEnumEntry](implicit enum: StringEnum[E]): Schema[E] =
Schema[E](SchemaType.SString()).validate(validatorValueEnumEntry[String, E])
implicit def schemaForStringEnumEntry[E <: StringEnumEntry](implicit annotations: SchemaAnnotations[E], enum: StringEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SString()).validate(validatorValueEnumEntry[String, E]))

implicit def schemaForByteEnumEntry[E <: ByteEnumEntry](implicit enum: ByteEnum[E]): Schema[E] =
Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Byte, E])
implicit def schemaForByteEnumEntry[E <: ByteEnumEntry](implicit annotations: SchemaAnnotations[E], enum: ByteEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SInteger()).validate(validatorValueEnumEntry[Byte, E]))

implicit def schemaForCharEnumEntry[E <: CharEnumEntry](implicit enum: CharEnum[E]): Schema[E] =
Schema[E](SchemaType.SString()).validate(validatorValueEnumEntry[Char, E])
implicit def schemaForCharEnumEntry[E <: CharEnumEntry](implicit annotations: SchemaAnnotations[E], enum: CharEnum[E]): Schema[E] =
annotations.enrich(Schema[E](SchemaType.SString()).validate(validatorValueEnumEntry[Char, E]))

def plainCodecValueEnumEntry[T, E <: ValueEnumEntry[T]](implicit
enum: ValueEnum[T, E],
Expand Down
Loading

0 comments on commit 90fe9ff

Please sign in to comment.