Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OpenAPI constraints for array schemas #709

Merged
merged 3 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,13 @@ lazy val tests: ProjectMatrix = (projectMatrix in file("tests"))
"io.circe" %% "circe-generic" % Versions.circe,
"com.softwaremill.common" %% "tagging" % "2.2.1",
scalaTest,
"com.softwaremill.macwire" %% "macros" % "2.3.7" % "provided"
),
"com.softwaremill.macwire" %% "macros" % "2.3.7" % "provided",
"com.beachape" %% "enumeratum" % Versions.enumeratum,
"com.beachape" %% "enumeratum-circe" % Versions.enumeratum),
libraryDependencies ++= loggerDependencies
)
.jvmPlatform(scalaVersions = allScalaVersions)
.dependsOn(core, circeJson)
.dependsOn(core, circeJson, enumeratum)

// integrations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,18 @@ private[schema] class TSchemaToOSchema(schemaReferenceMapper: SchemaReferenceMap
)
}

val primitiveValidators = typeData.schema.schemaType match {
case TSchemaType.SArray(_) => asPrimitiveValidators(typeData.validator)
case _ => asPrimitiveValidatorsDeep(typeData.validator)
}
val wholeNumbers = typeData.schema.schemaType match {
case TSchemaType.SInteger => true
case _ => false
}

result
.map(addMetadata(_, typeData.schema))
.map(
addConstraints(_, asPrimitiveValidators(typeData.validator), typeData.schema.schemaType.isInstanceOf[TSchemaType.SInteger.type])
)
.map(addConstraints(_, primitiveValidators, wholeNumbers))
}

private def addMetadata(oschema: OSchema, tschema: TSchema[_]): OSchema = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,27 @@ package object schema {

private[schema] def asPrimitiveValidators(v: Validator[_]): Seq[Validator.Primitive[_]] = {
v match {
case Validator.Mapped(wrapped, _) => asPrimitiveValidators(wrapped)
case Validator.All(validators) => validators.flatMap(asPrimitiveValidators)
case Validator.Any(validators) => validators.flatMap(asPrimitiveValidators)
case Validator.CollectionElements(wrapped, _) => asPrimitiveValidators(wrapped)
case Validator.Product(_) => Nil
case Validator.Coproduct(_) => Nil
case Validator.OpenProduct(_) => Nil
case bv: Validator.Primitive[_] => List(bv)
case Validator.Mapped(wrapped, _) => asPrimitiveValidators(wrapped)
case Validator.All(validators) => validators.flatMap(asPrimitiveValidators)
case Validator.Any(validators) => validators.flatMap(asPrimitiveValidators)
case Validator.CollectionElements(_, _) => Nil
case Validator.Product(_) => Nil
case Validator.Coproduct(_) => Nil
case Validator.OpenProduct(_) => Nil
case bv: Validator.Primitive[_] => List(bv)
}
}

private[schema] def asPrimitiveValidatorsDeep(v: Validator[_]): Seq[Validator.Primitive[_]] = {
v match {
case Validator.Mapped(wrapped, _) => asPrimitiveValidatorsDeep(wrapped)
case Validator.All(validators) => validators.flatMap(asPrimitiveValidatorsDeep)
case Validator.Any(validators) => validators.flatMap(asPrimitiveValidatorsDeep)
case Validator.CollectionElements(mapped, _) => asPrimitiveValidatorsDeep(mapped)
case Validator.Product(_) => Nil
case Validator.Coproduct(_) => Nil
case Validator.OpenProduct(_) => Nil
case bv: Validator.Primitive[_] => List(bv)
}
}

Expand Down
34 changes: 34 additions & 0 deletions docs/openapi-docs/src/test/resources/expected_valid_enum_array.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
openapi: 3.0.1
info:
title: Fruits
version: '1.0'
paths:
/enum-test:
get:
operationId: getEnum-test
responses:
'200':
description: ''
content:
application/json:
schema:
$ref: '#/components/schemas/FruitWithEnum'
components:
schemas:
FruitWithEnum:
required:
- fruit
- amount
type: object
properties:
fruit:
type: string
amount:
type: integer
fruitType:
type: array
items:
type: string
enum:
- APPLE
- PEAR
21 changes: 21 additions & 0 deletions docs/openapi-docs/src/test/resources/expected_valid_int_array.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
openapi: 3.0.1
info:
title: Entities
version: '1.0'
paths:
/:
get:
operationId: getRoot
requestBody:
content:
application/json:
schema:
type: array
items:
type: integer
minimum: 1
maximum: 10
required: false
responses:
'200':
description: ''
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.circe.generic.auto._
import sttp.model.{Method, StatusCode}
import sttp.tapir.EndpointIO.Example
import sttp.tapir._
import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum
import sttp.tapir.docs.openapi.dtos.Book
import sttp.tapir.docs.openapi.dtos.a.{Pet => APet}
import sttp.tapir.docs.openapi.dtos.b.{Pet => BPet}
Expand All @@ -20,7 +21,7 @@ import scala.io.Source
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class VerifyYamlTest extends AnyFunSuite with Matchers {
class VerifyYamlTest extends AnyFunSuite with Matchers with TapirCodecEnumeratum {
val all_the_way: Endpoint[(FruitAmount, String), Unit, (FruitAmount, Int), Nothing] = endpoint
.in(("fruit" / path[String] / "amount" / path[Int]).mapTo(FruitAmount))
.in(query[String]("color"))
Expand Down Expand Up @@ -573,6 +574,16 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
actualYamlNoIndent shouldBe expectedYaml
}

test("render validator for additional properties of array elements") {
val expectedYaml = loadYaml("expected_valid_int_array.yml")

val actualYaml = Validation.in_valid_int_array
.toOpenAPI(Info("Entities", "1.0"))
.toYaml
val actualYamlNoIndent = noIndentation(actualYaml)
actualYamlNoIndent shouldBe expectedYaml
}

test("render enum validator for classes") {
val expectedYaml = loadYaml("expected_valid_enum_class.yml")

Expand Down Expand Up @@ -613,6 +624,16 @@ class VerifyYamlTest extends AnyFunSuite with Matchers {
actualYamlNoIndent shouldBe expectedYaml
}

test("use enum validator for array elements") {
val out_enum_array = endpoint.in(("enum-test")).out(jsonBody[FruitWithEnum])
val expectedYaml = loadYaml("expected_valid_enum_array.yml")

val actualYaml = List(out_enum_array).toOpenAPI(Info("Fruits", "1.0")).toYaml
val actualYamlNoIndent = noIndentation(actualYaml)

actualYamlNoIndent shouldBe expectedYaml
}

test("support example of list and not-list types") {
val expectedYaml = loadYaml("expected_examples_of_list_and_not_list_types.yml")
val actualYaml = endpoint.post
Expand Down
16 changes: 16 additions & 0 deletions tests/src/main/scala/sttp/tapir/tests/FruitAmount.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package sttp.tapir.tests

import enumeratum.EnumEntry
import enumeratum.Enum
import sttp.tapir._
import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum

import scala.collection.immutable

case class FruitAmount(fruit: String, amount: Int)

Expand All @@ -12,8 +17,19 @@ case class ValidFruitAmount(fruit: StringWrapper, amount: IntWrapper)

case class ColorWrapper(color: Color)

case class FruitWithEnum(fruit: String, amount: Int, fruitType: List[FruitType])

sealed trait Entity {
def name: String
}
case class Person(name: String, age: Int) extends Entity
case class Organization(name: String) extends Entity

sealed trait FruitType extends EnumEntry

object FruitType extends Enum[FruitType] {
case object APPLE extends FruitType
case object PEAR extends FruitType

override def values: immutable.IndexedSeq[FruitType] = findValues
}
8 changes: 8 additions & 0 deletions tests/src/main/scala/sttp/tapir/tests/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,14 @@ package object tests {
endpoint.in(jsonBody[ColorWrapper])
}

val in_valid_int_array: Endpoint[List[IntWrapper], Unit, Unit, Nothing] = {
implicit val schemaForIntWrapper: Schema[IntWrapper] = Schema(SchemaType.SInteger)
implicit val encoder: Encoder[IntWrapper] = Encoder.encodeInt.contramap(_.v)
implicit val decode: Decoder[IntWrapper] = Decoder.decodeInt.map(IntWrapper.apply)
implicit val v: Validator[IntWrapper] = Validator.all(Validator.min(1), Validator.max(10)).contramap(_.v)
endpoint.in(jsonBody[List[IntWrapper]])
}

val allEndpoints: Set[Endpoint[_, _, _, _]] = wireSet[Endpoint[_, _, _, _]]
}

Expand Down