From 2f1b7ad2b504abc98b06bd24626b4f305e73ad60 Mon Sep 17 00:00:00 2001 From: Yury Kulazhenkov Date: Tue, 9 Jan 2024 17:06:14 +0200 Subject: [PATCH] Add additional methods to host/kernel.go New methods are: * BindDriverByBusAndDevice - binds device to the provided driver * UnbindDriverByBusAndDevice unbind device identified by bus and device ID from the driver Both methods allows to work with devices not only on PCI bus. +refactor driver-related methods +add unit-tests for changed methods Signed-off-by: Yury Kulazhenkov --- pkg/consts/constants.go | 9 +- pkg/helper/host.go | 3 +- pkg/helper/mock/mock_helper.go | 28 ++++ pkg/host/host_test.go | 241 +++++++++++++++++++++++++++++++++ pkg/host/kernel.go | 201 ++++++++++++++++++++------- 5 files changed, 427 insertions(+), 55 deletions(-) create mode 100644 pkg/host/host_test.go diff --git a/pkg/consts/constants.go b/pkg/consts/constants.go index d56d36e7f7..e7255368d7 100644 --- a/pkg/consts/constants.go +++ b/pkg/consts/constants.go @@ -77,13 +77,16 @@ const ( CheckpointFileName = "sno-initial-node-state.json" Unknown = "Unknown" - SysBusPciDevices = "/sys/bus/pci/devices" - SysBusPciDrivers = "/sys/bus/pci/drivers" - SysBusPciDriversProbe = "/sys/bus/pci/drivers_probe" + SysBus = "/sys/bus" + SysBusPciDevices = SysBus + "/pci/devices" + SysBusPciDrivers = SysBus + "/pci/drivers" + SysBusPciDriversProbe = SysBus + "/pci/drivers_probe" SysClassNet = "/sys/class/net" ProcKernelCmdLine = "/proc/cmdline" NetClass = 0x02 NumVfsFile = "sriov_numvfs" + BusPci = "pci" + BusVdpa = "vdpa" UdevFolder = "/etc/udev" UdevRulesFolder = UdevFolder + "/rules.d" diff --git a/pkg/helper/host.go b/pkg/helper/host.go index df33257623..0c9655d279 100644 --- a/pkg/helper/host.go +++ b/pkg/helper/host.go @@ -1,4 +1,5 @@ -package helper + +})package helper import ( "sigs.k8s.io/controller-runtime/pkg/log" diff --git a/pkg/helper/mock/mock_helper.go b/pkg/helper/mock/mock_helper.go index 74814f4e64..75015a925e 100644 --- a/pkg/helper/mock/mock_helper.go +++ b/pkg/helper/mock/mock_helper.go @@ -81,6 +81,20 @@ func (mr *MockHostHelpersInterfaceMockRecorder) BindDpdkDriver(arg0, arg1 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDpdkDriver", reflect.TypeOf((*MockHostHelpersInterface)(nil).BindDpdkDriver), arg0, arg1) } +// BindDriverByBusAndDevice mocks base method. +func (m *MockHostHelpersInterface) BindDriverByBusAndDevice(arg0, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BindDriverByBusAndDevice", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// BindDriverByBusAndDevice indicates an expected call of BindDriverByBusAndDevice. +func (mr *MockHostHelpersInterfaceMockRecorder) BindDriverByBusAndDevice(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDriverByBusAndDevice", reflect.TypeOf((*MockHostHelpersInterface)(nil).BindDriverByBusAndDevice), arg0, arg1, arg2) +} + // Chroot mocks base method. func (m *MockHostHelpersInterface) Chroot(arg0 string) (func() error, error) { m.ctrl.T.Helper() @@ -994,6 +1008,20 @@ func (mr *MockHostHelpersInterfaceMockRecorder) Unbind(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unbind", reflect.TypeOf((*MockHostHelpersInterface)(nil).Unbind), arg0) } +// UnbindDriverByBusAndDevice mocks base method. +func (m *MockHostHelpersInterface) UnbindDriverByBusAndDevice(bus, device string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnbindDriverByBusAndDevice", bus, device) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnbindDriverByBusAndDevice indicates an expected call of UnbindDriverByBusAndDevice. +func (mr *MockHostHelpersInterfaceMockRecorder) UnbindDriverByBusAndDevice(bus, device interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnbindDriverByBusAndDevice", reflect.TypeOf((*MockHostHelpersInterface)(nil).UnbindDriverByBusAndDevice), bus, device) +} + // UnbindDriverIfNeeded mocks base method. func (m *MockHostHelpersInterface) UnbindDriverIfNeeded(arg0 string, arg1 bool) error { m.ctrl.T.Helper() diff --git a/pkg/host/host_test.go b/pkg/host/host_test.go new file mode 100644 index 0000000000..e70559ee0f --- /dev/null +++ b/pkg/host/host_test.go @@ -0,0 +1,241 @@ +package host + +import ( + "os" + "path/filepath" + "testing" + + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/vars" + "github.com/k8snetworkplumbingwg/sriov-network-operator/test/util/fakefilesystem" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "go.uber.org/zap/zapcore" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" +) + +const ( + testUnknownDev = "unknown-dev" + testUnknownDriver = "unknown-driver" + testDev = "0000:d8:00.0" + testDevPath = "/sys/bus/pci/devices/" + testDev + testDevDriverPath = testDevPath + "/driver" + testDevDriverOverridePath = testDevPath + "/driver_override" + testDriversRelPath = "../../../../bus/pci/drivers/" + testDevDriverSymlink = testDriversRelPath + testDriver + testDevDriver2Symlink = testDriversRelPath + testDriver2 + testDriver = "test-driver" + testDriverPath = "/sys/bus/pci/drivers/" + testDriver + testDriverBindPath = testDriverPath + "/bind" + testDriverUnbindPath = testDriverPath + "/unbind" + testDriver2 = "vfio-pci" // dpdk + testDriver2Path = "/sys/bus/pci/drivers/" + testDriver2 + testDriver2BindPath = testDriver2Path + "/bind" + testDriver2UnbindPath = testDriver2Path + "/unbind" + testPciDeviceProbePath = "/sys/bus/pci/drivers_probe" +) + +func TestHostManager(t *testing.T) { + log.SetLogger(zap.New( + zap.WriteTo(GinkgoWriter), + zap.Level(zapcore.Level(-2)), + zap.UseDevMode(true))) + RegisterFailHandler(Fail) + RunSpecs(t, "Config Daemon Suite") +} + +func testFileContent(path, expectedContent string) { + d, err := os.ReadFile(filepath.Join(vars.FilesystemRoot, path)) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + ExpectWithOffset(1, string(d)).To(Equal(expectedContent)) +} + +var _ = Describe("Kernel", func() { + Context("Drivers", func() { + var ( + cleanFakeFs func() + k KernelInterface + ) + configureFS := func(f *fakefilesystem.FS) { + var err error + vars.FilesystemRoot, cleanFakeFs, err = f.Use() + Expect(err).ToNot(HaveOccurred()) + } + BeforeEach(func() { + cleanFakeFs = nil + k = newKernelInterface(nil) + }) + AfterEach(func() { + if cleanFakeFs != nil { + cleanFakeFs() + } + }) + Context("Unbind, UnbindDriverByBusAndDevice", func() { + It("unknown device", func() { + Expect(k.UnbindDriverByBusAndDevice(consts.BusPci, testUnknownDev)).NotTo(HaveOccurred()) + }) + It("known device, no driver", func() { + configureFS(&fakefilesystem.FS{Dirs: []string{testDevPath}}) + Expect(k.Unbind(testDev)).NotTo(HaveOccurred()) + }) + It("has driver, succeed", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + Files: map[string][]byte{testDriverUnbindPath: {}}, + }) + Expect(k.Unbind(testDev)).NotTo(HaveOccurred()) + // check that echo to unbind path was done + testFileContent(testDriverUnbindPath, testDev) + }) + It("has driver, failed to unbind", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + }) + Expect(k.Unbind(testDev)).To(HaveOccurred()) + }) + }) + Context("HasDriver", func() { + It("unknown device", func() { + has, driver := k.HasDriver(testUnknownDev) + Expect(has).To(BeFalse()) + Expect(driver).To(BeEmpty()) + }) + It("known device, no driver", func() { + configureFS(&fakefilesystem.FS{Dirs: []string{testDevPath}}) + has, driver := k.HasDriver(testDev) + Expect(has).To(BeFalse()) + Expect(driver).To(BeEmpty()) + }) + It("has driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + }) + has, driver := k.HasDriver(testDev) + Expect(has).To(BeTrue()) + Expect(driver).To(Equal(testDriver)) + }) + }) + Context("BindDriverByBusAndDevice", func() { + It("unknown device", func() { + Expect(k.BindDriverByBusAndDevice(consts.BusPci, testUnknownDev, testDriver)).To(HaveOccurred()) + }) + It("bind to unknown driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath}, + }) + Expect(k.BindDriverByBusAndDevice(consts.BusPci, testDev, testUnknownDriver)).To(HaveOccurred()) + }) + It("already has required driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + Files: map[string][]byte{testDriverBindPath: {}, testDriverUnbindPath: {}}, + }) + Expect(k.BindDriverByBusAndDevice(consts.BusPci, testDev, testDriver)).NotTo(HaveOccurred()) + // check that echo to bind/unbind path was not executed + testFileContent(testDriverBindPath, "") + testFileContent(testDriverUnbindPath, "") + }) + It("no driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath}, + Files: map[string][]byte{testDriverBindPath: {}, testDriverUnbindPath: {}}, + }) + Expect(k.BindDriverByBusAndDevice(consts.BusPci, testDev, testDriver)).NotTo(HaveOccurred()) + // check that echo to bind/unbind path was not executed + testFileContent(testDriverBindPath, testDev) + testFileContent(testDriverUnbindPath, "") + }) + It("wrong driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath, testDriver2Path}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + Files: map[string][]byte{testDriverUnbindPath: {}, testDriver2BindPath: {}}, + }) + Expect(k.BindDriverByBusAndDevice(consts.BusPci, testDev, testDriver2)).NotTo(HaveOccurred()) + // should unbind from driver1 + testFileContent(testDriverUnbindPath, testDev) + // should bind to driver2 + testFileContent(testDriver2BindPath, testDev) + }) + }) + Context("BindDefaultDriver", func() { + It("unknown device", func() { + Expect(k.BindDefaultDriver(testUnknownDev)).To(HaveOccurred()) + }) + It("no driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath}, + Files: map[string][]byte{testPciDeviceProbePath: {}, testDevDriverOverridePath: {}}, + }) + Expect(k.BindDefaultDriver(testDev)).NotTo(HaveOccurred()) + // should probe driver for dev + testFileContent(testPciDeviceProbePath, testDev) + }) + It("already bind to default driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + }) + Expect(k.BindDefaultDriver(testDev)).NotTo(HaveOccurred()) + }) + It("bind to dpdk driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriver2Path}, + Symlinks: map[string]string{testDevDriverPath: testDevDriver2Symlink}, + Files: map[string][]byte{testPciDeviceProbePath: {}, testDriver2UnbindPath: {}}, + }) + Expect(k.BindDefaultDriver(testDev)).NotTo(HaveOccurred()) + // should unbind from dpdk driver + testFileContent(testDriver2UnbindPath, testDev) + // should probe driver for dev + testFileContent(testPciDeviceProbePath, testDev) + }) + }) + Context("BindDpdkDriver", func() { + It("unknown device", func() { + Expect(k.BindDpdkDriver(testUnknownDev, testDriver2)).To(HaveOccurred()) + }) + It("no driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriver2Path}, + Files: map[string][]byte{testDevDriverOverridePath: {}}, + }) + Expect(k.BindDpdkDriver(testDev, testDriver2)).NotTo(HaveOccurred()) + // should reset driver override + testFileContent(testDevDriverOverridePath, "") + }) + It("already bind to required driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath}, + Symlinks: map[string]string{testDevDriverPath: testDevDriver2Symlink}, + }) + Expect(k.BindDpdkDriver(testDev, testDriver2)).NotTo(HaveOccurred()) + }) + It("bind to wrong driver", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath, testDriver2Path}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + Files: map[string][]byte{testDriverUnbindPath: {}, testDriver2BindPath: {}, testDevDriverOverridePath: {}}, + }) + Expect(k.BindDpdkDriver(testDev, testDriver2)).NotTo(HaveOccurred()) + // should unbind from driver1 + testFileContent(testDriverUnbindPath, testDev) + // should bind to driver2 + testFileContent(testDriver2BindPath, testDev) + }) + It("fail to bind", func() { + configureFS(&fakefilesystem.FS{ + Dirs: []string{testDevPath, testDriverPath,}, + Symlinks: map[string]string{testDevDriverPath: testDevDriverSymlink}, + Files: map[string][]byte{testDriverUnbindPath: {}, testDevDriverOverridePath: {}}, + }) + Expect(k.BindDpdkDriver(testDev, testDriver2)).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/pkg/host/kernel.go b/pkg/host/kernel.go index ac7277eeff..bf6ee23b73 100644 --- a/pkg/host/kernel.go +++ b/pkg/host/kernel.go @@ -1,12 +1,12 @@ package host import ( + "errors" "fmt" "os" "path/filepath" "strings" - dputils "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils" "sigs.k8s.io/controller-runtime/pkg/log" sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" @@ -36,12 +36,21 @@ type KernelInterface interface { BindDpdkDriver(string, string) error // BindDefaultDriver binds the virtual function to is default driver BindDefaultDriver(string) error + // BindDriverByBusAndDevice binds device to the provided driver + // bus - the bus path in the sysfs, e.g. "pci" or "vdpa" + // device - the name of the device on the bus, e.g. 0000:85:1e.5 for PCI or vpda1 for VDPA + // driver - the name of the driver, e.g. vfio-pci or vhost_vdpa. + BindDriverByBusAndDevice(string, string, string) error // HasDriver returns try if the virtual function is bind to a driver HasDriver(string) (bool, string) // RebindVfToDefaultDriver rebinds the virtual function to is default driver RebindVfToDefaultDriver(string) error // UnbindDriverIfNeeded unbinds the virtual function from a driver if needed UnbindDriverIfNeeded(string, bool) error + // UnbindDriverByBusAndDevice unbind device identified by bus and device ID from the driver + // bus - the bus path in the sysfs, e.g. "pci" or "vdpa" + // device - the name of the device on the bus, e.g. 0000:85:1e.5 for PCI or vpda1 for VDPA + UnbindDriverByBusAndDevice(bus, device string) error // LoadKernelModule loads a kernel module to the host LoadKernelModule(name string, args ...string) error // IsKernelModuleLoaded returns try if the requested kernel module is loaded @@ -165,18 +174,7 @@ func (k *kernel) IsKernelArgsSet(cmdLine string, karg string) bool { // Unbind unbind driver for one device func (k *kernel) Unbind(pciAddr string) error { log.Log.V(2).Info("Unbind(): unbind device driver for device", "device", pciAddr) - yes, driver := k.HasDriver(pciAddr) - if !yes { - return nil - } - - filePath := filepath.Join(vars.FilesystemRoot, consts.SysBusPciDrivers, driver, "unbind") - err := os.WriteFile(filePath, []byte(pciAddr), os.ModeAppend) - if err != nil { - log.Log.Error(err, "Unbind(): fail to unbind driver for device", "device", pciAddr) - return err - } - return nil + return k.UnbindDriverByBusAndDevice(consts.BusPci, pciAddr) } // BindDpdkDriver bind dpdk driver for one device @@ -185,44 +183,36 @@ func (k *kernel) BindDpdkDriver(pciAddr, driver string) error { log.Log.V(2).Info("BindDpdkDriver(): bind device to driver", "device", pciAddr, "driver", driver) - if yes, d := k.HasDriver(pciAddr); yes { - if driver == d { + curDriver, err := getDriverByBusAndDevice(consts.BusPci, pciAddr) + if err != nil { + return err + } + if curDriver != "" { + if curDriver == driver { log.Log.V(2).Info("BindDpdkDriver(): device already bound to driver", "device", pciAddr, "driver", driver) return nil } - if err := k.Unbind(pciAddr); err != nil { + if err := k.UnbindDriverByBusAndDevice(consts.BusPci, pciAddr); err != nil { return err } } - - driverOverridePath := filepath.Join(vars.FilesystemRoot, consts.SysBusPciDevices, pciAddr, "driver_override") - err := os.WriteFile(driverOverridePath, []byte(driver), os.ModeAppend) - if err != nil { - log.Log.Error(err, "BindDpdkDriver(): fail to write driver_override for device", - "device", pciAddr, "driver", driver) + if err := setDriverOverride(consts.BusPci, pciAddr, driver); err != nil { return err } - bindPath := filepath.Join(vars.FilesystemRoot, consts.SysBusPciDrivers, driver, "bind") - err = os.WriteFile(bindPath, []byte(pciAddr), os.ModeAppend) - if err != nil { - log.Log.Error(err, "BindDpdkDriver(): fail to bind driver for device", - "driver", driver, "device", pciAddr) - _, err := os.Readlink(filepath.Join(vars.FilesystemRoot, consts.SysBusPciDevices, pciAddr, "iommu_group")) - if err != nil { + if err := bindDriver(consts.BusPci, pciAddr, driver); err != nil { + _, innerErr := os.Readlink(filepath.Join(vars.FilesystemRoot, consts.SysBusPciDevices, pciAddr, "iommu_group")) + if innerErr != nil { log.Log.Error(err, "Could not read IOMMU group for device", "device", pciAddr) return fmt.Errorf( - "cannot bind driver %s to device %s, make sure IOMMU is enabled in BIOS. %w", driver, pciAddr, err) + "cannot bind driver %s to device %s, make sure IOMMU is enabled in BIOS. %w", driver, pciAddr, innerErr) } return err } - err = os.WriteFile(driverOverridePath, []byte(""), os.ModeAppend) - if err != nil { - log.Log.Error(err, "BindDpdkDriver(): failed to clear driver_override for device", "device", pciAddr) + if err := setDriverOverride(consts.BusPci, pciAddr, ""); err != nil { return err } - return nil } @@ -231,32 +221,52 @@ func (k *kernel) BindDpdkDriver(pciAddr, driver string) error { func (k *kernel) BindDefaultDriver(pciAddr string) error { log.Log.V(2).Info("BindDefaultDriver(): bind device to default driver", "device", pciAddr) - if yes, d := k.HasDriver(pciAddr); yes { - if !sriovnetworkv1.StringInArray(d, vars.DpdkDrivers) { + curDriver, err := getDriverByBusAndDevice(consts.BusPci, pciAddr) + if err != nil { + return err + } + if curDriver != "" { + if !sriovnetworkv1.StringInArray(curDriver, vars.DpdkDrivers) { log.Log.V(2).Info("BindDefaultDriver(): device already bound to default driver", - "device", pciAddr, "driver", d) + "device", pciAddr, "driver", curDriver) return nil } - if err := k.Unbind(pciAddr); err != nil { + if err := k.UnbindDriverByBusAndDevice(consts.BusPci, pciAddr); err != nil { return err } } - - driverOverridePath := filepath.Join(vars.FilesystemRoot, consts.SysBusPciDevices, pciAddr, "driver_override") - err := os.WriteFile(driverOverridePath, []byte("\x00"), os.ModeAppend) - if err != nil { - log.Log.Error(err, "BindDefaultDriver(): failed to write driver_override for device", "device", pciAddr) + if err := setDriverOverride(consts.BusPci, pciAddr, ""); err != nil { + return err + } + if err := probeDriver(consts.BusPci, pciAddr); err != nil { return err } + return nil +} - pciDriversProbe := filepath.Join(vars.FilesystemRoot, consts.SysBusPciDriversProbe) - err = os.WriteFile(pciDriversProbe, []byte(pciAddr), os.ModeAppend) +// BindDriverByBusAndDevice binds device to the provided driver +// bus - the bus path in the sysfs, e.g. "pci" or "vdpa" +// device - the name of the device on the bus, e.g. 0000:85:1e.5 for PCI or vpda1 for VDPA +// driver - the name of the driver, e.g. vfio-pci or vhost_vdpa. +func (k *kernel) BindDriverByBusAndDevice(bus, device, driver string) error { + log.Log.V(2).Info("BindDriverByBusAndDevice(): bind device to driver", + "bus", bus, "device", device, "driver", driver) + + curDriver, err := getDriverByBusAndDevice(bus, device) if err != nil { - log.Log.Error(err, "BindDefaultDriver(): failed to bind driver for device", "device", pciAddr) return err } - - return nil + if curDriver != "" { + if curDriver == driver { + log.Log.V(2).Info("BindDriverByBusAndDevice(): device already bound to driver", + "bus", bus, "device", device, "driver", driver) + return nil + } + if err := k.UnbindDriverByBusAndDevice(bus, device); err != nil { + return err + } + } + return bindDriver(bus, device, driver) } // Workaround function to handle a case where the vf default driver is stuck and not able to create the vf kernel interface. @@ -287,14 +297,33 @@ func (k *kernel) UnbindDriverIfNeeded(vfAddr string, isRdma bool) error { return nil } +// UnbindDriverByBusAndDevice unbind device identified by bus and device ID from the driver +// bus - the bus path in the sysfs, e.g. "pci" or "vdpa" +// device - the name of the device on the bus, e.g. 0000:85:1e.5 for PCI or vpda1 for VDPA +func (k *kernel) UnbindDriverByBusAndDevice(bus, device string) error { + log.Log.V(2).Info("UnbindDriverByBusAndDevice(): unbind device driver for device", "bus", bus, "device", device) + driver, err := getDriverByBusAndDevice(bus, device) + if err != nil { + return err + } + if driver == "" { + log.Log.V(2).Info("UnbindDriverByBusAndDevice(): device has no driver", "bus", bus, "device", device) + return nil + } + return unbindDriver(bus, device, driver) +} + func (k *kernel) HasDriver(pciAddr string) (bool, string) { - driver, err := dputils.GetDriverName(pciAddr) + driver, err := getDriverByBusAndDevice(consts.BusPci, pciAddr) if err != nil { log.Log.V(2).Info("HasDriver(): device driver is empty for device", "device", pciAddr) return false, "" } - log.Log.V(2).Info("HasDriver(): device driver for device", "device", pciAddr, "driver", driver) - return true, driver + if driver != "" { + log.Log.V(2).Info("HasDriver(): device driver for device", "device", pciAddr, "driver", driver) + return true, driver + } + return false, "" } func (k *kernel) TryEnableRdma() (bool, error) { @@ -625,3 +654,73 @@ func (k *kernel) IsKernelLockdownMode() bool { } return strings.Contains(stdout, "[integrity]") || strings.Contains(stdout, "[confidentiality]") } + +// returns driver for device on the bus +func getDriverByBusAndDevice(bus, device string) (string, error) { + driverLink := filepath.Join(vars.FilesystemRoot, consts.SysBus, bus, "devices", device, "driver") + driverInfo, err := os.Readlink(driverLink) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + log.Log.V(2).Info("getDriverByBusAndDevice(): driver path for device not exist", "bus", bus, "device", device, "driver", driverInfo) + return "", nil + } + log.Log.Error(err, "getDriverByBusAndDevice(): error getting driver info for device", "bus", bus, "device", device) + return "", err + } + log.Log.V(2).Info("getDriverByBusAndDevice(): driver for device", "bus", bus, "device", device, "driver", driverInfo) + return filepath.Base(driverInfo), nil +} + +// binds device to the provide driver +func bindDriver(bus, device, driver string) error { + log.Log.V(2).Info("bindDriver(): bind to driver", "bus", bus, "device", device, "driver", driver) + bindPath := filepath.Join(vars.FilesystemRoot, consts.SysBus, bus, "drivers", driver, "bind") + err := os.WriteFile(bindPath, []byte(device), os.ModeAppend) + if err != nil { + log.Log.Error(err, "bindDriver(): failed to bind driver", "bus", bus, "device", device, "driver", driver) + return err + } + return nil +} + +// unbind device from the driver +func unbindDriver(bus, device, driver string) error { + log.Log.V(2).Info("unbindDriver(): unbind from driver", "bus", bus, "device", device, "driver", driver) + unbindPath := filepath.Join(vars.FilesystemRoot, consts.SysBus, bus, "drivers", driver, "unbind") + err := os.WriteFile(unbindPath, []byte(device), os.ModeAppend) + if err != nil { + log.Log.Error(err, "unbindDriver(): failed to unbind driver", "bus", bus, "device", device, "driver", driver) + return err + } + return nil +} + +// probes driver for device on the bus +func probeDriver(bus, device string) error { + log.Log.V(2).Info("probeDriver(): drivers probe", "bus", bus, "device", device) + probePath := filepath.Join(vars.FilesystemRoot, consts.SysBus, bus, "drivers_probe") + err := os.WriteFile(probePath, []byte(device), os.ModeAppend) + if err != nil { + log.Log.Error(err, "probeDriver(): failed to trigger driver probe", "bus", bus, "device", device) + return err + } + return nil +} + +// set driver override for the bus/device, +// resets overrride if override arg is """ +func setDriverOverride(bus, device, override string) error { + driverOverridePath := filepath.Join(vars.FilesystemRoot, consts.SysBus, bus, "devices", device, "driver_override") + if override != "" { + log.Log.V(2).Info("setDriverOverride(): configure driver override for device", "bus", bus, "device", device, "driver", override) + } else { + log.Log.V(2).Info("setDriverOverride(): reset driver override for device", "bus", bus, "device", device) + } + err := os.WriteFile(driverOverridePath, []byte(override), os.ModeAppend) + if err != nil { + log.Log.Error(err, "setDriverOverride(): fail to write driver_override for device", + "bus", bus, "device", device, "driver", override) + return err + } + return nil +}