diff --git a/go/lib/infra/modules/trust/v2/BUILD.bazel b/go/lib/infra/modules/trust/v2/BUILD.bazel index 12ed4bcc23..2c355a7bc4 100644 --- a/go/lib/infra/modules/trust/v2/BUILD.bazel +++ b/go/lib/infra/modules/trust/v2/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//go/lib/infra/modules/trust/v2/internal/decoded:go_default_library", "//go/lib/log:go_default_library", "//go/lib/scrypto:go_default_library", + "//go/lib/scrypto/cert/v2:go_default_library", "//go/lib/scrypto/trc/v2:go_default_library", "//go/lib/serrors:go_default_library", "//go/lib/snet:go_default_library", @@ -52,6 +53,7 @@ go_test( "//go/lib/infra/modules/trust/v2/mock_v2:go_default_library", "//go/lib/log:go_default_library", "//go/lib/scrypto:go_default_library", + "//go/lib/scrypto/cert/v2:go_default_library", "//go/lib/scrypto/trc/v2:go_default_library", "//go/lib/serrors:go_default_library", "//go/lib/snet:go_default_library", diff --git a/go/lib/infra/modules/trust/v2/inserter.go b/go/lib/infra/modules/trust/v2/inserter.go index 6cbb3dc679..ca460298e0 100644 --- a/go/lib/infra/modules/trust/v2/inserter.go +++ b/go/lib/infra/modules/trust/v2/inserter.go @@ -20,6 +20,7 @@ import ( "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/cert/v2" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/serrors" ) @@ -98,9 +99,13 @@ type fwdInserter struct { func (ins *fwdInserter) InsertTRC(ctx context.Context, decTRC decoded.TRC, trcProvider TRCProviderFunc) error { - if insert, err := ins.shouldInsertTRC(ctx, decTRC, trcProvider); err != nil || !insert { + insert, err := ins.shouldInsertTRC(ctx, decTRC, trcProvider) + if err != nil { return err } + if !insert { + return nil + } cs := ins.router.chooseServer() if err := ins.rpc.SendTRC(ctx, decTRC.Raw, cs); err != nil { return serrors.WrapStr("unable to push TRC to certificate server", err, "addr", cs) @@ -119,13 +124,16 @@ func (ins *fwdInserter) InsertTRC(ctx context.Context, decTRC decoded.TRC, func (ins *fwdInserter) InsertChain(ctx context.Context, chain decoded.Chain, trcProvider TRCProviderFunc) error { - if insert, err := ins.shouldInsertChain(ctx, chain, trcProvider); err != nil || !insert { + insert, err := ins.shouldInsertChain(ctx, chain, trcProvider) + if err != nil { return err } + if !insert { + return nil + } cs := ins.router.chooseServer() if err := ins.rpc.SendCertChain(ctx, chain.Raw, cs); err != nil { - return serrors.WrapStr("unable to push chain to certificate server", err, - "addr", cs) + return serrors.WrapStr("unable to push chain to certificate server", err, "addr", cs) } if _, _, err := ins.db.InsertChain(ctx, chain); err != nil { return serrors.WrapStr("unable to insert chain", err) @@ -144,8 +152,11 @@ func (ins *baseInserter) shouldInsertTRC(ctx context.Context, decTRC decoded.TRC trcProvider TRCProviderFunc) (bool, error) { found, err := ins.db.TRCExists(ctx, decTRC) - if err != nil || found { - return !found, err + if err != nil { + return false, err + } + if found { + return false, nil } if decTRC.TRC.Base() { // XXX(roosd): remove when TAACs are supported. @@ -191,5 +202,53 @@ func (ins *baseInserter) checkUpdate(ctx context.Context, prev *trc.TRC, next de func (ins *baseInserter) shouldInsertChain(ctx context.Context, chain decoded.Chain, trcProvider TRCProviderFunc) (bool, error) { - return false, serrors.New("not implemented") + found, err := ins.db.ChainExists(ctx, chain) + if err != nil { + return false, err + } + if found { + return false, nil + } + if err := ins.validateChain(chain); err != nil { + return false, serrors.WrapStr("error validating the certificate chain", err) + } + t, err := trcProvider(ctx, chain.Issuer.Subject.I, chain.Issuer.Issuer.TRCVersion) + if err != nil { + return false, serrors.WrapStr("unable to get issuing TRC", err, + "isd", chain.Issuer.Subject.I, "version", chain.Issuer.Issuer.TRCVersion) + } + if err := ins.verifyChain(chain, t); err != nil { + return false, serrors.WrapStr("error verifying the certificate chain", err) + } + return true, nil +} + +func (ins *baseInserter) validateChain(chain decoded.Chain) error { + if err := chain.Issuer.Validate(); err != nil { + return serrors.Wrap(ErrValidation, err, "part", "issuer") + } + if err := chain.AS.Validate(); err != nil { + return serrors.Wrap(ErrValidation, err, "part", "AS") + } + return nil +} + +func (ins *baseInserter) verifyChain(chain decoded.Chain, t *trc.TRC) error { + issVerifier := cert.IssuerVerifier{ + TRC: t, + Issuer: chain.Issuer, + SignedIssuer: &chain.Chain.Issuer, + } + if err := issVerifier.Verify(); err != nil { + return serrors.Wrap(ErrVerification, err, "part", "issuer") + } + asVerifier := cert.ASVerifier{ + Issuer: chain.Issuer, + AS: chain.AS, + SignedAS: &chain.Chain.AS, + } + if err := asVerifier.Verify(); err != nil { + return serrors.Wrap(ErrVerification, err, "part", "AS") + } + return nil } diff --git a/go/lib/infra/modules/trust/v2/inserter_test.go b/go/lib/infra/modules/trust/v2/inserter_test.go index 6f9665c6d5..8d27cd1f3b 100644 --- a/go/lib/infra/modules/trust/v2/inserter_test.go +++ b/go/lib/infra/modules/trust/v2/inserter_test.go @@ -19,12 +19,16 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - "golang.org/x/xerrors" + "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" + "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + "github.com/scionproto/scion/go/lib/serrors" + "github.com/scionproto/scion/go/lib/snet" + "github.com/scionproto/scion/go/lib/xtest" ) func TestInserterInsertTRC(t *testing.T) { @@ -89,12 +93,214 @@ func TestInserterInsertTRC(t *testing.T) { ins := trust.NewInserter(db, test.Unsafe) err := ins.InsertTRC(context.Background(), decoded, nil) - if test.ExpectedErr != nil { - require.Truef(t, xerrors.Is(err, test.ExpectedErr), - "Expected: %s Actual: %s", test.ExpectedErr, err) - } else { - require.NoError(t, err) + xtest.AssertErrorsIs(t, err, test.ExpectedErr) + }) + } +} + +func TestInserterInsertChain(t *testing.T) { + notFound := serrors.New("not found") + dbErr := serrors.New("db error") + tests := map[string]struct { + Expect func(*mock_v2.MockDB, decoded.Chain) + ExpectedErr error + TRCProvider trust.TRCProviderFunc + }{ + "valid": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + db.EXPECT().InsertChain(gomock.Any(), dec).Return( + true, true, nil, + ) + }, + }, + "exists with same contents": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + true, nil, + ) + }, + }, + "exists with different contents": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + true, trust.ErrContentMismatch, + ) + }, + ExpectedErr: trust.ErrContentMismatch, + }, + "TRC not found": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + }, + TRCProvider: func(context.Context, addr.ISD, scrypto.Version) (*trc.TRC, error) { + return nil, notFound + }, + ExpectedErr: notFound, + }, + "insert fails": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + db.EXPECT().InsertChain(gomock.Any(), dec).Return( + false, false, dbErr, + ) + }, + ExpectedErr: dbErr, + }, + "invalid AS certificate": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + dec.AS.Subject = addr.IA{} + }, + ExpectedErr: trust.ErrValidation, + }, + "invalid issuer certificate": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + dec.Issuer.Subject = addr.IA{} + }, + ExpectedErr: trust.ErrValidation, + }, + "forged AS certificate": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + dec.Chain.AS.Signature[0] ^= 0xFF + }, + ExpectedErr: trust.ErrVerification, + }, + "forged issuer certificate": { + Expect: func(db *mock_v2.MockDB, dec decoded.Chain) { + db.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + dec.Chain.Issuer.Signature[0] ^= 0xFF + }, + ExpectedErr: trust.ErrVerification, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + + db := mock_v2.NewMockDB(mctrl) + decoded := loadChain(t, chain110v1) + test.Expect(db, decoded) + ins := trust.NewInserter(db, false) + + decTRC := loadTRC(t, trc1v1) + p := func(ctx context.Context, isd addr.ISD, ver scrypto.Version) (*trc.TRC, error) { + return decTRC.TRC, nil + } + if test.TRCProvider != nil { + p = test.TRCProvider + } + + err := ins.InsertChain(context.Background(), decoded, p) + xtest.AssertErrorsIs(t, err, test.ExpectedErr) + }) + } +} + +func TestFwdInserterInsertChain(t *testing.T) { + internal := serrors.New("internal") + type mocks struct { + DB *mock_v2.MockDB + Router *mock_v2.MockRouter + RPC *mock_v2.MockRPC + } + tests := map[string]struct { + Expect func(*mocks, decoded.Chain) + ExpectedErr error + TRCProvider trust.TRCProviderFunc + }{ + "valid": { + Expect: func(m *mocks, dec decoded.Chain) { + m.DB.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + a := snet.NewSVCAddr(addr.IA{}, nil, nil, addr.SvcCS) + m.RPC.EXPECT().SendCertChain(gomock.Any(), dec.Raw, a).Return(nil) + m.DB.EXPECT().InsertChain(gomock.Any(), dec).Return( + true, true, nil, + ) + }, + }, + "already exists": { + Expect: func(m *mocks, dec decoded.Chain) { + m.DB.EXPECT().ChainExists(gomock.Any(), dec).Return( + true, nil, + ) + }, + }, + "mismatch": { + Expect: func(m *mocks, dec decoded.Chain) { + m.DB.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, trust.ErrContentMismatch, + ) + }, + ExpectedErr: trust.ErrContentMismatch, + }, + "rpc fails": { + Expect: func(m *mocks, dec decoded.Chain) { + m.DB.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + a := snet.NewSVCAddr(addr.IA{}, nil, nil, addr.SvcCS) + m.RPC.EXPECT().SendCertChain(gomock.Any(), dec.Raw, a).Return(internal) + }, + ExpectedErr: internal, + }, + "insert fails": { + Expect: func(m *mocks, dec decoded.Chain) { + m.DB.EXPECT().ChainExists(gomock.Any(), dec).Return( + false, nil, + ) + a := snet.NewSVCAddr(addr.IA{}, nil, nil, addr.SvcCS) + m.RPC.EXPECT().SendCertChain(gomock.Any(), dec.Raw, a).Return(nil) + m.DB.EXPECT().InsertChain(gomock.Any(), dec).Return( + false, false, trust.ErrContentMismatch, + ) + }, + ExpectedErr: trust.ErrContentMismatch, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + + m := &mocks{ + DB: mock_v2.NewMockDB(mctrl), + RPC: mock_v2.NewMockRPC(mctrl), + Router: mock_v2.NewMockRouter(mctrl), + } + decoded := loadChain(t, chain110v1) + test.Expect(m, decoded) + ins := trust.NewFwdInserter(m.DB, m.RPC) + + decTRC := loadTRC(t, trc1v1) + p := func(ctx context.Context, isd addr.ISD, ver scrypto.Version) (*trc.TRC, error) { + return decTRC.TRC, nil + } + if test.TRCProvider != nil { + p = test.TRCProvider } + + err := ins.InsertChain(context.Background(), decoded, p) + xtest.AssertErrorsIs(t, err, test.ExpectedErr) }) } } diff --git a/go/lib/infra/modules/trust/v2/main_test.go b/go/lib/infra/modules/trust/v2/main_test.go index 198d85d43e..73c2a8f78a 100644 --- a/go/lib/infra/modules/trust/v2/main_test.go +++ b/go/lib/infra/modules/trust/v2/main_test.go @@ -28,6 +28,7 @@ import ( "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/log" "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/scrypto/cert/v2" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/xtest" ) @@ -47,26 +48,46 @@ func (desc TRCDesc) File() string { return fmt.Sprintf("ISD%d/trcs/ISD%d-V%d.trc", desc.ISD, desc.ISD, desc.Version) } -var ( - trc1v1 = TRCDesc{ISD: 1, Version: 1} - trc1v2 = TRCDesc{ISD: 1, Version: 2} - trc1v3 = TRCDesc{ISD: 1, Version: 3} - trc1v4 = TRCDesc{ISD: 1, Version: 4} +type ChainDesc struct { + IA addr.IA + Version scrypto.Version +} + +func (desc ChainDesc) File() string { + return fmt.Sprintf("ISD%d/AS%s/certs/%s-V%d.crt", desc.IA.I, desc.IA.A.FileFmt(), + desc.IA.FileFmt(true), desc.Version) +} - // primary ASes +// Primary ASes ISD 1 +var ( ia110 = xtest.MustParseIA("1-ff00:0:110") ia120 = xtest.MustParseIA("1-ff00:0:120") ia130 = xtest.MustParseIA("1-ff00:0:130") ) +// Non-primary ASes ISD 1 var ( - trc2v1 = TRCDesc{ISD: 2, Version: 1} + ia122 = xtest.MustParseIA("1-ff00:0:122") +) - // primary ASes +// Primary ASes ISD 2 +var ( ia210 = xtest.MustParseIA("2-ff00:0:210") +) - // non-primary ASes - ia122 = xtest.MustParseIA("1-ff00:0:122") +// TRCs +var ( + trc1v1 = TRCDesc{ISD: 1, Version: 1} + trc1v2 = TRCDesc{ISD: 1, Version: 2} + trc1v3 = TRCDesc{ISD: 1, Version: 3} + trc1v4 = TRCDesc{ISD: 1, Version: 4} + + trc2v1 = TRCDesc{ISD: 2, Version: 1} +) + +// Chains +var ( + chain110v1 = ChainDesc{IA: ia110, Version: 1} ) func TestMain(m *testing.M) { @@ -100,3 +121,19 @@ func loadTRC(t *testing.T, desc TRCDesc) decoded.TRC { TRC: trcObj, } } + +func loadChain(t *testing.T, desc ChainDesc) decoded.Chain { + t.Helper() + file := filepath.Join(tmpDir, desc.File()) + var err error + var chain decoded.Chain + chain.Raw, err = ioutil.ReadFile(file) + require.NoError(t, err, help) + chain.Chain, err = cert.ParseChain(chain.Raw) + require.NoError(t, err, help) + chain.Issuer, err = chain.Chain.Issuer.Encoded.Decode() + require.NoError(t, err, help) + chain.AS, err = chain.Chain.AS.Encoded.Decode() + require.NoError(t, err, help) + return chain +}