diff --git a/openapi3/schema.go b/openapi3/schema.go index 8a34282ff..a4e3b8d00 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -397,7 +397,7 @@ func (schema *Schema) WithAdditionalProperties(v *Schema) *Schema { func (schema *Schema) IsEmpty() bool { if schema.Type != "" || schema.Format != "" || len(schema.Enum) != 0 || schema.UniqueItems || schema.ExclusiveMin || schema.ExclusiveMax || - schema.Nullable || + schema.Nullable || schema.ReadOnly || schema.WriteOnly || schema.AllowEmptyValue || schema.Min != nil || schema.Max != nil || schema.MultipleOf != nil || schema.MinLength != 0 || schema.MaxLength != nil || schema.Pattern != "" || schema.MinItems != 0 || schema.MaxItems != nil || @@ -452,6 +452,10 @@ func (schema *Schema) validate(c context.Context, stack []*Schema) (err error) { } stack = append(stack, schema) + if schema.ReadOnly && schema.WriteOnly { + return errors.New("A property MUST NOT be marked as both readOnly and writeOnly being true") + } + for _, item := range schema.OneOf { v := item.Value if v == nil { @@ -577,37 +581,44 @@ func (schema *Schema) validate(c context.Context, stack []*Schema) (err error) { } func (schema *Schema) IsMatching(value interface{}) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } func (schema *Schema) IsMatchingJSONBoolean(value bool) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } func (schema *Schema) IsMatchingJSONNumber(value float64) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } func (schema *Schema) IsMatchingJSONString(value string) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } func (schema *Schema) IsMatchingJSONArray(value []interface{}) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } func (schema *Schema) IsMatchingJSONObject(value map[string]interface{}) bool { - return schema.visitJSON(value, true) == nil + settings := newSchemaValidationSettings(FailFast()) + return schema.visitJSON(settings, value) == nil } -func (schema *Schema) VisitJSON(value interface{}) error { - return schema.visitJSON(value, false) +func (schema *Schema) VisitJSON(value interface{}, opts ...SchemaValidationOption) error { + settings := newSchemaValidationSettings(opts...) + return schema.visitJSON(settings, value) } -func (schema *Schema) visitJSON(value interface{}, fast bool) (err error) { +func (schema *Schema) visitJSON(settings *schemaValidationSettings, value interface{}) (err error) { switch value := value.(type) { case nil: - return schema.visitJSONNull(fast) + return schema.visitJSONNull(settings) case float64: if math.IsNaN(value) { return ErrSchemaInputNaN @@ -620,23 +631,23 @@ func (schema *Schema) visitJSON(value interface{}, fast bool) (err error) { if schema.IsEmpty() { return } - if err = schema.visitSetOperations(value, fast); err != nil { + if err = schema.visitSetOperations(settings, value); err != nil { return } switch value := value.(type) { case nil: - return schema.visitJSONNull(fast) + return schema.visitJSONNull(settings) case bool: - return schema.visitJSONBoolean(value, fast) + return schema.visitJSONBoolean(settings, value) case float64: - return schema.visitJSONNumber(value, fast) + return schema.visitJSONNumber(settings, value) case string: - return schema.visitJSONString(value, fast) + return schema.visitJSONString(settings, value) case []interface{}: - return schema.visitJSONArray(value, fast) + return schema.visitJSONArray(settings, value) case map[string]interface{}: - return schema.visitJSONObject(value, fast) + return schema.visitJSONObject(settings, value) default: return &SchemaError{ Value: value, @@ -647,14 +658,16 @@ func (schema *Schema) visitJSON(value interface{}, fast bool) (err error) { } } -func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err error) { +func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, value interface{}) (err error) { + var oldfailfast bool + if enum := schema.Enum; len(enum) != 0 { for _, v := range enum { if value == v { return } } - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -670,8 +683,9 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro if v == nil { return foundUnresolvedRef(ref.Ref) } - if err := v.visitJSON(value, true); err == nil { - if fast { + oldfailfast, settings.failfast = settings.failfast, true + if err := v.visitJSON(settings, value); err == nil { + if oldfailfast { return errSchema } return &SchemaError{ @@ -680,6 +694,7 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro SchemaField: "not", } } + settings.failfast = oldfailfast } if v := schema.OneOf; len(v) > 0 { @@ -689,12 +704,14 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro if v == nil { return foundUnresolvedRef(item.Ref) } - if err := v.visitJSON(value, true); err == nil { + oldfailfast, settings.failfast = settings.failfast, true + if err := v.visitJSON(settings, value); err == nil { ok++ } + settings.failfast = oldfailfast } if ok != 1 { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -712,13 +729,15 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro if v == nil { return foundUnresolvedRef(item.Ref) } - if err := v.visitJSON(value, true); err == nil { + oldfailfast, settings.failfast = settings.failfast, true + if err := v.visitJSON(settings, value); err == nil { ok = true break } + settings.failfast = oldfailfast } if !ok { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -734,8 +753,9 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro if v == nil { return foundUnresolvedRef(item.Ref) } - if err := v.visitJSON(value, false); err != nil { - if fast { + oldfailfast, settings.failfast = settings.failfast, false + if err := v.visitJSON(settings, value); err != nil { + if oldfailfast { return errSchema } return &SchemaError{ @@ -745,15 +765,16 @@ func (schema *Schema) visitSetOperations(value interface{}, fast bool) (err erro Origin: err, } } + settings.failfast = oldfailfast } return } -func (schema *Schema) visitJSONNull(fast bool) (err error) { +func (schema *Schema) visitJSONNull(settings *schemaValidationSettings) (err error) { if schema.Nullable { return } - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -765,25 +786,27 @@ func (schema *Schema) visitJSONNull(fast bool) (err error) { } func (schema *Schema) VisitJSONBoolean(value bool) error { - return schema.visitJSONBoolean(value, false) + settings := newSchemaValidationSettings() + return schema.visitJSONBoolean(settings, value) } -func (schema *Schema) visitJSONBoolean(value bool, fast bool) (err error) { +func (schema *Schema) visitJSONBoolean(settings *schemaValidationSettings, value bool) (err error) { if schemaType := schema.Type; schemaType != "" && schemaType != "boolean" { - return schema.expectedType("boolean", fast) + return schema.expectedType(settings, "boolean") } return } func (schema *Schema) VisitJSONNumber(value float64) error { - return schema.visitJSONNumber(value, false) + settings := newSchemaValidationSettings() + return schema.visitJSONNumber(settings, value) } -func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { +func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value float64) (err error) { schemaType := schema.Type if schemaType == "integer" { if bigFloat := big.NewFloat(value); !bigFloat.IsInt() { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -794,12 +817,12 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { } } } else if schemaType != "" && schemaType != "number" { - return schema.expectedType("number, integer", fast) + return schema.expectedType(settings, "number, integer") } // "exclusiveMinimum" if v := schema.ExclusiveMin; v && !(*schema.Min < value) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -812,7 +835,7 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { // "exclusiveMaximum" if v := schema.ExclusiveMax; v && !(*schema.Max > value) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -825,7 +848,7 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { // "minimum" if v := schema.Min; v != nil && !(*v <= value) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -838,7 +861,7 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { // "maximum" if v := schema.Max; v != nil && !(*v >= value) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -854,7 +877,7 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { // "A numeric instance is valid only if division by this keyword's // value results in an integer." if bigFloat := big.NewFloat(value / *v); !bigFloat.IsInt() { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -868,12 +891,13 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) (err error) { } func (schema *Schema) VisitJSONString(value string) error { - return schema.visitJSONString(value, false) + settings := newSchemaValidationSettings() + return schema.visitJSONString(settings, value) } -func (schema *Schema) visitJSONString(value string, fast bool) (err error) { +func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value string) (err error) { if schemaType := schema.Type; schemaType != "" && schemaType != "string" { - return schema.expectedType("string", fast) + return schema.expectedType(settings, "string") } // "minLength" and "maxLength" @@ -890,7 +914,7 @@ func (schema *Schema) visitJSONString(value string, fast bool) (err error) { } } if minLength != 0 && length < int64(minLength) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -901,7 +925,7 @@ func (schema *Schema) visitJSONString(value string, fast bool) (err error) { } } if maxLength != nil && length > int64(*maxLength) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -958,19 +982,20 @@ func (schema *Schema) visitJSONString(value string, fast bool) (err error) { } func (schema *Schema) VisitJSONArray(value []interface{}) error { - return schema.visitJSONArray(value, false) + settings := newSchemaValidationSettings() + return schema.visitJSONArray(settings, value) } -func (schema *Schema) visitJSONArray(value []interface{}, fast bool) (err error) { +func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value []interface{}) (err error) { if schemaType := schema.Type; schemaType != "" && schemaType != "array" { - return schema.expectedType("array", fast) + return schema.expectedType(settings, "array") } lenValue := int64(len(value)) // "minItems" if v := schema.MinItems; v != 0 && lenValue < int64(v) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -983,7 +1008,7 @@ func (schema *Schema) visitJSONArray(value []interface{}, fast bool) (err error) // "maxItems" if v := schema.MaxItems; v != nil && lenValue > int64(*v) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -999,7 +1024,7 @@ func (schema *Schema) visitJSONArray(value []interface{}, fast bool) (err error) sliceUniqueItemsChecker = isSliceOfUniqueItems } if v := schema.UniqueItems; v && !sliceUniqueItemsChecker(value) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -1026,12 +1051,13 @@ func (schema *Schema) visitJSONArray(value []interface{}, fast bool) (err error) } func (schema *Schema) VisitJSONObject(value map[string]interface{}) error { - return schema.visitJSONObject(value, false) + settings := newSchemaValidationSettings() + return schema.visitJSONObject(settings, value) } -func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) (err error) { +func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value map[string]interface{}) (err error) { if schemaType := schema.Type; schemaType != "" && schemaType != "object" { - return schema.expectedType("object", fast) + return schema.expectedType(settings, "object") } // "properties" @@ -1040,7 +1066,7 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( // "minProperties" if v := schema.MinProps; v != 0 && lenValue < int64(v) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -1053,7 +1079,7 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( // "maxProperties" if v := schema.MaxProps; v != nil && lenValue > int64(*v) { - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -1078,7 +1104,7 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( return foundUnresolvedRef(propertyRef.Ref) } if err := p.VisitJSON(v); err != nil { - if fast { + if settings.failfast { return errSchema } return markSchemaErrorKey(err, k) @@ -1090,7 +1116,7 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( if additionalProperties != nil || allowed == nil || (allowed != nil && *allowed) { if additionalProperties != nil { if err := additionalProperties.VisitJSON(v); err != nil { - if fast { + if settings.failfast { return errSchema } return markSchemaErrorKey(err, k) @@ -1098,7 +1124,7 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( } continue } - if fast { + if settings.failfast { return errSchema } return &SchemaError{ @@ -1108,9 +1134,17 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( Reason: fmt.Sprintf("Property '%s' is unsupported", k), } } + + // "required" for _, k := range schema.Required { if _, ok := value[k]; !ok { - if fast { + if s := schema.Properties[k]; s != nil && s.Value.ReadOnly && settings.asreq { + continue + } + if s := schema.Properties[k]; s != nil && s.Value.WriteOnly && settings.asrep { + continue + } + if settings.failfast { return errSchema } return markSchemaErrorKey(&SchemaError{ @@ -1124,8 +1158,8 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) ( return } -func (schema *Schema) expectedType(typ string, fast bool) error { - if fast { +func (schema *Schema) expectedType(settings *schemaValidationSettings, typ string) error { + if settings.failfast { return errSchema } return &SchemaError{ diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go new file mode 100644 index 000000000..6c073cd43 --- /dev/null +++ b/openapi3/schema_validation_settings.go @@ -0,0 +1,29 @@ +package openapi3 + +// SchemaValidationOption describes options a user has when validating request / response bodies. +type SchemaValidationOption func(*schemaValidationSettings) + +type schemaValidationSettings struct { + failfast bool + asreq, asrep bool // exclusive (XOR) fields +} + +// FailFast returns schema validation errors quicker. +func FailFast() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.failfast = true } +} + +func VisitAsRequest() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.asreq, s.asrep = true, false } +} +func VisitAsResponse() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.asreq, s.asrep = false, true } +} + +func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { + settings := &schemaValidationSettings{} + for _, opt := range opts { + opt(settings) + } + return settings +} diff --git a/openapi3filter/validate_readonly_test.go b/openapi3filter/validate_readonly_test.go new file mode 100644 index 000000000..fac9bf524 --- /dev/null +++ b/openapi3filter/validate_readonly_test.go @@ -0,0 +1,88 @@ +package openapi3filter + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "post": { + "description": "Create a new account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["_id"], + "properties": { + "_id": { + "type": "string", + "description": "Unique identifier for this object.", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20, + "readOnly": true + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully created a new account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + type Request struct { + ID string `json:"_id"` + } + + sl := openapi3.NewSwaggerLoader() + l, err := sl.LoadSwaggerFromData([]byte(spec)) + require.NoError(t, err) + router := NewRouter().WithSwagger(l) + + b, err := json.Marshal(Request{ID: "bt6kdc3d0cvp6u8u3ft0"}) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add("Content-Type", "application/json") + + route, pathParams, err := router.FindRoute(httpReq.Method, httpReq.URL) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + }) + require.NoError(t, err) +} diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index f22e71efb..69fc58bd1 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -192,7 +192,7 @@ func ValidateRequestBody(c context.Context, input *RequestValidationInput, reque } // Validate JSON with the schema - if err := contentType.Schema.Value.VisitJSON(value); err != nil { + if err := contentType.Schema.Value.VisitJSON(value, openapi3.VisitAsRequest()); err != nil { return &RequestError{ Input: input, RequestBody: requestBody, diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index dad0864d2..9a458aa1b 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -129,7 +129,7 @@ func ValidateResponse(c context.Context, input *ResponseValidationInput) error { } // Validate data with the schema. - if err := contentType.Schema.Value.VisitJSON(value); err != nil { + if err := contentType.Schema.Value.VisitJSON(value, openapi3.VisitAsResponse()); err != nil { return &ResponseError{ Input: input, Reason: "response body doesn't match the schema",