From 9778df485170edc5acf47b479cc271d65da09676 Mon Sep 17 00:00:00 2001 From: nate Date: Tue, 28 Sep 2021 10:57:18 -0600 Subject: [PATCH] [dbnode,query] Ensure only single listener for interrupt channel (#3778) Passing the interrupt channel to multiple goroutines could cause a race where the main thread ends up missing the interrupt that triggers a server shutdown. This commit ensures that only a single goroutine is listening for the interrupt at a given time and all other interested parties can check the interrupted channel. The interrupted channel will be closed as soon as an interrupt is received. Since closed channels return immediately, this allows any interested goroutine to know if it should terminate by simply checking the interrupted channel. --- src/cluster/kv/util/runtime/options.go | 18 +++--- src/cluster/kv/util/runtime/value.go | 2 +- src/cluster/services/options.go | 10 ++-- src/cluster/services/services.go | 19 ++++--- src/cluster/services/services_mock.go | 26 ++++----- src/cluster/services/services_test.go | 9 +-- src/cluster/services/types.go | 8 +-- .../m3coordinator/downsample/options.go | 4 +- src/dbnode/environment/config.go | 4 +- src/dbnode/server/server.go | 23 ++++++-- src/metrics/matcher/namespaces.go | 2 +- src/metrics/matcher/namespaces_test.go | 15 +++-- src/metrics/matcher/options.go | 18 +++--- src/query/server/query.go | 29 +++++++--- src/x/os/interrupt.go | 52 ++++++++++++++++++ src/x/os/interrupt_test.go | 55 +++++++++++++++++++ src/x/watch/options.go | 18 +++--- src/x/watch/value.go | 14 +++-- src/x/watch/value_test.go | 10 ++-- 19 files changed, 234 insertions(+), 102 deletions(-) create mode 100644 src/x/os/interrupt_test.go diff --git a/src/cluster/kv/util/runtime/options.go b/src/cluster/kv/util/runtime/options.go index b51a1ab612..d5a331d941 100644 --- a/src/cluster/kv/util/runtime/options.go +++ b/src/cluster/kv/util/runtime/options.go @@ -63,11 +63,11 @@ type Options interface { // ProcessFn returns the process function. ProcessFn() ProcessFn - // InterruptCh returns the interrupt channel. - InterruptCh() <-chan error + // InterruptedCh returns the interrupted channel. + InterruptedCh() <-chan struct{} - // SetInterruptCh sets the interrupt channel. - SetInterruptCh(value <-chan error) Options + // SetInterruptedCh sets the interrupted channel. + SetInterruptedCh(value <-chan struct{}) Options } type options struct { @@ -76,7 +76,7 @@ type options struct { kvStore kv.Store unmarshalFn UnmarshalFn processFn ProcessFn - interruptCh <-chan error + interruptedCh <-chan struct{} } // NewOptions creates a new set of options. @@ -137,11 +137,11 @@ func (o *options) ProcessFn() ProcessFn { return o.processFn } -func (o *options) SetInterruptCh(ch <-chan error) Options { - o.interruptCh = ch +func (o *options) SetInterruptedCh(ch <-chan struct{}) Options { + o.interruptedCh = ch return o } -func (o *options) InterruptCh() <-chan error { - return o.interruptCh +func (o *options) InterruptedCh() <-chan struct{} { + return o.interruptedCh } diff --git a/src/cluster/kv/util/runtime/value.go b/src/cluster/kv/util/runtime/value.go index 0775371ef0..7a10ad475c 100644 --- a/src/cluster/kv/util/runtime/value.go +++ b/src/cluster/kv/util/runtime/value.go @@ -83,7 +83,7 @@ func (v *value) initValue() { SetGetUpdateFn(v.getUpdateFn). SetProcessFn(v.updateFn). SetKey(v.key). - SetInterruptCh(v.opts.InterruptCh()) + SetInterruptedCh(v.opts.InterruptedCh()) v.Value = watch.NewValue(valueOpts) } diff --git a/src/cluster/services/options.go b/src/cluster/services/options.go index f172b92726..ca9ea97ec4 100644 --- a/src/cluster/services/options.go +++ b/src/cluster/services/options.go @@ -252,17 +252,17 @@ func NewQueryOptions() QueryOptions { return new(queryOptions) } type queryOptions struct { includeUnhealthy bool - interruptCh <-chan error + interruptedCh <-chan struct{} } func (qo *queryOptions) IncludeUnhealthy() bool { return qo.includeUnhealthy } func (qo *queryOptions) SetIncludeUnhealthy(h bool) QueryOptions { qo.includeUnhealthy = h; return qo } -func (qo *queryOptions) InterruptCh() <-chan error { - return qo.interruptCh +func (qo *queryOptions) InterruptedCh() <-chan struct{} { + return qo.interruptedCh } -func (qo *queryOptions) SetInterruptCh(ch <-chan error) QueryOptions { - qo.interruptCh = ch +func (qo *queryOptions) SetInterruptedCh(ch <-chan struct{}) QueryOptions { + qo.interruptedCh = ch return qo } diff --git a/src/cluster/services/services.go b/src/cluster/services/services.go index 27d52c065b..4d0486973a 100644 --- a/src/cluster/services/services.go +++ b/src/cluster/services/services.go @@ -33,6 +33,7 @@ import ( ps "github.com/m3db/m3/src/cluster/placement/service" "github.com/m3db/m3/src/cluster/placement/storage" "github.com/m3db/m3/src/cluster/shard" + xos "github.com/m3db/m3/src/x/os" xwatch "github.com/m3db/m3/src/x/watch" "github.com/uber-go/tally" @@ -307,7 +308,7 @@ func (c *client) Watch(sid ServiceID, opts QueryOptions) (Watch, error) { return nil, err } - initValue, err := c.waitForInitValue(kvm.kv, placementWatch, sid, c.opts.InitTimeout(), opts.InterruptCh()) + initValue, err := c.waitForInitValue(kvm.kv, placementWatch, sid, c.opts.InitTimeout(), opts.InterruptedCh()) if err != nil { return nil, fmt.Errorf("could not get init value for '%s', err: %w", key, err) } @@ -592,12 +593,12 @@ func (c *client) waitForInitValue( w kv.ValueWatch, sid ServiceID, timeout time.Duration, - interruptCh <-chan error, + interruptedCh <-chan struct{}, ) (kv.Value, error) { - if interruptCh == nil { - // NB(nate): if no interrupt channel is provided, then this wait is not + if interruptedCh == nil { + // NB(nate): if no interrupted channel is provided, then this wait is not // gracefully interruptable. - interruptCh = make(chan error) + interruptedCh = make(chan struct{}) } if timeout < 0 { @@ -607,8 +608,8 @@ func (c *client) waitForInitValue( select { case <-w.C(): return w.Get(), nil - case err := <-interruptCh: - return nil, err + case <-interruptedCh: + return nil, xos.ErrInterrupted } } select { @@ -616,8 +617,8 @@ func (c *client) waitForInitValue( return w.Get(), nil case <-time.After(timeout): return kvStore.Get(c.placementKeyFn(sid)) - case err := <-interruptCh: - return nil, err + case <-interruptedCh: + return nil, xos.ErrInterrupted } } diff --git a/src/cluster/services/services_mock.go b/src/cluster/services/services_mock.go index 3ac1ea907f..e9e3876967 100644 --- a/src/cluster/services/services_mock.go +++ b/src/cluster/services/services_mock.go @@ -1271,18 +1271,18 @@ func (mr *MockQueryOptionsMockRecorder) IncludeUnhealthy() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncludeUnhealthy", reflect.TypeOf((*MockQueryOptions)(nil).IncludeUnhealthy)) } -// InterruptCh mocks base method. -func (m *MockQueryOptions) InterruptCh() <-chan error { +// InterruptedCh mocks base method. +func (m *MockQueryOptions) InterruptedCh() <-chan struct{} { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InterruptCh") - ret0, _ := ret[0].(<-chan error) + ret := m.ctrl.Call(m, "InterruptedCh") + ret0, _ := ret[0].(<-chan struct{}) return ret0 } -// InterruptCh indicates an expected call of InterruptCh. -func (mr *MockQueryOptionsMockRecorder) InterruptCh() *gomock.Call { +// InterruptedCh indicates an expected call of InterruptedCh. +func (mr *MockQueryOptionsMockRecorder) InterruptedCh() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterruptCh", reflect.TypeOf((*MockQueryOptions)(nil).InterruptCh)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterruptedCh", reflect.TypeOf((*MockQueryOptions)(nil).InterruptedCh)) } // SetIncludeUnhealthy mocks base method. @@ -1299,18 +1299,18 @@ func (mr *MockQueryOptionsMockRecorder) SetIncludeUnhealthy(h interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIncludeUnhealthy", reflect.TypeOf((*MockQueryOptions)(nil).SetIncludeUnhealthy), h) } -// SetInterruptCh mocks base method. -func (m *MockQueryOptions) SetInterruptCh(value <-chan error) QueryOptions { +// SetInterruptedCh mocks base method. +func (m *MockQueryOptions) SetInterruptedCh(value <-chan struct{}) QueryOptions { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetInterruptCh", value) + ret := m.ctrl.Call(m, "SetInterruptedCh", value) ret0, _ := ret[0].(QueryOptions) return ret0 } -// SetInterruptCh indicates an expected call of SetInterruptCh. -func (mr *MockQueryOptionsMockRecorder) SetInterruptCh(value interface{}) *gomock.Call { +// SetInterruptedCh indicates an expected call of SetInterruptedCh. +func (mr *MockQueryOptionsMockRecorder) SetInterruptedCh(value interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInterruptCh", reflect.TypeOf((*MockQueryOptions)(nil).SetInterruptCh), value) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInterruptedCh", reflect.TypeOf((*MockQueryOptions)(nil).SetInterruptedCh), value) } // MockMetadata is a mock of Metadata interface. diff --git a/src/cluster/services/services_test.go b/src/cluster/services/services_test.go index 32cb2a6936..076f061103 100644 --- a/src/cluster/services/services_test.go +++ b/src/cluster/services/services_test.go @@ -35,6 +35,7 @@ import ( "github.com/m3db/m3/src/cluster/placement/storage" "github.com/m3db/m3/src/cluster/shard" "github.com/m3db/m3/src/x/instrument" + xos "github.com/m3db/m3/src/x/os" xwatch "github.com/m3db/m3/src/x/watch" "github.com/golang/mock/gomock" @@ -920,15 +921,15 @@ func TestWatchInterruptedWithTimeout(t *testing.T) { func testWatchInterrupted(t *testing.T, s Services) { sid := NewServiceID().SetName("m3db").SetZone("zone1") - interruptCh := make(chan error, 1) - interruptCh <- errors.New("interrupt") + interruptedCh := make(chan struct{}) + close(interruptedCh) qopts := NewQueryOptions(). SetIncludeUnhealthy(true). - SetInterruptCh(interruptCh) + SetInterruptedCh(interruptedCh) _, err := s.Watch(sid, qopts) require.Error(t, err) - require.Contains(t, err.Error(), "interrupt") + require.True(t, errors.Is(err, xos.ErrInterrupted)) } func TestHeartbeatService(t *testing.T) { diff --git a/src/cluster/services/types.go b/src/cluster/services/types.go index 61745d713e..f7ede189e8 100644 --- a/src/cluster/services/types.go +++ b/src/cluster/services/types.go @@ -288,11 +288,11 @@ type QueryOptions interface { // SetIncludeUnhealthy sets the value of IncludeUnhealthy. SetIncludeUnhealthy(h bool) QueryOptions - // InterruptCh returns the interrupt channel. - InterruptCh() <-chan error + // InterruptedCh returns the interrupted channel. + InterruptedCh() <-chan struct{} - // SetInterruptCh sets the interrupt channel. - SetInterruptCh(value <-chan error) QueryOptions + // SetInterruptedCh sets the interrupted channel. + SetInterruptedCh(value <-chan struct{}) QueryOptions } // Metadata contains the metadata for a service. diff --git a/src/cmd/services/m3coordinator/downsample/options.go b/src/cmd/services/m3coordinator/downsample/options.go index 251b43ed2d..6d2c83d06e 100644 --- a/src/cmd/services/m3coordinator/downsample/options.go +++ b/src/cmd/services/m3coordinator/downsample/options.go @@ -131,7 +131,7 @@ type DownsamplerOptions struct { TagOptions models.TagOptions MetricsAppenderPoolOptions pool.ObjectPoolOptions RWOptions xio.Options - InterruptCh <-chan error + InterruptedCh <-chan struct{} } // AutoMappingRule is a mapping rule to apply to metrics. @@ -728,7 +728,7 @@ func (cfg Configuration) newAggregator(o DownsamplerOptions) (agg, error) { SetKVStore(o.RulesKVStore). SetNamespaceTag([]byte(namespaceTag)). SetRequireNamespaceWatchOnInit(cfg.Matcher.RequireNamespaceWatchOnInit). - SetInterruptCh(o.InterruptCh) + SetInterruptedCh(o.InterruptedCh) // NB(r): If rules are being explicitly set in config then we are // going to use an in memory KV store for rules and explicitly set them up. diff --git a/src/dbnode/environment/config.go b/src/dbnode/environment/config.go index a3f4b412c6..7c532d66f2 100644 --- a/src/dbnode/environment/config.go +++ b/src/dbnode/environment/config.go @@ -194,7 +194,7 @@ func (c ConfigureResults) SyncCluster() (ConfigureResult, error) { // ConfigurationParameters are options used to create new ConfigureResults type ConfigurationParameters struct { - InterruptCh <-chan error + InterruptedCh <-chan struct{} InstrumentOpts instrument.Options HashingSeed uint32 HostID string @@ -316,7 +316,7 @@ func (c Configuration) configureDynamic(cfgParams ConfigurationParameters) (Conf SetServiceID(serviceID). SetQueryOptions(services.NewQueryOptions(). SetIncludeUnhealthy(true). - SetInterruptCh(cfgParams.InterruptCh)). + SetInterruptedCh(cfgParams.InterruptedCh)). SetInstrumentOptions(cfgParams.InstrumentOpts). SetHashGen(sharding.NewHashGenWithSeed(cfgParams.HashingSeed)) topoInit := topology.NewDynamicInitializer(topoOpts) diff --git a/src/dbnode/server/server.go b/src/dbnode/server/server.go index e44765b90c..a9315b8eea 100644 --- a/src/dbnode/server/server.go +++ b/src/dbnode/server/server.go @@ -214,6 +214,13 @@ func Run(runOpts RunOptions) { }() } + interruptOpts := xos.NewInterruptOptions() + if runOpts.InterruptCh != nil { + interruptOpts.InterruptCh = runOpts.InterruptCh + } + intWatchCancel := xos.WatchForInterrupt(logger, interruptOpts) + defer intWatchCancel() + defer logger.Sync() cfg.Debug.SetRuntimeValues(logger) @@ -746,7 +753,7 @@ func Run(runOpts RunOptions) { logger.Info("creating dynamic config service client with m3cluster") envCfgResults, err = envConfig.Configure(environment.ConfigurationParameters{ - InterruptCh: runOpts.InterruptCh, + InterruptedCh: interruptOpts.InterruptedCh, InstrumentOpts: iOpts, HashingSeed: cfg.Hashing.Seed, NewDirectoryMode: newDirectoryMode, @@ -759,7 +766,7 @@ func Run(runOpts RunOptions) { logger.Info("creating static config service client with m3cluster") envCfgResults, err = envConfig.Configure(environment.ConfigurationParameters{ - InterruptCh: runOpts.InterruptCh, + InterruptedCh: interruptOpts.InterruptedCh, InstrumentOpts: iOpts, HostID: hostID, ForceColdWritesEnabled: forceColdWrites, @@ -1091,10 +1098,14 @@ func Run(runOpts RunOptions) { ) }() - // Wait for process interrupt. - xos.WaitForInterrupt(logger, xos.InterruptOptions{ - InterruptCh: runOpts.InterruptCh, - }) + // Stop our async watch and now block waiting for the interrupt. + intWatchCancel() + select { + case <-interruptOpts.InterruptedCh: + logger.Warn("interrupt already received. closing") + default: + xos.WaitForInterrupt(logger, interruptOpts) + } // Attempt graceful server close. closedCh := make(chan struct{}) diff --git a/src/metrics/matcher/namespaces.go b/src/metrics/matcher/namespaces.go index 79e42ad5fd..398a1688a7 100644 --- a/src/metrics/matcher/namespaces.go +++ b/src/metrics/matcher/namespaces.go @@ -144,7 +144,7 @@ func NewNamespaces(key string, opts Options) Namespaces { SetKVStore(n.store). SetUnmarshalFn(n.toNamespaces). SetProcessFn(n.process). - SetInterruptCh(opts.InterruptCh()) + SetInterruptedCh(opts.InterruptedCh()) n.Value = runtime.NewValue(key, valueOpts) return n } diff --git a/src/metrics/matcher/namespaces_test.go b/src/metrics/matcher/namespaces_test.go index 60e383f801..2895beab22 100644 --- a/src/metrics/matcher/namespaces_test.go +++ b/src/metrics/matcher/namespaces_test.go @@ -33,7 +33,6 @@ import ( "github.com/m3db/m3/src/metrics/generated/proto/rulepb" "github.com/m3db/m3/src/metrics/matcher/cache" "github.com/m3db/m3/src/metrics/rules" - xos "github.com/m3db/m3/src/x/os" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/require" @@ -128,14 +127,14 @@ func TestNamespacesWatchRulesetHardErr(t *testing.T) { } func TestNamespacesOpenWithInterrupt(t *testing.T) { - interruptCh := make(chan error, 1) - interruptCh <- xos.NewInterruptError("interrupt!") + interruptedCh := make(chan struct{}, 1) + interruptedCh <- struct{}{} - _, _, nss, _ := testNamespacesWithInterruptCh(interruptCh) + _, _, nss, _ := testNamespacesWithInterruptedCh(interruptedCh) err := nss.Open() require.Error(t, err) - require.Equal(t, err.Error(), "interrupt!") + require.Equal(t, err.Error(), "interrupted") } func TestToNamespacesNilValue(t *testing.T) { @@ -289,7 +288,7 @@ func TestNamespacesProcess(t *testing.T) { } } -func testNamespacesWithInterruptCh(interruptCh chan error) (kv.TxnStore, cache.Cache, *namespaces, Options) { +func testNamespacesWithInterruptedCh(interruptedCh chan struct{}) (kv.TxnStore, cache.Cache, *namespaces, Options) { store := mem.NewStore() cache := newMemCache() opts := NewOptions(). @@ -306,13 +305,13 @@ func testNamespacesWithInterruptCh(interruptCh chan error) (kv.TxnStore, cache.C SetOnRuleSetUpdatedFn(func(namespace []byte, ruleSet RuleSet) { cache.Register(namespace, ruleSet) }). - SetInterruptCh(interruptCh) + SetInterruptedCh(interruptedCh) return store, cache, NewNamespaces(testNamespacesKey, opts).(*namespaces), opts } func testNamespaces() (kv.TxnStore, cache.Cache, *namespaces, Options) { - return testNamespacesWithInterruptCh(nil) + return testNamespacesWithInterruptedCh(nil) } type memResults struct { diff --git a/src/metrics/matcher/options.go b/src/metrics/matcher/options.go index 176d4342da..add505d6b6 100644 --- a/src/metrics/matcher/options.go +++ b/src/metrics/matcher/options.go @@ -142,11 +142,11 @@ type Options interface { // RequireNamespaceWatchOnInit returns the flag to ensure matcher is initialized with a loaded namespace watch. RequireNamespaceWatchOnInit() bool - // InterruptCh returns the interrupt channel. - InterruptCh() <-chan error + // InterruptedCh returns the interrupted channel. + InterruptedCh() <-chan struct{} - // SetInterruptCh sets the interrupt channel. - SetInterruptCh(value <-chan error) Options + // SetInterruptedCh sets the interrupted channel. + SetInterruptedCh(value <-chan struct{}) Options } type options struct { @@ -164,7 +164,7 @@ type options struct { onNamespaceRemovedFn OnNamespaceRemovedFn onRuleSetUpdatedFn OnRuleSetUpdatedFn requireNamespaceWatchOnInit bool - interruptCh <-chan error + interruptedCh <-chan struct{} } // NewOptions creates a new set of options. @@ -325,13 +325,13 @@ func (o *options) RequireNamespaceWatchOnInit() bool { return o.requireNamespaceWatchOnInit } -func (o *options) SetInterruptCh(ch <-chan error) Options { - o.interruptCh = ch +func (o *options) SetInterruptedCh(ch <-chan struct{}) Options { + o.interruptedCh = ch return o } -func (o *options) InterruptCh() <-chan error { - return o.interruptCh +func (o *options) InterruptedCh() <-chan struct{} { + return o.interruptedCh } func defaultRuleSetKeyFn(namespace []byte) string { diff --git a/src/query/server/query.go b/src/query/server/query.go index 6c6ef9200a..04d8054a60 100644 --- a/src/query/server/query.go +++ b/src/query/server/query.go @@ -222,6 +222,13 @@ func Run(runOpts RunOptions) RunResult { }() } + interruptOpts := xos.NewInterruptOptions() + if runOpts.InterruptCh != nil { + interruptOpts.InterruptCh = runOpts.InterruptCh + } + intWatchCancel := xos.WatchForInterrupt(logger, interruptOpts) + defer intWatchCancel() + defer logger.Sync() cfg.Debug.SetRuntimeValues(logger) @@ -483,6 +490,7 @@ func Run(runOpts RunOptions) RunResult { downsampler, clusterClient, err = newDownsamplerAsync(cfg.Downsample, etcdConfig, backendStorage, clusterNamespacesWatcher, tsdbOpts.TagOptions(), clockOpts, instrumentOptions, rwOpts, runOpts, + interruptOpts, ) if err != nil { var interruptErr *xos.InterruptError @@ -516,6 +524,7 @@ func Run(runOpts RunOptions) RunResult { downsampler, clusterClient, err = newDownsamplerAsync(cfg.Downsample, cfg.ClusterManagement.Etcd, backendStorage, clusterNamespacesWatcher, tsdbOpts.TagOptions(), clockOpts, instrumentOptions, rwOpts, runOpts, + interruptOpts, ) if err != nil { logger.Fatal("unable to setup downsampler for prom remote backend", zap.Error(err)) @@ -735,10 +744,14 @@ func Run(runOpts RunOptions) RunResult { defer server.Close() } - // Wait for process interrupt. - xos.WaitForInterrupt(logger, xos.InterruptOptions{ - InterruptCh: runOpts.InterruptCh, - }) + // Stop our async watch and now block waiting for the interrupt. + intWatchCancel() + select { + case <-interruptOpts.InterruptedCh: + logger.Warn("interrupt already received. closing") + default: + xos.WaitForInterrupt(logger, interruptOpts) + } return runResult } @@ -789,7 +802,7 @@ func resolveEtcdForM3DB(cfg config.Configuration) (*etcdclient.Configuration, er func newDownsamplerAsync( cfg downsample.Configuration, etcdCfg *etcdclient.Configuration, storage storage.Appender, clusterNamespacesWatcher m3.ClusterNamespacesWatcher, tagOptions models.TagOptions, clockOpts clock.Options, - instrumentOptions instrument.Options, rwOpts xio.Options, runOpts RunOptions, + instrumentOptions instrument.Options, rwOpts xio.Options, runOpts RunOptions, interruptOpts xos.InterruptOptions, ) (downsample.Downsampler, clusterclient.Client, error) { var ( clusterClient clusterclient.Client @@ -821,7 +834,7 @@ func newDownsamplerAsync( cfg, clusterClient, storage, clusterNamespacesWatcher, tagOptions, clockOpts, instrumentOptions, rwOpts, runOpts.ApplyCustomRuleStore, - runOpts.InterruptCh) + interruptOpts.InterruptedCh) if err != nil { return nil, err } @@ -861,7 +874,7 @@ func newDownsampler( instrumentOpts instrument.Options, rwOpts xio.Options, applyCustomRuleStore downsample.CustomRuleStoreFn, - interruptCh <-chan error, + interruptedCh <-chan struct{}, ) (downsample.Downsampler, error) { // Namespace the downsampler metrics. instrumentOpts = instrumentOpts.SetMetricsScope( @@ -917,7 +930,7 @@ func newDownsampler( TagOptions: tagOptions, MetricsAppenderPoolOptions: metricsAppenderPoolOptions, RWOptions: rwOpts, - InterruptCh: interruptCh, + InterruptedCh: interruptedCh, }) if err != nil { return nil, fmt.Errorf("unable to create downsampler: %w", err) diff --git a/src/x/os/interrupt.go b/src/x/os/interrupt.go index 96729ea93c..dcf83c5267 100644 --- a/src/x/os/interrupt.go +++ b/src/x/os/interrupt.go @@ -24,6 +24,7 @@ import ( "fmt" "os" "os/signal" + "sync" "syscall" "go.uber.org/zap" @@ -34,6 +35,12 @@ type InterruptOptions struct { // InterruptChannel is an existing interrupt channel, if none // specified one will be created. InterruptCh <-chan error + + // InterruptedChannel is a channel that will be closed once an + // interrupt has been seen. Use this to pass to goroutines who + // want to be notified about interruptions that you don't want + // consuming from the main interrupt channel. + InterruptedCh chan struct{} } // InterruptError is an error representing an interrupt. @@ -41,6 +48,18 @@ type InterruptError struct { interrupt string } +// ErrInterrupted is an error indicating that the interrupted channel was closed, +// meaning an interrupt was received on the main interrupt channel. +var ErrInterrupted = NewInterruptError("interrupted") + +// NewInterruptOptions creates InterruptOptions with sane defaults. +func NewInterruptOptions() InterruptOptions { + return InterruptOptions{ + InterruptCh: NewInterruptChannel(1), + InterruptedCh: make(chan struct{}), + } +} + // NewInterruptError creates a new InterruptError. func NewInterruptError(interrupt string) error { return &InterruptError{interrupt: interrupt} @@ -50,6 +69,35 @@ func (i *InterruptError) Error() string { return i.interrupt } +// WatchForInterrupt watches for interrupts in a non-blocking fashion and closes +// the interrupted channel when an interrupt is found. Use this method to +// watch for interrupts while the caller continues to execute (e.g. during +// server startup). To ensure child goroutines get properly closed, pass them +// the interrupted channel. If the interrupted channel is closed, then the +// goroutine knows to stop its work. This method returns a function that +// can be used to stop the watch. +func WatchForInterrupt(logger *zap.Logger, opts InterruptOptions) func() { + interruptCh := opts.InterruptCh + closed := make(chan struct{}) + go func() { + select { + case err := <-interruptCh: + logger.Warn("interrupt", zap.Error(err)) + close(opts.InterruptedCh) + case <-closed: + logger.Info("interrupt watch stopped") + return + } + }() + + var doOnce sync.Once + return func() { + doOnce.Do(func() { + close(closed) + }) + } +} + // WaitForInterrupt will wait for an interrupt to occur and return when done. func WaitForInterrupt(logger *zap.Logger, opts InterruptOptions) { // Handle interrupts. @@ -63,6 +111,10 @@ func WaitForInterrupt(logger *zap.Logger, opts InterruptOptions) { } logger.Warn("interrupt", zap.Error(<-interruptCh)) + + if opts.InterruptedCh != nil { + close(opts.InterruptedCh) + } } // NewInterruptChannel will return an interrupt channel useful with multiple diff --git a/src/x/os/interrupt_test.go b/src/x/os/interrupt_test.go new file mode 100644 index 0000000000..ca25c5b014 --- /dev/null +++ b/src/x/os/interrupt_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package xos + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestWatchForInterrupt(t *testing.T) { + interruptCh := make(chan error, 1) + interruptCh <- NewInterruptError("interrupt time") + + opts := NewInterruptOptions() + opts.InterruptCh = interruptCh + + _ = WatchForInterrupt(zap.NewNop(), opts) + + select { + case <-opts.InterruptedCh: + case <-time.After(5 * time.Second): + t.Fail() + } +} + +func TestWatchForInterruptCloseTwice(t *testing.T) { + opts := NewInterruptOptions() + closer := WatchForInterrupt(zap.NewNop(), opts) + + require.NotPanics(t, func() { + closer() + closer() + }) +} diff --git a/src/x/watch/options.go b/src/x/watch/options.go index 5021a436e0..3ff3a3df71 100644 --- a/src/x/watch/options.go +++ b/src/x/watch/options.go @@ -68,11 +68,11 @@ type Options interface { // SetKey sets the key for the watch. SetKey(key string) Options - // InterruptCh returns the interrupt channel. - InterruptCh() <-chan error + // InterruptedCh returns the interrupted channel. + InterruptedCh() <-chan struct{} - // SetInterruptCh sets the interrupt channel. - SetInterruptCh(value <-chan error) Options + // SetInterruptedCh sets the interrupted channel. + SetInterruptedCh(value <-chan struct{}) Options } type options struct { @@ -82,7 +82,7 @@ type options struct { getUpdateFn GetUpdateFn processFn ProcessFn key string - interruptCh <-chan error + interruptedCh <-chan struct{} } // NewOptions creates a new set of options. @@ -153,11 +153,11 @@ func (o *options) SetKey(key string) Options { return &opts } -func (o *options) InterruptCh() <-chan error { - return o.interruptCh +func (o *options) InterruptedCh() <-chan struct{} { + return o.interruptedCh } -func (o *options) SetInterruptCh(ch <-chan error) Options { - o.interruptCh = ch +func (o *options) SetInterruptedCh(ch <-chan struct{}) Options { + o.interruptedCh = ch return o } diff --git a/src/x/watch/value.go b/src/x/watch/value.go index 24ac9a27a9..82532c3efb 100644 --- a/src/x/watch/value.go +++ b/src/x/watch/value.go @@ -27,6 +27,8 @@ import ( "time" "go.uber.org/zap" + + xos "github.com/m3db/m3/src/x/os" ) var ( @@ -114,11 +116,11 @@ func (v *value) Watch() error { // error condition is resolved. defer func() { go v.watchUpdates(v.updatable) }() - interruptCh := v.opts.InterruptCh() - if interruptCh == nil { - // NB(nate): if no interrupt channel is provided, then this wait is not + interruptedCh := v.opts.InterruptedCh() + if interruptedCh == nil { + // NB(nate): if no interrupted channel is provided, then this wait is not // gracefully interruptable. - interruptCh = make(chan error) + interruptedCh = make(chan struct{}) } select { @@ -128,8 +130,8 @@ func (v *value) Watch() error { innerError: errInitWatchTimeout, key: v.opts.Key(), } - case err = <-interruptCh: - return err + case <-interruptedCh: + return xos.ErrInterrupted } update, err := v.getUpdateFn(v.updatable) diff --git a/src/x/watch/value_test.go b/src/x/watch/value_test.go index b4b9c1718a..a59cd08d49 100644 --- a/src/x/watch/value_test.go +++ b/src/x/watch/value_test.go @@ -27,8 +27,6 @@ import ( "time" "github.com/m3db/m3/src/x/instrument" - xos "github.com/m3db/m3/src/x/os" - "github.com/stretchr/testify/require" ) @@ -95,18 +93,18 @@ func TestValueWatchSuccess(t *testing.T) { } func TestValueWatchInterrupt(t *testing.T) { - interruptCh := make(chan error, 1) - interruptCh <- xos.NewInterruptError("interrupt!") + interruptedCh := make(chan struct{}) + close(interruptedCh) opts := testValueOptions(). - SetInterruptCh(interruptCh). + SetInterruptedCh(interruptedCh). SetNewUpdatableFn(testUpdatableFn(NewWatchable())) val := NewValue(opts).(*value) err := val.Watch() require.Error(t, err) - require.Equal(t, err.Error(), "interrupt!") + require.Equal(t, err.Error(), "interrupted") } func TestValueUnwatchNotWatching(t *testing.T) {