Skip to content

Commit

Permalink
nvidia: support disabling the nvidia plugin (#8353)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahmood Ali authored Jul 21, 2020
1 parent ae76263 commit 60dd7ae
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
6 changes: 5 additions & 1 deletion client/devicemanager/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,11 @@ START:

// Start fingerprinting
fingerprintCh, err := devicePlugin.Fingerprint(i.ctx)
if err != nil {
if err == device.ErrPluginDisabled {
i.logger.Info("fingerprinting failed: plugin is not enabled")
i.handleFingerprintError()
return
} else if err != nil {
i.logger.Error("fingerprinting failed", "error", err)
i.handleFingerprintError()
return
Expand Down
24 changes: 22 additions & 2 deletions devices/gpu/nvidia/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ const (
// notAvailable value is returned to nomad server in case some properties were
// undetected by nvml driver
notAvailable = "N/A"
)

const (
// Nvidia-container-runtime environment variable names
NvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"
)
Expand Down Expand Up @@ -59,6 +57,10 @@ var (

// configSpec is the specification of the plugin's configuration
configSpec = hclspec.NewObject(map[string]*hclspec.Spec{
"enabled": hclspec.NewDefault(
hclspec.NewAttr("enabled", "bool", false),
hclspec.NewLiteral("true"),
),
"ignored_gpu_ids": hclspec.NewDefault(
hclspec.NewAttr("ignored_gpu_ids", "list(string)", false),
hclspec.NewLiteral("[]"),
Expand All @@ -72,12 +74,16 @@ var (

// Config contains configuration information for the plugin.
type Config struct {
Enabled bool `codec:"enabled"`
IgnoredGPUIDs []string `codec:"ignored_gpu_ids"`
FingerprintPeriod string `codec:"fingerprint_period"`
}

// NvidiaDevice contains all plugin specific data
type NvidiaDevice struct {
// enabled indicates whether the plugin should be enabled
enabled bool

// nvmlClient is used to get data from nvidia
nvmlClient nvml.NvmlClient

Expand Down Expand Up @@ -133,6 +139,8 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error {
}
}

d.enabled = config.Enabled

for _, ignoredGPUId := range config.IgnoredGPUIDs {
d.ignoredGPUIDs[ignoredGPUId] = struct{}{}
}
Expand All @@ -149,6 +157,10 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error {
// Fingerprint streams detected devices. If device changes are detected or the
// devices health changes, messages will be emitted.
func (d *NvidiaDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) {
if !d.enabled {
return nil, device.ErrPluginDisabled
}

outCh := make(chan *device.FingerprintResponse)
go d.fingerprint(ctx, outCh)
return outCh, nil
Expand All @@ -169,6 +181,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation
if len(deviceIDs) == 0 {
return &device.ContainerReservation{}, nil
}
if !d.enabled {
return nil, device.ErrPluginDisabled
}

// Due to the asynchronous nature of NvidiaPlugin, there is a possibility
// of race condition
//
Expand Down Expand Up @@ -202,6 +218,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation

// Stats streams statistics for the detected devices.
func (d *NvidiaDevice) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) {
if !d.enabled {
return nil, device.ErrPluginDisabled
}

outCh := make(chan *device.StatsResponse)
go d.stats(ctx, outCh, interval)
return outCh, nil
Expand Down
46 changes: 36 additions & 10 deletions devices/gpu/nvidia/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
}

func TestReserve(t *testing.T) {
for _, testCase := range []struct {
cases := []struct {
Name string
ExpectedReservation *device.ContainerReservation
ExpectedError error
Expand All @@ -47,7 +47,8 @@ func TestReserve(t *testing.T) {
"UUID3",
},
Device: &NvidiaDevice{
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -66,7 +67,8 @@ func TestReserve(t *testing.T) {
devices: map[string]struct{}{
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -88,7 +90,8 @@ func TestReserve(t *testing.T) {
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -102,13 +105,36 @@ func TestReserve(t *testing.T) {
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
} {
actualReservation, actualError := testCase.Device.Reserve(testCase.RequestedIDs)
req := require.New(t)
req.Equal(testCase.ExpectedReservation, actualReservation)
req.Equal(testCase.ExpectedError, actualError)
{
Name: "Device is disabled",
ExpectedReservation: nil,
ExpectedError: device.ErrPluginDisabled,
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: false,
},
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
actualReservation, actualError := c.Device.Reserve(c.RequestedIDs)
require.Equal(t, c.ExpectedReservation, actualReservation)
require.Equal(t, c.ExpectedError, actualError)
})
}
}
5 changes: 5 additions & 0 deletions plugins/device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ const (
DeviceTypeGPU = "gpu"
)

var (
// ErrPluginDisabled indicates that the device plugin is disabled
ErrPluginDisabled = fmt.Errorf("device is not enabled")
)

// DevicePlugin is the interface for a plugin that can expose detected devices
// to Nomad and inform it how to mount them.
type DevicePlugin interface {
Expand Down

0 comments on commit 60dd7ae

Please sign in to comment.