diff --git a/context.go b/context.go index 4f9ec67d6..3d7535648 100644 --- a/context.go +++ b/context.go @@ -31,6 +31,11 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte if nc == nil { return nil, ErrInvalidConnection } + // Check whether the context is done already before making + // the request. + if ctx.Err() != nil { + return nil, ctx.Err() + } nc.mu.Lock() // If user wants the old style. @@ -116,6 +121,9 @@ func (s *Subscription) NextMsgWithContext(ctx context.Context) (*Msg, error) { if s == nil { return nil, ErrBadSubscription } + if ctx.Err() != nil { + return nil, ctx.Err() + } s.mu.Lock() err := s.validateNextMsgState() @@ -124,7 +132,6 @@ func (s *Subscription) NextMsgWithContext(ctx context.Context) (*Msg, error) { return nil, err } - // snapshot mch := s.mch s.mu.Unlock() diff --git a/test/context_test.go b/test/context_test.go index 04359103a..453984f8a 100644 --- a/test/context_test.go +++ b/test/context_test.go @@ -17,7 +17,6 @@ package test import ( "context" - "errors" "strings" "testing" "time" @@ -137,16 +136,6 @@ func testContextRequestWithTimeoutCanceled(t *testing.T, nc *nats.Conn) { // Cancel the context already so that rest of requests fail. cancelCB() - // Wait for context to be eventually canceled. - waitFor(t, 1*time.Millisecond, 50*time.Millisecond, func() error { - select { - case <-ctx.Done(): - return nil - default: - return errors.New("Timeout waiting for context to be canceled") - } - }) - // Context is already canceled so requests should immediately fail. _, err = nc.RequestWithContext(ctx, "fast", []byte("world")) if err == nil {