Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually #624, thanks to @orensolo #634

Merged
merged 1 commit into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions openapi3filter/issue624_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,22 @@ paths:

router, err := gorillamux.NewRouter(doc)
require.NoError(t, err)
httpReq, err := http.NewRequest(http.MethodGet, `/items?test=test1`, nil)
require.NoError(t, err)

route, pathParams, err := router.FindRoute(httpReq)
require.NoError(t, err)
for _, testcase := range []string{`test1`, `test[1`} {
t.Run(testcase, func(t *testing.T) {
httpReq, err := http.NewRequest(http.MethodGet, `/items?test=`+testcase, nil)
require.NoError(t, err)

requestValidationInput := &RequestValidationInput{
Request: httpReq,
PathParams: pathParams,
Route: route,
route, pathParams, err := router.FindRoute(httpReq)
require.NoError(t, err)

requestValidationInput := &RequestValidationInput{
Request: httpReq,
PathParams: pathParams,
Route: route,
}
err = ValidateRequest(ctx, requestValidationInput)
require.NoError(t, err)
})
}
err = ValidateRequest(ctx, requestValidationInput)
require.NoError(t, err)
}
18 changes: 9 additions & 9 deletions openapi3filter/req_resp_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ func invalidSerializationMethodErr(sm *openapi3.SerializationMethod) error {
// Decodes a parameter defined via the content property as an object. It uses
// the user specified decoder, or our build-in decoder for application/json
func decodeContentParameter(param *openapi3.Parameter, input *RequestValidationInput) (
value interface{}, schema *openapi3.Schema, found bool, err error) {

value interface{},
schema *openapi3.Schema,
found bool,
err error,
) {
var paramValues []string
switch param.In {
case openapi3.ParameterInPath:
Expand Down Expand Up @@ -186,28 +189,25 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string)
}
outSchema = mt.Schema.Value

unmarshal := func(encoded string) (decoded interface{}, err error) {
unmarshal := func(encoded string, paramSchema *openapi3.SchemaRef) (decoded interface{}, err error) {
if err = json.Unmarshal([]byte(encoded), &decoded); err != nil {
const specialJSONChars = `[]{}":,`
if !strings.ContainsAny(encoded, specialJSONChars) {
// A string in a query parameter is not serialized with (double) quotes
// as JSON would expect, so let's fallback to that.
if paramSchema != nil && paramSchema.Value.Type != "object" {
decoded, err = encoded, nil
}
}
return
}

if len(values) == 1 {
if outValue, err = unmarshal(values[0]); err != nil {
if outValue, err = unmarshal(values[0], mt.Schema); err != nil {
err = fmt.Errorf("error unmarshaling parameter %q", param.Name)
return
}
} else {
outArray := make([]interface{}, 0, len(values))
for _, v := range values {
var item interface{}
if item, err = unmarshal(v); err != nil {
if item, err = unmarshal(v, outSchema.Items); err != nil {
err = fmt.Errorf("error unmarshaling parameter %q", param.Name)
return
}
Expand Down
22 changes: 7 additions & 15 deletions openapi3filter/validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ var ErrInvalidEmptyValue = errors.New("empty value is not allowed")
//
// Note: One can tune the behavior of uniqueItems: true verification
// by registering a custom function with openapi3.RegisterArrayUniqueItemsChecker
func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
var (
err error
me openapi3.MultiError
)
func ValidateRequest(ctx context.Context, input *RequestValidationInput) (err error) {
var me openapi3.MultiError

options := input.Options
if options == nil {
Expand All @@ -52,9 +49,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
}
if security != nil {
if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError {
return err
return
}

if err != nil {
me = append(me, err)
}
Expand All @@ -70,9 +66,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
}

if err = ValidateParameter(ctx, input, parameter); err != nil && !options.MultiError {
return err
return
}

if err != nil {
me = append(me, err)
}
Expand All @@ -81,9 +76,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
// For each parameter of the Operation
for _, parameter := range operationParameters {
if err = ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError {
return err
return
}

if err != nil {
me = append(me, err)
}
Expand All @@ -93,9 +87,8 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
requestBody := operation.RequestBody
if requestBody != nil && !options.ExcludeRequestBody {
if err = ValidateRequestBody(ctx, input, requestBody.Value); err != nil && !options.MultiError {
return err
return
}

if err != nil {
me = append(me, err)
}
Expand All @@ -104,8 +97,7 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error {
if len(me) > 0 {
return me
}

return nil
return
}

// ValidateParameter validates a parameter's value by JSON schema.
Expand Down
13 changes: 6 additions & 7 deletions openapi3filter/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func TestFilter(t *testing.T) {
}
err = ValidateResponse(context.Background(), responseValidationInput)
require.NoError(t, err)
return err
return nil
}
expect := func(req ExampleRequest, resp ExampleResponse) error {
return expectWithDecoder(req, resp, nil)
Expand All @@ -207,13 +207,12 @@ func TestFilter(t *testing.T) {
resp := ExampleResponse{
Status: 200,
}
// Test paths

// Test paths
req := ExampleRequest{
Method: "POST",
URL: "http://example.com/api/prefix/v/suffix",
}

err = expect(req, resp)
require.NoError(t, err)

Expand Down Expand Up @@ -328,15 +327,15 @@ func TestFilter(t *testing.T) {
// enough.
req = ExampleRequest{
Method: "POST",
URL: "http://example.com/api/prefix/v/suffix?contentArg={\"name\":\"bob\", \"id\":\"a\"}",
URL: `http://example.com/api/prefix/v/suffix?contentArg={"name":"bob", "id":"a"}`,
}
err = expect(req, resp)
require.NoError(t, err)

// Now it should fail due the ID being too long
req = ExampleRequest{
Method: "POST",
URL: "http://example.com/api/prefix/v/suffix?contentArg={\"name\":\"bob\", \"id\":\"EXCEEDS_MAX_LENGTH\"}",
URL: `http://example.com/api/prefix/v/suffix?contentArg={"name":"bob", "id":"EXCEEDS_MAX_LENGTH"}`,
}
err = expect(req, resp)
require.IsType(t, &RequestError{}, err)
Expand All @@ -351,15 +350,15 @@ func TestFilter(t *testing.T) {

req = ExampleRequest{
Method: "POST",
URL: "http://example.com/api/prefix/v/suffix?contentArg2={\"name\":\"bob\", \"id\":\"a\"}",
URL: `http://example.com/api/prefix/v/suffix?contentArg2={"name":"bob", "id":"a"}`,
}
err = expectWithDecoder(req, resp, customDecoder)
require.NoError(t, err)

// Now it should fail due the ID being too long
req = ExampleRequest{
Method: "POST",
URL: "http://example.com/api/prefix/v/suffix?contentArg2={\"name\":\"bob\", \"id\":\"EXCEEDS_MAX_LENGTH\"}",
URL: `http://example.com/api/prefix/v/suffix?contentArg2={"name":"bob", "id":"EXCEEDS_MAX_LENGTH"}`,
}
err = expectWithDecoder(req, resp, customDecoder)
require.IsType(t, &RequestError{}, err)
Expand Down