151 lines
3.9 KiB
Go
151 lines
3.9 KiB
Go
|
//go:build linux
|
||
|
// +build linux
|
||
|
|
||
|
package socket
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"runtime"
|
||
|
|
||
|
"golang.org/x/sync/errgroup"
|
||
|
"golang.org/x/sys/unix"
|
||
|
)
|
||
|
|
||
|
// errNetNSDisabled is returned when network namespaces are unavailable on
|
||
|
// a given system.
|
||
|
var errNetNSDisabled = errors.New("socket: Linux network namespaces are not enabled on this system")
|
||
|
|
||
|
// withNetNS invokes fn within the context of the network namespace specified by
|
||
|
// fd, while also managing the logic required to safely do so by manipulating
|
||
|
// thread-local state.
|
||
|
func withNetNS(fd int, fn func() (*Conn, error)) (*Conn, error) {
|
||
|
var (
|
||
|
eg errgroup.Group
|
||
|
conn *Conn
|
||
|
)
|
||
|
|
||
|
eg.Go(func() error {
|
||
|
// Retrieve and store the calling OS thread's network namespace so the
|
||
|
// thread can be reassigned to it after creating a socket in another network
|
||
|
// namespace.
|
||
|
runtime.LockOSThread()
|
||
|
|
||
|
ns, err := threadNetNS()
|
||
|
if err != nil {
|
||
|
// No thread-local manipulation, unlock.
|
||
|
runtime.UnlockOSThread()
|
||
|
return err
|
||
|
}
|
||
|
defer ns.Close()
|
||
|
|
||
|
// Beyond this point, the thread's network namespace is poisoned. Do not
|
||
|
// unlock the OS thread until all network namespace manipulation completes
|
||
|
// to avoid returning to the caller with altered thread-local state.
|
||
|
|
||
|
// Assign the current OS thread the goroutine is locked to to the given
|
||
|
// network namespace.
|
||
|
if err := ns.Set(fd); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Attempt Conn creation and unconditionally restore the original namespace.
|
||
|
c, err := fn()
|
||
|
if nerr := ns.Restore(); nerr != nil {
|
||
|
// Failed to restore original namespace. Return an error and allow the
|
||
|
// runtime to terminate the thread.
|
||
|
if err == nil {
|
||
|
_ = c.Close()
|
||
|
}
|
||
|
|
||
|
return nerr
|
||
|
}
|
||
|
|
||
|
// No more thread-local state manipulation; return the new Conn.
|
||
|
runtime.UnlockOSThread()
|
||
|
conn = c
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
if err := eg.Wait(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
// A netNS is a handle that can manipulate network namespaces.
|
||
|
//
|
||
|
// Operations performed on a netNS must use runtime.LockOSThread before
|
||
|
// manipulating any network namespaces.
|
||
|
type netNS struct {
|
||
|
// The handle to a network namespace.
|
||
|
f *os.File
|
||
|
|
||
|
// Indicates if network namespaces are disabled on this system, and thus
|
||
|
// operations should become a no-op or return errors.
|
||
|
disabled bool
|
||
|
}
|
||
|
|
||
|
// threadNetNS constructs a netNS using the network namespace of the calling
|
||
|
// thread. If the namespace is not the default namespace, runtime.LockOSThread
|
||
|
// should be invoked first.
|
||
|
func threadNetNS() (*netNS, error) {
|
||
|
return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid()))
|
||
|
}
|
||
|
|
||
|
// fileNetNS opens file and creates a netNS. fileNetNS should only be called
|
||
|
// directly in tests.
|
||
|
func fileNetNS(file string) (*netNS, error) {
|
||
|
f, err := os.Open(file)
|
||
|
switch {
|
||
|
case err == nil:
|
||
|
return &netNS{f: f}, nil
|
||
|
case os.IsNotExist(err):
|
||
|
// Network namespaces are not enabled on this system. Use this signal
|
||
|
// to return errors elsewhere if the caller explicitly asks for a
|
||
|
// network namespace to be set.
|
||
|
return &netNS{disabled: true}, nil
|
||
|
default:
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Close releases the handle to a network namespace.
|
||
|
func (n *netNS) Close() error {
|
||
|
return n.do(func() error { return n.f.Close() })
|
||
|
}
|
||
|
|
||
|
// FD returns a file descriptor which represents the network namespace.
|
||
|
func (n *netNS) FD() int {
|
||
|
if n.disabled {
|
||
|
// No reasonable file descriptor value in this case, so specify a
|
||
|
// non-existent one.
|
||
|
return -1
|
||
|
}
|
||
|
|
||
|
return int(n.f.Fd())
|
||
|
}
|
||
|
|
||
|
// Restore restores the original network namespace for the calling thread.
|
||
|
func (n *netNS) Restore() error {
|
||
|
return n.do(func() error { return n.Set(n.FD()) })
|
||
|
}
|
||
|
|
||
|
// Set sets a new network namespace for the current thread using fd.
|
||
|
func (n *netNS) Set(fd int) error {
|
||
|
return n.do(func() error {
|
||
|
return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// do runs fn if network namespaces are enabled on this system.
|
||
|
func (n *netNS) do(fn func() error) error {
|
||
|
if n.disabled {
|
||
|
return errNetNSDisabled
|
||
|
}
|
||
|
|
||
|
return fn()
|
||
|
}
|