diff --git a/x-pack/osquerybeat/internal/osqd/osqueryd.go b/x-pack/osquerybeat/internal/osqd/osqueryd.go index 09a5d866f4f..206846f985b 100644 --- a/x-pack/osquerybeat/internal/osqd/osqueryd.go +++ b/x-pack/osquerybeat/internal/osqd/osqueryd.go @@ -5,6 +5,7 @@ package osqd import ( + "bufio" "context" "fmt" "io" @@ -263,8 +264,9 @@ func (q *OSQueryD) prepare(ctx context.Context) (func(), error) { // Write the autoload file extensionAutoloadPath := q.resolveDataPath(osqueryAutoload) - if err := ioutil.WriteFile(extensionAutoloadPath, []byte(extensionPath), 0644); err != nil { - return nil, errors.Wrap(err, "failed write osquery extension autoload file") + err = prepareAutoloadFile(extensionAutoloadPath, extensionPath, q.log) + if err != nil { + return nil, errors.Wrapf(err, "failed to prepare extensions autoload file") } // Write the flagsfile in order to lock down/prevent loading default flags from osquery global locations. @@ -286,6 +288,60 @@ func (q *OSQueryD) prepare(ctx context.Context) (func(), error) { return func() {}, nil } +func prepareAutoloadFile(extensionAutoloadPath, mandatoryExtensionPath string, log *logp.Logger) error { + ok, err := fileutil.FileExists(extensionAutoloadPath) + if err != nil { + return errors.Wrapf(err, "failed to check osquery.autoload file exists") + } + + rewrite := false + + if ok { + log.Debugf("Extensions autoload file %s exists, verify the first extension is ours", extensionAutoloadPath) + err = verifyAutoloadFile(extensionAutoloadPath, mandatoryExtensionPath) + if err != nil { + log.Debugf("Extensions autoload file %v verification failed, err: %v, create a new one", extensionAutoloadPath, err) + rewrite = true + } + } else { + log.Debugf("Extensions autoload file %s doesn't exists, create a new one", extensionAutoloadPath) + rewrite = true + } + + if rewrite { + if err := ioutil.WriteFile(extensionAutoloadPath, []byte(mandatoryExtensionPath), 0644); err != nil { + return errors.Wrap(err, "failed write osquery extension autoload file") + } + } + return nil +} + +func verifyAutoloadFile(extensionAutoloadPath, mandatoryExtensionPath string) error { + f, err := os.Open(extensionAutoloadPath) + if err != nil { + return err + } + defer f.Close() + scanner := bufio.NewScanner(f) + for i := 0; scanner.Scan(); i++ { + line := scanner.Text() + if i == 0 { + // Check that the first line is the mandatory extension + if line != mandatoryExtensionPath { + return errors.New("extentsions autoload file is missing mandatory extension in the first line of the file") + } + } + + // Check that the line contains the valid path that exists + _, err := os.Stat(line) + if err != nil { + return err + } + } + + return scanner.Err() +} + func (q *OSQueryD) prepareBinPath() error { // If path to osquery was not set use the current executable path if q.binPath == "" { diff --git a/x-pack/osquerybeat/internal/osqd/osqueryd_test.go b/x-pack/osquerybeat/internal/osqd/osqueryd_test.go index 513e921cd2e..ae8387601dd 100644 --- a/x-pack/osquerybeat/internal/osqd/osqueryd_test.go +++ b/x-pack/osquerybeat/internal/osqd/osqueryd_test.go @@ -5,8 +5,18 @@ package osqd import ( + "bufio" + "errors" + "io/ioutil" + "os" + "path/filepath" "testing" + "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/beats/v7/libbeat/logp" + "github.com/elastic/beats/v7/x-pack/osquerybeat/internal/fileutil" + + "github.com/gofrs/uuid" "github.com/google/go-cmp/cmp" ) @@ -46,3 +56,120 @@ func TestNew(t *testing.T) { t.Error(diff) } } + +func TestVerifyAutoloadFileMissing(t *testing.T) { + dir := uuid.Must(uuid.NewV4()).String() + extensionAutoloadPath := filepath.Join(dir, osqueryAutoload) + mandatoryExtensionPath := filepath.Join(dir, extensionName) + err := verifyAutoloadFile(extensionAutoloadPath, mandatoryExtensionPath) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected error: %v, got: %v", os.ErrNotExist, err) + } +} + +// TestPrepareAutoloadFile tests possibly different states of the osquery.autoload file and that it is restored into the workable state +func TestPrepareAutoloadFile(t *testing.T) { + validLogger := logp.NewLogger("osqueryd_test") + + // Prepare the directory with extension + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + mandatoryExtensionPath := filepath.Join(dir, extensionName) + + // Write fake extension file for testing + err = ioutil.WriteFile(mandatoryExtensionPath, nil, 0644) + if err != nil { + t.Fatal(err) + } + + randomContent := func(sz int) []byte { + b, err := common.RandomBytes(sz) + if err != nil { + t.Fatal(err) + } + return b + } + + tests := []struct { + Name string + FileContent []byte + }{ + { + Name: "Empty file", + FileContent: nil, + }, + { + Name: "File with mandatory extension", + FileContent: []byte(mandatoryExtensionPath), + }, + { + Name: "Missing mandatory extension, should restore the file", + FileContent: []byte(filepath.Join(dir, "foobar.ext")), + }, + { + Name: "User extension path doesn't exists", + FileContent: []byte(mandatoryExtensionPath + "\n" + filepath.Join(dir, "foobar.ext")), + }, + { + Name: "Random garbage", + FileContent: randomContent(1234), + }, + } + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + + // Setup + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(dir) + + extensionAutoloadPath := filepath.Join(dir, osqueryAutoload) + + err = ioutil.WriteFile(extensionAutoloadPath, tc.FileContent, 0644) + if err != nil { + t.Fatal(err) + } + + err = prepareAutoloadFile(extensionAutoloadPath, mandatoryExtensionPath, validLogger) + if err != nil { + t.Fatal(err) + } + + // Check the content, should have our mandatory extension and possibly the other extension paths with each extension existing on the disk + f, err := os.Open(extensionAutoloadPath) + if err != nil { + t.Fatal(err) + } + defer f.Close() + scanner := bufio.NewScanner(f) + for i := 0; scanner.Scan(); i++ { + line := scanner.Text() + if i == 0 { + if line != mandatoryExtensionPath { + t.Fatalf("expected the fist line of the file to be: %v , got: %v", mandatoryExtensionPath, line) + } + } + // Check that it is a valid path to the file on the disk + ok, err := fileutil.FileExists(line) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("expected to have only valid paths to the extensions files that exists, got: %v", line) + } + } + + err = scanner.Err() + if err != nil { + t.Fatal(err) + } + }) + } +}