diff --git a/build.sbt b/build.sbt index c126c86c77..1cb70420c1 100644 --- a/build.sbt +++ b/build.sbt @@ -457,7 +457,7 @@ lazy val testing: ProjectMatrix = (projectMatrix in file("testing")) .jvmPlatform(scalaVersions = scala2And3Versions) .jsPlatform(scalaVersions = scala2And3Versions, settings = commonJsSettings) .nativePlatform(scalaVersions = List(scala3), settings = commonNativeSettings) - .dependsOn(core) + .dependsOn(core, circeJson % Test) lazy val tests: ProjectMatrix = (projectMatrix in file("tests")) .settings(commonSettings) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index ba9b0a0579..93bef28c6d 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -220,9 +220,10 @@ class ServerInterpreter[R, F[_], B, S]( val statusCode = outputValues.statusCode.getOrElse(defaultStatusCode) val headers = outputValues.headers - outputValues.body match { - case Some(bodyFromHeaders) => ServerResponse(statusCode, headers, Some(bodyFromHeaders(Headers(headers))), Some(output)).unit - case None => ServerResponse(statusCode, headers, None: Option[B], Some(output)).unit + (statusCode, outputValues.body) match { + case (StatusCode.NoContent | StatusCode.NotModified, Some(_)) => monad.error(new IllegalStateException(s"Unexpected response body when status code == $statusCode")) + case (_, Some(bodyFromHeaders)) => ServerResponse(statusCode, headers, Some(bodyFromHeaders(Headers(headers))), Some(output)).unit + case (_, None) => ServerResponse(statusCode, headers, None: Option[B], Some(output)).unit } } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index e82f8aa2dd..a5f501eb27 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -769,6 +769,34 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( r.code shouldBe StatusCode.InternalServerError r.body shouldBe Symbol("left") } + }, + testServer( + "fail when status is 204 or 304, but there's a body", + NonEmptyList.of( + route(List( + endpoint.in("no_content").out(jsonBody[Unit]).out(statusCode(StatusCode.NoContent)).serverLogicSuccess[F](_ => pureResult(())), + endpoint.in("not_modified").out(jsonBody[Unit]).out(statusCode(StatusCode.NotModified)).serverLogicSuccess[F](_ => pureResult(())), + endpoint + .in("one_of") + .in(query[String]("select_err")) + .errorOut( + sttp.tapir.oneOf[ErrorInfo]( + oneOfVariant(statusCode(StatusCode.NotFound).and(jsonBody[NotFound])), + oneOfVariant(statusCode(StatusCode.NoContent).and(jsonBody[NoContentData])), + ) + ) + .serverLogic[F] { selectErr => + if (selectErr == "no_content") + pureResult[F, Either[ErrorInfo, Unit]](Left(NoContentData("error"))) + else + pureResult[F, Either[ErrorInfo, Unit]](Left(NotFound("error"))) + } + ))) + ) { (backend, baseUri) => + basicRequest.get(uri"$baseUri/no_content").send(backend).map(_.code shouldBe StatusCode.InternalServerError) >> + basicRequest.get(uri"$baseUri/not_modified").send(backend).map(_.code shouldBe StatusCode.InternalServerError) >> + basicRequest.get(uri"$baseUri/one_of?select_err=no_content").send(backend).map(_.code shouldBe StatusCode.InternalServerError) >> + basicRequest.get(uri"$baseUri/one_of?select_err=not_found").send(backend).map(_.code shouldBe StatusCode.NotFound) } ) @@ -787,3 +815,7 @@ object Animal extends Enum[Animal] with TapirCodecEnumeratum { override def values = findValues } + +sealed trait ErrorInfo +case class NotFound(what: String) extends ErrorInfo +case class NoContentData(msg: String) extends ErrorInfo diff --git a/testing/src/main/scala/sttp/tapir/testing/EndpointVerificationError.scala b/testing/src/main/scala/sttp/tapir/testing/EndpointVerificationError.scala index a054891de6..1065ba118e 100644 --- a/testing/src/main/scala/sttp/tapir/testing/EndpointVerificationError.scala +++ b/testing/src/main/scala/sttp/tapir/testing/EndpointVerificationError.scala @@ -2,6 +2,7 @@ package sttp.tapir.testing import sttp.model.Method import sttp.tapir.AnyEndpoint +import sttp.model.StatusCode sealed trait EndpointVerificationError @@ -52,3 +53,17 @@ case class IncorrectPathsError(e: AnyEndpoint, at: Int) extends EndpointVerifica case class DuplicatedMethodDefinitionError(e: AnyEndpoint, methods: List[Method]) extends EndpointVerificationError { override def toString: String = s"An endpoint ${e.show} have multiple method definitions: $methods" } + +/** + * Endpoint `e` defines outputs where status code indicates no body, but at the same time a body output is specified. For status codes 204 and 304 it's forbidden by specification. + * + * Example of incorrectly defined endpoint: + * + * {{{ + * endpoint.get.in("x").out(jsonBody[Unit]).out(statusCode(StatusCode.NoContent)) + * }}} + * + */ +case class UnexpectedBodyError(e: AnyEndpoint, statusCode: StatusCode) extends EndpointVerificationError { + override def toString: String = s"An endpoint ${e.show} may return status code ${statusCode} with body, which is not allowed by specificiation." +} diff --git a/testing/src/main/scala/sttp/tapir/testing/EndpointVerifier.scala b/testing/src/main/scala/sttp/tapir/testing/EndpointVerifier.scala index 7de8b6ef1a..745ae98e41 100644 --- a/testing/src/main/scala/sttp/tapir/testing/EndpointVerifier.scala +++ b/testing/src/main/scala/sttp/tapir/testing/EndpointVerifier.scala @@ -1,8 +1,9 @@ package sttp.tapir.testing import sttp.model.Method -import sttp.tapir.internal.{RichEndpointInput, UrlencodedData} -import sttp.tapir.{AnyEndpoint, EndpointInput, testing} +import sttp.model.StatusCode.{NoContent, NotModified} +import sttp.tapir.internal.{RichEndpointInput, RichEndpointOutput, UrlencodedData} +import sttp.tapir.{AnyEndpoint, EndpointIO, EndpointInput, EndpointOutput, testing} import scala.annotation.tailrec @@ -10,7 +11,8 @@ object EndpointVerifier { def apply(endpoints: List[AnyEndpoint]): Set[EndpointVerificationError] = { findShadowedEndpoints(endpoints, List()).groupBy(_.e).map(_._2.head).toSet ++ findIncorrectPaths(endpoints).toSet ++ - findDuplicatedMethodDefinitions(endpoints).toSet + findDuplicatedMethodDefinitions(endpoints).toSet ++ + findIncorrectStatusWithBody(endpoints).toSet } private def findIncorrectPaths(endpoints: List[AnyEndpoint]): List[IncorrectPathsError] = { @@ -35,6 +37,19 @@ object EndpointVerifier { in.filter(e => checkIfShadows(endpoint, e)).map(e => testing.ShadowedEndpointError(e, endpoint)) } + private def findIncorrectStatusWithBody(endpoints: List[AnyEndpoint]): List[UnexpectedBodyError] = + endpoints.flatMap { e => + val outputs = (e.output.asBasicOutputsList ++ e.errorOutput.asBasicOutputsList) + outputs.flatMap { outputElems => + val hasBody = outputElems.collectFirst { case b: EndpointIO.Body[_, _] => b }.isDefined + val noBodyStatusCodes = outputElems.collect { + case EndpointOutput.FixedStatusCode(NoContent, _, _) => NoContent + case EndpointOutput.FixedStatusCode(NotModified, _, _) => NotModified + } + if (hasBody) noBodyStatusCodes.map(UnexpectedBodyError(e, _)) else Nil + } + } + private def checkIfShadows(e1: AnyEndpoint, e2: AnyEndpoint): Boolean = checkMethods(e1, e2) && checkPaths(e1, e2) diff --git a/testing/src/test/scala/sttp/tapir/testing/EndpointVerifierTest.scala b/testing/src/test/scala/sttp/tapir/testing/EndpointVerifierTest.scala index 6a7968c895..da3fd801ed 100644 --- a/testing/src/test/scala/sttp/tapir/testing/EndpointVerifierTest.scala +++ b/testing/src/test/scala/sttp/tapir/testing/EndpointVerifierTest.scala @@ -1,9 +1,12 @@ package sttp.tapir.testing +import io.circe.generic.auto._ import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers -import sttp.model.Method +import sttp.model.{Method, StatusCode} import sttp.tapir._ +import sttp.tapir.generic.auto._ +import sttp.tapir.json.circe._ class EndpointVerifierTest extends AnyFlatSpecLike with Matchers { @@ -286,4 +289,42 @@ class EndpointVerifierTest extends AnyFlatSpecLike with Matchers { result shouldBe Set() } + + it should "detect endpoints with body where status code doesn't allow a body" in { + + val e1 = endpoint.in("endpoint1_Err").out(stringBody).out(statusCode(StatusCode.NoContent)) + val e2 = endpoint.in("endpoint2_Ok").out(stringBody).out(statusCode(StatusCode.BadRequest)) + val e3 = endpoint.in("endpoint3_Err").out(stringBody).out(statusCode(StatusCode.NotModified)) + val e4 = endpoint.in("endpoint4_ok").out(emptyOutputAs(NoContent)).out(statusCode(StatusCode.NoContent)) + val e5 = endpoint + .in("endpoint5_err") + .out(stringBody) + .errorOut( + sttp.tapir.oneOf[ErrorInfo]( + oneOfVariant(statusCode(StatusCode.NotFound).and(jsonBody[NotFound])), + oneOfVariant(statusCode(StatusCode.NoContent).and(jsonBody[NoContentData])) + ) + ) + val e6 = endpoint + .in("endpoint6_ok") + .errorOut( + sttp.tapir.oneOf[ErrorInfo]( + oneOfVariant(statusCode(StatusCode.NotFound).and(jsonBody[NotFound])), + oneOfVariant(statusCode(StatusCode.NoContent).and(emptyOutputAs(NoContent))) + ) + ) + + val result = EndpointVerifier(List(e1, e2, e3, e4, e5, e6)) + + result shouldBe Set( + UnexpectedBodyError(e1, StatusCode.NoContent), + UnexpectedBodyError(e3, StatusCode.NotModified), + UnexpectedBodyError(e5, StatusCode.NoContent) + ) + } } + +sealed trait ErrorInfo +case class NotFound(what: String) extends ErrorInfo +case object NoContent extends ErrorInfo +case class NoContentData(msg: String) extends ErrorInfo