package genetlink import ( "syscall" "time" "github.com/mdlayher/netlink" "golang.org/x/net/bpf" ) // Protocol is the netlink protocol constant used to specify generic netlink. const Protocol = 0x10 // unix.NETLINK_GENERIC // A Conn is a generic netlink connection. A Conn can be used to send and // receive generic netlink messages to and from netlink. // // A Conn is safe for concurrent use, but to avoid contention in // high-throughput applications, the caller should almost certainly create a // pool of Conns and distribute them among workers. type Conn struct { // Operating system-specific netlink connection. c *netlink.Conn } // Dial dials a generic netlink connection. Config specifies optional // configuration for the underlying netlink connection. If config is // nil, a default configuration will be used. func Dial(config *netlink.Config) (*Conn, error) { c, err := netlink.Dial(Protocol, config) if err != nil { return nil, err } return NewConn(c), nil } // NewConn creates a Conn that wraps an existing *netlink.Conn for // generic netlink communications. // // NewConn is primarily useful for tests. Most applications should use // Dial instead. func NewConn(c *netlink.Conn) *Conn { return &Conn{c: c} } // Close closes the connection and unblocks any pending read operations. func (c *Conn) Close() error { return c.c.Close() } // GetFamily retrieves a generic netlink family with the specified name. // // If the family does not exist, the error value can be checked using // `errors.Is(err, os.ErrNotExist)`. func (c *Conn) GetFamily(name string) (Family, error) { return c.getFamily(name) } // ListFamilies retrieves all registered generic netlink families. func (c *Conn) ListFamilies() ([]Family, error) { return c.listFamilies() } // JoinGroup joins a netlink multicast group by its ID. func (c *Conn) JoinGroup(group uint32) error { return c.c.JoinGroup(group) } // LeaveGroup leaves a netlink multicast group by its ID. func (c *Conn) LeaveGroup(group uint32) error { return c.c.LeaveGroup(group) } // SetBPF attaches an assembled BPF program to a Conn. func (c *Conn) SetBPF(filter []bpf.RawInstruction) error { return c.c.SetBPF(filter) } // RemoveBPF removes a BPF filter from a Conn. func (c *Conn) RemoveBPF() error { return c.c.RemoveBPF() } // SetOption enables or disables a netlink socket option for the Conn. func (c *Conn) SetOption(option netlink.ConnOption, enable bool) error { return c.c.SetOption(option, enable) } // SetReadBuffer sets the size of the operating system's receive buffer // associated with the Conn. func (c *Conn) SetReadBuffer(bytes int) error { return c.c.SetReadBuffer(bytes) } // SetWriteBuffer sets the size of the operating system's transmit buffer // associated with the Conn. func (c *Conn) SetWriteBuffer(bytes int) error { return c.c.SetWriteBuffer(bytes) } // SyscallConn returns a raw network connection. This implements the // syscall.Conn interface. // // SyscallConn is intended for advanced use cases, such as getting and setting // arbitrary socket options using the netlink socket's file descriptor. // // Once invoked, it is the caller's responsibility to ensure that operations // performed using Conn and the syscall.RawConn do not conflict with // each other. func (c *Conn) SyscallConn() (syscall.RawConn, error) { return c.c.SyscallConn() } // SetDeadline sets the read and write deadlines associated with the connection. func (c *Conn) SetDeadline(t time.Time) error { return c.c.SetDeadline(t) } // SetReadDeadline sets the read deadline associated with the connection. func (c *Conn) SetReadDeadline(t time.Time) error { return c.c.SetReadDeadline(t) } // SetWriteDeadline sets the write deadline associated with the connection. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) } // Send sends a single Message to netlink, wrapping it in a netlink.Message // using the specified generic netlink family and flags. On success, Send // returns a copy of the netlink.Message with all parameters populated, for // later validation. func (c *Conn) Send(m Message, family uint16, flags netlink.HeaderFlags) (netlink.Message, error) { nm, err := packMessage(m, family, flags) if err != nil { return netlink.Message{}, err } reqnm, err := c.c.Send(nm) if err != nil { return netlink.Message{}, err } return reqnm, nil } // Receive receives one or more Messages from netlink. The netlink.Messages // used to wrap each Message are available for later validation. func (c *Conn) Receive() ([]Message, []netlink.Message, error) { msgs, err := c.c.Receive() if err != nil { return nil, nil, err } gmsgs, err := unpackMessages(msgs) if err != nil { return nil, nil, err } return gmsgs, msgs, nil } // Execute sends a single Message to netlink using Send, receives one or more // replies using Receive, and then checks the validity of the replies against // the request using netlink.Validate. // // Execute acquires a lock for the duration of the function call which blocks // concurrent calls to Send and Receive, in order to ensure consistency between // generic netlink request/reply messages. // // See the documentation of Send, Receive, and netlink.Validate for details // about each function. func (c *Conn) Execute(m Message, family uint16, flags netlink.HeaderFlags) ([]Message, error) { nm, err := packMessage(m, family, flags) if err != nil { return nil, err } // Locking behavior handled by netlink.Conn.Execute. msgs, err := c.c.Execute(nm) if err != nil { return nil, err } return unpackMessages(msgs) } // packMessage packs a generic netlink Message into a netlink.Message with the // appropriate generic netlink family and netlink flags. func packMessage(m Message, family uint16, flags netlink.HeaderFlags) (netlink.Message, error) { nm := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(family), Flags: flags, }, } mb, err := m.MarshalBinary() if err != nil { return netlink.Message{}, err } nm.Data = mb return nm, nil } // unpackMessages unpacks generic netlink Messages from a slice of netlink.Messages. func unpackMessages(msgs []netlink.Message) ([]Message, error) { gmsgs := make([]Message, 0, len(msgs)) for _, nm := range msgs { var gm Message if err := (&gm).UnmarshalBinary(nm.Data); err != nil { return nil, err } gmsgs = append(gmsgs, gm) } return gmsgs, nil }