diff --git a/conf/admin_windows.go b/conf/admin_windows.go new file mode 100644 index 0000000..a135aa6 --- /dev/null +++ b/conf/admin_windows.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import "golang.org/x/sys/windows/registry" + +const adminRegKey = `Software\WireGuard` + +var adminKey registry.Key + +func openAdminKey() (registry.Key, error) { + if adminKey != 0 { + return adminKey, nil + } + var err error + adminKey, err = registry.OpenKey(registry.LOCAL_MACHINE, adminRegKey, registry.QUERY_VALUE|registry.WOW64_64KEY) + if err != nil { + return 0, err + } + return adminKey, nil +} + +func AdminBool(name string) bool { + key, err := openAdminKey() + if err != nil { + return false + } + val, _, err := key.GetIntegerValue(name) + if err != nil { + return false + } + return val != 0 +} diff --git a/conf/config.go b/conf/config.go new file mode 100644 index 0000000..0a6c6ab --- /dev/null +++ b/conf/config.go @@ -0,0 +1,252 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "strings" + "time" + + "golang.org/x/crypto/curve25519" + + "golang.zx2c4.com/wireguard/windows/l18n" +) + +const KeyLength = 32 + +type IPCidr struct { + IP net.IP + Cidr uint8 +} + +type Endpoint struct { + Host string + Port uint16 +} + +type Key [KeyLength]byte +type HandshakeTime time.Duration +type Bytes uint64 + +type Config struct { + Name string + Interface Interface + Peers []Peer +} + +type Interface struct { + PrivateKey Key + Addresses []IPCidr + ListenPort uint16 + MTU uint16 + DNS []net.IP + DNSSearch []string + PreUp string + PostUp string + PreDown string + PostDown string +} + +type Peer struct { + PublicKey Key + PresharedKey Key + AllowedIPs []IPCidr + Endpoint Endpoint + PersistentKeepalive uint16 + + RxBytes Bytes + TxBytes Bytes + LastHandshakeTime HandshakeTime +} + +func (r *IPCidr) String() string { + return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr) +} + +func (r *IPCidr) Bits() uint8 { + if r.IP.To4() != nil { + return 32 + } + return 128 +} + +func (r *IPCidr) IPNet() net.IPNet { + return net.IPNet{ + IP: r.IP, + Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())), + } +} + +func (r *IPCidr) MaskSelf() { + bits := int(r.Bits()) + mask := net.CIDRMask(int(r.Cidr), bits) + for i := 0; i < bits/8; i++ { + r.IP[i] &= mask[i] + } +} + +func (e *Endpoint) String() string { + if strings.IndexByte(e.Host, ':') > 0 { + return fmt.Sprintf("[%s]:%d", e.Host, e.Port) + } + return fmt.Sprintf("%s:%d", e.Host, e.Port) +} + +func (e *Endpoint) IsEmpty() bool { + return len(e.Host) == 0 +} + +func (k *Key) String() string { + return base64.StdEncoding.EncodeToString(k[:]) +} + +func (k *Key) HexString() string { + return hex.EncodeToString(k[:]) +} + +func (k *Key) IsZero() bool { + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k *Key) Public() *Key { + var p [KeyLength]byte + curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k)) + return (*Key)(&p) +} + +func NewPresharedKey() (*Key, error) { + var k [KeyLength]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return (*Key)(&k), nil +} + +func NewPrivateKey() (*Key, error) { + k, err := NewPresharedKey() + if err != nil { + return nil, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return k, nil +} + +func NewPrivateKeyFromString(b64 string) (*Key, error) { + return parseKeyBase64(b64) +} + +func (t HandshakeTime) IsEmpty() bool { + return t == HandshakeTime(0) +} + +func (t HandshakeTime) String() string { + u := time.Unix(0, 0).Add(time.Duration(t)).Unix() + n := time.Now().Unix() + if u == n { + return l18n.Sprintf("Now") + } else if u > n { + return l18n.Sprintf("System clock wound backward!") + } + left := n - u + years := left / (365 * 24 * 60 * 60) + left = left % (365 * 24 * 60 * 60) + days := left / (24 * 60 * 60) + left = left % (24 * 60 * 60) + hours := left / (60 * 60) + left = left % (60 * 60) + minutes := left / 60 + seconds := left % 60 + s := make([]string, 0, 5) + if years > 0 { + s = append(s, l18n.Sprintf("%d year(s)", years)) + } + if days > 0 { + s = append(s, l18n.Sprintf("%d day(s)", days)) + } + if hours > 0 { + s = append(s, l18n.Sprintf("%d hour(s)", hours)) + } + if minutes > 0 { + s = append(s, l18n.Sprintf("%d minute(s)", minutes)) + } + if seconds > 0 { + s = append(s, l18n.Sprintf("%d second(s)", seconds)) + } + timestamp := strings.Join(s, l18n.UnitSeparator()) + return l18n.Sprintf("%s ago", timestamp) +} + +func (b Bytes) String() string { + if b < 1024 { + return l18n.Sprintf("%d\u00a0B", b) + } else if b < 1024*1024 { + return l18n.Sprintf("%.2f\u00a0KiB", float64(b)/1024) + } else if b < 1024*1024*1024 { + return l18n.Sprintf("%.2f\u00a0MiB", float64(b)/(1024*1024)) + } else if b < 1024*1024*1024*1024 { + return l18n.Sprintf("%.2f\u00a0GiB", float64(b)/(1024*1024*1024)) + } + return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024) +} + +func (conf *Config) DeduplicateNetworkEntries() { + m := make(map[string]bool, len(conf.Interface.Addresses)) + i := 0 + for _, addr := range conf.Interface.Addresses { + s := addr.String() + if m[s] { + continue + } + m[s] = true + conf.Interface.Addresses[i] = addr + i++ + } + conf.Interface.Addresses = conf.Interface.Addresses[:i] + + m = make(map[string]bool, len(conf.Interface.DNS)) + i = 0 + for _, addr := range conf.Interface.DNS { + s := addr.String() + if m[s] { + continue + } + m[s] = true + conf.Interface.DNS[i] = addr + i++ + } + conf.Interface.DNS = conf.Interface.DNS[:i] + + for _, peer := range conf.Peers { + m = make(map[string]bool, len(peer.AllowedIPs)) + i = 0 + for _, addr := range peer.AllowedIPs { + s := addr.String() + if m[s] { + continue + } + m[s] = true + peer.AllowedIPs[i] = addr + i++ + } + peer.AllowedIPs = peer.AllowedIPs[:i] + } +} + +func (conf *Config) Redact() { + conf.Interface.PrivateKey = Key{} + for i := range conf.Peers { + conf.Peers[i].PublicKey = Key{} + conf.Peers[i].PresharedKey = Key{} + } +} diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go new file mode 100644 index 0000000..b17be84 --- /dev/null +++ b/conf/dnsresolver_windows.go @@ -0,0 +1,90 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "fmt" + "log" + "net" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState + +func resolveHostname(name string) (resolvedIPString string, err error) { + maxTries := 10 + systemJustBooted := windows.DurationSinceBoot() <= time.Minute*4 + if systemJustBooted { + maxTries *= 4 + } + for i := 0; i < maxTries; i++ { + if i > 0 { + time.Sleep(time.Second * 4) + } + resolvedIPString, err = resolveHostnameOnce(name) + if err == nil { + return + } + if err == windows.WSATRY_AGAIN { + log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name) + continue + } + var state uint32 + if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) { + log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name) + continue + } + return + } + return +} + +func resolveHostnameOnce(name string) (resolvedIPString string, err error) { + hints := windows.AddrinfoW{ + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_IP, + } + var result *windows.AddrinfoW + name16, err := windows.UTF16PtrFromString(name) + if err != nil { + return + } + err = windows.GetAddrInfoW(name16, nil, &hints, &result) + if err != nil { + return + } + if result == nil { + err = windows.WSAHOST_NOT_FOUND + return + } + defer windows.FreeAddrInfoW(result) + ipv6 := "" + for ; result != nil; result = result.Next { + switch result.Family { + case windows.AF_INET: + return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil + case windows.AF_INET6: + if len(ipv6) != 0 { + continue + } + a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr)) + ipv6 = (net.IP)(a.Addr[:]).String() + if a.Scope_id != 0 { + ipv6 += fmt.Sprintf("%%%d", a.Scope_id) + } + } + } + if len(ipv6) != 0 { + return ipv6, nil + } + err = windows.WSAHOST_NOT_FOUND + return +} diff --git a/conf/filewriter_windows.go b/conf/filewriter_windows.go new file mode 100644 index 0000000..b0fca73 --- /dev/null +++ b/conf/filewriter_windows.go @@ -0,0 +1,89 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "crypto/rand" + "encoding/hex" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" +) + +var encryptedFileSd unsafe.Pointer + +func randomFileName() string { + var randBytes [32]byte + _, err := rand.Read(randBytes[:]) + if err != nil { + panic(err) + } + return hex.EncodeToString(randBytes[:]) + ".tmp" +} + +func writeLockedDownFile(destination string, overwrite bool, contents []byte) error { + var err error + sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))} + sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd)) + if sa.SecurityDescriptor == nil { + sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)") + if err != nil { + return err + } + atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor)) + } + destination16, err := windows.UTF16FromString(destination) + if err != nil { + return err + } + tmpDestination := randomFileName() + tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination) + if err != nil { + return err + } + handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return err + } + defer windows.CloseHandle(handle) + deleteIt := func() { + yes := byte(1) + windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1) + } + n, err := windows.Write(handle, contents) + if err != nil { + deleteIt() + return err + } + if n != len(contents) { + deleteIt() + return windows.ERROR_IO_INCOMPLETE + } + fileRenameInfo := &struct { + replaceIfExists byte + rootDirectory windows.Handle + fileNameLength uint32 + fileName [windows.MAX_PATH]uint16 + }{replaceIfExists: func() byte { + if overwrite { + return 1 + } else { + return 0 + } + }(), fileNameLength: uint32(len(destination16) - 1)} + if len(destination16) > len(fileRenameInfo.fileName) { + deleteIt() + return windows.ERROR_BUFFER_OVERFLOW + } + copy(fileRenameInfo.fileName[:], destination16[:]) + err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo))) + if err != nil { + deleteIt() + return err + } + return nil +} diff --git a/conf/migration_windows.go b/conf/migration_windows.go new file mode 100644 index 0000000..091efa3 --- /dev/null +++ b/conf/migration_windows.go @@ -0,0 +1,96 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "io" + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "golang.org/x/sys/windows" +) + +var migrating sync.Mutex +var lastMigrationTimer *time.Timer + +func MigrateUnencryptedConfigs() { migrateUnencryptedConfigs(3) } + +func migrateUnencryptedConfigs(sharingBase int) { + migrating.Lock() + defer migrating.Unlock() + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return + } + files, err := os.ReadDir(configFileDir) + if err != nil { + return + } + ignoreSharingViolations := false + for _, file := range files { + path := filepath.Join(configFileDir, file.Name()) + name := filepath.Base(file.Name()) + if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) { + continue + } + if !file.Type().IsRegular() { + continue + } + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().Perm()&0444 == 0 { + continue + } + + var bytes []byte + var config *Config + // We don't use os.ReadFile, because we actually want RDWR, so that we can take advantage + // of Windows file locking for ensuring the file is finished being written. + f, err := os.OpenFile(path, os.O_RDWR, 0) + if err != nil { + if errors.Is(err, windows.ERROR_SHARING_VIOLATION) { + if ignoreSharingViolations { + continue + } else if sharingBase > 0 { + if lastMigrationTimer != nil { + lastMigrationTimer.Stop() + } + lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { migrateUnencryptedConfigs(sharingBase - 1) }) + ignoreSharingViolations = true + continue + } + } + goto error + } + bytes, err = io.ReadAll(f) + f.Close() + if err != nil { + goto error + } + config, err = FromWgQuickWithUnknownEncoding(string(bytes), strings.TrimSuffix(name, configFileUnencryptedSuffix)) + if err != nil { + goto error + } + err = config.Save(false) + if err != nil { + goto error + } + err = os.Remove(path) + if err != nil { + log.Printf("Unable to remove old path %#q: %v", path, err) + } + continue + error: + log.Printf("Unable to ingest and encrypt %#q: %v", path, err) + } +} diff --git a/conf/mksyscall.go b/conf/mksyscall.go new file mode 100644 index 0000000..c7ed0e4 --- /dev/null +++ b/conf/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go dnsresolver_windows.go migration_windows.go storewatcher_windows.go diff --git a/conf/name.go b/conf/name.go new file mode 100644 index 0000000..a4affb4 --- /dev/null +++ b/conf/name.go @@ -0,0 +1,112 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "regexp" + "strconv" + "strings" +) + +var reservedNames = []string{ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", +} + +const serviceNameForbidden = "$" +const netshellDllForbidden = "\\/:*?\"<>|\t" +const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" + +var allowedNameFormat *regexp.Regexp + +func init() { + allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") +} + +func isReserved(name string) bool { + if len(name) == 0 { + return false + } + for _, reserved := range reservedNames { + if strings.EqualFold(name, reserved) { + return true + } + } + return false +} + +func hasSpecialChars(name string) bool { + return strings.ContainsAny(name, specialChars) || strings.ContainsAny(name, netshellDllForbidden) || strings.ContainsAny(name, serviceNameForbidden) +} + +func TunnelNameIsValid(name string) bool { + // Aside from our own restrictions, let's impose the Windows restrictions first + if isReserved(name) || hasSpecialChars(name) { + return false + } + return allowedNameFormat.MatchString(name) +} + +type naturalSortToken struct { + maybeString string + maybeNumber int +} +type naturalSortString struct { + originalString string + tokens []naturalSortToken +} + +var naturalSortDigitFinder = regexp.MustCompile(`\d+|\D+`) + +func newNaturalSortString(s string) (t naturalSortString) { + t.originalString = s + s = strings.ToLower(strings.Join(strings.Fields(s), " ")) + x := naturalSortDigitFinder.FindAllString(s, -1) + t.tokens = make([]naturalSortToken, len(x)) + for i, s := range x { + if n, err := strconv.Atoi(s); err == nil { + t.tokens[i].maybeNumber = n + } else { + t.tokens[i].maybeString = s + } + } + return +} + +func (f1 naturalSortToken) Cmp(f2 naturalSortToken) int { + if len(f1.maybeString) == 0 { + if len(f2.maybeString) > 0 || f1.maybeNumber < f2.maybeNumber { + return -1 + } else if f1.maybeNumber > f2.maybeNumber { + return 1 + } + } else if len(f2.maybeString) == 0 || f1.maybeString > f2.maybeString { + return 1 + } else if f1.maybeString < f2.maybeString { + return -1 + } + return 0 +} + +func TunnelNameIsLess(a, b string) bool { + if a == b { + return false + } + na, nb := newNaturalSortString(a), newNaturalSortString(b) + for i, t := range nb.tokens { + if i == len(na.tokens) { + return true + } + switch na.tokens[i].Cmp(t) { + case -1: + return true + case 1: + return false + } + } + return false +} diff --git a/conf/parser.go b/conf/parser.go new file mode 100644 index 0000000..bd08ee6 --- /dev/null +++ b/conf/parser.go @@ -0,0 +1,506 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "bufio" + "encoding/base64" + "encoding/hex" + "io" + "net" + "strconv" + "strings" + "time" + + "golang.org/x/text/encoding/unicode" + + "golang.zx2c4.com/wireguard/windows/l18n" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return l18n.Sprintf("%s: %q", e.why, e.offender) +} + +func parseIPCidr(s string) (ipcidr *IPCidr, err error) { + var addrStr, cidrStr string + var cidr int + + i := strings.IndexByte(s, '/') + if i < 0 { + addrStr = s + } else { + addrStr, cidrStr = s[:i], s[i+1:] + } + + err = &ParseError{l18n.Sprintf("Invalid IP address"), s} + addr := net.ParseIP(addrStr) + if addr == nil { + return + } + maybeV4 := addr.To4() + if maybeV4 != nil { + addr = maybeV4 + } + if len(cidrStr) > 0 { + err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s} + cidr, err = strconv.Atoi(cidrStr) + if err != nil || cidr < 0 || cidr > 128 { + return + } + if cidr > 32 && maybeV4 != nil { + return + } + } else { + if maybeV4 != nil { + cidr = 32 + } else { + cidr = 128 + } + } + return &IPCidr{addr, uint8(cidr)}, nil +} + +func parseEndpoint(s string) (*Endpoint, error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return nil, &ParseError{l18n.Sprintf("Missing port from endpoint"), s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return nil, &ParseError{l18n.Sprintf("Invalid endpoint host"), host} + } + port, err := parsePort(portStr) + if err != nil { + return nil, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{l18n.Sprintf("Brackets must contain an IPv6 address"), host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + end := len(host) - 1 + if i := strings.LastIndexByte(host, '%'); i > 1 { + end = i + } + maybeV6 := net.ParseIP(host[1:end]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return nil, err + } + } else { + return nil, err + } + host = host[1 : len(host)-1] + } + return &Endpoint{host, uint16(port)}, nil +} + +func parseMTU(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 576 || m > 65535 { + return 0, &ParseError{l18n.Sprintf("Invalid MTU"), s} + } + return uint16(m), nil +} + +func parsePort(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{l18n.Sprintf("Invalid port"), s} + } + return uint16(m), nil +} + +func parsePersistentKeepalive(s string) (uint16, error) { + if s == "off" { + return 0, nil + } + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{l18n.Sprintf("Invalid persistent keepalive"), s} + } + return uint16(m), nil +} + +func parseKeyBase64(s string) (*Key, error) { + k, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} + } + if len(k) != KeyLength { + return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseKeyHex(s string) (*Key, error) { + k, err := hex.DecodeString(s) + if err != nil { + return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} + } + if len(k) != KeyLength { + return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseBytesOrStamp(s string) (uint64, error) { + b, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s} + } + return b, nil +} + +func splitList(s string) ([]string, error) { + var out []string + for _, split := range strings.Split(s, ",") { + trim := strings.TrimSpace(split) + if len(trim) == 0 { + return nil, &ParseError{l18n.Sprintf("Two commas in a row"), s} + } + out = append(out, trim) + } + return out, nil +} + +type parserState int + +const ( + inInterfaceSection parserState = iota + inPeerSection + notInASection +) + +func (c *Config) maybeAddPeer(p *Peer) { + if p != nil { + c.Peers = append(c.Peers, *p) + } +} + +func FromWgQuick(s string, name string) (*Config, error) { + if !TunnelNameIsValid(name) { + return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name} + } + lines := strings.Split(s, "\n") + parserState := notInASection + conf := Config{Name: name} + sawPrivateKey := false + var peer *Peer + for _, line := range lines { + pound := strings.IndexByte(line, '#') + if pound >= 0 { + line = line[:pound] + } + line = strings.TrimSpace(line) + lineLower := strings.ToLower(line) + if len(line) == 0 { + continue + } + if lineLower == "[interface]" { + conf.maybeAddPeer(peer) + parserState = inInterfaceSection + continue + } + if lineLower == "[peer]" { + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + continue + } + if parserState == notInASection { + return nil, &ParseError{l18n.Sprintf("Line must occur in a section"), line} + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line} + } + key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) + if len(val) == 0 { + return nil, &ParseError{l18n.Sprintf("Key must have a value"), line} + } + if parserState == inInterfaceSection { + switch key { + case "privatekey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + conf.Interface.PrivateKey = *k + sawPrivateKey = true + case "listenport": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.Interface.ListenPort = p + case "mtu": + m, err := parseMTU(val) + if err != nil { + return nil, err + } + conf.Interface.MTU = m + case "address": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := parseIPCidr(address) + if err != nil { + return nil, err + } + conf.Interface.Addresses = append(conf.Interface.Addresses, *a) + } + case "dns": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a := net.ParseIP(address) + if a == nil { + conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address) + } else { + conf.Interface.DNS = append(conf.Interface.DNS, a) + } + } + case "preup": + conf.Interface.PreUp = val + case "postup": + conf.Interface.PostUp = val + case "predown": + conf.Interface.PreDown = val + case "postdown": + conf.Interface.PostDown = val + default: + return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key} + } + } else if parserState == inPeerSection { + switch key { + case "publickey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "presharedkey": + k, err := parseKeyBase64(val) + if err != nil { + return nil, err + } + peer.PresharedKey = *k + case "allowedips": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := parseIPCidr(address) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + } + case "persistentkeepalive": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + peer.Endpoint = *e + default: + return nil, &ParseError{l18n.Sprintf("Invalid key for [Peer] section"), key} + } + } + } + conf.maybeAddPeer(peer) + + if !sawPrivateKey { + return nil, &ParseError{l18n.Sprintf("An interface must have a private key"), l18n.Sprintf("[none specified]")} + } + for _, p := range conf.Peers { + if p.PublicKey.IsZero() { + return nil, &ParseError{l18n.Sprintf("All peers must have public keys"), l18n.Sprintf("[none specified]")} + } + } + + return &conf, nil +} + +func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) { + c, firstErr := FromWgQuick(s, name) + if firstErr == nil { + return c, nil + } + for _, encoding := range unicode.All { + decoded, err := encoding.NewDecoder().String(s) + if err == nil { + c, err := FromWgQuick(decoded, name) + if err == nil { + return c, nil + } + } + } + return nil, firstErr +} + +func FromUAPI(reader io.Reader, existingConfig *Config) (*Config, error) { + parserState := inInterfaceSection + conf := Config{ + Name: existingConfig.Name, + Interface: Interface{ + Addresses: existingConfig.Interface.Addresses, + DNS: existingConfig.Interface.DNS, + DNSSearch: existingConfig.Interface.DNSSearch, + MTU: existingConfig.Interface.MTU, + PreUp: existingConfig.Interface.PreUp, + PostUp: existingConfig.Interface.PostUp, + PreDown: existingConfig.Interface.PreDown, + PostDown: existingConfig.Interface.PostDown, + }, + } + var peer *Peer + lineReader := bufio.NewReader(reader) + for { + line, err := lineReader.ReadString('\n') + if err != nil { + return nil, err + } + line = line[:len(line)-1] + if len(line) == 0 { + break + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line} + } + key, val := line[:equals], line[equals+1:] + if len(val) == 0 { + return nil, &ParseError{l18n.Sprintf("Key must have a value"), line} + } + switch key { + case "public_key": + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + case "errno": + if val == "0" { + continue + } else { + return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val} + } + } + if parserState == inInterfaceSection { + switch key { + case "private_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + conf.Interface.PrivateKey = *k + case "listen_port": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.Interface.ListenPort = p + case "fwmark": + // Ignored for now. + + default: + return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key} + } + } else if parserState == inPeerSection { + switch key { + case "public_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "preshared_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PresharedKey = *k + case "protocol_version": + if val != "1" { + return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val} + } + case "allowed_ip": + a, err := parseIPCidr(val) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + case "persistent_keepalive_interval": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + peer.Endpoint = *e + case "tx_bytes": + b, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.TxBytes = Bytes(b) + case "rx_bytes": + b, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.RxBytes = Bytes(b) + case "last_handshake_time_sec": + t, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second) + case "last_handshake_time_nsec": + t, err := parseBytesOrStamp(val) + if err != nil { + return nil, err + } + peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond) + default: + return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key} + } + } + } + conf.maybeAddPeer(peer) + + return &conf, nil +} diff --git a/conf/parser_test.go b/conf/parser_test.go new file mode 100644 index 0000000..f80d6d1 --- /dev/null +++ b/conf/parser_test.go @@ -0,0 +1,128 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "net" + "reflect" + "runtime" + "testing" +) + +const testInput = ` +[Interface] +Address = 10.192.122.1/24 +Address = 10.10.0.1/16 +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +ListenPort = 51820 #comments don't matter + +[Peer] +PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +Endpoint = 192.95.5.67:1234 +AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = [2607:5300:60:6b0::c05f:543]:2468 +AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 +PersistentKeepalive = 100 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = test.wireguard.com:18981 +AllowedIPs = 10.10.10.230/32` + +func noError(t *testing.T, err error) bool { + if err == nil { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error at %s:%d: %#v", fn, line, err) + return false +} + +func equal(t *testing.T, expected, actual interface{}) bool { + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func lenTest(t *testing.T, actualO interface{}, expected int) bool { + actual := reflect.ValueOf(actualO).Len() + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func contains(t *testing.T, list, element interface{}) bool { + listValue := reflect.ValueOf(list) + for i := 0; i < listValue.Len(); i++ { + if reflect.DeepEqual(listValue.Index(i).Interface(), element) { + return true + } + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) + return false +} + +func TestFromWgQuick(t *testing.T) { + conf, err := FromWgQuick(testInput, "test") + if noError(t, err) { + + lenTest(t, conf.Interface.Addresses, 2) + contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)}) + contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)}) + equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String()) + equal(t, uint16(51820), conf.Interface.ListenPort) + + lenTest(t, conf.Peers, 3) + lenTest(t, conf.Peers[0].AllowedIPs, 2) + equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoint) + equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.String()) + + lenTest(t, conf.Peers[1].AllowedIPs, 2) + equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoint) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.String()) + equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) + + lenTest(t, conf.Peers[2].AllowedIPs, 1) + equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoint) + equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.String()) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.String()) + } +} + +func TestParseEndpoint(t *testing.T) { + _, err := parseEndpoint("[192.168.42.0:]:51880") + if err == nil { + t.Error("Error was expected") + } + e, err := parseEndpoint("192.168.42.0:51880") + if noError(t, err) { + equal(t, "192.168.42.0", e.Host) + equal(t, uint16(51880), e.Port) + } + e, err = parseEndpoint("test.wireguard.com:18981") + if noError(t, err) { + equal(t, "test.wireguard.com", e.Host) + equal(t, uint16(18981), e.Port) + } + e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") + if noError(t, err) { + equal(t, "2607:5300:60:6b0::c05f:543", e.Host) + equal(t, uint16(2468), e.Port) + } + _, err = parseEndpoint("[::::::invalid:18981") + if err == nil { + t.Error("Error was expected") + } +} diff --git a/conf/path_windows.go b/conf/path_windows.go new file mode 100644 index 0000000..e9ff783 --- /dev/null +++ b/conf/path_windows.go @@ -0,0 +1,128 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "os" + "path/filepath" + "strings" + "unsafe" + + "golang.org/x/sys/windows" +) + +var cachedConfigFileDir string +var cachedRootDir string + +func tunnelConfigurationsDirectory() (string, error) { + if cachedConfigFileDir != "" { + return cachedConfigFileDir, nil + } + root, err := RootDirectory(true) + if err != nil { + return "", err + } + c := filepath.Join(root, "Configurations") + err = os.Mkdir(c, os.ModeDir|0700) + if err != nil && !os.IsExist(err) { + return "", err + } + cachedConfigFileDir = c + return cachedConfigFileDir, nil +} + +// PresetRootDirectory causes RootDirectory() to not try any automatic deduction, and instead +// uses what's passed to it. This isn't used by wireguard-windows, but is useful for external +// consumers of our libraries who might want to do strange things. +func PresetRootDirectory(root string) { + cachedRootDir = root +} + +func RootDirectory(create bool) (string, error) { + if cachedRootDir != "" { + return cachedRootDir, nil + } + root, err := windows.KnownFolderPath(windows.FOLDERID_ProgramFiles, windows.KF_FLAG_DEFAULT) + if err != nil { + return "", err + } + root = filepath.Join(root, "WireGuard") + if !create { + return filepath.Join(root, "Data"), nil + } + root16, err := windows.UTF16PtrFromString(root) + if err != nil { + return "", err + } + + // The root directory inherits its ACL from Program Files; we don't want to mess with that + err = windows.CreateDirectory(root16, nil) + if err != nil && err != windows.ERROR_ALREADY_EXISTS { + return "", err + } + + dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)") + if err != nil { + return "", err + } + dataDirectorySa := &windows.SecurityAttributes{ + Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), + SecurityDescriptor: dataDirectorySd, + } + + data := filepath.Join(root, "Data") + data16, err := windows.UTF16PtrFromString(data) + if err != nil { + return "", err + } + var dataHandle windows.Handle + for { + err = windows.CreateDirectory(data16, dataDirectorySa) + if err != nil && err != windows.ERROR_ALREADY_EXISTS { + return "", err + } + dataHandle, err = windows.CreateFile(data16, windows.READ_CONTROL|windows.WRITE_OWNER|windows.WRITE_DAC, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_ATTRIBUTE_DIRECTORY, 0) + if err != nil && err != windows.ERROR_FILE_NOT_FOUND { + return "", err + } + if err == nil { + break + } + } + defer windows.CloseHandle(dataHandle) + var fileInfo windows.ByHandleFileInformation + err = windows.GetFileInformationByHandle(dataHandle, &fileInfo) + if err != nil { + return "", err + } + if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 { + return "", errors.New("Data directory is actually a file") + } + if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 { + return "", errors.New("Data directory is reparse point") + } + buf := make([]uint16, windows.MAX_PATH+4) + for { + bufLen, err := windows.GetFinalPathNameByHandle(dataHandle, &buf[0], uint32(len(buf)), 0) + if err != nil { + return "", err + } + if bufLen < uint32(len(buf)) { + break + } + buf = make([]uint16, bufLen) + } + if !strings.EqualFold(`\\?\`+data, windows.UTF16ToString(buf[:])) { + return "", errors.New("Data directory jumped to unexpected location") + } + err = windows.SetKernelObjectSecurity(dataHandle, windows.DACL_SECURITY_INFORMATION|windows.GROUP_SECURITY_INFORMATION|windows.OWNER_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, dataDirectorySd) + if err != nil { + return "", err + } + cachedRootDir = data + return cachedRootDir, nil +} diff --git a/conf/store.go b/conf/store.go new file mode 100644 index 0000000..f6f450c --- /dev/null +++ b/conf/store.go @@ -0,0 +1,144 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "errors" + "os" + "path/filepath" + "strings" + + "golang.zx2c4.com/wireguard/windows/conf/dpapi" +) + +const configFileSuffix = ".conf.dpapi" +const configFileUnencryptedSuffix = ".conf" + +func ListConfigNames() ([]string, error) { + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return nil, err + } + files, err := os.ReadDir(configFileDir) + if err != nil { + return nil, err + } + configs := make([]string, len(files)) + i := 0 + for _, file := range files { + name := filepath.Base(file.Name()) + if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) { + continue + } + if !file.Type().IsRegular() { + continue + } + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().Perm()&0444 == 0 { + continue + } + name = strings.TrimSuffix(name, configFileSuffix) + if !TunnelNameIsValid(name) { + continue + } + configs[i] = name + i++ + } + return configs[:i], nil +} + +func LoadFromName(name string) (*Config, error) { + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return nil, err + } + return LoadFromPath(filepath.Join(configFileDir, name+configFileSuffix)) +} + +func LoadFromPath(path string) (*Config, error) { + name, err := NameFromPath(path) + if err != nil { + return nil, err + } + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + if strings.HasSuffix(path, configFileSuffix) { + bytes, err = dpapi.Decrypt(bytes, name) + if err != nil { + return nil, err + } + } + return FromWgQuickWithUnknownEncoding(string(bytes), name) +} + +func PathIsEncrypted(path string) bool { + return strings.HasSuffix(filepath.Base(path), configFileSuffix) +} + +func NameFromPath(path string) (string, error) { + name := filepath.Base(path) + if !((len(name) > len(configFileSuffix) && strings.HasSuffix(name, configFileSuffix)) || + (len(name) > len(configFileUnencryptedSuffix) && strings.HasSuffix(name, configFileUnencryptedSuffix))) { + return "", errors.New("Path must end in either " + configFileSuffix + " or " + configFileUnencryptedSuffix) + } + if strings.HasSuffix(path, configFileSuffix) { + name = strings.TrimSuffix(name, configFileSuffix) + } else { + name = strings.TrimSuffix(name, configFileUnencryptedSuffix) + } + if !TunnelNameIsValid(name) { + return "", errors.New("Tunnel name is not valid") + } + return name, nil +} + +func (config *Config) Save(overwrite bool) error { + if !TunnelNameIsValid(config.Name) { + return errors.New("Tunnel name is not valid") + } + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return err + } + filename := filepath.Join(configFileDir, config.Name+configFileSuffix) + bytes := []byte(config.ToWgQuick()) + bytes, err = dpapi.Encrypt(bytes, config.Name) + if err != nil { + return err + } + return writeLockedDownFile(filename, overwrite, bytes) +} + +func (config *Config) Path() (string, error) { + if !TunnelNameIsValid(config.Name) { + return "", errors.New("Tunnel name is not valid") + } + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return "", err + } + return filepath.Join(configFileDir, config.Name+configFileSuffix), nil +} + +func DeleteName(name string) error { + if !TunnelNameIsValid(name) { + return errors.New("Tunnel name is not valid") + } + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return err + } + return os.Remove(filepath.Join(configFileDir, name+configFileSuffix)) +} + +func (config *Config) Delete() error { + return DeleteName(config.Name) +} diff --git a/conf/store_test.go b/conf/store_test.go new file mode 100644 index 0000000..3cec003 --- /dev/null +++ b/conf/store_test.go @@ -0,0 +1,91 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "reflect" + "testing" +) + +func TestStorage(t *testing.T) { + c, err := FromWgQuick(testInput, "golangTest") + if err != nil { + t.Errorf("Unable to parse test config: %s", err.Error()) + return + } + + err = c.Save() + if err != nil { + t.Errorf("Unable to save config: %s", err.Error()) + } + + configs, err := ListConfigNames() + if err != nil { + t.Errorf("Unable to list configs: %s", err.Error()) + } + + found := false + for _, name := range configs { + if name == "golangTest" { + found = true + break + } + } + if !found { + t.Error("Unable to find saved config in list") + } + + loaded, err := LoadFromName("golangTest") + if err != nil { + t.Errorf("Unable to load config: %s", err.Error()) + return + } + + if !reflect.DeepEqual(loaded, c) { + t.Error("Loaded config is not the same as saved config") + } + + k, err := NewPrivateKey() + if err != nil { + t.Errorf("Unable to generate new private key: %s", err.Error()) + } + c.Interface.PrivateKey = *k + + err = c.Save() + if err != nil { + t.Errorf("Unable to save config a second time: %s", err.Error()) + } + + loaded, err = LoadFromName("golangTest") + if err != nil { + t.Errorf("Unable to load config a second time: %s", err.Error()) + return + } + + if !reflect.DeepEqual(loaded, c) { + t.Error("Second loaded config is not the same as second saved config") + } + + err = DeleteName("golangTest") + if err != nil { + t.Errorf("Unable to delete config: %s", err.Error()) + } + + configs, err = ListConfigNames() + if err != nil { + t.Errorf("Unable to list configs: %s", err.Error()) + } + found = false + for _, name := range configs { + if name == "golangTest" { + found = true + break + } + } + if found { + t.Error("Config wasn't actually deleted") + } +} diff --git a/conf/storewatcher.go b/conf/storewatcher.go new file mode 100644 index 0000000..94586c3 --- /dev/null +++ b/conf/storewatcher.go @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +type StoreCallback struct { + cb func() +} + +var storeCallbacks = make(map[*StoreCallback]bool) + +func RegisterStoreChangeCallback(cb func()) *StoreCallback { + startWatchingConfigDir() + cb() + s := &StoreCallback{cb} + storeCallbacks[s] = true + return s +} + +func (cb *StoreCallback) Unregister() { + delete(storeCallbacks, cb) +} diff --git a/conf/storewatcher_windows.go b/conf/storewatcher_windows.go new file mode 100644 index 0000000..a12fb59 --- /dev/null +++ b/conf/storewatcher_windows.go @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "log" + + "golang.org/x/sys/windows" +) + +var haveStartedWatchingConfigDir bool + +func startWatchingConfigDir() { + if haveStartedWatchingConfigDir { + return + } + haveStartedWatchingConfigDir = true + go func() { + h := windows.InvalidHandle + defer func() { + if h != windows.InvalidHandle { + windows.FindCloseChangeNotification(h) + } + haveStartedWatchingConfigDir = false + }() + startover: + configFileDir, err := tunnelConfigurationsDirectory() + if err != nil { + return + } + h, err = windows.FindFirstChangeNotification(configFileDir, true, windows.FILE_NOTIFY_CHANGE_FILE_NAME|windows.FILE_NOTIFY_CHANGE_DIR_NAME|windows.FILE_NOTIFY_CHANGE_ATTRIBUTES|windows.FILE_NOTIFY_CHANGE_SIZE|windows.FILE_NOTIFY_CHANGE_LAST_WRITE|windows.FILE_NOTIFY_CHANGE_LAST_ACCESS|windows.FILE_NOTIFY_CHANGE_CREATION|windows.FILE_NOTIFY_CHANGE_SECURITY) + if err != nil { + log.Printf("Unable to monitor config directory: %v", err) + return + } + for { + s, err := windows.WaitForSingleObject(h, windows.INFINITE) + if err != nil || s == windows.WAIT_FAILED { + log.Printf("Unable to wait on config directory watcher: %v", err) + windows.FindCloseChangeNotification(h) + h = windows.InvalidHandle + goto startover + } + + for cb := range storeCallbacks { + cb.cb() + } + + err = windows.FindNextChangeNotification(h) + if err != nil { + log.Printf("Unable to monitor config directory again: %v", err) + windows.FindCloseChangeNotification(h) + h = windows.InvalidHandle + goto startover + } + } + }() +} diff --git a/conf/writer.go b/conf/writer.go new file mode 100644 index 0000000..c786ec1 --- /dev/null +++ b/conf/writer.go @@ -0,0 +1,124 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "fmt" + "strings" +) + +func (conf *Config) ToWgQuick() string { + var output strings.Builder + output.WriteString("[Interface]\n") + + output.WriteString(fmt.Sprintf("PrivateKey = %s\n", conf.Interface.PrivateKey.String())) + + if conf.Interface.ListenPort > 0 { + output.WriteString(fmt.Sprintf("ListenPort = %d\n", conf.Interface.ListenPort)) + } + + if len(conf.Interface.Addresses) > 0 { + addrStrings := make([]string, len(conf.Interface.Addresses)) + for i, address := range conf.Interface.Addresses { + addrStrings[i] = address.String() + } + output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if len(conf.Interface.DNS)+len(conf.Interface.DNSSearch) > 0 { + addrStrings := make([]string, 0, len(conf.Interface.DNS)+len(conf.Interface.DNSSearch)) + for _, address := range conf.Interface.DNS { + addrStrings = append(addrStrings, address.String()) + } + addrStrings = append(addrStrings, conf.Interface.DNSSearch...) + output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if conf.Interface.MTU > 0 { + output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.MTU)) + } + + if len(conf.Interface.PreUp) > 0 { + output.WriteString(fmt.Sprintf("PreUp = %s\n", conf.Interface.PreUp)) + } + if len(conf.Interface.PostUp) > 0 { + output.WriteString(fmt.Sprintf("PostUp = %s\n", conf.Interface.PostUp)) + } + if len(conf.Interface.PreDown) > 0 { + output.WriteString(fmt.Sprintf("PreDown = %s\n", conf.Interface.PreDown)) + } + if len(conf.Interface.PostDown) > 0 { + output.WriteString(fmt.Sprintf("PostDown = %s\n", conf.Interface.PostDown)) + } + + for _, peer := range conf.Peers { + output.WriteString("\n[Peer]\n") + + output.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey.String())) + + if !peer.PresharedKey.IsZero() { + output.WriteString(fmt.Sprintf("PresharedKey = %s\n", peer.PresharedKey.String())) + } + + if len(peer.AllowedIPs) > 0 { + addrStrings := make([]string, len(peer.AllowedIPs)) + for i, address := range peer.AllowedIPs { + addrStrings[i] = address.String() + } + output.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(addrStrings[:], ", "))) + } + + if !peer.Endpoint.IsEmpty() { + output.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint.String())) + } + + if peer.PersistentKeepalive > 0 { + output.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", peer.PersistentKeepalive)) + } + } + return output.String() +} + +func (conf *Config) ToUAPI() (uapi string, dnsErr error) { + var output strings.Builder + output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString())) + + if conf.Interface.ListenPort > 0 { + output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort)) + } + + if len(conf.Peers) > 0 { + output.WriteString("replace_peers=true\n") + } + + for _, peer := range conf.Peers { + output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString())) + + if !peer.PresharedKey.IsZero() { + output.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey.HexString())) + } + + if !peer.Endpoint.IsEmpty() { + var resolvedIP string + resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host) + if dnsErr != nil { + return + } + resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port} + output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String())) + } + + output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive)) + + if len(peer.AllowedIPs) > 0 { + output.WriteString("replace_allowed_ips=true\n") + for _, address := range peer.AllowedIPs { + output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String())) + } + } + } + return output.String(), nil +} diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go new file mode 100644 index 0000000..783411f --- /dev/null +++ b/conf/zsyscall_windows.go @@ -0,0 +1,50 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package conf + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modwininet = windows.NewLazySystemDLL("wininet.dll") + + procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState") +) + +func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) { + r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0) + connected = r0 != 0 + return +}