Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Python 3: Convert some unicode/bytes uses (#3569)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl authored Aug 1, 2018
1 parent c4842e1 commit da77851
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/3569.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
4 changes: 2 additions & 2 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def _get_appservice_user_id(self, request):
if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None))

if "user_id" not in request.args:
if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))

user_id = request.args["user_id"][0]
user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service))

Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
param_dict = dict(kv.split("=") for kv in params)

def strip_quotes(value):
if value.startswith(b"\""):
if value.startswith("\""):
return value[1:-1]
else:
return value
Expand Down
29 changes: 20 additions & 9 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
import unicodedata

import attr
import bcrypt
Expand Down Expand Up @@ -626,6 +627,7 @@ def validate_login(self, username, login_submission):
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")

if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
Expand Down Expand Up @@ -707,9 +709,10 @@ def _check_local_password(self, user_id, password):
multiple inexact matches.
Args:
user_id (str): complete @user:id
user_id (unicode): complete @user:id
password (unicode): the provided password
Returns:
(str) the canonical_user_id, or None if unknown user / bad password
(unicode) the canonical_user_id, or None if unknown user / bad password
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
Expand Down Expand Up @@ -849,14 +852,19 @@ def hash(self, password):
"""Computes a secure hash of password.
Args:
password (str): Password to hash.
password (unicode): Password to hash.
Returns:
Deferred(str): Hashed password.
Deferred(unicode): Hashed password.
"""
def _do_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

return bcrypt.hashpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')

return make_deferred_yieldable(
threads.deferToThreadPool(
Expand All @@ -868,16 +876,19 @@ def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
Args:
password (str): Password to hash.
stored_hash (str): Expected hash value.
password (unicode): Password to hash.
stored_hash (unicode): Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""

def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

return bcrypt.checkpw(
password.encode('utf8') + self.hs.config.password_pepper,
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8')
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def register(
Args:
localpart : The local part of the user ID to register. If None,
one will be generated.
password (str) : The password to assign to this user so they can
password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
Expand Down
35 changes: 25 additions & 10 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cgi
import collections
import logging
import urllib

from six.moves import http_client
from six import PY3
from six.moves import http_client, urllib

from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json

Expand Down Expand Up @@ -264,6 +265,7 @@ def __init__(self, hs, canonical_json=True):
self.hs = hs

def register_paths(self, method, path_patterns, callback):
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
Expand Down Expand Up @@ -296,8 +298,19 @@ def _async_render(self, request):
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.

def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')

kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
name: _unquote(value) if value else value
for name, value in group_dict.items()
})

Expand All @@ -313,9 +326,9 @@ def _get_handler_for_request(self, request):
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict
mapping keys to path components as specified in the handler's
path match regexp.
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
Expand All @@ -327,7 +340,7 @@ def _get_handler_for_request(self, request):
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
m = path_entry.pattern.match(request.path.decode('ascii'))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
Expand Down Expand Up @@ -383,7 +396,7 @@ def __init__(self, path):
self.url = path

def render_GET(self, request):
return redirectTo(self.url, request)
return redirectTo(self.url.encode('ascii'), request)

def getChild(self, name, request):
if len(name) == 0:
Expand All @@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
return

if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object)
else:
json_bytes = json.dumps(json_object)
json_bytes = json.dumps(json_object).encode("utf-8")

return respond_with_json_bytes(
request, code, json_bytes,
Expand Down
10 changes: 9 additions & 1 deletion synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body:
return None

# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try:
content = json.loads(content_bytes)
content_unicode = content_bytes.decode('utf8')
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

try:
content = json.loads(content_unicode)
except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
Expand Down
22 changes: 15 additions & 7 deletions synapse/rest/client/v1/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import hmac
import logging

from six import text_type
from six.moves import http_client

from twisted.internet import defer
Expand Down Expand Up @@ -131,7 +132,10 @@ def on_POST(self, request):
400, "username must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['username'], str) or len(body['username']) > 512):
if (
not isinstance(body['username'], text_type)
or len(body['username']) > 512
):
raise SynapseError(400, "Invalid username")

username = body["username"].encode("utf-8")
Expand All @@ -143,7 +147,10 @@ def on_POST(self, request):
400, "password must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['password'], str) or len(body['password']) > 512):
if (
not isinstance(body['password'], text_type)
or len(body['password']) > 512
):
raise SynapseError(400, "Invalid password")

password = body["password"].encode("utf-8")
Expand All @@ -166,17 +173,18 @@ def on_POST(self, request):
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()

if not hmac.compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
raise SynapseError(403, "HMAC incorrect")

# Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet

register = RegisterRestServlet(self.hs)

(user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin),
localpart=body['username'].lower(),
password=body["password"],
admin=bool(admin),
generate_token=False,
)

Expand Down
12 changes: 6 additions & 6 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def __init__(self, hs):
def on_POST(self, request):
body = parse_json_object_from_request(request)

kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
kind = b"user"
if b"kind" in request.args:
kind = request.args[b"kind"][0]

if kind == "guest":
if kind == b"guest":
ret = yield self._do_guest_registration(body)
defer.returnValue(ret)
return
elif kind != "user":
elif kind != b"user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
Expand Down Expand Up @@ -389,8 +389,8 @@ def on_POST(self, request):
assert_params_in_dict(params, ["password"])

desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None)

if desired_username is not None:
desired_username = desired_username.lower()
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/media/v1/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def ensure_media_is_in_local_cache(self, file_info):
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor())
open(local_path, "wb"), self.hs.get_reactor())
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
Expand Down
2 changes: 1 addition & 1 deletion synapse/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _make_state_cache_entry(

def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()

return sorted(events, key=key_func)

Expand Down
14 changes: 10 additions & 4 deletions synapse/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@


def encode_json(json_object):
return frozendict_json_encoder.encode(json_object)
"""
Encode a Python object as JSON and return it in a Unicode string.
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode('utf8')
return out


class _EventPeristenceQueue(object):
Expand Down Expand Up @@ -1058,7 +1064,7 @@ def _update_outliers_txn(self, txn, events_and_contexts):

metadata_json = encode_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
)

sql = (
"UPDATE event_json SET internal_metadata = ?"
Expand Down Expand Up @@ -1172,8 +1178,8 @@ def event_dict(event):
"room_id": event.room_id,
"internal_metadata": encode_json(
event.internal_metadata.get_dict()
).decode("UTF-8"),
"json": encode_json(event_dict(event)).decode("UTF-8"),
),
"json": encode_json(event_dict(event)),
}
for event, _ in events_and_contexts
],
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_event_reference_hashes_txn(self, txn, event_id):
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
A dict[unicode, bytes] of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
Expand Down
2 changes: 1 addition & 1 deletion synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __deepcopy__(self, memo):
@classmethod
def from_string(cls, s):
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0] != cls.SIGIL:
if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL,
))
Expand Down
Loading

0 comments on commit da77851

Please sign in to comment.