Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: close connection after responding the short-connection request #31

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pkg/common/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ import (

var (
// These errors are the base error, which are used for checking in errors.Is()
ErrNeedMore = errors.New("need more data")
ErrChunkedStream = errors.New("chunked stream")
ErrBodyTooLarge = errors.New("body size exceeds the given limit")
ErrHijacked = errors.New("connection has been hijacked")
ErrIdleTimeout = errors.New("idle timeout")
ErrTimeout = errors.New("timeout")
ErrNothingRead = errors.New("nothing read")
ErrNeedMore = errors.New("need more data")
ErrChunkedStream = errors.New("chunked stream")
ErrBodyTooLarge = errors.New("body size exceeds the given limit")
ErrHijacked = errors.New("connection has been hijacked")
ErrIdleTimeout = errors.New("idle timeout")
ErrTimeout = errors.New("timeout")
ErrNothingRead = errors.New("nothing read")
ErrShortConnection = errors.New("short connection")
)

// ErrorType is an unsigned 64-bit error code as defined in the hertz spec.
Expand Down
9 changes: 5 additions & 4 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ import (
const NextProtoTLS = suite.HTTP1

var (
errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil)
errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil)
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")
errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil)
errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil)
errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")
)

type Option struct {
Expand Down Expand Up @@ -327,7 +328,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
}

if connectionClose {
return
return errShortConnection
}

if s.IdleTimeout == 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func errProcess(conn io.Closer, err error) {
}()

// Quiet close the connection
if errors.Is(err, errs.ErrIdleTimeout) {
if errors.Is(err, errs.ErrShortConnection) || errors.Is(err, errs.ErrIdleTimeout) {
return
}

Expand Down
33 changes: 31 additions & 2 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ package route
import (
"context"
"crypto/tls"
"errors"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"sync/atomic"
"testing"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/config"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
)
Expand Down Expand Up @@ -148,9 +151,35 @@ func TestEngineUnescapeRaw(t *testing.T) {
}
}

func TestConnectionClose(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
atomic.StoreUint32(&engine.status, statusRunning)
engine.Init()
engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) {
ctx.String(200, "ok")
})
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n")
err := engine.Serve(context.Background(), conn)
assert.True(t, errors.Is(err, errs.ErrShortConnection))
}

func TestConnectionClose01(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
atomic.StoreUint32(&engine.status, statusRunning)
engine.Init()
engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) {
ctx.SetConnectionClose()
ctx.String(200, "ok")
})
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
err := engine.Serve(context.Background(), conn)
assert.True(t, errors.Is(err, errs.ErrShortConnection))
}

func TestIdleTimeout(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
engine.options.IdleTimeout = 0
atomic.StoreUint32(&engine.status, statusRunning)
engine.Init()
engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) {
time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -180,9 +209,9 @@ func TestIdleTimeout(t *testing.T) {
func TestIdleTimeout01(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
engine.options.IdleTimeout = 1 * time.Second
engine.status = statusRunning
atomic.StoreUint32(&engine.status, statusRunning)
engine.Init()
// engine.status = route.statusRunning
atomic.StoreUint32(&engine.status, statusRunning)
engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) {
time.Sleep(10 * time.Millisecond)
ctx.String(200, "ok")
Expand Down