Skip to content

Commit

Permalink
fix: changes in accepting the format(form ECPointFormat to string) of…
Browse files Browse the repository at this point in the history
… ec format.
  • Loading branch information
gstarovo committed Oct 22, 2024
1 parent 1a01313 commit b1de13b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 36 deletions.
82 changes: 48 additions & 34 deletions tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,9 @@ def makeServerKeyExchange(self, sigHash=None):
except StopIteration:
raise TLSIllegalParameterException("No common EC point format")

ext_negotiated = 'uncompressed' if \
ext_negotiated == ECPointFormat.uncompressed else 'compressed'

ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated)

version = self.serverHello.server_version
Expand Down Expand Up @@ -747,7 +750,12 @@ def processClientKeyExchange(self, clientKeyExchange):
]
if not ext_supported:
raise TLSIllegalParameterException("No common EC point format")
return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported)
ext_supported = map(
lambda x: 'uncompressed' if
x == ECPointFormat.uncompressed else
'compressed', ext_supported
)
return kex.calc_shared_key(self.ecdhXs, ecdhYc, set(ext_supported))

def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
"""Process the server key exchange, return premaster secret"""
Expand Down Expand Up @@ -783,8 +791,15 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
raise TLSIllegalParameterException(
"No common EC point format")

ext_negotiated = 'uncompressed' if \
ext_negotiated == ECPointFormat.uncompressed else 'compressed'
ext_supported = map(
lambda x: 'uncompressed' if
x == ECPointFormat.uncompressed else
'compressed', ext_supported
)
self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated)
return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported)
return kex.calc_shared_key(ecdhXc, ecdh_Ys, set(ext_supported))

def makeClientKeyExchange(self):
"""Make client key exchange for ECDHE"""
Expand Down Expand Up @@ -931,11 +946,11 @@ def get_random_private_key(self):
"""
raise NotImplementedError("Abstract class")

def calc_public_value(self, private, frm_negotiated=None):
def calc_public_value(self, private, point_format=None):
"""Calculate the public value from the provided private value."""
raise NotImplementedError("Abstract class")

def calc_shared_key(self, private, peer_share, frm_supported=None):
def calc_shared_key(self, private, peer_share, valid_point_formats=None):
"""Calcualte the shared key given our private and remote share value"""
raise NotImplementedError("Abstract class")

Expand Down Expand Up @@ -968,10 +983,11 @@ def get_random_private_key(self):
needed_bytes = divceil(paramStrength(self.prime) * 2, 8)
return bytesToNumber(getRandomBytes(needed_bytes))

def calc_public_value(self, private, frm_negotiated=None):
def calc_public_value(self, private, point_format=None):
"""
Calculate the public value for given private value.
Frm_negotiated added for API compatibility, not needed for FFDH.
:param point_format: ignored, used for compatibility with ECDH groups
:rtype: int
"""
Expand All @@ -993,9 +1009,10 @@ def _normalise_peer_share(self, peer_share):
"Key share does not match FFDH prime")
return bytesToNumber(peer_share)

def calc_shared_key(self, private, peer_share, frm_supported=None):
def calc_shared_key(self, private, peer_share, valid_point_formats=None):
"""Calculate the shared key.
Frm_supported added for API compatibility, not needed for FFDH.
:param valid_point_formats: ignored, used for compatibility with ECDH groups
:rtype: bytearray"""
peer_share = self._normalise_peer_share(peer_share)
Expand Down Expand Up @@ -1052,51 +1069,46 @@ def _get_fun_gen_size(self):
else:
return x448, bytearray(X448_G), X448_ORDER_SIZE

@staticmethod
def _get_point_format(ext):
"""Get point format from the numeric value."""
transform = {ECPointFormat.uncompressed: 'uncompressed',
ECPointFormat.ansiX962_compressed_char2: 'compressed',
ECPointFormat.ansiX962_compressed_prime: 'compressed'}
return transform[ext]

def calc_public_value(self,
private,
frm_negotiated=ECPointFormat.uncompressed):
"""Calculate public value for given private key."""
point_fmt = self._get_point_format(frm_negotiated)
point_format='uncompressed'):
"""Calculate public value for given private key.
:param private: Private key for the selected key exchange group.
:param str point_format: The point format to use for the
ECDH public key. Applies only to NIST curves.
"""
if isinstance(private, ecdsa.keys.SigningKey):
return private.verifying_key.to_string(point_fmt)
return private.verifying_key.to_string(point_format)
if self.group in self._x_groups:
fun, generator, _ = self._get_fun_gen_size()
return fun(private, generator)

curve = getCurveByName(GroupName.toStr(self.group))
point = curve.generator * private
return bytearray(point.to_bytes(encoding=point_fmt))
return bytearray(point.to_bytes(encoding=point_format))

def calc_shared_key(self, private, peer_share,
frm_supported=set([ECPointFormat.uncompressed])):
valid_point_formats=set(['uncompressed'])):
"""Calculate the shared key.
:type private: bytearray | SigningKey
:param private: private value
:param bytearray | SigningKey private: private value
:type peer_share: bytearray
:param peer_share: public value
:param bytearray peer_share: public value
:type frm_supported: set(ECPointFormat)
:param frm_supported: acceptable point formats for public value
:param set(str) valid_point_formats: list of point formats that
the peer share can be in; ["uncompressed"] by default.
:rtype: bytearray
:returns: shared key
:raises TLSIllegalParameterException
when the paramentrs for point are invalid
"""
valid_encodings = set([self._get_point_format(i) \
for i in frm_supported])
:raises TLSDecodeError
when the the valid_point_formats is empty
"""
if self.group in self._x_groups:
fun, _, size = self._get_fun_gen_size()
if len(peer_share) != size:
Expand All @@ -1111,17 +1123,19 @@ def calc_shared_key(self, private, peer_share,
try:
abstractPoint = ecdsa.ellipticcurve.AbstractPoint()
point = abstractPoint.from_bytes(curve.curve, peer_share,
valid_encodings=valid_encodings)
valid_encodings=valid_point_formats)
ecdhYc = ecdsa.ellipticcurve.Point(
curve.curve, point[0], point[1])

except (AssertionError, DecodeError):
except (AssertionError):
raise TLSIllegalParameterException("Invalid ECC point")
except DecodeError as err:
raise TLSDecodeError(f"Unexpected error {err=}, {type(err)=}") from err
if isinstance(private, ecdsa.keys.SigningKey):
ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private)
ecdh.load_received_public_key_bytes(peer_share,
valid_encodings=
valid_encodings)
valid_point_formats)
return bytearray(ecdh.generate_sharedsecret_bytes())
S = ecdhYc * private

Expand Down
2 changes: 1 addition & 1 deletion tlslite/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self):
self.resumptionMasterSecret = bytearray(0)
self.tickets = None
self.tls_1_0_tickets = None
self.ec_point_format = None
self.ec_point_format = 0

def create(self, masterSecret, sessionID, cipherSuite,
srpUsername, clientCertChain, serverCertChain,
Expand Down
3 changes: 2 additions & 1 deletion tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3242,7 +3242,8 @@ def _ticket_to_session(self, settings, ticket_ext):
serverName=ticket.server_name.decode("utf-8") if
ticket.server_name else "",
encryptThenMAC=ticket.encrypt_then_mac,
extendedMasterSecret=ticket.extended_master_secret)
extendedMasterSecret=ticket.extended_master_secret,
ec_point_format=0)
return session

def _serverGetClientHello(self, settings, private_key, cert_chain,
Expand Down

0 comments on commit b1de13b

Please sign in to comment.