diff --git a/pyproject.toml b/pyproject.toml index f3d3e8c..4cd3ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sd-jwt" -version = "0.9.1" +version = "0.10.0" description = "The reference implementation of the IETF SD-JWT specification." authors = ["Daniel Fett "] readme = "README.md" diff --git a/src/sd_jwt/__init__.py b/src/sd_jwt/__init__.py index d69d16e..61fb31c 100644 --- a/src/sd_jwt/__init__.py +++ b/src/sd_jwt/__init__.py @@ -1 +1 @@ -__version__ = "0.9.1" +__version__ = "0.10.0" diff --git a/src/sd_jwt/bin/generate.py b/src/sd_jwt/bin/generate.py index 01341c5..ad00641 100755 --- a/src/sd_jwt/bin/generate.py +++ b/src/sd_jwt/bin/generate.py @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str): use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") include_default_claims = testcase.get("include_default_claims", True) - extra_header_parameters = testcase.get("extra_header_parameters", None) + extra_header_parameters = testcase.get("extra_header_parameters", {}) claims = {} if include_default_claims: diff --git a/src/sd_jwt/common.py b/src/sd_jwt/common.py index 1398e69..0c88f4e 100644 --- a/src/sd_jwt/common.py +++ b/src/sd_jwt/common.py @@ -33,7 +33,7 @@ def __init__(self, error_location: any): class SDJWTCommon: - SD_JWT_TYP_HEADER = None # "sd+jwt" + SD_JWT_HEADER = "sd+jwt" KB_JWT_TYP_HEADER = "kb+jwt" JWS_KEY_DISCLOSURES = "disclosures" JWS_KEY_KB_JWT = "kb_jwt" @@ -91,14 +91,14 @@ def _create_hash_mappings(self, disclosurses_list: List): decoded_disclosure = loads( self._base64url_decode(disclosure).decode("utf-8") ) - hash = self._b64hash(disclosure.encode("ascii")) - if hash in self._hash_to_decoded_disclosure: + _hash = self._b64hash(disclosure.encode("ascii")) + if _hash in self._hash_to_decoded_disclosure: raise ValueError( - f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}" + f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}" ) - self._hash_to_decoded_disclosure[hash] = decoded_disclosure - self._hash_to_disclosure[hash] = disclosure + self._hash_to_decoded_disclosure[_hash] = decoded_disclosure + self._hash_to_disclosure[_hash] = disclosure def _check_for_sd_claim(self, the_object): # Recursively check for the presence of the _sd claim, also diff --git a/src/sd_jwt/holder.py b/src/sd_jwt/holder.py index 605f8c2..a63eec6 100644 --- a/src/sd_jwt/holder.py +++ b/src/sd_jwt/holder.py @@ -94,7 +94,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None) ): if ( - type(element) is dict + isinstance(element, dict) and len(element) == 1 and SD_LIST_PREFIX in element and type(element[SD_LIST_PREFIX]) is str @@ -116,11 +116,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): continue self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check]) - if type(disclosure_value) is dict: + if isinstance(disclosure_value, dict): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose_element = {} - if not type(claims_to_disclose_element) is dict: + if not isinstance(claims_to_disclose_element, dict): raise ValueError( f"To disclose object elements in arrays, provide an object (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -130,11 +130,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): self._select_disclosures( disclosure_value, claims_to_disclose_element ) - elif type(disclosure_value) is list: + elif isinstance(disclosure_value, list): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an array claims_to_disclose_element = [] - if not type(claims_to_disclose_element) is list: + if not isinstance(claims_to_disclose_element, list): raise ValueError( f"To disclose array elements nested in arrays, provide an array (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -155,7 +155,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): if claims_to_disclose is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose = {} - if not type(claims_to_disclose) is dict: + if not isinstance(claims_to_disclose, dict): raise ValueError( f"To disclose object elements, an object must be provided as disclosure information.\n" f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n" diff --git a/src/sd_jwt/issuer.py b/src/sd_jwt/issuer.py index 44f1845..5b49fdf 100644 --- a/src/sd_jwt/issuer.py +++ b/src/sd_jwt/issuer.py @@ -36,7 +36,7 @@ def __init__( sign_alg=None, add_decoy_claims: bool = False, serialization_format: str = "compact", - extra_header_parameters: Dict = None, + extra_header_parameters: dict = {}, ): super().__init__(serialization_format=serialization_format) @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims): # # If the user claims are a list, apply this function # to each item in the list. - if type(user_claims) is list: + if isinstance(user_claims, list): return self._create_sd_claims_list(user_claims) # If the user claims are a dictionary, apply this function # to each key/value pair in the dictionary. - elif type(user_claims) is dict: + elif isinstance(user_claims, dict): return self._create_sd_claims_object(user_claims) # For other types, assume that the value can be disclosed. - else: - if isinstance(user_claims, SDObj): - raise ValueError( - f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." - ) - return user_claims + elif isinstance(user_claims, SDObj): + raise ValueError( + f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." + ) + return user_claims def _create_sd_claims_list(self, user_claims: List): # Walk through all elements in the list. @@ -168,12 +167,13 @@ def _create_signed_jws(self): self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload)) - # Assemble protected headers - _protected_headers = {"alg": self._sign_alg} - if self.SD_JWT_TYP_HEADER: - _protected_headers["typ"] = self.SD_JWT_TYP_HEADER - if self._extra_header_parameters: - _protected_headers.update(self._extra_header_parameters) + # Assemble protected headers starting with default + _protected_headers = { + "alg": self._sign_alg, + "typ": self.SD_JWT_HEADER + } + # override if any + _protected_headers.update(self._extra_header_parameters) self.sd_jwt.add_signature( self._issuer_key, diff --git a/tests/test_disclose_all_shortcut.py b/tests/test_disclose_all_shortcut.py index e1eac5b..db4e319 100644 --- a/tests/test_disclose_all_shortcut.py +++ b/tests/test_disclose_all_shortcut.py @@ -10,7 +10,7 @@ def test_e2e(testcase, settings): demo_keys = get_jwk(settings["key_settings"], True, seed) use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") - extra_header_parameters = testcase.get("extra_header_parameters", None) + extra_header_parameters = testcase.get("extra_header_parameters", {}) # Issuer: Produce SD-JWT and issuance format for selected example @@ -59,7 +59,8 @@ def cb_get_issuer_key(issuer, header_parameters): assert verified == expected_claims expected_header_parameters = { - "alg": testcase.get("sign_alg", "ES256") + "alg": testcase.get("sign_alg", "ES256"), + "typ": "sd+jwt" } expected_header_parameters.update(extra_header_parameters or {}) diff --git a/tests/test_e2e_testcases.py b/tests/test_e2e_testcases.py index f6ca529..6f4a786 100644 --- a/tests/test_e2e_testcases.py +++ b/tests/test_e2e_testcases.py @@ -10,7 +10,7 @@ def test_e2e(testcase, settings): demo_keys = get_jwk(settings["key_settings"], True, seed) use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") - extra_header_parameters = testcase.get("extra_header_parameters", None) + extra_header_parameters = testcase.get("extra_header_parameters", {}) # Issuer: Produce SD-JWT and issuance format for selected example @@ -74,7 +74,8 @@ def cb_get_issuer_key(issuer, header_parameters): assert verified == expected_claims expected_header_parameters = { - "alg": testcase.get("sign_alg", "ES256") + "alg": testcase.get("sign_alg", "ES256"), + "typ": "sd+jwt" } expected_header_parameters.update(extra_header_parameters or {})