Skip to content

Commit

Permalink
codestyle: Type-annotate ca_util.py and add to mypy
Browse files Browse the repository at this point in the history
Type-annotate ca_util.py and also minimally annotate
revocation_notifier.py so ca_util.py passed mypy checks.

Signed-off-by: Stefan Berger <[email protected]>
  • Loading branch information
stefanberger committed Dec 5, 2022
1 parent 45edb4f commit 9222f81
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
59 changes: 36 additions & 23 deletions keylime/ca_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,27 @@
import zipfile
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
from typing import Any, Dict, List, Optional, Tuple, Union

import yaml

try:
from yaml import CSafeDumper as SafeDumper
from yaml import CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeLoader, SafeDumper
from yaml import SafeLoader, SafeDumper # type: ignore

from cryptography import exceptions as crypto_exceptions
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.x509 import Certificate
from cryptography.x509.extensions import ExtensionNotFound
from cryptography.x509.general_name import UniformResourceIdentifier

Expand All @@ -51,10 +58,10 @@

logger = keylime_logging.init_logging("ca-util")

global_password = None
global_password: Optional[str] = None


def load_cert_by_path(cert_path):
def load_cert_by_path(cert_path: str) -> Certificate:
cert = None
with open(cert_path, "rb") as ca_file:
cert = x509.load_pem_x509_certificate(
Expand All @@ -64,7 +71,7 @@ def load_cert_by_path(cert_path):
return cert


def setpassword(pw):
def setpassword(pw: Optional[str]) -> None:
global global_password
if not pw:
pw = getpass.getpass("Please enter the password to decrypt your keystore: ")
Expand All @@ -75,14 +82,20 @@ def setpassword(pw):
global_password = pw


def cmd_mkcert(workingdir, name, password=None):
def cmd_mkcert(workingdir: str, name: str, password: Optional[str] = None) -> None:
cwd = os.getcwd()
mask = os.umask(0o037)
try:
fs_util.ch_dir(workingdir)
priv = read_private()
cacert = load_cert_by_path("cacert.crt")
ca_pk = serialization.load_pem_private_key(priv[0]["ca"], password=None, backend=default_backend())
if not isinstance(
ca_pk, (EllipticCurvePrivateKey, RSAPrivateKey, DSAPrivateKey, Ed448PrivateKey, Ed25519PrivateKey)
):
raise Exception(
f"Private key of type {type(ca_pk).__name__} cannot be used for creating an x509 certificate"
)

cert, pk = ca_impl.mk_signed_cert(cacert, ca_pk, name, priv[0]["lastserial"] + 1)

Expand Down Expand Up @@ -131,7 +144,7 @@ def cmd_mkcert(workingdir, name, password=None):
os.chdir(cwd)


def cmd_init(workingdir):
def cmd_init(workingdir: str) -> None:
cwd = os.getcwd()
mask = os.umask(0o037)
try:
Expand Down Expand Up @@ -200,7 +213,7 @@ def cmd_init(workingdir):
os.umask(mask)


def cmd_certpkg(workingdir, name, insecure=False):
def cmd_certpkg(workingdir: str, name: str, insecure: bool = False) -> Tuple[bytes, int, str]:
cwd = os.getcwd()
try:
fs_util.ch_dir(workingdir)
Expand Down Expand Up @@ -267,13 +280,13 @@ def cmd_certpkg(workingdir, name, insecure=False):
os.chdir(cwd)


def convert_crl_to_pem(derfile, pemfile):
def convert_crl_to_pem(derfile: str, pemfile: str) -> None:
with open(derfile, "rb") as der_f, open(pemfile, "wb") as pem_f:
der_crl = der_f.read()
pem_f.write(x509.load_der_x509_crl(der_crl).public_bytes(encoding=serialization.Encoding.PEM))


def get_crl_distpoint(cert_path):
def get_crl_distpoint(cert_path: str) -> Optional[str]:
cert_obj = load_cert_by_path(cert_path)

try:
Expand All @@ -292,7 +305,7 @@ def get_crl_distpoint(cert_path):
# to check: openssl crl -inform DER -text -noout -in cacrl.der


def cmd_revoke(workingdir, name=None, serial=None):
def cmd_revoke(workingdir: str, name: Optional[str] = None, serial: Optional[int] = None) -> bytes:
cwd = os.getcwd()
try:
fs_util.ch_dir(workingdir)
Expand All @@ -308,15 +321,15 @@ def cmd_revoke(workingdir, name=None, serial=None):
serial = cert.serial_number

# convert serial to string
serial = str(serial)
serial_str = str(serial)

# get the ca key cert and keys as strings
with open("cacert.crt", encoding="utf-8") as f:
cacert = f.read()
ca_pk = priv[0]["ca"].decode("utf-8")

if serial not in priv[0]["revoked_keys"]:
priv[0]["revoked_keys"].append(serial)
if serial_str not in priv[0]["revoked_keys"]:
priv[0]["revoked_keys"].append(serial_str)

crl = ca_impl.gencrl(priv[0]["revoked_keys"], cacert, ca_pk)

Expand All @@ -336,7 +349,7 @@ def cmd_revoke(workingdir, name=None, serial=None):
# regenerate the crl without revoking anything


def cmd_regencrl(workingdir):
def cmd_regencrl(workingdir: str) -> bytes:
cwd = os.getcwd()
try:
fs_util.ch_dir(workingdir)
Expand All @@ -361,7 +374,7 @@ def cmd_regencrl(workingdir):
return crl


def cmd_listen(workingdir, cert_path):
def cmd_listen(workingdir: str, cert_path: str) -> None:
cwd = os.getcwd()
try:
fs_util.ch_dir(workingdir)
Expand All @@ -378,7 +391,7 @@ def cmd_listen(workingdir, cert_path):
logger.info("Hosting CRL on %s:%d", socket.getfqdn(), config.CRL_PORT)
t.start()

def check_expiration():
def check_expiration() -> None:
logger.info("checking CRL for expiration every hour")
while True: # pylint: disable=R1702
try:
Expand All @@ -405,7 +418,7 @@ def check_expiration():
t2 = threading.Thread(target=check_expiration, daemon=True)
t2.start()

def revoke_callback(revocation):
def revoke_callback(revocation: Dict[str, Union[str, bytes]]) -> None:
json_meta = json.loads(revocation["meta_data"])
serial = json_meta["cert_serial"]
if revocation.get("type", None) != "revocation" or serial is None:
Expand Down Expand Up @@ -434,12 +447,12 @@ def revoke_callback(revocation):
class ThreadedCRLServer(ThreadingMixIn, HTTPServer):
published_crl = None

def setcrl(self, crl):
def setcrl(self, crl: bytes) -> None:
self.published_crl = crl


class CRLHandler(BaseHTTPRequestHandler):
def do_GET(self):
def do_GET(self) -> None:
logger.info("GET invoked from %s with uri: %s", str(self.client_address), self.path)

assert isinstance(self.server, ThreadedCRLServer)
Expand All @@ -453,13 +466,13 @@ def do_GET(self):
self.wfile.write(self.server.published_crl)


def rmfiles(path):
def rmfiles(path: str) -> None:
files = glob.glob(path)
for f in files:
os.remove(f)


def write_private(inp):
def write_private(inp: Tuple[Dict[str, Any], str]) -> None:
priv = inp[0]
salt = inp[1]

Expand All @@ -475,7 +488,7 @@ def write_private(inp):
os.umask(mask)


def read_private(warn=False):
def read_private(warn: bool = False) -> Tuple[Dict[str, Any], str]:
if global_password is None:
setpassword(getpass.getpass("Please enter the password to decrypt your keystore: "))

Expand All @@ -498,7 +511,7 @@ def read_private(warn=False):
return {"revoked_keys": [], "ca": b""}, base64.b64encode(crypto.generate_random_key()).decode()


def main(argv=sys.argv): # pylint: disable=dangerous-default-value
def main(argv: List[str] = sys.argv) -> None: # pylint: disable=dangerous-default-value
parser = argparse.ArgumentParser(argv[0])
parser.add_argument(
"-c",
Expand Down
4 changes: 2 additions & 2 deletions keylime/revocation_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
from multiprocessing import Process
from typing import Optional, cast
from typing import Callable, Optional, cast

import requests

Expand Down Expand Up @@ -190,7 +190,7 @@ def process_revocation(revocation, callback, cert_path):
callback(message)


def await_notifications(callback, revocation_cert_path):
def await_notifications(callback: Callable, revocation_cert_path: str) -> None:
assert config.getboolean("agent", "enable_revocation_notifications", fallback=False)
try:
import zmq # pylint: disable=import-outside-toplevel
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ ignore_errors = False
[mypy-keylime.ca_impl_openssl]
ignore_errors = False

[mypy-keylime.ca_util]
ignore_errors = False

[mypy-keylime.cert_utils]
ignore_errors = False

Expand Down

0 comments on commit 9222f81

Please sign in to comment.