diff --git a/rpc/client.go b/rpc/client.go index 98b1d598c31..0f8b0c5dd71 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -116,6 +116,7 @@ type clientConn struct { func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.WithValue(context.Background(), clientContextKey{}, c) + ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */, c.logger, 0) return &clientConn{conn, handler} } @@ -434,7 +435,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf // Check type of channel first. chanVal := reflect.ValueOf(channel) if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { - panic("first argument to Subscribe must be a writable channel") + panic(fmt.Sprintf("channel argument of Subscribe has type %T need writable channel", channel)) } if chanVal.IsNil() { panic("channel given to Subscribe must not be nil") @@ -493,8 +494,8 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error } func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { - // The previous write failed. Try to establish a new connection. if c.writeConn == nil { + // The previous write failed. Try to establish a new connection. if err := c.reconnect(ctx); err != nil { return err } diff --git a/rpc/http.go b/rpc/http.go index 3be9700fe13..550ed619c27 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -61,11 +61,18 @@ type httpConn struct { headers http.Header } -// httpConn is treated specially by Client. +// httpConn implements ServerCodec, but it is treated specially by Client +// and some methods don't work. The panic() stubs here exist to ensure +// this special treatment is correct. + func (hc *httpConn) WriteJSON(context.Context, interface{}) error { panic("writeJSON called on httpConn") } +func (hc *httpConn) peerInfo() PeerInfo { + panic("peerInfo called on httpConn") +} + func (hc *httpConn) remoteAddr() string { return hc.url } @@ -239,10 +246,19 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), code) return } + + // Create request-scoped context. + connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr} + connInfo.HTTP.Version = r.Proto + connInfo.HTTP.Host = r.Host + connInfo.HTTP.Origin = r.Header.Get("Origin") + connInfo.HTTP.UserAgent = r.Header.Get("User-Agent") + ctx := r.Context() + ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo) + // All checks passed, create a codec that reads directly from the request body // until EOF, writes the response to w, and orders the server to process a // single request. - ctx := r.Context() // The context might be cancelled if the client's connection was closed while waiting for ServeHTTP. if libcommon.FastContextErr(ctx) != nil { @@ -252,15 +268,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx = context.WithValue(ctx, "remote", r.RemoteAddr) - ctx = context.WithValue(ctx, "scheme", r.Proto) - ctx = context.WithValue(ctx, "local", r.Host) - if ua := r.Header.Get("User-Agent"); ua != "" { - ctx = context.WithValue(ctx, "User-Agent", ua) - } - if origin := r.Header.Get("Origin"); origin != "" { - ctx = context.WithValue(ctx, "Origin", origin) - } if s.debugSingleRequest { if v := r.Header.Get(dbg.HTTPHeader); v == "true" { ctx = dbg.ContextWithDebug(ctx, true) diff --git a/rpc/http_test.go b/rpc/http_test.go index 80234beb02b..bed37a98f74 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -132,3 +132,40 @@ func TestHTTPRespBodyUnlimited(t *testing.T) { t.Fatalf("response has wrong length %d, want %d", len(r), respLength) } } + +func TestHTTPPeerInfo(t *testing.T) { + logger := log.New() + s := newTestServer(logger) + defer s.Stop() + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := Dial(ts.URL, logger) + if err != nil { + t.Fatal(err) + } + c.SetHeader("user-agent", "ua-testing") + c.SetHeader("origin", "origin.example.com") + + // Request peer information. + var info PeerInfo + if err := c.Call(&info, "test_peerInfo"); err != nil { + t.Fatal(err) + } + + if info.RemoteAddr == "" { + t.Error("RemoteAddr not set") + } + if info.Transport != "http" { + t.Errorf("wrong Transport %q", info.Transport) + } + if info.HTTP.Version != "HTTP/1.1" { + t.Errorf("wrong HTTP.Version %q", info.HTTP.Version) + } + if info.HTTP.UserAgent != "ua-testing" { + t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent) + } + if info.HTTP.Origin != "origin.example.com" { + t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent) + } +} diff --git a/rpc/json.go b/rpc/json.go index 280ce15902e..1dea3700135 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -205,6 +205,11 @@ func (c *jsonCodec) remoteAddr() string { return c.remote } +func (c *jsonCodec) peerInfo() PeerInfo { + // This returns "ipc" because all other built-in transports have a separate codec type. + return PeerInfo{Transport: "ipc", RemoteAddr: c.remote} +} + func (c *jsonCodec) ReadBatch() (messages []*jsonrpcMessage, batch bool, err error) { // Decode the next JSON object in the input stream. // This verifies basic syntax, etc. diff --git a/rpc/server.go b/rpc/server.go index d023802c636..7e1c7239a33 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -176,3 +176,38 @@ func (s *RPCService) Modules() map[string]string { } return modules } + +// PeerInfo contains information about the remote end of the network connection. +// +// This is available within RPC method handlers through the context. Call +// PeerInfoFromContext to get information about the client connection related to +// the current method call. +type PeerInfo struct { + // Transport is name of the protocol used by the client. + // This can be "http", "ws" or "ipc". + Transport string + + // Address of client. This will usually contain the IP address and port. + RemoteAddr string + + // Addditional information for HTTP and WebSocket connections. + HTTP struct { + // Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket. + Version string + // Header values sent by the client. + UserAgent string + Origin string + Host string + } +} + +type peerInfoContextKey struct{} + +// PeerInfoFromContext returns information about the client's network connection. +// Use this with the context passed to RPC method handler functions. +// +// The zero value is returned if no connection info is present in ctx. +func PeerInfoFromContext(ctx context.Context) PeerInfo { + info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo) + return info +} diff --git a/rpc/server_test.go b/rpc/server_test.go index 7a95e97ff4c..bc82f06a20d 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -54,7 +54,7 @@ func TestServerRegisterName(t *testing.T) { t.Fatalf("Expected service calc to be registered") } - wantCallbacks := 9 + wantCallbacks := 10 if len(svc.callbacks) != wantCallbacks { t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) } diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 84aedde7278..74df3298d18 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -81,6 +81,10 @@ func (s *testService) Echo(str string, i int, args *echoArgs) echoResult { return echoResult{str, i, args} } +func (s *testService) PeerInfo(ctx context.Context) PeerInfo { + return PeerInfoFromContext(ctx) +} + func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *echoArgs) echoResult { return echoResult{str, i, args} } diff --git a/rpc/types.go b/rpc/types.go index 334b827ee7a..f9ff7d0fa4a 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -56,8 +56,10 @@ type DataError interface { // a RPC session. Implementations must be go-routine safe since the codec can be called in // multiple go-routines concurrently. type ServerCodec interface { + peerInfo() PeerInfo ReadBatch() (msgs []*jsonrpcMessage, isBatch bool, err error) Close() + jsonWriter } diff --git a/rpc/websocket.go b/rpc/websocket.go index dc6eae0c272..039f633579d 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -67,7 +67,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string, jwtSecret []byte, com logger.Warn("WebSocket upgrade failed", "err", err) return } - codec := NewWebsocketCodec(conn) + codec := NewWebsocketCodec(conn, r.Host, r.Header) s.ServeCodec(codec, 0) }) } @@ -209,7 +209,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale } return nil, hErr } - return NewWebsocketCodec(conn), nil + return NewWebsocketCodec(conn, endpoint, header), nil }, logger) } @@ -247,18 +247,30 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { type websocketCodec struct { *jsonCodec conn *websocket.Conn + info PeerInfo wg sync.WaitGroup pingReset chan struct{} } -func NewWebsocketCodec(conn *websocket.Conn) ServerCodec { +func NewWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { conn.SetReadLimit(wsMessageSizeLimit) wc := &websocketCodec{ jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), + info: PeerInfo{ + Transport: "ws", + RemoteAddr: conn.RemoteAddr().String(), + }, } + // Fill in connection details. + wc.info.HTTP.Host = host + if req != nil { + wc.info.HTTP.Origin = req.Get("Origin") + wc.info.HTTP.UserAgent = req.Get("User-Agent") + } + // Start pinger. wc.wg.Add(1) go wc.pingLoop() return wc @@ -269,6 +281,10 @@ func (wc *websocketCodec) Close() { wc.wg.Wait() } +func (wc *websocketCodec) peerInfo() PeerInfo { + return wc.info +} + func (wc *websocketCodec) WriteJSON(ctx context.Context, v interface{}) error { err := wc.jsonCodec.WriteJSON(ctx, v) if err == nil { diff --git a/turbo/app/support_cmd.go b/turbo/app/support_cmd.go index 475ae88f53e..0c858ba48b1 100644 --- a/turbo/app/support_cmd.go +++ b/turbo/app/support_cmd.go @@ -256,7 +256,7 @@ func tunnel(ctx context.Context, cancel context.CancelFunc, sigs chan os.Signal, Nodes []*info `json:"nodes"` } - codec := rpc.NewWebsocketCodec(conn) + codec := rpc.NewWebsocketCodec(conn, "wss://"+diagnosticsUrl, nil) //TODO: revise why is it so defer codec.Close() err = codec.WriteJSON(ctx1, &connectionInfo{