Skip to content

Commit

Permalink
sha3: make APIs usable with zero allocations
Browse files Browse the repository at this point in the history
The "buf points into storage" pattern is nice, but causes the whole
state struct to escape, since escape analysis can't track the pointer
once it's assigned to buf.

Change-Id: I31c0e83f946d66bedb5a180e96ab5d5e936eb322
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/544817
Reviewed-by: Cherry Mui <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
Reviewed-by: Mauri de Souza Meneguzzo <[email protected]>
Auto-Submit: Filippo Valsorda <[email protected]>
  • Loading branch information
FiloSottile authored and gopherbot committed May 7, 2024
1 parent 59b5a86 commit 477a5b4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 36 deletions.
53 changes: 53 additions & 0 deletions sha3/allocations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build !noopt

package sha3_test

import (
"testing"

"golang.org/x/crypto/sha3"
)

var sink byte

func TestAllocations(t *testing.T) {
t.Run("New", func(t *testing.T) {
if allocs := testing.AllocsPerRun(10, func() {
h := sha3.New256()
b := []byte("ABC")
h.Write(b)
out := make([]byte, 0, 32)
out = h.Sum(out)
sink ^= out[0]
}); allocs > 0 {
t.Errorf("expected zero allocations, got %0.1f", allocs)
}
})
t.Run("NewShake", func(t *testing.T) {
if allocs := testing.AllocsPerRun(10, func() {
h := sha3.NewShake128()
b := []byte("ABC")
h.Write(b)
out := make([]byte, 0, 32)
out = h.Sum(out)
sink ^= out[0]
h.Read(out)
sink ^= out[0]
}); allocs > 0 {
t.Errorf("expected zero allocations, got %0.1f", allocs)
}
})
t.Run("Sum", func(t *testing.T) {
if allocs := testing.AllocsPerRun(10, func() {
b := []byte("ABC")
out := sha3.Sum256(b)
sink ^= out[0]
}); allocs > 0 {
t.Errorf("expected zero allocations, got %0.1f", allocs)
}
})
}
60 changes: 24 additions & 36 deletions sha3/sha3.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ const (
type state struct {
// Generic sponge components.
a [25]uint64 // main state of the hash
buf []byte // points into storage
rate int // the number of bytes of state to use

// dsbyte contains the "domain separation" bits and the first bit of
Expand All @@ -40,6 +39,7 @@ type state struct {
// Extendable-Output Functions (May 2014)"
dsbyte byte

i, n int // storage[i:n] is the buffer, i is only used while squeezing
storage [maxRate]byte

// Specific to SHA-3 and SHAKE.
Expand All @@ -54,24 +54,18 @@ func (d *state) BlockSize() int { return d.rate }
func (d *state) Size() int { return d.outputLen }

// Reset clears the internal state by zeroing the sponge state and
// the byte buffer, and setting Sponge.state to absorbing.
// the buffer indexes, and setting Sponge.state to absorbing.
func (d *state) Reset() {
// Zero the permutation's state.
for i := range d.a {
d.a[i] = 0
}
d.state = spongeAbsorbing
d.buf = d.storage[:0]
d.i, d.n = 0, 0
}

func (d *state) clone() *state {
ret := *d
if ret.state == spongeAbsorbing {
ret.buf = ret.storage[:len(ret.buf)]
} else {
ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate]
}

return &ret
}

Expand All @@ -82,43 +76,40 @@ func (d *state) permute() {
case spongeAbsorbing:
// If we're absorbing, we need to xor the input into the state
// before applying the permutation.
xorIn(d, d.buf)
d.buf = d.storage[:0]
xorIn(d, d.storage[:d.rate])
d.n = 0
keccakF1600(&d.a)
case spongeSqueezing:
// If we're squeezing, we need to apply the permutation before
// copying more output.
keccakF1600(&d.a)
d.buf = d.storage[:d.rate]
copyOut(d, d.buf)
d.i = 0
copyOut(d, d.storage[:d.rate])
}
}

// pads appends the domain separation bits in dsbyte, applies
// the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute(dsbyte byte) {
if d.buf == nil {
d.buf = d.storage[:0]
}
func (d *state) padAndPermute() {
// Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in d.buf because, if it were full,
// permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct.
d.buf = append(d.buf, dsbyte)
zerosStart := len(d.buf)
d.buf = d.storage[:d.rate]
for i := zerosStart; i < d.rate; i++ {
d.buf[i] = 0
d.storage[d.n] = d.dsbyte
d.n++
for d.n < d.rate {
d.storage[d.n] = 0
d.n++
}
// This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte.
d.buf[d.rate-1] ^= 0x80
d.storage[d.rate-1] ^= 0x80
// Apply the permutation
d.permute()
d.state = spongeSqueezing
d.buf = d.storage[:d.rate]
copyOut(d, d.buf)
d.n = d.rate
copyOut(d, d.storage[:d.rate])
}

// Write absorbs more data into the hash's state. It panics if any
Expand All @@ -127,28 +118,25 @@ func (d *state) Write(p []byte) (written int, err error) {
if d.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
if d.buf == nil {
d.buf = d.storage[:0]
}
written = len(p)

for len(p) > 0 {
if len(d.buf) == 0 && len(p) >= d.rate {
if d.n == 0 && len(p) >= d.rate {
// The fast path; absorb a full "rate" bytes of input and apply the permutation.
xorIn(d, p[:d.rate])
p = p[d.rate:]
keccakF1600(&d.a)
} else {
// The slow path; buffer the input until we can fill the sponge, and then xor it in.
todo := d.rate - len(d.buf)
todo := d.rate - d.n
if todo > len(p) {
todo = len(p)
}
d.buf = append(d.buf, p[:todo]...)
d.n += copy(d.storage[d.n:], p[:todo])
p = p[todo:]

// If the sponge is full, apply the permutation.
if len(d.buf) == d.rate {
if d.n == d.rate {
d.permute()
}
}
Expand All @@ -161,19 +149,19 @@ func (d *state) Write(p []byte) (written int, err error) {
func (d *state) Read(out []byte) (n int, err error) {
// If we're still absorbing, pad and apply the permutation.
if d.state == spongeAbsorbing {
d.padAndPermute(d.dsbyte)
d.padAndPermute()
}

n = len(out)

// Now, do the squeezing.
for len(out) > 0 {
n := copy(out, d.buf)
d.buf = d.buf[n:]
n := copy(out, d.storage[d.i:d.n])
d.i += n
out = out[n:]

// Apply the permutation if we've squeezed the sponge dry.
if len(d.buf) == 0 {
if d.i == d.rate {
d.permute()
}
}
Expand Down

0 comments on commit 477a5b4

Please sign in to comment.