From 92d7ff4f4d94cc6103c5bd9cd4b9c7b0c47cc0c8 Mon Sep 17 00:00:00 2001 From: gaowenju Date: Tue, 8 Nov 2022 20:52:12 +0800 Subject: [PATCH] feat: distinguish between global and local transporter --- pkg/app/server/option.go | 5 ++--- pkg/common/config/option.go | 5 +++++ pkg/route/engine.go | 31 ++++++++++++++++++++++++++++--- pkg/route/engine_test.go | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 6 deletions(-) diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index fc913e400..9447734cb 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -27,7 +27,6 @@ import ( "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" - "github.com/cloudwego/hertz/pkg/route" ) // WithKeepAliveTimeout sets keep-alive timeout. @@ -216,7 +215,7 @@ func WithExitWaitTime(timeout time.Duration) config.Option { // NOTE: If a tls server is started, it won't accept non-tls request. func WithTLS(cfg *tls.Config) config.Option { return config.Option{F: func(o *config.Options) { - route.SetTransporter(standard.NewTransporter) + o.TransporterNewer = standard.NewTransporter o.TLS = cfg }} } @@ -231,7 +230,7 @@ func WithListenConfig(l *net.ListenConfig) config.Option { // WithTransport sets which network library to use. func WithTransport(transporter func(options *config.Options) network.Transporter) config.Option { return config.Option{F: func(o *config.Options) { - route.SetTransporter(transporter) + o.TransporterNewer = transporter }} } diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index cd69716a3..15037f067 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -22,6 +22,7 @@ import ( "time" "github.com/cloudwego/hertz/pkg/app/server/registry" + "github.com/cloudwego/hertz/pkg/network" ) // Option is the only struct that can be used to set Options. @@ -68,11 +69,15 @@ type Options struct { TraceLevel interface{} ListenConfig *net.ListenConfig + // TransporterNewer is the function to create a transporter. + TransporterNewer func(opt *Options) network.Transporter + // Registry is used for service registry. Registry registry.Registry // RegistryInfo is base info used for service registry. RegistryInfo *registry.Info // Enable automatically HTML template reloading mechanism. + AutoReloadRender bool // If AutoReloadInterval is set to 0(default). // The HTML template will reload according to files' changing event diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 45ae38f89..8a272a8a6 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -76,6 +76,8 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/suite" ) +const unknownTransporterName = "unknown" + var ( defaultTransporter = standard.NewTransporter @@ -199,15 +201,35 @@ func (engine *Engine) GetOptions() *config.Options { return engine.options } +// SetTransporter only sets the global default value for the transporter. +// Use WithTransporter during engine creation to set the transporter for the engine. func SetTransporter(transporter func(options *config.Options) network.Transporter) { defaultTransporter = transporter } +func (engine *Engine) GetTransporterName() (tName string) { + return getTransporterName(engine.transport) +} + +func getTransporterName(transporter network.Transporter) (tName string) { + defer func() { + err := recover() + if err != nil || tName == "" { + tName = unknownTransporterName + } + }() + t := reflect.ValueOf(transporter).Type().String() + tName = strings.Split(strings.TrimPrefix(t, "*"), ".")[0] + return tName +} + +// Deprecated: This only get the global default transporter - may not be the real one used by the engine. +// Use engine.GetTransporterName for the real transporter used. func GetTransporterName() (tName string) { defer func() { err := recover() - if err != nil { - tName = "unknown" + if err != nil || tName == "" { + tName = unknownTransporterName } }() fName := runtime.FuncForPC(reflect.ValueOf(defaultTransporter).Pointer()).Name() @@ -363,7 +385,7 @@ func (engine *Engine) alpnEnable() bool { } func (engine *Engine) listenAndServe() error { - hlog.SystemLogger().Infof("Using network library=%s", GetTransporterName()) + hlog.SystemLogger().Infof("Using network library=%s", engine.GetTransporterName()) return engine.transport.ListenAndServe(engine.onData) } @@ -520,6 +542,9 @@ func NewEngine(opt *config.Options) *Engine { enableTrace: true, options: opt, } + if opt.TransporterNewer != nil { + engine.transport = opt.TransporterNewer(opt) + } engine.RouterGroup.engine = engine traceLevel := initTrace(engine) diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 9b99180e1..f3007ac06 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -64,13 +64,33 @@ import ( ) func TestNew_Engine(t *testing.T) { + defaultTransporter = standard.NewTransporter opt := config.NewOptions([]config.Option{}) router := NewEngine(opt) + assert.DeepEqual(t, "standard", router.GetTransporterName()) assert.DeepEqual(t, "/", router.basePath) assert.DeepEqual(t, router.engine, router) assert.DeepEqual(t, 0, len(router.Handlers)) } +func TestNew_Engine_WithTransporter(t *testing.T) { + defaultTransporter = netpoll.NewTransporter + opt := config.NewOptions([]config.Option{}) + router := NewEngine(opt) + assert.DeepEqual(t, "netpoll", router.GetTransporterName()) + + defaultTransporter = netpoll.NewTransporter + opt.TransporterNewer = standard.NewTransporter + router = NewEngine(opt) + assert.DeepEqual(t, "standard", router.GetTransporterName()) + assert.DeepEqual(t, "netpoll", GetTransporterName()) +} + +func TestGetTransporterName(t *testing.T) { + name := getTransporterName(&fakeTransporter{}) + assert.DeepEqual(t, "route", name) +} + func TestEngineUnescape(t *testing.T) { e := NewEngine(config.NewOptions(nil)) @@ -524,3 +544,20 @@ func (m *mockConn) WriteBinary(b []byte) (n int, err error) { func (m *mockConn) Flush() error { panic("implement me") } + +type fakeTransporter struct{} + +func (f *fakeTransporter) Close() error { + // TODO implement me + panic("implement me") +} + +func (f *fakeTransporter) Shutdown(ctx context.Context) error { + // TODO implement me + panic("implement me") +} + +func (f *fakeTransporter) ListenAndServe(onData network.OnData) error { + // TODO implement me + panic("implement me") +}