diff --git a/pkg/ns/ns_test.go b/pkg/ns/ns_test.go index 42fc6322..7ad882f5 100644 --- a/pkg/ns/ns_test.go +++ b/pkg/ns/ns_test.go @@ -28,13 +28,9 @@ import ( . "github.com/onsi/gomega" ) -const CurrentNetNS = "/proc/self/ns/net" - var _ = Describe("Linux namespace operations", func() { Describe("WithNetNS", func() { var ( - originalNetNS *os.File - targetNetNSName string targetNetNSPath string targetNetNS *os.File @@ -42,8 +38,6 @@ var _ = Describe("Linux namespace operations", func() { BeforeEach(func() { var err error - originalNetNS, err = os.Open(CurrentNetNS) - Expect(err).NotTo(HaveOccurred()) targetNetNSName = fmt.Sprintf("test-netns-%d", rand.Int()) @@ -60,8 +54,6 @@ var _ = Describe("Linux namespace operations", func() { err := exec.Command("ip", "netns", "del", targetNetNSName).Run() Expect(err).NotTo(HaveOccurred()) - - Expect(originalNetNS.Close()).To(Succeed()) }) It("executes the callback within the target network namespace", func() { @@ -71,7 +63,7 @@ var _ = Describe("Linux namespace operations", func() { var actualInode uint64 var innerErr error err = ns.WithNetNS(targetNetNS, false, func(*os.File) error { - actualInode, innerErr = testhelpers.GetInode(CurrentNetNS) + actualInode, innerErr = testhelpers.GetInodeCurNetNS() return nil }) Expect(err).NotTo(HaveOccurred()) @@ -81,7 +73,7 @@ var _ = Describe("Linux namespace operations", func() { }) It("provides the original namespace as the argument to the callback", func() { - hostNSInode, err := testhelpers.GetInode(CurrentNetNS) + hostNSInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) var inputNSInode uint64 @@ -97,7 +89,7 @@ var _ = Describe("Linux namespace operations", func() { }) It("restores the calling thread to the original network namespace", func() { - preTestInode, err := testhelpers.GetInode(CurrentNetNS) + preTestInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) err = ns.WithNetNS(targetNetNS, false, func(*os.File) error { @@ -105,7 +97,7 @@ var _ = Describe("Linux namespace operations", func() { }) Expect(err).NotTo(HaveOccurred()) - postTestInode, err := testhelpers.GetInode(CurrentNetNS) + postTestInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) Expect(postTestInode).To(Equal(preTestInode)) @@ -113,14 +105,14 @@ var _ = Describe("Linux namespace operations", func() { Context("when the callback returns an error", func() { It("restores the calling thread to the original namespace before returning", func() { - preTestInode, err := testhelpers.GetInode(CurrentNetNS) + preTestInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) _ = ns.WithNetNS(targetNetNS, false, func(*os.File) error { return errors.New("potato") }) - postTestInode, err := testhelpers.GetInode(CurrentNetNS) + postTestInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) Expect(postTestInode).To(Equal(preTestInode)) @@ -136,7 +128,7 @@ var _ = Describe("Linux namespace operations", func() { Describe("validating inode mapping to namespaces", func() { It("checks that different namespaces have different inodes", func() { - hostNSInode, err := testhelpers.GetInode(CurrentNetNS) + hostNSInode, err := testhelpers.GetInodeCurNetNS() Expect(err).NotTo(HaveOccurred()) testNsInode, err := testhelpers.GetInode(targetNetNSPath) diff --git a/pkg/testhelpers/testhelpers.go b/pkg/testhelpers/testhelpers.go index 0963121d..004006a9 100644 --- a/pkg/testhelpers/testhelpers.go +++ b/pkg/testhelpers/testhelpers.go @@ -27,6 +27,16 @@ import ( . "github.com/onsi/gomega" ) +func getCurrentThreadNetNSPath() string { + pid := unix.Getpid() + tid := unix.Gettid() + return fmt.Sprintf("/proc/%d/task/%d/ns/net", pid, tid) +} + +func GetInodeCurNetNS() (uint64, error) { + return GetInode(getCurrentThreadNetNSPath()) +} + func GetInode(path string) (uint64, error) { file, err := os.Open(path) if err != nil { @@ -68,9 +78,7 @@ func MakeNetworkNS(containerID string) string { defer GinkgoRecover() // capture current thread's original netns - pid := unix.Getpid() - tid := unix.Gettid() - currentThreadNetNSPath := fmt.Sprintf("/proc/%d/task/%d/ns/net", pid, tid) + currentThreadNetNSPath := getCurrentThreadNetNSPath() originalNetNS, err := unix.Open(currentThreadNetNSPath, unix.O_RDONLY, 0) Expect(err).NotTo(HaveOccurred()) defer unix.Close(originalNetNS)