diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 905b5a65886c..933500f7cfb0 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -98,6 +98,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...main[Check the HEAD dif *Filebeat* +- tcp/unix input: Stop accepting connections after socket is closed. {pull}29712[29712] - Fix using log_group_name_prefix in aws-cloudwatch input. {pull}29695[29695] *Heartbeat* diff --git a/filebeat/inputsource/common/streaming/listener.go b/filebeat/inputsource/common/streaming/listener.go index 11e60a35b7bf..f256a03c24c9 100644 --- a/filebeat/inputsource/common/streaming/listener.go +++ b/filebeat/inputsource/common/streaming/listener.go @@ -23,8 +23,11 @@ import ( "context" "fmt" "net" + "strings" "sync" + "github.com/pkg/errors" + "github.com/elastic/beats/v7/filebeat/inputsource" "github.com/elastic/beats/v7/libbeat/common/atomic" "github.com/elastic/beats/v7/libbeat/logp" @@ -60,8 +63,33 @@ var ( "delimiter": FramingDelimiter, "rfc6587": FramingRFC6587, } + + availableFramingTypesErrFormat string ) +func init() { + framingTypeNames := make([]string, 0, len(framingTypes)) + for t := range framingTypes { + framingTypeNames = append(framingTypeNames, t) + } + + availableFramingTypesErrFormat = fmt.Sprintf("invalid framing type %%q, "+ + "the supported types are [%v]", strings.Join(framingTypeNames, ", ")) +} + +// Unpack unpacks the FramingType string value. +func (f *FramingType) Unpack(value string) error { + value = strings.ToLower(value) + + ft, ok := framingTypes[value] + if !ok { + return fmt.Errorf(availableFramingTypesErrFormat, value) + } + + *f = ft + return nil +} + // NewListener creates a new Listener func NewListener(family inputsource.Family, location string, handlerFactory HandlerFactory, listenerFactory ListenerFactory, config *ListenerConfig) *Listener { return &Listener{ @@ -73,7 +101,9 @@ func NewListener(family inputsource.Family, location string, handlerFactory Hand } } -// Start listen to the socket. +// Start listening to the socket and accepting connections. The method is +// non-blocking and starts a goroutine to service the socket. Stop must be +// called to ensure proper cleanup. func (l *Listener) Start() error { if err := l.initListen(context.Background()); err != nil { return err @@ -117,73 +147,73 @@ func (l *Listener) initListen(ctx context.Context) error { } func (l *Listener) run() { - l.log.Info("Started listening for " + l.family.String() + " connection") + l.log.Debug("Start accepting connections") + defer func() { + l.Listener.Close() + l.log.Debug("Stopped accepting connections") + }() for { conn, err := l.Listener.Accept() if err != nil { - select { - case <-l.ctx.Done(): + if l.ctx.Err() != nil { + // Shutdown. return - default: - l.log.Debugw("Can not accept the connection", "error", err) - continue } + + if errors.Is(err, net.ErrClosed) { + return + } + + l.log.Debugw("Cannot accept new connection", "error", err) + continue } l.wg.Add(1) go func() { - defer logp.Recover("recovering from a " + l.family.String() + " client crash") defer l.wg.Done() + l.handleConnection(conn) + }() + } +} - ctx, cancel := ctxtool.WithFunc(l.ctx, func() { conn.Close() }) - defer cancel() - - l.registerHandler() - defer l.unregisterHandler() +func (l *Listener) handleConnection(conn net.Conn) { + log := l.log + if remoteAddr := conn.RemoteAddr().String(); remoteAddr != "" { + log = log.With("remote_address", remoteAddr) + } + defer log.Recover("Panic in connection handler") - if l.family == inputsource.FamilyUnix { - // unix sockets have an empty `RemoteAddr` value, so no need to capture it - l.log.Debugw("New client", "total", l.clientsCount.Load()) - } else { - l.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", l.clientsCount.Load()) - } + // Ensure accepted connection is closed on return and at shutdown. + connCtx, cancel := ctxtool.WithFunc(l.ctx, func() { + conn.Close() + }) + defer cancel() - handler := l.handlerFactory(*l.config) - err := handler(ctx, conn) - if err != nil { - l.log.Debugw("client error", "error", err) - } + // Track number of clients. + l.clientsCount.Inc() + log.Debugw("New client connection.", "active_clients", l.clientsCount.Load()) + defer func() { + l.clientsCount.Dec() + log.Debugw("Client disconnected.", "active_clients", l.clientsCount.Load()) + }() - defer func() { - if l.family == inputsource.FamilyUnix { - // unix sockets have an empty `RemoteAddr` value, so no need to capture it - l.log.Debugw("client disconnected", "total", l.clientsCount.Load()) - } else { - l.log.Debugw("client disconnected", "remote_address", conn.RemoteAddr(), "total", l.clientsCount.Load()) - } - }() - }() + handler := l.handlerFactory(*l.config) + if err := handler(connCtx, conn); err != nil { + log.Debugw("Client error", "error", err) + return } } -// Stop stops accepting new incoming connections and Close any active clients +// Stop stops accepting new incoming connections and closes all active clients. func (l *Listener) Stop() { - l.log.Info("Stopping" + l.family.String() + "server") + l.log.Debugw("Stopping socket listener. Waiting for active connections to close.", "active_clients", l.clientsCount.Load()) l.ctx.Cancel() l.wg.Wait() - l.log.Info(l.family.String() + " server stopped") -} - -func (l *Listener) registerHandler() { - l.clientsCount.Inc() -} - -func (l *Listener) unregisterHandler() { - l.clientsCount.Dec() + l.log.Info("Socket listener stopped") } -// SplitFunc allows to create a `bufio.SplitFunc` based on a framing & +// SplitFunc allows to create a `bufio.SplitFunc` based on a framing and // delimiter provided. func SplitFunc(framing FramingType, lineDelimiter []byte) (bufio.SplitFunc, error) { if len(lineDelimiter) == 0 { @@ -202,24 +232,6 @@ func SplitFunc(framing FramingType, lineDelimiter []byte) (bufio.SplitFunc, erro case FramingRFC6587: return FactoryRFC6587Framing(lineDelimiter), nil default: - return nil, fmt.Errorf("unknown SplitFunc for framing %d and line delimiter %s", framing, string(lineDelimiter)) - } - -} - -// Unpack for config -func (f *FramingType) Unpack(value string) error { - ft, ok := framingTypes[value] - if !ok { - availableTypes := make([]string, len(framingTypes)) - i := 0 - for t := range framingTypes { - availableTypes[i] = t - i++ - } - return fmt.Errorf("invalid framing type '%s', supported types: %v", value, availableTypes) - + return nil, fmt.Errorf("unknown SplitFunc for framing %d and line delimiter %q", framing, lineDelimiter) } - *f = ft - return nil } diff --git a/filebeat/tests/system/test_syslog.py b/filebeat/tests/system/test_syslog.py index be4a600f7143..d72d51cf7db1 100644 --- a/filebeat/tests/system/test_syslog.py +++ b/filebeat/tests/system/test_syslog.py @@ -31,7 +31,7 @@ def test_syslog_with_tcp(self): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for TCP connection")) + self.wait_until(lambda: self.log_contains("Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP sock.connect((host, port)) @@ -73,7 +73,7 @@ def test_syslog_with_tcp_invalid_message(self): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for TCP connection")) + self.wait_until(lambda: self.log_contains("Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP sock.connect((host, port)) @@ -173,7 +173,10 @@ def run_filebeat_and_send_using_socket(self, socket_type, send_over_socket): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for UNIX connection")) + if socket_type == "stream": + self.wait_until(lambda: self.log_contains("Start accepting connections")) + else: + self.wait_until(lambda: self.log_contains("Started listening")) sock = send_over_socket(path, "<13>Oct 11 22:14:15 wopr.mymachine.co postfix/smtpd[2000]:" @@ -233,7 +236,10 @@ def run_filebeat_and_send_invalid_message_using_socket(self, socket_type, send_o filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for UNIX connection")) + if socket_type == "stream": + self.wait_until(lambda: self.log_contains("Start accepting connections")) + else: + self.wait_until(lambda: self.log_contains("Started listening")) sock = send_over_socket(path, "invalid\n") diff --git a/filebeat/tests/system/test_tcp.py b/filebeat/tests/system/test_tcp.py index dea22e7388a7..2f98e95c8ba6 100644 --- a/filebeat/tests/system/test_tcp.py +++ b/filebeat/tests/system/test_tcp.py @@ -59,7 +59,7 @@ def send_events_with_delimiter(self, delimiter): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for TCP connection")) + self.wait_until(lambda: self.log_contains("Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP sock.connect((host, port)) @@ -100,7 +100,7 @@ def send_events_with_rfc6587_framing(self, framing): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for TCP connection")) + self.wait_until(lambda: self.log_contains("Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP sock.connect((host, port)) diff --git a/filebeat/tests/system/test_tcp_tls.py b/filebeat/tests/system/test_tcp_tls.py index b4edd0144b5f..3e853ff84ecb 100644 --- a/filebeat/tests/system/test_tcp_tls.py +++ b/filebeat/tests/system/test_tcp_tls.py @@ -66,7 +66,7 @@ def test_tcp_over_tls_and_verify_valid_server_without_mutual_auth(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED, @@ -118,7 +118,7 @@ def test_tcp_over_tls_and_verify_invalid_server_without_mutual_auth(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED, @@ -159,7 +159,7 @@ def test_tcp_over_tls_mutual_auth_fails(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tls = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED, @@ -206,7 +206,7 @@ def test_tcp_over_tls_mutual_auth_succeed(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -264,7 +264,7 @@ def test_tcp_tls_with_a_plain_text_socket(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TCP sock.connect((config.get('host'), config.get('port'))) @@ -317,7 +317,7 @@ def test_tcp_over_tls_mutual_auth_rfc6587_framing(self): filebeat = self.start_beat() self.wait_until(lambda: self.log_contains( - "Started listening for TCP connection")) + "Start accepting connections")) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/filebeat/tests/system/test_unix.py b/filebeat/tests/system/test_unix.py index bc506b47d7b2..90bde8096cae 100644 --- a/filebeat/tests/system/test_unix.py +++ b/filebeat/tests/system/test_unix.py @@ -59,7 +59,7 @@ def send_events_with_delimiter(self, delimiter): filebeat = self.start_beat() - self.wait_until(lambda: self.log_contains("Started listening for UNIX connection")) + self.wait_until(lambda: self.log_contains("Start accepting connections")) sock = send_stream_socket(path, delimiter)