diff --git a/virtcontainers/network.go b/virtcontainers/network.go index 94e961572f..01a666c8a0 100644 --- a/virtcontainers/network.go +++ b/virtcontainers/network.go @@ -1300,7 +1300,7 @@ func networkInfoFromLink(handle *netlink.Handle, link netlink.Link) (NetworkInfo }, nil } -func createEndpointsFromScan(networkNSPath string, config NetworkConfig) ([]Endpoint, error) { +func createEndpointsFromScan(networkNSPath string, config *NetworkConfig) ([]Endpoint, error) { var endpoints []Endpoint netnsHandle, err := netns.GetFromPath(networkNSPath) @@ -1441,26 +1441,24 @@ func (n *Network) Run(networkNSPath string, cb func() error) error { } // Add adds all needed interfaces inside the network namespace. -func (n *Network) Add(s *Sandbox, hotplug bool) error { - span, _ := n.trace(s.ctx, "add") +func (n *Network) Add(ctx context.Context, config *NetworkConfig, hypervisor hypervisor, hotplug bool) ([]Endpoint, error) { + span, _ := n.trace(ctx, "add") defer span.Finish() - endpoints, err := createEndpointsFromScan(s.config.NetworkConfig.NetNSPath, s.config.NetworkConfig) + endpoints, err := createEndpointsFromScan(config.NetNSPath, config) if err != nil { - return err + return endpoints, err } - s.networkNS.Endpoints = endpoints - - err = doNetNS(s.config.NetworkConfig.NetNSPath, func(_ ns.NetNS) error { - for _, endpoint := range s.networkNS.Endpoints { + err = doNetNS(config.NetNSPath, func(_ ns.NetNS) error { + for _, endpoint := range endpoints { networkLogger().WithField("endpoint-type", endpoint.Type()).WithField("hotplug", hotplug).Info("Attaching endpoint") if hotplug { - if err := endpoint.HotAttach(s.hypervisor); err != nil { + if err := endpoint.HotAttach(hypervisor); err != nil { return err } } else { - if err := endpoint.Attach(s.hypervisor); err != nil { + if err := endpoint.Attach(hypervisor); err != nil { return err } } @@ -1469,30 +1467,30 @@ func (n *Network) Add(s *Sandbox, hotplug bool) error { return nil }) if err != nil { - return err + return []Endpoint{}, err } networkLogger().Debug("Network added") - return nil + return endpoints, nil } // Remove network endpoints in the network namespace. It also deletes the network // namespace in case the namespace has been created by us. -func (n *Network) Remove(s *Sandbox, hotunplug bool) error { - span, _ := n.trace(s.ctx, "remove") +func (n *Network) Remove(ctx context.Context, ns *NetworkNamespace, hypervisor hypervisor, hotunplug bool) error { + span, _ := n.trace(ctx, "remove") defer span.Finish() - for _, endpoint := range s.networkNS.Endpoints { + for _, endpoint := range ns.Endpoints { // Detach for an endpoint should enter the network namespace // if required. networkLogger().WithField("endpoint-type", endpoint.Type()).WithField("hotunplug", hotunplug).Info("Detaching endpoint") if hotunplug { - if err := endpoint.HotDetach(s.hypervisor, s.networkNS.NetNsCreated, s.networkNS.NetNsPath); err != nil { + if err := endpoint.HotDetach(hypervisor, ns.NetNsCreated, ns.NetNsPath); err != nil { return err } } else { - if err := endpoint.Detach(s.networkNS.NetNsCreated, s.networkNS.NetNsPath); err != nil { + if err := endpoint.Detach(ns.NetNsCreated, ns.NetNsPath); err != nil { return err } } @@ -1500,9 +1498,9 @@ func (n *Network) Remove(s *Sandbox, hotunplug bool) error { networkLogger().Debug("Network removed") - if s.networkNS.NetNsCreated { - networkLogger().Infof("Network namespace %q deleted", s.networkNS.NetNsPath) - return deleteNetNS(s.networkNS.NetNsPath) + if ns.NetNsCreated { + networkLogger().Infof("Network namespace %q deleted", ns.NetNsPath) + return deleteNetNS(ns.NetNsPath) } return nil diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index 60942d4ce3..1a9e715db7 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -797,10 +797,13 @@ func (s *Sandbox) createNetwork() error { // after vm is started. if s.factory == nil { // Add the network - if err := s.network.Add(s, false); err != nil { + endpoints, err := s.network.Add(s.ctx, &s.config.NetworkConfig, s.hypervisor, false) + if err != nil { return err } + s.networkNS.Endpoints = endpoints + if s.config.NetworkConfig.NetmonConfig.Enable { if err := s.startNetworkMonitor(); err != nil { return err @@ -822,7 +825,7 @@ func (s *Sandbox) removeNetwork() error { } } - return s.network.Remove(s, s.factory != nil) + return s.network.Remove(s.ctx, &s.networkNS, s.hypervisor, s.factory != nil) } func (s *Sandbox) generateNetInfo(inf *vcTypes.Interface) (NetworkInfo, error) { @@ -954,10 +957,13 @@ func (s *Sandbox) startVM() error { // In case of vm factory, network interfaces are hotplugged // after vm is started. if s.factory != nil { - if err := s.network.Add(s, true); err != nil { + endpoints, err := s.network.Add(s.ctx, &s.config.NetworkConfig, s.hypervisor, true) + if err != nil { return err } + s.networkNS.Endpoints = endpoints + if s.config.NetworkConfig.NetmonConfig.Enable { if err := s.startNetworkMonitor(); err != nil { return err