-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfilter.go
131 lines (107 loc) · 2.94 KB
/
filter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package device_filter
import (
"bufio"
"fmt"
"os"
"github.com/taenzeyang/gpu-device-filter/nvml"
"github.com/taenzeyang/gpu-device-filter/util"
log "github.com/sirupsen/logrus"
"sigs.k8s.io/yaml"
)
type Flags struct {
ConfigFile string
SelectedConfig string
}
type DeviceInfo struct {
Index int
DeviceId string
}
func ParseConfigFile(f *Flags) (*Spec, error) {
var err error
var configYaml []byte
if f.ConfigFile == "-" {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
configYaml = append(configYaml, scanner.Bytes()...)
configYaml = append(configYaml, '\n')
}
} else {
configYaml, err = os.ReadFile(f.ConfigFile)
if err != nil {
return nil, fmt.Errorf("read error: %v", err)
}
}
var spec Spec
err = yaml.Unmarshal(configYaml, &spec)
if err != nil {
return nil, fmt.Errorf("unmarshal error: %v", err)
}
return &spec, nil
}
func GetSelectedMigConfig(f *Flags, spec *Spec) (MigConfigSpecSlice, error) {
if len(spec.MigConfigs) > 1 && f.SelectedConfig == "" {
return nil, fmt.Errorf("missing required flag 'selected-config' when more than one config available")
}
if len(spec.MigConfigs) == 1 && f.SelectedConfig == "" {
for c := range spec.MigConfigs {
f.SelectedConfig = c
}
}
if _, exists := spec.MigConfigs[f.SelectedConfig]; !exists {
return nil, fmt.Errorf("selected mig-config not present: %v", f.SelectedConfig)
}
return spec.MigConfigs[f.SelectedConfig], nil
}
func DeviceFilter(migConfig MigConfigSpecSlice) ([]DeviceInfo, error) {
n := nvml.New()
err := util.NvmlInit(n)
if err != nil {
return nil, fmt.Errorf("error initializing NVML: %v", err)
}
defer util.TryNvmlShutdown(n)
deviceIDs, err := util.GetGPUDeviceIDs()
if err != nil {
return nil, fmt.Errorf("Error enumerating GPU device IDs: %v", err)
}
deviceInfos := make([]DeviceInfo, 0)
for _, mc := range migConfig {
if mc.DeviceFilter == nil {
log.Infof("Walking Config for (devices=%v)", mc.Devices)
} else {
log.Infof("Walking Config for (device-filter=%v, devices=%v)", mc.DeviceFilter, mc.Devices)
}
for i, deviceID := range deviceIDs {
if !mc.MatchesDeviceFilter(deviceID) {
continue
}
if !mc.MatchesDevices(i) {
continue
}
deviceInfos = append(deviceInfos, DeviceInfo{
Index: i,
DeviceId: deviceID.String(),
})
}
}
return deviceInfos, nil
}
func Apply() ([]DeviceInfo, error) {
f := &Flags{ConfigFile: "/var/run/config.yaml", SelectedConfig: "test"}
//TODO: Check flags
log.Infof("Parsing config file...")
spec, err := ParseConfigFile(f)
if err != nil {
log.Errorf("error parsing config file: %v", err)
}
log.Infof("Selecting specific config...")
migConfig, err := GetSelectedMigConfig(f, spec)
if err != nil {
log.Errorf("error selecting config: %v", err)
}
log.Infof("Selecting specific devices...")
deviceInfos, err := DeviceFilter(migConfig)
if err != nil {
log.Errorf("error selecting specific devices: %v", err)
}
return deviceInfos, nil
}