diff --git a/netns_linux.go b/netns_linux.go index b6049d2..1cf5e13 100644 --- a/netns_linux.go +++ b/netns_linux.go @@ -47,7 +47,7 @@ func New() (ns NsHandle, err error) { // Get gets a handle to the current threads network namespace. func Get() (NsHandle, error) { - return GetFromPid(os.Getpid()) + return GetFromThread(os.Getpid(), syscall.Gettid()) } // GetFromName gets a handle to a named network namespace such as one @@ -60,7 +60,7 @@ func GetFromName(name string) (NsHandle, error) { return NsHandle(fd), nil } -// GetFromName gets a handle to the network namespace of a given pid. +// GetFromPid gets a handle to the network namespace of a given pid. func GetFromPid(pid int) (NsHandle, error) { fd, err := syscall.Open(fmt.Sprintf("/proc/%d/ns/net", pid), syscall.O_RDONLY, 0) if err != nil { @@ -69,7 +69,17 @@ func GetFromPid(pid int) (NsHandle, error) { return NsHandle(fd), nil } -// GetFromName gets a handle to the network namespace of a docker container. +// GetFromThread gets a handle to the network namespace of a given pid and tid. +func GetFromThread(pid, tid int) (NsHandle, error) { + name := fmt.Sprintf("/proc/%d/task/%d/ns/net", pid, tid) + fd, err := syscall.Open(name, syscall.O_RDONLY, 0) + if err != nil { + return -1, err + } + return NsHandle(fd), nil +} + +// GetFromDocker gets a handle to the network namespace of a docker container. // Id is prefixed matched against the running docker containers, so a short // identifier can be used as long as it isn't ambiguous. func GetFromDocker(id string) (NsHandle, error) { @@ -169,7 +179,7 @@ func getPidForContainer(id string) (int, error) { return pid, fmt.Errorf("Ambiguous id supplied: %v", filenames) } else if len(filenames) == 1 { filename = filenames[0] - break; + break } } diff --git a/netns_test.go b/netns_test.go index b685b9d..e51981c 100644 --- a/netns_test.go +++ b/netns_test.go @@ -2,6 +2,7 @@ package netns import ( "runtime" + "sync" "testing" ) @@ -42,3 +43,24 @@ func TestNone(t *testing.T) { t.Fatal("None ns is open", ns) } } + +func TestThreaded(t *testing.T) { + ncpu := runtime.GOMAXPROCS(-1) + if ncpu < 2 { + t.Skip("-cpu=2 or larger required") + } + + // Lock this thread simply to ensure other threads get used. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + wg := &sync.WaitGroup{} + for i := 0; i < ncpu; i++ { + wg.Add(1) + go func() { + defer wg.Done() + TestGetNewSetDelete(t) + }() + } + wg.Wait() +}