diff --git a/pkg/stub/stub.go b/pkg/stub/stub.go index 3cc837b0..58f84acc 100644 --- a/pkg/stub/stub.go +++ b/pkg/stub/stub.go @@ -137,7 +137,8 @@ type PostUpdateContainerInterface interface { // Stub is the interface the stub provides for the plugin implementation. type Stub interface { - // Run the plugin. Starts the plugin then waits for an error or the plugin to stop + // Run the plugin in blocked way and wait for critical error return from plugin service. + // before this function return, stopped plugin can be restart. Run(context.Context) error // Start the plugin. Start(context.Context) error @@ -255,7 +256,6 @@ type stub struct { rpcs *ttrpc.Server rpcc *ttrpc.Client runtime api.RuntimeService - closeOnce sync.Once started bool doneC chan struct{} srvErrC chan error @@ -288,7 +288,6 @@ func New(p interface{}, opts ...Option) (Stub, error) { idx: os.Getenv(api.PluginIdxEnvVar), socketPath: api.DefaultSocketPath, dialer: func(p string) (stdnet.Conn, error) { return stdnet.Dial("unix", p) }, - doneC: make(chan struct{}), } for _, o := range opts { @@ -319,7 +318,7 @@ func (stub *stub) Start(ctx context.Context) (retErr error) { if stub.started { return fmt.Errorf("stub already started") } - stub.started = true + stub.doneC = make(chan struct{}) err := stub.connect() if err != nil { @@ -401,6 +400,7 @@ func (stub *stub) Start(ctx context.Context) (retErr error) { log.Infof(ctx, "Started plugin %s...", stub.Name()) + stub.started = true return nil } @@ -413,24 +413,31 @@ func (stub *stub) Stop() { stub.close() } +// reset stub to the status that can initiate a new +// NRI connection, the caller must hold lock. func (stub *stub) close() { - stub.closeOnce.Do(func() { - if stub.rpcl != nil { - stub.rpcl.Close() - } - if stub.rpcs != nil { - stub.rpcs.Close() - } - if stub.rpcc != nil { - stub.rpcc.Close() - } - if stub.rpcm != nil { - stub.rpcm.Close() - } - if stub.srvErrC != nil { - <-stub.doneC - } - }) + if !stub.started { + return + } + + if stub.rpcl != nil { + stub.rpcl.Close() + } + if stub.rpcs != nil { + stub.rpcs.Close() + } + if stub.rpcc != nil { + stub.rpcc.Close() + } + if stub.rpcm != nil { + stub.rpcm.Close() + } + if stub.srvErrC != nil { + <-stub.doneC + } + + stub.started = false + stub.conn = nil } // Run the plugin. Start event processing then wait for an error or getting stopped. @@ -441,22 +448,33 @@ func (stub *stub) Run(ctx context.Context) error { return err } - err = <-stub.srvErrC - if err == ttrpc.ErrServerClosed { - return nil + for { + select { + case <-ctx.Done(): + return nil + case err = <-stub.srvErrC: + if isRecoverableErr(err) { + log.Warnf(ctx, "Plugin service stopped", "error", err) + continue + } + return err + } } +} - return err +func isRecoverableErr(err error) bool { + // For now, the error reports from the ttrpc level are regarded as tolerable errors. + return errors.Is(err, ttrpc.ErrProtocol) || + errors.Is(err, ttrpc.ErrClosed) || + errors.Is(err, ttrpc.ErrServerClosed) || + errors.Is(err, ttrpc.ErrStreamClosed) } -// Wait for the plugin to stop. +// Wait for the plugin to stop, should be called after Start() or Run(). func (stub *stub) Wait() { - stub.Lock() - if stub.srvErrC == nil { - return + if stub.started { + <-stub.doneC } - stub.Unlock() - <-stub.doneC } // Name returns the full indexed name of the plugin. @@ -518,7 +536,9 @@ func (stub *stub) register(ctx context.Context) error { // Handle a lost connection. func (stub *stub) connClosed() { + stub.Lock() stub.close() + stub.Unlock() if stub.onClose != nil { stub.onClose() return