Skip to content

Commit

Permalink
tcp/unix input: Stop accepting connections after socket is closed (#2…
Browse files Browse the repository at this point in the history
…9712)

Stop accepting connections after tcp/unix socket is closed. This will suppress debug messages for net.ErrClosed.

    2021-12-22T13:49:04.151Z	INFO	[tcp]	streaming/listener.go:172	StoppingTCPserver	{"address": "0.0.0.0:7000"}
    2021-12-22T13:49:04.151Z	DEBUG	[tcp]	streaming/listener.go:129	Can not accept the connection	{"address": "0.0.0.0:7000", "error": "accept tcp [::]:7000: use of closed network connection"}
  • Loading branch information
andrewkroh authored Feb 4, 2022
1 parent ddc0c2f commit 8c79b67
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...main[Check the HEAD dif

*Filebeat*

- tcp/unix input: Stop accepting connections after socket is closed. {pull}29712[29712]
- Fix using log_group_name_prefix in aws-cloudwatch input. {pull}29695[29695]

*Heartbeat*
Expand Down
140 changes: 76 additions & 64 deletions filebeat/inputsource/common/streaming/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ import (
"context"
"fmt"
"net"
"strings"
"sync"

"github.com/pkg/errors"

"github.com/elastic/beats/v7/filebeat/inputsource"
"github.com/elastic/beats/v7/libbeat/common/atomic"
"github.com/elastic/beats/v7/libbeat/logp"
Expand Down Expand Up @@ -60,8 +63,33 @@ var (
"delimiter": FramingDelimiter,
"rfc6587": FramingRFC6587,
}

availableFramingTypesErrFormat string
)

func init() {
framingTypeNames := make([]string, 0, len(framingTypes))
for t := range framingTypes {
framingTypeNames = append(framingTypeNames, t)
}

availableFramingTypesErrFormat = fmt.Sprintf("invalid framing type %%q, "+
"the supported types are [%v]", strings.Join(framingTypeNames, ", "))
}

// Unpack unpacks the FramingType string value.
func (f *FramingType) Unpack(value string) error {
value = strings.ToLower(value)

ft, ok := framingTypes[value]
if !ok {
return fmt.Errorf(availableFramingTypesErrFormat, value)
}

*f = ft
return nil
}

// NewListener creates a new Listener
func NewListener(family inputsource.Family, location string, handlerFactory HandlerFactory, listenerFactory ListenerFactory, config *ListenerConfig) *Listener {
return &Listener{
Expand All @@ -73,7 +101,9 @@ func NewListener(family inputsource.Family, location string, handlerFactory Hand
}
}

// Start listen to the socket.
// Start listening to the socket and accepting connections. The method is
// non-blocking and starts a goroutine to service the socket. Stop must be
// called to ensure proper cleanup.
func (l *Listener) Start() error {
if err := l.initListen(context.Background()); err != nil {
return err
Expand Down Expand Up @@ -117,73 +147,73 @@ func (l *Listener) initListen(ctx context.Context) error {
}

func (l *Listener) run() {
l.log.Info("Started listening for " + l.family.String() + " connection")
l.log.Debug("Start accepting connections")
defer func() {
l.Listener.Close()
l.log.Debug("Stopped accepting connections")
}()

for {
conn, err := l.Listener.Accept()
if err != nil {
select {
case <-l.ctx.Done():
if l.ctx.Err() != nil {
// Shutdown.
return
default:
l.log.Debugw("Can not accept the connection", "error", err)
continue
}

if errors.Is(err, net.ErrClosed) {
return
}

l.log.Debugw("Cannot accept new connection", "error", err)
continue
}

l.wg.Add(1)
go func() {
defer logp.Recover("recovering from a " + l.family.String() + " client crash")
defer l.wg.Done()
l.handleConnection(conn)
}()
}
}

ctx, cancel := ctxtool.WithFunc(l.ctx, func() { conn.Close() })
defer cancel()

l.registerHandler()
defer l.unregisterHandler()
func (l *Listener) handleConnection(conn net.Conn) {
log := l.log
if remoteAddr := conn.RemoteAddr().String(); remoteAddr != "" {
log = log.With("remote_address", remoteAddr)
}
defer log.Recover("Panic in connection handler")

if l.family == inputsource.FamilyUnix {
// unix sockets have an empty `RemoteAddr` value, so no need to capture it
l.log.Debugw("New client", "total", l.clientsCount.Load())
} else {
l.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", l.clientsCount.Load())
}
// Ensure accepted connection is closed on return and at shutdown.
connCtx, cancel := ctxtool.WithFunc(l.ctx, func() {
conn.Close()
})
defer cancel()

handler := l.handlerFactory(*l.config)
err := handler(ctx, conn)
if err != nil {
l.log.Debugw("client error", "error", err)
}
// Track number of clients.
l.clientsCount.Inc()
log.Debugw("New client connection.", "active_clients", l.clientsCount.Load())
defer func() {
l.clientsCount.Dec()
log.Debugw("Client disconnected.", "active_clients", l.clientsCount.Load())
}()

defer func() {
if l.family == inputsource.FamilyUnix {
// unix sockets have an empty `RemoteAddr` value, so no need to capture it
l.log.Debugw("client disconnected", "total", l.clientsCount.Load())
} else {
l.log.Debugw("client disconnected", "remote_address", conn.RemoteAddr(), "total", l.clientsCount.Load())
}
}()
}()
handler := l.handlerFactory(*l.config)
if err := handler(connCtx, conn); err != nil {
log.Debugw("Client error", "error", err)
return
}
}

// Stop stops accepting new incoming connections and Close any active clients
// Stop stops accepting new incoming connections and closes all active clients.
func (l *Listener) Stop() {
l.log.Info("Stopping" + l.family.String() + "server")
l.log.Debugw("Stopping socket listener. Waiting for active connections to close.", "active_clients", l.clientsCount.Load())
l.ctx.Cancel()
l.wg.Wait()
l.log.Info(l.family.String() + " server stopped")
}

func (l *Listener) registerHandler() {
l.clientsCount.Inc()
}

func (l *Listener) unregisterHandler() {
l.clientsCount.Dec()
l.log.Info("Socket listener stopped")
}

// SplitFunc allows to create a `bufio.SplitFunc` based on a framing &
// SplitFunc allows to create a `bufio.SplitFunc` based on a framing and
// delimiter provided.
func SplitFunc(framing FramingType, lineDelimiter []byte) (bufio.SplitFunc, error) {
if len(lineDelimiter) == 0 {
Expand All @@ -202,24 +232,6 @@ func SplitFunc(framing FramingType, lineDelimiter []byte) (bufio.SplitFunc, erro
case FramingRFC6587:
return FactoryRFC6587Framing(lineDelimiter), nil
default:
return nil, fmt.Errorf("unknown SplitFunc for framing %d and line delimiter %s", framing, string(lineDelimiter))
}

}

// Unpack for config
func (f *FramingType) Unpack(value string) error {
ft, ok := framingTypes[value]
if !ok {
availableTypes := make([]string, len(framingTypes))
i := 0
for t := range framingTypes {
availableTypes[i] = t
i++
}
return fmt.Errorf("invalid framing type '%s', supported types: %v", value, availableTypes)

return nil, fmt.Errorf("unknown SplitFunc for framing %d and line delimiter %q", framing, lineDelimiter)
}
*f = ft
return nil
}
14 changes: 10 additions & 4 deletions filebeat/tests/system/test_syslog.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_syslog_with_tcp(self):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for TCP connection"))
self.wait_until(lambda: self.log_contains("Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
sock.connect((host, port))
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_syslog_with_tcp_invalid_message(self):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for TCP connection"))
self.wait_until(lambda: self.log_contains("Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
sock.connect((host, port))
Expand Down Expand Up @@ -173,7 +173,10 @@ def run_filebeat_and_send_using_socket(self, socket_type, send_over_socket):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for UNIX connection"))
if socket_type == "stream":
self.wait_until(lambda: self.log_contains("Start accepting connections"))
else:
self.wait_until(lambda: self.log_contains("Started listening"))

sock = send_over_socket(path,
"<13>Oct 11 22:14:15 wopr.mymachine.co postfix/smtpd[2000]:"
Expand Down Expand Up @@ -233,7 +236,10 @@ def run_filebeat_and_send_invalid_message_using_socket(self, socket_type, send_o

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for UNIX connection"))
if socket_type == "stream":
self.wait_until(lambda: self.log_contains("Start accepting connections"))
else:
self.wait_until(lambda: self.log_contains("Started listening"))

sock = send_over_socket(path, "invalid\n")

Expand Down
4 changes: 2 additions & 2 deletions filebeat/tests/system/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def send_events_with_delimiter(self, delimiter):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for TCP connection"))
self.wait_until(lambda: self.log_contains("Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
sock.connect((host, port))
Expand Down Expand Up @@ -100,7 +100,7 @@ def send_events_with_rfc6587_framing(self, framing):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for TCP connection"))
self.wait_until(lambda: self.log_contains("Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
sock.connect((host, port))
Expand Down
12 changes: 6 additions & 6 deletions filebeat/tests/system/test_tcp_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_tcp_over_tls_and_verify_valid_server_without_mutual_auth(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED,
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_tcp_over_tls_and_verify_invalid_server_without_mutual_auth(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED,
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_tcp_over_tls_mutual_auth_fails(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED,
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_tcp_over_tls_mutual_auth_succeed(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

Expand Down Expand Up @@ -264,7 +264,7 @@ def test_tcp_tls_with_a_plain_text_socket(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP
sock.connect((config.get('host'), config.get('port')))
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_tcp_over_tls_mutual_auth_rfc6587_framing(self):
filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains(
"Started listening for TCP connection"))
"Start accepting connections"))

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

Expand Down
2 changes: 1 addition & 1 deletion filebeat/tests/system/test_unix.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def send_events_with_delimiter(self, delimiter):

filebeat = self.start_beat()

self.wait_until(lambda: self.log_contains("Started listening for UNIX connection"))
self.wait_until(lambda: self.log_contains("Start accepting connections"))

sock = send_stream_socket(path, delimiter)

Expand Down

0 comments on commit 8c79b67

Please sign in to comment.