Skip to content

Commit

Permalink
Add support for SP 800 108 Counter mode KDF
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Jan 27, 2023
1 parent 2c61a7b commit 6e8cfe0
Show file tree
Hide file tree
Showing 11 changed files with 11,975 additions and 10 deletions.
1 change: 1 addition & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changelog

New features
---------------
* Added support for the Counter Mode KDF defined in SP 800-108 Rev 1.
* Reduce the minimum tag length for the EAX cipher to 2 bytes.
* An RSA object has 4 new properties ``dp``, ``dq``, ``invq`` and ``invq``
for the CRT coefficients (``invp`` is the same as the old ``u``).
Expand Down
43 changes: 43 additions & 0 deletions Doc/src/protocol/kdf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,49 @@ Example, for deriving two AES256 keys::

.. autofunction:: Crypto.Protocol.KDF.HKDF

Counter Mode
++++++++++++

A KDF can be generically constructed with a pseudorandom function (PRF).
If the PRF has a fixed-length output,
you can evaluate the PRF multiple times and concatenate the results until you collect enough derived keying material.

This function implements such type of KDF, where a counter contributes to each invokation of the PRF, as defined in
`NIST SP 800-108 Rev 1 <https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-108r1.pdf>`_.
The NIST standard only allows the use of HMAC (recommended) and CMAC (not recommended) as PRF.

This KDF is not suitable for deriving keys from a password.

Example 1 (HMAC as PRF, one AES128 key to derive)::

>> from Crypto.Hash import SHA256, HMAC
>>
>> def prf(s, x):
>> return HMAC.new(s, x, SHA256).digest()
>>
>> key_derived = SP800_108_Counter(secret, 16, prf, label=b'Key A')

Example 2 (HMAC as PRF, two AES128 keys to derive)::

>> from Crypto.Hash import SHA256, HMAC
>>
>> def prf(s, x):
>> return HMAC.new(s, x, SHA256).digest()
>>
>> key_A, key_B = SP800_108_Counter(secret, 16, prf, num_keys=2, label=b'Key AB')

Example 3 (CMAC as PRF, two AES256 keys to derive)::

>> from Crypto.Cipher import AES
>> from Crypto.Hash import SHA256, CMAC
>>
>> def prf(s, x):
>> return CMAC.new(s, x, AES).digest()
>>
>> key_A, key_B = SP800_108_Counter(secret, 32, prf, num_keys=2, label=b'Key AB')

.. autofunction:: Crypto.Protocol.KDF.SP800_108_Counter

PBKDF1
+++++++

Expand Down
1 change: 0 additions & 1 deletion lib/Crypto/Cipher/AES.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from typing import ByteString, Dict, Optional, Tuple, Union, overload
from typing_extensions import Literal

Expand Down
70 changes: 66 additions & 4 deletions lib/Crypto/Protocol/KDF.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from functools import reduce

from Crypto.Util.py3compat import (tobytes, bord, _copy_bytes, iter_range,
tostr, bchr, bstr)
tostr, bchr, bstr)

from Crypto.Hash import SHA1, SHA256, HMAC, CMAC, BLAKE2s
from Crypto.Util.strxor import strxor
Expand Down Expand Up @@ -287,13 +287,13 @@ def HKDF(master, key_len, salt, hashmod, num_keys=1, context=None):
The unguessable value used by the KDF to generate the other keys.
It must be a high-entropy secret, though not necessarily uniform.
It must not be a password.
key_len (integer):
The length in bytes of every derived key.
salt (byte string):
A non-secret, reusable value that strengthens the randomness
extraction step.
Ideally, it is as long as the digest size of the chosen hash.
If empty, a string of zeroes in used.
key_len (integer):
The length in bytes of every derived key.
hashmod (module):
A cryptographic hash algorithm from :mod:`Crypto.Hash`.
:mod:`Crypto.Hash.SHA512` is a good choice.
Expand Down Expand Up @@ -352,7 +352,7 @@ def scrypt(password, salt, key_len, N, r, p, num_keys=1):
but it should be randomly chosen for each derivation.
It is recommended to be at least 16 bytes long.
key_len (integer):
The length in bytes of every derived key.
The length in bytes of each derived key.
N (integer):
CPU/Memory cost parameter. It must be a power of 2 and less
than :math:`2^{32}`.
Expand Down Expand Up @@ -578,3 +578,65 @@ def bcrypt_check(password, bcrypt_hash):
mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=bcrypt_hash2).digest()
if mac1 != mac2:
raise ValueError("Incorrect bcrypt hash")


def SP800_108_Counter(master, key_len, prf, num_keys=None, label=b'', context=b''):
"""Derive one or more keys from a master secret using
a pseudorandom function in Counter Mode, as specified in
`NIST SP 800-108r1 <https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-108r1.pdf>`_.
Args:
master (byte string):
The secret value used by the KDF to derive the other keys.
It must not be a password.
The length on the secret must be consistent with the input expected by
the :data:`prf` function.
key_len (integer):
The length in bytes of each derived key.
prf (function):
A pseudorandom function that takes two byte strings as parameters:
the secret and an input. It returns another byte string.
num_keys (integer):
The number of keys to derive. Every key is :data:`key_len` bytes long.
By default, only 1 key is derived.
label (byte string):
Optional description of the purpose of the derived keys.
It must not contain zero bytes.
context (byte string):
Optional information pertaining to
the protocol that uses the keys, such as the identity of the
participants, nonces, session IDs, etc.
It must not contain zero bytes.
Return:
- a byte string (if ``num_keys`` is not specified), or
- a tuple of byte strings (if ``num_key`` is specified).
"""

if num_keys is None:
num_keys = 1

if label.find(b'\x00') != -1:
raise ValueError("Null byte found in label")

if context.find(b'\x00') != -1:
raise ValueError("Null byte found in context")

key_len_enc = long_to_bytes(key_len * num_keys * 8, 4)
output_len = key_len * num_keys

i = 1
dk = b""
while len(dk) < output_len:
info = long_to_bytes(i, 4) + label + b'\x00' + context + key_len_enc
dk += prf(master, info)
i += 1
if i > 0xFFFFFFFF:
raise ValueError("Overflow in SP800 108 counter")

if num_keys == 1:
return dk[:key_len]
else:
kol = [dk[idx:idx + key_len]
for idx in iter_range(0, output_len, key_len)]
return kol
18 changes: 17 additions & 1 deletion lib/Crypto/Protocol/KDF.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from types import ModuleType
from typing import Optional, Callable, Tuple, Union, Dict, Any
from typing import Optional, Callable, Tuple, Union, Dict, Any, ByteString, overload
from typing_extensions import Literal

RNG = Callable[[int], bytes]
PRF = Callable[[bytes, bytes], bytes]

def PBKDF1(password: str, salt: bytes, dkLen: int, count: Optional[int]=1000, hashAlgo: Optional[ModuleType]=None) -> bytes: ...
def PBKDF2(password: str, salt: bytes, dkLen: Optional[int]=16, count: Optional[int]=1000, prf: Optional[RNG]=None, hmac_hash_module: Optional[ModuleType]=None) -> bytes: ...
Expand All @@ -22,3 +24,17 @@ def _bcrypt_decode(data: bytes) -> bytes: ...
def _bcrypt_hash(password:bytes , cost: int, salt: bytes, constant:bytes, invert:bool) -> bytes: ...
def bcrypt(password: Union[bytes, str], cost: int, salt: Optional[bytes]=None) -> bytes: ...
def bcrypt_check(password: Union[bytes, str], bcrypt_hash: Union[bytes, bytearray, str]) -> None: ...

@overload
def SP800_108_Counter(master: ByteString,
key_len: int,
prf: PRF,
num_keys: Literal[None] = None,
label: ByteString = b'', context: ByteString = b'') -> bytes: ...

@overload
def SP800_108_Counter(master: ByteString,
key_len: int,
prf: PRF,
num_keys: int,
label: ByteString = b'', context: ByteString = b'') -> Tuple[bytes]: ...
79 changes: 77 additions & 2 deletions lib/Crypto/SelfTest/Protocol/test_KDF.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
# SOFTWARE.
# ===================================================================

import re
import unittest
from binascii import unhexlify

from Crypto.Util.py3compat import b, bchr

from Crypto.SelfTest.st_common import list_test_cases
from Crypto.SelfTest.loader import load_test_vectors_wycheproof
from Crypto.SelfTest.loader import load_test_vectors, load_test_vectors_wycheproof
from Crypto.Hash import SHA1, HMAC, SHA256, MD5, SHA224, SHA384, SHA512
from Crypto.Cipher import AES, DES3

from Crypto.Protocol.KDF import (PBKDF1, PBKDF2, _S2V, HKDF, scrypt,
bcrypt, bcrypt_check)
bcrypt, bcrypt_check,
SP800_108_Counter)

from Crypto.Protocol.KDF import _bcrypt_decode

Expand Down Expand Up @@ -708,6 +710,78 @@ def runTest(self):
self.test_verify(tv)


def load_hash_by_name(hash_name):
return __import__("Crypto.Hash." + hash_name, globals(), locals(), ["new"])


class SP800_180_Counter_Tests(unittest.TestCase):

def test_negative_zeroes(self):
def prf(s, x):
return HMAC.new(s, x, SHA256).digest()

self.assertRaises(ValueError, SP800_108_Counter, b'0' * 16, 1, prf,
label=b'A\x00B')
self.assertRaises(ValueError, SP800_108_Counter, b'0' * 16, 1, prf,
context=b'A\x00B')

def test_multiple_keys(self):
def prf(s, x):
return HMAC.new(s, x, SHA256).digest()

key = b'0' * 16
expected = SP800_108_Counter(key, 2*3*23, prf)
for r in (1, 2, 3, 23):
dks = SP800_108_Counter(key, r, prf, 138//r)
self.assertEqual(len(dks), 138//r)
self.assertEqual(len(dks[0]), r)
self.assertEqual(b''.join(dks), expected)


def add_tests_sp800_108_counter(cls):

test_vectors_sp800_108_counter = load_test_vectors(("Protocol", ),
"KDF_SP800_108_COUNTER.txt",
"NIST SP 800 108 KDF Counter Mode",
{'count': lambda x: int(x)},
) or []

mac_type = None
for idx, tv in enumerate(test_vectors_sp800_108_counter):

if isinstance(tv, str):
res = re.match(r"\[HMAC-(SHA-[0-9]+)\]", tv)
if res:
hash_name = res.group(1).replace("-", "")
hash_module = load_hash_by_name(hash_name)
mac_type = "hmac"
continue
res = re.match(r"\[CMAC-AES-128\]", tv)
if res:
mac_type = "cmac"
continue
assert res

if mac_type == "hmac":
def prf(s, x, hash_module=hash_module):
return HMAC.new(s, x, hash_module).digest()
elif mac_type == "cmac":
def prf(s, x, hash_module=hash_module):
return CMAC.new(s, x, AES).digest()
continue

def kdf_test(self, prf=prf, kin=tv.kin, label=tv.label,
context=tv.context, kout=tv.kout, count=tv.count):
result = SP800_108_Counter(kin, len(kout), prf, 1, label, context)
assert(len(result) == len(kout))
self.assertEqual(result, kout)

setattr(cls, "test_kdf_sp800_108_counter_%d" % idx, kdf_test)


add_tests_sp800_108_counter(SP800_180_Counter_Tests)


def get_tests(config={}):
wycheproof_warnings = config.get('wycheproof_warnings')

Expand All @@ -723,6 +797,7 @@ def get_tests(config={}):
tests += [TestVectorsHKDFWycheproof(wycheproof_warnings)]
tests += list_test_cases(scrypt_Tests)
tests += list_test_cases(bcrypt_Tests)
tests += list_test_cases(SP800_180_Counter_Tests)

return tests

Expand Down
2 changes: 1 addition & 1 deletion test_vectors/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include MANIFEST.in
include setup.py
include LICENSE.rst
recursive-include src *.h *.c
recursive-include src *.h *.c *.cpp
graft pycryptodome_test_vectors
global-exclude __pycache__ *.pyc
Loading

0 comments on commit 6e8cfe0

Please sign in to comment.