diff --git a/go/lib/infra/modules/trust/v2/BUILD.bazel b/go/lib/infra/modules/trust/v2/BUILD.bazel index 327efadb70..43f52bfd12 100644 --- a/go/lib/infra/modules/trust/v2/BUILD.bazel +++ b/go/lib/infra/modules/trust/v2/BUILD.bazel @@ -23,18 +23,35 @@ go_library( "//go/lib/scrypto:go_default_library", "//go/lib/scrypto/trc/v2:go_default_library", "//go/lib/serrors:go_default_library", + "@org_golang_x_xerrors//:go_default_library", ], ) go_test( name = "go_default_test", - srcs = ["main_test.go"], + srcs = [ + "export_test.go", + "main_test.go", + "provider_test.go", + ], data = [ "//go/lib/infra/modules/trust/v2/testdata:crypto_tar", ], embed = [":go_default_library"], deps = [ + "//go/lib/addr:go_default_library", + "//go/lib/infra:go_default_library", + "//go/lib/infra/modules/trust/v2/internal/decoded:go_default_library", + "//go/lib/infra/modules/trust/v2/mock_v2:go_default_library", "//go/lib/log:go_default_library", + "//go/lib/scrypto:go_default_library", + "//go/lib/scrypto/trc/v2:go_default_library", + "//go/lib/serrors:go_default_library", + "//go/lib/util:go_default_library", "//go/lib/xtest:go_default_library", + "@com_github_golang_mock//gomock:go_default_library", + "@com_github_stretchr_testify//assert:go_default_library", + "@com_github_stretchr_testify//require:go_default_library", + "@org_golang_x_xerrors//:go_default_library", ], ) diff --git a/go/lib/infra/modules/trust/v2/export_test.go b/go/lib/infra/modules/trust/v2/export_test.go new file mode 100644 index 0000000000..38b503fde5 --- /dev/null +++ b/go/lib/infra/modules/trust/v2/export_test.go @@ -0,0 +1,32 @@ +// Copyright 2019 Anapaya Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trust + +// NewCryptoProvider allows instantiating the private cryptoProvider for +// black-box testing. +var NewCryptoProvider = newTestCryptoProvider + +// newTestCryptoProvider returns a new crypto provider for testing. +func newTestCryptoProvider(db DBRead, recurser Recurser, resolver Resolver, router Router, + alwaysCacheOnly bool) CryptoProvider { + + return &cryptoProvider{ + db: db, + recurser: recurser, + resolver: resolver, + router: router, + alwaysCacheOnly: alwaysCacheOnly, + } +} diff --git a/go/lib/infra/modules/trust/v2/internal/decoded/BUILD.bazel b/go/lib/infra/modules/trust/v2/internal/decoded/BUILD.bazel index 27832d9f02..263739ae7c 100644 --- a/go/lib/infra/modules/trust/v2/internal/decoded/BUILD.bazel +++ b/go/lib/infra/modules/trust/v2/internal/decoded/BUILD.bazel @@ -8,5 +8,6 @@ go_library( deps = [ "//go/lib/scrypto/cert/v2:go_default_library", "//go/lib/scrypto/trc/v2:go_default_library", + "//go/lib/serrors:go_default_library", ], ) diff --git a/go/lib/infra/modules/trust/v2/internal/decoded/decode.go b/go/lib/infra/modules/trust/v2/internal/decoded/decode.go index 97d70d8889..bc3b60aa86 100644 --- a/go/lib/infra/modules/trust/v2/internal/decoded/decode.go +++ b/go/lib/infra/modules/trust/v2/internal/decoded/decode.go @@ -19,8 +19,12 @@ import ( "github.com/scionproto/scion/go/lib/scrypto/cert/v2" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + "github.com/scionproto/scion/go/lib/serrors" ) +// ErrParse indicates that parsign failed. +var ErrParse = serrors.New("parse error") + // TRC is a container for the decoded TRC. type TRC struct { TRC *trc.TRC @@ -28,6 +32,24 @@ type TRC struct { Raw []byte } +// DecodeTRC decodes the TRC. +func DecodeTRC(raw []byte) (TRC, error) { + signed, err := trc.ParseSigned(raw) + if err != nil { + return TRC{}, serrors.WithCtx(ErrParse, err, "part", "signed") + } + decoded, err := signed.EncodedTRC.Decode() + if err != nil { + return TRC{}, serrors.WithCtx(ErrParse, err, "part", "decode payload") + } + d := TRC{ + TRC: decoded, + Signed: signed, + Raw: raw, + } + return d, nil +} + func (d TRC) String() string { if d.TRC == nil { return "" diff --git a/go/lib/infra/modules/trust/v2/main_test.go b/go/lib/infra/modules/trust/v2/main_test.go index 700bd74697..cae7e5f63b 100644 --- a/go/lib/infra/modules/trust/v2/main_test.go +++ b/go/lib/infra/modules/trust/v2/main_test.go @@ -16,17 +16,53 @@ package trust_test import ( "fmt" + "io/ioutil" "os" "os/exec" + "path/filepath" "testing" + "github.com/stretchr/testify/require" + + "github.com/scionproto/scion/go/lib/addr" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/log" + "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/xtest" ) // tmpDir contains the generated crypto material. var tmpDir string +type TRCDesc struct { + ISD addr.ISD + Version scrypto.Version +} + +func (desc TRCDesc) File() string { + return fmt.Sprintf("ISD%d/trcs/ISD%d-V%d.trc", desc.ISD, desc.ISD, desc.Version) +} + +var ( + trc1v1 = TRCDesc{ISD: 1, Version: 1} + + // primary ASes + ia110 = xtest.MustParseIA("1-ff00:0:110") + ia120 = xtest.MustParseIA("1-ff00:0:120") + ia130 = xtest.MustParseIA("1-ff00:0:130") +) + +var ( + trc2v1 = TRCDesc{ISD: 2, Version: 1} + + // primary ASes + ia210 = xtest.MustParseIA("2-ff00:0:210") + + // non-primary ASes + ia122 = xtest.MustParseIA("1-ff00:0:122") +) + func TestMain(m *testing.M) { var cleanF func() tmpDir, cleanF = xtest.MustTempDir("", "test-trust") @@ -41,3 +77,19 @@ func TestMain(m *testing.M) { log.Root().SetHandler(log.DiscardHandler()) os.Exit(m.Run()) } + +func loadTRC(t *testing.T, desc TRCDesc) decoded.TRC { + t.Helper() + file := filepath.Join(tmpDir, desc.File()) + raw, err := ioutil.ReadFile(file) + require.NoError(t, err) + signed, err := trc.ParseSigned(raw) + require.NoError(t, err) + trcObj, err := signed.EncodedTRC.Decode() + require.NoError(t, err) + return decoded.TRC{ + Raw: raw, + Signed: signed, + TRC: trcObj, + } +} diff --git a/go/lib/infra/modules/trust/v2/mock_v2/BUILD.bazel b/go/lib/infra/modules/trust/v2/mock_v2/BUILD.bazel new file mode 100644 index 0000000000..9ec41df2c4 --- /dev/null +++ b/go/lib/infra/modules/trust/v2/mock_v2/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["v2.go"], + importpath = "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2", + visibility = ["//visibility:public"], + deps = [ + "//go/lib/addr:go_default_library", + "//go/lib/infra/modules/trust/v2:go_default_library", + "//go/lib/infra/modules/trust/v2/internal/decoded:go_default_library", + "//go/lib/scrypto:go_default_library", + "//go/lib/scrypto/trc/v2:go_default_library", + "@com_github_golang_mock//gomock:go_default_library", + ], +) diff --git a/go/lib/infra/modules/trust/v2/mock_v2/v2.go b/go/lib/infra/modules/trust/v2/mock_v2/v2.go new file mode 100644 index 0000000000..8a5836ed1f --- /dev/null +++ b/go/lib/infra/modules/trust/v2/mock_v2/v2.go @@ -0,0 +1,343 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/scionproto/scion/go/lib/infra/modules/trust/v2 (interfaces: DB,Recurser,Resolver,Router) + +// Package mock_v2 is a generated GoMock package. +package mock_v2 + +import ( + context "context" + sql "database/sql" + gomock "github.com/golang/mock/gomock" + addr "github.com/scionproto/scion/go/lib/addr" + v2 "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" + decoded "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" + scrypto "github.com/scionproto/scion/go/lib/scrypto" + v20 "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + net "net" + reflect "reflect" +) + +// MockDB is a mock of DB interface +type MockDB struct { + ctrl *gomock.Controller + recorder *MockDBMockRecorder +} + +// MockDBMockRecorder is the mock recorder for MockDB +type MockDBMockRecorder struct { + mock *MockDB +} + +// NewMockDB creates a new mock instance +func NewMockDB(ctrl *gomock.Controller) *MockDB { + mock := &MockDB{ctrl: ctrl} + mock.recorder = &MockDBMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockDB) EXPECT() *MockDBMockRecorder { + return m.recorder +} + +// BeginTransaction mocks base method +func (m *MockDB) BeginTransaction(arg0 context.Context, arg1 *sql.TxOptions) (v2.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTransaction", arg0, arg1) + ret0, _ := ret[0].(v2.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTransaction indicates an expected call of BeginTransaction +func (mr *MockDBMockRecorder) BeginTransaction(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTransaction", reflect.TypeOf((*MockDB)(nil).BeginTransaction), arg0, arg1) +} + +// ChainExists mocks base method +func (m *MockDB) ChainExists(arg0 context.Context, arg1 decoded.TRC) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ChainExists", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ChainExists indicates an expected call of ChainExists +func (mr *MockDBMockRecorder) ChainExists(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChainExists", reflect.TypeOf((*MockDB)(nil).ChainExists), arg0, arg1) +} + +// Close mocks base method +func (m *MockDB) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockDBMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDB)(nil).Close)) +} + +// GetRawChain mocks base method +func (m *MockDB) GetRawChain(arg0 context.Context, arg1 addr.IA, arg2 scrypto.Version) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRawChain", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRawChain indicates an expected call of GetRawChain +func (mr *MockDBMockRecorder) GetRawChain(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChain", reflect.TypeOf((*MockDB)(nil).GetRawChain), arg0, arg1, arg2) +} + +// GetRawTRC mocks base method +func (m *MockDB) GetRawTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRawTRC", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRawTRC indicates an expected call of GetRawTRC +func (mr *MockDBMockRecorder) GetRawTRC(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawTRC", reflect.TypeOf((*MockDB)(nil).GetRawTRC), arg0, arg1, arg2) +} + +// GetTRC mocks base method +func (m *MockDB) GetTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) (*v20.TRC, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTRC", arg0, arg1, arg2) + ret0, _ := ret[0].(*v20.TRC) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTRC indicates an expected call of GetTRC +func (mr *MockDBMockRecorder) GetTRC(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRC", reflect.TypeOf((*MockDB)(nil).GetTRC), arg0, arg1, arg2) +} + +// GetTRCInfo mocks base method +func (m *MockDB) GetTRCInfo(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) (v2.TRCInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTRCInfo", arg0, arg1, arg2) + ret0, _ := ret[0].(v2.TRCInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTRCInfo indicates an expected call of GetTRCInfo +func (mr *MockDBMockRecorder) GetTRCInfo(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRCInfo", reflect.TypeOf((*MockDB)(nil).GetTRCInfo), arg0, arg1, arg2) +} + +// InsertChain mocks base method +func (m *MockDB) InsertChain(arg0 context.Context, arg1 decoded.Chain) (bool, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertChain", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// InsertChain indicates an expected call of InsertChain +func (mr *MockDBMockRecorder) InsertChain(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChain", reflect.TypeOf((*MockDB)(nil).InsertChain), arg0, arg1) +} + +// InsertTRC mocks base method +func (m *MockDB) InsertTRC(arg0 context.Context, arg1 decoded.TRC) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertTRC", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertTRC indicates an expected call of InsertTRC +func (mr *MockDBMockRecorder) InsertTRC(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTRC", reflect.TypeOf((*MockDB)(nil).InsertTRC), arg0, arg1) +} + +// SetMaxIdleConns mocks base method +func (m *MockDB) SetMaxIdleConns(arg0 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxIdleConns", arg0) +} + +// SetMaxIdleConns indicates an expected call of SetMaxIdleConns +func (mr *MockDBMockRecorder) SetMaxIdleConns(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxIdleConns", reflect.TypeOf((*MockDB)(nil).SetMaxIdleConns), arg0) +} + +// SetMaxOpenConns mocks base method +func (m *MockDB) SetMaxOpenConns(arg0 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetMaxOpenConns", arg0) +} + +// SetMaxOpenConns indicates an expected call of SetMaxOpenConns +func (mr *MockDBMockRecorder) SetMaxOpenConns(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxOpenConns", reflect.TypeOf((*MockDB)(nil).SetMaxOpenConns), arg0) +} + +// TRCExists mocks base method +func (m *MockDB) TRCExists(arg0 context.Context, arg1 decoded.TRC) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TRCExists", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TRCExists indicates an expected call of TRCExists +func (mr *MockDBMockRecorder) TRCExists(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TRCExists", reflect.TypeOf((*MockDB)(nil).TRCExists), arg0, arg1) +} + +// MockRecurser is a mock of Recurser interface +type MockRecurser struct { + ctrl *gomock.Controller + recorder *MockRecurserMockRecorder +} + +// MockRecurserMockRecorder is the mock recorder for MockRecurser +type MockRecurserMockRecorder struct { + mock *MockRecurser +} + +// NewMockRecurser creates a new mock instance +func NewMockRecurser(ctrl *gomock.Controller) *MockRecurser { + mock := &MockRecurser{ctrl: ctrl} + mock.recorder = &MockRecurserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRecurser) EXPECT() *MockRecurserMockRecorder { + return m.recorder +} + +// AllowRecursion mocks base method +func (m *MockRecurser) AllowRecursion(arg0 net.Addr) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllowRecursion", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AllowRecursion indicates an expected call of AllowRecursion +func (mr *MockRecurserMockRecorder) AllowRecursion(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowRecursion", reflect.TypeOf((*MockRecurser)(nil).AllowRecursion), arg0) +} + +// MockResolver is a mock of Resolver interface +type MockResolver struct { + ctrl *gomock.Controller + recorder *MockResolverMockRecorder +} + +// MockResolverMockRecorder is the mock recorder for MockResolver +type MockResolverMockRecorder struct { + mock *MockResolver +} + +// NewMockResolver creates a new mock instance +func NewMockResolver(ctrl *gomock.Controller) *MockResolver { + mock := &MockResolver{ctrl: ctrl} + mock.recorder = &MockResolverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockResolver) EXPECT() *MockResolverMockRecorder { + return m.recorder +} + +// Chain mocks base method +func (m *MockResolver) Chain(arg0 context.Context, arg1 v2.ChainReq, arg2 net.Addr) (decoded.Chain, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Chain", arg0, arg1, arg2) + ret0, _ := ret[0].(decoded.Chain) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Chain indicates an expected call of Chain +func (mr *MockResolverMockRecorder) Chain(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chain", reflect.TypeOf((*MockResolver)(nil).Chain), arg0, arg1, arg2) +} + +// TRC mocks base method +func (m *MockResolver) TRC(arg0 context.Context, arg1 v2.TRCReq, arg2 net.Addr) (decoded.TRC, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TRC", arg0, arg1, arg2) + ret0, _ := ret[0].(decoded.TRC) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TRC indicates an expected call of TRC +func (mr *MockResolverMockRecorder) TRC(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TRC", reflect.TypeOf((*MockResolver)(nil).TRC), arg0, arg1, arg2) +} + +// MockRouter is a mock of Router interface +type MockRouter struct { + ctrl *gomock.Controller + recorder *MockRouterMockRecorder +} + +// MockRouterMockRecorder is the mock recorder for MockRouter +type MockRouterMockRecorder struct { + mock *MockRouter +} + +// NewMockRouter creates a new mock instance +func NewMockRouter(ctrl *gomock.Controller) *MockRouter { + mock := &MockRouter{ctrl: ctrl} + mock.recorder = &MockRouterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRouter) EXPECT() *MockRouterMockRecorder { + return m.recorder +} + +// ChooseServer mocks base method +func (m *MockRouter) ChooseServer(arg0 context.Context, arg1 addr.ISD) (net.Addr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ChooseServer", arg0, arg1) + ret0, _ := ret[0].(net.Addr) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ChooseServer indicates an expected call of ChooseServer +func (mr *MockRouterMockRecorder) ChooseServer(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChooseServer", reflect.TypeOf((*MockRouter)(nil).ChooseServer), arg0, arg1) +} diff --git a/go/lib/infra/modules/trust/v2/provider.go b/go/lib/infra/modules/trust/v2/provider.go index bf0a8cf3a1..2a745c3c57 100644 --- a/go/lib/infra/modules/trust/v2/provider.go +++ b/go/lib/infra/modules/trust/v2/provider.go @@ -17,13 +17,21 @@ package trust import ( "context" "net" + "time" + + "golang.org/x/xerrors" "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/scrypto" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + "github.com/scionproto/scion/go/lib/serrors" ) +// ErrInactive indicates that the requested material is inactive. +var ErrInactive = serrors.New("inactive") + // CryptoProvider provides crypto material. A crypto provider can spawn network // requests if necessary and permitted. type CryptoProvider interface { @@ -45,3 +53,134 @@ type CryptoProvider interface { GetRawChain(ctx context.Context, ia addr.IA, version scrypto.Version, opts infra.ChainOpts, client net.Addr) ([]byte, error) } + +type cryptoProvider struct { + db DBRead + recurser Recurser + resolver Resolver + router Router + // alwaysCacheOnly forces the cryptoProvider to always send cache-only + // requests. This should be set in the CS. + alwaysCacheOnly bool +} + +func (p *cryptoProvider) GetTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, + opts infra.TRCOpts) (*trc.TRC, error) { + + t, _, err := p.getCheckedTRC(ctx, isd, version, opts, nil) + return t, err +} + +func (p *cryptoProvider) GetRawTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, + opts infra.TRCOpts, client net.Addr) ([]byte, error) { + + _, raw, err := p.getCheckedTRC(ctx, isd, version, opts, client) + return raw, err +} + +func (p *cryptoProvider) getCheckedTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, + opts infra.TRCOpts, client net.Addr) (*trc.TRC, []byte, error) { + + decTRC, err := p.getTRC(ctx, isd, version, opts, nil) + if err != nil { + return nil, nil, serrors.WrapStr("unable to get requested TRC", err) + } + if !opts.AllowInactive { + info, err := p.db.GetTRCInfo(ctx, isd, scrypto.LatestVer) + if err != nil { + return nil, nil, serrors.WrapStr("unable to get latest TRC info", err) + } + switch { + case info.Version > decTRC.TRC.Version+1: + return nil, nil, serrors.WrapStr("inactivated by latest TRC version", ErrInactive, + "latest", info.Version) + case info.Version == decTRC.TRC.Version+1 && graceExpired(info): + return nil, nil, serrors.WrapStr("grace period has passed", ErrInactive, + "end", info.Validity.NotBefore.Add(info.GracePeriod), "latest", info.Version) + case !decTRC.TRC.Validity.Contains(time.Now()): + if !version.IsLatest() || opts.LocalOnly { + return nil, nil, serrors.WrapStr("requested TRC expired", ErrInactive, + "validity", decTRC.TRC.Validity) + } + // There might exist a more recent TRC that is not available locally + // yet. Fetch it if the latest version was requested and recursion + // is allowed. + fetched, err := p.fetchTRC(ctx, isd, scrypto.LatestVer, opts, client) + if err != nil { + return nil, nil, serrors.WrapStr("unable to fetch latest TRC from network", err) + } + if fetched.TRC.Version <= decTRC.TRC.Version { + return nil, nil, serrors.WrapStr("latest TRC from network not newer than local", + ErrInactive, "net_version", fetched.TRC.Version, + "local_version", decTRC.TRC.Version, "validity", decTRC.TRC.Validity) + } + if !fetched.TRC.Validity.Contains(time.Now()) { + return nil, nil, serrors.WrapStr("latest TRC from network expired", ErrInactive, + "version", fetched.TRC.Version, "validity", fetched.TRC.Version) + } + return fetched.TRC, fetched.Raw, nil + } + } + return decTRC.TRC, decTRC.Raw, nil +} + +// getTRC attempts to grab the TRC from the database; if the TRC is not found, +// it follows up with a network request (if allowed). The options specify +// whether this function is allowed to create new network requests. Parameter +// client contains the node that caused the function to be called, or nil if the +// function was called due to a local feature. +func (p *cryptoProvider) getTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, + opts infra.TRCOpts, client net.Addr) (decoded.TRC, error) { + + raw, err := p.db.GetRawTRC(ctx, isd, version) + switch { + case err == nil: + return decoded.DecodeTRC(raw) + case !xerrors.Is(err, ErrNotFound): + return decoded.TRC{}, serrors.WrapStr("error querying DB for TRC", err) + case opts.LocalOnly: + return decoded.TRC{}, serrors.WrapStr("localOnly requested", err) + default: + return p.fetchTRC(ctx, isd, version, opts, client) + } +} + +// fetchTRC fetches a TRC via a network request, if allowed. +func (p *cryptoProvider) fetchTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, + opts infra.TRCOpts, client net.Addr) (decoded.TRC, error) { + + server := opts.Server + if err := p.recurser.AllowRecursion(client); err != nil { + return decoded.TRC{}, err + } + // In case the server is provided, cache-only should be set. + cacheOnly := server != nil || p.alwaysCacheOnly + req := TRCReq{ + ISD: isd, + Version: version, + CacheOnly: cacheOnly, + } + // Choose remote server, if not set. + if server == nil { + var err error + if server, err = p.router.ChooseServer(ctx, isd); err != nil { + return decoded.TRC{}, serrors.WrapStr("unable to route TRC request", err) + } + } + decTRC, err := p.resolver.TRC(ctx, req, server) + if err != nil { + return decoded.TRC{}, serrors.WrapStr("unable to resolve signed TRC from network", err) + } + return decTRC, nil +} + +func (p *cryptoProvider) GetRawChain(ctx context.Context, ia addr.IA, version scrypto.Version, + opts infra.ChainOpts, client net.Addr) ([]byte, error) { + + // TODO(roosd): implement. + return nil, serrors.New("not implemented") +} + +func graceExpired(info TRCInfo) bool { + return time.Now().After(info.Validity.NotBefore.Add(info.GracePeriod)) +} diff --git a/go/lib/infra/modules/trust/v2/provider_test.go b/go/lib/infra/modules/trust/v2/provider_test.go new file mode 100644 index 0000000000..375d63dc34 --- /dev/null +++ b/go/lib/infra/modules/trust/v2/provider_test.go @@ -0,0 +1,425 @@ +// Copyright 2019 Anapaya Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trust_test + +import ( + "encoding/json" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/scionproto/scion/go/lib/infra" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" + "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + "github.com/scionproto/scion/go/lib/serrors" + "github.com/scionproto/scion/go/lib/util" +) + +func TestCryptoProviderGetTRC(t *testing.T) { + internal := serrors.New("internal") + type mocks struct { + DB *mock_v2.MockDB + Recurser *mock_v2.MockRecurser + Resolver *mock_v2.MockResolver + Router *mock_v2.MockRouter + } + tests := map[string]struct { + Expect func(m *mocks, dec *decoded.TRC) + Opts infra.TRCOpts + ExpectedErr error + CacheOnly bool + }{ + "TRC in database, allow inactive": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + }, + Opts: infra.TRCOpts{AllowInactive: true}, + }, + "TRC in database, is newest": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + }, + }, + "TRC in database, within graceperiod": { + Expect: func(m *mocks, dec *decoded.TRC) { + info := trust.TRCInfo{ + Version: dec.TRC.Version + 1, + GracePeriod: time.Hour, + Validity: scrypto.Validity{NotBefore: util.UnixTime{Time: time.Now()}}, + } + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + info, nil, + ) + }, + }, + "not found, resolve success": { + Expect: func(m *mocks, dec *decoded.TRC) { + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: dec.TRC.Version, + CacheOnly: true, + } + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(*dec, nil) + }, + Opts: infra.TRCOpts{ + TrustStoreOpts: infra.TrustStoreOpts{ + Server: &net.IPAddr{IP: []byte{127, 0, 0, 1}}, + }, + AllowInactive: true, + }, + }, + "TRC in database, newest but expired": { + Expect: func(m *mocks, dec *decoded.TRC) { + dec.TRC.Validity.NotAfter.Time = time.Now() + dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) + dec.Raw, _ = json.Marshal(dec.Signed) + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + }, + ExpectedErr: trust.ErrInactive, + }, + "TRC in database, invalidated by newer": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version + 2}, nil, + ) + }, + ExpectedErr: trust.ErrInactive, + }, + "TRC in database, outside graceperiod": { + Expect: func(m *mocks, dec *decoded.TRC) { + info := trust.TRCInfo{ + Version: dec.TRC.Version + 1, + GracePeriod: time.Second, + Validity: scrypto.Validity{NotBefore: dec.TRC.Validity.NotBefore}, + } + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + info, nil, + ) + }, + ExpectedErr: trust.ErrInactive, + }, + "DB error": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, internal, + ) + }, + ExpectedErr: internal, + }, + "Fail getting TRC info": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{}, internal, + ) + }, + ExpectedErr: internal, + }, + "not found, local only": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + }, + Opts: infra.TRCOpts{TrustStoreOpts: infra.TrustStoreOpts{LocalOnly: true}}, + ExpectedErr: trust.ErrNotFound, + }, + "not found, recursion not allowed": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(internal) + }, + ExpectedErr: internal, + }, + "not found, router error": { + Expect: func(m *mocks, dec *decoded.TRC) { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(nil, internal) + }, + ExpectedErr: internal, + }, + "not found, resolve error": { + Expect: func(m *mocks, dec *decoded.TRC) { + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(ip, nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: dec.TRC.Version, + CacheOnly: false, + } + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(decoded.TRC{}, internal) + }, + ExpectedErr: internal, + }, + "not found, server set": { + Expect: func(m *mocks, dec *decoded.TRC) { + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: dec.TRC.Version, + CacheOnly: true, + } + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(decoded.TRC{}, internal) + }, + Opts: infra.TRCOpts{TrustStoreOpts: infra.TrustStoreOpts{ + Server: &net.IPAddr{IP: []byte{127, 0, 0, 1}}}, + }, + ExpectedErr: internal, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + m := mocks{ + DB: mock_v2.NewMockDB(mctrl), + Recurser: mock_v2.NewMockRecurser(mctrl), + Resolver: mock_v2.NewMockResolver(mctrl), + Router: mock_v2.NewMockRouter(mctrl), + } + decoded := loadTRC(t, trc1v1) + test.Expect(&m, &decoded) + provider := trust.NewCryptoProvider(m.DB, m.Recurser, m.Resolver, m.Router, false) + ptrc, err := provider.GetTRC(nil, trc1v1.ISD, trc1v1.Version, test.Opts) + if test.ExpectedErr != nil { + require.Error(t, err) + assert.Truef(t, xerrors.Is(err, test.ExpectedErr), + "actual: %s expected: %s", err, test.ExpectedErr) + } else { + require.NoError(t, err) + assert.Equal(t, decoded.TRC, ptrc) + } + }) + } +} + +func TestCryptoProviderGetTRCLatest(t *testing.T) { + internal := serrors.New("internal") + type mocks struct { + DB *mock_v2.MockDB + Recurser *mock_v2.MockRecurser + Resolver *mock_v2.MockResolver + Router *mock_v2.MockRouter + } + tests := map[string]struct { + Expect func(m *mocks, dec *decoded.TRC) decoded.TRC + Opts infra.TRCOpts + ExpectedErr error + CacheOnly bool + }{ + "TRC in database, allow inactive": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + dec.Raw, nil, + ) + return *dec + }, + Opts: infra.TRCOpts{AllowInactive: true}, + }, + "not found, resolve success": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + nil, trust.ErrNotFound, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(ip, nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: scrypto.Version(scrypto.LatestVer), + CacheOnly: false, + } + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(*dec, nil) + return *dec + }, + Opts: infra.TRCOpts{ + AllowInactive: true, + }, + }, + "newest expired, recursion not allowed": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + dec.TRC.Validity.NotAfter.Time = time.Now() + dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) + dec.Raw, _ = json.Marshal(dec.Signed) + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(internal) + return decoded.TRC{} + }, + ExpectedErr: internal, + }, + "newest expired, network returns same": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + dec.TRC.Validity.NotAfter.Time = time.Now() + dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) + dec.Raw, _ = json.Marshal(dec.Signed) + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(ip, nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: scrypto.Version(scrypto.LatestVer), + CacheOnly: false, + } + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(*dec, nil) + return decoded.TRC{} + }, + ExpectedErr: trust.ErrInactive, + }, + "newest expired, network returns expired": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + dec.TRC.Validity.NotAfter.Time = time.Now() + dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) + dec.Raw, _ = json.Marshal(dec.Signed) + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(ip, nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: scrypto.Version(scrypto.LatestVer), + CacheOnly: false, + } + newer := decoded.TRC{TRC: &(*dec.TRC)} + newer.TRC.Version += 1 + newer.Signed.EncodedTRC, _ = trc.Encode(newer.TRC) + newer.Raw, _ = json.Marshal(newer.Signed) + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(newer, nil) + return decoded.TRC{} + }, + ExpectedErr: trust.ErrInactive, + }, + "newest expired, network returns newer": { + Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { + dec.TRC.Validity.NotAfter.Time = time.Now() + dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) + dec.Raw, _ = json.Marshal(dec.Signed) + m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + dec.Raw, nil, + ) + m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + trust.TRCInfo{Version: dec.TRC.Version}, nil, + ) + m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) + ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} + m.Router.EXPECT().ChooseServer(gomock.Any(), dec.TRC.ISD).Return(ip, nil) + req := trust.TRCReq{ + ISD: dec.TRC.ISD, + Version: scrypto.Version(scrypto.LatestVer), + CacheOnly: false, + } + newer := decoded.TRC{TRC: &(*dec.TRC)} + newer.TRC.Version += 1 + newer.TRC.Validity = &scrypto.Validity{ + NotAfter: util.UnixTime{Time: time.Now().Add(1000 * time.Hour)}, + } + newer.Signed.EncodedTRC, _ = trc.Encode(newer.TRC) + newer.Raw, _ = json.Marshal(newer.Signed) + m.Resolver.EXPECT().TRC(gomock.Any(), req, ip).Return(newer, nil) + return newer + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + m := mocks{ + DB: mock_v2.NewMockDB(mctrl), + Recurser: mock_v2.NewMockRecurser(mctrl), + Resolver: mock_v2.NewMockResolver(mctrl), + Router: mock_v2.NewMockRouter(mctrl), + } + decoded := loadTRC(t, trc1v1) + expected := test.Expect(&m, &decoded) + provider := trust.NewCryptoProvider(m.DB, m.Recurser, m.Resolver, + m.Router, test.CacheOnly) + trcObj, err := provider.GetTRC(nil, trc1v1.ISD, scrypto.LatestVer, test.Opts) + assert.Equal(t, expected.TRC, trcObj) + if test.ExpectedErr != nil { + require.Error(t, err) + assert.Truef(t, xerrors.Is(err, test.ExpectedErr), + "actual: %s expected: %s", err, test.ExpectedErr) + } else { + require.NoError(t, err) + } + }) + } + +} diff --git a/go/lib/infra/modules/trust/v2/recurser.go b/go/lib/infra/modules/trust/v2/recurser.go index 9cf79cba12..8c736abd47 100644 --- a/go/lib/infra/modules/trust/v2/recurser.go +++ b/go/lib/infra/modules/trust/v2/recurser.go @@ -23,7 +23,9 @@ import ( type Recurser interface { // AllowRecursion indicates whether the recursion is allowed for the // provided Peer. Recursions started by the local trust store have a nil - // address and should generally be allowed. + // address and should generally be allowed. The nil value indicates + // recursion is allowed. Non-nil return values indicate that recursion is + // not allowed and specify the reason. AllowRecursion(peer net.Addr) error } diff --git a/go/lib/infra/modules/trust/v2/testdata/gen_crypto_tar.sh b/go/lib/infra/modules/trust/v2/testdata/gen_crypto_tar.sh index 265cf23a9d..2affca92e0 100755 --- a/go/lib/infra/modules/trust/v2/testdata/gen_crypto_tar.sh +++ b/go/lib/infra/modules/trust/v2/testdata/gen_crypto_tar.sh @@ -1,7 +1,10 @@ #! /bin/bash # usage: gen_crypto_tar.sh - +# +# Example: (generate crypto tar from root dir) +# CRYPTO_PATH="./go/lib/infra/modules/trust/v2/testdata" +# $CRYPTO_PATH/gen_crypto_tar.sh ./bin/scion-pki $CRYPTO_PATH/crypto.tar set -e TMP=`mktemp -d` diff --git a/tools/gomocks b/tools/gomocks index e09cce645c..76d049404f 100755 --- a/tools/gomocks +++ b/tools/gomocks @@ -47,6 +47,7 @@ MOCK_TARGETS = [ "Splitter,Validator"), (SCION_PACKAGE_PREFIX + "/go/lib/infra/modules/seghandler", "Storage,Verifier"), (SCION_PACKAGE_PREFIX + "/go/lib/infra/modules/trust/trustdb", "TrustDB"), + (SCION_PACKAGE_PREFIX + "/go/lib/infra/modules/trust/v2", "DB,Recurser,Resolver,Router"), (SCION_PACKAGE_PREFIX + "/go/lib/l4", "L4Header"), (SCION_PACKAGE_PREFIX + "/go/lib/log", "Handler,Logger"), (SCION_PACKAGE_PREFIX + "/go/lib/overlay/conn", "Conn"),