diff --git a/CHANGELOG.md b/CHANGELOG.md index 005887bcd7..72d831823f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Fixed +- Fix thread safety in process code #43 + ## [0.4.3] ## Fixed diff --git a/metric/system/process/process.go b/metric/system/process/process.go index 895f22ea88..2d83beeb97 100644 --- a/metric/system/process/process.go +++ b/metric/system/process/process.go @@ -70,7 +70,7 @@ func (procStats *Stats) Get() ([]mapstr.M, []mapstr.M, error) { return nil, nil, fmt.Errorf("error gathering PIDs: %w", err) } // We use this to track processes over time. - procStats.ProcsMap = pidMap + procStats.ProcsMap.SetMap(pidMap) // filter the process list that will be passed down to users plist = procStats.includeTopProcesses(plist) @@ -119,9 +119,8 @@ func (procStats *Stats) GetOne(pid int) (mapstr.M, error) { if err != nil { return nil, fmt.Errorf("error fetching PID %d: %w", pid, err) } - newMap := make(ProcsMap) - newMap[pid] = pidStat - procStats.ProcsMap = newMap + + procStats.ProcsMap.SetPid(pid, pidStat) return procStats.getProcessEvent(&pidStat) } @@ -134,13 +133,14 @@ func (procStats *Stats) GetSelf() (ProcState, error) { if err != nil { return ProcState{}, fmt.Errorf("error fetching PID %d: %w", self, err) } - procStats.ProcsMap[self] = pidStat + + procStats.ProcsMap.SetPid(self, pidStat) return pidStat, nil } // pidIter wraps a few lines of generic code that all OS-specific FetchPids() functions must call. // this also handles the process of adding to the maps/lists in order to limit the code duplication in all the OS implementations -func (procStats *Stats) pidIter(pid int, procMap map[int]ProcState, proclist []ProcState) (map[int]ProcState, []ProcState) { +func (procStats *Stats) pidIter(pid int, procMap ProcsMap, proclist []ProcState) (ProcsMap, []ProcState) { status, saved, err := procStats.pidFill(pid, true) if err != nil { procStats.logger.Debugf("Error fetching PID info for %d, skipping: %s", pid, err) @@ -190,7 +190,7 @@ func (procStats *Stats) pidFill(pid int, filter bool) (ProcState, bool, error) { } //postprocess with cgroups and percentages - last, ok := procStats.ProcsMap[status.Pid.ValueOr(0)] + last, ok := procStats.ProcsMap.GetPid(status.Pid.ValueOr(0)) status.SampleTime = time.Now() if procStats.EnableCgroups { cgStats, err := procStats.cgroups.GetStatsForPid(status.Pid.ValueOr(0)) @@ -215,7 +215,7 @@ func (procStats *Stats) pidFill(pid int, filter bool) (ProcState, bool, error) { // cacheCmdLine fills out Env and arg metrics from any stored previous metrics for the pid func (procStats *Stats) cacheCmdLine(in ProcState) ProcState { - if previousProc, ok := procStats.ProcsMap[in.Pid.ValueOr(0)]; ok { + if previousProc, ok := procStats.ProcsMap.GetPid(in.Pid.ValueOr(0)); ok { if procStats.CacheCmdLine { in.Args = previousProc.Args in.Cmdline = previousProc.Cmdline diff --git a/metric/system/process/process_common.go b/metric/system/process/process_common.go index fe648426ee..674eeca844 100644 --- a/metric/system/process/process_common.go +++ b/metric/system/process/process_common.go @@ -23,6 +23,7 @@ package process import ( "errors" "fmt" + "sync" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/match" @@ -33,9 +34,41 @@ import ( sysinfo "github.com/elastic/go-sysinfo" ) -// ProcsMap is a map where the keys are the names of processes and the value is the Process with that name +//ProcsMap is a convinence wrapper for the oft-used ideom of map[int]ProcState type ProcsMap map[int]ProcState +// ProcsTrack is a thread-safe wrapper for a process Stat object's internal map of processes. +type ProcsTrack struct { + pids ProcsMap + mut sync.RWMutex +} + +func NewProcsTrack() *ProcsTrack { + return &ProcsTrack{ + pids: make(ProcsMap, 0), + } +} + +func (pm *ProcsTrack) GetPid(pid int) (ProcState, bool) { + pm.mut.RLock() + defer pm.mut.RUnlock() + proc, ok := pm.pids[pid] + return proc, ok +} + +func (pm *ProcsTrack) SetPid(pid int, ps ProcState) { + pm.mut.Lock() + defer pm.mut.Unlock() + pm.pids[pid] = ps +} + +func (pm *ProcsTrack) SetMap(pids map[int]ProcState) { + pm.mut.Lock() + defer pm.mut.Unlock() + pm.pids = pids + +} + // ProcCallback is a function that FetchPid* methods can call at various points to do OS-agnostic processing type ProcCallback func(in ProcState) (ProcState, error) @@ -53,7 +86,7 @@ type CgroupPctStats struct { type Stats struct { Hostfs resolve.Resolver Procs []string - ProcsMap ProcsMap + ProcsMap *ProcsTrack CPUTicks bool EnvWhitelist []string CacheCmdLine bool @@ -128,7 +161,7 @@ func (procStats *Stats) Init() error { procStats.Hostfs = resolve.NewTestResolver("/") } - procStats.ProcsMap = make(ProcsMap) + procStats.ProcsMap = NewProcsTrack() if len(procStats.Procs) == 0 { return nil diff --git a/metric/system/process/process_linux_common.go b/metric/system/process/process_linux_common.go index 2a737638e4..622a8be5ae 100644 --- a/metric/system/process/process_linux_common.go +++ b/metric/system/process/process_linux_common.go @@ -59,7 +59,7 @@ func (procStats *Stats) FetchPids() (ProcsMap, []ProcState, error) { return nil, nil, fmt.Errorf("error reading directory names: %w", err) } - procMap := make(ProcsMap, 0) + procMap := make(ProcsMap) var plist []ProcState // Iterate over the directory, fetch just enough info so we can filter based on user input. diff --git a/metric/system/process/process_test.go b/metric/system/process/process_test.go index cf8c786af3..811aa42a23 100644 --- a/metric/system/process/process_test.go +++ b/metric/system/process/process_test.go @@ -206,8 +206,8 @@ func TestProcMemPercentage(t *testing.T) { }, } - procStats.ProcsMap = make(ProcsMap) - procStats.ProcsMap[p.Pid.ValueOr(0)] = p + procStats.ProcsMap = NewProcsTrack() + procStats.ProcsMap.SetPid(p.Pid.ValueOr(0), p) rssPercent := GetProcMemPercentage(p, 10000) assert.Equal(t, rssPercent.ValueOr(0), 0.1416) diff --git a/report/metrics_report_test.go b/report/metrics_report_test.go new file mode 100644 index 0000000000..207b261f0f --- /dev/null +++ b/report/metrics_report_test.go @@ -0,0 +1,70 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package report + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent-libs/monitoring" +) + +func TestSystemMetricsReport(t *testing.T) { + _ = logp.DevelopmentSetup() + logger := logp.L() + err := SetupMetrics(logger, "TestSys", "test") + require.NoError(t, err) + + var gotCPU, gotMem, gotInfo bool + testFunc := func(key string, val interface{}) { + if key == "info.uptime.ms" { + gotInfo = true + } + if key == "cpu.total.ticks" { + gotCPU = true + } + if key == "memstats.rss" { + gotMem = true + } + } + + //iterate over the processes a few times, + // with the concurrency (hopefully) emulating what might + // happen if this was an HTTP endpoint getting multiple GET requests + iter := 100 + var wait sync.WaitGroup + wait.Add(iter) + ch := make(chan struct{}) + for i := 0; i < iter; i++ { + go func() { + <-ch + processMetrics.Do(monitoring.Full, testFunc) + wait.Done() + }() + } + close(ch) + + wait.Wait() + assert.True(t, gotCPU, "Didn't find cpu.total.ticks") + assert.True(t, gotMem, "Didn't find memstats.rss") + assert.True(t, gotInfo, "Didn't find info.uptime.ms") +}