Skip to content

Commit

Permalink
Merge pull request #61 from magicus/fix-salt-leading-zero
Browse files Browse the repository at this point in the history
Keep leading zero bytes in salt, based on gist by @Gdocal
  • Loading branch information
cocagne authored Nov 1, 2024
2 parents fbdfe9d + e623a6a commit 8c55b0a
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions srp/_pysrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import hashlib
import os
import binascii
import six


_rfc5054_compat = False
Expand Down Expand Up @@ -139,7 +140,7 @@ def get_ng( ng_type, n_hex, g_hex ):

def bytes_to_long(s):
n = 0
for b in s:
for b in six.iterbytes(s):
n = (n << 8) | b
return n

Expand All @@ -154,7 +155,7 @@ def long_to_bytes(n):
x = x | (b << off)
off += 8
l.reverse()
return (''.join(l)).encode('latin1')
return six.b(''.join(l))


def get_random( nbytes ):
Expand All @@ -167,14 +168,14 @@ def get_random_of_length( nbytes ):


def old_H( hash_class, s1, s2 = '', s3=''):
if isinstance(s1, int):
if isinstance(s1, six.integer_types):
s1 = long_to_bytes(s1)
if s2 and isinstance(s2, int):
if s2 and isinstance(s2, six.integer_types):
s2 = long_to_bytes(s2)
if s3 and isinstance(s3, int):
if s3 and isinstance(s3, six.integer_types):
s3 = long_to_bytes(s3)
s = s1 + s2 + s3
return int(hash_class(s).hexdigest(), 16)
return long(hash_class(s).hexdigest(), 16)


def H( hash_class, *args, **kwargs ):
Expand All @@ -184,7 +185,7 @@ def H( hash_class, *args, **kwargs ):

for s in args:
if s is not None:
data = long_to_bytes(s) if isinstance(s, int) else s
data = long_to_bytes(s) if isinstance(s, six.integer_types) else s
if width is not None and _rfc5054_compat:
h.update( bytes(width - len(data)))
h.update( data )
Expand All @@ -206,16 +207,16 @@ def HNxorg( hash_class, N, g ):
hN = hash_class( bin_N ).digest()
hg = hash_class( b''.join( [b'\0'*padding, bin_g] ) ).digest()

return ( ''.join( chr( hN[i] ^ hg[i] ) for i in range(0,len(hN)) ) ).encode('latin1')
return six.b( ''.join( chr( six.indexbytes(hN, i) ^ six.indexbytes(hg, i) ) for i in range(0,len(hN)) ) )



def gen_x( hash_class, salt, username, password ):
username = username.encode() if hasattr(username, 'encode') else username
password = password.encode() if hasattr(password, 'encode') else password
if _no_username_in_x:
username = b''
return bytes_to_long( H(hash_class, salt, H( hash_class, username + b':' + password ) ))
username = six.b('')
return bytes_to_long( H(hash_class, salt, H( hash_class, username + six.b(':') + password ) ))



Expand All @@ -237,7 +238,9 @@ def calculate_M( hash_class, N, g, I, s, A, B, K ):
h = hash_class()
h.update( HNxorg( hash_class, N, g ) )
h.update( hash_class(I).digest() )
h.update( long_to_bytes(s) )
if isinstance(s, six.integer_types):
s = long_to_bytes(s)
h.update( s )
h.update( long_to_bytes(A) )
h.update( long_to_bytes(B) )
h.update( K )
Expand All @@ -259,9 +262,9 @@ class Verifier:
def __init__(self, username, bytes_s, bytes_v, bytes_A=None, hash_alg=SHA1, ng_type=NG_2048, n_hex=None, g_hex=None, bytes_b=None, k_hex=None):
if ng_type == NG_CUSTOM and (n_hex is None or g_hex is None):
raise ValueError("Both n_hex and g_hex are required when ng_type = NG_CUSTOM")
if bytes_b and len(bytes_b) != 32:
raise ValueError("32 bytes required for bytes_b")
self.s = bytes_to_long(bytes_s)
if bytes_b and len(bytes_b) != 256:
raise ValueError("256 bytes required for bytes_b")
self.s = bytes_s
self.v = bytes_to_long(bytes_v)
self.I = username
self.K = None
Expand All @@ -288,7 +291,7 @@ def __init__(self, username, bytes_s, bytes_v, bytes_A=None, hash_alg=SHA1, ng_t
if bytes_b:
self.b = bytes_to_long(bytes_b)
else:
self.b = get_random_of_length( 32 )
self.b = get_random_of_length( 256 )
self.B = (k*self.v + pow(g, self.b, N)) % N


Expand All @@ -312,7 +315,7 @@ def get_challenge(self):
if self.safety_failed:
return None,None
else:
return (long_to_bytes(self.s), long_to_bytes(self.B))
return (self.s, long_to_bytes(self.B))

# returns H_AMK on success, None on failure
def verify_session(self, user_M, bytes_A=None):
Expand Down Expand Up @@ -347,8 +350,8 @@ class User:
def __init__(self, username, password, hash_alg=SHA1, ng_type=NG_2048, n_hex=None, g_hex=None, bytes_a=None, bytes_A=None, k_hex=None):
if ng_type == NG_CUSTOM and (n_hex is None or g_hex is None):
raise ValueError("Both n_hex and g_hex are required when ng_type = NG_CUSTOM")
if bytes_a and len(bytes_a) != 32:
raise ValueError("32 bytes required for bytes_a")
if bytes_a and len(bytes_a) != 256:
raise ValueError("256 bytes required for bytes_a")
N,g = get_ng( ng_type, n_hex, g_hex )
hash_class = _hash_map[ hash_alg ]
if k_hex is None:
Expand All @@ -361,7 +364,7 @@ def __init__(self, username, password, hash_alg=SHA1, ng_type=NG_2048, n_hex=Non
if bytes_a:
self.a = bytes_to_long(bytes_a)
else:
self.a = get_random_of_length( 32 )
self.a = get_random_of_length( 256 )
if bytes_A:
self.A = bytes_to_long(bytes_A)
else:
Expand Down Expand Up @@ -401,7 +404,7 @@ def start_authentication(self):
# Returns M or None if SRP-6a safety check is violated
def process_challenge(self, bytes_s, bytes_B):

self.s = bytes_to_long( bytes_s )
self.s = bytes_s
self.B = bytes_to_long( bytes_B )

N = self.N
Expand Down

0 comments on commit 8c55b0a

Please sign in to comment.