diff --git a/input_tcp.go b/input_tcp.go index f3c95538..025312aa 100644 --- a/input_tcp.go +++ b/input_tcp.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "fmt" + "crypto/tls" "io" "log" "net" @@ -13,15 +14,23 @@ import ( // TCPInput used for internal communication type TCPInput struct { data chan []byte - address string listener net.Listener + address string + config *TCPInputConfig +} + +type TCPInputConfig struct { + secure bool + certificatePath string + keyPath string } // NewTCPInput constructor for TCPInput, accepts address with port -func NewTCPInput(address string) (i *TCPInput) { +func NewTCPInput(address string, config *TCPInputConfig) (i *TCPInput) { i = new(TCPInput) i.data = make(chan []byte, 1000) i.address = address + i.config = config i.listen(address) @@ -36,16 +45,30 @@ func (i *TCPInput) Read(data []byte) (int, error) { } func (i *TCPInput) listen(address string) { - listener, err := net.Listen("tcp", address) - i.listener = listener + if i.config.secure { + cer, err := tls.LoadX509KeyPair(i.config.certificatePath, i.config.keyPath) + if err != nil { + log.Fatal("Error while loading --input-file certificate:", err) + } + + config := &tls.Config{Certificates: []tls.Certificate{cer}} + listener, err := tls.Listen("tcp", address, config) + if err != nil { + log.Fatal("Can't start --input-tcp with secure connection:", err) + } + i.listener = listener + } else { + listener, err := net.Listen("tcp", address) + if err != nil { + log.Fatal("Can't start:", err) + } - if err != nil { - log.Fatal("Can't start:", err) + i.listener = listener } go func() { for { - conn, err := listener.Accept() + conn, err := i.listener.Accept() if err != nil { log.Println("Error while Accept()", err) diff --git a/input_tcp_test.go b/input_tcp_test.go index 2dcd6c52..ef62b6b3 100644 --- a/input_tcp_test.go +++ b/input_tcp_test.go @@ -2,8 +2,18 @@ package main import ( "io" + "os" "log" "net" + "io/ioutil" + "crypto/x509" + "crypto/rsa" + "crypto/rand" + "crypto/tls" + "encoding/pem" + "math/big" + "time" + "bytes" "sync" "testing" ) @@ -12,7 +22,7 @@ func TestTCPInput(t *testing.T) { wg := new(sync.WaitGroup) quit := make(chan int) - input := NewTCPInput("127.0.0.1:0") + input := NewTCPInput("127.0.0.1:0", &TCPInputConfig{}) output := NewTestOutput(func(data []byte) { wg.Done() }) @@ -46,3 +56,81 @@ func TestTCPInput(t *testing.T) { close(quit) } + +func genCertificate(template *x509.Certificate) ([]byte, []byte) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + template.SerialNumber = serialNumber + template.BasicConstraintsValid = true + template.NotBefore = time.Now() + template.NotAfter = time.Now().Add(time.Hour) + + derBytes, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + + var certPem, keyPem bytes.Buffer + pem.Encode(&certPem, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&keyPem, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return certPem.Bytes(), keyPem.Bytes() +} + +func TestTCPInputSecure(t *testing.T) { + serverCertPem, serverPrivPem := genCertificate(&x509.Certificate{ + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::")}, + }) + + serverCertPemFile, _ := ioutil.TempFile("", "server.crt") + serverCertPemFile.Write(serverCertPem) + serverCertPemFile.Close() + + serverPrivPemFile, _ := ioutil.TempFile("", "server.key") + serverPrivPemFile.Write(serverPrivPem) + serverPrivPemFile.Close() + + defer func(){ + os.Remove(serverPrivPemFile.Name()) + os.Remove(serverCertPemFile.Name()) + }() + + wg := new(sync.WaitGroup) + quit := make(chan int) + + input := NewTCPInput("127.0.0.1:0", &TCPInputConfig{ + secure: true, + certificatePath: serverCertPemFile.Name(), + keyPath: serverPrivPemFile.Name(), + }) + output := NewTestOutput(func(data []byte) { + wg.Done() + }) + + Plugins.Inputs = []io.Reader{input} + Plugins.Outputs = []io.Writer{output} + + go Start(quit) + + conf := &tls.Config{ + InsecureSkipVerify: true, + } + + conn, err := tls.Dial("tcp", input.listener.Addr().String(), conf) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + msg := []byte("1 1 1\nGET / HTTP/1.1\r\n\r\n") + + for i := 0; i < 100; i++ { + wg.Add(1) + conn.Write(msg) + conn.Write([]byte(payloadSeparator)) + } + + wg.Wait() + + close(quit) +} \ No newline at end of file diff --git a/output_tcp.go b/output_tcp.go index 933dcb81..69ba05e8 100644 --- a/output_tcp.go +++ b/output_tcp.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "fmt" "io" "log" @@ -16,14 +17,20 @@ type TCPOutput struct { limit int buf chan []byte bufStats *GorStat + config *TCPOutputConfig +} + +type TCPOutputConfig struct { + secure bool } // NewTCPOutput constructor for TCPOutput // Initialize 10 workers which hold keep-alive connection -func NewTCPOutput(address string) io.Writer { +func NewTCPOutput(address string, config *TCPOutputConfig) io.Writer { o := new(TCPOutput) o.address = address + o.config = config o.buf = make(chan []byte, 100) if Settings.outputTCPStats { @@ -89,7 +96,11 @@ func (o *TCPOutput) Write(data []byte) (n int, err error) { } func (o *TCPOutput) connect(address string) (conn net.Conn, err error) { - conn, err = net.Dial("tcp", address) + if o.config.secure { + conn, err = tls.Dial("tcp", address, &tls.Config{}) + } else { + conn, err = net.Dial("tcp", address) + } return } diff --git a/output_tcp_test.go b/output_tcp_test.go index a7efd385..64ef8cad 100644 --- a/output_tcp_test.go +++ b/output_tcp_test.go @@ -17,7 +17,7 @@ func TestTCPOutput(t *testing.T) { wg.Done() }) input := NewTestInput() - output := NewTCPOutput(listener.Addr().String()) + output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{}) Plugins.Inputs = []io.Reader{input} Plugins.Outputs = []io.Writer{output} @@ -69,7 +69,7 @@ func BenchmarkTCPOutput(b *testing.B) { wg.Done() }) input := NewTestInput() - output := NewTCPOutput(listener.Addr().String()) + output := NewTCPOutput(listener.Addr().String(), &TCPOutputConfig{}) Plugins.Inputs = []io.Reader{input} Plugins.Outputs = []io.Writer{output} diff --git a/plugins.go b/plugins.go index 9cd11367..3685313e 100644 --- a/plugins.go +++ b/plugins.go @@ -110,11 +110,11 @@ func InitPlugins() { } for _, options := range Settings.inputTCP { - registerPlugin(NewTCPInput, options) + registerPlugin(NewTCPInput, options, &Settings.inputTCPConfig) } for _, options := range Settings.outputTCP { - registerPlugin(NewTCPOutput, options) + registerPlugin(NewTCPOutput, options, &Settings.outputTCPConfig) } for _, options := range Settings.inputFile { diff --git a/settings.go b/settings.go index 654b1dbd..a55ca42e 100644 --- a/settings.go +++ b/settings.go @@ -37,9 +37,11 @@ type AppSettings struct { outputStdout bool outputNull bool - inputTCP MultiOption - outputTCP MultiOption - outputTCPStats bool + inputTCP MultiOption + inputTCPConfig TCPInputConfig + outputTCP MultiOption + outputTCPConfig TCPOutputConfig + outputTCPStats bool inputFile MultiOption inputFileLoop bool @@ -93,7 +95,13 @@ func init() { flag.BoolVar(&Settings.outputNull, "output-null", false, "Used for testing inputs. Drops all requests.") flag.Var(&Settings.inputTCP, "input-tcp", "Used for internal communication between Gor instances. Example: \n\t# Receive requests from other Gor instances on 28020 port, and redirect output to staging\n\tgor --input-tcp :28020 --output-http staging.com") + flag.BoolVar(&Settings.inputTCPConfig.secure, "input-tcp-secure", false, "Turn on TLS security. Do not forget to specify certificate and key files.") + flag.StringVar(&Settings.inputTCPConfig.certificatePath, "input-tcp-certificate", "", "Path to PEM encoded certificate file. Used when TLS turned on.") + flag.StringVar(&Settings.inputTCPConfig.keyPath, "input-tcp-certificate-key", "", "Path to PEM encoded certificate key file. Used when TLS turned on.") + + flag.Var(&Settings.outputTCP, "output-tcp", "Used for internal communication between Gor instances. Example: \n\t# Listen for requests on 80 port and forward them to other Gor instance on 28020 port\n\tgor --input-raw :80 --output-tcp replay.local:28020") + flag.BoolVar(&Settings.outputTCPConfig.secure, "output-tcp-secure", false, "Use TLS secure connection. --input-file on another end should have TLS turned on as well.") flag.BoolVar(&Settings.outputTCPStats, "output-tcp-stats", false, "Report TCP output queue stats to console every 5 seconds.") flag.Var(&Settings.inputFile, "input-file", "Read requests from file: \n\tgor --input-file ./requests.gor --output-http staging.com")