Skip to content

Commit

Permalink
Merge pull request #1126 from softwaremill/nicer-partial-endpoint-types
Browse files Browse the repository at this point in the history
Nicer partial endpoint types
  • Loading branch information
adamw authored Apr 1, 2021
2 parents a068920 + 7c29558 commit d2888b1
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 36 deletions.
18 changes: 8 additions & 10 deletions core/src/main/scala/sttp/tapir/Endpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,23 +350,21 @@ trait EndpointServerLogicOps[I, E, O, -R] { outer: Endpoint[I, E, O, R] =>
* An example use-case is defining an endpoint with fully-defined errors, and with authorization logic built-in.
* Such an endpoint can be then extended by multiple other endpoints.
*/
def serverLogicForCurrent[U, F[_]](f: I => F[Either[E, U]]): PartialServerEndpoint[U, Unit, E, O, R, F] =
new PartialServerEndpoint[U, Unit, E, O, R, F](this.copy(input = emptyInput)) {
override type T = I
override def tInput: EndpointInput[T] = outer.input
override def partialLogic: MonadError[F] => T => F[Either[E, U]] = _ => f
def serverLogicForCurrent[U, F[_]](f: I => F[Either[E, U]]): PartialServerEndpoint[I, U, Unit, E, O, R, F] =
new PartialServerEndpoint[I, U, Unit, E, O, R, F](this.copy(input = emptyInput)) {
override def tInput: EndpointInput[I] = outer.input
override def partialLogic: MonadError[F] => I => F[Either[E, U]] = _ => f
}

/** Same as [[serverLogicForCurrent]], but requires `E` to be a throwable, and coverts failed effects of type `E` to
* endpoint errors.
*/
def serverLogicForCurrentRecoverErrors[U, F[_]](
f: I => F[U]
)(implicit eIsThrowable: E <:< Throwable, eClassTag: ClassTag[E]): PartialServerEndpoint[U, Unit, E, O, R, F] =
new PartialServerEndpoint[U, Unit, E, O, R, F](this.copy(input = emptyInput)) {
override type T = I
override def tInput: EndpointInput[T] = outer.input
override def partialLogic: MonadError[F] => T => F[Either[E, U]] = recoverErrors(f)
)(implicit eIsThrowable: E <:< Throwable, eClassTag: ClassTag[E]): PartialServerEndpoint[I, U, Unit, E, O, R, F] =
new PartialServerEndpoint[I, U, Unit, E, O, R, F](this.copy(input = emptyInput)) {
override def tInput: EndpointInput[I] = outer.input
override def partialLogic: MonadError[F] => I => F[Either[E, U]] = recoverErrors(f)
}
}

Expand Down
35 changes: 16 additions & 19 deletions core/src/main/scala/sttp/tapir/server/PartialServerEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,32 @@ import scala.reflect.ClassTag
/** An endpoint, with some of the server logic already provided, and some left unspecified.
* See [[Endpoint.serverLogicForCurrent]].
*
* The part of the server logic which is provided transforms some inputs either to an error of type `E`, or value of
* type `U`.
* The part of the server logic which is provided transforms some inputs of type `T`, either to an error of type `E`,
* or value of type `U`.
*
* The part of the server logic which is not provided, transforms a tuple: `(U, I)` either into an error, or a value
* of type `O`.
* The part of the server logic which is not provided, will have to transform a tuple: `(U, I)` either into an error,
* or a value of type `O`.
*
* Inputs/outputs can be added to partial endpoints as to regular endpoints, however the shape of the error outputs
* is fixed and cannot be changed.
*
* @tparam T Original type of the input, transformed into U
* @tparam U Type of partially transformed input.
* @tparam I Input parameter types.
* @tparam E Error output parameter types.
* @tparam O Output parameter types.
* @tparam R The capabilities that are required by this endpoint's inputs/outputs. `Any`, if no requirements.
* @tparam F The effect type used in the provided partial server logic.
*/
abstract class PartialServerEndpoint[U, I, E, O, -R, F[_]](partialEndpoint: Endpoint[I, E, O, R])
abstract class PartialServerEndpoint[T, U, I, E, O, -R, F[_]](partialEndpoint: Endpoint[I, E, O, R])
extends EndpointInputsOps[I, E, O, R]
with EndpointOutputsOps[I, E, O, R]
with EndpointInfoOps[I, E, O, R]
with EndpointMetaOps[I, E, O, R] { outer =>
// original type of the partial input (transformed into U)
type T
protected def tInput: EndpointInput[T]
protected def partialLogic: MonadError[F] => T => F[Either[E, U]]

override type EndpointType[_I, _E, _O, -_R] = PartialServerEndpoint[U, _I, _E, _O, _R, F]
override type EndpointType[_I, _E, _O, -_R] = PartialServerEndpoint[T, U, _I, _E, _O, _R, F]

def endpoint: Endpoint[(T, I), E, O, R] = partialEndpoint.prependIn(tInput)

Expand All @@ -46,13 +45,12 @@ abstract class PartialServerEndpoint[U, I, E, O, -R, F[_]](partialEndpoint: Endp
override def output: EndpointOutput[O] = partialEndpoint.output
override def info: EndpointInfo = partialEndpoint.info

private def withEndpoint[I2, O2, R2 <: R](e2: Endpoint[I2, E, O2, R2]): PartialServerEndpoint[U, I2, E, O2, R2, F] =
new PartialServerEndpoint[U, I2, E, O2, R2, F](e2) {
override type T = outer.T
private def withEndpoint[I2, O2, R2 <: R](e2: Endpoint[I2, E, O2, R2]): PartialServerEndpoint[T, U, I2, E, O2, R2, F] =
new PartialServerEndpoint[T, U, I2, E, O2, R2, F](e2) {
override protected def tInput: EndpointInput[T] = outer.tInput
override protected def partialLogic: MonadError[F] => T => F[Either[E, U]] = outer.partialLogic
}
override private[tapir] def withInput[I2, R2](input: EndpointInput[I2]): PartialServerEndpoint[U, I2, E, O, R with R2, F] =
override private[tapir] def withInput[I2, R2](input: EndpointInput[I2]): PartialServerEndpoint[T, U, I2, E, O, R with R2, F] =
withEndpoint(partialEndpoint.withInput(input))
override private[tapir] def withOutput[O2, R2](output: EndpointOutput[O2]) = withEndpoint(partialEndpoint.withOutput(output))
override private[tapir] def withInfo(info: EndpointInfo) = withEndpoint(partialEndpoint.withInfo(info))
Expand All @@ -62,24 +60,23 @@ abstract class PartialServerEndpoint[U, I, E, O, -R, F[_]](partialEndpoint: Endp

def serverLogicForCurrent[V, UV](
f: I => F[Either[E, V]]
)(implicit concat: ParamConcat.Aux[U, V, UV]): PartialServerEndpoint[UV, Unit, E, O, R, F] = serverLogicForCurrentM(_ => f)
)(implicit concat: ParamConcat.Aux[U, V, UV]): PartialServerEndpoint[(T, I), UV, Unit, E, O, R, F] = serverLogicForCurrentM(_ => f)

def serverLogicForCurrentRecoverErrors[V, UV](
f: I => F[V]
)(implicit
concat: ParamConcat.Aux[U, V, UV],
eIsThrowable: E <:< Throwable,
eClassTag: ClassTag[E]
): PartialServerEndpoint[UV, Unit, E, O, R, F] =
): PartialServerEndpoint[(T, I), UV, Unit, E, O, R, F] =
serverLogicForCurrentM(recoverErrors(f))

private def serverLogicForCurrentM[V, UV](
_f: MonadError[F] => I => F[Either[E, V]]
)(implicit concat: ParamConcat.Aux[U, V, UV]): PartialServerEndpoint[UV, Unit, E, O, R, F] =
new PartialServerEndpoint[UV, Unit, E, O, R, F](partialEndpoint.copy(input = emptyInput)) {
override type T = (outer.T, I)
override def tInput: EndpointInput[(outer.T, I)] = outer.tInput.and(outer.partialEndpoint.input)
override def partialLogic: MonadError[F] => ((outer.T, I)) => F[Either[E, UV]] =
)(implicit concat: ParamConcat.Aux[U, V, UV]): PartialServerEndpoint[(T, I), UV, Unit, E, O, R, F] =
new PartialServerEndpoint[(T, I), UV, Unit, E, O, R, F](partialEndpoint.copy(input = emptyInput)) {
override def tInput: EndpointInput[(T, I)] = outer.tInput.and(outer.partialEndpoint.input)
override def partialLogic: MonadError[F] => ((T, I)) => F[Either[E, UV]] =
implicit monad => { case (t, i) =>
outer.partialLogic(monad)(t).flatMap {
case Left(e) => (Left(e): Either[E, UV]).unit
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/sttp/tapir/EndpointTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class EndpointTest extends AnyFlatSpec with EndpointTestExtensions with Matchers
case class User1(x: String, y: Int)
case class User2(z: Double)
case class Result(u1: User1, u2: User2, a: String)
val base: PartialServerEndpoint[User1, Unit, String, Unit, Any, Future] = endpoint
val base: PartialServerEndpoint[(String, Int), User1, Unit, String, Unit, Any, Future] = endpoint
.errorOut(stringBody)
.in(query[String]("x"))
.in(query[Int]("y"))
Expand Down Expand Up @@ -297,7 +297,7 @@ class EndpointTest extends AnyFlatSpec with EndpointTestExtensions with Matchers
}

"PartialServerEndpoint" should "include all inputs when recovering the endpoint" in {
val pe: PartialServerEndpoint[String, Unit, Int, Unit, Any, Future] =
val pe: PartialServerEndpoint[String, String, Unit, Int, Unit, Any, Future] =
endpoint
.in("secure")
.in(query[String]("token"))
Expand Down
2 changes: 1 addition & 1 deletion doc/server/logic.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def auth(token: String): Future[Either[Int, User]] = Future {
else Left(1001) // error code
}

val secureEndpoint: PartialServerEndpoint[User, Unit, Int, Unit, Any, Future] = endpoint
val secureEndpoint: PartialServerEndpoint[String, User, Unit, Int, Unit, Any, Future] = endpoint
.in(header[String]("X-AUTH-TOKEN"))
.errorOut(plainBody[Int])
.serverLogicForCurrent(auth)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import sttp.client3._
import sttp.tapir._
import sttp.tapir.server.PartialServerEndpoint
import sttp.tapir.server.{PartialServerEndpoint, ServerEndpoint}
import sttp.tapir.server.akkahttp.AkkaHttpServerInterpreter

import scala.concurrent.duration._
Expand All @@ -26,13 +26,13 @@ object PartialServerLogicAkka extends App {
}

// 1st approach: define a base endpoint, which has the authentication logic built-in
val secureEndpoint: PartialServerEndpoint[User, Unit, Int, Unit, Any, Future] = endpoint
val secureEndpoint: PartialServerEndpoint[String, User, Unit, Int, Unit, Any, Future] = endpoint
.in(header[String]("X-AUTH-TOKEN"))
.errorOut(plainBody[Int])
.serverLogicForCurrent(auth)

// extend the base endpoint to define (potentially multiple) proper endpoints, define the rest of the server logic
val secureHelloWorld1WithLogic = secureEndpoint.get
val secureHelloWorld1WithLogic: ServerEndpoint[(String, String), Int, String, Any, Future] = secureEndpoint.get
.in("hello1")
.in(query[String]("salutation"))
.out(stringBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import java.io.{ByteArrayInputStream, File, InputStream}
import java.nio.ByteBuffer
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import java.io.{ByteArrayInputStream, InputStream}

class ServerBasicTests[F[_], ROUTE](
backend: SttpBackend[IO, Any],
Expand Down

0 comments on commit d2888b1

Please sign in to comment.