Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Port rest/ to Python 3 #3823

Merged
merged 3 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3823.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rest/ is now ported to Python 3.
9 changes: 6 additions & 3 deletions synapse/rest/client/v1/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def on_GET(self, request):

nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')})
return (200, {"nonce": nonce})

@defer.inlineCallbacks
def on_POST(self, request):
Expand Down Expand Up @@ -164,7 +164,7 @@ def on_POST(self, request):
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce)
want_mac.update(nonce.encode('utf8'))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
Expand All @@ -173,7 +173,10 @@ def on_POST(self, request):
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()

if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
if not hmac.compare_digest(
want_mac.encode('ascii'),
got_mac.encode('ascii')
):
raise SynapseError(403, "HMAC incorrect")

# Reuse the parts of RegisterRestServlet to reduce code duplication
Expand Down
12 changes: 6 additions & 6 deletions synapse/rest/client/v1/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ def on_GET(self, request):
is_guest = requester.is_guest
room_id = None
if is_guest:
if "room_id" not in request.args:
if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args:
room_id = request.args["room_id"][0]
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode('ascii')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we've specified that room_id must be ascii?


pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
if b"timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
timeout = int(request.args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")

as_client_event = "raw" not in request.args
as_client_event = b"raw" not in request.args

chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v1/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, hs):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
Expand Down
44 changes: 22 additions & 22 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# limitations under the License.

import logging
import urllib
import xml.etree.ElementTree as ET

from six.moves.urllib import parse as urlparse
from six.moves import urllib

from canonicaljson import json
from saml2 import BINDING_HTTP_POST, config
Expand Down Expand Up @@ -134,7 +133,7 @@ def on_POST(self, request):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
relay_state = "&RelayState=" + urllib.parse.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
Expand Down Expand Up @@ -366,7 +365,7 @@ def on_POST(self, request):
(user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava
if 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' +
Expand All @@ -377,7 +376,7 @@ def on_POST(self, request):
"user_id": user_id, "token": token,
"ava": saml2_auth.ava}))
elif 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=not_authenticated')
finish_request(request)
Expand All @@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):

def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')

def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
client_redirect_url_param = urllib.parse.urlencode({
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/api/v1/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)


Expand All @@ -422,11 +422,11 @@ def __init__(self, hs):

@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0]
client_redirect_url = request.args[b"redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"ticket": request.args[b"ticket"][0].decode('ascii'),
"service": self.cas_service_url
}
try:
Expand Down Expand Up @@ -471,11 +471,11 @@ def handle_cas_response(self, request, cas_response_body, client_redirect_url):
finish_request(request)

def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urllib.parse.urlunparse(url_parts)

def parse_cas_response(self, cas_response_body):
user = None
Expand Down
24 changes: 12 additions & 12 deletions synapse/rest/client/v1/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def on_PUT(self, request):
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))

requester = yield self.auth.get_user_by_req(request)

Expand All @@ -73,7 +73,7 @@ def on_PUT(self, request):
content,
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))

before = parse_string(request, "before")
if before:
Expand All @@ -95,9 +95,9 @@ def on_PUT(self, request):
)
self.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))

defer.returnValue((200, {}))

Expand Down Expand Up @@ -142,10 +142,10 @@ def on_GET(self, request):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)

if path[0] == '':
if path[0] == b'':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
elif path[0] == b'global':
path = [x.decode('ascii') for x in path[1:]]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
else:
Expand Down Expand Up @@ -192,24 +192,24 @@ def set_rule_attr(self, user_id, spec, val):
def _rule_spec_from_path(path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
if path[0] != b'pushrules':
raise UnrecognizedRequestError()

scope = path[1]
scope = path[1].decode('ascii')
path = path[2:]
if scope != 'global':
raise UnrecognizedRequestError()

if len(path) == 0:
raise UnrecognizedRequestError()

template = path[0]
template = path[0].decode('ascii')
path = path[1:]

if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError()

rule_id = path[0]
rule_id = path[0].decode('ascii')

spec = {
'scope': scope,
Expand All @@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
path = path[1:]

if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0]
spec['attr'] = path[0].decode('ascii')

return spec

Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v1/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def on_GET(self, request):
]

for p in pushers:
for k, v in p.items():
for k, v in list(p.items()):
if k not in allowed_keys:
del p[k]

Expand Down Expand Up @@ -126,7 +126,7 @@ def on_POST(self, request):
profile_tag=content.get('profile_tag', ""),
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message,
raise SynapseError(400, "Config Error: " + str(pce),
errcode=Codes.MISSING_PARAM)

self.notifier.on_new_replication_data()
Expand Down
14 changes: 8 additions & 6 deletions synapse/rest/client/v1/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def on_POST(self, request, room_id, event_type, txn_id=None):
"sender": requester.user.to_string(),
}

if 'ts' in request.args and requester.app_service:
if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)

event = yield self.event_creation_hander.create_and_send_nonmember_event(
Expand Down Expand Up @@ -255,7 +255,9 @@ def on_POST(self, request, room_identifier, txn_id=None):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
remote_room_hosts = [
x.decode('ascii') for x in request.args[b"server_name"]
]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
Expand Down Expand Up @@ -461,10 +463,10 @@ def on_GET(self, request, room_id):
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = parse_string(request, "filter")
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
Expand Down Expand Up @@ -560,7 +562,7 @@ def on_GET(self, request, room_id, event_id):
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v2_alpha/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, hs):

@defer.inlineCallbacks
def on_GET(self, request):
if "from" in request.args:
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
raise SynapseError(
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v2_alpha/thirdparty.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request, allow_guest=True)

fields = request.args
fields.pop("access_token", None)
fields.pop(b"access_token", None)

results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
Expand All @@ -102,7 +102,7 @@ def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request, allow_guest=True)

fields = request.args
fields.pop("access_token", None)
fields.pop(b"access_token", None)

results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/key/v1/server_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ def render_GET(self, request):
)

def getChild(self, name, request):
if name == '':
if name == b'':
return self
4 changes: 2 additions & 2 deletions synapse/rest/key/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
class KeyApiV2Resource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild("server", LocalKey(hs))
self.putChild("query", RemoteKey(hs))
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
6 changes: 4 additions & 2 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def render_GET(self, request):
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server: {}}
query = {server.decode('ascii'): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(
Expand All @@ -112,11 +112,12 @@ def async_render_GET(self, request):
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}}
query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
)

yield self.query_keys(request, query, query_remote_on_cache_miss=True)

def render_POST(self, request):
Expand All @@ -135,6 +136,7 @@ def async_render_POST(self, request):
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)

store_queries = []
for server_name, key_ids in query.items():
if (
Expand Down
Loading