conf: add from wireguard-windows/conf/@b73dcdb

Signed-off-by: Vincent Batts <vbatts@hashbangbash.com>
This commit is contained in:
Vincent Batts 2021-04-21 07:03:18 -04:00
parent 56fa2f2258
commit 77158f3dde
No known key found for this signature in database
GPG Key ID: 524F155275DF0C3E
16 changed files with 1939 additions and 0 deletions

36
conf/admin_windows.go Normal file
View File

@ -0,0 +1,36 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import "golang.org/x/sys/windows/registry"
const adminRegKey = `Software\WireGuard`
var adminKey registry.Key
func openAdminKey() (registry.Key, error) {
if adminKey != 0 {
return adminKey, nil
}
var err error
adminKey, err = registry.OpenKey(registry.LOCAL_MACHINE, adminRegKey, registry.QUERY_VALUE|registry.WOW64_64KEY)
if err != nil {
return 0, err
}
return adminKey, nil
}
func AdminBool(name string) bool {
key, err := openAdminKey()
if err != nil {
return false
}
val, _, err := key.GetIntegerValue(name)
if err != nil {
return false
}
return val != 0
}

252
conf/config.go Normal file
View File

@ -0,0 +1,252 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"strings"
"time"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/windows/l18n"
)
const KeyLength = 32
type IPCidr struct {
IP net.IP
Cidr uint8
}
type Endpoint struct {
Host string
Port uint16
}
type Key [KeyLength]byte
type HandshakeTime time.Duration
type Bytes uint64
type Config struct {
Name string
Interface Interface
Peers []Peer
}
type Interface struct {
PrivateKey Key
Addresses []IPCidr
ListenPort uint16
MTU uint16
DNS []net.IP
DNSSearch []string
PreUp string
PostUp string
PreDown string
PostDown string
}
type Peer struct {
PublicKey Key
PresharedKey Key
AllowedIPs []IPCidr
Endpoint Endpoint
PersistentKeepalive uint16
RxBytes Bytes
TxBytes Bytes
LastHandshakeTime HandshakeTime
}
func (r *IPCidr) String() string {
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
}
func (r *IPCidr) Bits() uint8 {
if r.IP.To4() != nil {
return 32
}
return 128
}
func (r *IPCidr) IPNet() net.IPNet {
return net.IPNet{
IP: r.IP,
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
}
}
func (r *IPCidr) MaskSelf() {
bits := int(r.Bits())
mask := net.CIDRMask(int(r.Cidr), bits)
for i := 0; i < bits/8; i++ {
r.IP[i] &= mask[i]
}
}
func (e *Endpoint) String() string {
if strings.IndexByte(e.Host, ':') > 0 {
return fmt.Sprintf("[%s]:%d", e.Host, e.Port)
}
return fmt.Sprintf("%s:%d", e.Host, e.Port)
}
func (e *Endpoint) IsEmpty() bool {
return len(e.Host) == 0
}
func (k *Key) String() string {
return base64.StdEncoding.EncodeToString(k[:])
}
func (k *Key) HexString() string {
return hex.EncodeToString(k[:])
}
func (k *Key) IsZero() bool {
var zeros Key
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
}
func (k *Key) Public() *Key {
var p [KeyLength]byte
curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k))
return (*Key)(&p)
}
func NewPresharedKey() (*Key, error) {
var k [KeyLength]byte
_, err := rand.Read(k[:])
if err != nil {
return nil, err
}
return (*Key)(&k), nil
}
func NewPrivateKey() (*Key, error) {
k, err := NewPresharedKey()
if err != nil {
return nil, err
}
k[0] &= 248
k[31] = (k[31] & 127) | 64
return k, nil
}
func NewPrivateKeyFromString(b64 string) (*Key, error) {
return parseKeyBase64(b64)
}
func (t HandshakeTime) IsEmpty() bool {
return t == HandshakeTime(0)
}
func (t HandshakeTime) String() string {
u := time.Unix(0, 0).Add(time.Duration(t)).Unix()
n := time.Now().Unix()
if u == n {
return l18n.Sprintf("Now")
} else if u > n {
return l18n.Sprintf("System clock wound backward!")
}
left := n - u
years := left / (365 * 24 * 60 * 60)
left = left % (365 * 24 * 60 * 60)
days := left / (24 * 60 * 60)
left = left % (24 * 60 * 60)
hours := left / (60 * 60)
left = left % (60 * 60)
minutes := left / 60
seconds := left % 60
s := make([]string, 0, 5)
if years > 0 {
s = append(s, l18n.Sprintf("%d year(s)", years))
}
if days > 0 {
s = append(s, l18n.Sprintf("%d day(s)", days))
}
if hours > 0 {
s = append(s, l18n.Sprintf("%d hour(s)", hours))
}
if minutes > 0 {
s = append(s, l18n.Sprintf("%d minute(s)", minutes))
}
if seconds > 0 {
s = append(s, l18n.Sprintf("%d second(s)", seconds))
}
timestamp := strings.Join(s, l18n.UnitSeparator())
return l18n.Sprintf("%s ago", timestamp)
}
func (b Bytes) String() string {
if b < 1024 {
return l18n.Sprintf("%d\u00a0B", b)
} else if b < 1024*1024 {
return l18n.Sprintf("%.2f\u00a0KiB", float64(b)/1024)
} else if b < 1024*1024*1024 {
return l18n.Sprintf("%.2f\u00a0MiB", float64(b)/(1024*1024))
} else if b < 1024*1024*1024*1024 {
return l18n.Sprintf("%.2f\u00a0GiB", float64(b)/(1024*1024*1024))
}
return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024)
}
func (conf *Config) DeduplicateNetworkEntries() {
m := make(map[string]bool, len(conf.Interface.Addresses))
i := 0
for _, addr := range conf.Interface.Addresses {
s := addr.String()
if m[s] {
continue
}
m[s] = true
conf.Interface.Addresses[i] = addr
i++
}
conf.Interface.Addresses = conf.Interface.Addresses[:i]
m = make(map[string]bool, len(conf.Interface.DNS))
i = 0
for _, addr := range conf.Interface.DNS {
s := addr.String()
if m[s] {
continue
}
m[s] = true
conf.Interface.DNS[i] = addr
i++
}
conf.Interface.DNS = conf.Interface.DNS[:i]
for _, peer := range conf.Peers {
m = make(map[string]bool, len(peer.AllowedIPs))
i = 0
for _, addr := range peer.AllowedIPs {
s := addr.String()
if m[s] {
continue
}
m[s] = true
peer.AllowedIPs[i] = addr
i++
}
peer.AllowedIPs = peer.AllowedIPs[:i]
}
}
func (conf *Config) Redact() {
conf.Interface.PrivateKey = Key{}
for i := range conf.Peers {
conf.Peers[i].PublicKey = Key{}
conf.Peers[i].PresharedKey = Key{}
}
}

View File

@ -0,0 +1,90 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"fmt"
"log"
"net"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState
func resolveHostname(name string) (resolvedIPString string, err error) {
maxTries := 10
systemJustBooted := windows.DurationSinceBoot() <= time.Minute*4
if systemJustBooted {
maxTries *= 4
}
for i := 0; i < maxTries; i++ {
if i > 0 {
time.Sleep(time.Second * 4)
}
resolvedIPString, err = resolveHostnameOnce(name)
if err == nil {
return
}
if err == windows.WSATRY_AGAIN {
log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name)
continue
}
var state uint32
if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) {
log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name)
continue
}
return
}
return
}
func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
hints := windows.AddrinfoW{
Family: windows.AF_UNSPEC,
Socktype: windows.SOCK_DGRAM,
Protocol: windows.IPPROTO_IP,
}
var result *windows.AddrinfoW
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return
}
err = windows.GetAddrInfoW(name16, nil, &hints, &result)
if err != nil {
return
}
if result == nil {
err = windows.WSAHOST_NOT_FOUND
return
}
defer windows.FreeAddrInfoW(result)
ipv6 := ""
for ; result != nil; result = result.Next {
switch result.Family {
case windows.AF_INET:
return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
case windows.AF_INET6:
if len(ipv6) != 0 {
continue
}
a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
ipv6 = (net.IP)(a.Addr[:]).String()
if a.Scope_id != 0 {
ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
}
}
}
if len(ipv6) != 0 {
return ipv6, nil
}
err = windows.WSAHOST_NOT_FOUND
return
}

View File

@ -0,0 +1,89 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"crypto/rand"
"encoding/hex"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
var encryptedFileSd unsafe.Pointer
func randomFileName() string {
var randBytes [32]byte
_, err := rand.Read(randBytes[:])
if err != nil {
panic(err)
}
return hex.EncodeToString(randBytes[:]) + ".tmp"
}
func writeLockedDownFile(destination string, overwrite bool, contents []byte) error {
var err error
sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))}
sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd))
if sa.SecurityDescriptor == nil {
sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)")
if err != nil {
return err
}
atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor))
}
destination16, err := windows.UTF16FromString(destination)
if err != nil {
return err
}
tmpDestination := randomFileName()
tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination)
if err != nil {
return err
}
handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0)
if err != nil {
return err
}
defer windows.CloseHandle(handle)
deleteIt := func() {
yes := byte(1)
windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1)
}
n, err := windows.Write(handle, contents)
if err != nil {
deleteIt()
return err
}
if n != len(contents) {
deleteIt()
return windows.ERROR_IO_INCOMPLETE
}
fileRenameInfo := &struct {
replaceIfExists byte
rootDirectory windows.Handle
fileNameLength uint32
fileName [windows.MAX_PATH]uint16
}{replaceIfExists: func() byte {
if overwrite {
return 1
} else {
return 0
}
}(), fileNameLength: uint32(len(destination16) - 1)}
if len(destination16) > len(fileRenameInfo.fileName) {
deleteIt()
return windows.ERROR_BUFFER_OVERFLOW
}
copy(fileRenameInfo.fileName[:], destination16[:])
err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo)))
if err != nil {
deleteIt()
return err
}
return nil
}

96
conf/migration_windows.go Normal file
View File

@ -0,0 +1,96 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"errors"
"io"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/sys/windows"
)
var migrating sync.Mutex
var lastMigrationTimer *time.Timer
func MigrateUnencryptedConfigs() { migrateUnencryptedConfigs(3) }
func migrateUnencryptedConfigs(sharingBase int) {
migrating.Lock()
defer migrating.Unlock()
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return
}
files, err := os.ReadDir(configFileDir)
if err != nil {
return
}
ignoreSharingViolations := false
for _, file := range files {
path := filepath.Join(configFileDir, file.Name())
name := filepath.Base(file.Name())
if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) {
continue
}
if !file.Type().IsRegular() {
continue
}
info, err := file.Info()
if err != nil {
continue
}
if info.Mode().Perm()&0444 == 0 {
continue
}
var bytes []byte
var config *Config
// We don't use os.ReadFile, because we actually want RDWR, so that we can take advantage
// of Windows file locking for ensuring the file is finished being written.
f, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
if errors.Is(err, windows.ERROR_SHARING_VIOLATION) {
if ignoreSharingViolations {
continue
} else if sharingBase > 0 {
if lastMigrationTimer != nil {
lastMigrationTimer.Stop()
}
lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { migrateUnencryptedConfigs(sharingBase - 1) })
ignoreSharingViolations = true
continue
}
}
goto error
}
bytes, err = io.ReadAll(f)
f.Close()
if err != nil {
goto error
}
config, err = FromWgQuickWithUnknownEncoding(string(bytes), strings.TrimSuffix(name, configFileUnencryptedSuffix))
if err != nil {
goto error
}
err = config.Save(false)
if err != nil {
goto error
}
err = os.Remove(path)
if err != nil {
log.Printf("Unable to remove old path %#q: %v", path, err)
}
continue
error:
log.Printf("Unable to ingest and encrypt %#q: %v", path, err)
}
}

8
conf/mksyscall.go Normal file
View File

@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go dnsresolver_windows.go migration_windows.go storewatcher_windows.go

112
conf/name.go Normal file
View File

@ -0,0 +1,112 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"regexp"
"strconv"
"strings"
)
var reservedNames = []string{
"CON", "PRN", "AUX", "NUL",
"COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
}
const serviceNameForbidden = "$"
const netshellDllForbidden = "\\/:*?\"<>|\t"
const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00"
var allowedNameFormat *regexp.Regexp
func init() {
allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$")
}
func isReserved(name string) bool {
if len(name) == 0 {
return false
}
for _, reserved := range reservedNames {
if strings.EqualFold(name, reserved) {
return true
}
}
return false
}
func hasSpecialChars(name string) bool {
return strings.ContainsAny(name, specialChars) || strings.ContainsAny(name, netshellDllForbidden) || strings.ContainsAny(name, serviceNameForbidden)
}
func TunnelNameIsValid(name string) bool {
// Aside from our own restrictions, let's impose the Windows restrictions first
if isReserved(name) || hasSpecialChars(name) {
return false
}
return allowedNameFormat.MatchString(name)
}
type naturalSortToken struct {
maybeString string
maybeNumber int
}
type naturalSortString struct {
originalString string
tokens []naturalSortToken
}
var naturalSortDigitFinder = regexp.MustCompile(`\d+|\D+`)
func newNaturalSortString(s string) (t naturalSortString) {
t.originalString = s
s = strings.ToLower(strings.Join(strings.Fields(s), " "))
x := naturalSortDigitFinder.FindAllString(s, -1)
t.tokens = make([]naturalSortToken, len(x))
for i, s := range x {
if n, err := strconv.Atoi(s); err == nil {
t.tokens[i].maybeNumber = n
} else {
t.tokens[i].maybeString = s
}
}
return
}
func (f1 naturalSortToken) Cmp(f2 naturalSortToken) int {
if len(f1.maybeString) == 0 {
if len(f2.maybeString) > 0 || f1.maybeNumber < f2.maybeNumber {
return -1
} else if f1.maybeNumber > f2.maybeNumber {
return 1
}
} else if len(f2.maybeString) == 0 || f1.maybeString > f2.maybeString {
return 1
} else if f1.maybeString < f2.maybeString {
return -1
}
return 0
}
func TunnelNameIsLess(a, b string) bool {
if a == b {
return false
}
na, nb := newNaturalSortString(a), newNaturalSortString(b)
for i, t := range nb.tokens {
if i == len(na.tokens) {
return true
}
switch na.tokens[i].Cmp(t) {
case -1:
return true
case 1:
return false
}
}
return false
}

506
conf/parser.go Normal file
View File

@ -0,0 +1,506 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"bufio"
"encoding/base64"
"encoding/hex"
"io"
"net"
"strconv"
"strings"
"time"
"golang.org/x/text/encoding/unicode"
"golang.zx2c4.com/wireguard/windows/l18n"
)
type ParseError struct {
why string
offender string
}
func (e *ParseError) Error() string {
return l18n.Sprintf("%s: %q", e.why, e.offender)
}
func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
var addrStr, cidrStr string
var cidr int
i := strings.IndexByte(s, '/')
if i < 0 {
addrStr = s
} else {
addrStr, cidrStr = s[:i], s[i+1:]
}
err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
addr := net.ParseIP(addrStr)
if addr == nil {
return
}
maybeV4 := addr.To4()
if maybeV4 != nil {
addr = maybeV4
}
if len(cidrStr) > 0 {
err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
cidr, err = strconv.Atoi(cidrStr)
if err != nil || cidr < 0 || cidr > 128 {
return
}
if cidr > 32 && maybeV4 != nil {
return
}
} else {
if maybeV4 != nil {
cidr = 32
} else {
cidr = 128
}
}
return &IPCidr{addr, uint8(cidr)}, nil
}
func parseEndpoint(s string) (*Endpoint, error) {
i := strings.LastIndexByte(s, ':')
if i < 0 {
return nil, &ParseError{l18n.Sprintf("Missing port from endpoint"), s}
}
host, portStr := s[:i], s[i+1:]
if len(host) < 1 {
return nil, &ParseError{l18n.Sprintf("Invalid endpoint host"), host}
}
port, err := parsePort(portStr)
if err != nil {
return nil, err
}
hostColon := strings.IndexByte(host, ':')
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
err := &ParseError{l18n.Sprintf("Brackets must contain an IPv6 address"), host}
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
end := len(host) - 1
if i := strings.LastIndexByte(host, '%'); i > 1 {
end = i
}
maybeV6 := net.ParseIP(host[1:end])
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
return nil, err
}
} else {
return nil, err
}
host = host[1 : len(host)-1]
}
return &Endpoint{host, uint16(port)}, nil
}
func parseMTU(s string) (uint16, error) {
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 576 || m > 65535 {
return 0, &ParseError{l18n.Sprintf("Invalid MTU"), s}
}
return uint16(m), nil
}
func parsePort(s string) (uint16, error) {
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 0 || m > 65535 {
return 0, &ParseError{l18n.Sprintf("Invalid port"), s}
}
return uint16(m), nil
}
func parsePersistentKeepalive(s string) (uint16, error) {
if s == "off" {
return 0, nil
}
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 0 || m > 65535 {
return 0, &ParseError{l18n.Sprintf("Invalid persistent keepalive"), s}
}
return uint16(m), nil
}
func parseKeyBase64(s string) (*Key, error) {
k, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
}
if len(k) != KeyLength {
return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
}
var key Key
copy(key[:], k)
return &key, nil
}
func parseKeyHex(s string) (*Key, error) {
k, err := hex.DecodeString(s)
if err != nil {
return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
}
if len(k) != KeyLength {
return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
}
var key Key
copy(key[:], k)
return &key, nil
}
func parseBytesOrStamp(s string) (uint64, error) {
b, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s}
}
return b, nil
}
func splitList(s string) ([]string, error) {
var out []string
for _, split := range strings.Split(s, ",") {
trim := strings.TrimSpace(split)
if len(trim) == 0 {
return nil, &ParseError{l18n.Sprintf("Two commas in a row"), s}
}
out = append(out, trim)
}
return out, nil
}
type parserState int
const (
inInterfaceSection parserState = iota
inPeerSection
notInASection
)
func (c *Config) maybeAddPeer(p *Peer) {
if p != nil {
c.Peers = append(c.Peers, *p)
}
}
func FromWgQuick(s string, name string) (*Config, error) {
if !TunnelNameIsValid(name) {
return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name}
}
lines := strings.Split(s, "\n")
parserState := notInASection
conf := Config{Name: name}
sawPrivateKey := false
var peer *Peer
for _, line := range lines {
pound := strings.IndexByte(line, '#')
if pound >= 0 {
line = line[:pound]
}
line = strings.TrimSpace(line)
lineLower := strings.ToLower(line)
if len(line) == 0 {
continue
}
if lineLower == "[interface]" {
conf.maybeAddPeer(peer)
parserState = inInterfaceSection
continue
}
if lineLower == "[peer]" {
conf.maybeAddPeer(peer)
peer = &Peer{}
parserState = inPeerSection
continue
}
if parserState == notInASection {
return nil, &ParseError{l18n.Sprintf("Line must occur in a section"), line}
}
equals := strings.IndexByte(line, '=')
if equals < 0 {
return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
}
key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
if len(val) == 0 {
return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
}
if parserState == inInterfaceSection {
switch key {
case "privatekey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
conf.Interface.PrivateKey = *k
sawPrivateKey = true
case "listenport":
p, err := parsePort(val)
if err != nil {
return nil, err
}
conf.Interface.ListenPort = p
case "mtu":
m, err := parseMTU(val)
if err != nil {
return nil, err
}
conf.Interface.MTU = m
case "address":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a, err := parseIPCidr(address)
if err != nil {
return nil, err
}
conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
}
case "dns":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a := net.ParseIP(address)
if a == nil {
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
} else {
conf.Interface.DNS = append(conf.Interface.DNS, a)
}
}
case "preup":
conf.Interface.PreUp = val
case "postup":
conf.Interface.PostUp = val
case "predown":
conf.Interface.PreDown = val
case "postdown":
conf.Interface.PostDown = val
default:
return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key}
}
} else if parserState == inPeerSection {
switch key {
case "publickey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
peer.PublicKey = *k
case "presharedkey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
peer.PresharedKey = *k
case "allowedips":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a, err := parseIPCidr(address)
if err != nil {
return nil, err
}
peer.AllowedIPs = append(peer.AllowedIPs, *a)
}
case "persistentkeepalive":
p, err := parsePersistentKeepalive(val)
if err != nil {
return nil, err
}
peer.PersistentKeepalive = p
case "endpoint":
e, err := parseEndpoint(val)
if err != nil {
return nil, err
}
peer.Endpoint = *e
default:
return nil, &ParseError{l18n.Sprintf("Invalid key for [Peer] section"), key}
}
}
}
conf.maybeAddPeer(peer)
if !sawPrivateKey {
return nil, &ParseError{l18n.Sprintf("An interface must have a private key"), l18n.Sprintf("[none specified]")}
}
for _, p := range conf.Peers {
if p.PublicKey.IsZero() {
return nil, &ParseError{l18n.Sprintf("All peers must have public keys"), l18n.Sprintf("[none specified]")}
}
}
return &conf, nil
}
func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) {
c, firstErr := FromWgQuick(s, name)
if firstErr == nil {
return c, nil
}
for _, encoding := range unicode.All {
decoded, err := encoding.NewDecoder().String(s)
if err == nil {
c, err := FromWgQuick(decoded, name)
if err == nil {
return c, nil
}
}
}
return nil, firstErr
}
func FromUAPI(reader io.Reader, existingConfig *Config) (*Config, error) {
parserState := inInterfaceSection
conf := Config{
Name: existingConfig.Name,
Interface: Interface{
Addresses: existingConfig.Interface.Addresses,
DNS: existingConfig.Interface.DNS,
DNSSearch: existingConfig.Interface.DNSSearch,
MTU: existingConfig.Interface.MTU,
PreUp: existingConfig.Interface.PreUp,
PostUp: existingConfig.Interface.PostUp,
PreDown: existingConfig.Interface.PreDown,
PostDown: existingConfig.Interface.PostDown,
},
}
var peer *Peer
lineReader := bufio.NewReader(reader)
for {
line, err := lineReader.ReadString('\n')
if err != nil {
return nil, err
}
line = line[:len(line)-1]
if len(line) == 0 {
break
}
equals := strings.IndexByte(line, '=')
if equals < 0 {
return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
}
key, val := line[:equals], line[equals+1:]
if len(val) == 0 {
return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
}
switch key {
case "public_key":
conf.maybeAddPeer(peer)
peer = &Peer{}
parserState = inPeerSection
case "errno":
if val == "0" {
continue
} else {
return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val}
}
}
if parserState == inInterfaceSection {
switch key {
case "private_key":
k, err := parseKeyHex(val)
if err != nil {
return nil, err
}
conf.Interface.PrivateKey = *k
case "listen_port":
p, err := parsePort(val)
if err != nil {
return nil, err
}
conf.Interface.ListenPort = p
case "fwmark":
// Ignored for now.
default:
return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key}
}
} else if parserState == inPeerSection {
switch key {
case "public_key":
k, err := parseKeyHex(val)
if err != nil {
return nil, err
}
peer.PublicKey = *k
case "preshared_key":
k, err := parseKeyHex(val)
if err != nil {
return nil, err
}
peer.PresharedKey = *k
case "protocol_version":
if val != "1" {
return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val}
}
case "allowed_ip":
a, err := parseIPCidr(val)
if err != nil {
return nil, err
}
peer.AllowedIPs = append(peer.AllowedIPs, *a)
case "persistent_keepalive_interval":
p, err := parsePersistentKeepalive(val)
if err != nil {
return nil, err
}
peer.PersistentKeepalive = p
case "endpoint":
e, err := parseEndpoint(val)
if err != nil {
return nil, err
}
peer.Endpoint = *e
case "tx_bytes":
b, err := parseBytesOrStamp(val)
if err != nil {
return nil, err
}
peer.TxBytes = Bytes(b)
case "rx_bytes":
b, err := parseBytesOrStamp(val)
if err != nil {
return nil, err
}
peer.RxBytes = Bytes(b)
case "last_handshake_time_sec":
t, err := parseBytesOrStamp(val)
if err != nil {
return nil, err
}
peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second)
case "last_handshake_time_nsec":
t, err := parseBytesOrStamp(val)
if err != nil {
return nil, err
}
peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond)
default:
return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key}
}
}
}
conf.maybeAddPeer(peer)
return &conf, nil
}

128
conf/parser_test.go Normal file
View File

@ -0,0 +1,128 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"net"
"reflect"
"runtime"
"testing"
)
const testInput = `
[Interface]
Address = 10.192.122.1/24
Address = 10.10.0.1/16
PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
ListenPort = 51820 #comments don't matter
[Peer]
PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
Endpoint = 192.95.5.67:1234
AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
[Peer]
PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
Endpoint = [2607:5300:60:6b0::c05f:543]:2468
AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
PersistentKeepalive = 100
[Peer]
PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
Endpoint = test.wireguard.com:18981
AllowedIPs = 10.10.10.230/32`
func noError(t *testing.T, err error) bool {
if err == nil {
return true
}
_, fn, line, _ := runtime.Caller(1)
t.Errorf("Error at %s:%d: %#v", fn, line, err)
return false
}
func equal(t *testing.T, expected, actual interface{}) bool {
if reflect.DeepEqual(expected, actual) {
return true
}
_, fn, line, _ := runtime.Caller(1)
t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
return false
}
func lenTest(t *testing.T, actualO interface{}, expected int) bool {
actual := reflect.ValueOf(actualO).Len()
if reflect.DeepEqual(expected, actual) {
return true
}
_, fn, line, _ := runtime.Caller(1)
t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
return false
}
func contains(t *testing.T, list, element interface{}) bool {
listValue := reflect.ValueOf(list)
for i := 0; i < listValue.Len(); i++ {
if reflect.DeepEqual(listValue.Index(i).Interface(), element) {
return true
}
}
_, fn, line, _ := runtime.Caller(1)
t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element)
return false
}
func TestFromWgQuick(t *testing.T) {
conf, err := FromWgQuick(testInput, "test")
if noError(t, err) {
lenTest(t, conf.Interface.Addresses, 2)
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
equal(t, uint16(51820), conf.Interface.ListenPort)
lenTest(t, conf.Peers, 3)
lenTest(t, conf.Peers[0].AllowedIPs, 2)
equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoint)
equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.String())
lenTest(t, conf.Peers[1].AllowedIPs, 2)
equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoint)
equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.String())
equal(t, uint16(100), conf.Peers[1].PersistentKeepalive)
lenTest(t, conf.Peers[2].AllowedIPs, 1)
equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoint)
equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.String())
equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.String())
}
}
func TestParseEndpoint(t *testing.T) {
_, err := parseEndpoint("[192.168.42.0:]:51880")
if err == nil {
t.Error("Error was expected")
}
e, err := parseEndpoint("192.168.42.0:51880")
if noError(t, err) {
equal(t, "192.168.42.0", e.Host)
equal(t, uint16(51880), e.Port)
}
e, err = parseEndpoint("test.wireguard.com:18981")
if noError(t, err) {
equal(t, "test.wireguard.com", e.Host)
equal(t, uint16(18981), e.Port)
}
e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468")
if noError(t, err) {
equal(t, "2607:5300:60:6b0::c05f:543", e.Host)
equal(t, uint16(2468), e.Port)
}
_, err = parseEndpoint("[::::::invalid:18981")
if err == nil {
t.Error("Error was expected")
}
}

128
conf/path_windows.go Normal file
View File

@ -0,0 +1,128 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"errors"
"os"
"path/filepath"
"strings"
"unsafe"
"golang.org/x/sys/windows"
)
var cachedConfigFileDir string
var cachedRootDir string
func tunnelConfigurationsDirectory() (string, error) {
if cachedConfigFileDir != "" {
return cachedConfigFileDir, nil
}
root, err := RootDirectory(true)
if err != nil {
return "", err
}
c := filepath.Join(root, "Configurations")
err = os.Mkdir(c, os.ModeDir|0700)
if err != nil && !os.IsExist(err) {
return "", err
}
cachedConfigFileDir = c
return cachedConfigFileDir, nil
}
// PresetRootDirectory causes RootDirectory() to not try any automatic deduction, and instead
// uses what's passed to it. This isn't used by wireguard-windows, but is useful for external
// consumers of our libraries who might want to do strange things.
func PresetRootDirectory(root string) {
cachedRootDir = root
}
func RootDirectory(create bool) (string, error) {
if cachedRootDir != "" {
return cachedRootDir, nil
}
root, err := windows.KnownFolderPath(windows.FOLDERID_ProgramFiles, windows.KF_FLAG_DEFAULT)
if err != nil {
return "", err
}
root = filepath.Join(root, "WireGuard")
if !create {
return filepath.Join(root, "Data"), nil
}
root16, err := windows.UTF16PtrFromString(root)
if err != nil {
return "", err
}
// The root directory inherits its ACL from Program Files; we don't want to mess with that
err = windows.CreateDirectory(root16, nil)
if err != nil && err != windows.ERROR_ALREADY_EXISTS {
return "", err
}
dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)")
if err != nil {
return "", err
}
dataDirectorySa := &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: dataDirectorySd,
}
data := filepath.Join(root, "Data")
data16, err := windows.UTF16PtrFromString(data)
if err != nil {
return "", err
}
var dataHandle windows.Handle
for {
err = windows.CreateDirectory(data16, dataDirectorySa)
if err != nil && err != windows.ERROR_ALREADY_EXISTS {
return "", err
}
dataHandle, err = windows.CreateFile(data16, windows.READ_CONTROL|windows.WRITE_OWNER|windows.WRITE_DAC, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_ATTRIBUTE_DIRECTORY, 0)
if err != nil && err != windows.ERROR_FILE_NOT_FOUND {
return "", err
}
if err == nil {
break
}
}
defer windows.CloseHandle(dataHandle)
var fileInfo windows.ByHandleFileInformation
err = windows.GetFileInformationByHandle(dataHandle, &fileInfo)
if err != nil {
return "", err
}
if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 {
return "", errors.New("Data directory is actually a file")
}
if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 {
return "", errors.New("Data directory is reparse point")
}
buf := make([]uint16, windows.MAX_PATH+4)
for {
bufLen, err := windows.GetFinalPathNameByHandle(dataHandle, &buf[0], uint32(len(buf)), 0)
if err != nil {
return "", err
}
if bufLen < uint32(len(buf)) {
break
}
buf = make([]uint16, bufLen)
}
if !strings.EqualFold(`\\?\`+data, windows.UTF16ToString(buf[:])) {
return "", errors.New("Data directory jumped to unexpected location")
}
err = windows.SetKernelObjectSecurity(dataHandle, windows.DACL_SECURITY_INFORMATION|windows.GROUP_SECURITY_INFORMATION|windows.OWNER_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, dataDirectorySd)
if err != nil {
return "", err
}
cachedRootDir = data
return cachedRootDir, nil
}

144
conf/store.go Normal file
View File

@ -0,0 +1,144 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"errors"
"os"
"path/filepath"
"strings"
"golang.zx2c4.com/wireguard/windows/conf/dpapi"
)
const configFileSuffix = ".conf.dpapi"
const configFileUnencryptedSuffix = ".conf"
func ListConfigNames() ([]string, error) {
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return nil, err
}
files, err := os.ReadDir(configFileDir)
if err != nil {
return nil, err
}
configs := make([]string, len(files))
i := 0
for _, file := range files {
name := filepath.Base(file.Name())
if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) {
continue
}
if !file.Type().IsRegular() {
continue
}
info, err := file.Info()
if err != nil {
continue
}
if info.Mode().Perm()&0444 == 0 {
continue
}
name = strings.TrimSuffix(name, configFileSuffix)
if !TunnelNameIsValid(name) {
continue
}
configs[i] = name
i++
}
return configs[:i], nil
}
func LoadFromName(name string) (*Config, error) {
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return nil, err
}
return LoadFromPath(filepath.Join(configFileDir, name+configFileSuffix))
}
func LoadFromPath(path string) (*Config, error) {
name, err := NameFromPath(path)
if err != nil {
return nil, err
}
bytes, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if strings.HasSuffix(path, configFileSuffix) {
bytes, err = dpapi.Decrypt(bytes, name)
if err != nil {
return nil, err
}
}
return FromWgQuickWithUnknownEncoding(string(bytes), name)
}
func PathIsEncrypted(path string) bool {
return strings.HasSuffix(filepath.Base(path), configFileSuffix)
}
func NameFromPath(path string) (string, error) {
name := filepath.Base(path)
if !((len(name) > len(configFileSuffix) && strings.HasSuffix(name, configFileSuffix)) ||
(len(name) > len(configFileUnencryptedSuffix) && strings.HasSuffix(name, configFileUnencryptedSuffix))) {
return "", errors.New("Path must end in either " + configFileSuffix + " or " + configFileUnencryptedSuffix)
}
if strings.HasSuffix(path, configFileSuffix) {
name = strings.TrimSuffix(name, configFileSuffix)
} else {
name = strings.TrimSuffix(name, configFileUnencryptedSuffix)
}
if !TunnelNameIsValid(name) {
return "", errors.New("Tunnel name is not valid")
}
return name, nil
}
func (config *Config) Save(overwrite bool) error {
if !TunnelNameIsValid(config.Name) {
return errors.New("Tunnel name is not valid")
}
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return err
}
filename := filepath.Join(configFileDir, config.Name+configFileSuffix)
bytes := []byte(config.ToWgQuick())
bytes, err = dpapi.Encrypt(bytes, config.Name)
if err != nil {
return err
}
return writeLockedDownFile(filename, overwrite, bytes)
}
func (config *Config) Path() (string, error) {
if !TunnelNameIsValid(config.Name) {
return "", errors.New("Tunnel name is not valid")
}
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return "", err
}
return filepath.Join(configFileDir, config.Name+configFileSuffix), nil
}
func DeleteName(name string) error {
if !TunnelNameIsValid(name) {
return errors.New("Tunnel name is not valid")
}
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return err
}
return os.Remove(filepath.Join(configFileDir, name+configFileSuffix))
}
func (config *Config) Delete() error {
return DeleteName(config.Name)
}

91
conf/store_test.go Normal file
View File

@ -0,0 +1,91 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"reflect"
"testing"
)
func TestStorage(t *testing.T) {
c, err := FromWgQuick(testInput, "golangTest")
if err != nil {
t.Errorf("Unable to parse test config: %s", err.Error())
return
}
err = c.Save()
if err != nil {
t.Errorf("Unable to save config: %s", err.Error())
}
configs, err := ListConfigNames()
if err != nil {
t.Errorf("Unable to list configs: %s", err.Error())
}
found := false
for _, name := range configs {
if name == "golangTest" {
found = true
break
}
}
if !found {
t.Error("Unable to find saved config in list")
}
loaded, err := LoadFromName("golangTest")
if err != nil {
t.Errorf("Unable to load config: %s", err.Error())
return
}
if !reflect.DeepEqual(loaded, c) {
t.Error("Loaded config is not the same as saved config")
}
k, err := NewPrivateKey()
if err != nil {
t.Errorf("Unable to generate new private key: %s", err.Error())
}
c.Interface.PrivateKey = *k
err = c.Save()
if err != nil {
t.Errorf("Unable to save config a second time: %s", err.Error())
}
loaded, err = LoadFromName("golangTest")
if err != nil {
t.Errorf("Unable to load config a second time: %s", err.Error())
return
}
if !reflect.DeepEqual(loaded, c) {
t.Error("Second loaded config is not the same as second saved config")
}
err = DeleteName("golangTest")
if err != nil {
t.Errorf("Unable to delete config: %s", err.Error())
}
configs, err = ListConfigNames()
if err != nil {
t.Errorf("Unable to list configs: %s", err.Error())
}
found = false
for _, name := range configs {
if name == "golangTest" {
found = true
break
}
}
if found {
t.Error("Config wasn't actually deleted")
}
}

24
conf/storewatcher.go Normal file
View File

@ -0,0 +1,24 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
type StoreCallback struct {
cb func()
}
var storeCallbacks = make(map[*StoreCallback]bool)
func RegisterStoreChangeCallback(cb func()) *StoreCallback {
startWatchingConfigDir()
cb()
s := &StoreCallback{cb}
storeCallbacks[s] = true
return s
}
func (cb *StoreCallback) Unregister() {
delete(storeCallbacks, cb)
}

View File

@ -0,0 +1,61 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"log"
"golang.org/x/sys/windows"
)
var haveStartedWatchingConfigDir bool
func startWatchingConfigDir() {
if haveStartedWatchingConfigDir {
return
}
haveStartedWatchingConfigDir = true
go func() {
h := windows.InvalidHandle
defer func() {
if h != windows.InvalidHandle {
windows.FindCloseChangeNotification(h)
}
haveStartedWatchingConfigDir = false
}()
startover:
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return
}
h, err = windows.FindFirstChangeNotification(configFileDir, true, windows.FILE_NOTIFY_CHANGE_FILE_NAME|windows.FILE_NOTIFY_CHANGE_DIR_NAME|windows.FILE_NOTIFY_CHANGE_ATTRIBUTES|windows.FILE_NOTIFY_CHANGE_SIZE|windows.FILE_NOTIFY_CHANGE_LAST_WRITE|windows.FILE_NOTIFY_CHANGE_LAST_ACCESS|windows.FILE_NOTIFY_CHANGE_CREATION|windows.FILE_NOTIFY_CHANGE_SECURITY)
if err != nil {
log.Printf("Unable to monitor config directory: %v", err)
return
}
for {
s, err := windows.WaitForSingleObject(h, windows.INFINITE)
if err != nil || s == windows.WAIT_FAILED {
log.Printf("Unable to wait on config directory watcher: %v", err)
windows.FindCloseChangeNotification(h)
h = windows.InvalidHandle
goto startover
}
for cb := range storeCallbacks {
cb.cb()
}
err = windows.FindNextChangeNotification(h)
if err != nil {
log.Printf("Unable to monitor config directory again: %v", err)
windows.FindCloseChangeNotification(h)
h = windows.InvalidHandle
goto startover
}
}
}()
}

124
conf/writer.go Normal file
View File

@ -0,0 +1,124 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"fmt"
"strings"
)
func (conf *Config) ToWgQuick() string {
var output strings.Builder
output.WriteString("[Interface]\n")
output.WriteString(fmt.Sprintf("PrivateKey = %s\n", conf.Interface.PrivateKey.String()))
if conf.Interface.ListenPort > 0 {
output.WriteString(fmt.Sprintf("ListenPort = %d\n", conf.Interface.ListenPort))
}
if len(conf.Interface.Addresses) > 0 {
addrStrings := make([]string, len(conf.Interface.Addresses))
for i, address := range conf.Interface.Addresses {
addrStrings[i] = address.String()
}
output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", ")))
}
if len(conf.Interface.DNS)+len(conf.Interface.DNSSearch) > 0 {
addrStrings := make([]string, 0, len(conf.Interface.DNS)+len(conf.Interface.DNSSearch))
for _, address := range conf.Interface.DNS {
addrStrings = append(addrStrings, address.String())
}
addrStrings = append(addrStrings, conf.Interface.DNSSearch...)
output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", ")))
}
if conf.Interface.MTU > 0 {
output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.MTU))
}
if len(conf.Interface.PreUp) > 0 {
output.WriteString(fmt.Sprintf("PreUp = %s\n", conf.Interface.PreUp))
}
if len(conf.Interface.PostUp) > 0 {
output.WriteString(fmt.Sprintf("PostUp = %s\n", conf.Interface.PostUp))
}
if len(conf.Interface.PreDown) > 0 {
output.WriteString(fmt.Sprintf("PreDown = %s\n", conf.Interface.PreDown))
}
if len(conf.Interface.PostDown) > 0 {
output.WriteString(fmt.Sprintf("PostDown = %s\n", conf.Interface.PostDown))
}
for _, peer := range conf.Peers {
output.WriteString("\n[Peer]\n")
output.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey.String()))
if !peer.PresharedKey.IsZero() {
output.WriteString(fmt.Sprintf("PresharedKey = %s\n", peer.PresharedKey.String()))
}
if len(peer.AllowedIPs) > 0 {
addrStrings := make([]string, len(peer.AllowedIPs))
for i, address := range peer.AllowedIPs {
addrStrings[i] = address.String()
}
output.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(addrStrings[:], ", ")))
}
if !peer.Endpoint.IsEmpty() {
output.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint.String()))
}
if peer.PersistentKeepalive > 0 {
output.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", peer.PersistentKeepalive))
}
}
return output.String()
}
func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
var output strings.Builder
output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString()))
if conf.Interface.ListenPort > 0 {
output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort))
}
if len(conf.Peers) > 0 {
output.WriteString("replace_peers=true\n")
}
for _, peer := range conf.Peers {
output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString()))
if !peer.PresharedKey.IsZero() {
output.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey.HexString()))
}
if !peer.Endpoint.IsEmpty() {
var resolvedIP string
resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host)
if dnsErr != nil {
return
}
resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port}
output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String()))
}
output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive))
if len(peer.AllowedIPs) > 0 {
output.WriteString("replace_allowed_ips=true\n")
for _, address := range peer.AllowedIPs {
output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String()))
}
}
}
return output.String(), nil
}

50
conf/zsyscall_windows.go Normal file
View File

@ -0,0 +1,50 @@
// Code generated by 'go generate'; DO NOT EDIT.
package conf
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modwininet = windows.NewLazySystemDLL("wininet.dll")
procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState")
)
func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) {
r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0)
connected = r0 != 0
return
}