diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index cb381f2d9ad7..b30959363a0d 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -17,8 +17,7 @@ rather than silently truncating data in case the underlying tcp connection is closed prematurely. (thanks to @jochen-ott-by for the contribution) #20412 - UnboundLocalError when SansIOHTTPPolicy handles an exception #15222 - -### Other Changes +- Add default content type header of `text/plain` and content length header for users who pass unicode strings to the `content` kwarg of `HttpRequest` in 2.7 #21550 ## 1.19.1 (2021-11-01) diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 934bfa613c76..aa5b05bd9e10 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -134,7 +134,7 @@ def _shared_set_content_body(content): if isinstance(content, ET.Element): # XML body return set_xml_body(content) - if isinstance(content, (str, bytes)): + if isinstance(content, (six.string_types, bytes)): headers = {} body = content if isinstance(content, six.string_types): diff --git a/sdk/core/azure-core/tests/test_rest_http_request.py b/sdk/core/azure-core/tests/test_rest_http_request.py index 70c184b5cc3d..c56871dcd477 100644 --- a/sdk/core/azure-core/tests/test_rest_http_request.py +++ b/sdk/core/azure-core/tests/test_rest_http_request.py @@ -255,6 +255,17 @@ def test_data_str_input(): assert len(request.headers) == 1 assert request.headers['Content-Type'] == 'application/x-www-form-urlencoded' +def test_content_str_input(): + requests = [ + HttpRequest("POST", "/fake", content="hello, world!"), + HttpRequest("POST", "/fake", content=u"hello, world!"), + ] + for request in requests: + assert len(request.headers) == 2 + assert request.headers["Content-Type"] == "text/plain" + assert request.headers["Content-Length"] == "13" + assert request.content == "hello, world!" + @pytest.mark.parametrize(("value"), (object(), {"key": "value"})) def test_multipart_invalid_value(value):