Skip to content

Commit

Permalink
refactor: daemon to only interface with LoggedHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
chetan committed Feb 13, 2021
1 parent c062642 commit a2b5f7d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
4 changes: 2 additions & 2 deletions bin/vproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ func startDaemon(c *cli.Context) error {
httpsPort := c.Int("https")

vhostMux := vproxy.CreateVhostMux([]string{}, httpsPort > 0)
rootMux := vproxy.NewLoggedHandler(vhostMux)
loggedHandler := vproxy.NewLoggedHandler(vhostMux)

// start daemon
d := vproxy.NewDaemon(vhostMux, rootMux, listen, httpPort, httpsPort)
d := vproxy.NewDaemon(loggedHandler, listen, httpPort, httpsPort)
d.Run()

return nil
Expand Down
39 changes: 6 additions & 33 deletions daemon.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vproxy

import (
"crypto/tls"
"fmt"
"log"
"net"
Expand All @@ -25,7 +24,6 @@ const PONG = "hello from vproxy"
type Daemon struct {
wg sync.WaitGroup

vhostMux *VhostMux
loggedHandler *LoggedHandler

listenHost string
Expand All @@ -40,8 +38,8 @@ type Daemon struct {
}

// NewDaemon
func NewDaemon(vhost *VhostMux, mux *LoggedHandler, listen string, httpPort int, httpsPort int) *Daemon {
return &Daemon{vhostMux: vhost, loggedHandler: mux, listenHost: listen, httpPort: httpPort, httpsPort: httpsPort}
func NewDaemon(lh *LoggedHandler, listen string, httpPort int, httpsPort int) *Daemon {
return &Daemon{loggedHandler: lh, listenHost: listen, httpPort: httpPort, httpsPort: httpsPort}
}

func rerunWithSudo() {
Expand Down Expand Up @@ -116,13 +114,7 @@ func (d *Daemon) Run() {

if d.enableTLS() {
fmt.Printf("[*] starting proxy: https://%s\n", d.httpsAddr)
if len(d.vhostMux.Servers) > 0 {
fmt.Printf(" vhosts:\n")
for _, server := range d.vhostMux.Servers {
fmt.Printf(" - %s -> %d\n", server.Host, server.Port)
}
}

d.loggedHandler.DumpServers(os.Stdout)
go d.startTLS()
}

Expand Down Expand Up @@ -171,7 +163,7 @@ func (d *Daemon) startTLS() {

server := http.Server{
Handler: d.loggedHandler,
TLSConfig: createTLSConfig(d.vhostMux),
TLSConfig: d.loggedHandler.CreateTLSConfig(),
// ErrorLog: nullLogger,
}

Expand Down Expand Up @@ -209,7 +201,6 @@ func (d *Daemon) registerVhost(w http.ResponseWriter, r *http.Request) {
// Remove this client when this handler exits
fmt.Printf("[*] removing vhost: %s -> %d\n", vhost.Host, vhost.Port)
d.loggedHandler.RemoveLogListener(vhost.Host)
delete(d.vhostMux.Servers, vhost.Host)
d.restartTLS()
}()

Expand Down Expand Up @@ -242,7 +233,6 @@ func (d *Daemon) addVhost(binding string, w http.ResponseWriter) (chan string, *
}

fmt.Printf("[*] registering new vhost: %s -> %d\n", vhost.Host, vhost.Port)
d.vhostMux.Servers[vhost.Host] = vhost

// Set the headers related to event streaming.
w.Header().Set("Content-Type", "text/event-stream")
Expand All @@ -251,7 +241,7 @@ func (d *Daemon) addVhost(binding string, w http.ResponseWriter) (chan string, *
w.Header().Set("Access-Control-Allow-Origin", "*")

logChan := make(chan string, 10)
d.loggedHandler.AddLogListener(vhost.Host, logChan)
d.loggedHandler.AddVhost(vhost, logChan)
d.restartTLS()

err = addToHosts(vhost.Host)
Expand All @@ -272,22 +262,5 @@ func (d *Daemon) hello(w http.ResponseWriter, r *http.Request) {
// listClients currently connected to the vproxy daemon
func (d *Daemon) listClients(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprintf(w, " %d vhosts:\n", len(d.vhostMux.Servers))
for _, v := range d.vhostMux.Servers {
fmt.Fprintf(w, "%s -> %s:%d\n", v.Host, v.ServiceHost, v.Port)
}
}

// Create multi-certificate TLS config from vhost config
func createTLSConfig(vhost *VhostMux) *tls.Config {
cfg := &tls.Config{}
for _, server := range vhost.Servers {
cert, err := tls.LoadX509KeyPair(server.Cert, server.Key)
if err != nil {
log.Fatal("failed to load keypair:", err)
}
cfg.Certificates = append(cfg.Certificates, cert)
}
cfg.BuildNameToCertificate()
return cfg
d.loggedHandler.DumpServers(w)
}
41 changes: 34 additions & 7 deletions logged_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package vproxy

import (
"bufio"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
Expand All @@ -15,6 +17,7 @@ import (
type LoggedHandler struct {
*http.ServeMux
VhostLogListeners map[string]chan string
vhostMux *VhostMux
}

// NewLoggedHandler wraps the given handler with a request/response logger
Expand All @@ -27,30 +30,54 @@ func NewLoggedHandler(handler http.Handler) *LoggedHandler {
return lh
}

func (mux *LoggedHandler) AddLogListener(host string, listener chan string) {
mux.VhostLogListeners[host] = listener
func (lh *LoggedHandler) AddVhost(vhost *Vhost, listener chan string) {
lh.VhostLogListeners[vhost.Host] = listener
lh.vhostMux.Servers[vhost.Host] = vhost
}

func (mux *LoggedHandler) RemoveLogListener(host string) {
delete(mux.VhostLogListeners, host)
func (lh *LoggedHandler) RemoveLogListener(host string) {
delete(lh.VhostLogListeners, host)
delete(lh.vhostMux.Servers, host)
}

func (mux *LoggedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// DumpServers to the given writer
func (lh *LoggedHandler) DumpServers(w io.Writer) {
fmt.Fprintf(w, "%d vhosts:\n", len(lh.vhostMux.Servers))
for _, v := range lh.vhostMux.Servers {
fmt.Fprintf(w, "%s -> %s:%d\n", v.Host, v.ServiceHost, v.Port)
}
}

// Create multi-certificate TLS config from vhost config
func (lh *LoggedHandler) CreateTLSConfig() *tls.Config {
cfg := &tls.Config{}
for _, server := range lh.vhostMux.Servers {
cert, err := tls.LoadX509KeyPair(server.Cert, server.Key)
if err != nil {
log.Fatal("failed to load keypair:", err)
}
cfg.Certificates = append(cfg.Certificates, cert)
}
cfg.BuildNameToCertificate()
return cfg
}

func (lh *LoggedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
record := &LogRecord{
ResponseWriter: w,
}

// serve request and capture timings
startTime := time.Now()
mux.ServeMux.ServeHTTP(record, r)
lh.ServeMux.ServeHTTP(record, r)
finishTime := time.Now()
elapsedTime := finishTime.Sub(startTime)
host := getHostName(r.Host)

l := fmt.Sprintf("%s [%s] %s [ %d ] %s %d %s", r.RemoteAddr, host, r.Method, record.status, r.URL, r.ContentLength, elapsedTime)
log.Println(l)

if listener, ok := mux.VhostLogListeners[host]; ok {
if listener, ok := lh.VhostLogListeners[host]; ok {
listener <- l
}
}
Expand Down

0 comments on commit a2b5f7d

Please sign in to comment.