From 46f170b59534be2677d5aae23c12c589531c7fc1 Mon Sep 17 00:00:00 2001 From: anrs Date: Fri, 27 Aug 2021 15:44:41 +0800 Subject: [PATCH] use the real context from exterior (#466) Co-authored-by: anrs --- cluster/calcium/calcium.go | 4 ++-- cluster/calcium/wal.go | 20 +++++++++--------- cluster/calcium/wal_test.go | 26 +++++++++++------------ core.go | 2 +- store/redis/ephemeral.go | 12 +++++------ wal/hydro.go | 10 ++++----- wal/hydro_test.go | 13 ++++++------ wal/mocks/WAL.go | 7 ++++--- wal/wal.go | 42 ++++++++----------------------------- wal/wal_test.go | 11 +++++----- 10 files changed, 63 insertions(+), 84 deletions(-) diff --git a/cluster/calcium/calcium.go b/cluster/calcium/calcium.go index 8e874681c..3d50139ca 100644 --- a/cluster/calcium/calcium.go +++ b/cluster/calcium/calcium.go @@ -85,8 +85,8 @@ func New(config types.Config, t *testing.T) (*Calcium, error) { } // DisasterRecover . -func (c *Calcium) DisasterRecover() { - c.wal.Recover() +func (c *Calcium) DisasterRecover(ctx context.Context) { + c.wal.Recover(ctx) } // Finalizer use for defer diff --git a/cluster/calcium/wal.go b/cluster/calcium/wal.go index 5f075dd43..a1bd1c4cf 100644 --- a/cluster/calcium/wal.go +++ b/cluster/calcium/wal.go @@ -79,13 +79,13 @@ func (h *CreateWorkloadHandler) Event() string { } // Check . -func (h *CreateWorkloadHandler) Check(raw interface{}) (bool, error) { +func (h *CreateWorkloadHandler) Check(ctx context.Context, raw interface{}) (bool, error) { wrk, ok := raw.(*types.Workload) if !ok { return false, types.NewDetailedErr(types.ErrInvalidType, raw) } - ctx, cancel := getReplayContext(context.Background()) + ctx, cancel := getReplayContext(ctx) defer cancel() _, err := h.calcium.GetWorkload(ctx, wrk.ID) @@ -122,13 +122,13 @@ func (h *CreateWorkloadHandler) Decode(bs []byte) (interface{}, error) { } // Handle . -func (h *CreateWorkloadHandler) Handle(raw interface{}) error { +func (h *CreateWorkloadHandler) Handle(ctx context.Context, raw interface{}) error { wrk, ok := raw.(*types.Workload) if !ok { return types.NewDetailedErr(types.ErrInvalidType, raw) } - ctx, cancel := getReplayContext(context.Background()) + ctx, cancel := getReplayContext(ctx) defer cancel() // There hasn't been the exact workload metadata, so we must remove it. @@ -173,7 +173,7 @@ func (h *CreateLambdaHandler) Event() string { } // Check . -func (h *CreateLambdaHandler) Check(interface{}) (bool, error) { +func (h *CreateLambdaHandler) Check(context.Context, interface{}) (bool, error) { return true, nil } @@ -194,20 +194,20 @@ func (h *CreateLambdaHandler) Decode(bs []byte) (interface{}, error) { } // Handle . -func (h *CreateLambdaHandler) Handle(raw interface{}) error { +func (h *CreateLambdaHandler) Handle(ctx context.Context, raw interface{}) error { opts, ok := raw.(*types.ListWorkloadsOptions) if !ok { return types.NewDetailedErr(types.ErrInvalidType, raw) } - workloadIDs, err := h.getWorkloadIDs(opts) + workloadIDs, err := h.getWorkloadIDs(ctx, opts) if err != nil { log.Errorf(context.TODO(), "[CreateLambdaHandler.Handle] Get workloads %s/%s/%v failed: %v", opts.Appname, opts.Entrypoint, opts.Labels, err) return err } - ctx, cancel := getReplayContext(context.Background()) + ctx, cancel := getReplayContext(ctx) defer cancel() if err := h.calcium.doRemoveWorkloadSync(ctx, workloadIDs); err != nil { @@ -220,8 +220,8 @@ func (h *CreateLambdaHandler) Handle(raw interface{}) error { return nil } -func (h *CreateLambdaHandler) getWorkloadIDs(opts *types.ListWorkloadsOptions) ([]string, error) { - ctx, cancel := getReplayContext(context.Background()) +func (h *CreateLambdaHandler) getWorkloadIDs(ctx context.Context, opts *types.ListWorkloadsOptions) ([]string, error) { + ctx, cancel := getReplayContext(ctx) defer cancel() workloads, err := h.calcium.ListWorkloads(ctx, opts) diff --git a/cluster/calcium/wal_test.go b/cluster/calcium/wal_test.go index 9c39f2ee9..0fa3d3ae7 100644 --- a/cluster/calcium/wal_test.go +++ b/cluster/calcium/wal_test.go @@ -31,10 +31,10 @@ func TestHandleCreateWorkloadNoHandle(t *testing.T) { defer store.AssertExpectations(t) store.On("GetWorkload", mock.Anything, wrkid).Return(wrk, nil).Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) // Recovers nothing. - c.wal.Recover() + c.wal.Recover(context.TODO()) } func TestHandleCreateWorkloadError(t *testing.T) { @@ -59,12 +59,12 @@ func TestHandleCreateWorkloadError(t *testing.T) { store := c.store.(*storemocks.Store) defer store.AssertExpectations(t) store.On("GetWorkload", mock.Anything, wrkid).Return(wrk, fmt.Errorf("err")).Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) err = types.NewDetailedErr(types.ErrBadCount, fmt.Sprintf("keys: [%s]", wrkid)) store.On("GetWorkload", mock.Anything, wrkid).Return(wrk, err) store.On("GetNode", mock.Anything, wrk.Nodename).Return(nil, fmt.Errorf("err")).Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) store.On("GetNode", mock.Anything, wrk.Nodename).Return(node, nil) eng, ok := node.Engine.(*enginemocks.API) @@ -73,15 +73,15 @@ func TestHandleCreateWorkloadError(t *testing.T) { eng.On("VirtualizationRemove", mock.Anything, wrk.ID, true, true). Return(fmt.Errorf("err")). Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) eng.On("VirtualizationRemove", mock.Anything, wrk.ID, true, true). Return(fmt.Errorf("Error: No such container: %s", wrk.ID)). Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) // Nothing recovered. - c.wal.Recover() + c.wal.Recover(context.TODO()) } func TestHandleCreateWorkloadHandled(t *testing.T) { @@ -118,10 +118,10 @@ func TestHandleCreateWorkloadHandled(t *testing.T) { Return(nil). Once() - c.wal.Recover() + c.wal.Recover(context.TODO()) // Recovers nothing. - c.wal.Recover() + c.wal.Recover(context.TODO()) } func TestHandleCreateLambda(t *testing.T) { @@ -154,7 +154,7 @@ func TestHandleCreateLambda(t *testing.T) { Return(nil, fmt.Errorf("err")). Once() store.On("ListNodeWorkloads", mock.Anything, mock.Anything, mock.Anything).Return(nil, types.ErrNoETCD) - c.wal.Recover() + c.wal.Recover(context.TODO()) store.On("ListWorkloads", mock.Anything, deployOpts.Name, deployOpts.Entrypoint.Name, "", int64(0), deployOpts.Labels). Return([]*types.Workload{wrk}, nil). @@ -179,11 +179,11 @@ func TestHandleCreateLambda(t *testing.T) { Once() lock := &lockmocks.DistributedLock{} - lock.On("Lock", mock.Anything).Return(context.Background(), nil) + lock.On("Lock", mock.Anything).Return(context.TODO(), nil) lock.On("Unlock", mock.Anything).Return(nil) store.On("CreateLock", mock.Anything, mock.Anything).Return(lock, nil) - c.wal.Recover() + c.wal.Recover(context.TODO()) // Recovered nothing. - c.wal.Recover() + c.wal.Recover(context.TODO()) } diff --git a/core.go b/core.go index 64701c643..472ead01a 100644 --- a/core.go +++ b/core.go @@ -81,7 +81,7 @@ func serve(c *cli.Context) error { return err } defer cluster.Finalizer() - cluster.DisasterRecover() + cluster.DisasterRecover(c.Context) stop := make(chan struct{}, 1) vibranium := rpc.New(cluster, config, stop) diff --git a/store/redis/ephemeral.go b/store/redis/ephemeral.go index 4cd6dcd9d..85d38a950 100644 --- a/store/redis/ephemeral.go +++ b/store/redis/ephemeral.go @@ -22,7 +22,7 @@ func (r *Rediaron) StartEphemeral(ctx context.Context, path string, heartbeat ti return nil, nil, errors.Wrap(types.ErrKeyExists, path) } - cctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) expiry := make(chan struct{}) var wg sync.WaitGroup @@ -37,11 +37,11 @@ func (r *Rediaron) StartEphemeral(ctx context.Context, path string, heartbeat ti for { select { case <-tick.C: - if err := r.refreshEphemeral(path, heartbeat); err != nil { + if err := r.refreshEphemeral(ctx, path, heartbeat); err != nil { r.revokeEphemeral(path) return } - case <-cctx.Done(): + case <-ctx.Done(): r.revokeEphemeral(path) return } @@ -55,15 +55,15 @@ func (r *Rediaron) StartEphemeral(ctx context.Context, path string, heartbeat ti } func (r *Rediaron) revokeEphemeral(path string) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) defer cancel() if _, err := r.cli.Del(ctx, path).Result(); err != nil { log.Errorf(context.TODO(), "[refreshEphemeral] revoke with %s failed: %v", path, err) } } -func (r *Rediaron) refreshEphemeral(path string, ttl time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (r *Rediaron) refreshEphemeral(ctx context.Context, path string, ttl time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() _, err := r.cli.Expire(ctx, path, ttl).Result() return err diff --git a/wal/hydro.go b/wal/hydro.go index 3d27f00b4..9a7c63159 100644 --- a/wal/hydro.go +++ b/wal/hydro.go @@ -42,7 +42,7 @@ func (h *Hydro) Register(handler EventHandler) { } // Recover starts a disaster recovery, which will replay all the events. -func (h *Hydro) Recover() { +func (h *Hydro) Recover(ctx context.Context) { ch, _ := h.kv.Scan([]byte(EventPrefix)) events := []HydroEvent{} @@ -62,27 +62,27 @@ func (h *Hydro) Recover() { continue } - if err := h.recover(handler, ev); err != nil { + if err := h.recover(ctx, handler, ev); err != nil { log.Errorf(context.TODO(), "[Recover] handle event %d (%s) failed: %v", ev.ID, ev.Type, err) continue } } } -func (h *Hydro) recover(handler EventHandler, event HydroEvent) error { +func (h *Hydro) recover(ctx context.Context, handler EventHandler, event HydroEvent) error { item, err := handler.Decode(event.Item) if err != nil { return err } - switch handle, err := handler.Check(item); { + switch handle, err := handler.Check(ctx, item); { case err != nil: return err case !handle: return event.Delete() } - if err := handler.Handle(item); err != nil { + if err := handler.Handle(ctx, item); err != nil { return err } diff --git a/wal/hydro_test.go b/wal/hydro_test.go index 7d2e629d3..27b2aef27 100644 --- a/wal/hydro_test.go +++ b/wal/hydro_test.go @@ -1,6 +1,7 @@ package wal import ( + "context" "fmt" "io/ioutil" "os" @@ -74,7 +75,7 @@ func TestRecoverFailedAsNoSuchHandler(t *testing.T) { hydro.handlers.Delete(eventype) - hydro.Recover() + hydro.Recover(context.TODO()) require.True(t, encoded) require.False(t, decoded) require.False(t, checked) @@ -98,7 +99,7 @@ func TestRecoverFailedAsCheckError(t *testing.T) { require.NoError(t, err) require.NotNil(t, commit) - hydro.Recover() + hydro.Recover(context.TODO()) require.True(t, encoded) require.True(t, decoded) require.True(t, checked) @@ -145,7 +146,7 @@ func TestRecoverFailedAsDecodeLogError(t *testing.T) { require.NoError(t, err) require.NotNil(t, commit) - hydro.Recover() + hydro.Recover(context.TODO()) require.True(t, encoded) require.True(t, decoded) require.False(t, checked) @@ -171,7 +172,7 @@ func TestHydroRecoverDiscardNoNeedEvent(t *testing.T) { require.NoError(t, err) require.NotNil(t, commit) - hydro.Recover() + hydro.Recover(context.TODO()) require.True(t, encoded) require.True(t, decoded) require.True(t, checked) @@ -191,7 +192,7 @@ func TestHydroRecover(t *testing.T) { require.NoError(t, err) require.NotNil(t, commit) - hydro.Recover() + hydro.Recover(context.TODO()) require.True(t, encoded) require.True(t, decoded) require.True(t, checked) @@ -236,7 +237,7 @@ func TestHydroRecoverWithRealLithium(t *testing.T) { hydro.Log(handler.event, struct{}{}) hydro.Log(handler.event, struct{}{}) - hydro.Recover() + hydro.Recover(context.TODO()) ch, _ := hydro.kv.Scan([]byte(EventPrefix)) for range ch { diff --git a/wal/mocks/WAL.go b/wal/mocks/WAL.go index 656695c44..890384403 100644 --- a/wal/mocks/WAL.go +++ b/wal/mocks/WAL.go @@ -3,6 +3,7 @@ package mocks import ( + context "context" time "time" mock "github.com/stretchr/testify/mock" @@ -66,9 +67,9 @@ func (_m *WAL) Open(_a0 string, _a1 time.Duration) error { return r0 } -// Recover provides a mock function with given fields: -func (_m *WAL) Recover() { - _m.Called() +// Recover provides a mock function with given fields: _a0 +func (_m *WAL) Recover(_a0 context.Context) { + _m.Called(_a0) } // Register provides a mock function with given fields: _a0 diff --git a/wal/wal.go b/wal/wal.go index a86792cea..45daec9d9 100644 --- a/wal/wal.go +++ b/wal/wal.go @@ -1,6 +1,9 @@ package wal -import "time" +import ( + "context" + "time" +) const ( // EventPrefix indicates the key prefix of all events' keys. @@ -17,7 +20,7 @@ type WAL interface { // Recoverer is the interface that wraps the basic Recover method. type Recoverer interface { - Recover() + Recover(context.Context) } // Registry is the interface that wraps the basic Register method. @@ -51,7 +54,7 @@ func (h SimpleEventHandler) Event() string { } // Check . -func (h SimpleEventHandler) Check(raw interface{}) (bool, error) { +func (h SimpleEventHandler) Check(ctx context.Context, raw interface{}) (bool, error) { return h.check(raw) } @@ -66,45 +69,18 @@ func (h SimpleEventHandler) Decode(bs []byte) (interface{}, error) { } // Handle . -func (h SimpleEventHandler) Handle(raw interface{}) error { +func (h SimpleEventHandler) Handle(ctx context.Context, raw interface{}) error { return h.handle(raw) } // EventHandler is the interface that groups a few methods. type EventHandler interface { Event() string - Check(interface{}) (need bool, err error) + Check(context.Context, interface{}) (need bool, err error) Encode(interface{}) ([]byte, error) Decode([]byte) (interface{}, error) - Handle(interface{}) error + Handle(context.Context, interface{}) error } // Commit is a function for committing an event log. type Commit func() error - -// Register registers a new event to doit. -func Register(handler EventHandler) { - wal.Register(handler) -} - -// Log records a log item. -func Log(event string, item interface{}) (Commit, error) { - return wal.Log(event, item) -} - -// Recover makes a disaster recovery. -func Recover() { - wal.Recover() -} - -// Close closes a WAL file. -func Close() error { - return wal.Close() -} - -// Open opens a WAL file. -func Open(path string, timeout time.Duration) error { - return wal.Open(path, timeout) -} - -var wal WAL = NewHydro() diff --git a/wal/wal_test.go b/wal/wal_test.go index 18ed81266..1792b397a 100644 --- a/wal/wal_test.go +++ b/wal/wal_test.go @@ -38,8 +38,9 @@ func TestRecover(t *testing.T) { path := "/tmp/wal.unitest.wal" os.Remove(path) - require.NoError(t, Open(path, time.Second)) - defer Close() + wal WAL := NewHydro() + require.NoError(t, wal.Open(path, time.Second)) + defer wal.Close() hydro, ok := wal.(*Hydro) require.True(t, ok) @@ -48,7 +49,7 @@ func TestRecover(t *testing.T) { eventype := "create" - Register(SimpleEventHandler{ + wal.Register(SimpleEventHandler{ event: eventype, encode: encode, decode: decode, @@ -56,9 +57,9 @@ func TestRecover(t *testing.T) { handle: handle, }) - Log(eventype, struct{}{}) + wal.Log(eventype, struct{}{}) - Recover() + wal.Recover(context.TODO()) require.True(t, checked) require.True(t, handled) require.True(t, encoded)