296 lines
8.3 KiB
Go
296 lines
8.3 KiB
Go
package wgwindows
|
|
|
|
import (
|
|
"net"
|
|
"os"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"golang.org/x/sys/windows"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
|
|
"golang.zx2c4.com/wireguard/wgctrl/internal/wgwindows/internal/ioctl"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
var _ wginternal.Client = &Client{}
|
|
|
|
// A Client provides access to WireGuardNT ioctl information.
|
|
type Client struct {
|
|
cachedAdapters map[string]string
|
|
lastLenGuess uint32
|
|
}
|
|
|
|
var (
|
|
deviceClassNetGUID = windows.GUID{0x4d36e972, 0xe325, 0x11ce, [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
|
|
deviceInterfaceNetGUID = windows.GUID{0xcac88484, 0x7515, 0x4c03, [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}}
|
|
devpkeyWgName = windows.DEVPROPKEY{
|
|
FmtID: windows.DEVPROPGUID{0x65726957, 0x7547, 0x7261, [8]byte{0x64, 0x4e, 0x61, 0x6d, 0x65, 0x4b, 0x65, 0x79}},
|
|
PID: windows.DEVPROPID_FIRST_USABLE + 1,
|
|
}
|
|
)
|
|
|
|
var enumerator = `SWD\WireGuard`
|
|
|
|
func init() {
|
|
if maj, min, _ := windows.RtlGetNtVersionNumbers(); (maj == 6 && min <= 1) || maj < 6 {
|
|
enumerator = `ROOT\WIREGUARD`
|
|
}
|
|
}
|
|
|
|
func (c *Client) refreshInstanceIdCache() error {
|
|
cachedAdapters := make(map[string]string, 5)
|
|
devInfo, err := windows.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, 0, windows.DIGCF_PRESENT, 0, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer windows.SetupDiDestroyDeviceInfoList(devInfo)
|
|
for i := 0; ; i++ {
|
|
devInfoData, err := windows.SetupDiEnumDeviceInfo(devInfo, i)
|
|
if err != nil {
|
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
prop, err := windows.SetupDiGetDeviceProperty(devInfo, devInfoData, &devpkeyWgName)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
adapterName, ok := prop.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
var status, problemCode uint32
|
|
ret := windows.CM_Get_DevNode_Status(&status, &problemCode, devInfoData.DevInst, 0)
|
|
if ret != nil || (status&windows.DN_DRIVER_LOADED|windows.DN_STARTED) != windows.DN_DRIVER_LOADED|windows.DN_STARTED {
|
|
continue
|
|
}
|
|
instanceId, err := windows.SetupDiGetDeviceInstanceId(devInfo, devInfoData)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
cachedAdapters[adapterName] = instanceId
|
|
}
|
|
c.cachedAdapters = cachedAdapters
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) interfaceHandle(name string) (windows.Handle, error) {
|
|
instanceId, ok := c.cachedAdapters[name]
|
|
if !ok {
|
|
err := c.refreshInstanceIdCache()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
instanceId, ok = c.cachedAdapters[name]
|
|
if !ok {
|
|
return 0, os.ErrNotExist
|
|
}
|
|
}
|
|
interfaces, err := windows.CM_Get_Device_Interface_List(instanceId, &deviceInterfaceNetGUID, windows.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
interface16, err := windows.UTF16PtrFromString(interfaces[0])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return windows.CreateFile(interface16, windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
|
|
}
|
|
|
|
// Devices implements wginternal.Client.
|
|
func (c *Client) Devices() ([]*wgtypes.Device, error) {
|
|
err := c.refreshInstanceIdCache()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ds := make([]*wgtypes.Device, 0, len(c.cachedAdapters))
|
|
for name := range c.cachedAdapters {
|
|
d, err := c.Device(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ds = append(ds, d)
|
|
}
|
|
return ds, nil
|
|
}
|
|
|
|
// New creates a new Client
|
|
func New() *Client {
|
|
return &Client{}
|
|
}
|
|
|
|
// Close implements wginternal.Client.
|
|
func (c *Client) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// Device implements wginternal.Client.
|
|
func (c *Client) Device(name string) (*wgtypes.Device, error) {
|
|
handle, err := c.interfaceHandle(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer windows.CloseHandle(handle)
|
|
|
|
size := c.lastLenGuess
|
|
if size == 0 {
|
|
size = 512
|
|
}
|
|
var buf []byte
|
|
for {
|
|
buf = make([]byte, size)
|
|
err = windows.DeviceIoControl(handle, ioctl.IoctlGet, nil, 0, &buf[0], size, &size, nil)
|
|
if err == windows.ERROR_MORE_DATA {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
break
|
|
}
|
|
c.lastLenGuess = size
|
|
interfaze := (*ioctl.Interface)(unsafe.Pointer(&buf[0]))
|
|
|
|
device := wgtypes.Device{Type: wgtypes.WindowsKernel, Name: name}
|
|
if interfaze.Flags&ioctl.InterfaceHasPrivateKey != 0 {
|
|
device.PrivateKey = interfaze.PrivateKey
|
|
}
|
|
if interfaze.Flags&ioctl.InterfaceHasPublicKey != 0 {
|
|
device.PublicKey = interfaze.PublicKey
|
|
}
|
|
if interfaze.Flags&ioctl.InterfaceHasListenPort != 0 {
|
|
device.ListenPort = int(interfaze.ListenPort)
|
|
}
|
|
var p *ioctl.Peer
|
|
for i := uint32(0); i < interfaze.PeerCount; i++ {
|
|
if p == nil {
|
|
p = interfaze.FirstPeer()
|
|
} else {
|
|
p = p.NextPeer()
|
|
}
|
|
peer := wgtypes.Peer{}
|
|
if p.Flags&ioctl.PeerHasPublicKey != 0 {
|
|
peer.PublicKey = p.PublicKey
|
|
}
|
|
if p.Flags&ioctl.PeerHasPresharedKey != 0 {
|
|
peer.PresharedKey = p.PresharedKey
|
|
}
|
|
if p.Flags&ioctl.PeerHasEndpoint != 0 {
|
|
peer.Endpoint = &net.UDPAddr{IP: p.Endpoint.IP(), Port: int(p.Endpoint.Port())}
|
|
}
|
|
if p.Flags&ioctl.PeerHasPersistentKeepalive != 0 {
|
|
peer.PersistentKeepaliveInterval = time.Duration(p.PersistentKeepalive) * time.Second
|
|
}
|
|
if p.Flags&ioctl.PeerHasProtocolVersion != 0 {
|
|
peer.ProtocolVersion = int(p.ProtocolVersion)
|
|
}
|
|
peer.TransmitBytes = int64(p.TxBytes)
|
|
peer.ReceiveBytes = int64(p.RxBytes)
|
|
if p.LastHandshake != 0 {
|
|
peer.LastHandshakeTime = time.Unix(0, int64((p.LastHandshake-116444736000000000)*100))
|
|
}
|
|
var a *ioctl.AllowedIP
|
|
for j := uint32(0); j < p.AllowedIPsCount; j++ {
|
|
if a == nil {
|
|
a = p.FirstAllowedIP()
|
|
} else {
|
|
a = a.NextAllowedIP()
|
|
}
|
|
var ip net.IP
|
|
var bits int
|
|
if a.AddressFamily == windows.AF_INET {
|
|
ip = a.Address[:4]
|
|
bits = 32
|
|
} else if a.AddressFamily == windows.AF_INET6 {
|
|
ip = a.Address[:16]
|
|
bits = 128
|
|
}
|
|
peer.AllowedIPs = append(peer.AllowedIPs, net.IPNet{
|
|
IP: ip,
|
|
Mask: net.CIDRMask(int(a.Cidr), bits),
|
|
})
|
|
}
|
|
device.Peers = append(device.Peers, peer)
|
|
}
|
|
return &device, nil
|
|
}
|
|
|
|
// ConfigureDevice implements wginternal.Client.
|
|
func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
|
|
handle, err := c.interfaceHandle(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer windows.CloseHandle(handle)
|
|
|
|
preallocation := unsafe.Sizeof(ioctl.Interface{}) + uintptr(len(cfg.Peers))*unsafe.Sizeof(ioctl.Peer{})
|
|
for i := range cfg.Peers {
|
|
preallocation += uintptr(len(cfg.Peers[i].AllowedIPs)) * unsafe.Sizeof(ioctl.AllowedIP{})
|
|
}
|
|
var b ioctl.ConfigBuilder
|
|
b.Preallocate(uint32(preallocation))
|
|
interfaze := &ioctl.Interface{PeerCount: uint32(len(cfg.Peers))}
|
|
if cfg.ReplacePeers {
|
|
interfaze.Flags |= ioctl.InterfaceReplacePeers
|
|
}
|
|
if cfg.PrivateKey != nil {
|
|
interfaze.PrivateKey = *cfg.PrivateKey
|
|
interfaze.Flags |= ioctl.InterfaceHasPrivateKey
|
|
}
|
|
if cfg.ListenPort != nil {
|
|
interfaze.ListenPort = uint16(*cfg.ListenPort)
|
|
interfaze.Flags |= ioctl.InterfaceHasListenPort
|
|
}
|
|
b.AppendInterface(interfaze)
|
|
for i := range cfg.Peers {
|
|
peer := &ioctl.Peer{
|
|
Flags: ioctl.PeerHasPublicKey,
|
|
PublicKey: cfg.Peers[i].PublicKey,
|
|
AllowedIPsCount: uint32(len(cfg.Peers[i].AllowedIPs)),
|
|
}
|
|
if cfg.Peers[i].ReplaceAllowedIPs {
|
|
peer.Flags |= ioctl.PeerReplaceAllowedIPs
|
|
}
|
|
if cfg.Peers[i].UpdateOnly {
|
|
peer.Flags |= ioctl.PeerUpdateOnly
|
|
}
|
|
if cfg.Peers[i].Remove {
|
|
peer.Flags |= ioctl.PeerRemove
|
|
}
|
|
if cfg.Peers[i].PresharedKey != nil {
|
|
peer.Flags |= ioctl.PeerHasPresharedKey
|
|
peer.PresharedKey = *cfg.Peers[i].PresharedKey
|
|
}
|
|
if cfg.Peers[i].Endpoint != nil {
|
|
peer.Flags |= ioctl.PeerHasEndpoint
|
|
peer.Endpoint.SetIP(cfg.Peers[i].Endpoint.IP, uint16(cfg.Peers[i].Endpoint.Port))
|
|
}
|
|
if cfg.Peers[i].PersistentKeepaliveInterval != nil {
|
|
peer.Flags |= ioctl.PeerHasPersistentKeepalive
|
|
peer.PersistentKeepalive = uint16(*cfg.Peers[i].PersistentKeepaliveInterval / time.Second)
|
|
}
|
|
b.AppendPeer(peer)
|
|
for j := range cfg.Peers[i].AllowedIPs {
|
|
var family ioctl.AddressFamily
|
|
var ip net.IP
|
|
if ip = cfg.Peers[i].AllowedIPs[j].IP.To4(); ip != nil {
|
|
family = windows.AF_INET
|
|
} else if ip = cfg.Peers[i].AllowedIPs[j].IP.To16(); ip != nil {
|
|
family = windows.AF_INET6
|
|
} else {
|
|
ip = cfg.Peers[i].AllowedIPs[j].IP
|
|
}
|
|
cidr, _ := cfg.Peers[i].AllowedIPs[j].Mask.Size()
|
|
a := &ioctl.AllowedIP{
|
|
AddressFamily: family,
|
|
Cidr: uint8(cidr),
|
|
}
|
|
copy(a.Address[:], ip)
|
|
b.AppendAllowedIP(a)
|
|
}
|
|
}
|
|
interfaze, size := b.Interface()
|
|
return windows.DeviceIoControl(handle, ioctl.IoctlSet, nil, 0, (*byte)(unsafe.Pointer(interfaze)), size, &size, nil)
|
|
}
|