Skip to content

Commit

Permalink
Consistently apply Unix socket settings (#277)
Browse files Browse the repository at this point in the history
Previously, we only supported setting the group for the server-side
socket. This change makes it possible to set it on the client side as
well. Also fixes a bug where the gRPC broker on the server side would
previously not consume the directory/group environment variables.
  • Loading branch information
tomhjp authored Sep 5, 2023
1 parent c1fefa8 commit b8dba49
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 40 deletions.
50 changes: 45 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type Client struct {
// forcefully killed.
processKilled bool

hostSocketDir string
unixSocketCfg UnixSocketConfig
}

// NegotiatedVersion returns the protocol version negotiated with the server.
Expand Down Expand Up @@ -240,6 +240,28 @@ type ClientConfig struct {
// SkipHostEnv allows plugins to run without inheriting the parent process'
// environment variables.
SkipHostEnv bool

// UnixSocketConfig configures additional options for any Unix sockets
// that are created. Not normally required. Not supported on Windows.
UnixSocketConfig *UnixSocketConfig
}

type UnixSocketConfig struct {
// If set, go-plugin will change the owner of any Unix sockets created to
// this group, and set them as group-writable. Can be a name or gid. The
// client process must be a member of this group or chown will fail.
Group string

// The directory to create Unix sockets in. Internally managed by go-plugin
// and deleted when the plugin is killed.
directory string
}

func unixSocketConfigFromEnv() UnixSocketConfig {
return UnixSocketConfig{
Group: os.Getenv(EnvUnixSocketGroup),
directory: os.Getenv(EnvUnixSocketDir),
}
}

// ReattachConfig is used to configure a client to reattach to an
Expand Down Expand Up @@ -445,7 +467,7 @@ func (c *Client) Kill() {
c.l.Lock()
runner := c.runner
addr := c.address
hostSocketDir := c.hostSocketDir
hostSocketDir := c.unixSocketCfg.directory
c.l.Unlock()

// If there is no runner or ID, there is nothing to kill.
Expand Down Expand Up @@ -629,15 +651,33 @@ func (c *Client) Start() (addr net.Addr, err error) {
}
}

if c.config.UnixSocketConfig != nil {
c.unixSocketCfg.Group = c.config.UnixSocketConfig.Group
}

if c.unixSocketCfg.Group != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketGroup, c.unixSocketCfg.Group))
}

var runner runner.Runner
switch {
case c.config.RunnerFunc != nil:
c.hostSocketDir, err = os.MkdirTemp("", "")
c.unixSocketCfg.directory, err = os.MkdirTemp("", "plugin-dir")
if err != nil {
return nil, err
}
c.logger.Trace("created temporary directory for unix sockets", "dir", c.hostSocketDir)
runner, err = c.config.RunnerFunc(c.logger, cmd, c.hostSocketDir)
// os.MkdirTemp creates folders with 0o700, so if we have a group
// configured we need to make it group-writable.
if c.unixSocketCfg.Group != "" {
err = setGroupWritable(c.unixSocketCfg.directory, c.unixSocketCfg.Group, 0o770)
if err != nil {
return nil, err
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.directory))
c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.directory)

runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.directory)
if err != nil {
return nil, err
}
Expand Down
97 changes: 97 additions & 0 deletions client_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !windows
// +build !windows

package plugin

import (
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"syscall"
"testing"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin/internal/cmdrunner"
"github.com/hashicorp/go-plugin/runner"
)

func TestSetGroup(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("go-plugin doesn't support unix sockets on Windows")
}

group, err := user.LookupGroupId(fmt.Sprintf("%d", os.Getgid()))
if err != nil {
t.Fatal(err)
}
for name, tc := range map[string]struct {
group string
}{
"as integer": {fmt.Sprintf("%d", os.Getgid())},
"as name": {group.Name},
} {
t.Run(name, func(t *testing.T) {
process := helperProcess("mock")
c := NewClient(&ClientConfig{
HandshakeConfig: testHandshake,
Plugins: testPluginMap,
UnixSocketConfig: &UnixSocketConfig{
Group: tc.group,
},
RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) {
// Run tests inside the RunnerFunc to ensure we don't race
// with the code that deletes tmpDir when the client fails
// to start properly.

// Test that it creates a directory with the proper owners and permissions.
info, err := os.Lstat(tmpDir)
if err != nil {
t.Fatal(err)
}
if info.Mode()&os.ModePerm != 0o770 {
t.Fatal(info.Mode())
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
t.Fatal()
}
if stat.Gid != uint32(os.Getgid()) {
t.Fatalf("Expected %d, but got %d", os.Getgid(), stat.Gid)
}

// Check the correct environment variables were set to forward
// Unix socket config onto the plugin.
var foundUnixSocketDir, foundUnixSocketGroup bool
for _, env := range cmd.Env {
if env == fmt.Sprintf("%s=%s", EnvUnixSocketDir, tmpDir) {
foundUnixSocketDir = true
}
if env == fmt.Sprintf("%s=%s", EnvUnixSocketGroup, tc.group) {
foundUnixSocketGroup = true
}
}
if !foundUnixSocketDir {
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketDir, cmd.Env)
}
if !foundUnixSocketGroup {
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketGroup, cmd.Env)
}

process.Env = append(process.Env, cmd.Env...)
return cmdrunner.NewCmdRunner(l, process)
},
})
defer c.Kill()

_, err := c.Start()
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}
})
}
}
8 changes: 4 additions & 4 deletions grpc_broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ type GRPCBroker struct {
doneCh chan struct{}
o sync.Once

socketDir string
unixSocketCfg UnixSocketConfig
addrTranslator runner.AddrTranslator

sync.Mutex
Expand All @@ -279,14 +279,14 @@ type gRPCBrokerPending struct {
doneCh chan struct{}
}

func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator runner.AddrTranslator) *GRPCBroker {
func newGRPCBroker(s streamer, tls *tls.Config, unixSocketCfg UnixSocketConfig, addrTranslator runner.AddrTranslator) *GRPCBroker {
return &GRPCBroker{
streamer: s,
streams: make(map[uint32]*gRPCBrokerPending),
tls: tls,
doneCh: make(chan struct{}),

socketDir: socketDir,
unixSocketCfg: unixSocketCfg,
addrTranslator: addrTranslator,
}
}
Expand All @@ -295,7 +295,7 @@ func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator
//
// This should not be called multiple times with the same ID at one time.
func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) {
listener, err := serverListener(b.socketDir)
listener, err := serverListener(b.unixSocketCfg)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) {

// Start the broker.
brokerGRPCClient := newGRPCBrokerClient(conn)
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.hostSocketDir, c.runner)
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.unixSocketCfg, c.runner)
go broker.Run()
go brokerGRPCClient.StartStream()

Expand Down
2 changes: 1 addition & 1 deletion grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *GRPCServer) Init() error {
// Register the broker service
brokerServer := newGRPCBrokerServer()
plugin.RegisterGRPCBrokerServer(s.server, brokerServer)
s.broker = newGRPCBroker(brokerServer, s.TLS, "", nil)
s.broker = newGRPCBroker(brokerServer, s.TLS, unixSocketConfigFromEnv(), nil)
go s.broker.Run()

// Register the controller
Expand Down
57 changes: 33 additions & 24 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func Serve(opts *ServeConfig) {
}

// Register a listener so we can accept a connection
listener, err := serverListener(os.Getenv(EnvUnixSocketDir))
listener, err := serverListener(unixSocketConfigFromEnv())
if err != nil {
logger.Error("plugin init error", "error", err)
return
Expand Down Expand Up @@ -496,12 +496,12 @@ func Serve(opts *ServeConfig) {
}
}

func serverListener(dir string) (net.Listener, error) {
func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp()
}

return serverListener_unix(dir)
return serverListener_unix(unixSocketCfg)
}

func serverListener_tcp() (net.Listener, error) {
Expand Down Expand Up @@ -546,8 +546,8 @@ func serverListener_tcp() (net.Listener, error) {
return nil, errors.New("Couldn't bind plugin TCP listener")
}

func serverListener_unix(dir string) (net.Listener, error) {
tf, err := os.CreateTemp(dir, "plugin")
func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
tf, err := os.CreateTemp(unixSocketCfg.directory, "plugin")
if err != nil {
return nil, err
}
Expand All @@ -569,25 +569,8 @@ func serverListener_unix(dir string) (net.Listener, error) {

// By default, unix sockets are only writable by the owner. Set up a custom
// group owner and group write permissions if configured.
if groupString := os.Getenv(EnvUnixSocketGroup); groupString != "" {
groupID, err := strconv.Atoi(groupString)
if err != nil {
group, err := user.LookupGroup(groupString)
if err != nil {
return nil, fmt.Errorf("failed to find group ID from %s=%s environment variable: %w", EnvUnixSocketGroup, groupString, err)
}
groupID, err = strconv.Atoi(group.Gid)
if err != nil {
return nil, fmt.Errorf("failed to parse %q group's Gid as an integer: %w", groupString, err)
}
}

err = os.Chown(path, -1, groupID)
if err != nil {
return nil, err
}

err = os.Chmod(path, 0o660)
if unixSocketCfg.Group != "" {
err = setGroupWritable(path, unixSocketCfg.Group, 0o660)
if err != nil {
return nil, err
}
Expand All @@ -601,6 +584,32 @@ func serverListener_unix(dir string) (net.Listener, error) {
}, nil
}

func setGroupWritable(path, groupString string, mode os.FileMode) error {
groupID, err := strconv.Atoi(groupString)
if err != nil {
group, err := user.LookupGroup(groupString)
if err != nil {
return fmt.Errorf("failed to find gid from %q: %w", groupString, err)
}
groupID, err = strconv.Atoi(group.Gid)
if err != nil {
return fmt.Errorf("failed to parse %q group's gid as an integer: %w", groupString, err)
}
}

err = os.Chown(path, -1, groupID)
if err != nil {
return err
}

err = os.Chmod(path, mode)
if err != nil {
return err
}

return nil
}

// rmListener is an implementation of net.Listener that forwards most
// calls to the listener but also removes a file as part of the close. We
// use this to cleanup the unix domain socket on close.
Expand Down
6 changes: 2 additions & 4 deletions server_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ func TestUnixSocketGroupPermissions(t *testing.T) {
t.Fatal(err)
}
for name, tc := range map[string]struct {
gid string
group string
}{
"as integer": {fmt.Sprintf("%d", os.Getgid())},
"as name": {group.Name},
} {
t.Run(name, func(t *testing.T) {
t.Setenv(EnvUnixSocketGroup, tc.gid)

ln, err := serverListener_unix("")
ln, err := serverListener_unix(UnixSocketConfig{Group: tc.group})
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCSe
}

brokerGRPCClient := newGRPCBrokerClient(conn)
broker := newGRPCBroker(brokerGRPCClient, nil, "", nil)
broker := newGRPCBroker(brokerGRPCClient, nil, UnixSocketConfig{}, nil)
go broker.Run()
go brokerGRPCClient.StartStream()

Expand Down

0 comments on commit b8dba49

Please sign in to comment.