Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP+SNI support arbitrary large Client Hello #423

Merged
merged 3 commits into from
Feb 18, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions proxy/tcp/sni_proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package tcp

import (
"bufio"
"errors"
"io"
"log"
"net"
Expand Down Expand Up @@ -34,27 +36,86 @@ type SNIProxy struct {
Noroute metrics.Counter
}

// Create a buffer large enough to hold the client hello message including
// the tls record header and the handshake message header.
// The function requires at least the first 9 bytes of the tls conversation
// in "data".
// nil, error is returned if the data does not follow the
// specification (https://tools.ietf.org/html/rfc5246) or if the client hello
// is fragmented over multiple records.
func createClientHelloBuffer(data []byte) ([]byte, error) {
// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
if len(data) < 9 {
return nil, errors.New("At least 9 bytes required to determine client hello length")
}

if data[0] != 0x16 {
return nil, errors.New("Not a TLS handshake")
}

recordLength := int(data[3])<<8 | int(data[4])
if recordLength <= 0 || recordLength > 16384 {
return nil, errors.New("Invalid TLS record length")
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
if data[5] != 0x01 {
return nil, errors.New("Not a client hello")
}

handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if handshakeLength <= 0 || handshakeLength > recordLength-4 {
return nil, errors.New("Invalid client hello length (fragmentation not implemented)")
}

return make([]byte, handshakeLength+9), nil //9 for the header bytes
}

func (p *SNIProxy) ServeTCP(in net.Conn) error {
defer in.Close()

if p.Conn != nil {
p.Conn.Inc(1)
}

// capture client hello
data := make([]byte, 1024)
n, err := in.Read(data)
tlsReader := bufio.NewReader(in)
data, err := tlsReader.Peek(9)
if err != nil {
log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

tlsData, err := createClientHelloBuffer(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer if this function would only determine the size of the buffer and not allocate it. Also, keep it in the tls_clienthello.go since this is where the parsing logic for the TLS ClientHello header is. Also, I'd keep data instead of tlsData since there is no difference.

if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

_, err = io.ReadFull(tlsReader, tlsData)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}
data = data[:n]

host, ok := readServerName(data)
host, ok := readServerName(tlsData[5:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment where the 5: comes from (magic number).

if !ok {
log.Print("[DEBUG] tcp+sni: TLS handshake failed")
log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
Expand Down Expand Up @@ -88,8 +149,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
}
defer out.Close()

// copy client hello
n, err = out.Write(data)
// write the data already read from the connection
n, err := out.Write(tlsData)
if err != nil {
log.Print("[WARN] tcp+sni: copy client hello failed. ", err)
if p.ConnFail != nil {
Expand Down
58 changes: 58 additions & 0 deletions proxy/tcp/sni_proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package tcp

import (
"testing"

"github.com/fabiolb/fabio/assert"
)

func TestCreateClientHelloBufferNotTLS(t *testing.T) {
assertEqual := assert.Equal(t)

testCases := [][]byte{
// not enough data
{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x05},

// not tls record
{0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb},

// too large record
// |---------|
{0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x01, 0xec},

// zero record length
// |----------|
{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0xec},

// not client hello
// |----|
{0x16, 0x03, 0x01, 0x01, 0xF4, 0x02, 0x00, 0x01, 0xeb},

// bad handshake length
// |----- 0 --------|
{0x16, 0x03, 0x01, 0x00, 0xaa, 0x01, 0x00, 0x00, 0x00},

// Fragmentation (handshake larger than record)
// |- 500 ---| |----- 497 ------|
{0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1},
}

for i := 0; i < len(testCases); i++ {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like these test cases. Can you move them to the tls_clienthello_test.go since they test the parsing.

When you refactor the method to just return the size of the buffer you can test for that as well. My suggestion is:

tests := []struct {
    name string
    data []byte
    fail bool
    len int
}{
    {
        name: "bad handshake length",
        data: ...,
        fail: true,
    },
    ...
}

for _, tt := range tests {
    t.Run(tt.name, func(t *testing.T) {
        n, err := clientHelloBufferSize(tt.data)
        if tt.fail && err == nil {
            t.Fatal("should fail")
        }
        if got, want := n, tt.len; !tt.fail && got != want {
            t.Fatalf("got buffer size %d want %d", got, want)
        }
    }
}

Also I

_, err := createClientHelloBuffer(testCases[i])
if err == nil {
t.Logf("Case idx %d did not return an error", i)
}
assertEqual(err != nil, true)
}
}

func TestCreateClientHelloBufferOk(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can then probably roll that test case into the previous table driven test.

assertEqual := assert.Equal(t)
// Largest possible client hello message
// |- 16384 -| |----- 16380 ----|
data := []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc}
buffer, err := createClientHelloBuffer(data)
assertEqual(err, nil)
assertEqual(buffer != nil, true)
assertEqual(len(buffer), 16384+5) // record length + record header
}
73 changes: 10 additions & 63 deletions proxy/tcp/tls_clienthello.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,20 @@
package tcp

// record types
const (
handshakeRecord = 0x16
clientHelloType = 0x01
)

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// and empty string is returned as serverName.
func readServerName(data []byte) (serverName string, ok bool) {
if m, ok := readClientHello(data); ok {
return m.serverName, true
}
return "", false
}

// readClientHello
func readClientHello(data []byte) (m *clientHelloMsg, ok bool) {
if len(data) < 9 {
// println("buf too short")
return nil, false
}

// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
recType := data[0]
if recType != handshakeRecord {
// println("no handshake ")
return nil, false
// an empty string is returned as serverName.
// clientHelloHandshakeMsg must contain the full client hello handshake
// message including the 4 byte header.
// See: https://www.ietf.org/rfc/rfc5246.txt
func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) {
m := new(clientHelloMsg)
if !m.unmarshal(clientHelloHandshakeMsg) {
//println("client_hello unmarshal failed")
return "", false
}

recLen := int(data[3])<<8 | int(data[4])
if recLen == 0 || recLen > len(data)-5 {
// println("rec too short")
return nil, false
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
hsType := data[5]
if hsType != clientHelloType {
// println("no client_hello")
return nil, false
}

hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if hsLen == 0 || hsLen > len(data)-9 {
// println("handshake rec too short")
return nil, false
}

// byte 9- : client hello msg
//
// m.unmarshal parses the entire handshake message and
// not just the client hello. Therefore, we need to pass
// data from byte 5 instead of byte 9. (see comment below)
m = new(clientHelloMsg)
if !m.unmarshal(data[5:]) {
// println("client_hello unmarshal failed")
return nil, false
}
return m, true
return m.serverName, true
}

// The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go
Expand Down
57 changes: 57 additions & 0 deletions proxy/tcp/tls_clienthello_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package tcp

import (
"encoding/hex"
"testing"

"github.com/fabiolb/fabio/assert"
)

func TestReadServerNameBadData(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check whether you can also write a table driven test like the one I've outlined above?

assertEqual := assert.Equal(t)
clientHelloMsg := []byte{0x16, 0x03, 0x01, 0x45, 0x03, 0x01, 0x2, 0x01}
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "")
assertEqual(ok, false)
}

func TestReadServerNameNoExtension(t *testing.T) {
assertEqual := assert.Equal(t)
// Client hello from:
// openssl s_client -connect google.com:443
clientHelloMsg, _ := hex.DecodeString(
"0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" +
"67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" +
"390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" +
"00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" +
"450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" +
"00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" +
"00160017000800060007001400150004000500120013000100020003000f001000" +
"1100230000000d00260024060106020603efef050105020503040104020403eeee" +
"eded030103020303020102020203")
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "")
assertEqual(ok, true)
}

func TestReadServerNameOk(t *testing.T) {
assertEqual := assert.Equal(t)
// Client hello from:
// openssl s_client -connect google.com:443 -servername google.com
clientHelloMsg, _ := hex.DecodeString(
"0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" +
"24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" +
"6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" +
"009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" +
"33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" +
"0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" +
"12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" +
"001a0016001700080006000700140015000400050012001300010002000300" +
"0f0010001100230000000d00260024060106020603efef0501050205030401" +
"04020403eeeeeded030103020303020102020203")
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "google.com")
assertEqual(ok, true)
}