keksvpn/vendor/golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/configure_linux.go
2022-02-27 04:22:11 +01:00

293 lines
7.9 KiB
Go

//go:build linux
// +build linux
package wglinux
import (
"encoding/binary"
"fmt"
"net"
"unsafe"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// configAttrs creates the required encoded netlink attributes to configure
// the device specified by name using the non-nil fields in cfg.
func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) {
ae := netlink.NewAttributeEncoder()
ae.String(unix.WGDEVICE_A_IFNAME, name)
if cfg.PrivateKey != nil {
ae.Bytes(unix.WGDEVICE_A_PRIVATE_KEY, (*cfg.PrivateKey)[:])
}
if cfg.ListenPort != nil {
ae.Uint16(unix.WGDEVICE_A_LISTEN_PORT, uint16(*cfg.ListenPort))
}
if cfg.FirewallMark != nil {
ae.Uint32(unix.WGDEVICE_A_FWMARK, uint32(*cfg.FirewallMark))
}
if cfg.ReplacePeers {
ae.Uint32(unix.WGDEVICE_A_FLAGS, unix.WGDEVICE_F_REPLACE_PEERS)
}
// Only apply peer attributes if necessary.
if len(cfg.Peers) > 0 {
ae.Nested(unix.WGDEVICE_A_PEERS, func(nae *netlink.AttributeEncoder) error {
// Netlink arrays use type as an array index.
for i, p := range cfg.Peers {
nae.Nested(uint16(i), encodePeer(p))
}
return nil
})
}
return ae.Encode()
}
// ipBatchChunk is a tunable allowed IP batch limit per peer.
//
// Because we don't necessarily know how much space a given peer will occupy,
// we play it safe and use a reasonably small value. Note that this constant
// is used both in this package and tests, so be aware when making changes.
const ipBatchChunk = 256
// peerBatchChunk specifies the number of peers that can appear in a
// configuration before we start splitting it into chunks.
const peerBatchChunk = 32
// shouldBatch determines if a configuration is sufficiently complex that it
// should be split into batches.
func shouldBatch(cfg wgtypes.Config) bool {
if len(cfg.Peers) > peerBatchChunk {
return true
}
var ips int
for _, p := range cfg.Peers {
ips += len(p.AllowedIPs)
}
return ips > ipBatchChunk
}
// buildBatches produces a batch of configs from a single config, if needed.
func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
// Is this a small configuration; no need to batch?
if !shouldBatch(cfg) {
return []wgtypes.Config{cfg}
}
// Use most fields of cfg for our "base" configuration, and only differ
// peers in each batch.
base := cfg
base.Peers = nil
// Track the known peers so that peer IPs are not replaced if a single
// peer has its allowed IPs split into multiple batches.
knownPeers := make(map[wgtypes.Key]struct{})
batches := make([]wgtypes.Config, 0)
for _, p := range cfg.Peers {
batch := base
// Iterate until no more allowed IPs.
var done bool
for !done {
var tmp []net.IPNet
if len(p.AllowedIPs) < ipBatchChunk {
// IPs all fit within a batch; we are done.
tmp = make([]net.IPNet, len(p.AllowedIPs))
copy(tmp, p.AllowedIPs)
done = true
} else {
// IPs are larger than a single batch, copy a batch out and
// advance the cursor.
tmp = make([]net.IPNet, ipBatchChunk)
copy(tmp, p.AllowedIPs[:ipBatchChunk])
p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
if len(p.AllowedIPs) == 0 {
// IPs ended on a batch boundary; no more IPs left so end
// iteration after this loop.
done = true
}
}
pcfg := wgtypes.PeerConfig{
// PublicKey denotes the peer and must be present.
PublicKey: p.PublicKey,
// Apply the update only flag to every chunk to ensure
// consistency between batches when the kernel module processes
// them.
UpdateOnly: p.UpdateOnly,
// It'd be a bit weird to have a remove peer message with many
// IPs, but just in case, add this to every peer's message.
Remove: p.Remove,
// The IPs for this chunk.
AllowedIPs: tmp,
}
// Only pass certain fields on the first occurrence of a peer, so
// that subsequent IPs won't be wiped out and space isn't wasted.
if _, ok := knownPeers[p.PublicKey]; !ok {
knownPeers[p.PublicKey] = struct{}{}
pcfg.PresharedKey = p.PresharedKey
pcfg.Endpoint = p.Endpoint
pcfg.PersistentKeepaliveInterval = p.PersistentKeepaliveInterval
// Important: do not move or appending peers won't work.
pcfg.ReplaceAllowedIPs = p.ReplaceAllowedIPs
}
// Add a peer configuration to this batch and keep going.
batch.Peers = []wgtypes.PeerConfig{pcfg}
batches = append(batches, batch)
}
}
// Do not allow peer replacement beyond the first message in a batch,
// so we don't overwrite our previous batch work.
for i := range batches {
if i > 0 {
batches[i].ReplacePeers = false
}
}
return batches
}
// encodePeer returns a function to encode PeerConfig nested attributes.
func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
ae.Bytes(unix.WGPEER_A_PUBLIC_KEY, p.PublicKey[:])
// Flags are stored in a single attribute.
var flags uint32
if p.Remove {
flags |= unix.WGPEER_F_REMOVE_ME
}
if p.ReplaceAllowedIPs {
flags |= unix.WGPEER_F_REPLACE_ALLOWEDIPS
}
if p.UpdateOnly {
flags |= unix.WGPEER_F_UPDATE_ONLY
}
if flags != 0 {
ae.Uint32(unix.WGPEER_A_FLAGS, flags)
}
if p.PresharedKey != nil {
ae.Bytes(unix.WGPEER_A_PRESHARED_KEY, (*p.PresharedKey)[:])
}
if p.Endpoint != nil {
ae.Do(unix.WGPEER_A_ENDPOINT, encodeSockaddr(*p.Endpoint))
}
if p.PersistentKeepaliveInterval != nil {
ae.Uint16(unix.WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, uint16(p.PersistentKeepaliveInterval.Seconds()))
}
// Only apply allowed IPs if necessary.
if len(p.AllowedIPs) > 0 {
ae.Nested(unix.WGPEER_A_ALLOWEDIPS, encodeAllowedIPs(p.AllowedIPs))
}
return nil
}
}
// encodeSockaddr returns a function which encodes a net.UDPAddr as raw
// sockaddr_in or sockaddr_in6 bytes.
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
return func() ([]byte, error) {
if !isValidIP(endpoint.IP) {
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
}
// Is this an IPv6 address?
if isIPv6(endpoint.IP) {
var addr [16]byte
copy(addr[:], endpoint.IP.To16())
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
}
// IPv4 address handling.
var addr [4]byte
copy(addr[:], endpoint.IP.To4())
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
}
return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
}
}
// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
for i, ipn := range ipns {
if !isValidIP(ipn.IP) {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
}
family := uint16(unix.AF_INET6)
if !isIPv6(ipn.IP) {
// Make sure address is 4 bytes if IPv4.
family = unix.AF_INET
ipn.IP = ipn.IP.To4()
}
// Netlink arrays use type as an array index.
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family)
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP)
ones, _ := ipn.Mask.Size()
nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones))
return nil
})
}
return nil
}
}
// isValidIP determines if IP is a valid IPv4 or IPv6 address.
func isValidIP(ip net.IP) bool {
return ip.To16() != nil
}
// isIPv6 determines if IP is a valid IPv6 address.
func isIPv6(ip net.IP) bool {
return isValidIP(ip) && ip.To4() == nil
}
// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
// structures to the kernel.
func sockaddrPort(port int) uint16 {
return binary.BigEndian.Uint16(nlenc.Uint16Bytes(uint16(port)))
}