diff --git a/server/tcp.go b/server/tcp.go index 4ee805cd..73cbe499 100644 --- a/server/tcp.go +++ b/server/tcp.go @@ -3,6 +3,7 @@ package input import ( "bufio" "encoding/binary" + "io" "net" "sync" "sync/atomic" @@ -58,8 +59,12 @@ func (h *HEPInput) serveTCP(addr string) { } func (h *HEPInput) handleTCP(c net.Conn) { + h.handleStream(c, "TCP") +} + +func (h *HEPInput) handleStream(c net.Conn, protocol string) { defer func() { - logp.Info("closing TCP connection from %s", c.RemoteAddr()) + logp.Info("closing %s connection from %s", protocol, c.RemoteAddr()) err := c.Close() if err != nil { logp.Err("%v", err) @@ -67,17 +72,6 @@ func (h *HEPInput) handleTCP(c net.Conn) { }() r := bufio.NewReader(c) - readBytes := func(buffer []byte) (int, error) { - n := uint(0) - for n < uint(len(buffer)) { - nn, err := r.Read(buffer[n:]) - n += uint(nn) - if err != nil { - return 0, err - } - } - return int(n), nil - } for { if atomic.LoadUint32(&h.stopped) == 1 { return @@ -96,11 +90,11 @@ func (h *HEPInput) handleTCP(c net.Conn) { return } buf := h.buffer.Get().([]byte) - n, err := readBytes(buf[:size]) - if err != nil || n > maxPktLen { + n, err := io.ReadFull(r, buf[:size]) + if err != nil || n != int(size) { logp.Warn("%v, unusal packet size with %d bytes", err, n) atomic.AddUint64(&h.stats.ErrCount, 1) - continue + return } h.inputCh <- buf[:n] atomic.AddUint64(&h.stats.PktCount, 1) diff --git a/server/tls.go b/server/tls.go index 07343f55..36c6fdf3 100644 --- a/server/tls.go +++ b/server/tls.go @@ -3,18 +3,18 @@ package input import ( "crypto/tls" "net" + "path/filepath" "sync" "sync/atomic" "time" - "path/filepath" - "github.com/sipcapture/heplify-server/config" "github.com/negbie/cert" "github.com/negbie/logp" + "github.com/sipcapture/heplify-server/config" ) -func parseTLSVersion(versionText string ) uint16 { - switch(versionText){ +func parseTLSVersion(versionText string) uint16 { + switch versionText { case "1.0": logp.Warn("TLS1.0 is not recommended. Use 1.2 or greater where possible") return tls.VersionTLS10 @@ -50,7 +50,7 @@ func (h *HEPInput) serveTLS(addr string) { cPath := config.Setting.TLSCertFolder minTLSVersion := parseTLSVersion(config.Setting.TLSMinVersion) // load any existing certs, otherwise generate a new one - ca, err := cert.NewCertificateAuthority( filepath.Join(cPath, "heplify-server") ) + ca, err := cert.NewCertificateAuthority(filepath.Join(cPath, "heplify-server")) if err != nil { logp.Err("%v", err) return @@ -88,30 +88,5 @@ func (h *HEPInput) serveTLS(addr string) { } func (h *HEPInput) handleTLS(c net.Conn) { - defer func() { - logp.Info("closing TLS connection from %s", c.RemoteAddr()) - err := c.Close() - if err != nil { - logp.Err("%v", err) - } - }() - - for { - if atomic.LoadUint32(&h.stopped) == 1 { - return - } - - buf := h.buffer.Get().([]byte) - n, err := c.Read(buf) - if err != nil { - logp.Warn("%v from %s", err, c.RemoteAddr()) - return - } else if n > maxPktLen { - logp.Warn("received too big packet with %d bytes", n) - atomic.AddUint64(&h.stats.ErrCount, 1) - continue - } - h.inputCh <- buf[:n] - atomic.AddUint64(&h.stats.PktCount, 1) - } + h.handleStream(c, "TLS") }