diff --git a/cluster/calcium/service.go b/cluster/calcium/service.go index aa94f74a1..c8da863e4 100644 --- a/cluster/calcium/service.go +++ b/cluster/calcium/service.go @@ -101,17 +101,22 @@ func (c *Calcium) RegisterService(ctx context.Context) (unregister func(), err e log.Errorf("[RegisterService] failed to get outbound address: %v", err) return } - if err = c.store.RegisterService(ctx, serviceAddress, c.config.GRPCConfig.ServiceHeartbeatInterval); err != nil { + + ctxRegister, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err = c.store.RegisterService(ctxRegister, serviceAddress, c.config.GRPCConfig.ServiceHeartbeatInterval); err != nil { log.Errorf("[RegisterService] failed to register service: %v", err) return } done := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) + ctxHeartbeat, cancelHeartbeat := context.WithCancel(ctx) go func() { defer close(done) defer func() { - if err := c.store.UnregisterService(context.Background(), serviceAddress); err != nil { + ctxUnregister, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := c.store.UnregisterService(ctxUnregister, serviceAddress); err != nil { log.Errorf("[RegisterService] failed to unregister service: %v", err) } }() @@ -120,17 +125,19 @@ func (c *Calcium) RegisterService(ctx context.Context) (unregister func(), err e for { select { case <-timer.C: - if err := c.store.RegisterService(ctx, serviceAddress, c.config.GRPCConfig.ServiceHeartbeatInterval); err != nil { + ctxRegister, cancel = context.WithTimeout(ctxHeartbeat, time.Second) + if err := c.store.RegisterService(ctxRegister, serviceAddress, c.config.GRPCConfig.ServiceHeartbeatInterval); err != nil { log.Errorf("[RegisterService] failed to register service: %v", err) } - case <-ctx.Done(): - log.Infof("[RegisterService] context done: %v", ctx.Err()) + cancel() + case <-ctxRegister.Done(): + log.Infof("[RegisterService] context done: %v", ctxRegister.Err()) return } } }() return func() { - cancel() + cancelHeartbeat() <-done }, err } diff --git a/cluster/calcium/service_test.go b/cluster/calcium/service_test.go index 171f4229a..17d0c6471 100644 --- a/cluster/calcium/service_test.go +++ b/cluster/calcium/service_test.go @@ -21,7 +21,7 @@ func TestServiceStatusStream(t *testing.T) { c.store = store registered := map[string]int{} - store.On("RegisterService", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("time.Duration")).Return( + store.On("RegisterService", mock.AnythingOfType("*context.timerCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("time.Duration")).Return( func(_ context.Context, addr string, _ time.Duration) error { if v, ok := registered[addr]; ok { registered[addr] = v + 1 @@ -31,7 +31,7 @@ func TestServiceStatusStream(t *testing.T) { return nil }, ) - store.On("UnregisterService", mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")).Return( + store.On("UnregisterService", mock.AnythingOfType("*context.timerCtx"), mock.AnythingOfType("string")).Return( func(_ context.Context, addr string) error { delete(registered, addr) return nil