From 011eed231300ce287eb4bc9b291bc1130233dc1e Mon Sep 17 00:00:00 2001 From: r-gochain <40002763+r-gochain@users.noreply.github.com> Date: Mon, 23 Aug 2021 15:19:42 +0300 Subject: [PATCH] WebSockets #37 (#64) * WebSockets #37 * copy context to a goroutine * circleci: newer google-cloud-sdk Co-authored-by: jmank88 --- .circleci/config.yml | 4 +- go.mod | 1 + handler.go | 137 +++++++++++------------ main.go | 17 ++- proxy.go | 18 ++- websocketproxy.go | 258 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 355 insertions(+), 80 deletions(-) create mode 100644 websocketproxy.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 4858304..5d8115d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -14,8 +14,8 @@ jobs: - run: name: install gcloud command: | - wget https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-204.0.0-linux-x86_64.tar.gz --directory-prefix=tmp - tar -xvzf tmp/google-cloud-sdk-204.0.0-linux-x86_64.tar.gz -C tmp + wget https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-353.0.0-linux-x86_64.tar.gz --directory-prefix=tmp + tar -xvzf tmp/google-cloud-sdk-353.0.0-linux-x86_64.tar.gz -C tmp ./tmp/google-cloud-sdk/install.sh -q - setup_remote_docker - deploy: diff --git a/go.mod b/go.mod index 5410008..6462c5c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-chi/chi/v5 v5.0.3 github.com/gochain/gochain/v3 v3.4.7 github.com/golang/snappy v0.0.4 // indirect + github.com/gorilla/websocket v1.4.2 // indirect github.com/pelletier/go-toml v1.9.3 github.com/rs/cors v1.8.0 github.com/treeder/gcputils v0.1.1 diff --git a/handler.go b/handler.go index 7cbca9e..a0a8af3 100644 --- a/handler.go +++ b/handler.go @@ -71,32 +71,34 @@ func parseRequests(r *http.Request) (string, []string, []ModifiedRequest, error) if err != nil { return "", nil, nil, fmt.Errorf("failed to read body: %v", err) } - type rpcRequest struct { - ID json.RawMessage `json:"id"` - Method string `json:"method"` - Params []json.RawMessage `json:"params"` + methods, res, err = parseMessage(body, ip) + if err != nil { + return "", nil, nil, err } - if isBatch(body) { - var arr []rpcRequest - err = json.Unmarshal(body, &arr) - if err != nil { - return "", nil, nil, fmt.Errorf("failed to parse JSON batch request: %v", err) - } - for _, t := range arr { - methods = append(methods, t.Method) - res = append(res, ModifiedRequest{ - ID: t.ID, - Path: t.Method, - RemoteAddr: ip, - Params: t.Params, - }) - } - } else { - var t rpcRequest - err = json.Unmarshal(body, &t) - if err != nil { - return "", nil, nil, fmt.Errorf("failed to parse JSON request: %v", err) - } + } + if len(res) == 0 { + methods = append(methods, r.URL.Path) + res = append(res, ModifiedRequest{ + Path: r.URL.Path, + RemoteAddr: ip, + }) + } + return ip, methods, res, nil +} + +func parseMessage(body []byte, ip string) (methods []string, res []ModifiedRequest, err error) { + type rpcRequest struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params []json.RawMessage `json:"params"` + } + if isBatch(body) { + var arr []rpcRequest + err := json.Unmarshal(body, &arr) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse JSON batch request: %v", err) + } + for _, t := range arr { methods = append(methods, t.Method) res = append(res, ModifiedRequest{ ID: t.ID, @@ -105,15 +107,21 @@ func parseRequests(r *http.Request) (string, []string, []ModifiedRequest, error) Params: t.Params, }) } - } - if len(res) == 0 { - methods = append(methods, r.URL.Path) + } else { + var t rpcRequest + err := json.Unmarshal(body, &t) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse JSON request: %v", err) + } + methods = append(methods, t.Method) res = append(res, ModifiedRequest{ - Path: r.URL.Path, + ID: t.ID, + Path: t.Method, RemoteAddr: ip, + Params: t.Params, }) } - return ip, methods, res, nil + return methods, res, nil } const ( @@ -123,16 +131,18 @@ const ( jsonRPCInternal = -32603 ) +type ErrResponse struct { + Version string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + func jsonRPCError(id json.RawMessage, jsonCode int, msg string) interface{} { - type errResponse struct { - Version string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` - Error struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error"` - } - resp := errResponse{ + + resp := ErrResponse{ Version: "2.0", ID: id, } @@ -187,8 +197,13 @@ func (t *myTransport) RoundTrip(req *http.Request) (*http.Response, error) { ctx = gotils.With(ctx, "remoteIp", ip) ctx = gotils.With(ctx, "methods", methods) - if blockResponse := t.block(ctx, parsedRequests); blockResponse != nil { - return blockResponse, nil + errorCode, resp := t.block(ctx, parsedRequests) + if resp != nil { + resp, err := jsonRPCResponse(errorCode, resp) + if err != nil { + gotils.L(ctx).Error().Printf("Failed to construct a response: %v", err) + } + return resp, nil } gotils.L(ctx).Info().Print("Forwarding request") @@ -197,53 +212,33 @@ func (t *myTransport) RoundTrip(req *http.Request) (*http.Response, error) { } // block returns a response only if the request should be blocked, otherwise it returns nil if allowed. -func (t *myTransport) block(ctx context.Context, parsedRequests []ModifiedRequest) *http.Response { +func (t *myTransport) block(ctx context.Context, parsedRequests []ModifiedRequest) (int, interface{}) { var union *blockRange for _, parsedRequest := range parsedRequests { ctx = gotils.With(ctx, "ip", parsedRequest.RemoteAddr) if allowed, added := t.AllowVisitor(parsedRequest); !allowed { gotils.L(ctx).Info().Print("Request blocked: Rate limited") - resp, err := jsonRPCResponse(http.StatusTooManyRequests, jsonRPCLimit(parsedRequest.ID)) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct rate-limit response: %v", err) - } - return resp + return http.StatusTooManyRequests, jsonRPCLimit(parsedRequest.ID) } else if added { gotils.L(ctx).Info().Printf("Added new visitor, ip: %v", parsedRequest.RemoteAddr) } if !t.MatchAnyRule(parsedRequest.Path) { gotils.L(ctx).Info().Print("Request blocked: Method not allowed") - resp, err := jsonRPCResponse(http.StatusMethodNotAllowed, jsonRPCUnauthorized(parsedRequest.ID, parsedRequest.Path)) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct not-allowed response: %v", err) - } - return resp + return http.StatusMethodNotAllowed, jsonRPCUnauthorized(parsedRequest.ID, parsedRequest.Path) } if t.blockRangeLimit > 0 && parsedRequest.Path == "eth_getLogs" { r, invalid, err := t.parseRange(ctx, parsedRequest) if err != nil { - resp, err := jsonRPCResponse(http.StatusInternalServerError, jsonRPCError(parsedRequest.ID, jsonRPCInternal, err.Error())) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct internal error response: %v", err) - } - return resp + return http.StatusInternalServerError, jsonRPCError(parsedRequest.ID, jsonRPCInternal, err.Error()) } else if invalid != nil { gotils.L(ctx).Info().Printf("Request blocked: Invalid params: %v", invalid) - resp, err := jsonRPCResponse(http.StatusBadRequest, jsonRPCError(parsedRequest.ID, jsonRPCInvalidParams, invalid.Error())) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct invalid params response: %v", err) - } - return resp + return http.StatusBadRequest, jsonRPCError(parsedRequest.ID, jsonRPCInvalidParams, invalid.Error()) } if r != nil { if l := r.len(); l > t.blockRangeLimit { gotils.L(ctx).Info().Println("Request blocked: Exceeds block range limit, range:", l, "limit:", t.blockRangeLimit) - resp, err := jsonRPCResponse(http.StatusBadRequest, jsonRPCBlockRangeLimit(parsedRequest.ID, l, t.blockRangeLimit)) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct block range limit response: %v", err) - } - return resp + return http.StatusBadRequest, jsonRPCBlockRangeLimit(parsedRequest.ID, l, t.blockRangeLimit) } if union == nil { union = r @@ -251,17 +246,13 @@ func (t *myTransport) block(ctx context.Context, parsedRequests []ModifiedReques union.extend(r) if l := union.len(); l > t.blockRangeLimit { gotils.L(ctx).Info().Println("Request blocked: Exceeds block range limit, range:", l, "limit:", t.blockRangeLimit) - resp, err := jsonRPCResponse(http.StatusBadRequest, jsonRPCBlockRangeLimit(parsedRequest.ID, l, t.blockRangeLimit)) - if err != nil { - gotils.L(ctx).Error().Printf("Failed to construct block range limit response: %v", err) - } - return resp + return http.StatusBadRequest, jsonRPCBlockRangeLimit(parsedRequest.ID, l, t.blockRangeLimit) } } } } } - return nil + return 0, nil } type blockRange struct{ start, end uint64 } diff --git a/main.go b/main.go index 80938e6..6974a34 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ var requestsPerMinuteLimit int type ConfigData struct { Port string `toml:",omitempty"` URL string `toml:",omitempty"` + WSURL string `toml:",omitempty"` Allow []string `toml:",omitempty"` RPM int `toml:",omitempty"` NoLimit []string `toml:",omitempty"` @@ -37,6 +38,7 @@ func main() { var configPath string var port string var redirecturl string + var redirectWSUrl string var allowedPaths string var noLimitIPs string var blockRangeLimit uint64 @@ -64,6 +66,12 @@ func main() { Usage: "redirect url", Destination: &redirecturl, }, + &cli.StringFlag{ + Name: "wsurl, w", + Value: "ws://127.0.0.1:8041", + Usage: "redirect websocket url", + Destination: &redirectWSUrl, + }, &cli.StringFlag{ Name: "allow, a", Usage: "comma separated list of allowed paths", @@ -111,6 +119,12 @@ func main() { } cfg.URL = redirecturl } + if redirectWSUrl != "" { + if cfg.WSURL != "" { + return errors.New("ws url set in two places") + } + cfg.WSURL = redirectWSUrl + } if requestsPerMinuteLimit != 0 { if cfg.RPM != 0 { return errors.New("rpm set in two places") @@ -150,7 +164,7 @@ func (cfg *ConfigData) run(ctx context.Context) error { sort.Strings(cfg.Allow) sort.Strings(cfg.NoLimit) - gotils.L(ctx).Info().Println("Server starting, port:", cfg.Port, "redirectURL:", cfg.URL, + gotils.L(ctx).Info().Println("Server starting, port:", cfg.Port, "redirectURL:", cfg.URL, "redirectWSURL:", cfg.WSURL, "rpmLimit:", cfg.RPM, "exempt:", cfg.NoLimit, "allowed:", cfg.Allow) // Create proxy server. @@ -189,5 +203,6 @@ func (cfg *ConfigData) run(ctx context.Context) error { w.WriteHeader(http.StatusOK) }) r.HandleFunc("/*", server.RPCProxy) + r.HandleFunc("/ws", server.WSProxy) return http.ListenAndServe(":"+cfg.Port, r) } diff --git a/proxy.go b/proxy.go index 06d2ba5..3fb0d1e 100644 --- a/proxy.go +++ b/proxy.go @@ -21,8 +21,9 @@ import ( ) type Server struct { - target *url.URL - proxy *httputil.ReverseProxy + target *url.URL + proxy *httputil.ReverseProxy + wsProxy *WebsocketProxy myTransport homepage []byte } @@ -32,8 +33,11 @@ func (cfg *ConfigData) NewServer() (*Server, error) { if err != nil { return nil, err } - - s := &Server{target: url, proxy: httputil.NewSingleHostReverseProxy(url)} + wsurl, err := url.Parse(cfg.WSURL) + if err != nil { + return nil, err + } + s := &Server{target: url, proxy: httputil.NewSingleHostReverseProxy(url), wsProxy: NewProxy(wsurl)} s.myTransport.blockRangeLimit = cfg.BlockRangeLimit s.myTransport.url = cfg.URL s.matcher, err = newMatcher(cfg.Allow) @@ -46,6 +50,7 @@ func (cfg *ConfigData) NewServer() (*Server, error) { s.noLimitIPs[ip] = struct{}{} } s.proxy.Transport = &s.myTransport + s.wsProxy.Transport = &s.myTransport // Generate static home page. id := json.RawMessage([]byte(`"ID"`)) @@ -88,6 +93,11 @@ func (p *Server) RPCProxy(w http.ResponseWriter, r *http.Request) { p.proxy.ServeHTTP(w, r) } +func (p *Server) WSProxy(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-rpc-proxy", "rpc-proxy") + p.wsProxy.ServeHTTP(w, r) +} + func (p *Server) Example(w http.ResponseWriter, r *http.Request) { method := chi.URLParam(r, "method") args := []string{ diff --git a/websocketproxy.go b/websocketproxy.go new file mode 100644 index 0000000..ebccf6f --- /dev/null +++ b/websocketproxy.go @@ -0,0 +1,258 @@ +// Package websocketproxy is a reverse proxy for WebSocket connections. +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/gorilla/websocket" + "github.com/treeder/gotils/v2" +) + +var ( + // DefaultUpgrader specifies the parameters for upgrading an HTTP + // connection to a WebSocket connection. + DefaultUpgrader = &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // DefaultDialer is a dialer with all fields set to the default zero values. + DefaultDialer = websocket.DefaultDialer +) + +// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket +// connection and proxies it to another server. +type WebsocketProxy struct { + // Director, if non-nil, is a function that may copy additional request + // headers from the incoming WebSocket connection into the output headers + // which will be forwarded to another server. + Director func(incoming *http.Request, out http.Header) + + // Backend returns the backend URL which the proxy uses to reverse proxy + // the incoming WebSocket connection. Request is the initial incoming and + // unmodified request. + Backend func(*http.Request) *url.URL + + // Upgrader specifies the parameters for upgrading a incoming HTTP + // connection to a WebSocket connection. If nil, DefaultUpgrader is used. + Upgrader *websocket.Upgrader + + // Dialer contains options for connecting to the backend WebSocket server. + // If nil, DefaultDialer is used. + Dialer *websocket.Dialer + + Transport *myTransport +} + +// NewProxy returns a new Websocket reverse proxy that rewrites the +// URL's to the scheme, host and base path provider in target. +func NewProxy(target *url.URL) *WebsocketProxy { + backend := func(r *http.Request) *url.URL { + // Shallow copy + u := *target + u.Fragment = r.URL.Fragment + u.Path = r.URL.Path + u.RawQuery = r.URL.RawQuery + return &u + } + return &WebsocketProxy{Backend: backend} +} + +// ServeHTTP implements the http.Handler that proxies WebSocket connections. +func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + if w.Backend == nil { + gotils.L(ctx).Error().Print("webproxy backend function is not defined") + http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError) + return + } + + backendURL := w.Backend(req) + if backendURL == nil { + gotils.L(ctx).Error().Print("websocketproxy: backend URL is nil") + http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError) + return + } + + dialer := w.Dialer + if w.Dialer == nil { + dialer = DefaultDialer + } + + // Pass headers from the incoming request to the dialer to forward them to + // the final destinations. + requestHeader := http.Header{} + if origin := req.Header.Get("Origin"); origin != "" { + requestHeader.Add("Origin", origin) + } + for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + requestHeader.Add("Sec-WebSocket-Protocol", prot) + } + for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] { + requestHeader.Add("Cookie", cookie) + } + if req.Host != "" { + requestHeader.Set("Host", req.Host) + } + + // Pass X-Forwarded-For headers too, code below is a part of + // httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For + // for more information + // TODO: use RFC7239 http://tools.ietf.org/html/rfc7239 + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := req.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + requestHeader.Set("X-Forwarded-For", clientIP) + } + + // Set the originating protocol of the incoming HTTP request. The SSL might + // be terminated on our site and because we doing proxy adding this would + // be helpful for applications on the backend. + requestHeader.Set("X-Forwarded-Proto", "http") + if req.TLS != nil { + requestHeader.Set("X-Forwarded-Proto", "https") + } + + // Enable the director to copy any additional headers it desires for + // forwarding to the remote server. + if w.Director != nil { + w.Director(req, requestHeader) + } + + // Connect to the backend URL, also pass the headers we get from the requst + // together with the Forwarded headers we prepared above. + // TODO: support multiplexing on the same backend connection instead of + // opening a new TCP connection time for each request. This should be + // optional: + // http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01 + connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader) + if err != nil { + gotils.L(ctx).Error().Printf("websocketproxy:%s", err) + if resp != nil { + // If the WebSocket handshake fails, ErrBadHandshake is returned + // along with a non-nil *http.Response so that callers can handle + // redirects, authentication, etcetera. + if err := copyResponse(rw, resp); err != nil { + gotils.L(ctx).Error().Printf("websocketproxy: couldn't write response after failed remote backend handshake %s", err) + } + } else { + http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) + } + return + } + defer connBackend.Close() + + upgrader := w.Upgrader + if w.Upgrader == nil { + upgrader = DefaultUpgrader + } + + // Only pass those headers to the upgrader. + upgradeHeader := http.Header{} + if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" { + upgradeHeader.Set("Sec-Websocket-Protocol", hdr) + } + if hdr := resp.Header.Get("Set-Cookie"); hdr != "" { + upgradeHeader.Set("Set-Cookie", hdr) + } + + // Now upgrade the existing incoming request to a WebSocket connection. + // Also pass the header that we gathered from the Dial handshake. + connPub, err := upgrader.Upgrade(rw, req, upgradeHeader) + if err != nil { + gotils.L(ctx).Error().Printf("websocketproxy: couldn't upgrade %s", err) + return + } + defer connPub.Close() + + errClient := make(chan error, 1) + errBackend := make(chan error, 1) + replicateWebsocketConn := func(ctx context.Context, ip string, limit bool, dst, src *websocket.Conn, errc chan error) { + for { + msgType, msg, err := src.ReadMessage() + if err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) + if e, ok := err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } + } + errc <- err + dst.WriteMessage(websocket.CloseMessage, m) + break + } + if limit { + methods, res, err := parseMessage(msg, ip) + if err != nil { + errc <- err + err = src.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))) + if err != nil { + errc <- err + } + break + } + ctx = gotils.With(ctx, "remoteIp", ip) + ctx = gotils.With(ctx, "methods", methods) + if len(methods) > 0 { + _, resp := w.Transport.block(ctx, res) + if resp != nil { + errc <- errors.New(resp.(ErrResponse).Error.Message) + err = src.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, resp.(ErrResponse).Error.Message)) + if err != nil { + errc <- err + } + break + } + } + } + err = dst.WriteMessage(msgType, msg) + if err != nil { + errc <- err + break + } + } + } + ip := getIP(req) + go replicateWebsocketConn(ctx, ip, true, connBackend, connPub, errBackend) + go replicateWebsocketConn(ctx, ip, false, connPub, connBackend, errClient) + + var message string + select { + case err = <-errClient: + message = "websocketproxy: Error when copying from backend to client: %v" + case err = <-errBackend: + message = "websocketproxy: Error when copying from client to backend: %v" + + } + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + gotils.L(ctx).Error().Printf("%s %s", message, err) + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func copyResponse(rw http.ResponseWriter, resp *http.Response) error { + copyHeader(rw.Header(), resp.Header) + rw.WriteHeader(resp.StatusCode) + defer resp.Body.Close() + + _, err := io.Copy(rw, resp.Body) + return err +}