diff --git a/src/xpra/client/client_base.py b/src/xpra/client/client_base.py index 2a81f8bac8..eba21dddfe 100644 --- a/src/xpra/client/client_base.py +++ b/src/xpra/client/client_base.py @@ -9,6 +9,7 @@ import sys import socket import binascii +import string from xpra.gtk_common.gobject_compat import import_gobject, import_glib gobject = import_gobject() @@ -21,7 +22,7 @@ from xpra.platform.features import GOT_PASSWORD_PROMPT_SUGGESTION from xpra.platform.info import get_name from xpra.os_util import get_hex_uuid, get_machine_id, get_user_uuid, load_binary_file, SIGNAMES, strtobytes, bytestostr -from xpra.util import typedict, updict, xor +from xpra.util import typedict, updict, xor, repr_ellipsized EXIT_OK = 0 EXIT_CONNECTION_LOST = 1 @@ -167,6 +168,8 @@ def setup_connection(self, conn): self._protocol.large_packets.append("server-settings") self._protocol.set_compression_level(self.compression_level) self._protocol.receive_aliases.update(self._aliases) + self._protocol.enable_default_encoder() + self._protocol.enable_default_compressor() self.have_more = self._protocol.source_has_more def init_packet_handlers(self): @@ -174,11 +177,12 @@ def init_packet_handlers(self): "hello": self._process_hello, } self._ui_packet_handlers = { - "challenge": self._process_challenge, - "disconnect": self._process_disconnect, - "set_deflate": self._process_set_deflate, - Protocol.CONNECTION_LOST: self._process_connection_lost, - Protocol.GIBBERISH: self._process_gibberish, + "challenge": self._process_challenge, + "disconnect": self._process_disconnect, + "set_deflate": self._process_set_deflate, + Protocol.CONNECTION_LOST: self._process_connection_lost, + Protocol.GIBBERISH: self._process_gibberish, + Protocol.INVALID: self._process_invalid, } def init_authenticated_packet_handlers(self): @@ -521,8 +525,14 @@ def _process_set_deflate(self, packet): pass def _process_gibberish(self, packet): - (_, data) = packet - log.info("Received uninterpretable nonsense: %s", repr(data)) + (_, message, data) = packet + p = self._protocol + if (p and p.input_packetcount>0) or not all(c in string.printable for c in data): + log.info("Received uninterpretable nonsense: %s", message) + log.info(" packet no %i data: %s", p.input_packetcount, repr_ellipsized(data)) + else: + #looks like the first packet back is just text, print it: + log.info("Failed to connect: %s", repr_ellipsized(data.strip("\n").strip("\r"))) if str(data).find("assword")>0: log.warn("Your ssh program appears to be asking for a password." + GOT_PASSWORD_PROMPT_SUGGESTION) @@ -534,6 +544,13 @@ def _process_gibberish(self, packet): else: self.quit(EXIT_PACKET_FAILURE) + def _process_invalid(self, packet): + (_, message, data) = packet + log.info("Received invalid packet: %s", message) + log(" data: %s", repr_ellipsized(data)) + self.quit(EXIT_PACKET_FAILURE) + + def process_packet(self, proto, packet): try: handler = None diff --git a/src/xpra/net/compression.py b/src/xpra/net/compression.py index acce6d7740..627a047f2f 100644 --- a/src/xpra/net/compression.py +++ b/src/xpra/net/compression.py @@ -100,6 +100,10 @@ def compressed_wrapper(datatype, data, level=5, lz4=False): return LevelCompressed(datatype, cdata, cl, algo) +class InvalidCompressionException(Exception): + pass + + def get_compression_type(level): if level & LZ4_FLAG: return "lz4" @@ -111,12 +115,16 @@ def get_compression_type(level): def decompress(data, level): #log.info("decompress(%s bytes, %s) type=%s", len(data), get_compression_type(level)) if level & LZ4_FLAG: - assert has_lz4, "lz4 is not available" - assert use_lz4, "lz4 is not enabled" + if not has_lz4: + raise InvalidCompressionException("lz4 is not available") + if not use_lz4: + raise InvalidCompressionException("lz4 is not enabled") return LZ4_uncompress(data) elif level & BZ2_FLAG: - assert use_bz2, "bz2 is not enabled" + if not use_bz2: + raise InvalidCompressionException("bz2 is not enabled") return bz2.decompress(data) else: - assert use_zlib, "zlib is not enabled" + if not use_zlib: + raise InvalidCompressionException("zlib is not enabled") return zlib.decompress(data) diff --git a/src/xpra/net/packet_encoding.py b/src/xpra/net/packet_encoding.py index 8a61b00f21..804860d3a3 100644 --- a/src/xpra/net/packet_encoding.py +++ b/src/xpra/net/packet_encoding.py @@ -81,18 +81,29 @@ def get_packet_encoding_type(protocol_flags): else: return "bencode" + +class InvalidPacketEncodingException(Exception): + pass + + def decode(data, protocol_flags): if protocol_flags & FLAGS_RENCODE: - assert has_rencode, "rencode packet encoder is not available" - assert use_rencode, "rencode packet encoder is disabled" + if not has_rencode: + raise InvalidPacketEncodingException("rencode is not available") + if not use_rencode: + raise InvalidPacketEncodingException("rencode is disabled") return list(rencode_loads(data)) elif protocol_flags & FLAGS_YAML: - assert has_rencode, "yaml packet encoder is not available!" - assert use_yaml, "yaml packet encoder is disabled" + if not has_yaml: + raise InvalidPacketEncodingException("yaml is not available") + if not use_yaml: + raise InvalidPacketEncodingException("yaml is disabled") return list(yaml_decode(data)) else: - assert has_rencode, "bencode packet encoder is not available!" - assert use_bencode, "bencode packet encoder is disabled" + if not has_bencode: + raise InvalidPacketEncodingException("bencode is not available") + if not use_bencode: + raise InvalidPacketEncodingException("bencode is disabled") #if sys.version>='3': # data = data.decode("latin1") packet, l = bdecode(data) diff --git a/src/xpra/net/protocol.py b/src/xpra/net/protocol.py index a7e5af35dc..1422fdfa8b 100644 --- a/src/xpra/net/protocol.py +++ b/src/xpra/net/protocol.py @@ -23,11 +23,11 @@ from xpra.util import repr_ellipsized from xpra.net.bytestreams import ABORT from xpra.net import compression -from xpra.net.compression import nocompress, zcompress, bzcompress, lz4_compress, Compressed, LevelCompressed, get_compression_caps +from xpra.net.compression import nocompress, zcompress, bzcompress, lz4_compress, Compressed, LevelCompressed, get_compression_caps, InvalidCompressionException from xpra.net.header import unpack_header, pack_header, pack_header_and_data, FLAGS_RENCODE, FLAGS_YAML, FLAGS_CIPHER, FLAGS_NOHEADER from xpra.net.crypto import get_crypto_caps, get_cipher from xpra.net import packet_encoding -from xpra.net.packet_encoding import get_packet_encoding_caps, rencode_dumps, decode, has_rencode, \ +from xpra.net.packet_encoding import InvalidPacketEncodingException, get_packet_encoding_caps, rencode_dumps, decode, has_rencode, \ bencode, has_bencode, yaml_encode, has_yaml @@ -80,6 +80,7 @@ class ConnectionClosedException(Exception): class Protocol(object): CONNECTION_LOST = "connection-lost" GIBBERISH = "gibberish" + INVALID = "invalid" def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None): """ @@ -131,8 +132,6 @@ def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None): self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "parse") self._write_format_thread = make_daemon_thread(self._write_format_thread_loop, "format") self._source_has_more = threading.Event() - self.enable_default_encoder() - self.enable_default_compressor() STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases", "receive_aliases", "cipher_in", "cipher_in_name", "cipher_in_block_size", @@ -617,29 +616,42 @@ def _read(self): return self.input_raw_packetcount += 1 - def _call_connection_lost(self, message="", exc_info=False): - log("will call connection lost: %s", message) + def _internal_error(self, message="", exc_info=False): + log.error("internal error: %s", message) self.idle_add(self._connection_lost, message, exc_info) def _connection_lost(self, message="", exc_info=False): - log.info("connection lost: %s", message, exc_info=exc_info) + log("connection lost: %s", message, exc_info=exc_info) self.close() return False + + def invalid(self, msg, data): + self.idle_add(self._process_packet_cb, self, [Protocol.INVALID, msg, data]) + # Then hang up: + self.timeout_add(1000, self._connection_lost, msg) + def gibberish(self, msg, data): - self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, data]) + self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, msg, data]) # Then hang up: self.timeout_add(1000, self._connection_lost, msg) + + #delegates to invalid_header() + #(so this can more easily be intercepted and overriden + # see tcp-proxy) def _invalid_header(self, data): - self.invalid_header(self, data) + #call via idle_add gives the client time + #to disconnect (and so we don't bother with it) + self.idle_add(self.invalid_header, self, data) def invalid_header(self, proto, data): - err = "invalid packet header byte: '%#x'" % ord(data[0]) + err = "invalid packet header: '%s'" % binascii.hexlify(data[:8]) if len(data)>1: - err += " read buffer=0x%s" % repr_ellipsized(data) + err += " read buffer=%s" % repr_ellipsized(data) self.gibberish(err, data) + def _read_parse_thread_loop(self): log("read_parse_thread_loop starting") try: @@ -715,11 +727,13 @@ def do_read_parse_thread_loop(self): #this packet is seemingly too big, but check again from the main UI thread #this gives 'set_max_packet_size' a chance to run from "hello" def check_packet_size(size_to_check, packet_header): - if not self._closed: - log("check_packet_size(%s, 0x%s) limit is %s", size_to_check, repr_ellipsized(packet_header), self.max_packet_size) - if size_to_check>self.max_packet_size: - self._call_connection_lost("invalid packet: size requested is %s (maximum allowed is %s - packet header: 0x%s), dropping this connection!" % - (size_to_check, self.max_packet_size, repr_ellipsized(packet_header))) + if self._closed: + return False + log("check_packet_size(%s, 0x%s) limit is %s", size_to_check, repr_ellipsized(packet_header), self.max_packet_size) + if size_to_check>self.max_packet_size: + msg = "packet size requested is %s but maximum allowed is %s" % \ + (size_to_check, self.max_packet_size) + self.invalid(msg, packet_header) return False self.timeout_add(1000, check_packet_size, payload_size, read_buffer[:32]) @@ -748,22 +762,28 @@ def debug_str(s): if not data.endswith(padding): log("decryption failed: string does not end with '%s': %s (%s) -> %s (%s)", padding, debug_str(raw_string), type(raw_string), debug_str(data), type(data)) - self._connection_lost("encryption error (wrong key?)") + self._internal_error("encryption error (wrong key?)") return data = data[:-len(padding)] #uncompress if needed: if compression_level>0: try: data = compression.decompress(data, compression_level) + except InvalidCompressionException, e: + self.invalid("invalid compression: %s" % e, data) + return except Exception, e: ctype = compression.get_compression_type(compression_level) log("%s packet decompression failed", ctype, exc_info=True) + msg = "%s packet decompression failed" % ctype if self.cipher_in: - return self._call_connection_lost("%s packet decompression failed (invalid encryption key?): %s" % (ctype, e)) - return self._call_connection_lost("%s packet decompression failed: %s" % (ctype, e)) + msg += " (invalid encryption key?)" + msg = "msg: %s" % e + return self.gibberish(msg, data) if self.cipher_in and not (protocol_flags & FLAGS_CIPHER): - return self._call_connection_lost("unencrypted packet dropped: %s" % repr_ellipsized(data)) + self.invalid("unencrypted packet dropped", data) + return if self._closed: return @@ -773,18 +793,22 @@ def debug_str(s): payload_size = -1 packet_index = 0 if len(raw_packets)>=4: - return self._call_connection_lost("too many raw packets: %s" % len(raw_packets)) + self.invalid("too many raw packets: %s" % len(raw_packets), data) + return continue #final packet (packet_index==0), decode it: try: packet = decode(data, protocol_flags) + except InvalidPacketEncodingException, e: + self.invalid("invalid packet encoding: %s" % e, data) + return except ValueError, e: etype = packet_encoding.get_packet_encoding_type(protocol_flags) - log.error("value error parsing %s packet: %s", etype, e, exc_info=True) + log.error("failed to parse %s packet: %s", etype, e, exc_info=not self._closed) if self._closed: return log("failed to parse %s packet: %s", etype, binascii.hexlify(data)) - msg = "gibberish received: %s, packet index=%s, packet size=%s, buffer size=%s, error=%s" % (repr_ellipsized(data), packet_index, payload_size, bl, e) + msg = "packet index=%s, packet size=%s, buffer size=%s, error=%s" % (packet_index, payload_size, bl, e) self.gibberish(msg, data) return diff --git a/src/xpra/server/server_core.py b/src/xpra/server/server_core.py index 7f04b01e3a..bd16cf0ad0 100644 --- a/src/xpra/server/server_core.py +++ b/src/xpra/server/server_core.py @@ -202,6 +202,7 @@ def init_packet_handlers(self): "info-request": self._process_info_request, Protocol.CONNECTION_LOST: self._process_connection_lost, Protocol.GIBBERISH: self._process_gibberish, + Protocol.INVALID: self._process_invalid, } def init_aliases(self): @@ -325,9 +326,7 @@ def invalid_header(self, proto, data): if proto.input_packetcount==0 and self._tcp_proxy: self.start_tcp_proxy(proto, data) return - err = "invalid packet header byte: '%#x', not an xpra client?" % ord(data[0]) - if len(data)>1: - err += " read buffer=0x%s" % repr_ellipsized(data) + err = "invalid packet format, not an xpra client?" proto.gibberish(err, data) def start_tcp_proxy(self, proto, data): @@ -378,24 +377,30 @@ def force_disconnect(self, proto): proto.close() def disconnect_client(self, protocol, reason): - if protocol: + if protocol and not protocol._closed: self.disconnect_protocol(protocol, reason) - log.info("Connection lost") def disconnect_protocol(self, protocol, reason): - log.info("Disconnecting existing client %s, reason is: %s", protocol, reason) + log.info("Disconnecting client %s: %s", protocol, reason) protocol.flush_then_close(["disconnect", reason]) def _process_connection_lost(self, proto, packet): - log.info("Connection lost") if proto in self._potential_protocols: + log.info("Connection lost") self._potential_protocols.remove(proto) def _process_gibberish(self, proto, packet): - data = packet[1] - log.info("Received uninterpretable nonsense: %s", repr(data)) - self.disconnect_client(proto, "invalid packet format, not an xpra client?") + (_, message, data) = packet + log("Received uninterpretable nonsense from %s: %s", proto, message) + log(" data: %s", repr_ellipsized(data)) + self.disconnect_client(proto, message) + + def _process_invalid(self, protocol, packet): + (_, message, data) = packet + log("Received invalid packet: %s", message) + log(" data: %s", repr_ellipsized(data)) + self.disconnect_client(protocol, message) def send_version_info(self, proto):