From a27c9e791d943bf6c94bb3f750c250aa5cfc133b Mon Sep 17 00:00:00 2001 From: Toly Date: Mon, 1 Jul 2024 17:59:33 +0200 Subject: [PATCH] openapi3: resolve recursive file references (#974) * openapi3: resolve recursive file references * fix file name capitalization --- openapi3/loader.go | 143 +++++++++--------- openapi3/loader_circular_test.go | 41 +++++ .../circularRef2/AwsEnvironmentSettings.yaml | 7 + openapi3/testdata/circularRef2/circular2.yaml | 16 ++ 4 files changed, 137 insertions(+), 70 deletions(-) create mode 100644 openapi3/loader_circular_test.go create mode 100644 openapi3/testdata/circularRef2/AwsEnvironmentSettings.yaml create mode 100644 openapi3/testdata/circularRef2/circular2.yaml diff --git a/openapi3/loader.go b/openapi3/loader.go index 381b38bc..003ba6c7 100644 --- a/openapi3/loader.go +++ b/openapi3/loader.go @@ -55,6 +55,9 @@ func NewLoader() *Loader { func (loader *Loader) resetVisitedPathItemRefs() { loader.visitedPathItemRefs = make(map[string]struct{}) + loader.visitedRefs = make(map[string]struct{}) + loader.visitedPath = nil + loader.backtrack = make(map[string][]func(value any)) } // LoadFromURI loads a spec from a remote URL @@ -558,6 +561,12 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Header) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var header Header if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &header); err != nil { @@ -566,15 +575,8 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat component.Value = &header component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Header) - }) { - return nil - } var resolved HeaderRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -587,6 +589,7 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -610,6 +613,12 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Parameter) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var param Parameter if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, ¶m); err != nil { @@ -618,15 +627,8 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum component.Value = ¶m component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Parameter) - }) { - return nil - } var resolved ParameterRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -639,6 +641,7 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -672,6 +675,12 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*RequestBody) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var requestBody RequestBody if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &requestBody); err != nil { @@ -680,15 +689,8 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d component.Value = &requestBody component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*RequestBody) - }) { - return nil - } var resolved RequestBodyRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -701,6 +703,7 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -741,6 +744,12 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Response) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var resp Response if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resp); err != nil { @@ -749,15 +758,8 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen component.Value = &resp component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Response) - }) { - return nil - } var resolved ResponseRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -770,6 +772,7 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -821,6 +824,12 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Schema) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var schema Schema if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &schema); err != nil { @@ -829,15 +838,8 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = &schema component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Schema) - }) { - return nil - } var resolved SchemaRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -850,6 +852,7 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -904,6 +907,12 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*SecurityScheme) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var scheme SecurityScheme if _, err = loader.loadSingleElementFromURI(ref, documentPath, &scheme); err != nil { @@ -912,15 +921,8 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme component.Value = &scheme component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*SecurityScheme) - }) { - return nil - } var resolved SecuritySchemeRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -933,6 +935,7 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } return nil } @@ -942,6 +945,12 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Example) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var example Example if _, err = loader.loadSingleElementFromURI(ref, documentPath, &example); err != nil { @@ -950,15 +959,8 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP component.Value = &example component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Example) - }) { - return nil - } var resolved ExampleRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -971,6 +973,7 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } return nil } @@ -984,6 +987,12 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Callback) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var resolved Callback if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resolved); err != nil { @@ -992,15 +1001,8 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen component.Value = &resolved component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Callback) - }) { - return nil - } var resolved CallbackRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -1013,6 +1015,7 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } value := component.Value if value == nil { @@ -1036,6 +1039,12 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u if component.Value != nil { return nil } + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Link) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var link Link if _, err = loader.loadSingleElementFromURI(ref, documentPath, &link); err != nil { @@ -1044,15 +1053,8 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u component.Value = &link component.refPath = *documentPath } else { - if !loader.shouldVisitRef(ref, func(value any) { - component.Value = value.(*Link) - }) { - return nil - } var resolved LinkRef - loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -1065,6 +1067,7 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u component.Value = resolved.Value component.refPath = resolved.refPath } + defer loader.unvisitRef(ref, component.Value) } return nil } @@ -1079,6 +1082,12 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat if !pathItem.isEmpty() { return } + if !loader.shouldVisitRef(ref, func(value any) { + *pathItem = *value.(*PathItem) + }) { + return nil + } + loader.visitRef(ref) if isSingleRefElement(ref) { var p PathItem if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &p); err != nil { @@ -1086,15 +1095,8 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat } *pathItem = p } else { - if !loader.shouldVisitRef(ref, func(value any) { - *pathItem = *value.(*PathItem) - }) { - return nil - } var resolved PathItem - loader.visitRef(ref) doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved) - defer loader.unvisitRef(ref, &resolved) if err != nil { if err == errMUSTPathItem { return nil @@ -1104,6 +1106,7 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat *pathItem = resolved } pathItem.Ref = ref + defer loader.unvisitRef(ref, pathItem) } for _, parameter := range pathItem.Parameters { diff --git a/openapi3/loader_circular_test.go b/openapi3/loader_circular_test.go new file mode 100644 index 00000000..864e7708 --- /dev/null +++ b/openapi3/loader_circular_test.go @@ -0,0 +1,41 @@ +package openapi3 + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadCircular(t *testing.T) { + loader := NewLoader() + loader.IsExternalRefsAllowed = true + doc, err := loader.LoadFromFile("testdata/circularRef2/circular2.yaml") + require.NoError(t, err) + require.NotNil(t, doc) + + ref := "./AwsEnvironmentSettings.yaml" + + arr := NewArraySchema() + obj := NewObjectSchema() + arr.Items = &SchemaRef{ + Ref: ref, + Value: obj, + } + obj.Description = "test" + obj.Properties = map[string]*SchemaRef{ + "children": { + Value: arr, + }, + } + + expected := &SchemaRef{ + Ref: ref, + Value: obj, + } + + actual := doc.Paths.Map()["/sample"].Put.RequestBody.Value.Content.Get("application/json").Schema + actual.refPath = url.URL{} + + require.Equal(t, expected.Value, actual.Value) +} diff --git a/openapi3/testdata/circularRef2/AwsEnvironmentSettings.yaml b/openapi3/testdata/circularRef2/AwsEnvironmentSettings.yaml new file mode 100644 index 00000000..33ba38f3 --- /dev/null +++ b/openapi3/testdata/circularRef2/AwsEnvironmentSettings.yaml @@ -0,0 +1,7 @@ +type: object +properties: + children: + type: array + items: + $ref: './AwsEnvironmentSettings.yaml' +description: test \ No newline at end of file diff --git a/openapi3/testdata/circularRef2/circular2.yaml b/openapi3/testdata/circularRef2/circular2.yaml new file mode 100644 index 00000000..f10bf590 --- /dev/null +++ b/openapi3/testdata/circularRef2/circular2.yaml @@ -0,0 +1,16 @@ +openapi: 3.0.0 +info: + title: Circular Reference Example + version: 1.0.0 +paths: + /sample: + put: + requestBody: + required: true + content: + application/json: + schema: + $ref: './AwsEnvironmentSettings.yaml' + responses: + '200': + description: Ok \ No newline at end of file