diff --git a/cwt_test.go b/cwt_test.go index 7269b95..351e7bb 100644 --- a/cwt_test.go +++ b/cwt_test.go @@ -17,10 +17,10 @@ func ExampleCWTMessage() { msgToSign.Payload = []byte("hello world") msgToSign.Headers.Protected.SetAlgorithm(cose.AlgorithmES512) + msgToSign.Headers.Protected.SetType("application/cwt") claims := make(cose.CWTClaims) claims[cose.CWTClaimIssuer] = "issuer.example" claims[cose.CWTClaimSubject] = "subject.example" - msgToSign.Headers.Protected.SetCWTClaims(claims) msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1") @@ -41,12 +41,14 @@ func ExampleCWTMessage() { panic(err) } sig, err := msgToSign.MarshalCBOR() + // uncomment to review EDN + // coseSign1Diagnostic, err := cbor.Diagnose(sig) + // fmt.Println(coseSign1Diagnostic) if err != nil { panic(err) } fmt.Println("message signed") - // create a verifier from a trusted public key publicKey := privateKey.Public() verifier, err := cose.NewVerifier(cose.AlgorithmES512, publicKey) diff --git a/headers.go b/headers.go index 119ea7b..151ae32 100644 --- a/headers.go +++ b/headers.go @@ -24,6 +24,7 @@ const ( HeaderLabelCounterSignatureV2 int64 = 11 HeaderLabelCounterSignature0V2 int64 = 12 HeaderLabelCWTClaims int64 = 15 + HeaderLabelType int64 = 16 HeaderLabelX5Bag int64 = 32 HeaderLabelX5Chain int64 = 33 HeaderLabelX5T int64 = 34 @@ -103,10 +104,28 @@ func (h ProtectedHeader) SetAlgorithm(alg Algorithm) { h[HeaderLabelAlgorithm] = alg } +// SetType sets the type of the cose object in the protected header. +func (h ProtectedHeader) SetType(typ any) (any, error) { + if !canTstr(typ) && !canUint(typ) { + return typ, errors.New("header parameter: type: require tstr / uint type") + } + h[HeaderLabelType] = typ + return typ, nil +} + // SetCWTClaims sets the CWT Claims value of the protected header. -func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) { - // TODO: validate claims, for example ensuring that 1 and 2 are tstr, not bstr +func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) { + iss, hasIss := claims[1] + if hasIss && !canTstr(iss) { + return claims, errors.New("cwt claim: iss: require tstr") + } + sub, hasSub := claims[2] + if hasSub && !canTstr(sub) { + return claims, errors.New("cwt claim: sub: require tstr") + } + // TODO: validate claims, other claims h[HeaderLabelCWTClaims] = claims + return claims, nil } // Algorithm gets the algorithm value from the algorithm header. @@ -478,6 +497,25 @@ func validateHeaderParameters(h map[any]any, protected bool) error { if err := ensureCritical(value, h); err != nil { return fmt.Errorf("header parameter: crit: %w", err) } + case HeaderLabelType: + is_tstr := canTstr(value) + if !is_tstr && !canUint(value) { + return errors.New("header parameter: type: require tstr / uint type") + } + if is_tstr { + v := value.(string) + if len(v) == 0 { + return errors.New("header parameter: type: require non-empty string") + } + if v[0] == ' ' || v[len(v)-1] == ' ' { + return errors.New("header parameter: type: require no leading/trailing whitespace") + } + // Basic check that the content type is of form type/subtype. + // We don't check the precise definition though (RFC 6838 Section 4.2). + if strings.Count(v, "/") != 1 { + return errors.New("header parameter: type: require text of form type/subtype") + } + } case HeaderLabelContentType: is_tstr := canTstr(value) if !is_tstr && !canUint(value) {