diff --git a/oci/oci.go b/oci/oci.go index 287bc73e..033c1fe3 100644 --- a/oci/oci.go +++ b/oci/oci.go @@ -18,6 +18,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/kubernetes-incubator/cri-o/utils" + "github.com/containernetworking/cni/pkg/ns" "golang.org/x/sys/unix" "k8s.io/kubernetes/pkg/fields" pb "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" @@ -344,6 +345,7 @@ type Container struct { annotations fields.Set image *pb.ImageSpec sandbox string + netns ns.NetNS terminal bool state *ContainerState metadata *pb.ContainerMetadata @@ -360,7 +362,7 @@ type ContainerState struct { } // NewContainer creates a container object. -func NewContainer(id string, name string, bundlePath string, logPath string, labels map[string]string, annotations map[string]string, image *pb.ImageSpec, metadata *pb.ContainerMetadata, sandbox string, terminal bool) (*Container, error) { +func NewContainer(id string, name string, bundlePath string, logPath string, netns ns.NetNS, labels map[string]string, annotations map[string]string, image *pb.ImageSpec, metadata *pb.ContainerMetadata, sandbox string, terminal bool) (*Container, error) { c := &Container{ id: id, name: name, @@ -368,6 +370,7 @@ func NewContainer(id string, name string, bundlePath string, logPath string, lab logPath: logPath, labels: labels, sandbox: sandbox, + netns: netns, terminal: terminal, metadata: metadata, annotations: annotations, @@ -421,7 +424,12 @@ func (c *Container) NetNsPath() (string, error) { if c.state == nil { return "", fmt.Errorf("container state is not populated") } - return fmt.Sprintf("/proc/%d/ns/net", c.state.Pid), nil + + if c.netns == nil { + return fmt.Sprintf("/proc/%d/ns/net", c.state.Pid), nil + } + + return c.netns.Path(), nil } // Metadata returns the metadata of the container. diff --git a/server/container_create.go b/server/container_create.go index 5445a3ab..20674e56 100644 --- a/server/container_create.go +++ b/server/container_create.go @@ -273,14 +273,20 @@ func (s *Server) createSandboxContainer(containerID string, containerName string logrus.Debugf("pod container state %+v", podInfraState) - for nsType, nsFile := range map[string]string{ - "ipc": "ipc", - "network": "net", - } { - nsPath := fmt.Sprintf("/proc/%d/ns/%s", podInfraState.Pid, nsFile) - if err := specgen.AddOrReplaceLinuxNamespace(nsType, nsPath); err != nil { - return nil, err - } + ipcNsPath := fmt.Sprintf("/proc/%d/ns/ipc", podInfraState.Pid) + if err := specgen.AddOrReplaceLinuxNamespace("ipc", ipcNsPath); err != nil { + return nil, err + } + + netNsPath := sb.netNsPath() + if netNsPath == "" { + // The sandbox does not have a permanent namespace, + // it's on the host one. + netNsPath = fmt.Sprintf("/proc/%d/ns/net", podInfraState.Pid) + } + + if err := specgen.AddOrReplaceLinuxNamespace("network", netNsPath); err != nil { + return nil, err } imageSpec := containerConfig.GetImage() @@ -336,7 +342,7 @@ func (s *Server) createSandboxContainer(containerID string, containerName string return nil, err } - container, err := oci.NewContainer(containerID, containerName, containerDir, logPath, labels, annotations, imageSpec, metadata, sb.id, containerConfig.GetTty()) + container, err := oci.NewContainer(containerID, containerName, containerDir, logPath, sb.netNs(), labels, annotations, imageSpec, metadata, sb.id, containerConfig.GetTty()) if err != nil { return nil, err } diff --git a/server/sandbox.go b/server/sandbox.go index 34174575..7e9a38b5 100644 --- a/server/sandbox.go +++ b/server/sandbox.go @@ -1,15 +1,128 @@ package server import ( + "crypto/rand" "errors" "fmt" + "os" + "path/filepath" + "sync" + "github.com/Sirupsen/logrus" "github.com/docker/docker/pkg/stringid" "github.com/kubernetes-incubator/cri-o/oci" + "github.com/containernetworking/cni/pkg/ns" "k8s.io/kubernetes/pkg/fields" pb "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" ) +type sandboxNetNs struct { + sync.Mutex + ns ns.NetNS + symlink *os.File + closed bool +} + +func (ns *sandboxNetNs) symlinkCreate(name string) error { + b := make([]byte, 4) + _, randErr := rand.Reader.Read(b) + if randErr != nil { + return randErr + } + + nsName := fmt.Sprintf("%s-%x", name, b) + symlinkPath := filepath.Join(nsRunDir, nsName) + + if err := os.Symlink(ns.ns.Path(), symlinkPath); err != nil { + return err + } + + fd, err := os.Open(symlinkPath) + if err != nil { + if removeErr := os.RemoveAll(symlinkPath); removeErr != nil { + return removeErr + } + + return err + } + + ns.symlink = fd + + return nil +} + +func (ns *sandboxNetNs) symlinkRemove() error { + if err := ns.symlink.Close(); err != nil { + return err + } + + return os.RemoveAll(ns.symlink.Name()) +} + +func isSymbolicLink(path string) (bool, error) { + fi, err := os.Lstat(path) + if err != nil { + return false, err + } + + return fi.Mode()&os.ModeSymlink == os.ModeSymlink, nil +} + +func netNsGet(nspath, name string) (*sandboxNetNs, error) { + if err := ns.IsNSorErr(nspath); err != nil { + return nil, errSandboxClosedNetNS + } + + symlink, symlinkErr := isSymbolicLink(nspath) + if symlinkErr != nil { + return nil, symlinkErr + } + + var resolvedNsPath string + if symlink { + path, err := os.Readlink(nspath) + if err != nil { + return nil, err + } + resolvedNsPath = path + } else { + resolvedNsPath = nspath + } + + netNS, err := ns.GetNS(resolvedNsPath) + if err != nil { + return nil, err + } + + netNs := &sandboxNetNs{ns: netNS, closed: false,} + + if symlink { + fd, err := os.Open(nspath) + if err != nil { + return nil, err + } + + netNs.symlink = fd + } else { + if err := netNs.symlinkCreate(name); err != nil { + return nil, err + } + } + + return netNs, nil +} + +func hostNetNsPath() (string, error) { + netNS, err := ns.GetCurrentNS() + if err != nil { + return "", err + } + + defer netNS.Close() + + return netNS.Path(), nil +} + type sandbox struct { id string name string @@ -20,6 +133,7 @@ type sandbox struct { containers oci.Store processLabel string mountLabel string + netns *sandboxNetNs metadata *pb.PodSandboxMetadata shmPath string } @@ -27,10 +141,12 @@ type sandbox struct { const ( podDefaultNamespace = "default" defaultShmSize = 64 * 1024 * 1024 + nsRunDir = "/var/run/netns" ) var ( - errSandboxIDEmpty = errors.New("PodSandboxId should not be empty") + errSandboxIDEmpty = errors.New("PodSandboxId should not be empty") + errSandboxClosedNetNS = errors.New("PodSandbox networking namespace is closed") ) func (s *sandbox) addContainer(c *oci.Container) { @@ -45,6 +161,77 @@ func (s *sandbox) removeContainer(c *oci.Container) { s.containers.Delete(c.Name()) } +func (s *sandbox) netNs() ns.NetNS { + if s.netns == nil { + return nil + } + + return s.netns.ns +} + +func (s *sandbox) netNsPath() string { + if s.netns == nil { + return "" + } + + return s.netns.symlink.Name() +} + +func (s *sandbox) netNsCreate() error { + if s.netns != nil { + return fmt.Errorf("net NS already created") + } + + netNS, err := ns.NewNS() + if err != nil { + return err + } + + s.netns = &sandboxNetNs{ + ns: netNS, + closed: false, + } + + if err := s.netns.symlinkCreate(s.name); err != nil { + logrus.Warnf("Could not create nentns symlink %v", err) + + if err := s.netns.ns.Close(); err != nil { + return err + } + + return err + } + + return nil +} + +func (s *sandbox) netNsRemove() error { + if s.netns == nil { + logrus.Warn("no networking namespace") + return nil + } + + s.netns.Lock() + defer s.netns.Unlock() + + if s.netns.closed { + // netNsRemove() can be called multiple + // times without returning an error. + return nil + } + + if err := s.netns.symlinkRemove(); err != nil { + return err + } + + if err := s.netns.ns.Close(); err != nil { + return err + } + + s.netns.closed = true + return nil +} + func (s *Server) generatePodIDandName(name string, namespace string, attempt uint32) (string, string, error) { var ( err error diff --git a/server/sandbox_remove.go b/server/sandbox_remove.go index 00c1d74f..9064b716 100644 --- a/server/sandbox_remove.go +++ b/server/sandbox_remove.go @@ -73,6 +73,10 @@ func (s *Server) RemovePodSandbox(ctx context.Context, req *pb.RemovePodSandboxR } } + if err := sb.netNsRemove(); err != nil { + return nil, fmt.Errorf("failed to remove networking namespace for sandbox %s: %v", sb.id, err) + } + // Remove the files related to the sandbox podSandboxDir := filepath.Join(s.config.SandboxDir, sb.id) if err := os.RemoveAll(podSandboxDir); err != nil { diff --git a/server/sandbox_run.go b/server/sandbox_run.go index 5b494a74..7c0d604c 100644 --- a/server/sandbox_run.go +++ b/server/sandbox_run.go @@ -17,10 +17,30 @@ import ( pb "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" ) +func (s *Server) runContainer(container *oci.Container) error { + if err := s.runtime.CreateContainer(container); err != nil { + return err + } + + if err := s.runtime.UpdateStatus(container); err != nil { + return err + } + + if err := s.runtime.StartContainer(container); err != nil { + return err + } + + if err := s.runtime.UpdateStatus(container); err != nil { + return err + } + + return nil +} + // RunPodSandbox creates and runs a pod-level sandbox. -func (s *Server) RunPodSandbox(ctx context.Context, req *pb.RunPodSandboxRequest) (*pb.RunPodSandboxResponse, error) { +func (s *Server) RunPodSandbox(ctx context.Context, req *pb.RunPodSandboxRequest) (resp *pb.RunPodSandboxResponse, err error) { logrus.Debugf("RunPodSandboxRequest %+v", req) - var processLabel, mountLabel string + var processLabel, mountLabel, netNsPath string // process req.Name name := req.GetConfig().GetMetadata().GetName() if name == "" { @@ -30,7 +50,6 @@ func (s *Server) RunPodSandbox(ctx context.Context, req *pb.RunPodSandboxRequest namespace := req.GetConfig().GetMetadata().GetNamespace() attempt := req.GetConfig().GetMetadata().GetAttempt() - var err error id, name, err := s.generatePodIDandName(name, namespace, attempt) if err != nil { return nil, err @@ -235,6 +254,34 @@ func (s *Server) RunPodSandbox(ctx context.Context, req *pb.RunPodSandboxRequest if err != nil { return nil, err } + + netNsPath, err = hostNetNsPath() + if err != nil { + return nil, err + } + } else { + // Create the sandbox network namespace + if err = sb.netNsCreate(); err != nil { + return nil, err + } + + defer func() { + if err == nil { + return + } + + if netnsErr := sb.netNsRemove(); netnsErr != nil { + logrus.Warnf("Failed to remove networking namespace: %v", netnsErr) + } + } () + + // Pass the created namespace path to the runtime + err = g.AddOrReplaceLinuxNamespace("network", sb.netNsPath()) + if err != nil { + return nil, err + } + + netNsPath = sb.netNsPath() } if req.GetConfig().GetLinux().GetSecurityContext().GetNamespaceOptions().GetHostPid() { @@ -267,40 +314,24 @@ func (s *Server) RunPodSandbox(ctx context.Context, req *pb.RunPodSandboxRequest } } - container, err := oci.NewContainer(containerID, containerName, podSandboxDir, podSandboxDir, labels, annotations, nil, nil, id, false) + container, err := oci.NewContainer(containerID, containerName, podSandboxDir, podSandboxDir, sb.netNs(), labels, annotations, nil, nil, id, false) if err != nil { return nil, err } sb.infraContainer = container - if err = s.runtime.CreateContainer(container); err != nil { - return nil, err - } - - if err = s.runtime.UpdateStatus(container); err != nil { - return nil, err - } - // setup the network podNamespace := "" - netnsPath, err := container.NetNsPath() - if err != nil { - return nil, err - } - if err = s.netPlugin.SetUpPod(netnsPath, podNamespace, id, containerName); err != nil { + if err = s.netPlugin.SetUpPod(netNsPath, podNamespace, id, containerName); err != nil { return nil, fmt.Errorf("failed to create network for container %s in sandbox %s: %v", containerName, id, err) } - if err = s.runtime.StartContainer(container); err != nil { + if err = s.runContainer(container); err != nil { return nil, err } - if err = s.runtime.UpdateStatus(container); err != nil { - return nil, err - } - - resp := &pb.RunPodSandboxResponse{PodSandboxId: &id} + resp = &pb.RunPodSandboxResponse{PodSandboxId: &id} logrus.Debugf("RunPodSandboxResponse: %+v", resp) return resp, nil } diff --git a/server/sandbox_stop.go b/server/sandbox_stop.go index 2506ad4e..47f570c2 100644 --- a/server/sandbox_stop.go +++ b/server/sandbox_stop.go @@ -35,6 +35,11 @@ func (s *Server) StopPodSandbox(ctx context.Context, req *pb.StopPodSandboxReque podInfraContainer.Name(), sb.id, err) } + // Close the sandbox networking namespace. + if err := sb.netNsRemove(); err != nil { + return nil, err + } + containers := sb.containers.List() containers = append(containers, podInfraContainer) diff --git a/server/server.go b/server/server.go index e5620c87..c0e5ccb9 100644 --- a/server/server.go +++ b/server/server.go @@ -92,7 +92,7 @@ func (s *Server) loadContainer(id string) error { return err } - ctr, err := oci.NewContainer(id, name, containerPath, m.Annotations["ocid/log_path"], labels, annotations, img, &metadata, sb.id, tty) + ctr, err := oci.NewContainer(id, name, containerPath, m.Annotations["ocid/log_path"], sb.netNs(), labels, annotations, img, &metadata, sb.id, tty) if err != nil { return err } @@ -106,6 +106,22 @@ func (s *Server) loadContainer(id string) error { return nil } +func configNetNsPath(spec rspec.Spec) (string, error) { + for _, ns := range spec.Linux.Namespaces { + if ns.Type != rspec.NetworkNamespace { + continue + } + + if ns.Path == "" { + return "", fmt.Errorf("empty networking namespace") + } + + return ns.Path, nil + } + + return "", fmt.Errorf("missing networking namespace") +} + func (s *Server) loadSandbox(id string) error { config, err := ioutil.ReadFile(filepath.Join(s.config.SandboxDir, id, "config.json")) if err != nil { @@ -151,6 +167,22 @@ func (s *Server) loadSandbox(id string) error { metadata: &metadata, shmPath: m.Annotations["ocid/shm_path"], } + + // We add a netNS only if we can load a permanent one. + // Otherwise, the sandbox will live in the host namespace. + netNsPath, err := configNetNsPath(m) + if err == nil { + netNS, nsErr := netNsGet(netNsPath, sb.name) + // If we can't load the networking namespace + // because it's closed, we just set the sb netns + // pointer to nil. Otherwise we return an error. + if nsErr != nil && nsErr != errSandboxClosedNetNS { + return nsErr + } + + sb.netns = netNS + } + s.addSandbox(sb) sandboxPath := filepath.Join(s.config.SandboxDir, id) @@ -163,7 +195,7 @@ func (s *Server) loadSandbox(id string) error { if err != nil { return err } - scontainer, err := oci.NewContainer(m.Annotations["ocid/container_id"], cname, sandboxPath, sandboxPath, labels, annotations, nil, nil, id, false) + scontainer, err := oci.NewContainer(m.Annotations["ocid/container_id"], cname, sandboxPath, sandboxPath, sb.netNs(), labels, annotations, nil, nil, id, false) if err != nil { return err } diff --git a/vendor/src/github.com/containernetworking/cni/pkg/ns/README.md b/vendor/src/github.com/containernetworking/cni/pkg/ns/README.md new file mode 100644 index 00000000..99aed9c8 --- /dev/null +++ b/vendor/src/github.com/containernetworking/cni/pkg/ns/README.md @@ -0,0 +1,34 @@ +### Namespaces, Threads, and Go +On Linux each OS thread can have a different network namespace. Go's thread scheduling model switches goroutines between OS threads based on OS thread load and whether the goroutine would block other goroutines. This can result in a goroutine switching network namespaces without notice and lead to errors in your code. + +### Namespace Switching +Switching namespaces with the `ns.Set()` method is not recommended without additional strategies to prevent unexpected namespace changes when your goroutines switch OS threads. + +Go provides the `runtime.LockOSThread()` function to ensure a specific goroutine executes on its current OS thread and prevents any other goroutine from running in that thread until the locked one exits. Careful usage of `LockOSThread()` and goroutines can provide good control over which network namespace a given goroutine executes in. + +For example, you cannot rely on the `ns.Set()` namespace being the current namespace after the `Set()` call unless you do two things. First, the goroutine calling `Set()` must have previously called `LockOSThread()`. Second, you must ensure `runtime.UnlockOSThread()` is not called somewhere in-between. You also cannot rely on the initial network namespace remaining the current network namespace if any other code in your program switches namespaces, unless you have already called `LockOSThread()` in that goroutine. Note that `LockOSThread()` prevents the Go scheduler from optimally scheduling goroutines for best performance, so `LockOSThread()` should only be used in small, isolated goroutines that release the lock quickly. + +### Do() The Recommended Thing +The `ns.Do()` method provides control over network namespaces for you by implementing these strategies. All code dependent on a particular network namespace (including the root namespace) should be wrapped in the `ns.Do()` method to ensure the correct namespace is selected for the duration of your code. For example: + +```go +targetNs, err := ns.NewNS() +if err != nil { + return err +} +err = targetNs.Do(func(hostNs ns.NetNS) error { + dummy := &netlink.Dummy{ + LinkAttrs: netlink.LinkAttrs{ + Name: "dummy0", + }, + } + return netlink.LinkAdd(dummy) +}) +``` + +Note this requirement to wrap every network call is very onerous - any libraries you call might call out to network services such as DNS, and all such calls need to be protected after you call `ns.Do()`. The CNI plugins all exit very soon after calling `ns.Do()` which helps to minimize the problem. + +### Further Reading + - https://github.com/golang/go/wiki/LockOSThread + - http://morsmachine.dk/go-scheduler + - https://github.com/containernetworking/cni/issues/262 diff --git a/vendor/src/github.com/containernetworking/cni/pkg/ns/ns.go b/vendor/src/github.com/containernetworking/cni/pkg/ns/ns.go new file mode 100644 index 00000000..3246ebf3 --- /dev/null +++ b/vendor/src/github.com/containernetworking/cni/pkg/ns/ns.go @@ -0,0 +1,302 @@ +// Copyright 2015 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ns + +import ( + "crypto/rand" + "fmt" + "os" + "path" + "runtime" + "sync" + "syscall" + + "golang.org/x/sys/unix" +) + +type NetNS interface { + // Executes the passed closure in this object's network namespace, + // attempting to restore the original namespace before returning. + // However, since each OS thread can have a different network namespace, + // and Go's thread scheduling is highly variable, callers cannot + // guarantee any specific namespace is set unless operations that + // require that namespace are wrapped with Do(). Also, no code called + // from Do() should call runtime.UnlockOSThread(), or the risk + // of executing code in an incorrect namespace will be greater. See + // https://github.com/golang/go/wiki/LockOSThread for further details. + Do(toRun func(NetNS) error) error + + // Sets the current network namespace to this object's network namespace. + // Note that since Go's thread scheduling is highly variable, callers + // cannot guarantee the requested namespace will be the current namespace + // after this function is called; to ensure this wrap operations that + // require the namespace with Do() instead. + Set() error + + // Returns the filesystem path representing this object's network namespace + Path() string + + // Returns a file descriptor representing this object's network namespace + Fd() uintptr + + // Cleans up this instance of the network namespace; if this instance + // is the last user the namespace will be destroyed + Close() error +} + +type netNS struct { + file *os.File + mounted bool + closed bool +} + +func getCurrentThreadNetNSPath() string { + // /proc/self/ns/net returns the namespace of the main thread, not + // of whatever thread this goroutine is running on. Make sure we + // use the thread's net namespace since the thread is switching around + return fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), unix.Gettid()) +} + +// Returns an object representing the current OS thread's network namespace +func GetCurrentNS() (NetNS, error) { + return GetNS(getCurrentThreadNetNSPath()) +} + +const ( + // https://github.com/torvalds/linux/blob/master/include/uapi/linux/magic.h + NSFS_MAGIC = 0x6e736673 + PROCFS_MAGIC = 0x9fa0 +) + +type NSPathNotExistErr struct{ msg string } + +func (e NSPathNotExistErr) Error() string { return e.msg } + +type NSPathNotNSErr struct{ msg string } + +func (e NSPathNotNSErr) Error() string { return e.msg } + +func IsNSorErr(nspath string) error { + stat := syscall.Statfs_t{} + if err := syscall.Statfs(nspath, &stat); err != nil { + if os.IsNotExist(err) { + err = NSPathNotExistErr{msg: fmt.Sprintf("failed to Statfs %q: %v", nspath, err)} + } else { + err = fmt.Errorf("failed to Statfs %q: %v", nspath, err) + } + return err + } + + switch stat.Type { + case PROCFS_MAGIC, NSFS_MAGIC: + return nil + default: + return NSPathNotNSErr{msg: fmt.Sprintf("unknown FS magic on %q: %x", nspath, stat.Type)} + } +} + +// Returns an object representing the namespace referred to by @path +func GetNS(nspath string) (NetNS, error) { + err := IsNSorErr(nspath) + if err != nil { + return nil, err + } + + fd, err := os.Open(nspath) + if err != nil { + return nil, err + } + + return &netNS{file: fd}, nil +} + +// Creates a new persistent network namespace and returns an object +// representing that namespace, without switching to it +func NewNS() (NetNS, error) { + const nsRunDir = "/var/run/netns" + + b := make([]byte, 16) + _, err := rand.Reader.Read(b) + if err != nil { + return nil, fmt.Errorf("failed to generate random netns name: %v", err) + } + + err = os.MkdirAll(nsRunDir, 0755) + if err != nil { + return nil, err + } + + // create an empty file at the mount point + nsName := fmt.Sprintf("cni-%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) + nsPath := path.Join(nsRunDir, nsName) + mountPointFd, err := os.Create(nsPath) + if err != nil { + return nil, err + } + mountPointFd.Close() + + // Ensure the mount point is cleaned up on errors; if the namespace + // was successfully mounted this will have no effect because the file + // is in-use + defer os.RemoveAll(nsPath) + + var wg sync.WaitGroup + wg.Add(1) + + // do namespace work in a dedicated goroutine, so that we can safely + // Lock/Unlock OSThread without upsetting the lock/unlock state of + // the caller of this function + var fd *os.File + go (func() { + defer wg.Done() + runtime.LockOSThread() + + var origNS NetNS + origNS, err = GetNS(getCurrentThreadNetNSPath()) + if err != nil { + return + } + defer origNS.Close() + + // create a new netns on the current thread + err = unix.Unshare(unix.CLONE_NEWNET) + if err != nil { + return + } + defer origNS.Set() + + // bind mount the new netns from the current thread onto the mount point + err = unix.Mount(getCurrentThreadNetNSPath(), nsPath, "none", unix.MS_BIND, "") + if err != nil { + return + } + + fd, err = os.Open(nsPath) + if err != nil { + return + } + })() + wg.Wait() + + if err != nil { + unix.Unmount(nsPath, unix.MNT_DETACH) + return nil, fmt.Errorf("failed to create namespace: %v", err) + } + + return &netNS{file: fd, mounted: true}, nil +} + +func (ns *netNS) Path() string { + return ns.file.Name() +} + +func (ns *netNS) Fd() uintptr { + return ns.file.Fd() +} + +func (ns *netNS) errorIfClosed() error { + if ns.closed { + return fmt.Errorf("%q has already been closed", ns.file.Name()) + } + return nil +} + +func (ns *netNS) Close() error { + if err := ns.errorIfClosed(); err != nil { + return err + } + + if err := ns.file.Close(); err != nil { + return fmt.Errorf("Failed to close %q: %v", ns.file.Name(), err) + } + ns.closed = true + + if ns.mounted { + if err := unix.Unmount(ns.file.Name(), unix.MNT_DETACH); err != nil { + return fmt.Errorf("Failed to unmount namespace %s: %v", ns.file.Name(), err) + } + if err := os.RemoveAll(ns.file.Name()); err != nil { + return fmt.Errorf("Failed to clean up namespace %s: %v", ns.file.Name(), err) + } + ns.mounted = false + } + + return nil +} + +func (ns *netNS) Do(toRun func(NetNS) error) error { + if err := ns.errorIfClosed(); err != nil { + return err + } + + containedCall := func(hostNS NetNS) error { + threadNS, err := GetNS(getCurrentThreadNetNSPath()) + if err != nil { + return fmt.Errorf("failed to open current netns: %v", err) + } + defer threadNS.Close() + + // switch to target namespace + if err = ns.Set(); err != nil { + return fmt.Errorf("error switching to ns %v: %v", ns.file.Name(), err) + } + defer threadNS.Set() // switch back + + return toRun(hostNS) + } + + // save a handle to current network namespace + hostNS, err := GetNS(getCurrentThreadNetNSPath()) + if err != nil { + return fmt.Errorf("Failed to open current namespace: %v", err) + } + defer hostNS.Close() + + var wg sync.WaitGroup + wg.Add(1) + + var innerError error + go func() { + defer wg.Done() + runtime.LockOSThread() + innerError = containedCall(hostNS) + }() + wg.Wait() + + return innerError +} + +func (ns *netNS) Set() error { + if err := ns.errorIfClosed(); err != nil { + return err + } + + if _, _, err := unix.Syscall(unix.SYS_SETNS, ns.Fd(), uintptr(unix.CLONE_NEWNET), 0); err != 0 { + return fmt.Errorf("Error switching to ns %v: %v", ns.file.Name(), err) + } + + return nil +} + +// WithNetNSPath executes the passed closure under the given network +// namespace, restoring the original namespace afterwards. +func WithNetNSPath(nspath string, toRun func(NetNS) error) error { + ns, err := GetNS(nspath) + if err != nil { + return err + } + defer ns.Close() + return ns.Do(toRun) +}