Skip to content

Commit

Permalink
feat(trie): faster header decoding (#2649)
Browse files Browse the repository at this point in the history
- Performance increases as number of variants increases
- Changed from slice to fixed size array (thanks Eclesio)
- Update comments
- Add test verifying slice is sorted by bit mask
- Update relevant benchmark for 7 variants
  • Loading branch information
qdm12 authored Jul 26, 2022
1 parent cb1da40 commit d9460e3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
35 changes: 18 additions & 17 deletions internal/trie/node/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,29 @@ func decodeHeader(reader io.Reader) (variant byte,

var ErrVariantUnknown = errors.New("node variant is unknown")

// variantsOrderedByBitMask is an array of all variants sorted
// in ascending order by the number of LHS set bits each variant mask has.
// See https://spec.polkadot.network/#defn-node-header
// WARNING: DO NOT MUTATE.
// This array is defined at global scope for performance
// reasons only, instead of having it locally defined in
// the decodeHeaderByte function below.
// For 7 variants, the performance is improved by ~20%.
var variantsOrderedByBitMask = [...]variant{
leafVariant, // mask 1100_0000
branchVariant, // mask 1100_0000
branchWithValueVariant, // mask 1100_0000
}

func decodeHeaderByte(header byte) (variantBits,
partialKeyLengthHeader, partialKeyLengthHeaderMask byte, err error) {
// variants is a slice of all variants sorted in ascending
// order by the number of bits each variant mask occupy
// in the header byte.
// See https://spec.polkadot.network/#defn-node-header
// Performance note: see `Benchmark_decodeHeaderByte`;
// running with a locally scoped slice is as fast as having
// it at global scope.
variants := []variant{
leafVariant, // mask 1100_0000
branchVariant, // mask 1100_0000
branchWithValueVariant, // mask 1100_0000
}

for i := len(variants) - 1; i >= 0; i-- {
variantBits = header & variants[i].mask
if variantBits != variants[i].bits {
for i := len(variantsOrderedByBitMask) - 1; i >= 0; i-- {
variantBits = header & variantsOrderedByBitMask[i].mask
if variantBits != variantsOrderedByBitMask[i].bits {
continue
}

partialKeyLengthHeaderMask = ^variants[i].mask
partialKeyLengthHeaderMask = ^variantsOrderedByBitMask[i].mask
partialKeyLengthHeader = header & partialKeyLengthHeaderMask
return variantBits, partialKeyLengthHeader,
partialKeyLengthHeaderMask, nil
Expand Down
21 changes: 19 additions & 2 deletions internal/trie/node/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"io"
"math"
"sort"
"testing"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -419,11 +420,27 @@ func Test_decodeHeaderByte(t *testing.T) {
}
}

func Test_variantsOrderedByBitMask(t *testing.T) {
t.Parallel()

slice := make([]variant, len(variantsOrderedByBitMask))
sortedSlice := make([]variant, len(variantsOrderedByBitMask))
copy(slice, variantsOrderedByBitMask[:])
copy(sortedSlice, variantsOrderedByBitMask[:])

sort.Slice(slice, func(i, j int) bool {
return slice[i].mask > slice[j].mask
})

assert.Equal(t, sortedSlice, slice)
}

func Benchmark_decodeHeaderByte(b *testing.B) {
// For 7 variants defined in the variants array:
// With global scoped variants slice:
// 3.453 ns/op 0 B/op 0 allocs/op
// 2.987 ns/op 0 B/op 0 allocs/op
// With locally scoped variants slice:
// 3.441 ns/op 0 B/op 0 allocs/op
// 3.873 ns/op 0 B/op 0 allocs/op
header := leafVariant.bits | 0b0000_0001
b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand Down

0 comments on commit d9460e3

Please sign in to comment.