Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit GUI process execution to one per UID #2267

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions client/ui/client_ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"unicode"

Expand All @@ -34,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)

Expand Down Expand Up @@ -62,8 +62,25 @@ func main() {
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")

tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}

var saveLogsInFile bool
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))

flag.Parse()

if saveLogsInFile {
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
err := util.InitLog("trace", logFile)
if err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
}

a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))

Expand All @@ -76,8 +93,12 @@ func main() {
if showSettings || showRoutes {
a.Run()
} else {
if err := checkPIDFile(); err != nil {
log.Errorf("check PID file: %v", err)
running, err := isAnotherProcessRunning()
if err != nil {
log.Errorf("error while checking process: %v", err)
}
if running {
log.Warn("another process is running")
return
}
client.setDefaultFonts()
Expand Down Expand Up @@ -861,19 +882,3 @@ func openURL(url string) error {
}
return err
}

// checkPIDFile exists and return error, or write new.
func checkPIDFile() error {
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
if piddata, err := os.ReadFile(pidFile); err == nil {
if pid, err := strconv.Atoi(string(piddata)); err == nil {
if process, err := os.FindProcess(pid); err == nil {
if err := process.Signal(syscall.Signal(0)); err == nil {
return fmt.Errorf("process already exists: %d", pid)
}
}
}
}

return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
37 changes: 37 additions & 0 deletions client/ui/process.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"os"
"path/filepath"
"strings"

"github.com/shirou/gopsutil/v3/process"
)

func isAnotherProcessRunning() (bool, error) {
processes, err := process.Processes()
if err != nil {
return false, err
}

pid := os.Getpid()
processName := strings.ToLower(filepath.Base(os.Args[0]))

for _, p := range processes {
if int(p.Pid) == pid {
continue
}

runningProcessPath, err := p.Exe()
// most errors are related to short-lived processes
if err != nil {
continue
}

if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
return true, nil
}
}

return false, nil
}
26 changes: 26 additions & 0 deletions client/ui/process_nonwindows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//go:build !windows

package main

import (
"os"

"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)

func isProcessOwnedByCurrentUser(p *process.Process) bool {
currentUserID := os.Getuid()
uids, err := p.Uids()
if err != nil {
log.Errorf("get process uids: %v", err)
return false
}
for _, id := range uids {
log.Debugf("checking process uid: %d", id)
if int(id) == currentUserID {
return true
}
}
return false
}
24 changes: 24 additions & 0 deletions client/ui/process_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package main

import (
"os/user"

"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)

func isProcessOwnedByCurrentUser(p *process.Process) bool {
processUsername, err := p.Username()
if err != nil {
log.Errorf("get process username error: %v", err)
return false
}

currUser, err := user.Current()
if err != nil {
log.Errorf("get current user error: %v", err)
return false
}

return processUsername == currUser.Username
}
Loading