Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix thread safety in process code #43

Merged
merged 5 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions metric/system/process/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (procStats *Stats) Get() ([]mapstr.M, []mapstr.M, error) {
return nil, nil, fmt.Errorf("error gathering PIDs: %w", err)
}
// We use this to track processes over time.
procStats.ProcsMap = pidMap
procStats.ProcsMap.SetMap(pidMap)

// filter the process list that will be passed down to users
plist = procStats.includeTopProcesses(plist)
Expand Down Expand Up @@ -119,9 +119,8 @@ func (procStats *Stats) GetOne(pid int) (mapstr.M, error) {
if err != nil {
return nil, fmt.Errorf("error fetching PID %d: %w", pid, err)
}
newMap := make(ProcsMap)
newMap[pid] = pidStat
procStats.ProcsMap = newMap

procStats.ProcsMap.SetPid(pid, pidStat)

return procStats.getProcessEvent(&pidStat)
}
Expand All @@ -134,13 +133,14 @@ func (procStats *Stats) GetSelf() (ProcState, error) {
if err != nil {
return ProcState{}, fmt.Errorf("error fetching PID %d: %w", self, err)
}
procStats.ProcsMap[self] = pidStat

procStats.ProcsMap.SetPid(self, pidStat)
return pidStat, nil
}

// pidIter wraps a few lines of generic code that all OS-specific FetchPids() functions must call.
// this also handles the process of adding to the maps/lists in order to limit the code duplication in all the OS implementations
func (procStats *Stats) pidIter(pid int, procMap map[int]ProcState, proclist []ProcState) (map[int]ProcState, []ProcState) {
func (procStats *Stats) pidIter(pid int, procMap ProcsMap, proclist []ProcState) (ProcsMap, []ProcState) {
status, saved, err := procStats.pidFill(pid, true)
if err != nil {
procStats.logger.Debugf("Error fetching PID info for %d, skipping: %s", pid, err)
Expand Down Expand Up @@ -190,7 +190,7 @@ func (procStats *Stats) pidFill(pid int, filter bool) (ProcState, bool, error) {
}

//postprocess with cgroups and percentages
last, ok := procStats.ProcsMap[status.Pid.ValueOr(0)]
last, ok := procStats.ProcsMap.GetPid(status.Pid.ValueOr(0))
status.SampleTime = time.Now()
if procStats.EnableCgroups {
cgStats, err := procStats.cgroups.GetStatsForPid(status.Pid.ValueOr(0))
Expand All @@ -215,7 +215,7 @@ func (procStats *Stats) pidFill(pid int, filter bool) (ProcState, bool, error) {

// cacheCmdLine fills out Env and arg metrics from any stored previous metrics for the pid
func (procStats *Stats) cacheCmdLine(in ProcState) ProcState {
if previousProc, ok := procStats.ProcsMap[in.Pid.ValueOr(0)]; ok {
if previousProc, ok := procStats.ProcsMap.GetPid(in.Pid.ValueOr(0)); ok {
if procStats.CacheCmdLine {
in.Args = previousProc.Args
in.Cmdline = previousProc.Cmdline
Expand Down
40 changes: 37 additions & 3 deletions metric/system/process/process_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package process
import (
"errors"
"fmt"
"sync"

"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/match"
Expand All @@ -33,9 +34,42 @@ import (
sysinfo "github.com/elastic/go-sysinfo"
)

// ProcsMap is a map where the keys are the names of processes and the value is the Process with that name
//ProcsMap is a convinence wrapper for the oft-used ideom of map[int]ProcState
type ProcsMap map[int]ProcState

// ProcsTrack is a thread-safe wrapper for a process Stat object's internal map of processes.
type ProcsTrack struct {
pids ProcsMap
mut sync.Mutex
}

func NewProcsMap() *ProcsTrack {
return &ProcsTrack{
pids: make(ProcsMap, 0),
mut: sync.Mutex{},
}
}

func (pm *ProcsTrack) GetPid(pid int) (ProcState, bool) {
pm.mut.Lock()
defer pm.mut.Unlock()
proc, ok := pm.pids[pid]
return proc, ok
}

func (pm *ProcsTrack) SetPid(pid int, ps ProcState) {
pm.mut.Lock()
defer pm.mut.Unlock()
pm.pids[pid] = ps
}

func (pm *ProcsTrack) SetMap(pids map[int]ProcState) {
pm.mut.Lock()
defer pm.mut.Unlock()
pm.pids = pids

}

// ProcCallback is a function that FetchPid* methods can call at various points to do OS-agnostic processing
type ProcCallback func(in ProcState) (ProcState, error)

Expand All @@ -53,7 +87,7 @@ type CgroupPctStats struct {
type Stats struct {
Hostfs resolve.Resolver
Procs []string
ProcsMap ProcsMap
ProcsMap *ProcsTrack
CPUTicks bool
EnvWhitelist []string
CacheCmdLine bool
Expand Down Expand Up @@ -128,7 +162,7 @@ func (procStats *Stats) Init() error {
procStats.Hostfs = resolve.NewTestResolver("/")
}

procStats.ProcsMap = make(ProcsMap)
procStats.ProcsMap = NewProcsMap()

if len(procStats.Procs) == 0 {
return nil
Expand Down
2 changes: 1 addition & 1 deletion metric/system/process/process_linux_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (procStats *Stats) FetchPids() (ProcsMap, []ProcState, error) {
return nil, nil, fmt.Errorf("error reading directory names: %w", err)
}

procMap := make(ProcsMap, 0)
procMap := make(ProcsMap)
var plist []ProcState

// Iterate over the directory, fetch just enough info so we can filter based on user input.
Expand Down
4 changes: 2 additions & 2 deletions metric/system/process/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ func TestProcMemPercentage(t *testing.T) {
},
}

procStats.ProcsMap = make(ProcsMap)
procStats.ProcsMap[p.Pid.ValueOr(0)] = p
procStats.ProcsMap = NewProcsMap()
procStats.ProcsMap.SetPid(p.Pid.ValueOr(0), p)

rssPercent := GetProcMemPercentage(p, 10000)
assert.Equal(t, rssPercent.ValueOr(0), 0.1416)
Expand Down
67 changes: 67 additions & 0 deletions report/metrics_report_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package report

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/monitoring"
)

func TestSystemMetricsReport(t *testing.T) {
_ = logp.DevelopmentSetup()
logger := logp.L()
err := SetupMetrics(logger, "TestSys", "test")
require.NoError(t, err)

var gotCPU, gotMem, gotInfo bool
testFunc := func(key string, val interface{}) {
if key == "info.uptime.ms" {
gotInfo = true
}
if key == "cpu.total.ticks" {
gotCPU = true
}
if key == "memstats.rss" {
gotMem = true
}
}

//iterate over the processes a few times,
// with the concurrency (hopefully) emulating what might
// happen if this was an HTTP endpoint getting multiple GET requests
iter := 5
var wait sync.WaitGroup
wait.Add(iter)
for i := 0; i < iter; i++ {
go func() {
processMetrics.Do(monitoring.Full, testFunc)
wait.Done()
}()
}

wait.Wait()
assert.True(t, gotCPU, "Didn't find cpu.total.ticks")
assert.True(t, gotMem, "Didn't find memstats.rss")
assert.True(t, gotInfo, "Didn't find info.uptime.ms")
}