Skip to content

Commit

Permalink
type hinting for symmetric ciphers (#5719)
Browse files Browse the repository at this point in the history
* type hinting for symmetric ciphers

* make our interface verifier happy
  • Loading branch information
reaperhulk authored Jan 31, 2021
1 parent 576adfb commit 4372d3f
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 118 deletions.
31 changes: 31 additions & 0 deletions src/cryptography/hazmat/primitives/_cipheralgorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import abc


# This exists to break an import cycle. It is normally accessible from the
# ciphers module.


class CipherAlgorithm(metaclass=abc.ABCMeta):
@abc.abstractproperty
def name(self):
"""
A string naming this mode (e.g. "AES", "Camellia").
"""

@abc.abstractproperty
def key_size(self):
"""
The size of the key being used as an integer in bits (e.g. 128, 256).
"""


class BlockCipherAlgorithm(metaclass=abc.ABCMeta):
@abc.abstractproperty
def block_size(self):
"""
The size of a block as an integer in bits (e.g. 64, 128).
"""
69 changes: 55 additions & 14 deletions src/cryptography/hazmat/primitives/ciphers/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import os
import typing

from cryptography import exceptions, utils
from cryptography.hazmat.backends.openssl import aead
Expand All @@ -13,7 +14,7 @@
class ChaCha20Poly1305(object):
_MAX_SIZE = 2 ** 32

def __init__(self, key):
def __init__(self, key: bytes):
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"ChaCha20Poly1305 is not supported by this version of OpenSSL",
Expand All @@ -27,10 +28,15 @@ def __init__(self, key):
self._key = key

@classmethod
def generate_key(cls):
def generate_key(cls) -> bytes:
return os.urandom(32)

def encrypt(self, nonce, data, associated_data):
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

Expand All @@ -43,14 +49,24 @@ def encrypt(self, nonce, data, associated_data):
self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, associated_data, 16)

def decrypt(self, nonce, data, associated_data):
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, associated_data, 16)

def _check_params(self, nonce, data, associated_data):
def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_bytes("data", data)
utils._check_bytes("associated_data", associated_data)
Expand All @@ -61,7 +77,7 @@ def _check_params(self, nonce, data, associated_data):
class AESCCM(object):
_MAX_SIZE = 2 ** 32

def __init__(self, key, tag_length=16):
def __init__(self, key: bytes, tag_length: int = 16):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESCCM key must be 128, 192, or 256 bits.")
Expand All @@ -76,7 +92,7 @@ def __init__(self, key, tag_length=16):
self._tag_length = tag_length

@classmethod
def generate_key(cls, bit_length):
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")

Expand All @@ -85,7 +101,12 @@ def generate_key(cls, bit_length):

return os.urandom(bit_length // 8)

def encrypt(self, nonce, data, associated_data):
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

Expand All @@ -101,7 +122,12 @@ def encrypt(self, nonce, data, associated_data):
backend, self, nonce, data, associated_data, self._tag_length
)

def decrypt(self, nonce, data, associated_data):
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

Expand All @@ -128,15 +154,15 @@ def _check_params(self, nonce, data, associated_data):
class AESGCM(object):
_MAX_SIZE = 2 ** 32

def __init__(self, key):
def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESGCM key must be 128, 192, or 256 bits.")

self._key = key

@classmethod
def generate_key(cls, bit_length):
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")

Expand All @@ -145,7 +171,12 @@ def generate_key(cls, bit_length):

return os.urandom(bit_length // 8)

def encrypt(self, nonce, data, associated_data):
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

Expand All @@ -158,14 +189,24 @@ def encrypt(self, nonce, data, associated_data):
self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, associated_data, 16)

def decrypt(self, nonce, data, associated_data):
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""

self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, associated_data, 16)

def _check_params(self, nonce, data, associated_data):
def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_bytes("data", data)
utils._check_bytes("associated_data", associated_data)
Expand Down
70 changes: 27 additions & 43 deletions src/cryptography/hazmat/primitives/ciphers/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,135 +25,119 @@ def _verify_key_size(algorithm, key):
return key


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class AES(object):
class AES(CipherAlgorithm, BlockCipherAlgorithm):
name = "AES"
block_size = 128
# 512 added to support AES-256-XTS, which uses 512-bit keys
key_sizes = frozenset([128, 192, 256, 512])

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class Camellia(object):
class Camellia(CipherAlgorithm, BlockCipherAlgorithm):
name = "camellia"
block_size = 128
key_sizes = frozenset([128, 192, 256])

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class TripleDES(object):
class TripleDES(CipherAlgorithm, BlockCipherAlgorithm):
name = "3DES"
block_size = 64
key_sizes = frozenset([64, 128, 192])

def __init__(self, key):
def __init__(self, key: bytes):
if len(key) == 8:
key += key + key
elif len(key) == 16:
key += key[:8]
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class Blowfish(object):
class Blowfish(CipherAlgorithm, BlockCipherAlgorithm):
name = "Blowfish"
block_size = 64
key_sizes = frozenset(range(32, 449, 8))

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class CAST5(object):
class CAST5(CipherAlgorithm, BlockCipherAlgorithm):
name = "CAST5"
block_size = 64
key_sizes = frozenset(range(40, 129, 8))

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(CipherAlgorithm)
class ARC4(object):
class ARC4(CipherAlgorithm):
name = "RC4"
key_sizes = frozenset([40, 56, 64, 80, 128, 160, 192, 256])

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(CipherAlgorithm)
class IDEA(object):
class IDEA(CipherAlgorithm):
name = "IDEA"
block_size = 64
key_sizes = frozenset([128])

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(BlockCipherAlgorithm)
@utils.register_interface(CipherAlgorithm)
class SEED(object):
class SEED(CipherAlgorithm, BlockCipherAlgorithm):
name = "SEED"
block_size = 128
key_sizes = frozenset([128])

def __init__(self, key):
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8


@utils.register_interface(CipherAlgorithm)
@utils.register_interface(ModeWithNonce)
class ChaCha20(object):
class ChaCha20(CipherAlgorithm, ModeWithNonce):
name = "ChaCha20"
key_sizes = frozenset([256])

def __init__(self, key, nonce):
def __init__(self, key: bytes, nonce: bytes):
self.key = _verify_key_size(self, key)
utils._check_byteslike("nonce", nonce)

Expand All @@ -165,5 +149,5 @@ def __init__(self, key, nonce):
nonce = utils.read_only_property("_nonce")

@property
def key_size(self):
def key_size(self) -> int:
return len(self.key) * 8
15 changes: 1 addition & 14 deletions src/cryptography/hazmat/primitives/ciphers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,10 @@
)
from cryptography.hazmat.backends import _get_backend
from cryptography.hazmat.backends.interfaces import CipherBackend
from cryptography.hazmat.primitives._cipheralgorithm import CipherAlgorithm
from cryptography.hazmat.primitives.ciphers import modes


class CipherAlgorithm(metaclass=abc.ABCMeta):
@abc.abstractproperty
def name(self):
"""
A string naming this mode (e.g. "AES", "Camellia").
"""

@abc.abstractproperty
def key_size(self):
"""
The size of the key being used as an integer in bits (e.g. 128, 256).
"""


class BlockCipherAlgorithm(metaclass=abc.ABCMeta):
@abc.abstractproperty
def block_size(self):
Expand Down
Loading

0 comments on commit 4372d3f

Please sign in to comment.