Skip to content
This repository has been archived by the owner on Sep 21, 2023. It is now read-only.

Commit

Permalink
Add StrictMode with event validation
Browse files Browse the repository at this point in the history
In `StrictMode` required fields are:

* Timestamp
* Datastream.Namespace
* Datastream.Dataset
* Datastream.Type
* Source.InputId
  • Loading branch information
rdner committed Aug 12, 2022
1 parent dcb8427 commit 1330785
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/spf13/cobra"

_ "github.com/elastic/elastic-agent-libs/logp/configure"
"github.com/elastic/elastic-agent-shipper/server"
"github.com/elastic/elastic-agent-shipper/controller"
)

// NewCommand returns a new command structure
Expand All @@ -41,7 +41,7 @@ func runCmd() *cobra.Command {
Use: "run",
Short: "Start the elastic-agent-shipper.",
Run: func(_ *cobra.Command, _ []string) {
if err := server.LoadAndRun(); err != nil {
if err := controller.LoadAndRun(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n\n", err)
os.Exit(1)
}
Expand Down
4 changes: 4 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-shipper/monitoring"
"github.com/elastic/elastic-agent-shipper/queue"
"github.com/elastic/elastic-agent-shipper/server"
"github.com/elastic/go-ucfg/json"
)

Expand Down Expand Up @@ -43,6 +44,7 @@ type ShipperConfig struct {
Port int `config:"port"` //Port to listen on
Monitor monitoring.Config `config:"monitoring"` //Queue monitoring settings
Queue queue.Config `config:"queue"` //Queue settings
Server server.Config `config:"server"` //gRPC Server settings
}

// ReadConfig returns the populated config from the specified path
Expand All @@ -64,6 +66,7 @@ func ReadConfig() (ShipperConfig, error) {
Log: logp.DefaultConfig(logp.SystemdEnvironment),
Monitor: monitoring.DefaultConfig(),
Queue: queue.DefaultConfig(),
Server: server.DefaultConfig(),
}
err = raw.Unpack(&config)
if err != nil {
Expand All @@ -84,6 +87,7 @@ func ReadConfigFromJSON(raw string) (ShipperConfig, error) {
Log: logp.DefaultConfig(logp.SystemdEnvironment),
Monitor: monitoring.DefaultConfig(),
Queue: queue.DefaultConfig(),
Server: server.DefaultConfig(),
}
err = rawCfg.Unpack(&shipperConfig)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package server
package controller

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package server
package controller

import (
"context"
Expand Down
152 changes: 152 additions & 0 deletions controller/run.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package controller

import (
"context"
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"github.com/elastic/elastic-agent-client/v7/pkg/client"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-shipper/config"
"github.com/elastic/elastic-agent-shipper/monitoring"
"github.com/elastic/elastic-agent-shipper/output"
"github.com/elastic/elastic-agent-shipper/queue"
"github.com/elastic/elastic-agent-shipper/server"

pb "github.com/elastic/elastic-agent-shipper-client/pkg/proto"
)

// LoadAndRun loads the config object and runs the gRPC server
func LoadAndRun() error {
agentClient, _, err := client.NewV2FromReader(os.Stdin, client.VersionInfo{Name: "elastic-agent-shipper", Version: "v2"})
if err != nil {
return fmt.Errorf("error reading control config from agent: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err = runController(ctx, agentClient)

return err
}

// handle shutdown of the shipper
func handleShutdown(stopFunc func(), done doneChan) {
var callback sync.Once
log := logp.L()
// On termination signals, gracefully stop the Beat
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
for {
select {
case <-done:
log.Debugf("Shutting down from agent controller")
callback.Do(stopFunc)
return
case sig := <-sigc:
switch sig {
case syscall.SIGINT, syscall.SIGTERM:
log.Debug("Received sigterm/sigint, stopping")
case syscall.SIGHUP:
log.Debug("Received sighup, stopping")
}
callback.Do(stopFunc)
return
}
}

}()
}

// Run starts the gRPC server
func (c *clientHandler) Run(cfg config.ShipperConfig, unit *client.Unit) error {
log := logp.L()

// When there is queue-specific configuration in ShipperConfig, it should
// be passed in here.
queue, err := queue.New(cfg.Queue)
if err != nil {
return fmt.Errorf("couldn't create queue: %w", err)
}

// Make a placeholder console output to read the queue's events
out := output.NewConsole(queue)
out.Start()

lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", cfg.Port))
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

monHandler, err := loadMonitoring(cfg, queue)
if err != nil {
return fmt.Errorf("error loading outputs: %w", err)
}

_ = unit.UpdateState(client.UnitStateConfiguring, "starting shipper server", nil)

var opts []grpc.ServerOption
if cfg.TLS {
creds, err := credentials.NewServerTLSFromFile(cfg.Cert, cfg.Key)
if err != nil {
return fmt.Errorf("failed to generate credentials %w", err)
}
opts = []grpc.ServerOption{grpc.Creds(creds)}
}
grpcServer := grpc.NewServer(opts...)
shipperServer, err := server.NewShipperServer(cfg.Server, queue)
if err != nil {
return fmt.Errorf("failed to initialise the server: %w", err)
}
pb.RegisterProducerServer(grpcServer, shipperServer)

shutdownFunc := func() {
grpcServer.GracefulStop()
monHandler.End()
queue.Close()
// The output will shut down once the queue is closed.
// We call Wait to give it a chance to finish with events
// it has already read.
out.Wait()
shipperServer.Close()
}
handleShutdown(shutdownFunc, c.shutdownInit)
log.Debugf("gRPC server is listening on port %d", cfg.Port)
_ = unit.UpdateState(client.UnitStateHealthy, "Shipper Running", nil)

// This will get sent after the server has shutdown, signaling to the runloop that it can stop.
// The shipper has no queues connected right now, but once it does, this function can't run until
// after the queues have emptied and/or shutdown. We'll presumably have a better idea of how this
// will work once we have queues connected here.
defer func() {
log.Debugf("shipper has completed shutdown, stopping")
c.shutdownComplete.Done()
}()
c.shutdownComplete.Add(1)
return grpcServer.Serve(lis)

}

// Initialize metrics and outputs
func loadMonitoring(cfg config.ShipperConfig, queue *queue.Queue) (*monitoring.QueueMonitor, error) {
//startup monitor
mon, err := monitoring.NewFromConfig(cfg.Monitor, queue)
if err != nil {
return nil, fmt.Errorf("error initializing output monitor: %w", err)
}

mon.Watch()

return mon, nil
}
File renamed without changes.
20 changes: 20 additions & 0 deletions server/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package server

type Config struct {
// StrictMode means that every incoming event will be validated against the
// list of requried fields. This introduces some additional overhead but can
// be really handy for client developers on the debugging stage.
// Normally, it should be disabled during production use and enabled for testing.
StrictMode bool `config:"strict_mode"`
}

// DefaultConfig returns default configuration for the gRPC server
func DefaultConfig() Config {
return Config{
StrictMode: false,
}
}
77 changes: 76 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -52,11 +53,13 @@ type shipperServer struct {
ctx context.Context
stop func()

cfg Config

pb.UnimplementedProducerServer
}

// NewShipperServer creates a new server instance for handling gRPC endpoints.
func NewShipperServer(publisher Publisher) (ShipperServer, error) {
func NewShipperServer(cfg Config, publisher Publisher) (ShipperServer, error) {
if publisher == nil {
return nil, errors.New("publisher cannot be nil")
}
Expand All @@ -71,6 +74,7 @@ func NewShipperServer(publisher Publisher) (ShipperServer, error) {
logger: logp.NewLogger("shipper-server"),
publisher: publisher,
close: &sync.Once{},
cfg: cfg,
}

s.ctx, s.stop = context.WithCancel(context.Background())
Expand Down Expand Up @@ -103,6 +107,15 @@ func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.Publis
return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid))
}

if serv.cfg.StrictMode {
for _, e := range req.Events {
err := serv.validateEvent(e)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
}

for _, e := range req.Events {
_, err := serv.publisher.Publish(e)
if err == nil {
Expand Down Expand Up @@ -190,3 +203,65 @@ func (serv *shipperServer) Close() error {

return nil
}

func (serv *shipperServer) validateEvent(m *messages.Event) error {
var msgs []string

if err := m.Timestamp.CheckValid(); err != nil {
msgs = append(msgs, fmt.Sprintf("timestamp: %s", err))
}

if err := serv.validateDataStream(m.DataStream); err != nil {
msgs = append(msgs, fmt.Sprintf("datastream: %s", err))
}

if err := serv.validateSource(m.Source); err != nil {
msgs = append(msgs, fmt.Sprintf("source: %s", err))
}

if len(msgs) == 0 {
return nil
}

return errors.New(strings.Join(msgs, "; "))
}

func (serv *shipperServer) validateSource(s *messages.Source) error {
if s == nil {
return fmt.Errorf("cannot be nil")
}

var msgs []string
if s.InputId == "" {
msgs = append(msgs, "input_id is a required field")
}

if len(msgs) == 0 {
return nil
}

return errors.New(strings.Join(msgs, "; "))
}

func (serv *shipperServer) validateDataStream(ds *messages.DataStream) error {
if ds == nil {
return fmt.Errorf("cannot be nil")
}

var msgs []string
if ds.Dataset == "" {
msgs = append(msgs, "dataset is a required field")
}
if ds.Namespace == "" {
msgs = append(msgs, "namespace is a required field")
}
if ds.Type == "" {
msgs = append(msgs, "type is a required field")
}

if len(msgs) == 0 {
return nil
}

return errors.New(strings.Join(msgs, "; "))
}
Loading

0 comments on commit 1330785

Please sign in to comment.