diff --git a/bao.go b/bao/bao.go similarity index 60% rename from bao.go rename to bao/bao.go index cd4435b..beda376 100644 --- a/bao.go +++ b/bao/bao.go @@ -1,4 +1,5 @@ -package blake3 +// Package bao implements BLAKE3 verified streaming. +package bao import ( "bytes" @@ -6,45 +7,63 @@ import ( "errors" "io" "math/bits" + + "lukechampine.com/blake3/guts" ) -func compressGroup(p []byte, counter uint64) node { - var stack [54 - maxSIMD][8]uint32 +func bytesToCV(b []byte) (cv [8]uint32) { + _ = b[31] // bounds check hint + for i := range cv { + cv[i] = binary.LittleEndian.Uint32(b[4*i:]) + } + return cv +} + +func cvToBytes(cv *[8]uint32) *[32]byte { + var b [32]byte + for i, w := range cv { + binary.LittleEndian.PutUint32(b[4*i:], w) + } + return &b +} + +func compressGroup(p []byte, counter uint64) guts.Node { + var stack [54 - guts.MaxSIMD][8]uint32 var sc uint64 pushSubtree := func(cv [8]uint32) { i := 0 for sc&(1< 0 { if buflen == len(buf) { - pushSubtree(chainingValue(compressBuffer(&buf, buflen, &iv, counter+(sc*maxSIMD), 0))) + pushSubtree(guts.ChainingValue(guts.CompressBuffer(&buf, buflen, &guts.IV, counter+(sc*guts.MaxSIMD), 0))) buflen = 0 } n := copy(buf[buflen:], p) buflen += n p = p[n:] } - n := compressBuffer(&buf, buflen, &iv, counter+(sc*maxSIMD), 0) + n := guts.CompressBuffer(&buf, buflen, &guts.IV, counter+(sc*guts.MaxSIMD), 0) for i := bits.TrailingZeros64(sc); i < bits.Len64(sc); i++ { if sc&(1< 0 { chunks := (dataLen + groupSize - 1) / groupSize @@ -57,16 +76,16 @@ func BaoEncodedSize(dataLen int, group int, outboard bool) int { return size } -// BaoEncode computes the intermediate BLAKE3 tree hashes of data and writes -// them to dst. If outboard is false, the contents of data are also written to -// dst, interleaved with the tree hashes. It also returns the tree root, i.e. -// the 256-bit BLAKE3 hash. The group parameter controls how many chunks are -// hashed per "group," as a power of 2; for standard Bao, use 0. +// Encode computes the intermediate BLAKE3 tree hashes of data and writes them +// to dst. If outboard is false, the contents of data are also written to dst, +// interleaved with the tree hashes. It also returns the tree root, i.e. the +// 256-bit BLAKE3 hash. The group parameter controls how many chunks are hashed +// per "group," as a power of 2; for standard Bao, use 0. // // Note that dst is not written sequentially, and therefore must be initialized -// with sufficient capacity to hold the encoding; see BaoEncodedSize. -func BaoEncode(dst io.WriterAt, data io.Reader, dataLen int64, group int, outboard bool) ([32]byte, error) { - groupSize := uint64(chunkSize << group) +// with sufficient capacity to hold the encoding; see EncodedSize. +func Encode(dst io.WriterAt, data io.Reader, dataLen int64, group int, outboard bool) ([32]byte, error) { + groupSize := uint64(guts.ChunkSize << group) buf := make([]byte, groupSize) var err error read := func(p []byte) []byte { @@ -97,9 +116,9 @@ func BaoEncode(dst io.WriterAt, data io.Reader, dataLen int64, group int, outboa write(g, off) } n := compressGroup(g, counter) - counter += bufLen / chunkSize - n.flags |= flags - return 0, chainingValue(n) + counter += bufLen / guts.ChunkSize + n.Flags |= flags + return 0, guts.ChainingValue(n) } mid := uint64(1) << (bits.Len64(bufLen-1) - 1) lchildren, l := rec(mid, 0, off+64) @@ -110,23 +129,23 @@ func BaoEncode(dst io.WriterAt, data io.Reader, dataLen int64, group int, outboa rchildren, r := rec(bufLen-mid, 0, off+64+llen) write(cvToBytes(&l)[:], off) write(cvToBytes(&r)[:], off+32) - return 2 + lchildren + rchildren, chainingValue(parentNode(l, r, iv, flags)) + return 2 + lchildren + rchildren, guts.ChainingValue(guts.ParentNode(l, r, &guts.IV, flags)) } binary.LittleEndian.PutUint64(buf[:8], uint64(dataLen)) write(buf[:8], 0) - _, root := rec(uint64(dataLen), flagRoot, 8) + _, root := rec(uint64(dataLen), guts.FlagRoot, 8) return *cvToBytes(&root), err } -// BaoDecode reads content and tree data from the provided reader(s), and +// Decode reads content and tree data from the provided reader(s), and // streams the verified content to dst. It returns false if verification fails. // If the content and tree data are interleaved, outboard should be nil. -func BaoDecode(dst io.Writer, data, outboard io.Reader, group int, root [32]byte) (bool, error) { +func Decode(dst io.Writer, data, outboard io.Reader, group int, root [32]byte) (bool, error) { if outboard == nil { outboard = data } - groupSize := uint64(chunkSize << group) + groupSize := uint64(guts.ChunkSize << group) buf := make([]byte, groupSize) var err error read := func(r io.Reader, p []byte) []byte { @@ -151,23 +170,23 @@ func BaoDecode(dst io.Writer, data, outboard io.Reader, group int, root [32]byte return false } else if bufLen <= groupSize { n := compressGroup(read(data, buf[:bufLen]), counter) - counter += bufLen / chunkSize - n.flags |= flags - valid := cv == chainingValue(n) + counter += bufLen / guts.ChunkSize + n.Flags |= flags + valid := cv == guts.ChainingValue(n) if valid { write(dst, buf[:bufLen]) } return valid } l, r := readParent() - n := parentNode(l, r, iv, flags) + n := guts.ParentNode(l, r, &guts.IV, flags) mid := uint64(1) << (bits.Len64(bufLen-1) - 1) - return chainingValue(n) == cv && rec(l, mid, 0) && rec(r, bufLen-mid, 0) + return guts.ChainingValue(n) == cv && rec(l, mid, 0) && rec(r, bufLen-mid, 0) } read(outboard, buf[:8]) dataLen := binary.LittleEndian.Uint64(buf[:8]) - ok := rec(bytesToCV(root[:]), dataLen, flagRoot) + ok := rec(bytesToCV(root[:]), dataLen, guts.FlagRoot) return ok, err } @@ -182,34 +201,34 @@ func (b *bufferAt) WriteAt(p []byte, off int64) (int, error) { return len(p), nil } -// BaoEncodeBuf returns the Bao encoding and root (i.e. BLAKE3 hash) for data. -func BaoEncodeBuf(data []byte, group int, outboard bool) ([]byte, [32]byte) { - buf := bufferAt{buf: make([]byte, BaoEncodedSize(len(data), group, outboard))} - root, _ := BaoEncode(&buf, bytes.NewReader(data), int64(len(data)), group, outboard) +// EncodeBuf returns the Bao encoding and root (i.e. BLAKE3 hash) for data. +func EncodeBuf(data []byte, group int, outboard bool) ([]byte, [32]byte) { + buf := bufferAt{buf: make([]byte, EncodedSize(len(data), group, outboard))} + root, _ := Encode(&buf, bytes.NewReader(data), int64(len(data)), group, outboard) return buf.buf, root } -// BaoVerifyBuf verifies the Bao encoding and root (i.e. BLAKE3 hash) for data. +// VerifyBuf verifies the Bao encoding and root (i.e. BLAKE3 hash) for data. // If the content and tree data are interleaved, outboard should be nil. -func BaoVerifyBuf(data, outboard []byte, group int, root [32]byte) bool { +func VerifyBuf(data, outboard []byte, group int, root [32]byte) bool { d, o := bytes.NewBuffer(data), bytes.NewBuffer(outboard) var or io.Reader = o if outboard == nil { or = nil } - ok, _ := BaoDecode(io.Discard, d, or, group, root) + ok, _ := Decode(io.Discard, d, or, group, root) return ok && d.Len() == 0 && o.Len() == 0 // check for trailing data } -// BaoExtractSlice returns the slice encoding for the given offset and length. -// When extracting from an outboard encoding, data should contain only the chunk +// ExtractSlice returns the slice encoding for the given offset and length. When +// extracting from an outboard encoding, data should contain only the chunk // groups that will be present in the slice. -func BaoExtractSlice(dst io.Writer, data, outboard io.Reader, group int, offset uint64, length uint64) error { +func ExtractSlice(dst io.Writer, data, outboard io.Reader, group int, offset uint64, length uint64) error { combinedEncoding := outboard == nil if combinedEncoding { outboard = data } - groupSize := uint64(chunkSize << group) + groupSize := uint64(guts.ChunkSize << group) buf := make([]byte, groupSize) var err error read := func(r io.Reader, n uint64, copy bool) { @@ -245,11 +264,11 @@ func BaoExtractSlice(dst io.Writer, data, outboard io.Reader, group int, offset return err } -// BaoDecodeSlice reads from data, which must contain a slice encoding for the +// DecodeSlice reads from data, which must contain a slice encoding for the // given offset and length, and streams verified content to dst. It returns // false if verification fails. -func BaoDecodeSlice(dst io.Writer, data io.Reader, group int, offset, length uint64, root [32]byte) (bool, error) { - groupSize := uint64(chunkSize << group) +func DecodeSlice(dst io.Writer, data io.Reader, group int, offset, length uint64, root [32]byte) (bool, error) { + groupSize := uint64(guts.ChunkSize << group) buf := make([]byte, groupSize) var err error read := func(n uint64) []byte { @@ -276,9 +295,9 @@ func BaoDecodeSlice(dst io.Writer, data io.Reader, group int, offset, length uin if !inSlice { return true } - n := compressGroup(read(bufLen), pos/chunkSize) - n.flags |= flags - valid := cv == chainingValue(n) + n := compressGroup(read(bufLen), pos/guts.ChunkSize) + n.Flags |= flags + valid := cv == guts.ChainingValue(n) if valid { // only write within range p := buf[:bufLen] @@ -296,36 +315,35 @@ func BaoDecodeSlice(dst io.Writer, data io.Reader, group int, offset, length uin return true } l, r := readParent() - n := parentNode(l, r, iv, flags) + n := guts.ParentNode(l, r, &guts.IV, flags) mid := uint64(1) << (bits.Len64(bufLen-1) - 1) - return chainingValue(n) == cv && rec(l, pos, mid, 0) && rec(r, pos+mid, bufLen-mid, 0) + return guts.ChainingValue(n) == cv && rec(l, pos, mid, 0) && rec(r, pos+mid, bufLen-mid, 0) } dataLen := binary.LittleEndian.Uint64(read(8)) if dataLen < offset+length { return false, errors.New("invalid slice length") } - ok := rec(bytesToCV(root[:]), 0, dataLen, flagRoot) + ok := rec(bytesToCV(root[:]), 0, dataLen, guts.FlagRoot) return ok, err } -// BaoVerifySlice verifies the Bao slice encoding in data, returning the +// VerifySlice verifies the Bao slice encoding in data, returning the // verified bytes. -func BaoVerifySlice(data []byte, group int, offset uint64, length uint64, root [32]byte) ([]byte, bool) { +func VerifySlice(data []byte, group int, offset uint64, length uint64, root [32]byte) ([]byte, bool) { d := bytes.NewBuffer(data) var buf bytes.Buffer - if ok, _ := BaoDecodeSlice(&buf, d, group, offset, length, root); !ok || d.Len() > 0 { + if ok, _ := DecodeSlice(&buf, d, group, offset, length, root); !ok || d.Len() > 0 { return nil, false } return buf.Bytes(), true } -// BaoVerifyChunks verifies the provided chunks using the provided outboard -// encoding, -func BaoVerifyChunk(chunks, outboard []byte, group int, offset uint64, root [32]byte) bool { +// VerifyChunks verifies the provided chunks with a full outboard encoding. +func VerifyChunk(chunks, outboard []byte, group int, offset uint64, root [32]byte) bool { cbuf := bytes.NewBuffer(chunks) obuf := bytes.NewBuffer(outboard) - groupSize := uint64(chunkSize << group) + groupSize := uint64(guts.ChunkSize << group) length := uint64(len(chunks)) nodesWithin := func(bufLen uint64) int { n := int(bufLen / groupSize) @@ -342,18 +360,18 @@ func BaoVerifyChunk(chunks, outboard []byte, group int, offset uint64, root [32] if !inSlice { return true } - n := compressGroup(cbuf.Next(int(groupSize)), pos/chunkSize) - n.flags |= flags - return cv == chainingValue(n) + n := compressGroup(cbuf.Next(int(groupSize)), pos/guts.ChunkSize) + n.Flags |= flags + return cv == guts.ChainingValue(n) } if !inSlice { _ = obuf.Next(64 * nodesWithin(bufLen)) // skip return true } l, r := bytesToCV(obuf.Next(32)), bytesToCV(obuf.Next(32)) - n := parentNode(l, r, iv, flags) + n := guts.ParentNode(l, r, &guts.IV, flags) mid := uint64(1) << (bits.Len64(bufLen-1) - 1) - return chainingValue(n) == cv && rec(l, pos, mid, 0) && rec(r, pos+mid, bufLen-mid, 0) + return guts.ChainingValue(n) == cv && rec(l, pos, mid, 0) && rec(r, pos+mid, bufLen-mid, 0) } if obuf.Len() < 8 { @@ -363,5 +381,5 @@ func BaoVerifyChunk(chunks, outboard []byte, group int, offset uint64, root [32] if dataLen < offset+length || obuf.Len() != 64*nodesWithin(dataLen) { return false } - return rec(bytesToCV(root[:]), 0, dataLen, flagRoot) + return rec(bytesToCV(root[:]), 0, dataLen, guts.FlagRoot) } diff --git a/bao_test.go b/bao/bao_test.go similarity index 73% rename from bao_test.go rename to bao/bao_test.go index 315531d..7c91533 100644 --- a/bao_test.go +++ b/bao/bao_test.go @@ -1,37 +1,41 @@ -package blake3_test +package bao_test import ( "bytes" "encoding/binary" + "encoding/hex" "fmt" "os" "testing" "lukechampine.com/blake3" + "lukechampine.com/blake3/bao" ) +func toHex(data []byte) string { return hex.EncodeToString(data) } + func TestBaoGolden(t *testing.T) { - data, err := os.ReadFile("testdata/vectors.json") + data, err := os.ReadFile("../testdata/vectors.json") if err != nil { t.Fatal(err) } - goldenInterleaved, err := os.ReadFile("testdata/bao-golden.bao") + goldenInterleaved, err := os.ReadFile("../testdata/bao-golden.bao") if err != nil { t.Fatal(err) } - goldenOutboard, err := os.ReadFile("testdata/bao-golden.obao") + goldenOutboard, err := os.ReadFile("../testdata/bao-golden.obao") if err != nil { t.Fatal(err) } - interleaved, root := blake3.BaoEncodeBuf(data, 0, false) + interleaved, root := bao.EncodeBuf(data, 0, false) if toHex(root[:]) != "6654fbd1836b531b25e2782c9cc9b792c80abb36b024f59db5d5f6bd3187ddfe" { t.Errorf("bad root: %x", root) } else if !bytes.Equal(interleaved, goldenInterleaved) { t.Error("bad interleaved encoding") } - outboard, root := blake3.BaoEncodeBuf(data, 0, true) + outboard, root := bao.EncodeBuf(data, 0, true) if toHex(root[:]) != "6654fbd1836b531b25e2782c9cc9b792c80abb36b024f59db5d5f6bd3187ddfe" { t.Errorf("bad root: %x", root) } else if !bytes.Equal(outboard, goldenOutboard) { @@ -39,20 +43,20 @@ func TestBaoGolden(t *testing.T) { } // test empty input - interleaved, root = blake3.BaoEncodeBuf(nil, 0, false) + interleaved, root = bao.EncodeBuf(nil, 0, false) if toHex(root[:]) != "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262" { t.Errorf("bad root: %x", root) } else if toHex(interleaved[:]) != "0000000000000000" { t.Errorf("bad interleaved encoding: %x", interleaved) - } else if !blake3.BaoVerifyBuf(interleaved, nil, 0, root) { + } else if !bao.VerifyBuf(interleaved, nil, 0, root) { t.Error("verify failed") } - outboard, root = blake3.BaoEncodeBuf(nil, 0, true) + outboard, root = bao.EncodeBuf(nil, 0, true) if toHex(root[:]) != "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262" { t.Errorf("bad root: %x", root) } else if toHex(outboard[:]) != "0000000000000000" { t.Errorf("bad outboard encoding: %x", outboard) - } else if !blake3.BaoVerifyBuf(nil, outboard, 0, root) { + } else if !bao.VerifyBuf(nil, outboard, 0, root) { t.Error("verify failed") } } @@ -62,32 +66,32 @@ func TestBaoInterleaved(t *testing.T) { blake3.New(0, nil).XOF().Read(data) for group := 0; group < 10; group++ { - interleaved, root := blake3.BaoEncodeBuf(data, group, false) - if !blake3.BaoVerifyBuf(interleaved, nil, group, root) { + interleaved, root := bao.EncodeBuf(data, group, false) + if !bao.VerifyBuf(interleaved, nil, group, root) { t.Fatal("verify failed") } badRoot := root badRoot[0] ^= 1 - if blake3.BaoVerifyBuf(interleaved, nil, group, badRoot) { + if bao.VerifyBuf(interleaved, nil, group, badRoot) { t.Fatal("verify succeeded with bad root") } badPrefix := append([]byte(nil), interleaved...) badPrefix[0] ^= 1 - if blake3.BaoVerifyBuf(badPrefix, nil, group, root) { + if bao.VerifyBuf(badPrefix, nil, group, root) { t.Fatal("verify succeeded with bad length prefix") } badCVs := append([]byte(nil), interleaved...) badCVs[8] ^= 1 - if blake3.BaoVerifyBuf(badCVs, nil, group, root) { + if bao.VerifyBuf(badCVs, nil, group, root) { t.Fatal("verify succeeded with bad cv data") } badData := append([]byte(nil), interleaved...) badData[len(badData)-1] ^= 1 - if blake3.BaoVerifyBuf(badData, nil, group, root) { + if bao.VerifyBuf(badData, nil, group, root) { t.Fatal("verify succeeded with bad content") } extraData := append(append([]byte(nil), interleaved...), 1, 2, 3) - if blake3.BaoVerifyBuf(extraData, nil, group, root) { + if bao.VerifyBuf(extraData, nil, group, root) { t.Fatal("verify succeeded with extra data") } } @@ -98,23 +102,23 @@ func TestBaoOutboard(t *testing.T) { blake3.New(0, nil).XOF().Read(data) for group := 0; group < 10; group++ { - outboard, root := blake3.BaoEncodeBuf(data, group, true) - if !blake3.BaoVerifyBuf(data, outboard, group, root) { + outboard, root := bao.EncodeBuf(data, group, true) + if !bao.VerifyBuf(data, outboard, group, root) { t.Fatal("verify failed") } badRoot := root badRoot[0] ^= 1 - if blake3.BaoVerifyBuf(data, outboard, group, badRoot) { + if bao.VerifyBuf(data, outboard, group, badRoot) { t.Fatal("verify succeeded with bad root") } badPrefix := append([]byte(nil), outboard...) badPrefix[0] ^= 1 - if blake3.BaoVerifyBuf(data, badPrefix, group, root) { + if bao.VerifyBuf(data, badPrefix, group, root) { t.Fatal("verify succeeded with bad length prefix") } badCVs := append([]byte(nil), outboard...) badCVs[8] ^= 1 - if blake3.BaoVerifyBuf(data, badCVs, group, root) { + if bao.VerifyBuf(data, badCVs, group, root) { t.Fatal("verify succeeded with bad cv data") } } @@ -147,7 +151,7 @@ func TestBaoChunkGroup(t *testing.T) { {212992, "760c549edfe95c734b1d6a9b846d81692ed3ca022b541442949a0e42fe570df2"}, } { input := baoInput(test.inputLen) - _, root := blake3.BaoEncodeBuf(input, group, false) + _, root := bao.EncodeBuf(input, group, false) if out := fmt.Sprintf("%x", root); out != test.exp { t.Errorf("output %v did not match test vector:\n\texpected: %v...\n\t got: %v...", test.inputLen, test.exp[:10], out[:10]) } @@ -158,12 +162,12 @@ func TestBaoStreaming(t *testing.T) { data := make([]byte, 1<<20) blake3.New(0, nil).XOF().Read(data) - enc, root := blake3.BaoEncodeBuf(data, 0, false) + enc, root := bao.EncodeBuf(data, 0, false) if root != blake3.Sum256(data) { t.Fatal("bad root") } var buf bytes.Buffer - if ok, err := blake3.BaoDecode(&buf, bytes.NewReader(enc), nil, 0, root); err != nil || !ok { + if ok, err := bao.Decode(&buf, bytes.NewReader(enc), nil, 0, root); err != nil || !ok { t.Fatal("decode failed") } else if !bytes.Equal(buf.Bytes(), data) { t.Fatal("bad decode") @@ -171,7 +175,7 @@ func TestBaoStreaming(t *testing.T) { // corrupt root; nothing should be written to buf buf.Reset() - if ok, err := blake3.BaoDecode(&buf, bytes.NewReader(enc), nil, 0, [32]byte{}); err != nil { + if ok, err := bao.Decode(&buf, bytes.NewReader(enc), nil, 0, [32]byte{}); err != nil { t.Fatal("decode failed") } else if ok { t.Fatal("decode succeeded with bad root") @@ -182,7 +186,7 @@ func TestBaoStreaming(t *testing.T) { // corrupt a byte halfway through; buf should only be partially written buf.Reset() enc[len(enc)/2] ^= 1 - if ok, err := blake3.BaoDecode(&buf, bytes.NewReader(enc), nil, 0, root); err != nil { + if ok, err := bao.Decode(&buf, bytes.NewReader(enc), nil, 0, root); err != nil { t.Fatal("decode failed") } else if ok { t.Fatal("decode succeeded with bad data") @@ -207,11 +211,11 @@ func TestBaoSlice(t *testing.T) { } { // combined encoding { - enc, root := blake3.BaoEncodeBuf(data, 0, false) + enc, root := bao.EncodeBuf(data, 0, false) var buf bytes.Buffer - if err := blake3.BaoExtractSlice(&buf, bytes.NewReader(enc), nil, 0, test.off, test.len); err != nil { + if err := bao.ExtractSlice(&buf, bytes.NewReader(enc), nil, 0, test.off, test.len); err != nil { t.Error(err) - } else if vdata, ok := blake3.BaoVerifySlice(buf.Bytes(), 0, test.off, test.len, root); !ok { + } else if vdata, ok := bao.VerifySlice(buf.Bytes(), 0, test.off, test.len, root); !ok { t.Error("combined verify failed", test) } else if !bytes.Equal(vdata, data[test.off:][:test.len]) { t.Error("combined bad decode", test, vdata, data[test.off:][:test.len]) @@ -219,15 +223,15 @@ func TestBaoSlice(t *testing.T) { } // outboard encoding { - enc, root := blake3.BaoEncodeBuf(data, 0, true) + enc, root := bao.EncodeBuf(data, 0, true) start, end := (test.off/1024)*1024, ((test.off+test.len+1024-1)/1024)*1024 if end > uint64(len(data)) { end = uint64(len(data)) } var buf bytes.Buffer - if err := blake3.BaoExtractSlice(&buf, bytes.NewReader(data[start:end]), bytes.NewReader(enc), 0, test.off, test.len); err != nil { + if err := bao.ExtractSlice(&buf, bytes.NewReader(data[start:end]), bytes.NewReader(enc), 0, test.off, test.len); err != nil { t.Error(err) - } else if vdata, ok := blake3.BaoVerifySlice(buf.Bytes(), 0, test.off, test.len, root); !ok { + } else if vdata, ok := bao.VerifySlice(buf.Bytes(), 0, test.off, test.len, root); !ok { t.Error("outboard verify failed", test) } else if !bytes.Equal(vdata, data[test.off:][:test.len]) { t.Error("outboard bad decode", test, vdata, data[test.off:][:test.len]) diff --git a/blake3.go b/blake3.go index 262cfa8..259e126 100644 --- a/blake3.go +++ b/blake3.go @@ -8,51 +8,11 @@ import ( "io" "math" "math/bits" -) - -const ( - flagChunkStart = 1 << iota - flagChunkEnd - flagParent - flagRoot - flagKeyedHash - flagDeriveKeyContext - flagDeriveKeyMaterial - blockSize = 64 - chunkSize = 1024 - - maxSIMD = 16 // AVX-512 vectors can store 16 words + "lukechampine.com/blake3/bao" + "lukechampine.com/blake3/guts" ) -var iv = [8]uint32{ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, - 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, -} - -// A node represents a chunk or parent in the BLAKE3 Merkle tree. -type node struct { - cv [8]uint32 // chaining value from previous node - block [16]uint32 - counter uint64 - blockLen uint32 - flags uint32 -} - -// parentNode returns a node that incorporates the chaining values of two child -// nodes. -func parentNode(left, right [8]uint32, key [8]uint32, flags uint32) node { - n := node{ - cv: key, - counter: 0, // counter is reset for parents - blockLen: blockSize, // block is full - flags: flags | flagParent, - } - copy(n.block[:8], left[:]) - copy(n.block[8:], right[:]) - return n -} - // Hasher implements hash.Hash. type Hasher struct { key [8]uint32 @@ -60,10 +20,10 @@ type Hasher struct { size int // output size, for Sum // log(n) set of Merkle subtree roots, at most one per height. - stack [50][8]uint32 // 2^50 * maxSIMD * chunkSize = 2^64 - counter uint64 // number of buffers hashed; also serves as a bit vector indicating which stack elems are occupied + stack [64 - (guts.MaxSIMD + 10)][8]uint32 // 10 = log2(guts.ChunkSize) + counter uint64 // number of buffers hashed; also serves as a bit vector indicating which stack elems are occupied - buf [maxSIMD * chunkSize]byte + buf [guts.MaxSIMD * guts.ChunkSize]byte buflen int } @@ -75,7 +35,7 @@ func (h *Hasher) pushSubtree(cv [8]uint32) { // seek to first open stack slot, merging subtrees as we go i := 0 for h.hasSubtreeAtHeight(i) { - cv = chainingValue(parentNode(h.stack[i], cv, h.key, h.flags)) + cv = guts.ChainingValue(guts.ParentNode(h.stack[i], cv, &h.key, h.flags)) i++ } h.stack[i] = cv @@ -84,14 +44,14 @@ func (h *Hasher) pushSubtree(cv [8]uint32) { // rootNode computes the root of the Merkle tree. It does not modify the // stack. -func (h *Hasher) rootNode() node { - n := compressBuffer(&h.buf, h.buflen, &h.key, h.counter*maxSIMD, h.flags) +func (h *Hasher) rootNode() guts.Node { + n := guts.CompressBuffer(&h.buf, h.buflen, &h.key, h.counter*guts.MaxSIMD, h.flags) for i := bits.TrailingZeros64(h.counter); i < bits.Len64(h.counter); i++ { if h.hasSubtreeAtHeight(i) { - n = parentNode(h.stack[i], chainingValue(n), h.key, h.flags) + n = guts.ParentNode(h.stack[i], guts.ChainingValue(n), &h.key, h.flags) } } - n.flags |= flagRoot + n.Flags |= guts.FlagRoot return n } @@ -100,8 +60,8 @@ func (h *Hasher) Write(p []byte) (int, error) { lenp := len(p) for len(p) > 0 { if h.buflen == len(h.buf) { - n := compressBuffer(&h.buf, h.buflen, &h.key, h.counter*maxSIMD, h.flags) - h.pushSubtree(chainingValue(n)) + n := guts.CompressBuffer(&h.buf, h.buflen, &h.key, h.counter*guts.MaxSIMD, h.flags) + h.pushSubtree(guts.ChainingValue(n)) h.buflen = 0 } n := copy(h.buf[h.buflen:], p) @@ -125,8 +85,7 @@ func (h *Hasher) Sum(b []byte) (sum []byte) { // path for small digests (requiring a single compression), and a // high-latency-high-throughput path for large digests. if dst := sum[len(b):]; len(dst) <= 64 { - var out [64]byte - wordsToBytes(compressNode(h.rootNode()), &out) + out := guts.WordsToBytes(guts.CompressNode(h.rootNode())) copy(dst, out[:]) } else { h.XOF().Read(dst) @@ -165,13 +124,13 @@ func newHasher(key [8]uint32, flags uint32, size int) *Hasher { // the hash is unkeyed. Otherwise, len(key) must be 32. func New(size int, key []byte) *Hasher { if key == nil { - return newHasher(iv, 0, size) + return newHasher(guts.IV, 0, size) } var keyWords [8]uint32 for i := range keyWords { keyWords[i] = binary.LittleEndian.Uint32(key[i*4:]) } - return newHasher(keyWords, flagKeyedHash, size) + return newHasher(keyWords, guts.FlagKeyedHash, size) } // Sum256 and Sum512 always use the same hasher state, so we can save some time @@ -187,20 +146,25 @@ func Sum256(b []byte) (out [32]byte) { // Sum512 returns the unkeyed BLAKE3 hash of b, truncated to 512 bits. func Sum512(b []byte) (out [64]byte) { - var n node - if len(b) <= blockSize { - hashBlock(&out, b) - return - } else if len(b) <= chunkSize { - n = compressChunk(b, &iv, 0, 0) - n.flags |= flagRoot + var n guts.Node + if len(b) <= guts.BlockSize { + var block [64]byte + copy(block[:], b) + return guts.WordsToBytes(guts.CompressNode(guts.Node{ + CV: guts.IV, + Block: guts.BytesToWords(block), + BlockLen: uint32(len(b)), + Flags: guts.FlagChunkStart | guts.FlagChunkEnd | guts.FlagRoot, + })) + } else if len(b) <= guts.ChunkSize { + n = guts.CompressChunk(b, &guts.IV, 0, 0) + n.Flags |= guts.FlagRoot } else { h := *defaultHasher h.Write(b) n = h.rootNode() } - wordsToBytes(compressNode(n), &out) - return + return guts.WordsToBytes(guts.CompressNode(n)) } // DeriveKey derives a subkey from ctx and srcKey. ctx should be hardcoded, @@ -217,14 +181,14 @@ func Sum512(b []byte) (out [64]byte) { func DeriveKey(subKey []byte, ctx string, srcKey []byte) { // construct the derivation Hasher const derivationIVLen = 32 - h := newHasher(iv, flagDeriveKeyContext, 32) + h := newHasher(guts.IV, guts.FlagDeriveKeyContext, 32) h.Write([]byte(ctx)) derivationIV := h.Sum(make([]byte, 0, derivationIVLen)) var ivWords [8]uint32 for i := range ivWords { ivWords[i] = binary.LittleEndian.Uint32(derivationIV[i*4:]) } - h = newHasher(ivWords, flagDeriveKeyMaterial, 0) + h = newHasher(ivWords, guts.FlagDeriveKeyMaterial, 0) // derive the subKey h.Write(srcKey) h.XOF().Read(subKey) @@ -233,8 +197,8 @@ func DeriveKey(subKey []byte, ctx string, srcKey []byte) { // An OutputReader produces an seekable stream of 2^64 - 1 pseudorandom output // bytes. type OutputReader struct { - n node - buf [maxSIMD * blockSize]byte + n guts.Node + buf [guts.MaxSIMD * guts.BlockSize]byte off uint64 } @@ -248,11 +212,11 @@ func (or *OutputReader) Read(p []byte) (int, error) { } lenp := len(p) for len(p) > 0 { - if or.off%(maxSIMD*blockSize) == 0 { - or.n.counter = or.off / blockSize - compressBlocks(&or.buf, or.n) + if or.off%(guts.MaxSIMD*guts.BlockSize) == 0 { + or.n.Counter = or.off / guts.BlockSize + guts.CompressBlocks(&or.buf, or.n) } - n := copy(p, or.buf[or.off%(maxSIMD*blockSize):]) + n := copy(p, or.buf[or.off%(guts.MaxSIMD*guts.BlockSize):]) p = p[n:] or.off += uint64(n) } @@ -283,9 +247,9 @@ func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { panic("invalid whence") } or.off = off - or.n.counter = uint64(off) / blockSize - if or.off%(maxSIMD*blockSize) != 0 { - compressBlocks(&or.buf, or.n) + or.n.Counter = uint64(off) / guts.BlockSize + if or.off%(guts.MaxSIMD*guts.BlockSize) != 0 { + guts.CompressBlocks(&or.buf, or.n) } // NOTE: or.off >= 2^63 will result in a negative return value. // Nothing we can do about this. @@ -294,3 +258,41 @@ func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { // ensure that Hasher implements hash.Hash var _ hash.Hash = (*Hasher)(nil) + +// EncodedSize returns the size of a Bao encoding for the provided quantity +// of data. +// +// Deprecated: Use bao.EncodedSize instead. +func BaoEncodedSize(dataLen int, outboard bool) int { + return bao.EncodedSize(dataLen, 0, outboard) +} + +// BaoEncode computes the intermediate BLAKE3 tree hashes of data and writes +// them to dst. +// +// Deprecated: Use bao.Encode instead. +func BaoEncode(dst io.WriterAt, data io.Reader, dataLen int64, outboard bool) ([32]byte, error) { + return bao.Encode(dst, data, dataLen, 0, outboard) +} + +// BaoDecode reads content and tree data from the provided reader(s), and +// streams the verified content to dst. +// +// Deprecated: Use bao.Decode instead. +func BaoDecode(dst io.Writer, data, outboard io.Reader, root [32]byte) (bool, error) { + return bao.Decode(dst, data, outboard, 0, root) +} + +// BaoEncodeBuf returns the Bao encoding and root (i.e. BLAKE3 hash) for data. +// +// Deprecated: Use bao.EncodeBuf instead. +func BaoEncodeBuf(data []byte, outboard bool) ([]byte, [32]byte) { + return bao.EncodeBuf(data, 0, outboard) +} + +// BaoVerifyBuf verifies the Bao encoding and root (i.e. BLAKE3 hash) for data. +// +// Deprecated: Use bao.VerifyBuf instead. +func BaoVerifyBuf(data, outboard []byte, root [32]byte) bool { + return bao.VerifyBuf(data, outboard, 0, root) +} diff --git a/compress_amd64.go b/compress_amd64.go deleted file mode 100644 index 0114bda..0000000 --- a/compress_amd64.go +++ /dev/null @@ -1,147 +0,0 @@ -package blake3 - -import "unsafe" - -//go:generate go run avo/gen.go -out blake3_amd64.s - -//go:noescape -func compressChunksAVX512(cvs *[16][8]uint32, buf *[16 * chunkSize]byte, key *[8]uint32, counter uint64, flags uint32) - -//go:noescape -func compressChunksAVX2(cvs *[8][8]uint32, buf *[8 * chunkSize]byte, key *[8]uint32, counter uint64, flags uint32) - -//go:noescape -func compressBlocksAVX512(out *[1024]byte, block *[16]uint32, cv *[8]uint32, counter uint64, blockLen uint32, flags uint32) - -//go:noescape -func compressBlocksAVX2(out *[512]byte, msgs *[16]uint32, cv *[8]uint32, counter uint64, blockLen uint32, flags uint32) - -//go:noescape -func compressParentsAVX2(parents *[8][8]uint32, cvs *[16][8]uint32, key *[8]uint32, flags uint32) - -func compressNode(n node) (out [16]uint32) { - compressNodeGeneric(&out, n) - return -} - -func compressBufferAVX512(buf *[maxSIMD * chunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) node { - var cvs [maxSIMD][8]uint32 - compressChunksAVX512(&cvs, buf, key, counter, flags) - numChunks := uint64(buflen / chunkSize) - if buflen%chunkSize != 0 { - // use non-asm for remainder - partialChunk := buf[buflen-buflen%chunkSize : buflen] - cvs[numChunks] = chainingValue(compressChunk(partialChunk, key, counter+numChunks, flags)) - numChunks++ - } - return mergeSubtrees(&cvs, numChunks, key, flags) -} - -func compressBufferAVX2(buf *[maxSIMD * chunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) node { - var cvs [maxSIMD][8]uint32 - cvHalves := (*[2][8][8]uint32)(unsafe.Pointer(&cvs)) - bufHalves := (*[2][8 * chunkSize]byte)(unsafe.Pointer(buf)) - compressChunksAVX2(&cvHalves[0], &bufHalves[0], key, counter, flags) - numChunks := uint64(buflen / chunkSize) - if numChunks > 8 { - compressChunksAVX2(&cvHalves[1], &bufHalves[1], key, counter+8, flags) - } - if buflen%chunkSize != 0 { - // use non-asm for remainder - partialChunk := buf[buflen-buflen%chunkSize : buflen] - cvs[numChunks] = chainingValue(compressChunk(partialChunk, key, counter+numChunks, flags)) - numChunks++ - } - return mergeSubtrees(&cvs, numChunks, key, flags) -} - -func compressBuffer(buf *[maxSIMD * chunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) node { - if buflen <= chunkSize { - return compressChunk(buf[:buflen], key, counter, flags) - } - switch { - case haveAVX512 && buflen >= chunkSize*2: - return compressBufferAVX512(buf, buflen, key, counter, flags) - case haveAVX2 && buflen >= chunkSize*2: - return compressBufferAVX2(buf, buflen, key, counter, flags) - default: - return compressBufferGeneric(buf, buflen, key, counter, flags) - } -} - -func compressChunk(chunk []byte, key *[8]uint32, counter uint64, flags uint32) node { - n := node{ - cv: *key, - counter: counter, - blockLen: blockSize, - flags: flags | flagChunkStart, - } - blockBytes := (*[64]byte)(unsafe.Pointer(&n.block))[:] - for len(chunk) > blockSize { - copy(blockBytes, chunk) - chunk = chunk[blockSize:] - n.cv = chainingValue(n) - n.flags &^= flagChunkStart - } - // pad last block with zeros - n.block = [16]uint32{} - copy(blockBytes, chunk) - n.blockLen = uint32(len(chunk)) - n.flags |= flagChunkEnd - return n -} - -func hashBlock(out *[64]byte, buf []byte) { - var block [16]uint32 - copy((*[64]byte)(unsafe.Pointer(&block))[:], buf) - compressNodeGeneric((*[16]uint32)(unsafe.Pointer(out)), node{ - cv: iv, - block: block, - blockLen: uint32(len(buf)), - flags: flagChunkStart | flagChunkEnd | flagRoot, - }) -} - -func compressBlocks(out *[maxSIMD * blockSize]byte, n node) { - switch { - case haveAVX512: - compressBlocksAVX512(out, &n.block, &n.cv, n.counter, n.blockLen, n.flags) - case haveAVX2: - outs := (*[2][512]byte)(unsafe.Pointer(out)) - compressBlocksAVX2(&outs[0], &n.block, &n.cv, n.counter, n.blockLen, n.flags) - compressBlocksAVX2(&outs[1], &n.block, &n.cv, n.counter+8, n.blockLen, n.flags) - default: - outs := (*[maxSIMD][64]byte)(unsafe.Pointer(out)) - compressBlocksGeneric(outs, n) - } -} - -func mergeSubtrees(cvs *[maxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) node { - if !haveAVX2 { - return mergeSubtreesGeneric(cvs, numCVs, key, flags) - } - for numCVs > 2 { - if numCVs%2 == 0 { - compressParentsAVX2((*[8][8]uint32)(unsafe.Pointer(cvs)), cvs, key, flags) - } else { - keep := cvs[numCVs-1] - compressParentsAVX2((*[8][8]uint32)(unsafe.Pointer(cvs)), cvs, key, flags) - cvs[numCVs/2] = keep - numCVs++ - } - numCVs /= 2 - } - return parentNode(cvs[0], cvs[1], *key, flags) -} - -func wordsToBytes(words [16]uint32, block *[64]byte) { - *block = *(*[64]byte)(unsafe.Pointer(&words)) -} - -func bytesToCV(b []byte) [8]uint32 { - return *(*[8]uint32)(unsafe.Pointer(&b[0])) -} - -func cvToBytes(cv *[8]uint32) *[32]byte { - return (*[32]byte)(unsafe.Pointer(cv)) -} diff --git a/compress_generic.go b/compress_generic.go deleted file mode 100644 index b033b65..0000000 --- a/compress_generic.go +++ /dev/null @@ -1,143 +0,0 @@ -package blake3 - -import ( - "bytes" - "math/bits" -) - -func compressNodeGeneric(out *[16]uint32, n node) { - g := func(a, b, c, d, mx, my uint32) (uint32, uint32, uint32, uint32) { - a += b + mx - d = bits.RotateLeft32(d^a, -16) - c += d - b = bits.RotateLeft32(b^c, -12) - a += b + my - d = bits.RotateLeft32(d^a, -8) - c += d - b = bits.RotateLeft32(b^c, -7) - return a, b, c, d - } - - // NOTE: we unroll all of the rounds, as well as the permutations that occur - // between rounds. - - // round 1 (also initializes state) - // columns - s0, s4, s8, s12 := g(n.cv[0], n.cv[4], iv[0], uint32(n.counter), n.block[0], n.block[1]) - s1, s5, s9, s13 := g(n.cv[1], n.cv[5], iv[1], uint32(n.counter>>32), n.block[2], n.block[3]) - s2, s6, s10, s14 := g(n.cv[2], n.cv[6], iv[2], n.blockLen, n.block[4], n.block[5]) - s3, s7, s11, s15 := g(n.cv[3], n.cv[7], iv[3], n.flags, n.block[6], n.block[7]) - // diagonals - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[8], n.block[9]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[10], n.block[11]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[12], n.block[13]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[14], n.block[15]) - - // round 2 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[2], n.block[6]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[3], n.block[10]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[7], n.block[0]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[4], n.block[13]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[1], n.block[11]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[12], n.block[5]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[9], n.block[14]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[15], n.block[8]) - - // round 3 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[3], n.block[4]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[10], n.block[12]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[13], n.block[2]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[7], n.block[14]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[6], n.block[5]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[9], n.block[0]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[11], n.block[15]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[8], n.block[1]) - - // round 4 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[10], n.block[7]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[12], n.block[9]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[14], n.block[3]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[13], n.block[15]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[4], n.block[0]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[11], n.block[2]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[5], n.block[8]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[1], n.block[6]) - - // round 5 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[12], n.block[13]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[9], n.block[11]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[15], n.block[10]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[14], n.block[8]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[7], n.block[2]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[5], n.block[3]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[0], n.block[1]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[6], n.block[4]) - - // round 6 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[9], n.block[14]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[11], n.block[5]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[8], n.block[12]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[15], n.block[1]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[13], n.block[3]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[0], n.block[10]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[2], n.block[6]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[4], n.block[7]) - - // round 7 - s0, s4, s8, s12 = g(s0, s4, s8, s12, n.block[11], n.block[15]) - s1, s5, s9, s13 = g(s1, s5, s9, s13, n.block[5], n.block[0]) - s2, s6, s10, s14 = g(s2, s6, s10, s14, n.block[1], n.block[9]) - s3, s7, s11, s15 = g(s3, s7, s11, s15, n.block[8], n.block[6]) - s0, s5, s10, s15 = g(s0, s5, s10, s15, n.block[14], n.block[10]) - s1, s6, s11, s12 = g(s1, s6, s11, s12, n.block[2], n.block[12]) - s2, s7, s8, s13 = g(s2, s7, s8, s13, n.block[3], n.block[4]) - s3, s4, s9, s14 = g(s3, s4, s9, s14, n.block[7], n.block[13]) - - // finalization - *out = [16]uint32{ - s0 ^ s8, s1 ^ s9, s2 ^ s10, s3 ^ s11, - s4 ^ s12, s5 ^ s13, s6 ^ s14, s7 ^ s15, - s8 ^ n.cv[0], s9 ^ n.cv[1], s10 ^ n.cv[2], s11 ^ n.cv[3], - s12 ^ n.cv[4], s13 ^ n.cv[5], s14 ^ n.cv[6], s15 ^ n.cv[7], - } -} - -func chainingValue(n node) (cv [8]uint32) { - full := compressNode(n) - copy(cv[:], full[:]) - return -} - -func compressBufferGeneric(buf *[maxSIMD * chunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) (n node) { - if buflen <= chunkSize { - return compressChunk(buf[:buflen], key, counter, flags) - } - var cvs [maxSIMD][8]uint32 - var numCVs uint64 - for bb := bytes.NewBuffer(buf[:buflen]); bb.Len() > 0; numCVs++ { - cvs[numCVs] = chainingValue(compressChunk(bb.Next(chunkSize), key, counter+numCVs, flags)) - } - return mergeSubtrees(&cvs, numCVs, key, flags) -} - -func compressBlocksGeneric(outs *[maxSIMD][64]byte, n node) { - for i := range outs { - wordsToBytes(compressNode(n), &outs[i]) - n.counter++ - } -} - -func mergeSubtreesGeneric(cvs *[maxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) node { - for numCVs > 2 { - rem := numCVs / 2 - for i := range cvs[:rem] { - cvs[i] = chainingValue(parentNode(cvs[i*2], cvs[i*2+1], *key, flags)) - } - if numCVs%2 != 0 { - cvs[rem] = cvs[rem*2] - rem++ - } - numCVs = rem - } - return parentNode(cvs[0], cvs[1], *key, flags) -} diff --git a/compress_noasm.go b/compress_noasm.go deleted file mode 100644 index c38819d..0000000 --- a/compress_noasm.go +++ /dev/null @@ -1,93 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -package blake3 - -import "encoding/binary" - -func compressNode(n node) (out [16]uint32) { - compressNodeGeneric(&out, n) - return -} - -func compressBuffer(buf *[maxSIMD * chunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) node { - return compressBufferGeneric(buf, buflen, key, counter, flags) -} - -func compressChunk(chunk []byte, key *[8]uint32, counter uint64, flags uint32) node { - n := node{ - cv: *key, - counter: counter, - blockLen: blockSize, - flags: flags | flagChunkStart, - } - var block [blockSize]byte - for len(chunk) > blockSize { - copy(block[:], chunk) - chunk = chunk[blockSize:] - bytesToWords(block, &n.block) - n.cv = chainingValue(n) - n.flags &^= flagChunkStart - } - // pad last block with zeros - block = [blockSize]byte{} - n.blockLen = uint32(len(chunk)) - copy(block[:], chunk) - bytesToWords(block, &n.block) - n.flags |= flagChunkEnd - return n -} - -func hashBlock(out *[64]byte, buf []byte) { - var block [64]byte - var words [16]uint32 - copy(block[:], buf) - bytesToWords(block, &words) - compressNodeGeneric(&words, node{ - cv: iv, - block: words, - blockLen: uint32(len(buf)), - flags: flagChunkStart | flagChunkEnd | flagRoot, - }) - wordsToBytes(words, out) -} - -func compressBlocks(out *[maxSIMD * blockSize]byte, n node) { - var outs [maxSIMD][64]byte - compressBlocksGeneric(&outs, n) - for i := range outs { - copy(out[i*64:], outs[i][:]) - } -} - -func mergeSubtrees(cvs *[maxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) node { - return mergeSubtreesGeneric(cvs, numCVs, key, flags) -} - -func bytesToWords(bytes [64]byte, words *[16]uint32) { - for i := range words { - words[i] = binary.LittleEndian.Uint32(bytes[4*i:]) - } -} - -func wordsToBytes(words [16]uint32, block *[64]byte) { - for i, w := range words { - binary.LittleEndian.PutUint32(block[4*i:], w) - } -} - -func bytesToCV(b []byte) [8]uint32 { - var cv [8]uint32 - for i := range cv { - cv[i] = binary.LittleEndian.Uint32(b[4*i:]) - } - return cv -} - -func cvToBytes(cv *[8]uint32) *[32]byte { - var b [32]byte - for i, w := range cv { - binary.LittleEndian.PutUint32(b[4*i:], w) - } - return &b -} diff --git a/guts/compress_amd64.go b/guts/compress_amd64.go new file mode 100644 index 0000000..57b10be --- /dev/null +++ b/guts/compress_amd64.go @@ -0,0 +1,135 @@ +package guts + +import "unsafe" + +//go:generate go run avo/gen.go -out blake3_amd64.s + +//go:noescape +func compressChunksAVX512(cvs *[16][8]uint32, buf *[16 * ChunkSize]byte, key *[8]uint32, counter uint64, flags uint32) + +//go:noescape +func compressChunksAVX2(cvs *[8][8]uint32, buf *[8 * ChunkSize]byte, key *[8]uint32, counter uint64, flags uint32) + +//go:noescape +func compressBlocksAVX512(out *[1024]byte, block *[16]uint32, cv *[8]uint32, counter uint64, blockLen uint32, flags uint32) + +//go:noescape +func compressBlocksAVX2(out *[512]byte, msgs *[16]uint32, cv *[8]uint32, counter uint64, blockLen uint32, flags uint32) + +//go:noescape +func compressParentsAVX2(parents *[8][8]uint32, cvs *[16][8]uint32, key *[8]uint32, flags uint32) + +func compressBufferAVX512(buf *[MaxSIMD * ChunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) Node { + var cvs [MaxSIMD][8]uint32 + compressChunksAVX512(&cvs, buf, key, counter, flags) + numChunks := uint64(buflen / ChunkSize) + if buflen%ChunkSize != 0 { + // use non-asm for remainder + partialChunk := buf[buflen-buflen%ChunkSize : buflen] + cvs[numChunks] = ChainingValue(CompressChunk(partialChunk, key, counter+numChunks, flags)) + numChunks++ + } + return mergeSubtrees(&cvs, numChunks, key, flags) +} + +func compressBufferAVX2(buf *[MaxSIMD * ChunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) Node { + var cvs [MaxSIMD][8]uint32 + cvHalves := (*[2][8][8]uint32)(unsafe.Pointer(&cvs)) + bufHalves := (*[2][8 * ChunkSize]byte)(unsafe.Pointer(buf)) + compressChunksAVX2(&cvHalves[0], &bufHalves[0], key, counter, flags) + numChunks := uint64(buflen / ChunkSize) + if numChunks > 8 { + compressChunksAVX2(&cvHalves[1], &bufHalves[1], key, counter+8, flags) + } + if buflen%ChunkSize != 0 { + // use non-asm for remainder + partialChunk := buf[buflen-buflen%ChunkSize : buflen] + cvs[numChunks] = ChainingValue(CompressChunk(partialChunk, key, counter+numChunks, flags)) + numChunks++ + } + return mergeSubtrees(&cvs, numChunks, key, flags) +} + +// CompressBuffer compresses up to MaxSIMD chunks in parallel and returns their +// root node. +func CompressBuffer(buf *[MaxSIMD * ChunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) Node { + if buflen <= ChunkSize { + return CompressChunk(buf[:buflen], key, counter, flags) + } + switch { + case haveAVX512 && buflen >= ChunkSize*2: + return compressBufferAVX512(buf, buflen, key, counter, flags) + case haveAVX2 && buflen >= ChunkSize*2: + return compressBufferAVX2(buf, buflen, key, counter, flags) + default: + return compressBufferGeneric(buf, buflen, key, counter, flags) + } +} + +// CompressChunk compresses a single chunk, returning its final (uncompressed) +// node. +func CompressChunk(chunk []byte, key *[8]uint32, counter uint64, flags uint32) Node { + n := Node{ + CV: *key, + Counter: counter, + BlockLen: BlockSize, + Flags: flags | FlagChunkStart, + } + blockBytes := (*[64]byte)(unsafe.Pointer(&n.Block))[:] + for len(chunk) > BlockSize { + copy(blockBytes, chunk) + chunk = chunk[BlockSize:] + n.CV = ChainingValue(n) + n.Flags &^= FlagChunkStart + } + // pad last block with zeros + n.Block = [16]uint32{} + copy(blockBytes, chunk) + n.BlockLen = uint32(len(chunk)) + n.Flags |= FlagChunkEnd + return n +} + +// CompressBlocks compresses MaxSIMD copies of n with successive counter values, +// storing the results in out. +func CompressBlocks(out *[MaxSIMD * BlockSize]byte, n Node) { + switch { + case haveAVX512: + compressBlocksAVX512(out, &n.Block, &n.CV, n.Counter, n.BlockLen, n.Flags) + case haveAVX2: + outs := (*[2][512]byte)(unsafe.Pointer(out)) + compressBlocksAVX2(&outs[0], &n.Block, &n.CV, n.Counter, n.BlockLen, n.Flags) + compressBlocksAVX2(&outs[1], &n.Block, &n.CV, n.Counter+8, n.BlockLen, n.Flags) + default: + outs := (*[MaxSIMD][64]byte)(unsafe.Pointer(out)) + compressBlocksGeneric(outs, n) + } +} + +func mergeSubtrees(cvs *[MaxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) Node { + if !haveAVX2 { + return mergeSubtreesGeneric(cvs, numCVs, key, flags) + } + for numCVs > 2 { + if numCVs%2 == 0 { + compressParentsAVX2((*[8][8]uint32)(unsafe.Pointer(cvs)), cvs, key, flags) + } else { + keep := cvs[numCVs-1] + compressParentsAVX2((*[8][8]uint32)(unsafe.Pointer(cvs)), cvs, key, flags) + cvs[numCVs/2] = keep + numCVs++ + } + numCVs /= 2 + } + return ParentNode(cvs[0], cvs[1], key, flags) +} + +// BytesToWords converts an array of 64 bytes to an array of 16 bytes. +func BytesToWords(bytes [64]byte) [16]uint32 { + return *(*[16]uint32)(unsafe.Pointer(&bytes)) +} + +// WordsToBytes converts an array of 16 words to an array of 64 bytes. +func WordsToBytes(words [16]uint32) [64]byte { + return *(*[64]byte)(unsafe.Pointer(&words)) +} diff --git a/blake3_amd64.s b/guts/compress_amd64.s similarity index 99% rename from blake3_amd64.s rename to guts/compress_amd64.s index df6bd97..66c6bb1 100644 --- a/blake3_amd64.s +++ b/guts/compress_amd64.s @@ -1,4 +1,4 @@ -// Code generated by command: go run gen.go -out blake3_amd64.s. DO NOT EDIT. +// Code generated by command: go run gen.go -out compress_amd64.s. DO NOT EDIT. #include "textflag.h" diff --git a/guts/compress_generic.go b/guts/compress_generic.go new file mode 100644 index 0000000..6572836 --- /dev/null +++ b/guts/compress_generic.go @@ -0,0 +1,145 @@ +package guts + +import ( + "bytes" + "math/bits" +) + +// CompressNode compresses a node into a 16-word output. +func CompressNode(n Node) (out [16]uint32) { + g := func(a, b, c, d, mx, my uint32) (uint32, uint32, uint32, uint32) { + a += b + mx + d = bits.RotateLeft32(d^a, -16) + c += d + b = bits.RotateLeft32(b^c, -12) + a += b + my + d = bits.RotateLeft32(d^a, -8) + c += d + b = bits.RotateLeft32(b^c, -7) + return a, b, c, d + } + + // NOTE: we unroll all of the rounds, as well as the permutations that occur + // between rounds. + + // round 1 (also initializes state) + // columns + s0, s4, s8, s12 := g(n.CV[0], n.CV[4], IV[0], uint32(n.Counter), n.Block[0], n.Block[1]) + s1, s5, s9, s13 := g(n.CV[1], n.CV[5], IV[1], uint32(n.Counter>>32), n.Block[2], n.Block[3]) + s2, s6, s10, s14 := g(n.CV[2], n.CV[6], IV[2], n.BlockLen, n.Block[4], n.Block[5]) + s3, s7, s11, s15 := g(n.CV[3], n.CV[7], IV[3], n.Flags, n.Block[6], n.Block[7]) + // diagonals + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[8], n.Block[9]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[10], n.Block[11]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[12], n.Block[13]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[14], n.Block[15]) + + // round 2 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[2], n.Block[6]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[3], n.Block[10]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[7], n.Block[0]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[4], n.Block[13]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[1], n.Block[11]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[12], n.Block[5]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[9], n.Block[14]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[15], n.Block[8]) + + // round 3 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[3], n.Block[4]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[10], n.Block[12]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[13], n.Block[2]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[7], n.Block[14]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[6], n.Block[5]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[9], n.Block[0]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[11], n.Block[15]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[8], n.Block[1]) + + // round 4 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[10], n.Block[7]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[12], n.Block[9]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[14], n.Block[3]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[13], n.Block[15]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[4], n.Block[0]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[11], n.Block[2]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[5], n.Block[8]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[1], n.Block[6]) + + // round 5 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[12], n.Block[13]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[9], n.Block[11]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[15], n.Block[10]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[14], n.Block[8]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[7], n.Block[2]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[5], n.Block[3]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[0], n.Block[1]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[6], n.Block[4]) + + // round 6 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[9], n.Block[14]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[11], n.Block[5]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[8], n.Block[12]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[15], n.Block[1]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[13], n.Block[3]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[0], n.Block[10]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[2], n.Block[6]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[4], n.Block[7]) + + // round 7 + s0, s4, s8, s12 = g(s0, s4, s8, s12, n.Block[11], n.Block[15]) + s1, s5, s9, s13 = g(s1, s5, s9, s13, n.Block[5], n.Block[0]) + s2, s6, s10, s14 = g(s2, s6, s10, s14, n.Block[1], n.Block[9]) + s3, s7, s11, s15 = g(s3, s7, s11, s15, n.Block[8], n.Block[6]) + s0, s5, s10, s15 = g(s0, s5, s10, s15, n.Block[14], n.Block[10]) + s1, s6, s11, s12 = g(s1, s6, s11, s12, n.Block[2], n.Block[12]) + s2, s7, s8, s13 = g(s2, s7, s8, s13, n.Block[3], n.Block[4]) + s3, s4, s9, s14 = g(s3, s4, s9, s14, n.Block[7], n.Block[13]) + + // finalization + return [16]uint32{ + s0 ^ s8, s1 ^ s9, s2 ^ s10, s3 ^ s11, + s4 ^ s12, s5 ^ s13, s6 ^ s14, s7 ^ s15, + s8 ^ n.CV[0], s9 ^ n.CV[1], s10 ^ n.CV[2], s11 ^ n.CV[3], + s12 ^ n.CV[4], s13 ^ n.CV[5], s14 ^ n.CV[6], s15 ^ n.CV[7], + } +} + +// ChainingValue compresses n and returns the first 8 output words. +func ChainingValue(n Node) (cv [8]uint32) { + full := CompressNode(n) + copy(cv[:], full[:]) + return +} + +func compressBufferGeneric(buf *[MaxSIMD * ChunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) (n Node) { + if buflen <= ChunkSize { + return CompressChunk(buf[:buflen], key, counter, flags) + } + var cvs [MaxSIMD][8]uint32 + var numCVs uint64 + for bb := bytes.NewBuffer(buf[:buflen]); bb.Len() > 0; numCVs++ { + cvs[numCVs] = ChainingValue(CompressChunk(bb.Next(ChunkSize), key, counter+numCVs, flags)) + } + return mergeSubtrees(&cvs, numCVs, key, flags) +} + +func compressBlocksGeneric(outs *[MaxSIMD][64]byte, n Node) { + for i := range outs { + outs[i] = WordsToBytes(CompressNode(n)) + n.Counter++ + } +} + +func mergeSubtreesGeneric(cvs *[MaxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) Node { + for numCVs > 2 { + rem := numCVs / 2 + for i := range cvs[:rem] { + cvs[i] = ChainingValue(ParentNode(cvs[i*2], cvs[i*2+1], key, flags)) + } + if numCVs%2 != 0 { + cvs[rem] = cvs[rem*2] + rem++ + } + numCVs = rem + } + return ParentNode(cvs[0], cvs[1], key, flags) +} diff --git a/guts/compress_noasm.go b/guts/compress_noasm.go new file mode 100644 index 0000000..76517cb --- /dev/null +++ b/guts/compress_noasm.go @@ -0,0 +1,67 @@ +//go:build !amd64 +// +build !amd64 + +package guts + +import "encoding/binary" + +// CompressBuffer compresses up to MaxSIMD chunks in parallel and returns their +// root node. +func CompressBuffer(buf *[MaxSIMD * ChunkSize]byte, buflen int, key *[8]uint32, counter uint64, flags uint32) Node { + return compressBufferGeneric(buf, buflen, key, counter, flags) +} + +// CompressChunk compresses a single chunk, returning its final (uncompressed) +// node. +func CompressChunk(chunk []byte, key *[8]uint32, counter uint64, flags uint32) Node { + n := Node{ + CV: *key, + Counter: counter, + BlockLen: BlockSize, + Flags: flags | FlagChunkStart, + } + var block [BlockSize]byte + for len(chunk) > BlockSize { + copy(block[:], chunk) + chunk = chunk[BlockSize:] + n.Block = BytesToWords(block) + n.CV = ChainingValue(n) + n.Flags &^= FlagChunkStart + } + // pad last block with zeros + block = [BlockSize]byte{} + n.BlockLen = uint32(copy(block[:], chunk)) + n.Block = BytesToWords(block) + n.Flags |= FlagChunkEnd + return n +} + +// CompressBlocks compresses MaxSIMD copies of n with successive counter values, +// storing the results in out. +func CompressBlocks(out *[MaxSIMD * BlockSize]byte, n Node) { + var outs [MaxSIMD][64]byte + compressBlocksGeneric(&outs, n) + for i := range outs { + copy(out[i*64:], outs[i][:]) + } +} + +func mergeSubtrees(cvs *[MaxSIMD][8]uint32, numCVs uint64, key *[8]uint32, flags uint32) Node { + return mergeSubtreesGeneric(cvs, numCVs, key, flags) +} + +// BytesToWords converts an array of 64 bytes to an array of 16 bytes. +func BytesToWords(bytes [64]byte) (words [16]uint32) { + for i := range words { + words[i] = binary.LittleEndian.Uint32(bytes[4*i:]) + } + return +} + +// WordsToBytes converts an array of 16 words to an array of 64 bytes. +func WordsToBytes(words [16]uint32) (block [64]byte) { + for i, w := range words { + binary.LittleEndian.PutUint32(block[4*i:], w) + } + return +} diff --git a/cpu.go b/guts/cpu.go similarity index 83% rename from cpu.go rename to guts/cpu.go index c2a61e7..34e1038 100644 --- a/cpu.go +++ b/guts/cpu.go @@ -1,6 +1,7 @@ +//go:build !darwin // +build !darwin -package blake3 +package guts import "github.com/klauspost/cpuid/v2" diff --git a/cpu_darwin.go b/guts/cpu_darwin.go similarity index 96% rename from cpu_darwin.go rename to guts/cpu_darwin.go index 372b734..b1b35c7 100644 --- a/cpu_darwin.go +++ b/guts/cpu_darwin.go @@ -1,4 +1,4 @@ -package blake3 +package guts import ( "syscall" diff --git a/guts/node.go b/guts/node.go new file mode 100644 index 0000000..b14ab88 --- /dev/null +++ b/guts/node.go @@ -0,0 +1,48 @@ +// Package guts provides a low-level interface to the BLAKE3 cryptographic hash +// function. +package guts + +// Various constants. +const ( + FlagChunkStart = 1 << iota + FlagChunkEnd + FlagParent + FlagRoot + FlagKeyedHash + FlagDeriveKeyContext + FlagDeriveKeyMaterial + + BlockSize = 64 + ChunkSize = 1024 + + MaxSIMD = 16 // AVX-512 vectors can store 16 words +) + +// IV is the BLAKE3 initialization vector. +var IV = [8]uint32{ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, + 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +} + +// A Node represents a chunk or parent in the BLAKE3 Merkle tree. +type Node struct { + CV [8]uint32 // chaining value from previous node + Block [16]uint32 + Counter uint64 + BlockLen uint32 + Flags uint32 +} + +// ParentNode returns a Node that incorporates the chaining values of two child +// nodes. +func ParentNode(left, right [8]uint32, key *[8]uint32, flags uint32) Node { + n := Node{ + CV: *key, + Counter: 0, // counter is reset for parents + BlockLen: BlockSize, // block is full + Flags: flags | FlagParent, + } + copy(n.Block[:8], left[:]) + copy(n.Block[8:], right[:]) + return n +}