Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #2586 from matrix-org/rav/frontend_proxy_auth_header
Browse files Browse the repository at this point in the history
Front-end proxy: pass through auth header
  • Loading branch information
richvdh authored Oct 27, 2017
2 parents 8854c03 + 173567a commit 8b56977
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 28 deletions.
7 changes: 7 additions & 0 deletions synapse/app/frontend_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,16 @@ def on_POST(self, request, device_id):

if body:
# They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.post_json_get_json(
self.main_uri + request.uri,
body,
headers=headers,
)

defer.returnValue((200, result))
Expand Down
108 changes: 80 additions & 28 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,34 @@ def send_request():
raise e

@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}):
def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""

# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)

query_bytes = urllib.urlencode(encode_urlencode_args(args), True)

actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)

response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)

Expand All @@ -135,18 +150,33 @@ def post_urlencoded_get_json(self, uri, args={}):
defer.returnValue(json.loads(body))

@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
uri (str):
post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
json_str = encode_canonical_json(post_json)

logger.debug("HTTP POST %s -> %s", json_str, uri)

actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)

response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)

Expand All @@ -160,7 +190,7 @@ def post_json_get_json(self, uri, post_json):
defer.returnValue(json.loads(body))

@defer.inlineCallbacks
def get_json(self, uri, args={}):
def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
Expand All @@ -169,6 +199,8 @@ def get_json(self, uri, args={}):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Expand All @@ -177,13 +209,13 @@ def get_json(self, uri, args={}):
error message.
"""
try:
body = yield self.get_raw(uri, args)
body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)

@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
Expand All @@ -193,6 +225,8 @@ def put_json(self, uri, json_body, args={}):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Expand All @@ -205,13 +239,17 @@ def put_json(self, uri, json_body, args={}):

json_str = encode_canonical_json(json_body)

actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)

response = yield self.request(
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)

Expand All @@ -226,7 +264,7 @@ def put_json(self, uri, json_body, args={}):
raise CodeMessageException(response.code, body)

@defer.inlineCallbacks
def get_raw(self, uri, args={}):
def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
Expand All @@ -235,6 +273,8 @@ def get_raw(self, uri, args={}):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
Expand All @@ -246,12 +286,16 @@ def get_raw(self, uri, args={}):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)

actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)

response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
headers=Headers(actual_headers),
)

body = yield make_deferred_yieldable(readBody(response))
Expand All @@ -274,27 +318,33 @@ def _exceptionFromFailedRequest(self, response, body):
# The two should be factored out.

@defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None):
def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""

actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)

response = yield self.request(
"GET",
url.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
headers=Headers(actual_headers),
)

headers = dict(response.headers.getAllRawHeaders())
resp_headers = dict(response.headers.getAllRawHeaders())

if 'Content-Length' in headers and headers['Content-Length'] > max_size:
if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
Expand Down Expand Up @@ -326,7 +376,9 @@ def get_file(self, url, output_stream, max_size=None):
Codes.UNKNOWN,
)

defer.returnValue((length, headers, response.request.absoluteURI, response.code))
defer.returnValue(
(length, resp_headers, response.request.absoluteURI, response.code),
)


# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
Expand Down

0 comments on commit 8b56977

Please sign in to comment.