Skip to content

Commit

Permalink
Driver loading refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Leonardo Milleri <[email protected]>
  • Loading branch information
lmilleri authored and Leonardo Milleri committed Jul 24, 2023
1 parent 154d2d6 commit d5d6cb4
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 62 deletions.
171 changes: 109 additions & 62 deletions pkg/plugins/generic/generic_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,67 @@ import (

var PluginName = "generic_plugin"

// driver id
const (
Vfio = iota
VirtioVdpa
)

// driver name
const (
vfioPciDriver = "vfio_pci"
virtioVdpaDriver = "virtio_vdpa"
)

// function type for determining if a given driver has to be loaded in the kernel
type needDriver func(state *sriovnetworkv1.SriovNetworkNodeState, driverState *DriverState) bool

type DriverState struct {
DriverName string
DeviceType string
VdpaType string
NeedDriverFunc needDriver
DriverLoaded bool
}

type DriverStateMapType map[uint]*DriverState

type GenericPlugin struct {
PluginName string
SpecVersion string
DesireState *sriovnetworkv1.SriovNetworkNodeState
LastState *sriovnetworkv1.SriovNetworkNodeState
LoadVfioDriver uint
LoadVirtioVdpaDriver uint
RunningOnHost bool
HostManager host.HostManagerInterface
PluginName string
SpecVersion string
DesireState *sriovnetworkv1.SriovNetworkNodeState
LastState *sriovnetworkv1.SriovNetworkNodeState
DriverStateMap DriverStateMapType
RunningOnHost bool
HostManager host.HostManagerInterface
}

const scriptsPath = "bindata/scripts/enable-kargs.sh"

const (
unloaded = iota
loading
loaded
)

// Initialize our plugin and set up initial values
func NewGenericPlugin(runningOnHost bool) (plugin.VendorPlugin, error) {
driverStateMap := make(map[uint]*DriverState)
driverStateMap[Vfio] = &DriverState{
DriverName: vfioPciDriver,
DeviceType: constants.DeviceTypeVfioPci,
VdpaType: "",
NeedDriverFunc: needDriverCheckDeviceType,
DriverLoaded: false,
}
driverStateMap[VirtioVdpa] = &DriverState{
DriverName: virtioVdpaDriver,
DeviceType: constants.DeviceTypeNetDevice,
VdpaType: constants.VdpaTypeVirtio,
NeedDriverFunc: needDriverCheckVdpaType,
DriverLoaded: false,
}

return &GenericPlugin{
PluginName: PluginName,
SpecVersion: "1.0",
LoadVfioDriver: unloaded,
LoadVirtioVdpaDriver: unloaded,
RunningOnHost: runningOnHost,
HostManager: host.NewHostManager(runningOnHost),
PluginName: PluginName,
SpecVersion: "1.0",
DriverStateMap: driverStateMap,
RunningOnHost: runningOnHost,
HostManager: host.NewHostManager(runningOnHost),
}, nil
}

Expand All @@ -60,7 +93,7 @@ func (p *GenericPlugin) Spec() string {
return p.SpecVersion
}

// OnNodeStateChange Invoked when SriovNetworkNodeState CR is created or updated, return if need dain and/or reboot node
// OnNodeStateChange Invoked when SriovNetworkNodeState CR is created or updated, return if need drain and/or reboot node
func (p *GenericPlugin) OnNodeStateChange(new *sriovnetworkv1.SriovNetworkNodeState) (needDrain bool, needReboot bool, err error) {
glog.Info("generic-plugin OnNodeStateChange()")
needDrain = false
Expand All @@ -69,32 +102,30 @@ func (p *GenericPlugin) OnNodeStateChange(new *sriovnetworkv1.SriovNetworkNodeSt
p.DesireState = new

needDrain = needDrainNode(new.Spec.Interfaces, new.Status.Interfaces)
needReboot = needRebootNode(new, &p.LoadVfioDriver, &p.LoadVirtioVdpaDriver)
needReboot = needRebootNode(new, p.DriverStateMap)

if needReboot {
needDrain = true
}
return
}

// Apply config change
func (p *GenericPlugin) Apply() error {
glog.Infof("generic-plugin Apply(): desiredState=%v", p.DesireState.Spec)
if p.LoadVfioDriver == loading {
if err := p.HostManager.LoadKernelModule("vfio_pci"); err != nil {
glog.Errorf("generic-plugin Apply(): fail to load vfio_pci kmod: %v", err)
return err
func syncDriverState(p *GenericPlugin) error {
for _, driverState := range p.DriverStateMap {
if !driverState.DriverLoaded && driverState.NeedDriverFunc(p.DesireState, driverState) {
if err := p.HostManager.LoadKernelModule(driverState.DriverName); err != nil {
glog.Errorf("generic-plugin Apply(): fail to load %s kmod: %v", driverState.DriverName, err)
return err
}
driverState.DriverLoaded = true
}
p.LoadVfioDriver = loaded
}
return nil
}

if p.LoadVirtioVdpaDriver == loading {
if err := p.HostManager.LoadKernelModule("virtio_vdpa"); err != nil {
glog.Errorf("generic-plugin Apply(): fail to load virtio_vdpa kmod: %v", err)
return err
}
p.LoadVirtioVdpaDriver = loaded
}
// Apply config change
func (p *GenericPlugin) Apply() error {
glog.Infof("generic-plugin Apply(): desiredState=%v", p.DesireState.Spec)

if p.LastState != nil {
glog.Infof("generic-plugin Apply(): lastStat=%v", p.LastState.Spec)
Expand All @@ -104,6 +135,11 @@ func (p *GenericPlugin) Apply() error {
}
}

err := syncDriverState(p)
if err != nil {
return err
}

// Create a map with all the PFs we will need to configure
// we need to create it here before we access the host file system using the chroot function
// because the skipConfigVf needs the mstconfig package that exist only inside the sriov-config-daemon file system
Expand All @@ -129,21 +165,21 @@ func (p *GenericPlugin) Apply() error {
return nil
}

func needVfioDriver(state *sriovnetworkv1.SriovNetworkNodeState) bool {
func needDriverCheckDeviceType(state *sriovnetworkv1.SriovNetworkNodeState, driverState *DriverState) bool {
for _, iface := range state.Spec.Interfaces {
for i := range iface.VfGroups {
if iface.VfGroups[i].DeviceType == constants.DeviceTypeVfioPci {
if iface.VfGroups[i].DeviceType == driverState.DeviceType {
return true
}
}
}
return false
}

func needVirtioVdpaDriver(state *sriovnetworkv1.SriovNetworkNodeState) bool {
func needDriverCheckVdpaType(state *sriovnetworkv1.SriovNetworkNodeState, driverState *DriverState) bool {
for _, iface := range state.Spec.Interfaces {
for i := range iface.VfGroups {
if iface.VfGroups[i].VdpaType == constants.VdpaTypeVirtio {
if iface.VfGroups[i].VdpaType == driverState.VdpaType {
return true
}
}
Expand Down Expand Up @@ -217,35 +253,46 @@ func needDrainNode(desired sriovnetworkv1.Interfaces, current sriovnetworkv1.Int
return
}

func needRebootNode(state *sriovnetworkv1.SriovNetworkNodeState, loadVfioDriver *uint, loadVirtioVdpaDriver *uint) (needReboot bool) {
needReboot = false
if *loadVfioDriver != loaded {
if needVfioDriver(state) {
*loadVfioDriver = loading
update, err := tryEnableIommuInKernelArgs()
if err != nil {
glog.Errorf("generic-plugin needRebootNode():fail to enable iommu in kernel args: %v", err)
}
if update {
glog.V(2).Infof("generic-plugin needRebootNode(): need reboot for enabling iommu kernel args")
}
needReboot = needReboot || update
func needRebootIfVfio(state *sriovnetworkv1.SriovNetworkNodeState, driverMap DriverStateMapType) (needReboot bool) {
driverState := driverMap[Vfio]
if !driverState.DriverLoaded && driverState.NeedDriverFunc(state, driverState) {
var err error
needReboot, err = tryEnableIommuInKernelArgs()
if err != nil {
glog.Errorf("generic-plugin needRebootNode():fail to enable iommu in kernel args: %v", err)
}
}

if *loadVirtioVdpaDriver != loaded {
if needVirtioVdpaDriver(state) {
*loadVirtioVdpaDriver = loading
if needReboot {
glog.V(2).Infof("generic-plugin needRebootNode(): need reboot for enabling iommu kernel args")
}
}
return needReboot
}

func needRebootNode(state *sriovnetworkv1.SriovNetworkNodeState, driverMap DriverStateMapType) (needReboot bool) {
needReboot = needRebootIfVfio(state, driverMap)
updateNode, err := utils.WriteSwitchdevConfFile(state)

update, err := utils.WriteSwitchdevConfFile(state)
if err != nil {
glog.Errorf("generic-plugin needRebootNode(): fail to write switchdev device config file")
}
if update {
if updateNode {
glog.V(2).Infof("generic-plugin needRebootNode(): need reboot for updating switchdev device configuration")
}
needReboot = needReboot || update
needReboot = needReboot || updateNode
return
}

// ////////////// for testing purposes only ///////////////////////
func (p *GenericPlugin) GetDriverStateMap() DriverStateMapType {
return p.DriverStateMap
}

func (p *GenericPlugin) LoadDriverForTests(state *sriovnetworkv1.SriovNetworkNodeState) {
for _, driverState := range p.DriverStateMap {
if !driverState.DriverLoaded && driverState.NeedDriverFunc(state, driverState) {
driverState.DriverLoaded = true
}
}
}

//////////////////////////////////////////////////////////////////
113 changes: 113 additions & 0 deletions pkg/plugins/generic/generic_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,119 @@ var _ = Describe("Generic plugin", func() {
Expect(needReboot).To(BeFalse())
Expect(needDrain).To(BeTrue())
})

It("should load vfio_pci driver", func() {
networkNodeState := &sriovnetworkv1.SriovNetworkNodeState{
Spec: sriovnetworkv1.SriovNetworkNodeStateSpec{
Interfaces: sriovnetworkv1.Interfaces{{
PciAddress: "0000:00:00.0",
NumVfs: 2,
Mtu: 1500,
VfGroups: []sriovnetworkv1.VfGroup{{
DeviceType: "vfio-pci",
PolicyName: "policy-1",
ResourceName: "resource-1",
VfRange: "0-1",
Mtu: 1500,
}}}},
},
Status: sriovnetworkv1.SriovNetworkNodeStateStatus{
Interfaces: sriovnetworkv1.InterfaceExts{{
PciAddress: "0000:00:00.0",
NumVfs: 2,
TotalVfs: 2,
DeviceID: "1015",
Vendor: "15b3",
Name: "sriovif1",
Mtu: 1500,
Mac: "0c:42:a1:55:ee:46",
Driver: "mlx5_core",
EswitchMode: "legacy",
LinkSpeed: "25000 Mb/s",
LinkType: "ETH",
VFs: []sriovnetworkv1.VirtualFunction{{
PciAddress: "0000:00:00.1",
DeviceID: "1016",
Vendor: "15b3",
VfID: 0,
Driver: "mlx5_core",
Name: "sriovif1v0",
Mtu: 1500,
Mac: "8e:d6:2c:62:87:1b",
}, {
PciAddress: "0000:00:00.2",
DeviceID: "1016",
Vendor: "15b3",
VfID: 0,
Driver: "mlx5_core",
}},
}},
},
}

concretePlugin := genericPlugin.(*generic.GenericPlugin)
driverStateMap := concretePlugin.GetDriverStateMap()
driverState := driverStateMap[generic.Vfio]
concretePlugin.LoadDriverForTests(networkNodeState)
Expect(driverState.DriverLoaded).To(BeTrue())
})

It("should load virtio_vdpa driver", func() {
networkNodeState := &sriovnetworkv1.SriovNetworkNodeState{
Spec: sriovnetworkv1.SriovNetworkNodeStateSpec{
Interfaces: sriovnetworkv1.Interfaces{{
PciAddress: "0000:00:00.0",
NumVfs: 2,
Mtu: 1500,
VfGroups: []sriovnetworkv1.VfGroup{{
DeviceType: "netdevice",
VdpaType: "virtio",
PolicyName: "policy-1",
ResourceName: "resource-1",
VfRange: "0-1",
Mtu: 1500,
}}}},
},
Status: sriovnetworkv1.SriovNetworkNodeStateStatus{
Interfaces: sriovnetworkv1.InterfaceExts{{
PciAddress: "0000:00:00.0",
NumVfs: 2,
TotalVfs: 2,
DeviceID: "1015",
Vendor: "15b3",
Name: "sriovif1",
Mtu: 1500,
Mac: "0c:42:a1:55:ee:46",
Driver: "mlx5_core",
EswitchMode: "legacy",
LinkSpeed: "25000 Mb/s",
LinkType: "ETH",
VFs: []sriovnetworkv1.VirtualFunction{{
PciAddress: "0000:00:00.1",
DeviceID: "1016",
Vendor: "15b3",
VfID: 0,
Driver: "mlx5_core",
Name: "sriovif1v0",
Mtu: 1500,
Mac: "8e:d6:2c:62:87:1b",
}, {
PciAddress: "0000:00:00.2",
DeviceID: "1016",
Vendor: "15b3",
VfID: 0,
Driver: "mlx5_core",
}},
}},
},
}

concretePlugin := genericPlugin.(*generic.GenericPlugin)
driverStateMap := concretePlugin.GetDriverStateMap()
driverState := driverStateMap[generic.VirtioVdpa]
concretePlugin.LoadDriverForTests(networkNodeState)
Expect(driverState.DriverLoaded).To(BeTrue())
})
})

})

0 comments on commit d5d6cb4

Please sign in to comment.