Skip to content

Commit

Permalink
Retrieve shard tree length if it isn't provided in the config (sigsto…
Browse files Browse the repository at this point in the history
…re#810)

* Get sharding details when configuring API

Signed-off-by: Priya Wadhwa <[email protected]>

* If tree length isn't provided then retrieve it

Signed-off-by: Priya Wadhwa <[email protected]>
  • Loading branch information
priyawadhwa authored May 7, 2022
1 parent c3b87cd commit 5ed77ae
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 26 deletions.
10 changes: 1 addition & 9 deletions cmd/rekor-server/app/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/sigstore/rekor/pkg/generated/restapi"
"github.com/sigstore/rekor/pkg/generated/restapi/operations"
"github.com/sigstore/rekor/pkg/log"
"github.com/sigstore/rekor/pkg/sharding"
"github.com/sigstore/rekor/pkg/types/alpine"
alpine_v001 "github.com/sigstore/rekor/pkg/types/alpine/v0.0.1"
hashedrekord "github.com/sigstore/rekor/pkg/types/hashedrekord"
Expand Down Expand Up @@ -104,16 +103,9 @@ var serveCmd = &cobra.Command{
server.Port = int(viper.GetUint("port"))
server.EnabledListeners = []string{"http"}

// Update logRangeMap if flag was passed in
shardingConfig := viper.GetString("trillian_log_server.sharding_config")
treeID := viper.GetUint("trillian_log_server.tlog_id")

ranges, err := sharding.NewLogRanges(shardingConfig, treeID)
if err != nil {
log.Logger.Fatalf("unable get sharding details from sharding config: %v", err)
}

api.ConfigureAPI(ranges, treeID)
api.ConfigureAPI(treeID)
server.ConfigureAPI()

http.Handle("/metrics", promhttp.Handler())
Expand Down
12 changes: 9 additions & 3 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type API struct {
certChainPem string // PEM encoded timestamping cert chain
}

func NewAPI(ranges sharding.LogRanges, treeID uint) (*API, error) {
func NewAPI(treeID uint) (*API, error) {
logRPCServer := fmt.Sprintf("%s:%d",
viper.GetString("trillian_log_server.address"),
viper.GetUint("trillian_log_server.port"))
Expand All @@ -78,6 +78,12 @@ func NewAPI(ranges sharding.LogRanges, treeID uint) (*API, error) {
logAdminClient := trillian.NewTrillianAdminClient(tConn)
logClient := trillian.NewTrillianLogClient(tConn)

shardingConfig := viper.GetString("trillian_log_server.sharding_config")
ranges, err := sharding.NewLogRanges(ctx, logClient, shardingConfig, treeID)
if err != nil {
return nil, errors.Wrap(err, "unable get sharding details from sharding config")
}

tid := int64(treeID)
if tid == 0 {
log.Logger.Info("No tree ID specified, attempting to create a new tree")
Expand Down Expand Up @@ -160,11 +166,11 @@ var (
storageClient storage.AttestationStorage
)

func ConfigureAPI(ranges sharding.LogRanges, treeID uint) {
func ConfigureAPI(treeID uint) {
cfg := radix.PoolConfig{}
var err error

api, err = NewAPI(ranges, treeID)
api, err = NewAPI(treeID)
if err != nil {
log.Logger.Panic(err)
}
Expand Down
60 changes: 48 additions & 12 deletions pkg/sharding/ranges.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
package sharding

import (
"context"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"strconv"
"strings"

"github.com/ghodss/yaml"
"github.com/google/trillian"
"github.com/google/trillian/types"
"github.com/pkg/errors"
"github.com/sigstore/rekor/pkg/log"
)

Expand All @@ -41,7 +44,7 @@ type LogRange struct {
decodedPublicKey string
}

func NewLogRanges(path string, treeID uint) (LogRanges, error) {
func NewLogRanges(ctx context.Context, logClient trillian.TrillianLogClient, path string, treeID uint) (LogRanges, error) {
if path == "" {
log.Logger.Info("No config file specified, skipping init of logRange map")
return LogRanges{}, nil
Expand All @@ -50,30 +53,63 @@ func NewLogRanges(path string, treeID uint) (LogRanges, error) {
return LogRanges{}, errors.New("non-zero tlog_id required when passing in shard config filepath; please set the active tree ID via the `--trillian_log_server.tlog_id` flag")
}
// otherwise, try to read contents of the sharding config
ranges, err := logRangesFromPath(path)
if err != nil {
return LogRanges{}, errors.Wrap(err, "log ranges from path")
}
for i, r := range ranges {
r, err := updateRange(ctx, logClient, r)
if err != nil {
return LogRanges{}, errors.Wrapf(err, "updating range for tree id %d", r.TreeID)
}
ranges[i] = r
}
log.Logger.Info("Ranges: %v", ranges)
return LogRanges{
inactive: ranges,
active: int64(treeID),
}, nil
}

func logRangesFromPath(path string) (Ranges, error) {
var ranges Ranges
contents, err := ioutil.ReadFile(path)
if err != nil {
return LogRanges{}, err
return Ranges{}, err
}
if string(contents) == "" {
log.Logger.Info("Sharding config file contents empty, skipping init of logRange map")
return LogRanges{}, nil
return Ranges{}, nil
}
if err := yaml.Unmarshal(contents, &ranges); err != nil {
return LogRanges{}, err
return Ranges{}, err
}
for i, r := range ranges {
return ranges, nil
}

// updateRange fills in any missing information about the range
func updateRange(ctx context.Context, logClient trillian.TrillianLogClient, r LogRange) (LogRange, error) {
// If a tree length wasn't passed in, get it ourselves
if r.TreeLength == 0 {
resp, err := logClient.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{LogId: r.TreeID})
if err != nil {
return LogRange{}, errors.Wrapf(err, "getting signed log root for tree %d", r.TreeID)
}
var root types.LogRootV1
if err := root.UnmarshalBinary(resp.SignedLogRoot.LogRoot); err != nil {
return LogRange{}, err
}
r.TreeLength = int64(root.TreeSize)
}
// If a public key was provided, decode it
if r.EncodedPublicKey != "" {
decoded, err := base64.StdEncoding.DecodeString(r.EncodedPublicKey)
if err != nil {
return LogRanges{}, err
return LogRange{}, err
}
r.decodedPublicKey = string(decoded)
ranges[i] = r
}
return LogRanges{
inactive: ranges,
active: int64(treeID),
}, nil
return r, nil
}

func (l *LogRanges) ResolveVirtualIndex(index int) (int64, int64) {
Expand Down
39 changes: 38 additions & 1 deletion pkg/sharding/ranges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
package sharding

import (
"context"
"io/ioutil"
"path/filepath"
"reflect"
"testing"

"github.com/google/trillian"
"google.golang.org/grpc"
)

func TestNewLogRanges(t *testing.T) {
Expand Down Expand Up @@ -47,7 +51,9 @@ func TestNewLogRanges(t *testing.T) {
}},
active: int64(45),
}
got, err := NewLogRanges(file, treeID)
ctx := context.Background()
tc := trillian.NewTrillianLogClient(&grpc.ClientConn{})
got, err := NewLogRanges(ctx, tc, file, treeID)
if err != nil {
t.Fatal(err)
}
Expand All @@ -59,6 +65,37 @@ func TestNewLogRanges(t *testing.T) {
}
}

func TestLogRangesFromPath(t *testing.T) {
contents := `
- treeID: 0001
treeLength: 3
encodedPublicKey: c2hhcmRpbmcK
- treeID: 0002
treeLength: 4`
file := filepath.Join(t.TempDir(), "sharding-config")
if err := ioutil.WriteFile(file, []byte(contents), 0644); err != nil {
t.Fatal(err)
}
expected := Ranges{
{
TreeID: 1,
TreeLength: 3,
EncodedPublicKey: "c2hhcmRpbmcK",
}, {
TreeID: 2,
TreeLength: 4,
},
}

got, err := logRangesFromPath(file)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("expected %v got %v", expected, got)
}
}

func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
lrs := LogRanges{
inactive: []LogRange{
Expand Down
1 change: 0 additions & 1 deletion tests/sharding-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ docker stop $REKOR_CONTAINER_ID
SHARDING_CONFIG=sharding-config.yaml
cat << EOF > $SHARDING_CONFIG
- treeID: $INITIAL_TREE_ID
treeLength: 3
encodedPublicKey: $ENCODED_PUBLIC_KEY
EOF

Expand Down

0 comments on commit 5ed77ae

Please sign in to comment.