Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix derived enumeration schema validators #1989

Merged
merged 3 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ object SchemaEnumerationMacro {
case Nil => c.abort(c.enclosingPosition, s"Invalid enum name: ${weakTypeT.toString}")
}

val validator = q"_root_.sttp.tapir.Validator.enumeration($enumeration.values.toList)"
val validator =
q"_root_.sttp.tapir.Validator.enumeration($enumeration.values.toList, v => Option(v), Some(sttp.tapir.Schema.SName(${enumNameComponents
.mkString(".")})))"
val schemaAnnotations = c.inferImplicitValue(appliedType(SchemaAnnotations, weakTypeT))

c.Expr[Schema[T]](q"$schemaAnnotations.enrich(Schema.string[$weakTypeT].validate($validator))")
Expand Down
22 changes: 16 additions & 6 deletions core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,13 @@ object SchemaCompanionMacros {
)
val sname = SName(SNameMacros.typeFullName[E], ${ Expr(typeParams) })
val subtypes = mappingAsList.map(_._2)
Schema(SCoproduct[E](subtypes, _root_.scala.Some(discriminator)) { e =>
val ee = $extractor(e)
mappingAsMap.get(ee).map(s => SchemaWithValue(s.asInstanceOf[Schema[Any]], e))
}, Some(sname))
Schema(
SCoproduct[E](subtypes, _root_.scala.Some(discriminator)) { e =>
val ee = $extractor(e)
mappingAsMap.get(ee).map(s => SchemaWithValue(s.asInstanceOf[Schema[Any]], e))
},
Some(sname)
)
}
}

Expand All @@ -220,16 +223,23 @@ object SchemaCompanionMacros {
val enumerationPath = tpe.show.split("\\.").dropRight(1).mkString(".")
val enumeration = Symbol.requiredModule(enumerationPath)

val sName = '{ Some(Schema.SName(${ Expr(enumerationPath) })) }

'{
SchemaAnnotations
.derived[T]
.enrich(
Schema
.string[T]
.validate(Validator.enumeration(${ Ref(enumeration).asExprOf[scala.Enumeration] }.values.toList.asInstanceOf[List[T]]))
.validate(
Validator.enumeration(
${ Ref(enumeration).asExprOf[scala.Enumeration] }.values.toList.asInstanceOf[List[T]],
v => Option(v),
$sName
)
)
)
}
}
}

}
24 changes: 21 additions & 3 deletions core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,29 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {

it should "derive schema for enumeration and enrich schema" in {
val expected = Schema[Countries.Country](SString())
.validate(Validator.enumeration[Countries.Country](Countries.values.toList))
.validate(
Validator.enumeration[Countries.Country](
Countries.values.toList,
(v: Countries.Country) => Option(v),
Some(SName("sttp.tapir.generic.Countries"))
)
)
.description("country")
.default(Countries.PL)
.name(SName("country-encoded-name"))
implicitly[Schema[Countries.Country]] shouldBe expected

val actual = implicitly[Schema[Countries.Country]]

(actual.validator, expected.validator) match {
case (Validator.Enumeration(va, Some(ea), Some(na)), Validator.Enumeration(ve, Some(ee), Some(ne))) =>
va shouldBe ve
ea(Countries.PL) shouldBe ee(Countries.PL)
na shouldBe ne
case _ => Assertions.fail()
}
actual.description shouldBe expected.description
actual.default shouldBe expected.default
actual.name shouldBe expected.name
}
}

Expand Down Expand Up @@ -482,5 +500,5 @@ case object UnknownEntity extends Entity
@encodedName("country-encoded-name")
object Countries extends Enumeration {
type Country = Value
val PL, NL, RUS = Value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🇺🇦

val PL, NL = Value
}
2 changes: 2 additions & 0 deletions doc/endpoint/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ object Color extends Enumeration with EnumHelper {
}
```

Tapir `Schema` for any `Enumeration.Value` can also be auto or semi-auto derived using `import sttp.tapir.generic.auto._` or `Schema.derivedEnumerationValue`.

## NewType integration

If you use [scala-newtype](https://github.com/estatico/scala-newtype), the `tapir-newtype` module will provide implicit codecs and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
openapi: 3.0.3
info:
title: Numbers
version: '1.0'
paths:
/numbers:
get:
operationId: getNumbers
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/Number'
required: true
responses:
'200':
description: ''
'400':
description: 'Invalid value for: body'
content:
text/plain:
schema:
type: string
components:
schemas:
Number:
required:
- value
type: object
properties:
value:
$ref: '#/components/schemas/Numbers'
Numbers:
type: string
enum:
- One
- Two
- Three
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import sttp.model.{Method, StatusCode}
import sttp.tapir.Schema.SName
import sttp.tapir.Schema.annotations.description
import sttp.tapir.docs.apispec.DocsExtension
import sttp.tapir.docs.openapi.VerifyYamlTest.Problem
import sttp.tapir.docs.openapi.VerifyYamlTest._
import sttp.tapir.docs.openapi.dtos.VerifyYamlTestData._
import sttp.tapir.docs.openapi.dtos.VerifyYamlTestData2._
import sttp.tapir.docs.openapi.dtos.Book
Expand Down Expand Up @@ -642,6 +642,17 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
noIndentation(actualYaml) shouldBe expectedYaml
}

test("should contain named schema component and values for enumeration") {
implicit val numberCodec: io.circe.Codec[Number] = null

val actualYaml = OpenAPIDocsInterpreter()
.toOpenAPI(endpoint.in("numbers").in(jsonBody[Number]), Info("Numbers", "1.0"))
.toYaml

val expectedYaml = load("expected_enumeration_values.yml")

noIndentation(actualYaml) shouldBe expectedYaml
}
}

object VerifyYamlTest {
Expand All @@ -655,4 +666,10 @@ object VerifyYamlTest {
)
)
}

object Numbers extends Enumeration {
val One, Two, Three = Value
}

case class Number(value: Numbers.Value)
}