From 076956cde3b78a5ef3328f4167ae5f34e6090288 Mon Sep 17 00:00:00 2001 From: gstarovo Date: Wed, 17 Apr 2024 11:19:54 +0200 Subject: [PATCH] changes in point extension format --- .github/workflows/ci.yml | 8 +- .gitignore | 2 +- scripts/tls.py | 8 +- test | 0 tests/tlstest.py | 64 +++++++++++++- tlslite/handshakesettings.py | 17 +++- tlslite/keyexchange.py | 115 +++++++++++++++++-------- tlslite/session.py | 8 +- tlslite/tlsconnection.py | 57 ++++++++---- unit_tests/test_tlslite_keyexchange.py | 9 +- 10 files changed, 223 insertions(+), 65 deletions(-) delete mode 100644 test diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6581b6fe..64d4eb2b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -433,7 +433,13 @@ jobs: COVERALLS_FLAG_NAME: ${{ matrix.name }} COVERALLS_PARALLEL: true COVERALLS_SERVICE_NAME: github - run: coveralls + PY_VERSION: ${{ matrix.python-version }} + run: | + if [[ $PY_VERSION == "2.6" ]]; then + COVERALLS_SKIP_SSL_VERIFY=1 coveralls + else + coveralls + fi - name: Publish coverage to Codeclimate if: ${{ contains(matrix.opt-deps, 'codeclimate') }} env: diff --git a/.gitignore b/.gitignore index 56433754..daedfe76 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ coverage.xml pylint_report.txt build/ docs/_build/ -htmlcov/ +htmlcov/ \ No newline at end of file diff --git a/scripts/tls.py b/scripts/tls.py index a3f27ebe..d717631b 100755 --- a/scripts/tls.py +++ b/scripts/tls.py @@ -367,6 +367,7 @@ def printGoodConnection(connection, seconds): print(" Extended Master Secret: {0}".format( connection.extendedMasterSecret)) print(" Session Resumed: {0}".format(connection.resumed)) + print(" Session used ec point format extension: {0}".format(connection.session.ec_point_format)) def printExporter(connection, expLabel, expLength): if expLabel is None: @@ -415,6 +416,8 @@ def clientCmd(argv): if cipherlist: settings.cipherNames = [item for cipher in cipherlist for item in cipher.split(',')] + # CHANGED + settings.ec_point_formats = [] try: start = time_stamp() if username and password: @@ -424,7 +427,7 @@ def clientCmd(argv): connection.handshakeClientCert(cert_chain, privateKey, settings=settings, serverName=address[0], alpn=alpn) stop = time_stamp() - print("Handshake success") + print("Handshake success") except TLSLocalAlert as a: if a.description == AlertDescription.user_canceled: print(str(a)) @@ -567,6 +570,9 @@ def serverCmd(argv): if cipherlist: settings.cipherNames = [item for cipher in cipherlist for item in cipher.split(',')] + # CHANGED + + settings.ec_point_formats = [2, 0] class MySimpleEchoHandler(BaseRequestHandler): def handle(self): diff --git a/test b/test deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/tlstest.py b/tests/tlstest.py index 18a64b73..337beaf8 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -44,7 +44,7 @@ from xmlrpc import client as xmlrpclib import ssl from tlslite import * -from tlslite.constants import KeyUpdateMessageType +from tlslite.constants import KeyUpdateMessageType, ECPointFormat try: from tack.structures.Tack import Tack @@ -286,6 +286,34 @@ def connect(): test_no += 1 + print("Test {0} - client compressed/uncompressed - uncompressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} - client compressed - compressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_char2 + connection.close() + + test_no += 1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.recv(1) connection = connect() @@ -2162,6 +2190,37 @@ def connect(): test_no += 1 + print("Test {0} server uncompressed ec format - uncompressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.ec_point_formats = [ECPointFormat.uncompressed] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} server compressed ec format - compressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_char2 + connection.close() + + test_no +=1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.send(b'R') connection = connect() @@ -3416,7 +3475,7 @@ def heartbeat_response_check(message): assert synchro.recv(1) == b'R' connection.close() - test_no += 1 + test_no +=1 print("Tests {0}-{1} - XMLRPXC server".format(test_no, test_no + 2)) @@ -3449,6 +3508,7 @@ def add(self, x, y): return x + y synchro.close() synchroSocket.close() + test_no += 2 print("Test succeeded") diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 38e560a2..cc836563 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -7,7 +7,7 @@ """Class for setting handshake parameters.""" -from .constants import CertificateType +from .constants import CertificateType, ECPointFormat from .utils import cryptomath from .utils import cipherfactory from .utils.compat import ecdsaAllCurves, int_types @@ -61,6 +61,9 @@ TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] PSK_MODES = ["psk_dhe_ke", "psk_ke"] +EC_POINT_FORMATS = [ECPointFormat.ansiX962_compressed_char2, + ECPointFormat.ansiX962_compressed_prime, + ECPointFormat.uncompressed] class Keypair(object): @@ -353,6 +356,10 @@ class HandshakeSettings(object): :vartype keyExchangeNames: list :ivar keyExchangeNames: Enabled key exchange types for the connection, influences selected cipher suites. + + :vartype ec_point_formats: list + :ivat ec_point_formats: Enabeled point format extension for + elliptic curves. """ def _init_key_settings(self): @@ -396,6 +403,7 @@ def _init_misc_extensions(self): # resumed connections (as tickets are single-use in TLS 1.3 self.ticket_count = 2 self.record_size_limit = 2**14 + 1 # TLS 1.3 includes content type + self.ec_point_formats = list(EC_POINT_FORMATS) def __init__(self): """Initialise default values for settings.""" @@ -598,6 +606,12 @@ def _sanityCheckExtensions(other): if other.record_size_limit is not None and \ not 64 <= other.record_size_limit <= 2**14 + 1: raise ValueError("record_size_limit cannot exceed 2**14+1 bytes") + + bad_ec_ext = [i for i in other.ec_point_formats if + i not in EC_POINT_FORMATS] + if bad_ec_ext: + raise ValueError("Unknown ec point format extension: " + "{0}".format(bad_ec_ext)) HandshakeSettings._sanityCheckEMSExtension(other) @@ -667,6 +681,7 @@ def _copy_extension_settings(self, other): other.sendFallbackSCSV = self.sendFallbackSCSV other.useEncryptThenMAC = self.useEncryptThenMAC other.usePaddingExtension = self.usePaddingExtension + other.ec_point_formats = self.ec_point_formats # session tickets other.padding_cb = self.padding_cb other.ticketKeys = self.ticketKeys diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 2242aad3..82c0a1b3 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -12,7 +12,7 @@ TLSDecodeError from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \ - ExtensionType, GroupName, ECCurveType, SignatureScheme + ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat from .utils.ecc import getCurveByName, getPointByteSize from .utils.rsakey import RSAKey from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ @@ -705,14 +705,16 @@ def makeServerKeyExchange(self, sigHash=None): kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) self.ecdhXs = kex.get_random_private_key() - if isinstance(self.ecdhXs, ecdsa.keys.SigningKey): - ecdhYs = bytearray( - self.ecdhXs.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - ecdhYs = kex.calc_public_value(self.ecdhXs) + ext_negotiated = ECPointFormat.uncompressed + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + for ext in ext_c.formats: + if ext in ext_s.formats: + ext_negotiated = ext + break + + ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) @@ -730,7 +732,14 @@ def processClientKeyExchange(self, clientKeyExchange): raise TLSDecodeError("No key share") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) - return kex.calc_shared_key(self.ecdhXs, ecdhYc) + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [ + ext for ext in ext_c.formats if ext in ext_s.formats + ] + return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported) def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" @@ -748,15 +757,16 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): kex = ECDHKeyExchange(serverKeyExchange.named_curve, self.serverHello.server_version) ecdhXc = kex.get_random_private_key() - if isinstance(ecdhXc, ecdsa.keys.SigningKey): - self.ecdhYc = bytearray( - ecdhXc.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - self.ecdhYc = kex.calc_public_value(ecdhXc) - return kex.calc_shared_key(ecdhXc, ecdh_Ys) + ext_negotiated = ECPointFormat.uncompressed + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [i for i in ext_c.formats if i in ext_s.formats] + ext_negotiated = ext_supported[0] + + self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) + return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported) def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" @@ -903,11 +913,11 @@ def get_random_private_key(self): """ raise NotImplementedError("Abstract class") - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") - def calc_shared_key(self, private, peer_share): + def calc_shared_key(self, private, peer_share, frm_supported=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class") @@ -940,9 +950,10 @@ def get_random_private_key(self): needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes)) - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """ Calculate the public value for given private value. + Frm_negotiated added for API compatibility, not needed for FFDH. :rtype: int """ @@ -964,8 +975,11 @@ def _normalise_peer_share(self, peer_share): "Key share does not match FFDH prime") return bytesToNumber(peer_share) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key.""" + def calc_shared_key(self, private, peer_share, frm_supported=None): + """Calculate the shared key. + Frm_supported added for API compatibility, not needed for FFDH. + + :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) # First half of RFC 2631, Section 2.1.5. Validate the client's public # key. @@ -984,7 +998,6 @@ def calc_shared_key(self, private, peer_share): class ECDHKeyExchange(RawDHKeyExchange): """Implementation of the Elliptic Curve Diffie-Hellman key exchange.""" - _x_groups = set((GroupName.x25519, GroupName.x448)) @staticmethod @@ -1021,20 +1034,50 @@ def _get_fun_gen_size(self): else: return x448, bytearray(X448_G), X448_ORDER_SIZE - def calc_public_value(self, private): + @staticmethod + def _get_point_format(ext): + """Get extension name from the numeric value.""" + transform = {ECPointFormat.uncompressed: 'uncompressed', + ECPointFormat.ansiX962_compressed_char2: 'compressed', + ECPointFormat.ansiX962_compressed_prime: 'compressed'} + return transform[ext] + + def calc_public_value(self, + private, + frm_negotiated=ECPointFormat.uncompressed): """Calculate public value for given private key.""" + point_fmt = self._get_point_format(frm_negotiated) if isinstance(private, ecdsa.keys.SigningKey): - return private.verifying_key.to_string('uncompressed') + return private.verifying_key.to_string(point_fmt) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) - else: - curve = getCurveByName(GroupName.toStr(self.group)) - point = curve.generator * private - return bytearray(point.to_bytes('uncompressed')) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key,""" + curve = getCurveByName(GroupName.toStr(self.group)) + point = curve.generator * private + return bytearray(point.to_bytes(encoding=point_fmt)) + + def calc_shared_key(self, private, peer_share, + frm_supported=set([ECPointFormat.uncompressed])): + """Calculate the shared key. + + :type private: bytearray | SigningKey + :param private: private value + + :type peer_share: bytearray + :param peer_share: public value + + :type frm_supported: set(ECPointFormat) + :param frm_supported: acceptable point formats for public value + + :rtype: bytearray + :returns: shared key + + :raises TLSIllegalParameterException + when the paramentrs for point are invalid + """ + valid_encodings = set([self._get_point_format(i) \ + for i in frm_supported]) if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() @@ -1049,7 +1092,8 @@ def calc_shared_key(self, private, peer_share): curve = getCurveByName(GroupName.toRepr(self.group)) try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() - point = abstractPoint.from_bytes(curve.curve, peer_share) + point = abstractPoint.from_bytes(curve.curve, peer_share, + valid_encodings=valid_encodings) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) @@ -1057,7 +1101,8 @@ def calc_shared_key(self, private, peer_share): raise TLSIllegalParameterException("Invalid ECC point") if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) - ecdh.load_received_public_key_bytes(peer_share) + ecdh.load_received_public_key_bytes(peer_share, + valid_encodings=valid_encodings) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private diff --git a/tlslite/session.py b/tlslite/session.py index 0e310b71..1a9939ae 100644 --- a/tlslite/session.py +++ b/tlslite/session.py @@ -72,6 +72,10 @@ class Session(object): :vartype tls_1_0_tickets: list :ivar tls_1_0_tickets: list of TLS 1.2 and earlier session tickets received from the server + + :vartype ec_point_format: int + :ivar ec_point_format: used ec point extension format; + created for testing """ def __init__(self): @@ -94,6 +98,7 @@ def __init__(self): self.resumptionMasterSecret = bytearray(0) self.tickets = None self.tls_1_0_tickets = None + self.ec_point_format = None def create(self, masterSecret, sessionID, cipherSuite, srpUsername, clientCertChain, serverCertChain, @@ -102,7 +107,7 @@ def create(self, masterSecret, sessionID, cipherSuite, appProto=bytearray(0), cl_app_secret=bytearray(0), sr_app_secret=bytearray(0), exporterMasterSecret=bytearray(0), resumptionMasterSecret=bytearray(0), tickets=None, - tls_1_0_tickets=None): + tls_1_0_tickets=None, ec_point_format=None): self.masterSecret = masterSecret self.sessionID = sessionID self.cipherSuite = cipherSuite @@ -123,6 +128,7 @@ def create(self, masterSecret, sessionID, cipherSuite, # NOTE we need a reference copy not a copy of object here! self.tickets = tickets self.tls_1_0_tickets = tls_1_0_tickets + self.ec_point_format = ec_point_format def _clone(self): other = Session() diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 582097a7..b5efdfcb 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -655,6 +655,12 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, alpnExt = serverHello.getExtension(ExtensionType.alpn) if alpnExt: alpnProto = alpnExt.protocol_names[0] + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + ext_ec_point = ECPointFormat.uncompressed + if ext_c and ext_s: + ext_ec_point = [i for i in ext_c.formats \ + if i in ext_s.formats][0] # Create the session object which is used for resumptions self.session = Session() @@ -667,7 +673,8 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, appProto=alpnProto, # NOTE it must be a reference not a copy tickets=self.tickets, - tls_1_0_tickets=self.tls_1_0_tickets) + tls_1_0_tickets=self.tls_1_0_tickets, + ec_point_format=ext_ec_point) self._handshakeDone(resumed=False) self._serverRandom = serverHello.random self._clientRandom = clientHello.random @@ -745,7 +752,6 @@ def _clientSendClientHello(self, settings, session, srpUsername, for group_name in settings.keyShares: group_id = getattr(GroupName, group_name) key_share = self._genKeyShareEntry(group_id, (3, 4)) - shares.append(key_share) # if TLS 1.3 is enabled, key_share must always be sent # (unless only static PSK is used) @@ -762,8 +768,12 @@ def _clientSendClientHello(self, settings, session, srpUsername, if next((cipher for cipher in cipherSuites \ if cipher in CipherSuite.ecdhAllSuites), None) is not None: groups.extend(self._curveNamesToList(settings)) - extensions.append(ECPointFormatsExtension().\ - create([ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension().\ + create(settings.ec_point_formats)) + else: + extensions.append(ECPointFormatsExtension().\ + create(list([ECPointFormat.uncompressed]))) # Advertise FFDHE groups if we have DHE ciphers if next((cipher for cipher in cipherSuites if cipher in CipherSuite.dhAllSuites), None) is not None: @@ -838,7 +848,7 @@ def _clientSendClientHello(self, settings, session, srpUsername, session_id, wireCipherSuites, certificateTypes, srpUsername, - reqTack, nextProtos is not None, + reqTack, nextProtos is not None, serverName, extensions=extensions) @@ -915,6 +925,7 @@ def _clientGetServerHello(self, settings, session, clientHello): hello_retry = None ext = result.getExtension(ExtensionType.supported_versions) + if result.random == TLS_1_3_HRR and ext and ext.version > (3, 3): self.version = ext.version hello_retry = result @@ -974,7 +985,6 @@ def _clientGetServerHello(self, settings, session, clientHello): "did sent the key share " "for"): yield result - key_share = self._genKeyShareEntry(group_id, (3, 4)) # old key shares need to be removed @@ -1212,7 +1222,6 @@ def _clientTLS13Handshake(self, settings, session, clientHello, raise TLSIllegalParameterException("Server selected not " "advertised group.") kex = self._getKEX(sr_kex.group, self.version) - shared_sec = kex.calc_shared_key(cl_kex.private, sr_kex.key_exchange) else: @@ -1855,8 +1864,8 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom, cipherSuite, clientRandom, serverRandom) - self._calcPendingStates(cipherSuite, masterSecret, - clientRandom, serverRandom, + self._calcPendingStates(cipherSuite, masterSecret, + clientRandom, serverRandom, cipherImplementations) #Exchange ChangeCipherSpec and Finished messages @@ -1989,7 +1998,7 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None): def handshakeServer(self, verifierDB=None, certChain=None, privateKey=None, reqCert=False, sessionCache=None, settings=None, checker=None, - reqCAs = None, + reqCAs = None, tacks=None, activationFlags=0, nextProtos=None, anon=False, alpn=None, sni=None): """Perform a handshake in the role of server. @@ -2090,7 +2099,7 @@ def handshakeServer(self, verifierDB=None, def handshakeServerAsync(self, verifierDB=None, certChain=None, privateKey=None, reqCert=False, sessionCache=None, settings=None, checker=None, - reqCAs=None, + reqCAs=None, tacks=None, activationFlags=0, nextProtos=None, anon=False, alpn=None, sni=None ): @@ -2108,9 +2117,9 @@ def handshakeServerAsync(self, verifierDB=None, handshaker = self._handshakeServerAsyncHelper(\ verifierDB=verifierDB, cert_chain=certChain, privateKey=privateKey, reqCert=reqCert, - sessionCache=sessionCache, settings=settings, - reqCAs=reqCAs, - tacks=tacks, activationFlags=activationFlags, + sessionCache=sessionCache, settings=settings, + reqCAs=reqCAs, + tacks=tacks, activationFlags=activationFlags, nextProtos=nextProtos, anon=anon, alpn=alpn, sni=sni) for result in self._handshakeWrapperAsync(handshaker, checker): yield result @@ -2270,8 +2279,12 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.getExtension(ExtensionType.ec_point_formats): # even though the selected cipher may not use ECC, client may want # to send a CA certificate with ECDSA... - extensions.append(ECPointFormatsExtension().create( - [ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension(). + create(settings.ec_point_formats)) + else: + extensions.append(ECPointFormatsExtension().\ + create(list([ECPointFormat.uncompressed]))) # if client sent Heartbeat extension if clientHello.getExtension(ExtensionType.heartbeat): @@ -2412,6 +2425,11 @@ def _handshakeServerAsyncHelper(self, verifierDB, srpUsername = clientHello.srp_username.decode("utf-8") if clientHello.server_name: serverName = clientHello.server_name.decode("utf-8") + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + ext_ec_point = ECPointFormat.uncompressed + if ext_c and ext_s: + ext_ec_point = [i for i in ext_c.formats if i in ext_s.formats][0] # We'll update the session master secret once it is calculated # in _serverFinished @@ -2424,7 +2442,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, extendedMasterSecret=self.extendedMasterSecret, appProto=selectedALPN, # NOTE it must be a reference, not a copy! - tickets=self.tickets) + tickets=self.tickets, + ec_point_format=ext_ec_point) # Exchange Finished messages for result in self._serverFinished(premasterSecret, @@ -2709,8 +2728,8 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, (psk is None and privateKey): self.ecdhCurve = selected_group kex = self._getKEX(selected_group, version) - key_share = self._genKeyShareEntry(selected_group, version) - + key_share = self._genKeyShareEntry(selected_group, + version) try: shared_sec = kex.calc_shared_key(key_share.private, cl_key_share.key_exchange) diff --git a/unit_tests/test_tlslite_keyexchange.py b/unit_tests/test_tlslite_keyexchange.py index cfc02aa4..ea06a047 100644 --- a/unit_tests/test_tlslite_keyexchange.py +++ b/unit_tests/test_tlslite_keyexchange.py @@ -20,7 +20,7 @@ CertificateRequest, ClientKeyExchange from tlslite.constants import CipherSuite, CertificateType, AlertDescription, \ HashAlgorithm, SignatureAlgorithm, GroupName, ECCurveType, \ - SignatureScheme + SignatureScheme, ECPointFormat from tlslite.errors import TLSLocalAlert, TLSIllegalParameterException, \ TLSDecryptionFailed, TLSInsufficientSecurity, TLSUnknownPSKIdentity, \ TLSInternalError, TLSDecodeError @@ -33,7 +33,8 @@ from tlslite.mathtls import makeX, makeU, makeK, goodGroupParameters from tlslite.handshakehashes import HandshakeHashes from tlslite import VerifierDB -from tlslite.extensions import SupportedGroupsExtension, SNIExtension +from tlslite.extensions import SupportedGroupsExtension, SNIExtension, \ + ECPointFormatsExtension from tlslite.utils.ecc import getCurveByName, getPointByteSize from tlslite.utils.compat import a2b_hex import ecdsa @@ -2523,13 +2524,13 @@ def test_calc_public_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_public_value(None) + kex.calc_public_value(None, None) def test_calc_shared_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_shared_key(None, None) + kex.calc_shared_key(None, None, None) class TestFFDHKeyExchange(unittest.TestCase):