Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SplitHTTP: Server supports HTTP/3 #3554

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions transport/internet/splithttp/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"sync"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
Expand Down Expand Up @@ -233,10 +235,13 @@ func (c *httpResponseBodyWriter) Close() error {

type Listener struct {
sync.Mutex
server http.Server
listener net.Listener
config *Config
addConn internet.ConnHandler
server http.Server
h3server *http3.Server
listener net.Listener
h3listener *quic.EarlyListener
config *Config
addConn internet.ConnHandler
isH3 bool
}

func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
Expand All @@ -253,6 +258,17 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
var listener net.Listener
var err error
var localAddr = gonet.TCPAddr{}
handler := &requestHandler{
host: shSettings.Host,
path: shSettings.GetNormalizedPath(),
ln: l,
sessionMu: &sync.Mutex{},
sessions: sync.Map{},
localAddr: localAddr,
}
tlsConfig := getTLSConfig(streamSettings)
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"


if port == net.Port(0) { // unix
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Expand All @@ -263,6 +279,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
}
errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
} else if l.isH3 { // quic
Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
IP: address.IP(),
Port: int(port),
}, streamSettings.SocketSettings)
if err != nil {
return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
}
h3listener, err := quic.ListenEarly(Conn,tlsConfig, nil)
if err != nil {
return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
}
l.h3listener = h3listener
errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)

l.h3server = &http3.Server{
Handler: handler,
}
go func() {
if err := l.h3server.ServeListener(l.h3listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
}
}()
} else { // tcp
localAddr = gonet.TCPAddr{
IP: address.IP(),
Expand All @@ -275,41 +314,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
if err != nil {
return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
}
l.listener = listener
errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
}

// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.server = http.Server{
Handler: h2cHandler,
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
go func() {
if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
}
}()
}
l.listener = listener
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
}

handler := &requestHandler{
host: shSettings.Host,
path: shSettings.GetNormalizedPath(),
ln: l,
sessionMu: &sync.Mutex{},
sessions: sync.Map{},
localAddr: localAddr,
}

// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})

l.listener = listener

l.server = http.Server{
Handler: h2cHandler,
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}

go func() {
if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
}
}()

return l, err
}

Expand All @@ -320,9 +347,22 @@ func (ln *Listener) Addr() net.Addr {

// Close implements net.Listener.Close().
func (ln *Listener) Close() error {
return ln.listener.Close()
if ln.h3server != nil {
ll11l1lIllIl1lll marked this conversation as resolved.
Show resolved Hide resolved
if err := ln.h3server.Close(); err != nil {
return err
}
} else if ln.listener != nil {
return ln.listener.Close()
}
return errors.New("listener does not have an HTTP/3 server or a net.listener")
}
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
config := v2tls.ConfigFromStreamSettings(streamSettings)
if config == nil {
return &tls.Config{}
}
return config.GetTLSConfig()
}

func init() {
common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
}
40 changes: 40 additions & 0 deletions transport/internet/splithttp/splithttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp"
"github.com/xtls/xray-core/testing/servers/udp"
"github.com/xtls/xray-core/transport/internet"
. "github.com/xtls/xray-core/transport/internet/splithttp"
"github.com/xtls/xray-core/transport/internet/stat"
Expand Down Expand Up @@ -204,3 +205,42 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
t.Error("Expected h2 but got:", resp.ProtoMajor)
}
}

func Test_listenSHAndDial_QUIC(t *testing.T) {
if runtime.GOARCH == "arm64" {
return
}

listenPort := udp.PickPort()

start := time.Now()

streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Path: "shs",
},
SecurityType: "tls",
SecuritySettings: &tls.Config{
AllowInsecure: true,
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
NextProtocol: []string{"h3"},
},
}
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
_ = conn.Close()
}()
})
common.Must(err)
defer listen.Close()

conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
_ = conn.Close()

end := time.Now()
if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start)
}
}