diff --git a/README.md b/README.md index 0004771..35b9d86 100644 --- a/README.md +++ b/README.md @@ -1 +1,186 @@ -# go-httpproxy +# A Go HTTP proxy library which has KISS principle + +## Introduction + +`github.com/go-httpproxy/httpproxy` repository provides an HTTP proxy library +for Go (golang). + +The library is regular HTTP proxy; supports HTTP, HTTPS through CONNECT. And +also provides HTTPS connection using "Man in the Middle" style attack. + +It's easy to use. `httpproxy.Proxy` implements `Handler` interface of `net/http` +package to offer `http.ListenAndServe` function. + +### Keep it simple, stupid! + +> KISS is an acronym for "Keep it simple, stupid" as a design principle. The +KISS principle states that most systems work best if they are kept simple rather +than made complicated; therefore simplicity should be a key goal in design and +unnecessary complexity should be avoided. [Wikipedia] + +## Usage + +Library has two significant structs: Proxy and Context. + +### Proxy struct + +```go +// Proxy defines parameters for running an HTTP Proxy. It implements +// http.Handler interface for ListenAndServe function. If you need, you must +// fill Proxy struct before handling requests. +type Proxy struct { + // Session number of last proxy request. + SessionNo int64 + + // RoundTripper interface to obtain remote response. + // By default, it uses &http.Transport{}. + Rt http.RoundTripper + + // Certificate key pair. + Ca tls.Certificate + + // User data to use free. + UserData interface{} + + // Error handler. + OnError func(ctx *Context, when string, err *Error, opErr error) + + // Accept handler. It greets proxy request like ServeHTTP function of + // http.Handler. + // If it returns true, stops processing proxy request. + OnAccept func(ctx *Context, w http.ResponseWriter, r *http.Request) bool + + // Auth handler. If you need authentication, set this handler. + // If it returns true, authentication succeeded. + OnAuth func(ctx *Context, user string, pass string) bool + + // Connect handler. It sets connect action and new host. + // If len(newhost) > 0, host changes. + OnConnect func(ctx *Context, host string) (ConnectAction ConnectAction, + newHost string) + + // Request handler. It greets remote request. + // If it returns non-nil response, stops processing remote request. + OnRequest func(ctx *Context, req *http.Request) (resp *http.Response) + + // Response handler. It greets remote response. + // Remote response sends after this handler. + OnResponse func(ctx *Context, req *http.Request, resp *http.Response) + + // If ConnectAction is ConnectMitm, it sets chunked to Transfer-Encoding. + // By default, it is true. + MitmChunked bool +} +``` + +### Context struct + +```go +// Context keeps context of each proxy request. +type Context struct { + // Pointer of Proxy struct handled this context. + // It's using internally. Don't change in Context struct! + Prx *Proxy + + // Session number of this context obtained from Proxy struct. + SessionNo int64 + + // Sub session number of processing remote connection. + SubSessionNo int64 + + // Action of after the CONNECT, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! + ConnectAction ConnectAction + + // Proxy request, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! + ConnectReq *http.Request + + // Remote host, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! + ConnectHost string + + // User data to use free. + UserData interface{} +} +``` + +### Demo code + +```go +package main + +import ( + "log" + "net/http" + + "github.com/go-httpproxy/httpproxy" +) + +func OnError(ctx *httpproxy.Context, when string, + err *httpproxy.Error, opErr error) { + // Log errors. + log.Printf("ERR: %s: %s [%s]", when, err, opErr) +} + +func OnAccept(ctx *httpproxy.Context, w http.ResponseWriter, + r *http.Request) bool { + // Handle local request has path "/info" + if r.Method == "GET" && !r.URL.IsAbs() && r.URL.Path == "/info" { + w.Write([]byte("This is go-httpproxy.")) + return true + } + return false +} + +func OnAuth(ctx *httpproxy.Context, user string, pass string) bool { + // Auth test user. + if user == "test" && pass == "1234" { + return true + } + return false +} + +func OnConnect(ctx *httpproxy.Context, host string) ( + ConnectAction httpproxy.ConnectAction, newHost string) { + // Apply "Man in the Middle" to all ssl connections. Never change host. + return httpproxy.ConnectMitm, host +} + +func OnRequest(ctx *httpproxy.Context, req *http.Request) ( + resp *http.Response) { + // Log proxying requests. + log.Printf("INFO: Proxy: %s %s", req.Method, req.URL.String()) + return +} + +func OnResponse(ctx *httpproxy.Context, req *http.Request, + resp *http.Response) { + // Add header "Via: go-httpproxy". + resp.Header.Add("Via", "go-httpproxy") +} + +func main() { + // Create a new proxy with default certificate pair. + prx, _ := httpproxy.NewProxy() + + // Set handlers. + prx.OnError = OnError + prx.OnAccept = OnAccept + prx.OnAuth = OnAuth + prx.OnConnect = OnConnect + prx.OnRequest = OnRequest + prx.OnResponse = OnResponse + + // Listen... + http.ListenAndServe(":8080", prx) +} +``` + +## GoDoc + +[https://godoc.org/github.com/go-httpproxy/httpproxy](https://godoc.org/github.com/go-httpproxy/httpproxy) + +## To-Do + +* GoDoc diff --git a/ca.go b/ca.go index 50a9bdb..5895c7b 100644 --- a/ca.go +++ b/ca.go @@ -13,6 +13,7 @@ import ( "time" ) +// Default certificate. var DefaultCaCert = []byte(`-----BEGIN CERTIFICATE----- MIIFkzCCA3ugAwIBAgIJAKEbW2ujNjX9MA0GCSqGSIb3DQEBCwUAMGAxCzAJBgNV BAYTAlRSMREwDwYDVQQIDAhJc3RhbmJ1bDEVMBMGA1UECgwMZ28taHR0cHByb3h5 @@ -46,6 +47,7 @@ Ii9Vb07WDMQXou0ZZs7rnjAKo+sfFElTFewtS1wif4ZYBUJN1ln9G8qKaxbAiElm MgzNfZ7WlnaJf2rfHJbvK9VqJ9z6dLRYPjCHhakJBtzsMdxysEGJ -----END CERTIFICATE-----`) +// Default key. var DefaultCaKey = []byte(`-----BEGIN RSA PRIVATE KEY----- MIIJKQIBAAKCAgEA18cwaaZzhdDEpUXpR9pkYRqsSdT30WhynFhFtcaBOf4eYdpt AJWL2ipo3Ac6bh+YgWfywG4prrSfWOJl+dQ59w439vLek/waBcEeFx+wJ6PFu0ur diff --git a/context.go b/context.go index 5804908..93eebb0 100644 --- a/context.go +++ b/context.go @@ -1,12 +1,38 @@ package httpproxy -import "net/http" +import ( + "bufio" + "crypto/tls" + "net/http" +) -// Context defines context of each proxy connection. +// Context keeps context of each proxy request. type Context struct { - Prx *Proxy - SessionNo int64 + // Pointer of Proxy struct handled this context. + // It's using internally. Don't change in Context struct! + Prx *Proxy + + // Session number of this context obtained from Proxy struct. + SessionNo int64 + + // Sub session number of processing remote connection. + SubSessionNo int64 + + // Action of after the CONNECT, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! ConnectAction ConnectAction - ConnectReq *http.Request - UserData interface{} + + // Proxy request, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! + ConnectReq *http.Request + + // Remote host, if proxy request method is CONNECT. + // It's using internally. Don't change in Context struct! + ConnectHost string + + // User data to use free. + UserData interface{} + + hijTLSConn *tls.Conn + hijTLSReader *bufio.Reader } diff --git a/demo/main.go b/demo/main.go index f05f444..bc61a4f 100644 --- a/demo/main.go +++ b/demo/main.go @@ -7,35 +7,54 @@ import ( "github.com/go-httpproxy/httpproxy" ) -func OnError(ctx *httpproxy.Context, when string, err *httpproxy.Error, opErr error) { - log.Printf("%s %s %s", when, err, opErr) +func OnError(ctx *httpproxy.Context, when string, + err *httpproxy.Error, opErr error) { + // Log errors. + log.Printf("ERR: %s: %s [%s]", when, err, opErr) } -func OnAccept(ctx *httpproxy.Context, req *http.Request) *http.Response { - return nil +func OnAccept(ctx *httpproxy.Context, w http.ResponseWriter, + r *http.Request) bool { + // Handle local request has path "/info" + if r.Method == "GET" && !r.URL.IsAbs() && r.URL.Path == "/info" { + w.Write([]byte("This is go-httpproxy.")) + return true + } + return false } func OnAuth(ctx *httpproxy.Context, user string, pass string) bool { + // Auth test user. if user == "test" && pass == "1234" { return true } return false } -func OnConnect(ctx *httpproxy.Context, host string) (httpproxy.ConnectAction, string) { +func OnConnect(ctx *httpproxy.Context, host string) ( + ConnectAction httpproxy.ConnectAction, newHost string) { + // Apply "Man in the Middle" to all ssl connections. Never change host. return httpproxy.ConnectMitm, host } -func OnRequest(ctx *httpproxy.Context, req *http.Request) *http.Response { - return nil +func OnRequest(ctx *httpproxy.Context, req *http.Request) ( + resp *http.Response) { + // Log proxying requests. + log.Printf("INFO: Proxy: %s %s", req.Method, req.URL.String()) + return } -func OnResponse(ctx *httpproxy.Context, req *http.Request, resp *http.Response) { - resp.Header.Add("Via", "test") +func OnResponse(ctx *httpproxy.Context, req *http.Request, + resp *http.Response) { + // Add header "Via: go-httpproxy". + resp.Header.Add("Via", "go-httpproxy") } func main() { + // Create a new proxy with default certificate pair. prx, _ := httpproxy.NewProxy() + + // Set handlers. prx.OnError = OnError prx.OnAccept = OnAccept prx.OnAuth = OnAuth @@ -43,5 +62,6 @@ func main() { prx.OnRequest = OnRequest prx.OnResponse = OnResponse + // Listen... http.ListenAndServe(":8080", prx) } diff --git a/doing.go b/doing.go index 5ef6023..809c02d 100644 --- a/doing.go +++ b/doing.go @@ -20,24 +20,19 @@ func doError(ctx *Context, when string, err *Error, opErr error) { } func doAccept(ctx *Context, w http.ResponseWriter, r *http.Request) bool { - if ctx.Prx.OnAccept == nil { - return false - } - resp := ctx.Prx.OnAccept(ctx, r) - if resp == nil { + if ctx.Prx.OnAccept == nil || !ctx.Prx.OnAccept(ctx, w, r) { return false } if r.Close { defer r.Body.Close() } - err := ServeResponse(w, resp) - if err != nil && !isConnectionClosed(err) { - doError(ctx, "Accept", ErrResponseWrite, err) - } return true } func doAuth(ctx *Context, w http.ResponseWriter, r *http.Request) bool { + if r.Method != "CONNECT" && !r.URL.IsAbs() { + return false + } if ctx.Prx.OnAuth == nil { return false } @@ -57,16 +52,17 @@ func doAuth(ctx *Context, w http.ResponseWriter, r *http.Request) bool { if r.Close { defer r.Body.Close() } - err := ServeInMemory(w, 407, nil, []byte("Proxy Authentication Required")) + err := ServeInMemory(w, 407, map[string][]string{"Proxy-Authenticate": {"Basic"}}, + []byte("Proxy Authentication Required")) if err != nil && !isConnectionClosed(err) { doError(ctx, "Auth", ErrResponseWrite, err) } return true } -func doConnect(ctx *Context, w http.ResponseWriter, r *http.Request) (w2 http.ResponseWriter, r2 *http.Request) { +func doConnect(ctx *Context, w http.ResponseWriter, r *http.Request) (w2 http.ResponseWriter) { if r.Method != "CONNECT" { - w2, r2 = w, r + w2 = w return } hij, ok := w.(http.Hijacker) @@ -87,16 +83,22 @@ func doConnect(ctx *Context, w http.ResponseWriter, r *http.Request) (w2 http.Re } hijConn := conn.(*net.TCPConn) ctx.ConnectAction = ConnectProxy + ctx.ConnectReq = r host := r.URL.Host if ctx.Prx.OnConnect != nil { - ctx.ConnectAction, host = ctx.Prx.OnConnect(ctx, host) + var newHost string + ctx.ConnectAction, newHost = ctx.Prx.OnConnect(ctx, host) if ctx.ConnectAction == ConnectNone { ctx.ConnectAction = ConnectProxy } + if newHost != "" { + host = newHost + } } if !hasPort.MatchString(host) { host += ":80" } + ctx.ConnectHost = host switch ctx.ConnectAction { case ConnectProxy: conn, err := net.Dial("tcp", host) @@ -150,39 +152,41 @@ func doConnect(ctx *Context, w http.ResponseWriter, r *http.Request) (w2 http.Re } return } - hijTlsConn := tls.Server(hijConn, tlsConfig) - if err := hijTlsConn.Handshake(); err != nil { - hijTlsConn.Close() + ctx.hijTLSConn = tls.Server(hijConn, tlsConfig) + if err := ctx.hijTLSConn.Handshake(); err != nil { + ctx.hijTLSConn.Close() if !isConnectionClosed(err) { doError(ctx, "Connect", ErrTLSHandshake, err) } return } - hijTlsReader := bufio.NewReader(hijTlsConn) - req, err := http.ReadRequest(hijTlsReader) - if err != nil { - hijTlsConn.Close() - if !isConnectionClosed(err) { - doError(ctx, "Connect", ErrRequestRead, err) - } - return - } - req.RemoteAddr = r.RemoteAddr - if req.URL.IsAbs() { - hijTlsConn.Close() - doError(ctx, "Connect", ErrAbsURLAfterCONNECT, nil) - return + ctx.hijTLSReader = bufio.NewReader(ctx.hijTLSConn) + w2 = NewConnResponseWriter(ctx.hijTLSConn) + return + } + return +} + +func doMitm(ctx *Context, w http.ResponseWriter) (r *http.Request) { + req, err := http.ReadRequest(ctx.hijTLSReader) + if err != nil { + if !isConnectionClosed(err) { + doError(ctx, "Request", ErrRequestRead, err) } - req.URL.Scheme = "https" - req.URL.Host = host - w2 = NewConnResponseWriter(hijTlsConn) - r2 = req return } + req.RemoteAddr = ctx.ConnectReq.RemoteAddr + if req.URL.IsAbs() { + doError(ctx, "Request", ErrAbsURLAfterCONNECT, nil) + return + } + req.URL.Scheme = "https" + req.URL.Host = ctx.ConnectHost + r = req return } -func doRequest(ctx *Context, w http.ResponseWriter, r *http.Request) bool { +func doRequest(ctx *Context, w http.ResponseWriter, r *http.Request) (bool, error) { r.RequestURI = "" if !r.URL.IsAbs() { if r.Close { @@ -192,26 +196,30 @@ func doRequest(ctx *Context, w http.ResponseWriter, r *http.Request) bool { if err != nil && !isConnectionClosed(err) { doError(ctx, "Request", ErrResponseWrite, err) } - return true + return true, err } if ctx.Prx.OnRequest == nil { - return false + return false, nil } resp := ctx.Prx.OnRequest(ctx, r) if resp == nil { - return false + return false, nil } if r.Close { defer r.Body.Close() } + resp.TransferEncoding = nil + if ctx.ConnectAction == ConnectMitm && ctx.Prx.MitmChunked { + resp.TransferEncoding = []string{"chunked"} + } err := ServeResponse(w, resp) if err != nil && !isConnectionClosed(err) { doError(ctx, "Request", ErrResponseWrite, err) } - return true + return true, err } -func doResponse(ctx *Context, w http.ResponseWriter, r *http.Request) bool { +func doResponse(ctx *Context, w http.ResponseWriter, r *http.Request) (bool, error) { resp, err := ctx.Prx.Rt.RoundTrip(r) if err != nil { if r.Close { @@ -220,14 +228,18 @@ func doResponse(ctx *Context, w http.ResponseWriter, r *http.Request) bool { if err != context.Canceled && !isConnectionClosed(err) { doError(ctx, "Response", ErrRoundTrip, err) } - return false + return false, err } if ctx.Prx.OnResponse != nil { ctx.Prx.OnResponse(ctx, r, resp) } + resp.TransferEncoding = nil + if ctx.ConnectAction == ConnectMitm && ctx.Prx.MitmChunked { + resp.TransferEncoding = []string{"chunked"} + } err = ServeResponse(w, resp) if err != nil && !isConnectionClosed(err) { doError(ctx, "Response", ErrResponseWrite, err) } - return true + return true, err } diff --git a/error.go b/error.go index bb8399c..b9a7f41 100644 --- a/error.go +++ b/error.go @@ -7,25 +7,30 @@ import ( "syscall" ) +// Library specific errors. var ( - ErrResponseWrite = NewError("response write") - ErrRequestRead = NewError("request read") - ErrRemoteConnect = NewError("remote connect") - ErrNotSupportHijacking = NewError("httpserver does not support hijacking") - ErrTLSSignHost = NewError("TLS sign host") - ErrTLSHandshake = NewError("TLS handshake") - ErrAbsURLAfterCONNECT = NewError("absolute URL after CONNECT") - ErrRoundTrip = NewError("round trip") + ErrResponseWrite = NewError("response write") + ErrRequestRead = NewError("request read") + ErrRemoteConnect = NewError("remote connect") + ErrNotSupportHijacking = NewError("httpserver does not support hijacking") + ErrTLSSignHost = NewError("TLS sign host") + ErrTLSHandshake = NewError("TLS handshake") + ErrAbsURLAfterCONNECT = NewError("absolute URL after CONNECT") + ErrRoundTrip = NewError("round trip") + ErrUnsupportedTransferEncoding = NewError("unsupported transfer encoding") ) +// Error struct is base of library specific errors. type Error struct { ErrString string } +// NewError returns a new Error. func NewError(errString string) *Error { return &Error{errString} } +// Error implements error interface. func (e *Error) Error() string { return e.ErrString } diff --git a/httpproxy.go b/httpproxy.go index 108c17e..f46ee99 100644 --- a/httpproxy.go +++ b/httpproxy.go @@ -11,35 +11,75 @@ type ConnectAction int // Constants of ConnectAction type. const ( + // ConnectNone specifies that proxy request is not CONNECT. + // If it returned in OnConnect, changed to ConnectProxy. ConnectNone = ConnectAction(iota) + + // ConnectProxy specifies directly socket proxy after the CONNECT. ConnectProxy + + // ConnectMitm specifies proxy "Man in the Middle" style attack + // after the CONNECT. ConnectMitm ) -// Proxy defines parameters for running an HTTP Proxy. Also implements http.Handler interface for ListenAndServe function. +// Proxy defines parameters for running an HTTP Proxy. It implements +// http.Handler interface for ListenAndServe function. If you need, you must +// fill Proxy struct before handling requests. type Proxy struct { - SessionNo int64 - Rt http.RoundTripper - Ca tls.Certificate - UserData interface{} - OnError func(ctx *Context, when string, err *Error, opErr error) - OnAccept func(ctx *Context, req *http.Request) *http.Response - OnAuth func(ctx *Context, user string, pass string) bool - OnConnect func(ctx *Context, host string) (ConnectAction, string) - OnRequest func(ctx *Context, req *http.Request) *http.Response + // Session number of last proxy request. + SessionNo int64 + + // RoundTripper interface to obtain remote response. + // By default, it uses &http.Transport{}. + Rt http.RoundTripper + + // Certificate key pair. + Ca tls.Certificate + + // User data to use free. + UserData interface{} + + // Error handler. + OnError func(ctx *Context, when string, err *Error, opErr error) + + // Accept handler. It greets proxy request like ServeHTTP function of + // http.Handler. + // If it returns true, stops processing proxy request. + OnAccept func(ctx *Context, w http.ResponseWriter, r *http.Request) bool + + // Auth handler. If you need authentication, set this handler. + // If it returns true, authentication succeeded. + OnAuth func(ctx *Context, user string, pass string) bool + + // Connect handler. It sets connect action and new host. + // If len(newhost) > 0, host changes. + OnConnect func(ctx *Context, host string) (ConnectAction ConnectAction, + newHost string) + + // Request handler. It greets remote request. + // If it returns non-nil response, stops processing remote request. + OnRequest func(ctx *Context, req *http.Request) (resp *http.Response) + + // Response handler. It greets remote response. + // Remote response sends after this handler. OnResponse func(ctx *Context, req *http.Request, resp *http.Response) + + // If ConnectAction is ConnectMitm, it sets chunked to Transfer-Encoding. + // By default, it is true. + MitmChunked bool } -// NewProxy returns a new Proxy has defaults. +// NewProxy returns a new Proxy has default certificate and key. func NewProxy() (*Proxy, error) { - return NewProxyWithCert(nil, nil) + return NewProxyCert(nil, nil) } -// NewProxyWithCert returns a new Proxy given certificate and key. -func NewProxyWithCert(caCert, caKey []byte) (result *Proxy, error error) { +// NewProxyCert returns a new Proxy given certificate and key. +func NewProxyCert(caCert, caKey []byte) (result *Proxy, error error) { result = &Proxy{ Rt: &http.Transport{TLSClientConfig: &tls.Config{}, - Proxy: http.ProxyFromEnvironment}, + Proxy: http.ProxyFromEnvironment}, MitmChunked: true, } if caCert == nil { caCert = DefaultCaCert @@ -54,7 +94,7 @@ func NewProxyWithCert(caCert, caKey []byte) (result *Proxy, error error) { return } -// ServeHTTP has been needed for implementing http.Handler. +// ServeHTTP implements http.Handler. func (prx *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := &Context{Prx: prx, SessionNo: atomic.AddInt64(&prx.SessionNo, 1)} @@ -67,23 +107,40 @@ func (prx *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } removeProxyHeaders(r) - if w2, r2 := doConnect(ctx, w, r); r2 != nil { - if r != r2 { - ctx.ConnectReq = r - } - w, r = w2, r2 + if w2 := doConnect(ctx, w, r); w2 != nil { + w = w2 } else { return } - if doRequest(ctx, w, r) { - if w2, ok := w.(*ConnResponseWriter); ok { - w2.Close() + for { + var cyclic = false + if ctx.ConnectAction == ConnectMitm { + if prx.MitmChunked { + cyclic = true + } + r = doMitm(ctx, w) + } + if r == nil { + break + } + ctx.SubSessionNo += 1 + if b, err := doRequest(ctx, w, r); err != nil { + break + } else { + if b { + if !cyclic { + break + } else { + continue + } + } + } + if b, err := doResponse(ctx, w, r); err != nil || !b || !cyclic { + break } - return } - doResponse(ctx, w, r) if w2, ok := w.(*ConnResponseWriter); ok { w2.Close() } diff --git a/utils.go b/utils.go index ae48325..ebf9c37 100644 --- a/utils.go +++ b/utils.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httputil" "regexp" "strconv" "strings" @@ -49,9 +50,39 @@ func ServeResponse(w http.ResponseWriter, resp *http.Response) error { } else { h.Del("Content-Length") } - w.WriteHeader(resp.StatusCode) - _, err := io.Copy(w, resp.Body) - return err + h.Del("Transfer-Encoding") + te := "" + if len(resp.TransferEncoding) > 0 { + if len(resp.TransferEncoding) > 1 { + return ErrUnsupportedTransferEncoding + } + te = resp.TransferEncoding[0] + } + switch te { + case "": + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + return err + } + case "chunked": + h.Add("Transfer-Encoding", "chunked") + //h.Del("Content-Length") + h.Set("Connection", "close") + w.WriteHeader(resp.StatusCode) + w2 := httputil.NewChunkedWriter(w) + if _, err := io.Copy(w2, resp.Body); err != nil { + return err + } + if err := w2.Close(); err != nil { + return err + } + if _, err := w.Write([]byte("\r\n")); err != nil { + return err + } + default: + return ErrUnsupportedTransferEncoding + } + return nil } func ServeInMemory(w http.ResponseWriter, code int, header http.Header, body []byte) error {