From b85408e7bc8b435cd3fd6a2a586e2f0506008165 Mon Sep 17 00:00:00 2001 From: Denis Rechkunov Date: Thu, 11 Aug 2022 13:24:58 +0200 Subject: [PATCH] Add `StrictMode` with event validation In `StrictMode` required fields are: * Timestamp * Datastream.Namespace * Datastream.Dataset * Datastream.Type * Source.InputId --- cmd/cmd.go | 4 +- config/config.go | 4 + {server => controller}/controller_client.go | 2 +- .../controller_client_test.go | 2 +- {server => controller}/run.go | 5 +- server/config.go | 20 +++ server/server.go | 77 ++++++++- server/server_test.go | 146 ++++++++++++++++-- 8 files changed, 244 insertions(+), 16 deletions(-) rename {server => controller}/controller_client.go (99%) rename {server => controller}/controller_client_test.go (99%) rename {server => controller}/run.go (96%) create mode 100644 server/config.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 339903c..f08eefc 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/cobra" _ "github.com/elastic/elastic-agent-libs/logp/configure" - "github.com/elastic/elastic-agent-shipper/server" + "github.com/elastic/elastic-agent-shipper/controller" ) // NewCommand returns a new command structure @@ -41,7 +41,7 @@ func runCmd() *cobra.Command { Use: "run", Short: "Start the elastic-agent-shipper.", Run: func(_ *cobra.Command, _ []string) { - if err := server.LoadAndRun(); err != nil { + if err := controller.LoadAndRun(); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n\n", err) os.Exit(1) } diff --git a/config/config.go b/config/config.go index 1412212..6175ae1 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-shipper/monitoring" "github.com/elastic/elastic-agent-shipper/queue" + "github.com/elastic/elastic-agent-shipper/server" "github.com/elastic/go-ucfg/json" ) @@ -43,6 +44,7 @@ type ShipperConfig struct { Port int `config:"port"` //Port to listen on Monitor monitoring.Config `config:"monitoring"` //Queue monitoring settings Queue queue.Config `config:"queue"` //Queue settings + Server server.Config `config:"server"` //gRPC Server settings } // ReadConfig returns the populated config from the specified path @@ -64,6 +66,7 @@ func ReadConfig() (ShipperConfig, error) { Log: logp.DefaultConfig(logp.SystemdEnvironment), Monitor: monitoring.DefaultConfig(), Queue: queue.DefaultConfig(), + Server: server.DefaultConfig(), } err = raw.Unpack(&config) if err != nil { @@ -84,6 +87,7 @@ func ReadConfigFromJSON(raw string) (ShipperConfig, error) { Log: logp.DefaultConfig(logp.SystemdEnvironment), Monitor: monitoring.DefaultConfig(), Queue: queue.DefaultConfig(), + Server: server.DefaultConfig(), } err = rawCfg.Unpack(&shipperConfig) if err != nil { diff --git a/server/controller_client.go b/controller/controller_client.go similarity index 99% rename from server/controller_client.go rename to controller/controller_client.go index 8ab26a4..14b2b64 100644 --- a/server/controller_client.go +++ b/controller/controller_client.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -package server +package controller import ( "context" diff --git a/server/controller_client_test.go b/controller/controller_client_test.go similarity index 99% rename from server/controller_client_test.go rename to controller/controller_client_test.go index 93c613e..02237c2 100644 --- a/server/controller_client_test.go +++ b/controller/controller_client_test.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -package server +package controller import ( "context" diff --git a/server/run.go b/controller/run.go similarity index 96% rename from server/run.go rename to controller/run.go index 28793c2..90de2e8 100644 --- a/server/run.go +++ b/controller/run.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -package server +package controller import ( "context" @@ -22,6 +22,7 @@ import ( "github.com/elastic/elastic-agent-shipper/monitoring" "github.com/elastic/elastic-agent-shipper/output" "github.com/elastic/elastic-agent-shipper/queue" + "github.com/elastic/elastic-agent-shipper/server" pb "github.com/elastic/elastic-agent-shipper-client/pkg/proto" ) @@ -104,7 +105,7 @@ func (c *clientHandler) Run(cfg config.ShipperConfig, unit *client.Unit) error { opts = []grpc.ServerOption{grpc.Creds(creds)} } grpcServer := grpc.NewServer(opts...) - shipperServer, err := NewShipperServer(queue) + shipperServer, err := server.NewShipperServer(cfg.Server, queue) if err != nil { return fmt.Errorf("failed to initialise the server: %w", err) } diff --git a/server/config.go b/server/config.go new file mode 100644 index 0000000..750413d --- /dev/null +++ b/server/config.go @@ -0,0 +1,20 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package server + +type Config struct { + // StrictMode means that every incoming event will be validated against the + // list of required fields. This introduces some additional overhead but can + // be really handy for client developers on the debugging stage. + // Normally, it should be disabled during production use and enabled for testing. + StrictMode bool `config:"strict_mode"` +} + +// DefaultConfig returns default configuration for the gRPC server +func DefaultConfig() Config { + return Config{ + StrictMode: false, + } +} diff --git a/server/server.go b/server/server.go index 7f508bd..0830e1d 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "strings" "sync" "time" @@ -52,11 +53,13 @@ type shipperServer struct { ctx context.Context stop func() + cfg Config + pb.UnimplementedProducerServer } // NewShipperServer creates a new server instance for handling gRPC endpoints. -func NewShipperServer(publisher Publisher) (ShipperServer, error) { +func NewShipperServer(cfg Config, publisher Publisher) (ShipperServer, error) { if publisher == nil { return nil, errors.New("publisher cannot be nil") } @@ -71,6 +74,7 @@ func NewShipperServer(publisher Publisher) (ShipperServer, error) { logger: logp.NewLogger("shipper-server"), publisher: publisher, close: &sync.Once{}, + cfg: cfg, } s.ctx, s.stop = context.WithCancel(context.Background()) @@ -103,6 +107,15 @@ func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.Publis return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid)) } + if serv.cfg.StrictMode { + for _, e := range req.Events { + err := serv.validateEvent(e) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + } + } + for _, e := range req.Events { _, err := serv.publisher.Publish(e) if err == nil { @@ -190,3 +203,65 @@ func (serv *shipperServer) Close() error { return nil } + +func (serv *shipperServer) validateEvent(m *messages.Event) error { + var msgs []string + + if err := m.Timestamp.CheckValid(); err != nil { + msgs = append(msgs, fmt.Sprintf("timestamp: %s", err)) + } + + if err := serv.validateDataStream(m.DataStream); err != nil { + msgs = append(msgs, fmt.Sprintf("datastream: %s", err)) + } + + if err := serv.validateSource(m.Source); err != nil { + msgs = append(msgs, fmt.Sprintf("source: %s", err)) + } + + if len(msgs) == 0 { + return nil + } + + return errors.New(strings.Join(msgs, "; ")) +} + +func (serv *shipperServer) validateSource(s *messages.Source) error { + if s == nil { + return fmt.Errorf("cannot be nil") + } + + var msgs []string + if s.InputId == "" { + msgs = append(msgs, "input_id is a required field") + } + + if len(msgs) == 0 { + return nil + } + + return errors.New(strings.Join(msgs, "; ")) +} + +func (serv *shipperServer) validateDataStream(ds *messages.DataStream) error { + if ds == nil { + return fmt.Errorf("cannot be nil") + } + + var msgs []string + if ds.Dataset == "" { + msgs = append(msgs, "dataset is a required field") + } + if ds.Namespace == "" { + msgs = append(msgs, "namespace is a required field") + } + if ds.Type == "" { + msgs = append(msgs, "type is a required field") + } + + if len(msgs) == 0 { + return nil + } + + return errors.New(strings.Join(msgs, "; ")) +} diff --git a/server/server_test.go b/server/server_test.go index a21458b..e870ed1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,7 +18,9 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -54,19 +56,14 @@ func TestPublish(t *testing.T) { publisher := &publisherMock{ persistedIndex: 42, } - shipper, err := NewShipperServer(publisher) + shipper, err := NewShipperServer(DefaultConfig(), publisher) defer func() { _ = shipper.Close() }() require.NoError(t, err) client, stop := startServer(t, ctx, shipper) defer stop() // get the current UUID - pirCtx, cancel := context.WithCancel(ctx) - consumer, err := client.PersistedIndex(pirCtx, &messages.PersistedIndexRequest{}) - require.NoError(t, err) - pir, err := consumer.Recv() - require.NoError(t, err) - cancel() // close the stream + pir := getPersistedIndex(t, ctx, client) t.Run("should successfully publish a batch", func(t *testing.T) { publisher.q = make([]*messages.Event, 0, 3) @@ -134,6 +131,127 @@ func TestPublish(t *testing.T) { require.Contains(t, err.Error(), "UUID does not match") require.Nil(t, reply) }) + + t.Run("should return validation errors", func(t *testing.T) { + cases := []struct { + name string + event *messages.Event + expectedMsg string + }{ + { + name: "no timestamp", + event: &messages.Event{ + Source: &messages.Source{ + InputId: "input", + StreamId: "stream", + }, + DataStream: &messages.DataStream{ + Type: "log", + Dataset: "default", + Namespace: "default", + }, + Metadata: sampleValues, + Fields: sampleValues, + }, + expectedMsg: "timestamp: proto:\u00a0invalid nil Timestamp", + }, + { + name: "no source", + event: &messages.Event{ + Timestamp: timestamppb.Now(), + DataStream: &messages.DataStream{ + Type: "log", + Dataset: "default", + Namespace: "default", + }, + Metadata: sampleValues, + Fields: sampleValues, + }, + expectedMsg: "source: cannot be nil", + }, + { + name: "no input ID", + event: &messages.Event{ + Timestamp: timestamppb.Now(), + Source: &messages.Source{ + StreamId: "stream", + }, + DataStream: &messages.DataStream{ + Type: "log", + Dataset: "default", + Namespace: "default", + }, + Metadata: sampleValues, + Fields: sampleValues, + }, + expectedMsg: "source: input_id is a required field", + }, + { + name: "no datastream", + event: &messages.Event{ + Timestamp: timestamppb.Now(), + Source: &messages.Source{ + InputId: "input", + StreamId: "stream", + }, + Metadata: sampleValues, + Fields: sampleValues, + }, + expectedMsg: "datastream: cannot be nil", + }, + { + name: "invalid data stream", + event: &messages.Event{ + Timestamp: timestamppb.Now(), + Source: &messages.Source{ + InputId: "input", + StreamId: "stream", + }, + DataStream: &messages.DataStream{}, + Metadata: sampleValues, + Fields: sampleValues, + }, + expectedMsg: "datastream: dataset is a required field; namespace is a required field; type is a required field", + }, + } + + publisher.q = make([]*messages.Event, 0, len(cases)) + + cfg := Config{ + StrictMode: true, // so we can test the validation + } + + strictShipper, err := NewShipperServer(cfg, publisher) + defer func() { _ = strictShipper.Close() }() + require.NoError(t, err) + strictClient, stop := startServer(t, ctx, strictShipper) + defer stop() + strictPir := getPersistedIndex(t, ctx, strictClient) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + reply, err := strictClient.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: strictPir.Uuid, + Events: []*messages.Event{tc.event}, + }) + require.Error(t, err) + require.Nil(t, reply) + + status, ok := status.FromError(err) + require.True(t, ok, "expected gRPC error") + require.Equal(t, codes.InvalidArgument, status.Code()) + require.Equal(t, tc.expectedMsg, status.Message()) + + // no validation in non-strict mode + reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: []*messages.Event{tc.event}, + }) + require.NoError(t, err) + require.Equal(t, uint32(1), reply.AcceptedCount) + }) + } + }) } func TestPersistedIndex(t *testing.T) { @@ -142,7 +260,7 @@ func TestPersistedIndex(t *testing.T) { publisher := &publisherMock{persistedIndex: 42} t.Run("server should send updates to the clients", func(t *testing.T) { - shipper, err := NewShipperServer(publisher) + shipper, err := NewShipperServer(DefaultConfig(), publisher) defer func() { _ = shipper.Close() }() require.NoError(t, err) client, stop := startServer(t, ctx, shipper) @@ -168,7 +286,7 @@ func TestPersistedIndex(t *testing.T) { }) t.Run("server should properly shutdown", func(t *testing.T) { - shipper, err := NewShipperServer(publisher) + shipper, err := NewShipperServer(DefaultConfig(), publisher) require.NoError(t, err) client, stop := startServer(t, ctx, shipper) defer stop() @@ -232,6 +350,16 @@ func createConsumers(t *testing.T, ctx context.Context, client pb.ProducerClient return cl } +func getPersistedIndex(t *testing.T, ctx context.Context, client pb.ProducerClient) *messages.PersistedIndexReply { + pirCtx, cancel := context.WithCancel(ctx) + defer cancel() + consumer, err := client.PersistedIndex(pirCtx, &messages.PersistedIndexRequest{}) + require.NoError(t, err) + pir, err := consumer.Recv() + require.NoError(t, err) + return pir +} + type consumerList struct { consumers []pb.Producer_PersistedIndexClient stop func()