From 55d9a8cffe23657e42b6710c104aca42a81aab45 Mon Sep 17 00:00:00 2001 From: gaowenju Date: Wed, 1 Jun 2022 19:00:36 +0800 Subject: [PATCH] feat: close connection after responding the short-connection request Change-Id: I22d17642339ece5f198936bf30bd50768808f46a --- pkg/common/errors/errors.go | 15 ++++++++------- pkg/protocol/http1/server.go | 9 +++++---- pkg/route/engine.go | 2 +- pkg/route/engine_test.go | 33 +++++++++++++++++++++++++++++++-- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/pkg/common/errors/errors.go b/pkg/common/errors/errors.go index 0c1991556..38e90c9c1 100644 --- a/pkg/common/errors/errors.go +++ b/pkg/common/errors/errors.go @@ -49,13 +49,14 @@ import ( var ( // These errors are the base error, which are used for checking in errors.Is() - ErrNeedMore = errors.New("need more data") - ErrChunkedStream = errors.New("chunked stream") - ErrBodyTooLarge = errors.New("body size exceeds the given limit") - ErrHijacked = errors.New("connection has been hijacked") - ErrIdleTimeout = errors.New("idle timeout") - ErrTimeout = errors.New("timeout") - ErrNothingRead = errors.New("nothing read") + ErrNeedMore = errors.New("need more data") + ErrChunkedStream = errors.New("chunked stream") + ErrBodyTooLarge = errors.New("body size exceeds the given limit") + ErrHijacked = errors.New("connection has been hijacked") + ErrIdleTimeout = errors.New("idle timeout") + ErrTimeout = errors.New("timeout") + ErrNothingRead = errors.New("nothing read") + ErrShortConnection = errors.New("short connection") ) // ErrorType is an unsigned 64-bit error code as defined in the hertz spec. diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 4fdaf3d39..b4e9b4e98 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -45,9 +45,10 @@ import ( const NextProtoTLS = suite.HTTP1 var ( - errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil) - errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil) - errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") + errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil) + errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil) + errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") + errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") ) type Option struct { @@ -327,7 +328,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { } if connectionClose { - return + return errShortConnection } if s.IdleTimeout == 0 { diff --git a/pkg/route/engine.go b/pkg/route/engine.go index b8a626948..6ee5b4a1d 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -365,7 +365,7 @@ func errProcess(conn io.Closer, err error) { }() // Quiet close the connection - if errors.Is(err, errs.ErrIdleTimeout) { + if errors.Is(err, errs.ErrShortConnection) || errors.Is(err, errs.ErrIdleTimeout) { return } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 96e950626..451ea35db 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -43,16 +43,19 @@ package route import ( "context" "crypto/tls" + "errors" "fmt" "html/template" "io/ioutil" "net" "net/http" + "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" + errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" ) @@ -148,9 +151,35 @@ func TestEngineUnescapeRaw(t *testing.T) { } } +func TestConnectionClose(t *testing.T) { + engine := NewEngine(config.NewOptions(nil)) + atomic.StoreUint32(&engine.status, statusRunning) + engine.Init() + engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { + ctx.String(200, "ok") + }) + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n") + err := engine.Serve(context.Background(), conn) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) +} + +func TestConnectionClose01(t *testing.T) { + engine := NewEngine(config.NewOptions(nil)) + atomic.StoreUint32(&engine.status, statusRunning) + engine.Init() + engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { + ctx.SetConnectionClose() + ctx.String(200, "ok") + }) + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + err := engine.Serve(context.Background(), conn) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) +} + func TestIdleTimeout(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.options.IdleTimeout = 0 + atomic.StoreUint32(&engine.status, statusRunning) engine.Init() engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { time.Sleep(100 * time.Millisecond) @@ -180,9 +209,9 @@ func TestIdleTimeout(t *testing.T) { func TestIdleTimeout01(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.options.IdleTimeout = 1 * time.Second - engine.status = statusRunning + atomic.StoreUint32(&engine.status, statusRunning) engine.Init() - // engine.status = route.statusRunning + atomic.StoreUint32(&engine.status, statusRunning) engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { time.Sleep(10 * time.Millisecond) ctx.String(200, "ok")