Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: distinguish between global and local transporter #356

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/cloudwego/hertz/pkg/common/tracer/stats"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/route"
)

// WithKeepAliveTimeout sets keep-alive timeout.
Expand Down Expand Up @@ -216,7 +215,7 @@ func WithExitWaitTime(timeout time.Duration) config.Option {
// NOTE: If a tls server is started, it won't accept non-tls request.
func WithTLS(cfg *tls.Config) config.Option {
return config.Option{F: func(o *config.Options) {
route.SetTransporter(standard.NewTransporter)
o.TransporterNewer = standard.NewTransporter
o.TLS = cfg
}}
}
Expand All @@ -231,7 +230,7 @@ func WithListenConfig(l *net.ListenConfig) config.Option {
// WithTransport sets which network library to use.
func WithTransport(transporter func(options *config.Options) network.Transporter) config.Option {
return config.Option{F: func(o *config.Options) {
route.SetTransporter(transporter)
o.TransporterNewer = transporter
}}
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/cloudwego/hertz/pkg/app/server/registry"
"github.com/cloudwego/hertz/pkg/network"
)

// Option is the only struct that can be used to set Options.
Expand Down Expand Up @@ -68,11 +69,15 @@ type Options struct {
TraceLevel interface{}
ListenConfig *net.ListenConfig

// TransporterNewer is the function to create a transporter.
TransporterNewer func(opt *Options) network.Transporter
welkeyever marked this conversation as resolved.
Show resolved Hide resolved

// Registry is used for service registry.
Registry registry.Registry
// RegistryInfo is base info used for service registry.
RegistryInfo *registry.Info
// Enable automatically HTML template reloading mechanism.
FGYFFFF marked this conversation as resolved.
Show resolved Hide resolved

AutoReloadRender bool
// If AutoReloadInterval is set to 0(default).
// The HTML template will reload according to files' changing event
Expand Down
31 changes: 28 additions & 3 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ import (
"github.com/cloudwego/hertz/pkg/protocol/suite"
)

const unknownTransporterName = "unknown"

var (
defaultTransporter = standard.NewTransporter

Expand Down Expand Up @@ -199,15 +201,35 @@ func (engine *Engine) GetOptions() *config.Options {
return engine.options
}

// SetTransporter only sets the global default value for the transporter.
// Use WithTransporter during engine creation to set the transporter for the engine.
func SetTransporter(transporter func(options *config.Options) network.Transporter) {
defaultTransporter = transporter
}

func (engine *Engine) GetTransporterName() (tName string) {
return getTransporterName(engine.transport)
}

func getTransporterName(transporter network.Transporter) (tName string) {
defer func() {
err := recover()
if err != nil || tName == "" {
tName = unknownTransporterName
}
}()
t := reflect.ValueOf(transporter).Type().String()
tName = strings.Split(strings.TrimPrefix(t, "*"), ".")[0]
return tName
}

// Deprecated: This only get the global default transporter - may not be the real one used by the engine.
// Use engine.GetTransporterName for the real transporter used.
func GetTransporterName() (tName string) {
defer func() {
err := recover()
if err != nil {
tName = "unknown"
if err != nil || tName == "" {
welkeyever marked this conversation as resolved.
Show resolved Hide resolved
tName = unknownTransporterName
}
}()
fName := runtime.FuncForPC(reflect.ValueOf(defaultTransporter).Pointer()).Name()
Expand Down Expand Up @@ -363,7 +385,7 @@ func (engine *Engine) alpnEnable() bool {
}

func (engine *Engine) listenAndServe() error {
hlog.SystemLogger().Infof("Using network library=%s", GetTransporterName())
hlog.SystemLogger().Infof("Using network library=%s", engine.GetTransporterName())
return engine.transport.ListenAndServe(engine.onData)
}

Expand Down Expand Up @@ -520,6 +542,9 @@ func NewEngine(opt *config.Options) *Engine {
enableTrace: true,
options: opt,
}
if opt.TransporterNewer != nil {
engine.transport = opt.TransporterNewer(opt)
}
engine.RouterGroup.engine = engine

traceLevel := initTrace(engine)
Expand Down
37 changes: 37 additions & 0 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,33 @@ import (
)

func TestNew_Engine(t *testing.T) {
defaultTransporter = standard.NewTransporter
opt := config.NewOptions([]config.Option{})
router := NewEngine(opt)
assert.DeepEqual(t, "standard", router.GetTransporterName())
assert.DeepEqual(t, "/", router.basePath)
assert.DeepEqual(t, router.engine, router)
assert.DeepEqual(t, 0, len(router.Handlers))
}

func TestNew_Engine_WithTransporter(t *testing.T) {
defaultTransporter = netpoll.NewTransporter
opt := config.NewOptions([]config.Option{})
router := NewEngine(opt)
assert.DeepEqual(t, "netpoll", router.GetTransporterName())

defaultTransporter = netpoll.NewTransporter
opt.TransporterNewer = standard.NewTransporter
router = NewEngine(opt)
assert.DeepEqual(t, "standard", router.GetTransporterName())
assert.DeepEqual(t, "netpoll", GetTransporterName())
}

func TestGetTransporterName(t *testing.T) {
name := getTransporterName(&fakeTransporter{})
assert.DeepEqual(t, "route", name)
}

func TestEngineUnescape(t *testing.T) {
e := NewEngine(config.NewOptions(nil))

Expand Down Expand Up @@ -524,3 +544,20 @@ func (m *mockConn) WriteBinary(b []byte) (n int, err error) {
func (m *mockConn) Flush() error {
panic("implement me")
}

type fakeTransporter struct{}

func (f *fakeTransporter) Close() error {
// TODO implement me
panic("implement me")
}

func (f *fakeTransporter) Shutdown(ctx context.Context) error {
// TODO implement me
panic("implement me")
}

func (f *fakeTransporter) ListenAndServe(onData network.OnData) error {
// TODO implement me
panic("implement me")
}