From 0173fe002fb3c3b7f15158ed2c9b1a3bd5a8dc93 Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Tue, 17 Dec 2019 08:16:44 +0100 Subject: [PATCH] TrustStore: Add Implementation for GetASKey Fixes #3524 --- go/lib/infra/modules/trust/v2/provider.go | 50 +++++++----- .../infra/modules/trust/v2/provider_test.go | 79 +++++++++++++++++++ 2 files changed, 110 insertions(+), 19 deletions(-) diff --git a/go/lib/infra/modules/trust/v2/provider.go b/go/lib/infra/modules/trust/v2/provider.go index 6506eac0c1..f5c2ac8c34 100644 --- a/go/lib/infra/modules/trust/v2/provider.go +++ b/go/lib/infra/modules/trust/v2/provider.go @@ -25,6 +25,7 @@ import ( "github.com/scionproto/scion/go/lib/infra" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/cert/v2" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/serrors" ) @@ -49,8 +50,8 @@ type CryptoProvider interface { // not available locally. Otherwise, the default server is queried. How the // default server is determined differs between implementations. GetRawChain(context.Context, ChainID, infra.ChainOpts) ([]byte, error) - //GetASKey returns from trust store the public key required to verify signature - //originated from an AS. + // GetASKey returns from trust store the public key required to verify + // signature originated from an AS. GetASKey(context.Context, ChainID, infra.ChainOpts) (scrypto.KeyMeta, error) } @@ -184,35 +185,53 @@ func (p *cryptoProvider) fetchTRC(ctx context.Context, id TRCID, func (p *cryptoProvider) GetRawChain(ctx context.Context, id ChainID, opts infra.ChainOpts) ([]byte, error) { + chain, err := p.getCheckedChain(ctx, id, opts) + return chain.Raw, err +} + +func (p *cryptoProvider) GetASKey(ctx context.Context, + id ChainID, opts infra.ChainOpts) (scrypto.KeyMeta, error) { + + chain, err := p.getCheckedChain(ctx, id, opts) + if err != nil { + return scrypto.KeyMeta{}, err + } + return chain.AS.Keys[cert.SigningKey], nil +} + +func (p *cryptoProvider) getCheckedChain(ctx context.Context, id ChainID, + opts infra.ChainOpts) (decoded.Chain, error) { + chain, err := p.getChain(ctx, id, opts) if err != nil { - return nil, serrors.WrapStr("unable to get requested certificate chain", err) + return decoded.Chain{}, serrors.WrapStr("unable to get requested certificate chain", err) } if opts.AllowInactive { - return chain.Raw, nil + return chain, nil } err = p.issuerActive(ctx, chain, opts.TrustStoreOpts) switch { case err == nil: - return chain.Raw, nil + return chain, nil case !xerrors.Is(err, ErrInactive): - return nil, err + return decoded.Chain{}, err case !id.Version.IsLatest(): - return nil, err + return decoded.Chain{}, err case opts.LocalOnly: - return nil, err + return decoded.Chain{}, err default: // In case the latest certificate chain is requested, there might be a more // recent and active one that is not locally available yet. fetched, err := p.fetchChain(ctx, id, opts) if err != nil { - return nil, serrors.WrapStr("unable to fetch latest certificate chain from network", - err) + return decoded.Chain{}, + serrors.WrapStr("unable to fetch latest certificate chain from network", err) } if err := p.issuerActive(ctx, fetched, opts.TrustStoreOpts); err != nil { - return nil, serrors.WrapStr("latest certificate chain from network not active", err) + return decoded.Chain{}, + serrors.WrapStr("latest certificate chain from network not active", err) } - return fetched.Raw, nil + return fetched, nil } } @@ -303,13 +322,6 @@ func (p *cryptoProvider) fetchChain(ctx context.Context, id ChainID, return chain, nil } -func (p *cryptoProvider) GetASKey(ctx context.Context, - id ChainID, opts infra.ChainOpts) (scrypto.KeyMeta, error) { - - // TODO(karampok): implement. - return scrypto.KeyMeta{}, serrors.New("not implemented") -} - func graceExpired(info TRCInfo) bool { return time.Now().After(info.Validity.NotBefore.Add(info.GracePeriod)) } diff --git a/go/lib/infra/modules/trust/v2/provider_test.go b/go/lib/infra/modules/trust/v2/provider_test.go index d18e4eb5ec..18ec94155b 100644 --- a/go/lib/infra/modules/trust/v2/provider_test.go +++ b/go/lib/infra/modules/trust/v2/provider_test.go @@ -1248,3 +1248,82 @@ func TestCryptoProviderGetRawChain(t *testing.T) { }) } } + +func TestCryptoProviderGetASKey(t *testing.T) { + internal := serrors.New("internal") + dec110v1 := loadChain(t, chain110v1) + tests := map[string]struct { + DB func(t *testing.T, ctrl *gomock.Controller) trust.DB + Recurser func(t *testing.T, ctrl *gomock.Controller) trust.Recurser + Resolver func(t *testing.T, ctrl *gomock.Controller) trust.Resolver + Router func(t *testing.T, ctrl *gomock.Controller) trust.Router + ChainDesc ChainDesc + Opts infra.ChainOpts + ExpectedErr error + ExpectedKeyMeta scrypto.KeyMeta + }{ + "chain in database, allow inactive": { + DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { + db := mock_v2.NewMockDB(ctrl) + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( + dec110v1.Raw, nil, + ) + return db + }, + Recurser: func(t *testing.T, ctrl *gomock.Controller) trust.Recurser { + return mock_v2.NewMockRecurser(ctrl) + }, + Resolver: func(t *testing.T, ctrl *gomock.Controller) trust.Resolver { + return mock_v2.NewMockResolver(ctrl) + }, + Router: func(t *testing.T, ctrl *gomock.Controller) trust.Router { + return mock_v2.NewMockRouter(ctrl) + }, + ChainDesc: chain110v1, + Opts: infra.ChainOpts{AllowInactive: true}, + ExpectedKeyMeta: dec110v1.AS.Keys[cert.SigningKey], + }, + "database error": { + DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { + db := mock_v2.NewMockDB(ctrl) + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( + nil, internal, + ) + return db + }, + Recurser: func(t *testing.T, ctrl *gomock.Controller) trust.Recurser { + return mock_v2.NewMockRecurser(ctrl) + }, + Resolver: func(t *testing.T, ctrl *gomock.Controller) trust.Resolver { + return mock_v2.NewMockResolver(ctrl) + }, + Router: func(t *testing.T, ctrl *gomock.Controller) trust.Router { + return mock_v2.NewMockRouter(ctrl) + }, + ChainDesc: chain110v1, + Opts: infra.ChainOpts{}, + ExpectedErr: internal, + ExpectedKeyMeta: scrypto.KeyMeta{}, + }, + } + for n, tc := range tests { + name, test := n, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mctrl := gomock.NewController(t) + defer mctrl.Finish() + p := trust.NewCryptoProvider( + test.DB(t, mctrl), + test.Recurser(t, mctrl), + test.Resolver(t, mctrl), + test.Router(t, mctrl)) + km, err := p.GetASKey(nil, + trust.ChainID{IA: test.ChainDesc.IA, Version: test.ChainDesc.Version}, + test.Opts) + xtest.AssertErrorsIs(t, err, test.ExpectedErr) + assert.Equal(t, test.ExpectedKeyMeta, km) + }) + } +}