diff --git a/client.go b/client.go index 396e142..e5c7c3e 100644 --- a/client.go +++ b/client.go @@ -357,7 +357,7 @@ func (hc *HostClient) acquireConn() (*clientConn, error) { return c, nil } -func (hc *HostClient) dialHostHard() (conn net.Conn, err error) { +func (hc *HostClient) dialHostHard() (conn network.Conn, err error) { hc.addrsLock.Lock() n := len(hc.addrs) hc.addrsLock.Unlock() @@ -384,8 +384,8 @@ func (hc *HostClient) dialHostHard() (conn net.Conn, err error) { return nil, err } -func dialAddr(addr string, dial network.Dialer, tlsConfig *tls.Config, timeout time.Duration, isTLS bool) (net.Conn, error) { - var conn net.Conn +func dialAddr(addr string, dial network.Dialer, tlsConfig *tls.Config, timeout time.Duration, isTLS bool) (network.Conn, error) { + var conn network.Conn var err error if dial == nil { hlog.Warnf("HERTZ: HostClient: no dialer specified, trying to use default dialer") @@ -516,9 +516,9 @@ func (hc *HostClient) nextAddr() string { return addr } -func (hc *HostClient) newClientConn(c net.Conn, singleUse bool) (*clientConn, error) { +func (hc *HostClient) newClientConn(c network.Conn, singleUse bool) (*clientConn, error) { cc := &clientConn{} - cc.tconn = &h2Conn{c} + cc.tconn = c cc.createdTime = time.Now() cc.readerDone = make(chan struct{}) cc.nextStreamID = 1 @@ -589,7 +589,7 @@ func (hc *HostClient) newClientConn(c net.Conn, singleUse bool) (*clientConn, er // clientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type clientConn struct { - tconn net.Conn // usually *tls.Conn, except specialized impls + tconn network.Conn hc *HostClient // readLoop goroutine fields: diff --git a/client_test.go b/client_test.go index c402924..c2af1ae 100644 --- a/client_test.go +++ b/client_test.go @@ -464,20 +464,12 @@ func (c *testNetConn) Close() error { if c.onClose != nil { c.onClose() } else { - if cwrb, ok := c.Conn.(CloseWithoutResetBuffer); ok { - return cwrb.CloseNoResetBuffer() - } else { - return c.Conn.Close() - } + return c.Conn.Close() } } c.closed = true - if cwrb, ok := c.Conn.(CloseWithoutResetBuffer); ok { - return cwrb.CloseNoResetBuffer() - } else { - return c.Conn.Close() - } + return c.Conn.Close() } // Tests that the Transport only keeps one pending dial open per destination address. @@ -826,8 +818,8 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { type clientTester struct { t *testing.T tr *HostClient - cc, sc net.Conn // server and client conn - fr *Framer // server's framer + cc, sc network.Conn // server and client conn + fr *Framer // server's framer client func() error server func() error } @@ -868,7 +860,7 @@ func newClientTester(t *testing.T) *clientTester { t.Fatal(err) } ln.Close() - ct.cc = &h2Conn{cc} + ct.cc = cc // ct.sc = standard.NewConn(sc, 4096) ct.sc = newMockNetworkConn(sc) ct.fr = NewFramer(ct.sc, ct.sc) @@ -2089,11 +2081,7 @@ func (c *noteCloseConn) SetReadTimeout(t time.Duration) error { func (c *noteCloseConn) Close() error { c.onceClose.Do(c.closefn) - if cwrb, ok := c.Conn.(CloseWithoutResetBuffer); ok { - return cwrb.CloseNoResetBuffer() - } else { - return c.Conn.Close() - } + return c.Conn.Close() } // RFC 7540 section 8.1.2.2 @@ -3498,7 +3486,7 @@ func TestHostClientRetryAfterGOAWAY(t *testing.T) { ct := &clientTester{ t: t, tr: tr, - cc: &h2Conn{cc}, + cc: cc, sc: newMockNetworkConn(sc), } ct.fr = NewFramer(sc, ct.sc) diff --git a/go.mod b/go.mod index 12dc5a6..15dce84 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,15 @@ module github.com/hertz-contrib/http2 go 1.18 require ( - github.com/cloudwego/hertz v0.5.3-0.20230208034101-28c304eb7082 + github.com/cloudwego/hertz v0.6.2 golang.org/x/net v0.7.0 ) require ( github.com/bytedance/go-tagexpr/v2 v2.9.2 // indirect github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 // indirect - github.com/bytedance/sonic v1.5.0 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06 // indirect + github.com/bytedance/sonic v1.8.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cloudwego/netpoll v0.3.1 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/golang/protobuf v1.5.0 // indirect diff --git a/go.sum b/go.sum index bbeb01f..9ee82cf 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,16 @@ github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6H github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/sonic v1.5.0 h1:XWdTi8bwPgxIML+eNV1IwNuTROK6EUrQ65ey8yd6fRQ= +github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= +github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06 h1:1sDoSuDPWzhkdzNVxCxtIaKiAe96ESVPv8coGwc1gZ4= +github.com/bytedance/sonic v1.8.1 h1:NqAHCaGaTzro0xMmnTCLUyRlbEP6r8MCA1cJUrH3Pu4= +github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/cloudwego/hertz v0.5.3-0.20230208034101-28c304eb7082 h1:JACtt2oDZdk/7SWncpDyAS7Qw8A/QeupuKm1d747cws= -github.com/cloudwego/hertz v0.5.3-0.20230208034101-28c304eb7082/go.mod h1:K1U0RlU07CDeBINfHNbafH/3j9uSgIW8otbjUys3OPY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/cloudwego/hertz v0.6.2 h1:8NM0yHbyv8B4dNYgICirk733S7monTNB+uR9as1It1Y= +github.com/cloudwego/hertz v0.6.2/go.mod h1:2em2hGREvCBawsTQcQxyWBGVlCeo+N1pp2q0HkkbwR0= github.com/cloudwego/netpoll v0.3.1 h1:xByoORmCLIyKZ8gS+da06WDo3j+jvmhaqS2KeKejtBk= github.com/cloudwego/netpoll v0.3.1/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -20,22 +24,35 @@ github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -45,17 +62,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -63,6 +86,7 @@ google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+Rur google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/http2.go b/http2.go index 48ae2e8..b824efd 100644 --- a/http2.go +++ b/http2.go @@ -36,12 +36,12 @@ import ( "crypto/tls" "fmt" "io" - "net" "os" "strconv" "strings" "sync" + "github.com/cloudwego/hertz/pkg/network" "golang.org/x/net/http/httpguts" ) @@ -339,18 +339,7 @@ func validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } -type CloseWithoutResetBuffer interface { - CloseNoResetBuffer() error -} - -type h2Conn struct { - net.Conn -} - -func (c *h2Conn) Close() error { - if cwrb, ok := c.Conn.(CloseWithoutResetBuffer); ok { - return cwrb.CloseNoResetBuffer() - } else { - return c.Conn.Close() - } +type h2ServerConn struct { + network.Conn + rw *responseWriter } diff --git a/response_writer.go b/response_writer.go index caef8b6..e603909 100644 --- a/response_writer.go +++ b/response_writer.go @@ -31,6 +31,7 @@ import ( "strings" "sync" + "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/hertz-contrib/http2/internal/bytesconv" @@ -425,3 +426,30 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { return err } } + +type extWriter struct { + rw *responseWriter +} + +func (w *extWriter) Write(p []byte) (n int, err error) { + return w.rw.Write(p) +} + +func (w *extWriter) Flush() error { + w.rw.Flush() + return nil +} + +func (w *extWriter) Finalize() error { + w.rw.handlerDone() + return nil +} + +func NewResponserWriter(conn network.Conn) network.ExtWriter { + c, ok := conn.(*h2ServerConn) + if !ok { + panic("the conn is not the H2 Conn!") + } + + return &extWriter{c.rw} +} diff --git a/server.go b/server.go index 8c62ce5..1c09530 100644 --- a/server.go +++ b/server.go @@ -201,8 +201,7 @@ func (s *Server) Serve(ctx context.Context, c network.Conn) error { sc := &serverConn{ srv: s, engine: &s.BaseEngine, - conn: &h2Conn{c}, - rawConn: c, + conn: c, baseCtx: ctx, remoteAddrStr: c.RemoteAddr().String(), bw: newBufferedWriter(c), @@ -319,8 +318,7 @@ func (sc *serverConn) rejectConn(err ErrCode, debug string) { type serverConn struct { // Immutable: srv *Server - conn net.Conn - rawConn network.Conn + conn network.Conn bw *bufferedWriter // writing to conn baseCtx context.Context framer *Framer @@ -1631,6 +1629,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { rw, err := sc.newWriterAndRequest(st, f) st.rw = rw + st.reqCtx.SetConn(&h2ServerConn{sc.conn, rw}) if err != nil { return err } @@ -1689,10 +1688,6 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream } reqCtx := sc.engine.AcquireReqCtx() - if connection, ok := sc.conn.(network.Conn); ok { - reqCtx.SetConn(connection) - } - reqCtx.SetConn(sc.rawConn) reqCtx.Request.Header.SetProtocol(consts.HTTP20) reqCtx.Request.Header.InitContentLengthWithValue(-1) @@ -1881,6 +1876,11 @@ func (sc *serverConn) runHandler(rw *responseWriter, reqCtx *app.RequestContext, } return } else { + if writer := reqCtx.Response.GetHijackWriter(); writer != nil { + writer.Finalize() + return + } + rw.WriteHeader(reqCtx.Response.StatusCode()) err := writeResponseBody(rw, reqCtx) if err != nil { diff --git a/server_test.go b/server_test.go index 109c561..d9d0cee 100644 --- a/server_test.go +++ b/server_test.go @@ -3676,3 +3676,92 @@ func TestProtocolErrorAfterGoAway(t *testing.T) { } } } + +func TestServer_HijackWriter(t *testing.T) { + testServerResponse(t, func(c context.Context, ctx *app.RequestContext) error { + ctx.Response.HijackWriter(NewResponserWriter(ctx.GetConn())) + ctx.Write([]byte("Hello")) + ctx.Flush() + ctx.Write([]byte("World")) + ctx.Flush() + return nil + }, func(st *hertzServerTester) { + getSlash(st) + hf := st.wantHeaders() + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + + df := st.wantData() + if !df.valid { + t.Fatal("data Frame is invalid") + } + + if string(df.data) != "Hello" { + t.Fatalf("Got %v; want %v", string(df.data), "Hello") + } + + df = st.wantData() + if !df.valid { + t.Fatal("data Frame is invalid") + } + + if string(df.data) != "World" { + t.Fatalf("Got %v; want %v", string(df.data), "World") + } + + df = st.wantData() + if !df.StreamEnded() { + t.Fatal("want STREAM_ENDED flag") + } + }) +} + +func TestServer_HijackWriter_Flush(t *testing.T) { + ch := make(chan struct{}) + testServerResponse(t, func(c context.Context, ctx *app.RequestContext) error { + ctx.Response.HijackWriter(NewResponserWriter(ctx.GetConn())) + ctx.Write([]byte("Hello")) + ctx.Flush() + + ch <- struct{}{} + + ctx.Write([]byte("World")) + ctx.Flush() + + ch <- struct{}{} + return nil + }, func(st *hertzServerTester) { + getSlash(st) + hf := st.wantHeaders() + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + + <-ch + df := st.wantData() + if !df.valid { + t.Fatal("data Frame is invalid") + } + + if string(df.data) != "Hello" { + t.Fatalf("Got %v; want %v", string(df.data), "Hello") + } + + <-ch + df = st.wantData() + if !df.valid { + t.Fatal("data Frame is invalid") + } + + if string(df.data) != "World" { + t.Fatalf("Got %v; want %v", string(df.data), "World") + } + + df = st.wantData() + if !df.StreamEnded() { + t.Fatal("want STREAM_ENDED flag") + } + + }) +}