diff --git a/encrypt.go b/encrypt.go index 5fa2d67..da57ae6 100644 --- a/encrypt.go +++ b/encrypt.go @@ -42,6 +42,14 @@ const ( // EncryptionAlgorithmDESCBC is the DES CBC encryption algorithm EncryptionAlgorithmDESCBC = iota + // EncryptionAlgorithmAES128CBC is the AES 128 bits with CBC encryption algorithm + // Avoid this algorithm unless required for interoperability; use AES GCM instead. + EncryptionAlgorithmAES128CBC + + // EncryptionAlgorithmAES256CBC is the AES 256 bits with CBC encryption algorithm + // Avoid this algorithm unless required for interoperability; use AES GCM instead. + EncryptionAlgorithmAES256CBC + // EncryptionAlgorithmAES128GCM is the AES 128 bits with GCM encryption algorithm EncryptionAlgorithmAES128GCM @@ -56,7 +64,7 @@ var ContentEncryptionAlgorithm = EncryptionAlgorithmDESCBC // ErrUnsupportedEncryptionAlgorithm is returned when attempting to encrypt // content with an unsupported algorithm. -var ErrUnsupportedEncryptionAlgorithm = errors.New("pkcs7: cannot encrypt content: only DES-CBC and AES-GCM supported") +var ErrUnsupportedEncryptionAlgorithm = errors.New("pkcs7: cannot encrypt content: only DES-CBC, AES-CBC, and AES-GCM supported") // ErrPSKNotProvided is returned when attempting to encrypt // using a PSK without actually providing the PSK. @@ -72,12 +80,15 @@ type aesGCMParameters struct { func encryptAESGCM(content []byte, key []byte) ([]byte, *encryptedContentInfo, error) { var keyLen int var algID asn1.ObjectIdentifier - if ContentEncryptionAlgorithm == EncryptionAlgorithmAES128GCM { + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmAES128GCM: keyLen = 16 algID = OIDEncryptionAlgorithmAES128GCM - } else { + case EncryptionAlgorithmAES256GCM: keyLen = 32 algID = OIDEncryptionAlgorithmAES256GCM + default: + return nil, nil, fmt.Errorf("invalid ContentEncryptionAlgorithm in encryptAESGCM: %d", ContentEncryptionAlgorithm) } if key == nil { // Create AES key @@ -180,6 +191,63 @@ func encryptDESCBC(content []byte, key []byte) ([]byte, *encryptedContentInfo, e return key, &eci, nil } +func encryptAESCBC(content []byte, key []byte) ([]byte, *encryptedContentInfo, error) { + var keyLen int + var algID asn1.ObjectIdentifier + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmAES128CBC: + keyLen = 16 + algID = OIDEncryptionAlgorithmAES128CBC + case EncryptionAlgorithmAES256CBC: + keyLen = 32 + algID = OIDEncryptionAlgorithmAES256CBC + default: + return nil, nil, fmt.Errorf("invalid ContentEncryptionAlgorithm in encryptAESCBC: %d", ContentEncryptionAlgorithm) + } + + if key == nil { + // Create AES key + key = make([]byte, keyLen) + + _, err := rand.Read(key) + if err != nil { + return nil, nil, err + } + } + + // Create CBC IV + iv := make([]byte, aes.BlockSize) + _, err := rand.Read(iv) + if err != nil { + return nil, nil, err + } + + // Encrypt padded content + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + plaintext, err := pad(content, mode.BlockSize()) + if err != nil { + return nil, nil, err + } + cyphertext := make([]byte, len(plaintext)) + mode.CryptBlocks(cyphertext, plaintext) + + // Prepare ASN.1 Encrypted Content Info + eci := encryptedContentInfo{ + ContentType: OIDData, + ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{ + Algorithm: algID, + Parameters: asn1.RawValue{Tag: 4, Bytes: iv}, + }, + EncryptedContent: marshalEncryptedContent(cyphertext), + } + + return key, &eci, nil +} + // Encrypt creates and returns an envelope data PKCS7 structure with encrypted // recipient keys for each recipient public key. // @@ -200,7 +268,10 @@ func Encrypt(content []byte, recipients []*x509.Certificate) ([]byte, error) { switch ContentEncryptionAlgorithm { case EncryptionAlgorithmDESCBC: key, eci, err = encryptDESCBC(content, nil) - + case EncryptionAlgorithmAES128CBC: + fallthrough + case EncryptionAlgorithmAES256CBC: + key, eci, err = encryptAESCBC(content, nil) case EncryptionAlgorithmAES128GCM: fallthrough case EncryptionAlgorithmAES256GCM: diff --git a/encrypt_test.go b/encrypt_test.go index e8a206a..c64381e 100644 --- a/encrypt_test.go +++ b/encrypt_test.go @@ -9,6 +9,8 @@ import ( func TestEncrypt(t *testing.T) { modes := []int{ EncryptionAlgorithmDESCBC, + EncryptionAlgorithmAES128CBC, + EncryptionAlgorithmAES256CBC, EncryptionAlgorithmAES128GCM, EncryptionAlgorithmAES256GCM, }