diff --git a/devices/gpu/nvidia/device.go b/devices/gpu/nvidia/device.go index 064161cf5c5..c0abf874396 100644 --- a/devices/gpu/nvidia/device.go +++ b/devices/gpu/nvidia/device.go @@ -28,9 +28,7 @@ const ( // notAvailable value is returned to nomad server in case some properties were // undetected by nvml driver notAvailable = "N/A" -) -const ( // Nvidia-container-runtime environment variable names NvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES" ) @@ -59,6 +57,10 @@ var ( // configSpec is the specification of the plugin's configuration configSpec = hclspec.NewObject(map[string]*hclspec.Spec{ + "enabled": hclspec.NewDefault( + hclspec.NewAttr("enabled", "bool", false), + hclspec.NewLiteral("true"), + ), "ignored_gpu_ids": hclspec.NewDefault( hclspec.NewAttr("ignored_gpu_ids", "list(string)", false), hclspec.NewLiteral("[]"), @@ -68,16 +70,22 @@ var ( hclspec.NewLiteral("\"1m\""), ), }) + + errDeviceNotEnabled = fmt.Errorf("Nvidia device is not enabled") ) // Config contains configuration information for the plugin. type Config struct { + Enabled bool `codec:"enabled"` IgnoredGPUIDs []string `codec:"ignored_gpu_ids"` FingerprintPeriod string `codec:"fingerprint_period"` } // NvidiaDevice contains all plugin specific data type NvidiaDevice struct { + // enabled indicates whether the plugin should be enabled + enabled bool + // nvmlClient is used to get data from nvidia nvmlClient nvml.NvmlClient @@ -133,6 +141,8 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error { } } + d.enabled = config.Enabled + for _, ignoredGPUId := range config.IgnoredGPUIDs { d.ignoredGPUIDs[ignoredGPUId] = struct{}{} } @@ -149,6 +159,10 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error { // Fingerprint streams detected devices. If device changes are detected or the // devices health changes, messages will be emitted. func (d *NvidiaDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) { + if !d.enabled { + return nil, errDeviceNotEnabled + } + outCh := make(chan *device.FingerprintResponse) go d.fingerprint(ctx, outCh) return outCh, nil @@ -169,6 +183,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation if len(deviceIDs) == 0 { return &device.ContainerReservation{}, nil } + if !d.enabled { + return nil, errDeviceNotEnabled + } + // Due to the asynchronous nature of NvidiaPlugin, there is a possibility // of race condition // @@ -202,6 +220,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation // Stats streams statistics for the detected devices. func (d *NvidiaDevice) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) { + if !d.enabled { + return nil, errDeviceNotEnabled + } + outCh := make(chan *device.StatsResponse) go d.stats(ctx, outCh, interval) return outCh, nil diff --git a/devices/gpu/nvidia/device_test.go b/devices/gpu/nvidia/device_test.go index 717491f2b69..0554a0b4205 100644 --- a/devices/gpu/nvidia/device_test.go +++ b/devices/gpu/nvidia/device_test.go @@ -26,7 +26,7 @@ func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) { } func TestReserve(t *testing.T) { - for _, testCase := range []struct { + cases := []struct { Name string ExpectedReservation *device.ContainerReservation ExpectedError error @@ -47,7 +47,8 @@ func TestReserve(t *testing.T) { "UUID3", }, Device: &NvidiaDevice{ - logger: hclog.NewNullLogger(), + logger: hclog.NewNullLogger(), + enabled: true, }, }, { @@ -66,7 +67,8 @@ func TestReserve(t *testing.T) { devices: map[string]struct{}{ "UUID3": {}, }, - logger: hclog.NewNullLogger(), + logger: hclog.NewNullLogger(), + enabled: true, }, }, { @@ -88,7 +90,8 @@ func TestReserve(t *testing.T) { "UUID2": {}, "UUID3": {}, }, - logger: hclog.NewNullLogger(), + logger: hclog.NewNullLogger(), + enabled: true, }, }, { @@ -102,13 +105,36 @@ func TestReserve(t *testing.T) { "UUID2": {}, "UUID3": {}, }, - logger: hclog.NewNullLogger(), + logger: hclog.NewNullLogger(), + enabled: true, }, }, - } { - actualReservation, actualError := testCase.Device.Reserve(testCase.RequestedIDs) - req := require.New(t) - req.Equal(testCase.ExpectedReservation, actualReservation) - req.Equal(testCase.ExpectedError, actualError) + { + Name: "Device is disabled", + ExpectedReservation: nil, + ExpectedError: errDeviceNotEnabled, + RequestedIDs: []string{ + "UUID1", + "UUID2", + "UUID3", + }, + Device: &NvidiaDevice{ + devices: map[string]struct{}{ + "UUID1": {}, + "UUID2": {}, + "UUID3": {}, + }, + logger: hclog.NewNullLogger(), + enabled: false, + }, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + actualReservation, actualError := c.Device.Reserve(c.RequestedIDs) + require.Equal(t, c.ExpectedReservation, actualReservation) + require.Equal(t, c.ExpectedError, actualError) + }) } }