From d34c27c3deeaf027826dbfa8de9a13fef51d54c8 Mon Sep 17 00:00:00 2001
From: Leonid Bugaev <leonsbox@gmail.com>
Date: Sun, 31 Jul 2016 20:16:01 +0300
Subject: [PATCH] Track FIN packets to check for closed connection (#350)

* Track FIN packets to check for closed connection

* Add tests for FIN / fix bugs

* Fmt changes
---
 http_client_test.go                     | 31 +++++++++++++++
 input_raw_test.go                       | 49 +++++++++++++++++++++++
 raw_socket_listener/listener.go         | 24 ++++++-----
 raw_socket_listener/listener_test.go    | 40 +++++++++++++++++++
 raw_socket_listener/tcp_message.go      | 53 +++++++++++++++++++------
 raw_socket_listener/tcp_message_test.go |  7 ++--
 raw_socket_listener/tcp_packet.go       |  8 ++++
 7 files changed, 185 insertions(+), 27 deletions(-)

diff --git a/http_client_test.go b/http_client_test.go
index d93984a6..79c71e70 100644
--- a/http_client_test.go
+++ b/http_client_test.go
@@ -87,6 +87,37 @@ func TestHTTPClientSend(t *testing.T) {
 	wg.Wait()
 }
 
+func TestHTTPClientResonseByClose(t *testing.T) {
+	wg := new(sync.WaitGroup)
+
+	payload := []byte("GET / HTTP/1.1\r\n\r\n")
+	ln, _ := net.Listen("tcp", ":0")
+	go func(){
+		for {
+			conn, _ := ln.Accept()
+			buf := make([]byte, 4096)
+			conn.Read(buf)
+
+			conn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+			conn.Write([]byte("ab"))
+			conn.Close()
+
+			wg.Done()
+		}
+	}()
+
+	client := NewHTTPClient(ln.Addr().String(), &HTTPClientConfig{Debug: true})
+
+	wg.Add(1)
+	resp, _ := client.Send(payload)
+
+	if !bytes.Equal(resp, []byte("HTTP/1.1 200 OK\r\n\r\nab")) {
+		t.Error("Should return valid response", string(resp))
+	}
+
+	wg.Wait()
+}
+
 // https://github.com/buger/gor/issues/184
 func TestHTTPClientResponseBuffer(t *testing.T) {
 	testCases := []struct {
diff --git a/input_raw_test.go b/input_raw_test.go
index b8fe1b11..a40b0fad 100644
--- a/input_raw_test.go
+++ b/input_raw_test.go
@@ -79,6 +79,55 @@ func TestRAWInputIPv4(t *testing.T) {
 	}
 
 	wg.Wait()
+
+	close(quit)
+}
+
+func TestRAWInputNoKeepAlive(t *testing.T) {
+	wg := new(sync.WaitGroup)
+	quit := make(chan int)
+
+	listener, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	origin := &http.Server{
+		Handler:      http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				w.Write([]byte("a"))
+				w.Write([]byte("b"))
+			}),
+		ReadTimeout:  10 * time.Second,
+		WriteTimeout: 10 * time.Second,
+	}
+	origin.SetKeepAlivesEnabled(false)
+	go origin.Serve(listener)
+	defer listener.Close()
+
+	originAddr := listener.Addr().String()
+
+	input := NewRAWInput(originAddr, EnginePcap, true, testRawExpire, "")
+	defer input.Close()
+
+	output := NewTestOutput(func(data []byte) {
+		wg.Done()
+	})
+
+	Plugins.Inputs = []io.Reader{input}
+	Plugins.Outputs = []io.Writer{output}
+
+	client := NewHTTPClient("http://"+listener.Addr().String(), &HTTPClientConfig{})
+
+	go Start(quit)
+
+	for i := 0; i < 100; i++ {
+		// request + response
+		wg.Add(2)
+		client.Get("/")
+		time.Sleep(2 * time.Millisecond)
+	}
+
+	wg.Wait()
+
 	close(quit)
 }
 
diff --git a/raw_socket_listener/listener.go b/raw_socket_listener/listener.go
index a4c88fde..f85597b3 100644
--- a/raw_socket_listener/listener.go
+++ b/raw_socket_listener/listener.go
@@ -181,6 +181,7 @@ func (t *Listener) dispatchMessage(message *TCPMessage) {
 			delete(t.respAliases, message.Ack)
 			delete(t.respWithoutReq, message.Ack)
 		}
+
 		return
 	}
 
@@ -337,7 +338,7 @@ func (t *Listener) readPcap() {
 				var allAddr []string
 				for _, dc := range devices {
 					for _, addr := range dc.Addresses {
-						allAddr = append(allAddr, "(dst host " + addr.IP.String() + " and src host " + addr.IP.String() + ")")
+						allAddr = append(allAddr, "(dst host "+addr.IP.String()+" and src host "+addr.IP.String()+")")
 					}
 				}
 
@@ -441,7 +442,7 @@ func (t *Listener) readPcap() {
 					}
 
 					// Invalid length
-					if int(ihl * 4) > ipLength {
+					if int(ihl*4) > ipLength {
 						continue
 					}
 
@@ -452,7 +453,7 @@ func (t *Listener) readPcap() {
 						continue
 					}
 
-					data = data[ihl * 4:]
+					data = data[ihl*4:]
 				} else {
 					// Truncated IP info
 					if len(data) < 40 {
@@ -471,10 +472,11 @@ func (t *Listener) readPcap() {
 				}
 
 				dataOffset := (data[12] & 0xF0) >> 4
+				isFIN := data[13]&0x01 != 0
 
 				// We need only packets with data inside
 				// Check that the buffer is larger than the size of the TCP header
-				if len(data) > int(dataOffset*4) {
+				if len(data) > int(dataOffset*4) || isFIN {
 					if !bpfSupported {
 						destPort := binary.BigEndian.Uint16(data[2:4])
 						srcPort := binary.BigEndian.Uint16(data[0:2])
@@ -555,16 +557,16 @@ func (t *Listener) readPcapFile() {
 			var addr, data []byte
 
 			if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
-			  tcp, _ := tcpLayer.(*layers.TCP)
-			  data = append(tcp.LayerContents(), tcp.LayerPayload()...)
-			  copy(data[2:4], []byte{0, 1})
+				tcp, _ := tcpLayer.(*layers.TCP)
+				data = append(tcp.LayerContents(), tcp.LayerPayload()...)
+				copy(data[2:4], []byte{0, 1})
 			} else {
 				continue
 			}
 
 			if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
-			  ip, _ := ipLayer.(*layers.IPv4)
-			  addr = ip.SrcIP
+				ip, _ := ipLayer.(*layers.IPv4)
+				addr = ip.SrcIP
 			} else if ipLayer = packet.Layer(layers.LayerTypeIPv6); ipLayer != nil {
 				ip, _ := ipLayer.(*layers.IPv6)
 				addr = ip.SrcIP
@@ -763,13 +765,13 @@ func (t *Listener) processTCPPacket(packet *TCPPacket) {
 	// If message contains only single packet immediately dispatch it
 	if message.complete {
 		if isIncoming {
-			// log.Println("I'm finished", string(message.Bytes()), message.ResponseID, t.messages)
 			if t.trackResponse {
 				if resp, ok := t.messages[message.ResponseID]; ok {
-					t.dispatchMessage(message)
 					if resp.complete {
 						t.dispatchMessage(resp)
 					}
+
+					t.dispatchMessage(message)
 				}
 			} else {
 				t.dispatchMessage(message)
diff --git a/raw_socket_listener/listener_test.go b/raw_socket_listener/listener_test.go
index b66b23ab..72e01ded 100644
--- a/raw_socket_listener/listener_test.go
+++ b/raw_socket_listener/listener_test.go
@@ -46,6 +46,46 @@ func TestRawListenerInput(t *testing.T) {
 	}
 }
 
+func TestRawListenerInputResponseByClose(t *testing.T) {
+	var req, resp *TCPMessage
+
+	listener := NewListener("", "0", EnginePcap, true, 10*time.Millisecond)
+	defer listener.Close()
+
+	reqPacket := buildPacket(true, 1, 1, []byte("GET / HTTP/1.1\r\n\r\n"))
+
+	respAck := reqPacket.Seq + uint32(len(reqPacket.Data))
+	respPacket := buildPacket(false, respAck, reqPacket.Seq+1, []byte("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nasd"))
+	finPacket := buildPacket(false, respAck, reqPacket.Seq+2, []byte(""))
+	finPacket.IsFIN = true
+
+	listener.packetsChan <- reqPacket.Dump()
+	listener.packetsChan <- respPacket.Dump()
+	listener.packetsChan <- finPacket.Dump()
+
+	select {
+	case req = <-listener.messagesChan:
+	case <-time.After(time.Millisecond):
+		t.Error("Should return request immediately")
+		return
+	}
+
+	if !req.IsIncoming {
+		t.Error("Should be request")
+	}
+
+	select {
+	case resp = <-listener.messagesChan:
+	case <-time.After(20 * time.Millisecond):
+		t.Error("Should return response immediately")
+		return
+	}
+
+	if resp.IsIncoming {
+		t.Error("Should be response")
+	}
+}
+
 func TestRawListenerInputWithoutResponse(t *testing.T) {
 	var req *TCPMessage
 
diff --git a/raw_socket_listener/tcp_message.go b/raw_socket_listener/tcp_message.go
index 15837bdb..a394153c 100644
--- a/raw_socket_listener/tcp_message.go
+++ b/raw_socket_listener/tcp_message.go
@@ -72,7 +72,7 @@ func (t *TCPMessage) BodySize() (size int) {
 
 	size += len(proto.Body(t.packets[t.headerPacket].Data))
 
-	for _, p := range t.packets[t.headerPacket + 1:] {
+	for _, p := range t.packets[t.headerPacket+1:] {
 		size += len(p.Data)
 	}
 
@@ -145,7 +145,17 @@ func (t *TCPMessage) checkSeqIntegrity() {
 		t.seqMissing = false
 	}
 
-	for i, p := range t.packets {
+	offset := len(t.packets) - 1
+
+	if t.packets[offset].IsFIN {
+		offset--
+	}
+
+	for i, p := range t.packets[:offset] {
+		if p.IsFIN {
+			continue
+		}
+
 		// If final packet
 		if len(t.packets) == i+1 {
 			t.seqMissing = false
@@ -228,6 +238,15 @@ func (t *TCPMessage) checkIfComplete() {
 			if bytes.LastIndex(lastPacket.Data, bChunkEnd) != -1 {
 				t.complete = true
 			}
+		default:
+			if len(t.packets) == 0 {
+				return
+			}
+
+			last := t.packets[len(t.packets)-1]
+			if last.IsFIN {
+				t.complete = true
+			}
 		}
 	}
 }
@@ -302,10 +321,11 @@ func (t *TCPMessage) updateMethodType() {
 type httpBodyType uint8
 
 const (
-	httpBodyNotSet        httpBodyType = 0
-	httpBodyEmpty         httpBodyType = 1
-	httpBodyContentLength httpBodyType = 2
-	httpBodyChunked       httpBodyType = 3
+	httpBodyNotSet          httpBodyType = 0
+	httpBodyEmpty           httpBodyType = 1
+	httpBodyContentLength   httpBodyType = 2
+	httpBodyChunked         httpBodyType = 3
+	httpBodyConnectionClose httpBodyType = 4
 )
 
 func (t *TCPMessage) updateBodyType() {
@@ -326,7 +346,7 @@ func (t *TCPMessage) updateBodyType() {
 		t.bodyType = httpBodyEmpty
 		return
 	case httpMethodWithBody:
-		var lengthB, encB []byte
+		var lengthB, encB, connB []byte
 
 		for _, p := range t.packets[:t.headerPacket+1] {
 			lengthB = proto.Header(p.Data, []byte("Content-Length"))
@@ -340,12 +360,21 @@ func (t *TCPMessage) updateBodyType() {
 			t.bodyType = httpBodyContentLength
 			t.contentLength, _ = strconv.Atoi(string(lengthB))
 			return
-		} else {
+		}
+
+		for _, p := range t.packets[:t.headerPacket+1] {
+			encB = proto.Header(p.Data, []byte("Transfer-Encoding"))
+
+			if len(encB) > 0 {
+				t.bodyType = httpBodyChunked
+				return
+			}
+
 			for _, p := range t.packets[:t.headerPacket+1] {
-				encB = proto.Header(p.Data, []byte("Transfer-Encoding"))
+				connB = proto.Header(p.Data, []byte("Connection"))
 
-				if len(encB) > 0 {
-					t.bodyType = httpBodyChunked
+				if len(connB) > 0 && bytes.Equal(connB, []byte("close")) {
+					t.bodyType = httpBodyConnectionClose
 					return
 				}
 			}
@@ -363,7 +392,7 @@ const (
 	httpExpect100Continue httpExpectType = 2
 )
 
-var bExpectHeader = []byte("Expect:")
+var bExpectHeader = []byte("Expect")
 var bExpect100Value = []byte("100-continue")
 
 func (t *TCPMessage) check100Continue() {
diff --git a/raw_socket_listener/tcp_message_test.go b/raw_socket_listener/tcp_message_test.go
index 0af3bd91..781a8243 100644
--- a/raw_socket_listener/tcp_message_test.go
+++ b/raw_socket_listener/tcp_message_test.go
@@ -216,12 +216,11 @@ func TestTCPMessageBodyType(t *testing.T) {
 	}
 }
 
-
 func TestTCPMessageBodySize(t *testing.T) {
 	testCases := []struct {
-		direction        bool
-		payloads         []string
-		expectedSize     int
+		direction    bool
+		payloads     []string
+		expectedSize int
 	}{
 		{true, []string{"GET / HTTP/1.1\r\n\r\n"}, 0},
 		{true, []string{"POST / HTTP/1.1\r\nContent-Length: 2\r\n\r\nab"}, 2},
diff --git a/raw_socket_listener/tcp_packet.go b/raw_socket_listener/tcp_packet.go
index d53834e9..0e3e8324 100644
--- a/raw_socket_listener/tcp_packet.go
+++ b/raw_socket_listener/tcp_packet.go
@@ -33,6 +33,7 @@ type TCPPacket struct {
 	Ack        uint32
 	OrigAck    uint32
 	DataOffset uint8
+	IsFIN      bool
 
 	Raw  []byte
 	Data []byte
@@ -71,6 +72,7 @@ func (t *TCPPacket) ParseBasic() {
 	t.Seq = binary.BigEndian.Uint32(t.Raw[4:8])
 	t.Ack = binary.BigEndian.Uint32(t.Raw[8:12])
 	t.DataOffset = (t.Raw[12] & 0xF0) >> 4
+	t.IsFIN = t.Raw[13]&0x01 != 0
 
 	// log.Println("DataOffset:", t.DataOffset, t.DestPort, t.SrcPort, t.Seq, t.Ack)
 
@@ -90,6 +92,11 @@ func (t *TCPPacket) Dump() []byte {
 	binary.BigEndian.PutUint32(tcpBuf[8:12], t.Ack)
 
 	tcpBuf[12] = 64
+
+	if t.IsFIN {
+		tcpBuf[13] = tcpBuf[13] | 0x01
+	}
+
 	copy(tcpBuf[16:], t.Data)
 
 	return buf
@@ -109,6 +116,7 @@ func (t *TCPPacket) String() string {
 		"Sequence:" + strconv.Itoa(int(t.Seq)),
 		"Acknowledgment:" + strconv.Itoa(int(t.Ack)),
 		"Header len:" + strconv.Itoa(int(t.DataOffset)),
+		"FIN:" + strconv.FormatBool(t.IsFIN),
 
 		"Data size:" + strconv.Itoa(len(t.Data)),
 		"Data:" + string(t.Data[:maxLen]),