diff --git a/pyproject.toml b/pyproject.toml index f3d3e8c..4cd3ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sd-jwt" -version = "0.9.1" +version = "0.10.0" description = "The reference implementation of the IETF SD-JWT specification." authors = ["Daniel Fett "] readme = "README.md" diff --git a/src/sd_jwt/__init__.py b/src/sd_jwt/__init__.py index d69d16e..61fb31c 100644 --- a/src/sd_jwt/__init__.py +++ b/src/sd_jwt/__init__.py @@ -1 +1 @@ -__version__ = "0.9.1" +__version__ = "0.10.0" diff --git a/src/sd_jwt/common.py b/src/sd_jwt/common.py index 1398e69..0c88f4e 100644 --- a/src/sd_jwt/common.py +++ b/src/sd_jwt/common.py @@ -33,7 +33,7 @@ def __init__(self, error_location: any): class SDJWTCommon: - SD_JWT_TYP_HEADER = None # "sd+jwt" + SD_JWT_HEADER = "sd+jwt" KB_JWT_TYP_HEADER = "kb+jwt" JWS_KEY_DISCLOSURES = "disclosures" JWS_KEY_KB_JWT = "kb_jwt" @@ -91,14 +91,14 @@ def _create_hash_mappings(self, disclosurses_list: List): decoded_disclosure = loads( self._base64url_decode(disclosure).decode("utf-8") ) - hash = self._b64hash(disclosure.encode("ascii")) - if hash in self._hash_to_decoded_disclosure: + _hash = self._b64hash(disclosure.encode("ascii")) + if _hash in self._hash_to_decoded_disclosure: raise ValueError( - f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}" + f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}" ) - self._hash_to_decoded_disclosure[hash] = decoded_disclosure - self._hash_to_disclosure[hash] = disclosure + self._hash_to_decoded_disclosure[_hash] = decoded_disclosure + self._hash_to_disclosure[_hash] = disclosure def _check_for_sd_claim(self, the_object): # Recursively check for the presence of the _sd claim, also diff --git a/src/sd_jwt/holder.py b/src/sd_jwt/holder.py index 605f8c2..a63eec6 100644 --- a/src/sd_jwt/holder.py +++ b/src/sd_jwt/holder.py @@ -94,7 +94,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None) ): if ( - type(element) is dict + isinstance(element, dict) and len(element) == 1 and SD_LIST_PREFIX in element and type(element[SD_LIST_PREFIX]) is str @@ -116,11 +116,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): continue self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check]) - if type(disclosure_value) is dict: + if isinstance(disclosure_value, dict): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose_element = {} - if not type(claims_to_disclose_element) is dict: + if not isinstance(claims_to_disclose_element, dict): raise ValueError( f"To disclose object elements in arrays, provide an object (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -130,11 +130,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): self._select_disclosures( disclosure_value, claims_to_disclose_element ) - elif type(disclosure_value) is list: + elif isinstance(disclosure_value, list): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an array claims_to_disclose_element = [] - if not type(claims_to_disclose_element) is list: + if not isinstance(claims_to_disclose_element, list): raise ValueError( f"To disclose array elements nested in arrays, provide an array (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -155,7 +155,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): if claims_to_disclose is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose = {} - if not type(claims_to_disclose) is dict: + if not isinstance(claims_to_disclose, dict): raise ValueError( f"To disclose object elements, an object must be provided as disclosure information.\n" f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n" diff --git a/src/sd_jwt/issuer.py b/src/sd_jwt/issuer.py index 44f1845..3314013 100644 --- a/src/sd_jwt/issuer.py +++ b/src/sd_jwt/issuer.py @@ -36,7 +36,7 @@ def __init__( sign_alg=None, add_decoy_claims: bool = False, serialization_format: str = "compact", - extra_header_parameters: Dict = None, + extra_header_parameters: dict = {}, ): super().__init__(serialization_format=serialization_format) @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims): # # If the user claims are a list, apply this function # to each item in the list. - if type(user_claims) is list: + if isinstance(user_claims, list): return self._create_sd_claims_list(user_claims) # If the user claims are a dictionary, apply this function # to each key/value pair in the dictionary. - elif type(user_claims) is dict: + elif isinstance(user_claims, dict): return self._create_sd_claims_object(user_claims) # For other types, assume that the value can be disclosed. - else: - if isinstance(user_claims, SDObj): - raise ValueError( - f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." - ) - return user_claims + elif isinstance(user_claims, SDObj): + raise ValueError( + f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." + ) + return user_claims def _create_sd_claims_list(self, user_claims: List): # Walk through all elements in the list. @@ -168,12 +167,13 @@ def _create_signed_jws(self): self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload)) - # Assemble protected headers - _protected_headers = {"alg": self._sign_alg} - if self.SD_JWT_TYP_HEADER: - _protected_headers["typ"] = self.SD_JWT_TYP_HEADER - if self._extra_header_parameters: - _protected_headers.update(self._extra_header_parameters) + # Assemble protected headers starting with default + _protected_headers = { + "alg": self._sign_alg, + "typ": self.SD_JWT_HEADER + } + # override if any + _protected_headers.update(self.extra_header_parameters) self.sd_jwt.add_signature( self._issuer_key,