Skip to content

Commit

Permalink
creds/aws: Add support for DSA signature verification for EC2 (#12340) (
Browse files Browse the repository at this point in the history
#12361)

* creds/aws: import pkcs7 verification package

* Add DSA support

* changelog

* Add DSA to correct verify function

* Remove unneeded tests

* Fix backend test

* Update builtin/credential/aws/pkcs7/README.md

Co-authored-by: Calvin Leung Huang <[email protected]>

* Update builtin/credential/aws/path_login.go

Co-authored-by: Calvin Leung Huang <[email protected]>

Co-authored-by: Calvin Leung Huang <[email protected]>

Co-authored-by: Calvin Leung Huang <[email protected]>
  • Loading branch information
jasonodonnell and calvn authored Aug 20, 2021
1 parent b882dde commit 407532f
Show file tree
Hide file tree
Showing 18 changed files with 2,973 additions and 6 deletions.
17 changes: 17 additions & 0 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,22 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
}
}

// Configure additional metadata to be returned for ec2 logins.
identity := map[string]interface{}{
"ec2_metadata": []string{"instance_id", "region", "ami_id"},
}

// store the identity
_, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Storage: storage,
Path: "config/identity",
Data: identity,
})
if err != nil {
t.Fatal(err)
}

loginInput := map[string]interface{}{
"pkcs7": pkcs7,
"nonce": "vault-client-nonce",
Expand Down Expand Up @@ -1241,6 +1257,7 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
delete(loginInput, "pkcs7")
loginInput["identity"] = identityDoc
loginInput["signature"] = identityDocSig

resp, err = b.HandleRequest(context.Background(), loginRequest)
if err != nil {
t.Fatal(err)
Expand Down
6 changes: 3 additions & 3 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import (
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/fullsailor/pkcs7"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/awsutil"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
Expand Down Expand Up @@ -348,8 +348,8 @@ func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage,

// Verify extracts the authenticated attributes in the PKCS#7 signature, and verifies
// the authenticity of the content using 'dsa.PublicKey' embedded in the public certificate.
if pkcs7Data.Verify() != nil {
return nil, fmt.Errorf("failed to verify the signature")
if err := pkcs7Data.Verify(); err != nil {
return nil, fmt.Errorf("failed to verify the signature: %w", err)
}

// Check if the signature has content inside of it
Expand Down
5 changes: 5 additions & 0 deletions builtin/credential/aws/pkcs7/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# PKCS7

This code is used to verify PKCS7 signatures for the EC2 auth method. The code
was forked from [mozilla-services/pkcs7](https://github.com/mozilla-services/pkcs7)
and modified for Vault.
251 changes: 251 additions & 0 deletions builtin/credential/aws/pkcs7/ber.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package pkcs7

import (
"bytes"
"errors"
)

var encodeIndent = 0

type asn1Object interface {
EncodeTo(writer *bytes.Buffer) error
}

type asn1Structured struct {
tagBytes []byte
content []asn1Object
}

func (s asn1Structured) EncodeTo(out *bytes.Buffer) error {
//fmt.Printf("%s--> tag: % X\n", strings.Repeat("| ", encodeIndent), s.tagBytes)
encodeIndent++
inner := new(bytes.Buffer)
for _, obj := range s.content {
err := obj.EncodeTo(inner)
if err != nil {
return err
}
}
encodeIndent--
out.Write(s.tagBytes)
encodeLength(out, inner.Len())
out.Write(inner.Bytes())
return nil
}

type asn1Primitive struct {
tagBytes []byte
length int
content []byte
}

func (p asn1Primitive) EncodeTo(out *bytes.Buffer) error {
_, err := out.Write(p.tagBytes)
if err != nil {
return err
}
if err = encodeLength(out, p.length); err != nil {
return err
}
//fmt.Printf("%s--> tag: % X length: %d\n", strings.Repeat("| ", encodeIndent), p.tagBytes, p.length)
//fmt.Printf("%s--> content length: %d\n", strings.Repeat("| ", encodeIndent), len(p.content))
out.Write(p.content)

return nil
}

func ber2der(ber []byte) ([]byte, error) {
if len(ber) == 0 {
return nil, errors.New("ber2der: input ber is empty")
}
//fmt.Printf("--> ber2der: Transcoding %d bytes\n", len(ber))
out := new(bytes.Buffer)

obj, _, err := readObject(ber, 0)
if err != nil {
return nil, err
}
obj.EncodeTo(out)

// if offset < len(ber) {
// return nil, fmt.Errorf("ber2der: Content longer than expected. Got %d, expected %d", offset, len(ber))
//}

return out.Bytes(), nil
}

// encodes lengths that are longer than 127 into string of bytes
func marshalLongLength(out *bytes.Buffer, i int) (err error) {
n := lengthLength(i)

for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}

return nil
}

// computes the byte length of an encoded length value
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}

// encodes the length in DER format
// If the length fits in 7 bits, the value is encoded directly.
//
// Otherwise, the number of bytes to encode the length is first determined.
// This number is likely to be 4 or less for a 32bit length. This number is
// added to 0x80. The length is encoded in big endian encoding follow after
//
// Examples:
// length | byte 1 | bytes n
// 0 | 0x00 | -
// 120 | 0x78 | -
// 200 | 0x81 | 0xC8
// 500 | 0x82 | 0x01 0xF4
//
func encodeLength(out *bytes.Buffer, length int) (err error) {
if length >= 128 {
l := lengthLength(length)
err = out.WriteByte(0x80 | byte(l))
if err != nil {
return
}
err = marshalLongLength(out, length)
if err != nil {
return
}
} else {
err = out.WriteByte(byte(length))
if err != nil {
return
}
}
return
}

func readObject(ber []byte, offset int) (asn1Object, int, error) {
berLen := len(ber)
if offset >= berLen {
return nil, 0, errors.New("ber2der: offset is after end of ber data")
}
tagStart := offset
b := ber[offset]
offset++
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
tag := b & 0x1F // last 5 bits
if tag == 0x1F {
tag = 0
for ber[offset] >= 0x80 {
tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
// jvehent 20170227: this doesn't appear to be used anywhere...
//tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
tagEnd := offset

kind := b & 0x20
if kind == 0 {
debugprint("--> Primitive\n")
} else {
debugprint("--> Constructed\n")
}
// read length
var length int
l := ber[offset]
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
hack := 0
if l > 0x80 {
numberOfBytes := (int)(l & 0x7F)
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
return nil, 0, errors.New("ber2der: BER tag length too long")
}
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
return nil, 0, errors.New("ber2der: BER tag length is negative")
}
if (int)(ber[offset]) == 0x0 {
return nil, 0, errors.New("ber2der: BER tag length has leading zero")
}
debugprint("--> (compute length) indicator byte: %x\n", l)
debugprint("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes])
for i := 0; i < numberOfBytes; i++ {
length = length*256 + (int)(ber[offset])
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
} else if l == 0x80 {
// find length by searching content
markerIndex := bytes.LastIndex(ber[offset:], []byte{0x0, 0x0})
if markerIndex == -1 {
return nil, 0, errors.New("ber2der: Invalid BER format")
}
length = markerIndex
hack = 2
debugprint("--> (compute length) marker found at offset: %d\n", markerIndex+offset)
} else {
length = (int)(l)
}
if length < 0 {
return nil, 0, errors.New("ber2der: invalid negative value found in BER tag length")
}
//fmt.Printf("--> length : %d\n", length)
contentEnd := offset + length
if contentEnd > len(ber) {
return nil, 0, errors.New("ber2der: BER tag length is more than available data")
}
debugprint("--> content start : %d\n", offset)
debugprint("--> content end : %d\n", contentEnd)
debugprint("--> content : % X\n", ber[offset:contentEnd])
var obj asn1Object
if kind == 0 {
obj = asn1Primitive{
tagBytes: ber[tagStart:tagEnd],
length: length,
content: ber[offset:contentEnd],
}
} else {
var subObjects []asn1Object
for offset < contentEnd {
var subObj asn1Object
var err error
subObj, offset, err = readObject(ber[:contentEnd], offset)
if err != nil {
return nil, 0, err
}
subObjects = append(subObjects, subObj)
}
obj = asn1Structured{
tagBytes: ber[tagStart:tagEnd],
content: subObjects,
}
}

return obj, contentEnd + hack, nil
}

func debugprint(format string, a ...interface{}) {
//fmt.Printf(format, a)
}
62 changes: 62 additions & 0 deletions builtin/credential/aws/pkcs7/ber_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package pkcs7

import (
"bytes"
"encoding/asn1"
"strings"
"testing"
)

func TestBer2Der(t *testing.T) {
// indefinite length fixture
ber := []byte{0x30, 0x80, 0x02, 0x01, 0x01, 0x00, 0x00}
expected := []byte{0x30, 0x03, 0x02, 0x01, 0x01}
der, err := ber2der(ber)
if err != nil {
t.Fatalf("ber2der failed with error: %v", err)
}
if !bytes.Equal(der, expected) {
t.Errorf("ber2der result did not match.\n\tExpected: % X\n\tActual: % X", expected, der)
}

if der2, err := ber2der(der); err != nil {
t.Errorf("ber2der on DER bytes failed with error: %v", err)
} else {
if !bytes.Equal(der, der2) {
t.Error("ber2der is not idempotent")
}
}
var thing struct {
Number int
}
rest, err := asn1.Unmarshal(der, &thing)
if err != nil {
t.Errorf("Cannot parse resulting DER because: %v", err)
} else if len(rest) > 0 {
t.Errorf("Resulting DER has trailing data: % X", rest)
}
}

func TestBer2Der_Negatives(t *testing.T) {
fixtures := []struct {
Input []byte
ErrorContains string
}{
{[]byte{0x30, 0x85}, "tag length too long"},
{[]byte{0x30, 0x84, 0x80, 0x0, 0x0, 0x0}, "length is negative"},
{[]byte{0x30, 0x82, 0x0, 0x1}, "length has leading zero"},
{[]byte{0x30, 0x80, 0x1, 0x2}, "Invalid BER format"},
{[]byte{0x30, 0x03, 0x01, 0x02}, "length is more than available data"},
{[]byte{0x30}, "end of ber data reached"},
}

for _, fixture := range fixtures {
_, err := ber2der(fixture.Input)
if err == nil {
t.Errorf("No error thrown. Expected: %s", fixture.ErrorContains)
}
if !strings.Contains(err.Error(), fixture.ErrorContains) {
t.Errorf("Unexpected error thrown.\n\tExpected: /%s/\n\tActual: %s", fixture.ErrorContains, err.Error())
}
}
}
Loading

0 comments on commit 407532f

Please sign in to comment.