diff --git a/ber.go b/ber.go index 5852567..89e96d3 100644 --- a/ber.go +++ b/ber.go @@ -5,7 +5,7 @@ import ( "errors" ) -var encodeIndent = 0 +// var encodeIndent = 0 type asn1Object interface { EncodeTo(writer *bytes.Buffer) error @@ -18,7 +18,7 @@ type asn1Structured struct { func (s asn1Structured) EncodeTo(out *bytes.Buffer) error { //fmt.Printf("%s--> tag: % X\n", strings.Repeat("| ", encodeIndent), s.tagBytes) - encodeIndent++ + //encodeIndent++ inner := new(bytes.Buffer) for _, obj := range s.content { err := obj.EncodeTo(inner) @@ -26,7 +26,7 @@ func (s asn1Structured) EncodeTo(out *bytes.Buffer) error { return err } } - encodeIndent-- + //encodeIndent-- out.Write(s.tagBytes) encodeLength(out, inner.Len()) out.Write(inner.Bytes()) @@ -133,49 +133,35 @@ func encodeLength(out *bytes.Buffer, length int) (err error) { } func readObject(ber []byte, offset int) (asn1Object, int, error) { - berLen := len(ber) - if offset >= berLen { - return nil, 0, errors.New("ber2der: offset is after end of ber data") - } + //fmt.Printf("\n====> Starting readObject at offset: %d\n\n", offset) tagStart := offset b := ber[offset] offset++ - if offset >= berLen { - return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") - } tag := b & 0x1F // last 5 bits if tag == 0x1F { tag = 0 for ber[offset] >= 0x80 { tag = tag*128 + ber[offset] - 0x80 offset++ - if offset > berLen { - return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") - } } - // jvehent 20170227: this doesn't appear to be used anywhere... - //tag = tag*128 + ber[offset] - 0x80 + tag = tag*128 + ber[offset] - 0x80 offset++ - if offset > berLen { - return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") - } } tagEnd := offset kind := b & 0x20 - if kind == 0 { - debugprint("--> Primitive\n") - } else { - debugprint("--> Constructed\n") - } + /* + if kind == 0 { + fmt.Print("--> Primitive\n") + } else { + fmt.Print("--> Constructed\n") + } + */ // read length var length int l := ber[offset] offset++ - if offset > berLen { - return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") - } - hack := 0 + indefinite := false if l > 0x80 { numberOfBytes := (int)(l & 0x7F) if numberOfBytes > 4 { // int is only guaranteed to be 32bit @@ -184,42 +170,33 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) { if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F { return nil, 0, errors.New("ber2der: BER tag length is negative") } - if (int)(ber[offset]) == 0x0 { + if 0x0 == (int)(ber[offset]) { return nil, 0, errors.New("ber2der: BER tag length has leading zero") } - debugprint("--> (compute length) indicator byte: %x\n", l) - debugprint("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes]) + //fmt.Printf("--> (compute length) indicator byte: %x\n", l) + //fmt.Printf("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes]) for i := 0; i < numberOfBytes; i++ { length = length*256 + (int)(ber[offset]) offset++ - if offset > berLen { - return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") - } } } else if l == 0x80 { - // find length by searching content - markerIndex := bytes.LastIndex(ber[offset:], []byte{0x0, 0x0}) - if markerIndex == -1 { - return nil, 0, errors.New("ber2der: Invalid BER format") - } - length = markerIndex - hack = 2 - debugprint("--> (compute length) marker found at offset: %d\n", markerIndex+offset) + indefinite = true } else { length = (int)(l) } - if length < 0 { - return nil, 0, errors.New("ber2der: invalid negative value found in BER tag length") - } + //fmt.Printf("--> length : %d\n", length) contentEnd := offset + length if contentEnd > len(ber) { return nil, 0, errors.New("ber2der: BER tag length is more than available data") } - debugprint("--> content start : %d\n", offset) - debugprint("--> content end : %d\n", contentEnd) - debugprint("--> content : % X\n", ber[offset:contentEnd]) + //fmt.Printf("--> content start : %d\n", offset) + //fmt.Printf("--> content end : %d\n", contentEnd) + //fmt.Printf("--> content : % X\n", ber[offset:contentEnd]) var obj asn1Object + if indefinite && kind == 0 { + return nil, 0, errors.New("ber2der: Indefinite form tag must have constructed encoding") + } if kind == 0 { obj = asn1Primitive{ tagBytes: ber[tagStart:tagEnd], @@ -228,14 +205,25 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) { } } else { var subObjects []asn1Object - for offset < contentEnd { + for (offset < contentEnd) || indefinite { var subObj asn1Object var err error - subObj, offset, err = readObject(ber[:contentEnd], offset) + subObj, offset, err = readObject(ber, offset) if err != nil { return nil, 0, err } subObjects = append(subObjects, subObj) + + if indefinite { + terminated, err := isIndefiniteTermination(ber, offset) + if err != nil { + return nil, 0, err + } + + if terminated { + break + } + } } obj = asn1Structured{ tagBytes: ber[tagStart:tagEnd], @@ -243,9 +231,18 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) { } } - return obj, contentEnd + hack, nil + // Apply indefinite form length with 0x0000 terminator. + if indefinite { + contentEnd = offset + 2 + } + + return obj, contentEnd, nil } -func debugprint(format string, a ...interface{}) { - //fmt.Printf(format, a) +func isIndefiniteTermination(ber []byte, offset int) (bool, error) { + if len(ber) - offset < 2 { + return false, errors.New("ber2der: Invalid BER format") + } + + return bytes.Index(ber[offset:], []byte{0x0, 0x0}) == 0, nil } diff --git a/ber_test.go b/ber_test.go index fcb4b6a..19a0f51 100644 --- a/ber_test.go +++ b/ber_test.go @@ -15,7 +15,7 @@ func TestBer2Der(t *testing.T) { if err != nil { t.Fatalf("ber2der failed with error: %v", err) } - if !bytes.Equal(der, expected) { + if bytes.Compare(der, expected) != 0 { t.Errorf("ber2der result did not match.\n\tExpected: % X\n\tActual: % X", expected, der) } @@ -42,12 +42,11 @@ func TestBer2Der_Negatives(t *testing.T) { Input []byte ErrorContains string }{ - {[]byte{0x30, 0x85}, "tag length too long"}, + {[]byte{0x30, 0x85}, "length too long"}, {[]byte{0x30, 0x84, 0x80, 0x0, 0x0, 0x0}, "length is negative"}, {[]byte{0x30, 0x82, 0x0, 0x1}, "length has leading zero"}, - {[]byte{0x30, 0x80, 0x1, 0x2}, "Invalid BER format"}, + {[]byte{0x30, 0x80, 0x1, 0x2, 0x1, 0x2}, "Invalid BER format"}, {[]byte{0x30, 0x03, 0x01, 0x02}, "length is more than available data"}, - {[]byte{0x30}, "end of ber data reached"}, } for _, fixture := range fixtures { @@ -60,3 +59,39 @@ func TestBer2Der_Negatives(t *testing.T) { } } } + +func TestBer2Der_NestedMultipleIndefinite(t *testing.T) { + // indefinite length fixture + ber := []byte{0x30, 0x80, 0x30, 0x80, 0x02, 0x01, 0x01, 0x00, 0x00, 0x30, 0x80, 0x02, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00} + expected := []byte{0x30, 0x0A, 0x30, 0x03, 0x02, 0x01, 0x01, 0x30, 0x03, 0x02, 0x01, 0x02} + + der, err := ber2der(ber) + if err != nil { + t.Fatalf("ber2der failed with error: %v", err) + } + if bytes.Compare(der, expected) != 0 { + t.Errorf("ber2der result did not match.\n\tExpected: % X\n\tActual: % X", expected, der) + } + + if der2, err := ber2der(der); err != nil { + t.Errorf("ber2der on DER bytes failed with error: %v", err) + } else { + if !bytes.Equal(der, der2) { + t.Error("ber2der is not idempotent") + } + } + var thing struct { + Nest1 struct { + Number int + } + Nest2 struct { + Number int + } + } + rest, err := asn1.Unmarshal(der, &thing) + if err != nil { + t.Errorf("Cannot parse resulting DER because: %v", err) + } else if len(rest) > 0 { + t.Errorf("Resulting DER has trailing data: % X", rest) + } +} diff --git a/encrypt.go b/encrypt.go index 6b26557..23fe64e 100644 --- a/encrypt.go +++ b/encrypt.go @@ -60,7 +60,7 @@ const ( // ContentEncryptionAlgorithm determines the algorithm used to encrypt the // plaintext message. Change the value of this variable to change which // algorithm is used in the Encrypt() function. -var ContentEncryptionAlgorithm = EncryptionAlgorithmDESCBC +var ContentEncryptionAlgorithm = EncryptionAlgorithmAES128CBC // ErrUnsupportedEncryptionAlgorithm is returned when attempting to encrypt // content with an unsupported algorithm. diff --git a/pkcs7.go b/pkcs7.go index ccc6cc6..7f7cb1f 100644 --- a/pkcs7.go +++ b/pkcs7.go @@ -158,6 +158,7 @@ func Parse(data []byte) (p7 *PKCS7, err error) { } var info contentInfo der, err := ber2der(data) + if err != nil { return nil, err } diff --git a/verify.go b/verify.go index c8ead23..1eb93fa 100644 --- a/verify.go +++ b/verify.go @@ -53,6 +53,7 @@ func verifySignature(p7 *PKCS7, signer signerInfo, truststore *x509.CertPool) (e h := hash.New() h.Write(p7.Content) computed := h.Sum(nil) + if subtle.ConstantTimeCompare(digest, computed) != 1 { return &MessageDigestMismatchError{ ExpectedDigest: digest, @@ -113,6 +114,7 @@ func (p7 *PKCS7) UnmarshalSignedAttribute(attributeType asn1.ObjectIdentifier, o func parseSignedData(data []byte) (*PKCS7, error) { var sd signedData asn1.Unmarshal(data, &sd) + certs, err := sd.Certificates.Parse() if err != nil { return nil, err @@ -120,29 +122,45 @@ func parseSignedData(data []byte) (*PKCS7, error) { // fmt.Printf("--> Signed Data Version %d\n", sd.Version) var compound asn1.RawValue - var content unsignedData // The Content.Bytes maybe empty on PKI responses. + + if rest, err := asn1.Unmarshal(sd.ContentInfo.Content.Bytes, &compound); err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, errors.New("unexpected trailing data") + } + + if compound.Class != asn1.ClassUniversal || compound.Tag != asn1.TagOctetString { + return nil, errors.New("bad tag or class") + } + if len(sd.ContentInfo.Content.Bytes) > 0 { if _, err := asn1.Unmarshal(sd.ContentInfo.Content.Bytes, &compound); err != nil { return nil, err } } - // Compound octet string + + var value []byte + if compound.IsCompound { - if compound.Tag == 4 { - if _, err = asn1.Unmarshal(compound.Bytes, &content); err != nil { + rest := compound.Bytes + for len(rest) > 0 { + if rest, err = asn1.Unmarshal(rest, &compound); err != nil { return nil, err } - } else { - content = compound.Bytes + // Don't allow further constructed types. + if compound.Class != asn1.ClassUniversal || compound.Tag != asn1.TagOctetString || compound.IsCompound { + return nil, errors.New("bad class or tag") + } + value = append(value, compound.Bytes...) } } else { - // assuming this is tag 04 - content = compound.Bytes + value = compound.Bytes } + return &PKCS7{ - Content: content, + Content: value, Certificates: certs, CRLs: sd.CRLs, Signers: sd.SignerInfos, @@ -199,13 +217,19 @@ func getSignatureAlgorithm(digestEncryption, digest pkix.AlgorithmIdentifier) (x digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA384), digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA512): switch { - case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1): + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1), + // Some Microsoft issued certificates use the general RSA algorithm + // for the encrypt OID and the digest+algorithm OID for the digest. + digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA1): return x509.SHA1WithRSA, nil - case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256): + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256), + digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA256): return x509.SHA256WithRSA, nil - case digest.Algorithm.Equal(OIDDigestAlgorithmSHA384): + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA384), + digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA384): return x509.SHA384WithRSA, nil - case digest.Algorithm.Equal(OIDDigestAlgorithmSHA512): + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA512), + digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA512): return x509.SHA512WithRSA, nil default: return -1, fmt.Errorf("pkcs7: unsupported digest %q for encryption algorithm %q", @@ -214,7 +238,8 @@ func getSignatureAlgorithm(digestEncryption, digest pkix.AlgorithmIdentifier) (x case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmDSA), digestEncryption.Algorithm.Equal(OIDDigestAlgorithmDSASHA1): switch { - case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1): + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1), + digest.Algorithm.Equal(OIDDigestAlgorithmDSASHA1): return x509.DSAWithSHA1, nil case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256): return x509.DSAWithSHA256, nil