Skip to content

Commit

Permalink
shake: add validity check
Browse files Browse the repository at this point in the history
  • Loading branch information
MingLLuo committed Mar 11, 2024
1 parent 7067223 commit 40ba86e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
71 changes: 61 additions & 10 deletions sha3/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var testDigests = map[string]func() hash.Hash{
// testShakes contains functions that return sha3.ShakeHash instances for
// with output-length equal to the KAT length.
var testShakes = map[string]struct {
constructor func(N []byte, S []byte) ShakeHash
constructor func(N []byte, S []byte) (ShakeHash, error)
defAlgoName string
defCustomStr string
}{
Expand Down Expand Up @@ -136,7 +136,10 @@ func TestKeccakKats(t *testing.T) {
if err != nil {
t.Errorf("error decoding KAT: %s", err)
}
d := v.constructor(N, S)
d, err := v.constructor(N, S)
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
in, err := hex.DecodeString(kat.Message)
if err != nil {
t.Errorf("error decoding KAT: %s", err)
Expand Down Expand Up @@ -221,7 +224,10 @@ func TestUnalignedWrite(t *testing.T) {
for alg, df := range testShakes {
want := make([]byte, 16)
got := make([]byte, 16)
d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
d, err := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}

d.Reset()
d.Write(buf)
Expand Down Expand Up @@ -286,12 +292,19 @@ func TestAppendNoRealloc(t *testing.T) {
func TestSqueezing(t *testing.T) {
testUnalignedAndGeneric(t, func(impl string) {
for algo, v := range testShakes {
d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
d0, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}

d0.Write([]byte(testString))
ref := make([]byte, 32)
d0.Read(ref)

d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
d1, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
d1.Write([]byte(testString))
var multiple []byte
for range ref {
Expand Down Expand Up @@ -327,7 +340,10 @@ func TestReset(t *testing.T) {

for _, v := range testShakes {
// Calculate hash for the first time
c := v.constructor(nil, []byte{0x99, 0x98})
c, err := v.constructor(nil, []byte{0x99, 0x98})
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
c.Write(sequentialBytes(0x100))
c.Read(out1)

Expand All @@ -350,7 +366,10 @@ func TestClone(t *testing.T) {
for _, size := range []int{0x1, 0x100} {
in := sequentialBytes(size)
for _, v := range testShakes {
h1 := v.constructor(nil, []byte{0x01})
h1, err := v.constructor(nil, []byte{0x01})
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
h1.Write([]byte{0x01})

h2 := h1.Clone()
Expand All @@ -368,6 +387,26 @@ func TestClone(t *testing.T) {
}
}

// TestValidity tests the length validity checks for cSHAKE.
func TestValidity(t *testing.T) {
inValidBytes := make([]byte, 256)

for _, v := range testShakes {
_, err := v.constructor(nil, inValidBytes)
if err == nil {
t.Error("expected error for S length")
}
_, err = v.constructor(inValidBytes, nil)
if err == nil {
t.Error("expected error for N length")
}
_, err = v.constructor(inValidBytes, inValidBytes)
if err == nil {
t.Error("expected error for N and S length")
}
}
}

// BenchmarkPermutationFunction measures the speed of the permutation function
// with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
Expand Down Expand Up @@ -460,20 +499,32 @@ func ExampleNewCShake256() {
msg := []byte("The quick brown fox jumps over the lazy dog")

// Example 1: Simple cshake
c1 := NewCShake256([]byte("NAME"), []byte("Partition1"))
c1, err := NewCShake256([]byte("NAME"), []byte("Partition1"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))

// Example 2: Different customization string produces different digest
c1 = NewCShake256([]byte("NAME"), []byte("Partition2"))
c1, err = NewCShake256([]byte("NAME"), []byte("Partition2"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))

// Example 3: Longer output length produces longer digest
out = make([]byte, 64)
c1 = NewCShake256([]byte("NAME"), []byte("Partition1"))
c1, err = NewCShake256([]byte("NAME"), []byte("Partition1"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))
Expand Down
19 changes: 13 additions & 6 deletions sha3/shake.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package sha3

import (
"encoding/binary"
"errors"
"hash"
"io"
)
Expand Down Expand Up @@ -80,7 +81,11 @@ func leftEncode(value uint64) []byte {
return b[i-1:]
}

func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) (ShakeHash, error) {
if len(N) >= 256 || len(S) >= 256 {
return nil, errors.New("crypto/cSHAKE: N and S can be at most 255 bytes long")
}

c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}

// leftEncode returns max 9 bytes
Expand All @@ -90,7 +95,7 @@ func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...)
c.initBlock = append(c.initBlock, S...)
c.Write(bytepad(c.initBlock, c.rate))
return &c
return &c, nil
}

// Reset resets the hash to initial state.
Expand Down Expand Up @@ -137,9 +142,10 @@ func NewShake256() ShakeHash {
// desired. S is a customization byte string used for domain separation - two cSHAKE
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake128.
func NewCShake128(N, S []byte) ShakeHash {
// N and S can be at most 255 bytes long.
func NewCShake128(N, S []byte) (ShakeHash, error) {
if len(N) == 0 && len(S) == 0 {
return NewShake128()
return NewShake128(), nil
}
return newCShake(N, S, rate128, 32, dsbyteCShake)
}
Expand All @@ -150,9 +156,10 @@ func NewCShake128(N, S []byte) ShakeHash {
// desired. S is a customization byte string used for domain separation - two cSHAKE
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake256.
func NewCShake256(N, S []byte) ShakeHash {
// N and S can be at most 255 bytes long.
func NewCShake256(N, S []byte) (ShakeHash, error) {
if len(N) == 0 && len(S) == 0 {
return NewShake256()
return NewShake256(), nil
}
return newCShake(N, S, rate256, 64, dsbyteCShake)
}
Expand Down

0 comments on commit 40ba86e

Please sign in to comment.