diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 8ef10a466777..55086531a63c 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -8,6 +8,8 @@ ### Bugs Fixed +* `runtime.MarshalAsByteArray` and `runtime.MarshalAsJSON` will preserve the preexisting value of the `Content-Type` header. + ### Other Changes ## 1.9.1 (2023-12-11) diff --git a/sdk/azcore/internal/exported/request.go b/sdk/azcore/internal/exported/request.go index 659f2a7d2ead..8d1ae213c958 100644 --- a/sdk/azcore/internal/exported/request.go +++ b/sdk/azcore/internal/exported/request.go @@ -125,46 +125,11 @@ func (req *Request) OperationValue(value interface{}) bool { // SetBody sets the specified ReadSeekCloser as the HTTP request body, and sets Content-Type and Content-Length // accordingly. If the ReadSeekCloser is nil or empty, Content-Length won't be set. If contentType is "", -// Content-Type won't be set. +// Content-Type won't be set, and if it was set, will be deleted. // Use streaming.NopCloser to turn an io.ReadSeeker into an io.ReadSeekCloser. func (req *Request) SetBody(body io.ReadSeekCloser, contentType string) error { - var err error - var size int64 - if body != nil { - size, err = body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size - if err != nil { - return err - } - } - if size == 0 { - // treat an empty stream the same as a nil one: assign req a nil body - body = nil - // RFC 9110 specifies a client shouldn't set Content-Length on a request containing no content - // (Del is a no-op when the header has no value) - req.req.Header.Del(shared.HeaderContentLength) - } else { - _, err = body.Seek(0, io.SeekStart) - if err != nil { - return err - } - req.req.Header.Set(shared.HeaderContentLength, strconv.FormatInt(size, 10)) - req.Raw().GetBody = func() (io.ReadCloser, error) { - _, err := body.Seek(0, io.SeekStart) // Seek back to the beginning of the stream - return body, err - } - } - // keep a copy of the body argument. this is to handle cases - // where req.Body is replaced, e.g. httputil.DumpRequest and friends. - req.body = body - req.req.Body = body - req.req.ContentLength = size - if contentType == "" { - // Del is a no-op when the header has no value - req.req.Header.Del(shared.HeaderContentType) - } else { - req.req.Header.Set(shared.HeaderContentType, contentType) - } - return nil + // clobber the existing Content-Type to preserve behavior + return SetBody(req, body, contentType, true) } // RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. @@ -211,3 +176,48 @@ type PolicyFunc func(*Request) (*http.Response, error) func (pf PolicyFunc) Do(req *Request) (*http.Response, error) { return pf(req) } + +// SetBody sets the specified ReadSeekCloser as the HTTP request body, and sets Content-Type and Content-Length accordingly. +// - req is the request to modify +// - body is the request body; if nil or empty, Content-Length won't be set +// - contentType is the value for the Content-Type header; if empty, Content-Type will be deleted +// - clobberContentType when true, will overwrite the existing value of Content-Type with contentType +func SetBody(req *Request, body io.ReadSeekCloser, contentType string, clobberContentType bool) error { + var err error + var size int64 + if body != nil { + size, err = body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size + if err != nil { + return err + } + } + if size == 0 { + // treat an empty stream the same as a nil one: assign req a nil body + body = nil + // RFC 9110 specifies a client shouldn't set Content-Length on a request containing no content + // (Del is a no-op when the header has no value) + req.req.Header.Del(shared.HeaderContentLength) + } else { + _, err = body.Seek(0, io.SeekStart) + if err != nil { + return err + } + req.req.Header.Set(shared.HeaderContentLength, strconv.FormatInt(size, 10)) + req.Raw().GetBody = func() (io.ReadCloser, error) { + _, err := body.Seek(0, io.SeekStart) // Seek back to the beginning of the stream + return body, err + } + } + // keep a copy of the body argument. this is to handle cases + // where req.Body is replaced, e.g. httputil.DumpRequest and friends. + req.body = body + req.req.Body = body + req.req.ContentLength = size + if contentType == "" { + // Del is a no-op when the header has no value + req.req.Header.Del(shared.HeaderContentType) + } else if req.req.Header.Get(shared.HeaderContentType) == "" || clobberContentType { + req.req.Header.Set(shared.HeaderContentType, contentType) + } + return nil +} diff --git a/sdk/azcore/internal/exported/request_test.go b/sdk/azcore/internal/exported/request_test.go index 3acc8e7a76ae..6aee60d141bd 100644 --- a/sdk/azcore/internal/exported/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -211,3 +211,22 @@ func TestRequestWithContext(t *testing.T) { req2.Raw().Header.Add("added-req2", "value") require.EqualValues(t, "value", req1.Raw().Header.Get("added-req2")) } + +func TestSetBodyWithClobber(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPatch, "https://contoso.com") + require.NoError(t, err) + require.NotNil(t, req) + req.req.Header.Set(shared.HeaderContentType, "clobber-me") + require.NoError(t, SetBody(req, NopCloser(strings.NewReader(`"json-string"`)), shared.ContentTypeAppJSON, true)) + require.EqualValues(t, shared.ContentTypeAppJSON, req.req.Header.Get(shared.HeaderContentType)) +} + +func TestSetBodyWithNoClobber(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPatch, "https://contoso.com") + require.NoError(t, err) + require.NotNil(t, req) + const mergePatch = "application/merge-patch+json" + req.req.Header.Set(shared.HeaderContentType, mergePatch) + require.NoError(t, SetBody(req, NopCloser(strings.NewReader(`"json-string"`)), shared.ContentTypeAppJSON, false)) + require.EqualValues(t, mergePatch, req.req.Header.Get(shared.HeaderContentType)) +} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index e97223da29e1..5d1569c8dd2b 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -97,7 +97,8 @@ func EncodeByteArray(v []byte, format Base64Encoding) string { func MarshalAsByteArray(req *policy.Request, v []byte, format Base64Encoding) error { // send as a JSON string encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) - return req.SetBody(exported.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON) + // tsp generated code can set Content-Type so we must prefer that + return exported.SetBody(req, exported.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON, false) } // MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. @@ -106,7 +107,8 @@ func MarshalAsJSON(req *policy.Request, v interface{}) error { if err != nil { return fmt.Errorf("error marshalling type %T: %s", v, err) } - return req.SetBody(exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON) + // tsp generated code can set Content-Type so we must prefer that + return exported.SetBody(req, exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON, false) } // MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody.