diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go index 8a7c8e82..4b42678e 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go @@ -40,7 +40,7 @@ import ( ) func create(ctx context.Context, conn *networkservice.Connection, isClient bool) error { - if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil { + if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { // Note: These are switched from normal because if we are the client, we need to assign the IP // in the Endpoints NetNS for the Dst. If we are the *server* we need to assign the IP for the // clients NetNS (ie the source). diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipneighbors/common.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipneighbors/common.go index c52f8671..ee81f13a 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipneighbors/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipneighbors/common.go @@ -31,7 +31,7 @@ import ( ) func create(conn *networkservice.Connection) error { - if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil { + if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL()) if err != nil { return errors.WithStack(err) diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/routes/common.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/routes/common.go index b9dfcb17..85bdcef0 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/routes/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/routes/common.go @@ -35,7 +35,7 @@ import ( ) func create(ctx context.Context, conn *networkservice.Connection, isClient bool) error { - if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil { + if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL()) if err != nil { return errors.WithStack(err) diff --git a/pkg/kernel/networkservice/connectioncontextkernel/mtu/common.go b/pkg/kernel/networkservice/connectioncontextkernel/mtu/common.go index 34a13e49..3b06eea6 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/mtu/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/mtu/common.go @@ -33,7 +33,7 @@ import ( ) func setMTU(ctx context.Context, conn *networkservice.Connection) error { - if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil { + if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { // Note: These are switched from normal because if we are the client, we need to assign the IP // in the Endpoints NetNS for the Dst. If we are the *server* we need to assign the IP for the // clients NetNS (ie the source). diff --git a/pkg/kernel/networkservice/inject/client.go b/pkg/kernel/networkservice/inject/client.go index 0ec8f1b0..17562ad7 100644 --- a/pkg/kernel/networkservice/inject/client.go +++ b/pkg/kernel/networkservice/inject/client.go @@ -18,6 +18,7 @@ package inject import ( "context" + "sync" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" @@ -31,13 +32,16 @@ import ( "github.com/networkservicemesh/sdk-kernel/pkg/kernel/networkservice/vfconfig" ) -type injectClient struct{} +type injectClient struct { + vfRefCountMap map[string]int + vfRefCountMutex sync.Mutex +} // NewClient - returns a new networkservice.NetworkServiceClient that moves given network // interface into the Endpoint's pod network namespace on Request and back to Forwarder's // network namespace on Close func NewClient() networkservice.NetworkServiceClient { - return &injectClient{} + return &injectClient{vfRefCountMap: make(map[string]int)} } func (c *injectClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { @@ -54,7 +58,7 @@ func (c *injectClient) Request(ctx context.Context, request *networkservice.Netw } if !isEstablished { - if err := move(ctx, conn, metadata.IsClient(c), false); err != nil { + if err := move(ctx, conn, c.vfRefCountMap, &c.vfRefCountMutex, metadata.IsClient(c), false); err != nil { closeCtx, cancelClose := postponeCtxFunc() defer cancelClose() @@ -72,7 +76,7 @@ func (c *injectClient) Request(ctx context.Context, request *networkservice.Netw func (c *injectClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) { rv, err := next.Client(ctx).Close(ctx, conn, opts...) - injectErr := move(ctx, conn, metadata.IsClient(c), true) + injectErr := move(ctx, conn, c.vfRefCountMap, &c.vfRefCountMutex, metadata.IsClient(c), true) if err != nil && injectErr != nil { return nil, errors.Wrap(err, injectErr.Error()) diff --git a/pkg/kernel/networkservice/inject/common.go b/pkg/kernel/networkservice/inject/common.go index 9dcd8165..8d16c583 100644 --- a/pkg/kernel/networkservice/inject/common.go +++ b/pkg/kernel/networkservice/inject/common.go @@ -19,6 +19,7 @@ package inject import ( "context" "strings" + "sync" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" @@ -69,7 +70,7 @@ func renameInterface(origIfName, desiredIfName string, curNetNS, targetNetNS net }) } -func move(ctx context.Context, conn *networkservice.Connection, isClient, isMoveBack bool) error { +func move(ctx context.Context, conn *networkservice.Connection, vfRefCountMap map[string]int, vfRefCountMutex sync.Locker, isClient, isMoveBack bool) error { mech := kernel.ToMechanism(conn.GetMechanism()) if mech == nil { return nil @@ -102,12 +103,20 @@ func move(ctx context.Context, conn *networkservice.Connection, isClient, isMove defer func() { _ = contNetNS.Close() }() } + vfRefCountMutex.Lock() + defer vfRefCountMutex.Unlock() + + vfRefKey := vfConfig.VFPCIAddress + if vfRefKey == "" { + vfRefKey = vfConfig.VFInterfaceName + } + ifName := mech.GetInterfaceName() if !isMoveBack { - err = moveToContNetNS(vfConfig, ifName, hostNetNS, contNetNS) + err = moveToContNetNS(vfConfig, vfRefCountMap, vfRefKey, ifName, hostNetNS, contNetNS) vfConfig.ContNetNS = contNetNS } else { - err = moveToHostNetNS(vfConfig, ifName, hostNetNS, contNetNS) + err = moveToHostNetNS(vfConfig, vfRefCountMap, vfRefKey, ifName, hostNetNS, contNetNS) } if err != nil { // link may not be available at this stage for cases like veth pair (might be deleted in previous chain element itself) @@ -120,7 +129,13 @@ func move(ctx context.Context, conn *networkservice.Connection, isClient, isMove return nil } -func moveToContNetNS(vfConfig *vfconfig.VFConfig, ifName string, hostNetNS, contNetNS netns.NsHandle) (err error) { +func moveToContNetNS(vfConfig *vfconfig.VFConfig, vfRefCountMap map[string]int, vfRefKey, ifName string, hostNetNS, contNetNS netns.NsHandle) (err error) { + if _, exists := vfRefCountMap[vfRefKey]; !exists { + vfRefCountMap[vfRefKey] = 1 + } else { + vfRefCountMap[vfRefKey]++ + return + } link, _ := kernellink.FindHostDevice("", ifName, contNetNS) if link != nil { return @@ -136,28 +151,39 @@ func moveToContNetNS(vfConfig *vfconfig.VFConfig, ifName string, hostNetNS, cont return } -func moveToHostNetNS(vfConfig *vfconfig.VFConfig, ifName string, hostNetNS, contNetNS netns.NsHandle) (err error) { - if vfConfig != nil && vfConfig.VFInterfaceName != ifName { - link, _ := kernellink.FindHostDevice(vfConfig.VFPCIAddress, vfConfig.VFInterfaceName, hostNetNS) - if link != nil { - linkName := link.GetName() - if linkName != vfConfig.VFInterfaceName { - if err = netlink.LinkSetName(link.GetLink(), vfConfig.VFInterfaceName); err != nil { - err = errors.Wrapf(err, "failed to rename interface from %s to %s", linkName, vfConfig.VFInterfaceName) +func moveToHostNetNS(vfConfig *vfconfig.VFConfig, vfRefCountMap map[string]int, vfRefKey, ifName string, hostNetNS, contNetNS netns.NsHandle) error { + var refCount int + if count, exists := vfRefCountMap[vfRefKey]; exists && count > 0 { + refCount = count - 1 + vfRefCountMap[vfRefKey] = refCount + } else { + return nil + } + + if refCount == 0 { + delete(vfRefCountMap, vfRefKey) + if vfConfig != nil && vfConfig.VFInterfaceName != ifName { + link, _ := kernellink.FindHostDevice(vfConfig.VFPCIAddress, vfConfig.VFInterfaceName, hostNetNS) + if link != nil { + linkName := link.GetName() + if linkName != vfConfig.VFInterfaceName { + if err := netlink.LinkSetName(link.GetLink(), vfConfig.VFInterfaceName); err != nil { + return errors.Wrapf(err, "failed to rename interface from %s to %s: %v", linkName, vfConfig.VFInterfaceName, err) + } } + return nil } - return - } - err = renameInterface(ifName, vfConfig.VFInterfaceName, hostNetNS, contNetNS) - if err == nil { - err = moveInterfaceToAnotherNamespace(vfConfig.VFInterfaceName, hostNetNS, contNetNS, hostNetNS) + err := renameInterface(ifName, vfConfig.VFInterfaceName, hostNetNS, contNetNS) + if err == nil { + err = moveInterfaceToAnotherNamespace(vfConfig.VFInterfaceName, hostNetNS, contNetNS, hostNetNS) + } + return err } - } else { link, _ := kernellink.FindHostDevice("", ifName, hostNetNS) if link != nil { return nil } - err = moveInterfaceToAnotherNamespace(ifName, hostNetNS, contNetNS, hostNetNS) + return moveInterfaceToAnotherNamespace(ifName, hostNetNS, contNetNS, hostNetNS) } - return + return nil } diff --git a/pkg/kernel/networkservice/inject/server.go b/pkg/kernel/networkservice/inject/server.go index 103e5879..43c6226c 100644 --- a/pkg/kernel/networkservice/inject/server.go +++ b/pkg/kernel/networkservice/inject/server.go @@ -19,6 +19,7 @@ package inject import ( "context" + "sync" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" @@ -32,12 +33,15 @@ import ( "github.com/networkservicemesh/sdk-kernel/pkg/kernel/networkservice/vfconfig" ) -type injectServer struct{} +type injectServer struct { + vfRefCountMap map[string]int + vfRefCountMutex sync.Mutex +} // NewServer - returns a new networkservice.NetworkServiceServer that moves given network interface into the Client's // pod network namespace on Request and back to Forwarder's network namespace on Close func NewServer() networkservice.NetworkServiceServer { - return &injectServer{} + return &injectServer{vfRefCountMap: make(map[string]int)} } func (s *injectServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { @@ -52,7 +56,7 @@ func (s *injectServer) Request(ctx context.Context, request *networkservice.Netw } if !isEstablished { - if err := move(ctx, request.GetConnection(), metadata.IsClient(s), false); err != nil { + if err := move(ctx, request.GetConnection(), s.vfRefCountMap, &s.vfRefCountMutex, metadata.IsClient(s), false); err != nil { return nil, err } } @@ -64,7 +68,7 @@ func (s *injectServer) Request(ctx context.Context, request *networkservice.Netw moveCtx, cancelMove := postponeCtxFunc() defer cancelMove() - if moveRenameErr := move(moveCtx, request.GetConnection(), metadata.IsClient(s), true); moveRenameErr != nil { + if moveRenameErr := move(moveCtx, request.GetConnection(), s.vfRefCountMap, &s.vfRefCountMutex, metadata.IsClient(s), true); moveRenameErr != nil { err = errors.Wrapf(err, "server request failed, failed to move back the interface: %s", moveRenameErr.Error()) } } @@ -75,7 +79,7 @@ func (s *injectServer) Request(ctx context.Context, request *networkservice.Netw func (s *injectServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { _, err := next.Server(ctx).Close(ctx, conn) - moveRenameErr := move(ctx, conn, metadata.IsClient(s), true) + moveRenameErr := move(ctx, conn, s.vfRefCountMap, &s.vfRefCountMutex, metadata.IsClient(s), true) if err != nil && moveRenameErr != nil { return nil, errors.Wrap(err, moveRenameErr.Error())