diff --git a/openapi3filter/middleware.go b/openapi3filter/middleware.go new file mode 100644 index 000000000..bf2f04d32 --- /dev/null +++ b/openapi3filter/middleware.go @@ -0,0 +1,272 @@ +package openapi3filter + +import ( + "bytes" + "io" + "log" + "net/http" + + "github.com/getkin/kin-openapi/routers" +) + +// Validator provides HTTP request and response validation middleware. +type Validator struct { + router routers.Router + errFunc ErrFunc + logFunc LogFunc + strict bool +} + +// ErrFunc handles errors that may occur during validation. +type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error) + +// LogFunc handles log messages that may occur during validation. +type LogFunc func(message string, err error) + +// ErrCode is used for classification of different types of errors that may +// occur during validation. These may be used to write an appropriate response +// in ErrFunc. +type ErrCode int + +const ( + // ErrCodeOK indicates no error. It is also the default value. + ErrCodeOK = 0 + // ErrCodeCannotFindRoute happens when the validator fails to resolve the + // request to a defined OpenAPI route. + ErrCodeCannotFindRoute = iota + // ErrCodeRequestInvalid happens when the inbound request does not conform + // to the OpenAPI 3 specification. + ErrCodeRequestInvalid = iota + // ErrCodeResponseInvalid happens when the wrapped handler response does + // not conform to the OpenAPI 3 specification. + ErrCodeResponseInvalid = iota +) + +func (e ErrCode) responseText() string { + switch e { + case ErrCodeOK: + return "OK" + case ErrCodeCannotFindRoute: + return "not found" + case ErrCodeRequestInvalid: + return "bad request" + default: + return "server error" + } +} + +// NewValidator returns a new response validation middlware, using the given +// routes from an OpenAPI 3 specification. +func NewValidator(router routers.Router, options ...ValidatorOption) *Validator { + v := &Validator{ + router: router, + errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) { + http.Error(w, code.responseText(), status) + }, + logFunc: func(message string, err error) { + log.Printf("%s: %v", message, err) + }, + } + for i := range options { + options[i](v) + } + return v +} + +// ValidatorOption defines an option that may be specified when creating a +// Validator. +type ValidatorOption func(*Validator) + +// OnErr provides a callback that handles writing an HTTP response on a +// validation error. This allows customization of error responses without +// prescribing a particular form. This callback is only called on response +// validator errors in Strict mode. +func OnErr(f ErrFunc) ValidatorOption { + return func(v *Validator) { + v.errFunc = f + } +} + +// OnLog provides a callback that handles logging in the Validator. This allows +// the validator to integrate with a services' existing logging system without +// prescribing a particular one. +func OnLog(f LogFunc) ValidatorOption { + return func(v *Validator) { + v.logFunc = f + } +} + +// Strict, if set, causes an internal server error to be sent if the wrapped +// handler response fails response validation. If not set, the response is sent +// and the error is only logged. +func Strict(strict bool) ValidatorOption { + return func(v *Validator) { + v.strict = strict + } +} + +// Middleware returns an http.Handler which wraps the given handler with +// request and response validation. +func (v *Validator) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route, pathParams, err := v.router.FindRoute(r) + if err != nil { + v.logFunc("validation error: failed to find route for "+r.URL.String(), err) + v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err) + return + } + requestValidationInput := &RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + err = ValidateRequest(r.Context(), requestValidationInput) + if err != nil { + v.logFunc("invalid request", err) + v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err) + return + } + + var wr responseWrapper + if v.strict { + wr = &strictResponseWrapper{w: w} + } else { + wr = &warnResponseWrapper{w: w} + } + + h.ServeHTTP(wr, r) + + err = ValidateResponse(r.Context(), &ResponseValidationInput{ + RequestValidationInput: requestValidationInput, + Status: wr.statusCode(), + Header: wr.Header(), + Body: nopCloser{bytes.NewBuffer(wr.bodyContents())}, + }) + if err != nil { + v.logFunc("invalid response", err) + if v.strict { + v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err) + } + return + } + + err = wr.flushBodyContents() + if err != nil { + v.logFunc("failed to write response", err) + } + }) +} + +type nopCloser struct { + io.Reader +} + +// Close implements io.Closer. +func (nopCloser) Close() error { + return nil +} + +type responseWrapper interface { + http.ResponseWriter + + // flushBodyContents writes the buffered response to the client, if it has + // not yet been written. + flushBodyContents() error + + // statusCode returns the response status code, 0 if not set yet. + statusCode() int + + // bodyContents returns the buffered + bodyContents() []byte +} + +type warnResponseWrapper struct { + w http.ResponseWriter + fl http.Flusher + status int + body bytes.Buffer +} + +// Write implements http.ResponseWriter. +func (l *warnResponseWrapper) Write(b []byte) (int, error) { + if l.status == 0 { + l.w.WriteHeader(l.status) + } + n, err := l.w.Write(b) + if err == nil { + l.body.Write(b) + } + return n, err +} + +// WriteHeader implements http.ResponseWriter. +func (l *warnResponseWrapper) WriteHeader(status int) { + if l.status == 0 { + l.status = status + } + l.w.WriteHeader(l.status) +} + +// Header implements http.ResponseWriter. +func (wr *warnResponseWrapper) Header() http.Header { + return wr.w.Header() +} + +// Flush implements the optional http.Flusher interface. +func (wr *warnResponseWrapper) Flush() { + if fl, ok := wr.w.(http.Flusher); ok { + fl.Flush() + } +} + +func (l *warnResponseWrapper) flushBodyContents() error { + return nil +} + +func (l *warnResponseWrapper) statusCode() int { + return l.status +} + +func (l *warnResponseWrapper) bodyContents() []byte { + return l.body.Bytes() +} + +type strictResponseWrapper struct { + w http.ResponseWriter + status int + body bytes.Buffer +} + +// Write implements http.ResponseWriter. +func (wr *strictResponseWrapper) Write(b []byte) (int, error) { + if wr.status == 0 { + wr.status = http.StatusOK + } + return wr.body.Write(b) +} + +// WriteHeader implements http.ResponseWriter. +func (wr *strictResponseWrapper) WriteHeader(status int) { + if wr.status == 0 { + wr.status = status + } +} + +// Header implements http.ResponseWriter. +func (wr *strictResponseWrapper) Header() http.Header { + return wr.w.Header() +} + +func (wr *strictResponseWrapper) flushBodyContents() error { + wr.w.WriteHeader(wr.status) + _, err := wr.w.Write(wr.body.Bytes()) + return err +} + +func (wr *strictResponseWrapper) statusCode() int { + return wr.status +} + +func (wr *strictResponseWrapper) bodyContents() []byte { + return wr.body.Bytes() +} diff --git a/openapi3filter/middleware_test.go b/openapi3filter/middleware_test.go new file mode 100644 index 000000000..a7930e069 --- /dev/null +++ b/openapi3filter/middleware_test.go @@ -0,0 +1,384 @@ +package openapi3filter_test + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "regexp" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +// TODO: simple OpenAPI spec +const validatorSpec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: '0.0.0' +paths: + /test: + post: + operationId: newTest + description: create a new test + parameters: + - in: query + name: version + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/TestContents' } + responses: + '201': + description: 'created test' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } + /test/{id}: + get: + operationId: getTest + description: get a test + parameters: + - in: path + name: id + schema: + type: string + required: true + - in: query + name: version + schema: + type: string + required: true + responses: + '200': + description: 'respond with test resource' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '404': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } +components: + schemas: + TestContents: + type: object + properties: + name: + type: string + expected: + type: number + actual: + type: number + required: [name, expected, actual] + additionalProperties: false + TestResource: + type: object + properties: + id: + type: string + contents: + { $ref: '#/components/schemas/TestContents' } + required: [id, contents] + additionalProperties: false + Error: + type: object + properties: + code: + type: string + message: + type: string + required: [code, message] + additionalProperties: false + responses: + ErrorResponse: + description: 'an error occurred' + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } +` + +type validatorTestHandler struct { + contentType string + getBody, postBody string + errBody string + errStatusCode int +} + +const validatorOkResponse = `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}}` + +func (h validatorTestHandler) withDefaults() validatorTestHandler { + if h.contentType == "" { + h.contentType = "application/json" + } + if h.getBody == "" { + h.getBody = validatorOkResponse + } + if h.postBody == "" { + h.postBody = validatorOkResponse + } + if h.errBody == "" { + h.errBody = `{"code":"bad","message":"bad things"}` + } + return h +} + +var testUrlRE = regexp.MustCompile(`^/test(/\d+)?$`) + +func (h *validatorTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", h.contentType) + if h.errStatusCode != 0 { + w.WriteHeader(h.errStatusCode) + w.Write([]byte(h.errBody)) + return + } + if !testUrlRE.MatchString(r.URL.Path) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(h.errBody)) + return + } + switch r.Method { + case "GET": + w.WriteHeader(http.StatusOK) + w.Write([]byte(h.getBody)) + case "POST": + w.WriteHeader(http.StatusCreated) + w.Write([]byte(h.postBody)) + default: + http.Error(w, h.errBody, http.StatusMethodNotAllowed) + } +} + +func TestValidator(t *testing.T) { + doc, err := openapi3.NewLoader().LoadFromData([]byte(validatorSpec)) + if err != nil { + t.Error("failed to load test fixture spec:", err) + } + ctx := context.Background() + if err := doc.Validate(ctx); err != nil { + t.Error("invalid test fixture spec:", err) + } + type testRequest struct { + method, path, body, contentType string + } + type testResponse struct { + statusCode int + body string + } + tests := []struct { + name string + handler validatorTestHandler + options []openapi3filter.ValidatorOption + request testRequest + response testResponse + strict bool + }{{ + name: "valid GET", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 200, validatorOkResponse, + }, + strict: true, + }, { + name: "valid POST", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10}`, + contentType: "application/json", + }, + response: testResponse{ + 201, validatorOkResponse, + }, + strict: true, + }, { + name: "not found; no GET operation for /test", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test?version=1", + }, + response: testResponse{ + 404, "not found\n", + }, + strict: true, + }, { + name: "not found; no POST operation for /test/42", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test/42?version=1", + }, + response: testResponse{ + 404, "not found\n", + }, + strict: true, + }, { + name: "invalid request; missing version", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; wrong property type", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": "nine", "actual": "ten"}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; missing property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; extra property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10, "ideal": 8}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "valid response; 404 error", + handler: validatorTestHandler{ + contentType: "application/json", + errBody: `{"code": "404", "message": "not found"}`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 404, `{"code": "404", "message": "not found"}`, + }, + strict: true, + }, { + name: "invalid response; invalid error", + handler: validatorTestHandler{ + errBody: `"not found"`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 500, "server error\n", + }, + strict: true, + }, { + name: "invalid POST response; not strict", + handler: validatorTestHandler{ + postBody: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10}`, + contentType: "application/json", + }, + response: testResponse{ + statusCode: 201, + body: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }, + strict: false, + }} + for i, test := range tests { + t.Logf("test#%d: %s", i, test.name) + func() { + // Set up a test HTTP server + var h http.Handler + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + })) + defer s.Close() + + // Update the OpenAPI servers section with the test server URL This is + // needed by the router which matches request routes for OpenAPI + // validation. + doc.Servers = []*openapi3.Server{{URL: s.URL}} + if err := doc.Validate(ctx); err != nil { + t.Error("failed to validate with test server:", err) + } + // Create the router and validator + router, err := gorillamux.NewRouter(doc) + if err != nil { + t.Error("failed to create router:", err) + } + v := openapi3filter.NewValidator(router, append(test.options, openapi3filter.Strict(test.strict))...) + // Now wrap the test handler with the validator middlware + h = v.Middleware(&test.handler) + + // Test: make a client request + var requestBody io.Reader + if test.request.body != "" { + requestBody = bytes.NewBufferString(test.request.body) + } + req, err := http.NewRequest(test.request.method, s.URL+test.request.path, requestBody) + if err != nil { + t.Error("failed to create request:", err) + } + if test.request.contentType != "" { + req.Header.Set("Content-Type", test.request.contentType) + } + resp, err := s.Client().Do(req) + if err != nil { + t.Error("request failed:", err) + } + defer resp.Body.Close() + if test.response.statusCode != resp.StatusCode { + t.Errorf("response code expect %d got %d", test.response.statusCode, resp.StatusCode) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error("failed to read response body:", err) + } + if test.response.body != string(body) { + t.Errorf("response body expect %q got %q", test.response.body, string(body)) + } + }() + } +} + +// TODO: example