Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Check if group IDs are valid before using them. #8977

Merged
merged 6 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8977.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Properly return 400 errors on invalid group IDs.
2 changes: 1 addition & 1 deletion synapse/handlers/groups_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _create_rerouter(func_name):

async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s was not legal group ID" % (group_id,))
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)(
Expand Down
48 changes: 45 additions & 3 deletions synapse/rest/client/v2_alpha/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
from functools import wraps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about wraps vs decorators!


from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
Expand All @@ -25,6 +26,22 @@
logger = logging.getLogger(__name__)


def _validate_group_id(f):
"""Wrapper to validate the form of the group ID.

Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""

@wraps(f)
def wrapper(self, request, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

return f(self, request, group_id, *args, **kwargs)

return wrapper


class GroupServlet(RestServlet):
"""Get the group profile
"""
Expand All @@ -37,6 +54,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand All @@ -47,6 +65,7 @@ async def on_GET(self, request, group_id):

return 200, group_description

@_validate_group_id
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -71,6 +90,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand Down Expand Up @@ -102,6 +122,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -117,6 +138,7 @@ async def on_PUT(self, request, group_id, category_id, room_id):

return 200, resp

@_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -142,6 +164,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand All @@ -152,6 +175,7 @@ async def on_GET(self, request, group_id, category_id):

return 200, category

@_validate_group_id
async def on_PUT(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -163,6 +187,7 @@ async def on_PUT(self, request, group_id, category_id):

return 200, resp

@_validate_group_id
async def on_DELETE(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -186,6 +211,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand All @@ -209,6 +235,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand All @@ -219,6 +246,7 @@ async def on_GET(self, request, group_id, role_id):

return 200, category

@_validate_group_id
async def on_PUT(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -230,6 +258,7 @@ async def on_PUT(self, request, group_id, role_id):

return 200, resp

@_validate_group_id
async def on_DELETE(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -253,6 +282,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand Down Expand Up @@ -284,6 +314,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -299,6 +330,7 @@ async def on_PUT(self, request, group_id, role_id, user_id):

return 200, resp

@_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -322,13 +354,11 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()

if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s was not legal group ID" % (group_id,))

result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
Expand All @@ -348,6 +378,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
Expand All @@ -371,6 +402,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -393,6 +425,7 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand Down Expand Up @@ -449,6 +482,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -460,6 +494,7 @@ async def on_PUT(self, request, group_id, room_id):

return 200, result

@_validate_group_id
async def on_DELETE(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -486,6 +521,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand Down Expand Up @@ -514,6 +550,7 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id

@_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand Down Expand Up @@ -541,6 +578,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -565,6 +603,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -589,6 +628,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -613,6 +653,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()

@_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand All @@ -637,6 +678,7 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()

@_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
Expand Down