Skip to content

Commit

Permalink
feat: reworked
Browse files Browse the repository at this point in the history
  • Loading branch information
ConsoleTVs committed Mar 30, 2024
1 parent 61b89f2 commit 4f9921e
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 27 deletions.
59 changes: 56 additions & 3 deletions example/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package main

import (
"errors"
"fmt"
"net/http"
"time"

"github.com/studiolambda/akumu"
"github.com/studiolambda/akumu/middleware"
)

type User struct {
Expand All @@ -19,6 +21,51 @@ func user(request *http.Request) error {
JSON(User{Name: "Erik C. Forés", Email: "[email protected]"})
}

func fails(request *http.Request) error {
return errors.New("something went wrong")
}

func fails2(request *http.Request) error {
return akumu.
Response(http.StatusNotFound).
Failed(errors.New("something went wrong"))
}

func fails3(request *http.Request) error {
return akumu.Failed(akumu.Problem{
Type: "http://example.com/problems/not-found",
Title: http.StatusText(http.StatusNotFound),
Detail: "The requested resource could not be found.",
Status: http.StatusNotFound,
Instance: request.URL.String(),
})
}

func fails4(request *http.Request) error {
return akumu.Response(http.StatusUnauthorized).
Failed(akumu.Problem{
Type: "http://example.com/problems/not-found",
Title: http.StatusText(http.StatusNotFound),
Detail: "The requested resource could not be found.",
Status: http.StatusNotFound,
Instance: request.URL.String(),
})
}

func fails5(request *http.Request) error {
return akumu.Response(http.StatusNotFound).
Failed(akumu.Problem{
Type: "http://example.com/problems/not-found",
Title: "page not found",
Detail: "the requested resource could not be found.",
Instance: request.URL.String(),
})
}

func panicable(request *http.Request) error {
panic("something went wrong")
}

func sse(request *http.Request) error {
messages := make(chan []byte)

Expand All @@ -39,12 +86,18 @@ func sse(request *http.Request) error {
func main() {
router := http.NewServeMux()

router.HandleFunc("GET /", akumu.Handle(user))
router.HandleFunc("GET /sse", akumu.Handle(sse))
router.Handle("GET /", akumu.Handler(user))
router.Handle("GET /sse", akumu.Handler(sse))
router.Handle("GET /fails", akumu.Handler(fails))
router.Handle("GET /fails2", akumu.Handler(fails2))
router.Handle("GET /fails3", akumu.Handler(fails3))
router.Handle("GET /fails4", akumu.Handler(fails4))
router.Handle("GET /fails5", akumu.Handler(fails5))
router.Handle("GET /panic", akumu.Handler(panicable))

server := http.Server{
Addr: ":8080",
Handler: router,
Handler: middleware.Recover(router),
}

fmt.Println("Server is running on: ", server.Addr)
Expand Down
14 changes: 2 additions & 12 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ import "net/http"

type Handler func(*http.Request) error

func Handle(handler Handler) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
response := handler(request)

if responder, ok := response.(Responder); ok {
responder.Handle(writer, request)

return
}

http.Error(writer, response.Error(), http.StatusInternalServerError)
}
func (handler Handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
HandleError(writer, request, handler(request), http.StatusInternalServerError)
}
36 changes: 36 additions & 0 deletions middleware/recover.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package middleware

import (
"errors"
"fmt"
"net/http"

"github.com/studiolambda/akumu"
)

func Recover(handler http.Handler) http.Handler {
return RecoverWith(handler, func(value any) error {
switch err := (value).(type) {
case error:
return err
case string:
return errors.New(err)
case fmt.Stringer:
return errors.New(err.String())
}

return errors.New("an unexpected error occurred")
})
}

func RecoverWith(handler http.Handler, handle func(value any) error) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
defer func() {
if err := recover(); err != nil {
akumu.HandleError(writer, request, handle(err), http.StatusInternalServerError)
}
}()

handler.ServeHTTP(writer, request)
})
}
16 changes: 16 additions & 0 deletions problem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// This package implements the Problem Details for HTTP APIs (RFC 7807) specification.
package akumu

// Problem represents a problem details for HTTP APIs.
// See https://tools.ietf.org/html/rfc7807 for more information.
type Problem struct {
Type string `json:"type"`
Title string `json:"title"`
Detail string `json:"detail"`
Status int `json:"status"`
Instance string `json:"instance"`
}

func (problem Problem) Error() string {
return problem.Title
}
32 changes: 32 additions & 0 deletions problem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package akumu_test

import (
"net/http"
"testing"

"github.com/studiolambda/akumu"
)

func problemHandler(request *http.Request) error {
return akumu.Failed(akumu.Problem{
Type: "http://example.com/problems/not-found",
Title: http.StatusText(http.StatusNotFound),
Detail: "The requested resource could not be found.",
Status: http.StatusNotFound,
Instance: request.URL.String(),
})
}

func TestProblemHandler(t *testing.T) {
request, err := http.NewRequest("GET", "/", nil)

if err != nil {
t.Fatalf("unable to create request: %s", err)
}

response := akumu.Record(problemHandler, request)

if response.Code != http.StatusNotFound {
t.Fatalf("unexpected status code: %d", response.Code)
}
}
85 changes: 73 additions & 12 deletions responder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,37 @@ type ResponderHandler func(http.ResponseWriter, *http.Request, Responder)

type Responder struct {
handler ResponderHandler
code int
status int
headers http.Header
body io.Reader
err error
stream <-chan []byte
}

func HandleError(writer http.ResponseWriter, request *http.Request, err error, status int) {
if responder, ok := err.(Responder); ok {
responder.Handle(writer, request)

return
}

if request.Header.Get("Accept") == "application/problem+json" {
problem := Problem{
Type: "about:blank",
Title: http.StatusText(status),
Detail: err.Error(),
Status: status,
Instance: request.URL.String(),
}

Failed(problem).Handle(writer, request)

return
}

http.Error(writer, err.Error(), status)
}

func DefaultResponderHandler(writer http.ResponseWriter, request *http.Request, responder Responder) {
for key, values := range responder.headers {
for _, value := range values {
Expand All @@ -26,7 +50,7 @@ func DefaultResponderHandler(writer http.ResponseWriter, request *http.Request,
}

if responder.err != nil {
http.Error(writer, responder.err.Error(), responder.code)
HandleError(writer, request, responder.err, responder.status)

return
}
Expand All @@ -35,19 +59,19 @@ func DefaultResponderHandler(writer http.ResponseWriter, request *http.Request,
body, err := io.ReadAll(responder.body)

if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
HandleError(writer, request, err, http.StatusInternalServerError)

return
}

writer.WriteHeader(responder.code)
writer.WriteHeader(responder.status)
writer.Write(body)

return
}

if responder.stream != nil {
writer.WriteHeader(responder.code)
writer.WriteHeader(responder.status)

for {
select {
Expand All @@ -64,26 +88,30 @@ func DefaultResponderHandler(writer http.ResponseWriter, request *http.Request,
}
}

writer.WriteHeader(responder.code)
writer.WriteHeader(responder.status)
}

func Response(code int) Responder {
func Response(status int) Responder {
return Responder{
handler: DefaultResponderHandler,
code: code,
status: status,
headers: make(http.Header),
body: nil,
err: nil,
stream: nil,
}
}

func Failed(err error) Responder {
return Response(http.StatusInternalServerError).Failed(err)
}

func (responder Responder) Error() string {
return http.StatusText(responder.code)
return http.StatusText(responder.status)
}

func (responder Responder) Code(code int) Responder {
responder.code = code
func (responder Responder) Status(status int) Responder {
responder.status = status

return responder
}
Expand All @@ -95,12 +123,14 @@ func (responder Responder) Headers(headers http.Header) Responder {
}

func (responder Responder) Header(key, value string) Responder {
responder.headers = responder.headers.Clone()
responder.headers.Set(key, value)

return responder
}

func (responder Responder) AppendHeader(key, value string) Responder {
responder.headers = responder.headers.Clone()
responder.headers.Add(key, value)

return responder
Expand Down Expand Up @@ -131,17 +161,48 @@ func (responder Responder) SSE(stream <-chan []byte) Responder {
}

func (responder Responder) Failed(err error) Responder {
if problem, ok := err.(Problem); ok {
if problem.Type == "" {
problem.Type = "about:blank"
}

if problem.Status == 0 {
problem.Status = responder.status
}

if problem.Title == "" {
problem.Title = http.StatusText(problem.Status)
}

return responder.
Status(problem.Status).
JSON(problem).
Header("Content-Type", "application/problem+json")
}

responder.err = err

return responder
}

func (responder Responder) Text(body string) Responder {
return responder.
Header("Content-Type", "text/plain").
Body([]byte(body))
}

func (responder Responder) HTML(html string) Responder {
return responder.
Header("Content-Type", "text/html").
Body([]byte(html))
}

func (responder Responder) JSON(body any) Responder {
buffer := &bytes.Buffer{}

if err := json.NewEncoder(buffer).Encode(body); err != nil {
return responder.
Code(http.StatusInternalServerError).
Status(http.StatusInternalServerError).
Failed(err)
}

Expand Down
24 changes: 24 additions & 0 deletions testing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package akumu

import (
"net/http"
"net/http/httptest"
)

func RecordServer(server *http.Server, request *http.Request) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
server.Handler.ServeHTTP(recorder, request)

return recorder
}

func RecordHandler(handler http.Handler, request *http.Request) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)

return recorder
}

func Record(handler Handler, request *http.Request) *httptest.ResponseRecorder {
return RecordHandler(handler, request)
}
Loading

0 comments on commit 4f9921e

Please sign in to comment.