diff --git a/service/glacier/treehash.go b/service/glacier/treehash.go index 1d7534fbde8..1e62c565ebe 100644 --- a/service/glacier/treehash.go +++ b/service/glacier/treehash.go @@ -55,25 +55,36 @@ func ComputeHashes(r io.ReadSeeker) Hash { // // See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information. func ComputeTreeHash(hashes [][]byte) []byte { - if hashes == nil || len(hashes) == 0 { + hashCount := len(hashes) + switch hashCount { + case 0: return nil + case 1: + return hashes[0] } - - for len(hashes) > 1 { - tmpHashes := [][]byte{} - - for i := 0; i < len(hashes); i += 2 { - if i+1 <= len(hashes)-1 { - tmpHash := append(append([]byte{}, hashes[i]...), hashes[i+1]...) - tmpSum := sha256.Sum256(tmpHash) - tmpHashes = append(tmpHashes, tmpSum[:]) - } else { - tmpHashes = append(tmpHashes, hashes[i]) + leaves := make([][32]byte, hashCount) + for i := range leaves { + copy(leaves[i][:], hashes[i]) + } + var ( + queue = leaves[:0] + h256 = sha256.New() + buf [32]byte + ) + for len(leaves) > 1 { + for i := 0; i < len(leaves); i += 2 { + if i+1 == len(leaves) { + queue = append(queue, leaves[i]) + break } + h256.Write(leaves[i][:]) + h256.Write(leaves[i+1][:]) + h256.Sum(buf[:0]) + queue = append(queue, buf) + h256.Reset() } - - hashes = tmpHashes + leaves = queue + queue = queue[:0] } - - return hashes[0] + return leaves[0][:] } diff --git a/service/glacier/treehash_test.go b/service/glacier/treehash_test.go index 46f0facdb85..af9dbd782c4 100644 --- a/service/glacier/treehash_test.go +++ b/service/glacier/treehash_test.go @@ -3,8 +3,10 @@ package glacier_test import ( "bytes" "crypto/sha256" + "encoding/hex" "fmt" "io" + "testing" "github.com/aws/aws-sdk-go/service/glacier" ) @@ -61,3 +63,52 @@ func ExampleComputeTreeHash() { // Output: // TreeHash: 154e26c78fd74d0c2c9b3cc4644191619dc4f2cd539ae2a74d5fd07957a3ee6a } + +func TestComputeHashes(t *testing.T) { + + t.Run("no hash", func(t *testing.T) { + var hashes [][]byte + tree := glacier.ComputeTreeHash(hashes) + if tree != nil { + t.Fatalf("expected []byte(nil), got %v", tree) + } + }) + + t.Run("one hash", func(t *testing.T) { + hash := sha256.Sum256([]byte("hash")) + tree := glacier.ComputeTreeHash([][]byte{hash[:]}) + + expected, actual := hex.EncodeToString(hash[:]), hex.EncodeToString(tree) + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + + t.Run("even hashes", func(t *testing.T) { + h1 := sha256.Sum256([]byte("h1")) + h2 := sha256.Sum256([]byte("h2")) + tree := glacier.ComputeTreeHash([][]byte{h1[:], h2[:]}) + var ( + expected = "0228c4e26bfc81adb535b2809bdfb1929d9b1cb05c1b2c60a8a4904edbe88ba1" + actual = hex.EncodeToString(tree) + ) + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + + t.Run("odd hashes", func(t *testing.T) { + h1 := sha256.Sum256([]byte("h1")) + h2 := sha256.Sum256([]byte("h2")) + h3 := sha256.Sum256([]byte("h3")) + tree := glacier.ComputeTreeHash([][]byte{h1[:], h2[:], h3[:]}) + var ( + expected = "81cbe4a143be0fcffe9d4c3d90db3a2963c154783f54fa19cb1e8912f3ca5724" + actual = hex.EncodeToString(tree) + ) + if expected != actual { + t.Fatalf("expected %v, got %v", expected, actual) + } + }) + +}