diff --git a/channel.go b/channel.go index 4477d00e..4c5aa4a9 100644 --- a/channel.go +++ b/channel.go @@ -155,6 +155,7 @@ type Channel struct { relayHost RelayHost relayMaxTimeout time.Duration relayTimerVerify bool + internalHandlers *handlerMap handler Handler onPeerStatusChanged func(*Peer) closed chan struct{} @@ -284,12 +285,7 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { ch.mutable.state = ChannelClient ch.mutable.conns = make(map[uint32]*Connection) ch.createCommonStats() - - // Register internal unless the root handler has been overridden, since - // Register will panic. - if opts.Handler == nil { - ch.registerInternal() - } + ch.internalHandlers = ch.createInternalHandlers() registerNewChannel(ch) diff --git a/connection.go b/connection.go index 937e1fd2..220c0f66 100644 --- a/connection.go +++ b/connection.go @@ -161,23 +161,24 @@ type connectionEvents struct { type Connection struct { channelConnectionCommon - connID uint32 - connDirection connectionDirection - opts ConnectionOptions - conn net.Conn - localPeerInfo LocalPeerInfo - remotePeerInfo PeerInfo - sendCh chan *Frame - stopCh chan struct{} - state connectionState - stateMut sync.RWMutex - inbound *messageExchangeSet - outbound *messageExchangeSet - handler Handler - nextMessageID atomic.Uint32 - events connectionEvents - commonStatsTags map[string]string - relay *Relayer + connID uint32 + connDirection connectionDirection + opts ConnectionOptions + conn net.Conn + localPeerInfo LocalPeerInfo + remotePeerInfo PeerInfo + sendCh chan *Frame + stopCh chan struct{} + state connectionState + stateMut sync.RWMutex + inbound *messageExchangeSet + outbound *messageExchangeSet + internalHandlers *handlerMap + handler Handler + nextMessageID atomic.Uint32 + events connectionEvents + commonStatsTags map[string]string + relay *Relayer // outboundHP is the host:port we used to create this outbound connection. // It may not match remotePeerInfo.HostPort, in which case the connection is @@ -311,6 +312,7 @@ func (ch *Channel) newConnection(conn net.Conn, initialID uint32, outboundHP str outboundHP: outboundHP, inbound: newMessageExchangeSet(log, messageExchangeSetInbound), outbound: newMessageExchangeSet(log, messageExchangeSetOutbound), + internalHandlers: ch.internalHandlers, handler: ch.handler, events: events, commonStatsTags: ch.commonStatsTags, diff --git a/inbound.go b/inbound.go index c133ffa9..8808d30e 100644 --- a/inbound.go +++ b/inbound.go @@ -191,6 +191,15 @@ func (c *Connection) dispatchInbound(_ uint32, _ uint32, call *InboundCall, fram } }() + // Internal handlers (e.g., introspection) trump all other user-registered handlers on + // the "tchannel" name. + if call.ServiceName() == "tchannel" { + if h := c.internalHandlers.find(call.Method()); h != nil { + h.Handle(call.mex.ctx, call) + return + } + } + c.handler.Handle(call.mex.ctx, call) } diff --git a/introspection.go b/introspection.go index f4d30ee9..c119a364 100644 --- a/introspection.go +++ b/introspection.go @@ -519,7 +519,9 @@ func introspectRuntimeVersion() RuntimeVersion { // registerInternal registers the following internal handlers which return runtime state: // _gometa_introspect: TChannel internal state. // _gometa_runtime: Golang runtime stats. -func (ch *Channel) registerInternal() { +func (ch *Channel) createInternalHandlers() *handlerMap { + internalHandlers := &handlerMap{} + endpoints := []struct { name string handler func([]byte) interface{} @@ -528,7 +530,6 @@ func (ch *Channel) registerInternal() { {"_gometa_runtime", handleInternalRuntime}, } - tchanSC := ch.GetSubChannel("tchannel") for _, ep := range endpoints { // We need ep in our closure. ep := ep @@ -545,7 +546,13 @@ func (ch *Channel) registerInternal() { } NewArgWriter(call.Response().Arg3Writer()).WriteJSON(ep.handler(arg3)) } - ch.Register(HandlerFunc(handler), ep.name) - tchanSC.Register(HandlerFunc(handler), ep.name) + + h := HandlerFunc(handler) + internalHandlers.register(h, ep.name) + + // Register under the service name of channel as well (for backwards compatibility). + ch.GetSubChannel(ch.PeerInfo().ServiceName).Register(h, ep.name) } + + return internalHandlers } diff --git a/introspection_test.go b/introspection_test.go index 61218c6e..04665a8d 100644 --- a/introspection_test.go +++ b/introspection_test.go @@ -21,6 +21,7 @@ package tchannel_test import ( + "context" "math" "strconv" "testing" @@ -37,7 +38,10 @@ import ( // Purpose of this test is to ensure introspection doesn't cause any panics // and we have coverage of the introspection code. func TestIntrospection(t *testing.T) { - testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { + opts := testutils.NewOpts(). + AddLogFilter("Couldn't find handler", 1). // call with service name fails + NoRelay() // "tchannel" service name is not forwarded. + testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := testutils.NewClient(t, nil) defer client.Close() @@ -46,27 +50,21 @@ func TestIntrospection(t *testing.T) { var resp map[string]interface{} peer := client.Peers().GetOrAdd(ts.HostPort()) - err := json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_introspect", map[string]interface{}{ + err := json.CallPeer(ctx, peer, "tchannel", "_gometa_introspect", map[string]interface{}{ "includeExchanges": true, "includeEmptyPeers": true, "includeTombstones": true, }, &resp) require.NoError(t, err, "Call _gometa_introspect failed") - err = json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_runtime", map[string]interface{}{ + err = json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_introspect", nil /* arg */, &resp) + require.NoError(t, err, "Call _gometa_introspect failed") + + // Try making the call on any other service name will fail. + err = json.CallPeer(ctx, peer, "unknown-service", "_gometa_runtime", map[string]interface{}{ "includeGoStacks": true, }, &resp) - require.NoError(t, err, "Call _gometa_runtime failed") - - if !ts.HasRelay() { - // Try making the call on the "tchannel" service which is where meta handlers - // are registered. This will only work when we call it directly as the relay - // will not forward the tchannel service. - err = json.CallPeer(ctx, peer, "tchannel", "_gometa_runtime", map[string]interface{}{ - "includeGoStacks": true, - }, &resp) - require.NoError(t, err, "Call _gometa_runtime failed") - } + require.Error(t, err, "_gometa_introspect should only be registered under tchannel") }) } @@ -148,3 +146,28 @@ func TestIntrospectClosedConn(t *testing.T) { } }) } + +func TestIntrospectionNotBlocked(t *testing.T) { + testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { + subCh := ts.Server().GetSubChannel("tchannel") + subCh.SetHandler(HandlerFunc(func(ctx context.Context, inbound *InboundCall) { + panic("should not be called") + })) + + // Ensure that tchannel is also relayed + if ts.HasRelay() { + ts.RelayHost().Add("tchannel", ts.Server().PeerInfo().HostPort) + } + + ctx, cancel := NewContext(time.Second) + defer cancel() + + client := ts.NewClient(nil) + peer := client.Peers().GetOrAdd(ts.HostPort()) + + // Ensure that SetHandler doesn't block introspection. + var resp interface{} + err := json.CallPeer(Wrap(ctx), peer, "tchannel", "_gometa_runtime", nil, &resp) + require.NoError(t, err, "Call _gometa_runtime failed") + }) +} diff --git a/subchannel_test.go b/subchannel_test.go index 31615aa0..bac4e43e 100644 --- a/subchannel_test.go +++ b/subchannel_test.go @@ -195,7 +195,6 @@ func TestGetHandlers(t *testing.T) { }{ { serviceName: ch.ServiceName(), - // Default service name comes with extra introspection methods. wantMethods: []string{"_gometa_introspect", "_gometa_runtime", "method1", "method2"}, }, {