diff --git a/x-pack/osquerybeat/beater/osquery_runner_test.go b/x-pack/osquerybeat/beater/osquery_runner_test.go index b1fe1c132e1..e04bcb5c1e0 100644 --- a/x-pack/osquerybeat/beater/osquery_runner_test.go +++ b/x-pack/osquerybeat/beater/osquery_runner_test.go @@ -12,6 +12,8 @@ import ( "golang.org/x/sync/errgroup" + "github.com/google/go-cmp/cmp" + "github.com/elastic/beats/v7/libbeat/logp" "github.com/elastic/beats/v7/x-pack/osquerybeat/internal/config" "github.com/elastic/beats/v7/x-pack/osquerybeat/internal/osqd" @@ -84,9 +86,7 @@ func TestOsqueryRunnerCancellable(t *testing.T) { } // Cancel - go func() { - cn() - }() + cn() // Wait for runner stop er := waitGroupWithTimeout(parentCtx, g, to) @@ -94,3 +94,84 @@ func TestOsqueryRunnerCancellable(t *testing.T) { t.Fatal("failed running:", er) } } + +func TestOsqueryRunnerRestart(t *testing.T) { + to := 10 * time.Second + + parentCtx := context.Background() + logger := logp.NewLogger("osquery_runner") + + runCh := make(chan struct{}, 1) + + var runs int + + runfn := func(ctx context.Context, flags osqd.Flags, inputCh <-chan []config.InputConfig) error { + runs++ + runCh <- struct{}{} + <-ctx.Done() + return nil + } + + ctx, cn := context.WithCancel(parentCtx) + defer cn() + + g, ctx := errgroup.WithContext(ctx) + + // Start runner + runner := newOsqueryRunner(logger) + g.Go(func() error { + return runner.Run(ctx, runfn) + }) + + // Sent input that will start the runner function + runner.Update(ctx, nil) + + // Wait for runner start + err := waitForStart(ctx, runCh, to) + if err != nil { + t.Fatal("failed starting:", err) + } + + inputConfigs := []config.InputConfig{ + { + Osquery: &config.OsqueryConfig{ + Options: map[string]interface{}{ + "foo": "bar", + }, + }, + }, + } + + // Update flags, this should restart the run function + runner.Update(ctx, inputConfigs) + + // Should get another run + err = waitForStart(ctx, runCh, to) + if err != nil { + t.Fatal("failed starting after flags update:", err) + } + + // Update with the same flags, should not restart the runner function + runner.Update(ctx, inputConfigs) + + // Should timeout on waiting for another run + err = waitForStart(ctx, runCh, 300*time.Millisecond) + if err != context.DeadlineExceeded { + t.Fatal("unexpected error type after update with the same flags:", err) + } + + // Cancel + cn() + + // Wait for runner stop + er := waitGroupWithTimeout(parentCtx, g, to) + if er != nil && !errors.Is(er, context.Canceled) { + t.Fatal("failed running:", er) + } + + // Check that there were total of 2 executions of run function + diff := cmp.Diff(2, runs) + if diff != "" { + t.Error(diff) + } +} diff --git a/x-pack/osquerybeat/internal/fetch/fetch.go b/x-pack/osquerybeat/internal/fetch/fetch.go index 4de7bc326b8..795869d7f2e 100644 --- a/x-pack/osquerybeat/internal/fetch/fetch.go +++ b/x-pack/osquerybeat/internal/fetch/fetch.go @@ -5,21 +5,31 @@ package fetch import ( + "context" "fmt" "io/ioutil" "log" "net/http" "os" + "strings" "github.com/elastic/beats/v7/x-pack/osquerybeat/internal/hash" ) -func Download(url, fp string) (hashout string, err error) { +// Download downloads the osquery distro package +// writes the content into a given filepath +// returns the sha256 hash +func Download(ctx context.Context, url, fp string) (hashout string, err error) { log.Printf("Download %s to %s", url, fp) cli := http.Client{} - res, err := cli.Get(url) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return + } + + res, err := cli.Do(req) if err != nil { return } @@ -32,7 +42,7 @@ func Download(url, fp string) (hashout string, err error) { if err != nil { log.Printf("Failed to read the error response body: %v", err) } else { - s = string(b) + s = strings.TrimSpace(string(b)) } return hashout, fmt.Errorf("failed fetch %s, status: %d, message: %s", url, res.StatusCode, s) } diff --git a/x-pack/osquerybeat/internal/fetch/fetch_test.go b/x-pack/osquerybeat/internal/fetch/fetch_test.go new file mode 100644 index 00000000000..f234a78f256 --- /dev/null +++ b/x-pack/osquerybeat/internal/fetch/fetch_test.go @@ -0,0 +1,92 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package fetch + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gofrs/uuid" + "github.com/google/go-cmp/cmp" +) + +func TestDownload(t *testing.T) { + ctx := context.Background() + + localFilePathUUID := func() string { + return uuid.Must(uuid.NewV4()).String() + } + tests := []struct { + Name string + Path string + LocalFilePath string + Status int + Payload string + Hash string + ErrStr string + }{ + { + Name: "Http OK", + Path: "/ok", + LocalFilePath: localFilePathUUID(), + Status: http.StatusOK, + Payload: "serenity now", + Hash: "d1071dfdfd6a5bdf08d9b110f664731cf327cc3d341038f0739699690b599281", + }, + { + Name: "Http OK, empty local file path", + Path: "/ok2", + LocalFilePath: "", + Status: http.StatusOK, + Payload: "serenity now", + Hash: "d1071dfdfd6a5bdf08d9b110f664731cf327cc3d341038f0739699690b599281", + ErrStr: "no such file or directory", + }, + { + Name: "Http not found", + Path: "/notfound", + LocalFilePath: localFilePathUUID(), + Payload: "file not found", + Status: http.StatusNotFound, + ErrStr: "file not found", + }, + } + + mux := http.NewServeMux() + for _, tc := range tests { + mux.HandleFunc(tc.Path, func(payload string, status int) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, payload, status) + } + }(tc.Payload, tc.Status)) + } + + svr := httptest.NewServer(mux) + defer svr.Close() + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + hash, err := Download(ctx, svr.URL+tc.Path, tc.LocalFilePath) + defer os.Remove(tc.LocalFilePath) + + if err != nil { + if tc.ErrStr == "" { + t.Fatal("unexpected download error:", err) + } + return + } + + diff := cmp.Diff(tc.Hash, hash) + if diff != "" { + t.Fatal(diff) + } + + }) + } + +} diff --git a/x-pack/osquerybeat/scripts/mage/distro.go b/x-pack/osquerybeat/scripts/mage/distro.go index dde52e3e7b1..1be99ae3f6d 100644 --- a/x-pack/osquerybeat/scripts/mage/distro.go +++ b/x-pack/osquerybeat/scripts/mage/distro.go @@ -5,6 +5,7 @@ package mage import ( + "context" "errors" "fmt" "io/ioutil" @@ -121,7 +122,7 @@ func checkCacheAndFetch(osarch distro.OSArch, spec distro.Spec) (fetched bool, e log.Printf("Hash mismatch, expected: %s, got: %s.", specHash, fileHash) } - fileHash, err = fetch.Download(url, fp) + fileHash, err = fetch.Download(context.Background(), url, fp) if err != nil { log.Printf("File %s fetch failed, err: %v", url, err) return