conf: add from wireguard-windows/conf/@b73dcdb
Signed-off-by: Vincent Batts <vbatts@hashbangbash.com>
This commit is contained in:
parent
56fa2f2258
commit
77158f3dde
16 changed files with 1939 additions and 0 deletions
36
conf/admin_windows.go
Normal file
36
conf/admin_windows.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import "golang.org/x/sys/windows/registry"
|
||||
|
||||
const adminRegKey = `Software\WireGuard`
|
||||
|
||||
var adminKey registry.Key
|
||||
|
||||
func openAdminKey() (registry.Key, error) {
|
||||
if adminKey != 0 {
|
||||
return adminKey, nil
|
||||
}
|
||||
var err error
|
||||
adminKey, err = registry.OpenKey(registry.LOCAL_MACHINE, adminRegKey, registry.QUERY_VALUE|registry.WOW64_64KEY)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return adminKey, nil
|
||||
}
|
||||
|
||||
func AdminBool(name string) bool {
|
||||
key, err := openAdminKey()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
val, _, err := key.GetIntegerValue(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return val != 0
|
||||
}
|
252
conf/config.go
Normal file
252
conf/config.go
Normal file
|
@ -0,0 +1,252 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/l18n"
|
||||
)
|
||||
|
||||
const KeyLength = 32
|
||||
|
||||
type IPCidr struct {
|
||||
IP net.IP
|
||||
Cidr uint8
|
||||
}
|
||||
|
||||
type Endpoint struct {
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
type Key [KeyLength]byte
|
||||
type HandshakeTime time.Duration
|
||||
type Bytes uint64
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Interface Interface
|
||||
Peers []Peer
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
PrivateKey Key
|
||||
Addresses []IPCidr
|
||||
ListenPort uint16
|
||||
MTU uint16
|
||||
DNS []net.IP
|
||||
DNSSearch []string
|
||||
PreUp string
|
||||
PostUp string
|
||||
PreDown string
|
||||
PostDown string
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PublicKey Key
|
||||
PresharedKey Key
|
||||
AllowedIPs []IPCidr
|
||||
Endpoint Endpoint
|
||||
PersistentKeepalive uint16
|
||||
|
||||
RxBytes Bytes
|
||||
TxBytes Bytes
|
||||
LastHandshakeTime HandshakeTime
|
||||
}
|
||||
|
||||
func (r *IPCidr) String() string {
|
||||
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
|
||||
}
|
||||
|
||||
func (r *IPCidr) Bits() uint8 {
|
||||
if r.IP.To4() != nil {
|
||||
return 32
|
||||
}
|
||||
return 128
|
||||
}
|
||||
|
||||
func (r *IPCidr) IPNet() net.IPNet {
|
||||
return net.IPNet{
|
||||
IP: r.IP,
|
||||
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *IPCidr) MaskSelf() {
|
||||
bits := int(r.Bits())
|
||||
mask := net.CIDRMask(int(r.Cidr), bits)
|
||||
for i := 0; i < bits/8; i++ {
|
||||
r.IP[i] &= mask[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Endpoint) String() string {
|
||||
if strings.IndexByte(e.Host, ':') > 0 {
|
||||
return fmt.Sprintf("[%s]:%d", e.Host, e.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", e.Host, e.Port)
|
||||
}
|
||||
|
||||
func (e *Endpoint) IsEmpty() bool {
|
||||
return len(e.Host) == 0
|
||||
}
|
||||
|
||||
func (k *Key) String() string {
|
||||
return base64.StdEncoding.EncodeToString(k[:])
|
||||
}
|
||||
|
||||
func (k *Key) HexString() string {
|
||||
return hex.EncodeToString(k[:])
|
||||
}
|
||||
|
||||
func (k *Key) IsZero() bool {
|
||||
var zeros Key
|
||||
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
|
||||
}
|
||||
|
||||
func (k *Key) Public() *Key {
|
||||
var p [KeyLength]byte
|
||||
curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k))
|
||||
return (*Key)(&p)
|
||||
}
|
||||
|
||||
func NewPresharedKey() (*Key, error) {
|
||||
var k [KeyLength]byte
|
||||
_, err := rand.Read(k[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return (*Key)(&k), nil
|
||||
}
|
||||
|
||||
func NewPrivateKey() (*Key, error) {
|
||||
k, err := NewPresharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k[0] &= 248
|
||||
k[31] = (k[31] & 127) | 64
|
||||
return k, nil
|
||||
}
|
||||
|
||||
func NewPrivateKeyFromString(b64 string) (*Key, error) {
|
||||
return parseKeyBase64(b64)
|
||||
}
|
||||
|
||||
func (t HandshakeTime) IsEmpty() bool {
|
||||
return t == HandshakeTime(0)
|
||||
}
|
||||
|
||||
func (t HandshakeTime) String() string {
|
||||
u := time.Unix(0, 0).Add(time.Duration(t)).Unix()
|
||||
n := time.Now().Unix()
|
||||
if u == n {
|
||||
return l18n.Sprintf("Now")
|
||||
} else if u > n {
|
||||
return l18n.Sprintf("System clock wound backward!")
|
||||
}
|
||||
left := n - u
|
||||
years := left / (365 * 24 * 60 * 60)
|
||||
left = left % (365 * 24 * 60 * 60)
|
||||
days := left / (24 * 60 * 60)
|
||||
left = left % (24 * 60 * 60)
|
||||
hours := left / (60 * 60)
|
||||
left = left % (60 * 60)
|
||||
minutes := left / 60
|
||||
seconds := left % 60
|
||||
s := make([]string, 0, 5)
|
||||
if years > 0 {
|
||||
s = append(s, l18n.Sprintf("%d year(s)", years))
|
||||
}
|
||||
if days > 0 {
|
||||
s = append(s, l18n.Sprintf("%d day(s)", days))
|
||||
}
|
||||
if hours > 0 {
|
||||
s = append(s, l18n.Sprintf("%d hour(s)", hours))
|
||||
}
|
||||
if minutes > 0 {
|
||||
s = append(s, l18n.Sprintf("%d minute(s)", minutes))
|
||||
}
|
||||
if seconds > 0 {
|
||||
s = append(s, l18n.Sprintf("%d second(s)", seconds))
|
||||
}
|
||||
timestamp := strings.Join(s, l18n.UnitSeparator())
|
||||
return l18n.Sprintf("%s ago", timestamp)
|
||||
}
|
||||
|
||||
func (b Bytes) String() string {
|
||||
if b < 1024 {
|
||||
return l18n.Sprintf("%d\u00a0B", b)
|
||||
} else if b < 1024*1024 {
|
||||
return l18n.Sprintf("%.2f\u00a0KiB", float64(b)/1024)
|
||||
} else if b < 1024*1024*1024 {
|
||||
return l18n.Sprintf("%.2f\u00a0MiB", float64(b)/(1024*1024))
|
||||
} else if b < 1024*1024*1024*1024 {
|
||||
return l18n.Sprintf("%.2f\u00a0GiB", float64(b)/(1024*1024*1024))
|
||||
}
|
||||
return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024)
|
||||
}
|
||||
|
||||
func (conf *Config) DeduplicateNetworkEntries() {
|
||||
m := make(map[string]bool, len(conf.Interface.Addresses))
|
||||
i := 0
|
||||
for _, addr := range conf.Interface.Addresses {
|
||||
s := addr.String()
|
||||
if m[s] {
|
||||
continue
|
||||
}
|
||||
m[s] = true
|
||||
conf.Interface.Addresses[i] = addr
|
||||
i++
|
||||
}
|
||||
conf.Interface.Addresses = conf.Interface.Addresses[:i]
|
||||
|
||||
m = make(map[string]bool, len(conf.Interface.DNS))
|
||||
i = 0
|
||||
for _, addr := range conf.Interface.DNS {
|
||||
s := addr.String()
|
||||
if m[s] {
|
||||
continue
|
||||
}
|
||||
m[s] = true
|
||||
conf.Interface.DNS[i] = addr
|
||||
i++
|
||||
}
|
||||
conf.Interface.DNS = conf.Interface.DNS[:i]
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
m = make(map[string]bool, len(peer.AllowedIPs))
|
||||
i = 0
|
||||
for _, addr := range peer.AllowedIPs {
|
||||
s := addr.String()
|
||||
if m[s] {
|
||||
continue
|
||||
}
|
||||
m[s] = true
|
||||
peer.AllowedIPs[i] = addr
|
||||
i++
|
||||
}
|
||||
peer.AllowedIPs = peer.AllowedIPs[:i]
|
||||
}
|
||||
}
|
||||
|
||||
func (conf *Config) Redact() {
|
||||
conf.Interface.PrivateKey = Key{}
|
||||
for i := range conf.Peers {
|
||||
conf.Peers[i].PublicKey = Key{}
|
||||
conf.Peers[i].PresharedKey = Key{}
|
||||
}
|
||||
}
|
90
conf/dnsresolver_windows.go
Normal file
90
conf/dnsresolver_windows.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState
|
||||
|
||||
func resolveHostname(name string) (resolvedIPString string, err error) {
|
||||
maxTries := 10
|
||||
systemJustBooted := windows.DurationSinceBoot() <= time.Minute*4
|
||||
if systemJustBooted {
|
||||
maxTries *= 4
|
||||
}
|
||||
for i := 0; i < maxTries; i++ {
|
||||
if i > 0 {
|
||||
time.Sleep(time.Second * 4)
|
||||
}
|
||||
resolvedIPString, err = resolveHostnameOnce(name)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if err == windows.WSATRY_AGAIN {
|
||||
log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name)
|
||||
continue
|
||||
}
|
||||
var state uint32
|
||||
if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) {
|
||||
log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
|
||||
hints := windows.AddrinfoW{
|
||||
Family: windows.AF_UNSPEC,
|
||||
Socktype: windows.SOCK_DGRAM,
|
||||
Protocol: windows.IPPROTO_IP,
|
||||
}
|
||||
var result *windows.AddrinfoW
|
||||
name16, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = windows.GetAddrInfoW(name16, nil, &hints, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
err = windows.WSAHOST_NOT_FOUND
|
||||
return
|
||||
}
|
||||
defer windows.FreeAddrInfoW(result)
|
||||
ipv6 := ""
|
||||
for ; result != nil; result = result.Next {
|
||||
switch result.Family {
|
||||
case windows.AF_INET:
|
||||
return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
|
||||
case windows.AF_INET6:
|
||||
if len(ipv6) != 0 {
|
||||
continue
|
||||
}
|
||||
a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
|
||||
ipv6 = (net.IP)(a.Addr[:]).String()
|
||||
if a.Scope_id != 0 {
|
||||
ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(ipv6) != 0 {
|
||||
return ipv6, nil
|
||||
}
|
||||
err = windows.WSAHOST_NOT_FOUND
|
||||
return
|
||||
}
|
89
conf/filewriter_windows.go
Normal file
89
conf/filewriter_windows.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var encryptedFileSd unsafe.Pointer
|
||||
|
||||
func randomFileName() string {
|
||||
var randBytes [32]byte
|
||||
_, err := rand.Read(randBytes[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(randBytes[:]) + ".tmp"
|
||||
}
|
||||
|
||||
func writeLockedDownFile(destination string, overwrite bool, contents []byte) error {
|
||||
var err error
|
||||
sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))}
|
||||
sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd))
|
||||
if sa.SecurityDescriptor == nil {
|
||||
sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor))
|
||||
}
|
||||
destination16, err := windows.UTF16FromString(destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpDestination := randomFileName()
|
||||
tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
deleteIt := func() {
|
||||
yes := byte(1)
|
||||
windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1)
|
||||
}
|
||||
n, err := windows.Write(handle, contents)
|
||||
if err != nil {
|
||||
deleteIt()
|
||||
return err
|
||||
}
|
||||
if n != len(contents) {
|
||||
deleteIt()
|
||||
return windows.ERROR_IO_INCOMPLETE
|
||||
}
|
||||
fileRenameInfo := &struct {
|
||||
replaceIfExists byte
|
||||
rootDirectory windows.Handle
|
||||
fileNameLength uint32
|
||||
fileName [windows.MAX_PATH]uint16
|
||||
}{replaceIfExists: func() byte {
|
||||
if overwrite {
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}(), fileNameLength: uint32(len(destination16) - 1)}
|
||||
if len(destination16) > len(fileRenameInfo.fileName) {
|
||||
deleteIt()
|
||||
return windows.ERROR_BUFFER_OVERFLOW
|
||||
}
|
||||
copy(fileRenameInfo.fileName[:], destination16[:])
|
||||
err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo)))
|
||||
if err != nil {
|
||||
deleteIt()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
96
conf/migration_windows.go
Normal file
96
conf/migration_windows.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var migrating sync.Mutex
|
||||
var lastMigrationTimer *time.Timer
|
||||
|
||||
func MigrateUnencryptedConfigs() { migrateUnencryptedConfigs(3) }
|
||||
|
||||
func migrateUnencryptedConfigs(sharingBase int) {
|
||||
migrating.Lock()
|
||||
defer migrating.Unlock()
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
files, err := os.ReadDir(configFileDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ignoreSharingViolations := false
|
||||
for _, file := range files {
|
||||
path := filepath.Join(configFileDir, file.Name())
|
||||
name := filepath.Base(file.Name())
|
||||
if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) {
|
||||
continue
|
||||
}
|
||||
if !file.Type().IsRegular() {
|
||||
continue
|
||||
}
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.Mode().Perm()&0444 == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var bytes []byte
|
||||
var config *Config
|
||||
// We don't use os.ReadFile, because we actually want RDWR, so that we can take advantage
|
||||
// of Windows file locking for ensuring the file is finished being written.
|
||||
f, err := os.OpenFile(path, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_SHARING_VIOLATION) {
|
||||
if ignoreSharingViolations {
|
||||
continue
|
||||
} else if sharingBase > 0 {
|
||||
if lastMigrationTimer != nil {
|
||||
lastMigrationTimer.Stop()
|
||||
}
|
||||
lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { migrateUnencryptedConfigs(sharingBase - 1) })
|
||||
ignoreSharingViolations = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
goto error
|
||||
}
|
||||
bytes, err = io.ReadAll(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
goto error
|
||||
}
|
||||
config, err = FromWgQuickWithUnknownEncoding(string(bytes), strings.TrimSuffix(name, configFileUnencryptedSuffix))
|
||||
if err != nil {
|
||||
goto error
|
||||
}
|
||||
err = config.Save(false)
|
||||
if err != nil {
|
||||
goto error
|
||||
}
|
||||
err = os.Remove(path)
|
||||
if err != nil {
|
||||
log.Printf("Unable to remove old path %#q: %v", path, err)
|
||||
}
|
||||
continue
|
||||
error:
|
||||
log.Printf("Unable to ingest and encrypt %#q: %v", path, err)
|
||||
}
|
||||
}
|
8
conf/mksyscall.go
Normal file
8
conf/mksyscall.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go dnsresolver_windows.go migration_windows.go storewatcher_windows.go
|
112
conf/name.go
Normal file
112
conf/name.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var reservedNames = []string{
|
||||
"CON", "PRN", "AUX", "NUL",
|
||||
"COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
|
||||
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
|
||||
}
|
||||
|
||||
const serviceNameForbidden = "$"
|
||||
const netshellDllForbidden = "\\/:*?\"<>|\t"
|
||||
const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00"
|
||||
|
||||
var allowedNameFormat *regexp.Regexp
|
||||
|
||||
func init() {
|
||||
allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$")
|
||||
}
|
||||
|
||||
func isReserved(name string) bool {
|
||||
if len(name) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, reserved := range reservedNames {
|
||||
if strings.EqualFold(name, reserved) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasSpecialChars(name string) bool {
|
||||
return strings.ContainsAny(name, specialChars) || strings.ContainsAny(name, netshellDllForbidden) || strings.ContainsAny(name, serviceNameForbidden)
|
||||
}
|
||||
|
||||
func TunnelNameIsValid(name string) bool {
|
||||
// Aside from our own restrictions, let's impose the Windows restrictions first
|
||||
if isReserved(name) || hasSpecialChars(name) {
|
||||
return false
|
||||
}
|
||||
return allowedNameFormat.MatchString(name)
|
||||
}
|
||||
|
||||
type naturalSortToken struct {
|
||||
maybeString string
|
||||
maybeNumber int
|
||||
}
|
||||
type naturalSortString struct {
|
||||
originalString string
|
||||
tokens []naturalSortToken
|
||||
}
|
||||
|
||||
var naturalSortDigitFinder = regexp.MustCompile(`\d+|\D+`)
|
||||
|
||||
func newNaturalSortString(s string) (t naturalSortString) {
|
||||
t.originalString = s
|
||||
s = strings.ToLower(strings.Join(strings.Fields(s), " "))
|
||||
x := naturalSortDigitFinder.FindAllString(s, -1)
|
||||
t.tokens = make([]naturalSortToken, len(x))
|
||||
for i, s := range x {
|
||||
if n, err := strconv.Atoi(s); err == nil {
|
||||
t.tokens[i].maybeNumber = n
|
||||
} else {
|
||||
t.tokens[i].maybeString = s
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f1 naturalSortToken) Cmp(f2 naturalSortToken) int {
|
||||
if len(f1.maybeString) == 0 {
|
||||
if len(f2.maybeString) > 0 || f1.maybeNumber < f2.maybeNumber {
|
||||
return -1
|
||||
} else if f1.maybeNumber > f2.maybeNumber {
|
||||
return 1
|
||||
}
|
||||
} else if len(f2.maybeString) == 0 || f1.maybeString > f2.maybeString {
|
||||
return 1
|
||||
} else if f1.maybeString < f2.maybeString {
|
||||
return -1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func TunnelNameIsLess(a, b string) bool {
|
||||
if a == b {
|
||||
return false
|
||||
}
|
||||
na, nb := newNaturalSortString(a), newNaturalSortString(b)
|
||||
for i, t := range nb.tokens {
|
||||
if i == len(na.tokens) {
|
||||
return true
|
||||
}
|
||||
switch na.tokens[i].Cmp(t) {
|
||||
case -1:
|
||||
return true
|
||||
case 1:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
506
conf/parser.go
Normal file
506
conf/parser.go
Normal file
|
@ -0,0 +1,506 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/l18n"
|
||||
)
|
||||
|
||||
type ParseError struct {
|
||||
why string
|
||||
offender string
|
||||
}
|
||||
|
||||
func (e *ParseError) Error() string {
|
||||
return l18n.Sprintf("%s: %q", e.why, e.offender)
|
||||
}
|
||||
|
||||
func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
|
||||
var addrStr, cidrStr string
|
||||
var cidr int
|
||||
|
||||
i := strings.IndexByte(s, '/')
|
||||
if i < 0 {
|
||||
addrStr = s
|
||||
} else {
|
||||
addrStr, cidrStr = s[:i], s[i+1:]
|
||||
}
|
||||
|
||||
err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
|
||||
addr := net.ParseIP(addrStr)
|
||||
if addr == nil {
|
||||
return
|
||||
}
|
||||
maybeV4 := addr.To4()
|
||||
if maybeV4 != nil {
|
||||
addr = maybeV4
|
||||
}
|
||||
if len(cidrStr) > 0 {
|
||||
err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
|
||||
cidr, err = strconv.Atoi(cidrStr)
|
||||
if err != nil || cidr < 0 || cidr > 128 {
|
||||
return
|
||||
}
|
||||
if cidr > 32 && maybeV4 != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if maybeV4 != nil {
|
||||
cidr = 32
|
||||
} else {
|
||||
cidr = 128
|
||||
}
|
||||
}
|
||||
return &IPCidr{addr, uint8(cidr)}, nil
|
||||
}
|
||||
|
||||
func parseEndpoint(s string) (*Endpoint, error) {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i < 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Missing port from endpoint"), s}
|
||||
}
|
||||
host, portStr := s[:i], s[i+1:]
|
||||
if len(host) < 1 {
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid endpoint host"), host}
|
||||
}
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostColon := strings.IndexByte(host, ':')
|
||||
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
|
||||
err := &ParseError{l18n.Sprintf("Brackets must contain an IPv6 address"), host}
|
||||
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
|
||||
end := len(host) - 1
|
||||
if i := strings.LastIndexByte(host, '%'); i > 1 {
|
||||
end = i
|
||||
}
|
||||
maybeV6 := net.ParseIP(host[1:end])
|
||||
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
return &Endpoint{host, uint16(port)}, nil
|
||||
}
|
||||
|
||||
func parseMTU(s string) (uint16, error) {
|
||||
m, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if m < 576 || m > 65535 {
|
||||
return 0, &ParseError{l18n.Sprintf("Invalid MTU"), s}
|
||||
}
|
||||
return uint16(m), nil
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
m, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if m < 0 || m > 65535 {
|
||||
return 0, &ParseError{l18n.Sprintf("Invalid port"), s}
|
||||
}
|
||||
return uint16(m), nil
|
||||
}
|
||||
|
||||
func parsePersistentKeepalive(s string) (uint16, error) {
|
||||
if s == "off" {
|
||||
return 0, nil
|
||||
}
|
||||
m, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if m < 0 || m > 65535 {
|
||||
return 0, &ParseError{l18n.Sprintf("Invalid persistent keepalive"), s}
|
||||
}
|
||||
return uint16(m), nil
|
||||
}
|
||||
|
||||
func parseKeyBase64(s string) (*Key, error) {
|
||||
k, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
|
||||
}
|
||||
if len(k) != KeyLength {
|
||||
return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
|
||||
}
|
||||
var key Key
|
||||
copy(key[:], k)
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func parseKeyHex(s string) (*Key, error) {
|
||||
k, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
|
||||
}
|
||||
if len(k) != KeyLength {
|
||||
return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
|
||||
}
|
||||
var key Key
|
||||
copy(key[:], k)
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func parseBytesOrStamp(s string) (uint64, error) {
|
||||
b, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s}
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func splitList(s string) ([]string, error) {
|
||||
var out []string
|
||||
for _, split := range strings.Split(s, ",") {
|
||||
trim := strings.TrimSpace(split)
|
||||
if len(trim) == 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Two commas in a row"), s}
|
||||
}
|
||||
out = append(out, trim)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
inInterfaceSection parserState = iota
|
||||
inPeerSection
|
||||
notInASection
|
||||
)
|
||||
|
||||
func (c *Config) maybeAddPeer(p *Peer) {
|
||||
if p != nil {
|
||||
c.Peers = append(c.Peers, *p)
|
||||
}
|
||||
}
|
||||
|
||||
func FromWgQuick(s string, name string) (*Config, error) {
|
||||
if !TunnelNameIsValid(name) {
|
||||
return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name}
|
||||
}
|
||||
lines := strings.Split(s, "\n")
|
||||
parserState := notInASection
|
||||
conf := Config{Name: name}
|
||||
sawPrivateKey := false
|
||||
var peer *Peer
|
||||
for _, line := range lines {
|
||||
pound := strings.IndexByte(line, '#')
|
||||
if pound >= 0 {
|
||||
line = line[:pound]
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
lineLower := strings.ToLower(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if lineLower == "[interface]" {
|
||||
conf.maybeAddPeer(peer)
|
||||
parserState = inInterfaceSection
|
||||
continue
|
||||
}
|
||||
if lineLower == "[peer]" {
|
||||
conf.maybeAddPeer(peer)
|
||||
peer = &Peer{}
|
||||
parserState = inPeerSection
|
||||
continue
|
||||
}
|
||||
if parserState == notInASection {
|
||||
return nil, &ParseError{l18n.Sprintf("Line must occur in a section"), line}
|
||||
}
|
||||
equals := strings.IndexByte(line, '=')
|
||||
if equals < 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
|
||||
}
|
||||
key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
|
||||
if len(val) == 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
|
||||
}
|
||||
if parserState == inInterfaceSection {
|
||||
switch key {
|
||||
case "privatekey":
|
||||
k, err := parseKeyBase64(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.PrivateKey = *k
|
||||
sawPrivateKey = true
|
||||
case "listenport":
|
||||
p, err := parsePort(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.ListenPort = p
|
||||
case "mtu":
|
||||
m, err := parseMTU(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.MTU = m
|
||||
case "address":
|
||||
addresses, err := splitList(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, address := range addresses {
|
||||
a, err := parseIPCidr(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
|
||||
}
|
||||
case "dns":
|
||||
addresses, err := splitList(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, address := range addresses {
|
||||
a := net.ParseIP(address)
|
||||
if a == nil {
|
||||
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
|
||||
} else {
|
||||
conf.Interface.DNS = append(conf.Interface.DNS, a)
|
||||
}
|
||||
}
|
||||
case "preup":
|
||||
conf.Interface.PreUp = val
|
||||
case "postup":
|
||||
conf.Interface.PostUp = val
|
||||
case "predown":
|
||||
conf.Interface.PreDown = val
|
||||
case "postdown":
|
||||
conf.Interface.PostDown = val
|
||||
default:
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key}
|
||||
}
|
||||
} else if parserState == inPeerSection {
|
||||
switch key {
|
||||
case "publickey":
|
||||
k, err := parseKeyBase64(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PublicKey = *k
|
||||
case "presharedkey":
|
||||
k, err := parseKeyBase64(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PresharedKey = *k
|
||||
case "allowedips":
|
||||
addresses, err := splitList(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, address := range addresses {
|
||||
a, err := parseIPCidr(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, *a)
|
||||
}
|
||||
case "persistentkeepalive":
|
||||
p, err := parsePersistentKeepalive(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PersistentKeepalive = p
|
||||
case "endpoint":
|
||||
e, err := parseEndpoint(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.Endpoint = *e
|
||||
default:
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key for [Peer] section"), key}
|
||||
}
|
||||
}
|
||||
}
|
||||
conf.maybeAddPeer(peer)
|
||||
|
||||
if !sawPrivateKey {
|
||||
return nil, &ParseError{l18n.Sprintf("An interface must have a private key"), l18n.Sprintf("[none specified]")}
|
||||
}
|
||||
for _, p := range conf.Peers {
|
||||
if p.PublicKey.IsZero() {
|
||||
return nil, &ParseError{l18n.Sprintf("All peers must have public keys"), l18n.Sprintf("[none specified]")}
|
||||
}
|
||||
}
|
||||
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) {
|
||||
c, firstErr := FromWgQuick(s, name)
|
||||
if firstErr == nil {
|
||||
return c, nil
|
||||
}
|
||||
for _, encoding := range unicode.All {
|
||||
decoded, err := encoding.NewDecoder().String(s)
|
||||
if err == nil {
|
||||
c, err := FromWgQuick(decoded, name)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
func FromUAPI(reader io.Reader, existingConfig *Config) (*Config, error) {
|
||||
parserState := inInterfaceSection
|
||||
conf := Config{
|
||||
Name: existingConfig.Name,
|
||||
Interface: Interface{
|
||||
Addresses: existingConfig.Interface.Addresses,
|
||||
DNS: existingConfig.Interface.DNS,
|
||||
DNSSearch: existingConfig.Interface.DNSSearch,
|
||||
MTU: existingConfig.Interface.MTU,
|
||||
PreUp: existingConfig.Interface.PreUp,
|
||||
PostUp: existingConfig.Interface.PostUp,
|
||||
PreDown: existingConfig.Interface.PreDown,
|
||||
PostDown: existingConfig.Interface.PostDown,
|
||||
},
|
||||
}
|
||||
var peer *Peer
|
||||
lineReader := bufio.NewReader(reader)
|
||||
for {
|
||||
line, err := lineReader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = line[:len(line)-1]
|
||||
if len(line) == 0 {
|
||||
break
|
||||
}
|
||||
equals := strings.IndexByte(line, '=')
|
||||
if equals < 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
|
||||
}
|
||||
key, val := line[:equals], line[equals+1:]
|
||||
if len(val) == 0 {
|
||||
return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
|
||||
}
|
||||
switch key {
|
||||
case "public_key":
|
||||
conf.maybeAddPeer(peer)
|
||||
peer = &Peer{}
|
||||
parserState = inPeerSection
|
||||
case "errno":
|
||||
if val == "0" {
|
||||
continue
|
||||
} else {
|
||||
return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val}
|
||||
}
|
||||
}
|
||||
if parserState == inInterfaceSection {
|
||||
switch key {
|
||||
case "private_key":
|
||||
k, err := parseKeyHex(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.PrivateKey = *k
|
||||
case "listen_port":
|
||||
p, err := parsePort(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Interface.ListenPort = p
|
||||
case "fwmark":
|
||||
// Ignored for now.
|
||||
|
||||
default:
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key}
|
||||
}
|
||||
} else if parserState == inPeerSection {
|
||||
switch key {
|
||||
case "public_key":
|
||||
k, err := parseKeyHex(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PublicKey = *k
|
||||
case "preshared_key":
|
||||
k, err := parseKeyHex(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PresharedKey = *k
|
||||
case "protocol_version":
|
||||
if val != "1" {
|
||||
return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val}
|
||||
}
|
||||
case "allowed_ip":
|
||||
a, err := parseIPCidr(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, *a)
|
||||
case "persistent_keepalive_interval":
|
||||
p, err := parsePersistentKeepalive(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.PersistentKeepalive = p
|
||||
case "endpoint":
|
||||
e, err := parseEndpoint(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.Endpoint = *e
|
||||
case "tx_bytes":
|
||||
b, err := parseBytesOrStamp(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.TxBytes = Bytes(b)
|
||||
case "rx_bytes":
|
||||
b, err := parseBytesOrStamp(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.RxBytes = Bytes(b)
|
||||
case "last_handshake_time_sec":
|
||||
t, err := parseBytesOrStamp(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second)
|
||||
case "last_handshake_time_nsec":
|
||||
t, err := parseBytesOrStamp(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond)
|
||||
default:
|
||||
return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key}
|
||||
}
|
||||
}
|
||||
}
|
||||
conf.maybeAddPeer(peer)
|
||||
|
||||
return &conf, nil
|
||||
}
|
128
conf/parser_test.go
Normal file
128
conf/parser_test.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testInput = `
|
||||
[Interface]
|
||||
Address = 10.192.122.1/24
|
||||
Address = 10.10.0.1/16
|
||||
PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
|
||||
ListenPort = 51820 #comments don't matter
|
||||
|
||||
[Peer]
|
||||
PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=
|
||||
Endpoint = 192.95.5.67:1234
|
||||
AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
|
||||
Endpoint = [2607:5300:60:6b0::c05f:543]:2468
|
||||
AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
|
||||
PersistentKeepalive = 100
|
||||
|
||||
[Peer]
|
||||
PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
|
||||
PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=
|
||||
Endpoint = test.wireguard.com:18981
|
||||
AllowedIPs = 10.10.10.230/32`
|
||||
|
||||
func noError(t *testing.T, err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
_, fn, line, _ := runtime.Caller(1)
|
||||
t.Errorf("Error at %s:%d: %#v", fn, line, err)
|
||||
return false
|
||||
}
|
||||
|
||||
func equal(t *testing.T, expected, actual interface{}) bool {
|
||||
if reflect.DeepEqual(expected, actual) {
|
||||
return true
|
||||
}
|
||||
_, fn, line, _ := runtime.Caller(1)
|
||||
t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
|
||||
return false
|
||||
}
|
||||
func lenTest(t *testing.T, actualO interface{}, expected int) bool {
|
||||
actual := reflect.ValueOf(actualO).Len()
|
||||
if reflect.DeepEqual(expected, actual) {
|
||||
return true
|
||||
}
|
||||
_, fn, line, _ := runtime.Caller(1)
|
||||
t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
|
||||
return false
|
||||
}
|
||||
func contains(t *testing.T, list, element interface{}) bool {
|
||||
listValue := reflect.ValueOf(list)
|
||||
for i := 0; i < listValue.Len(); i++ {
|
||||
if reflect.DeepEqual(listValue.Index(i).Interface(), element) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
_, fn, line, _ := runtime.Caller(1)
|
||||
t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element)
|
||||
return false
|
||||
}
|
||||
|
||||
func TestFromWgQuick(t *testing.T) {
|
||||
conf, err := FromWgQuick(testInput, "test")
|
||||
if noError(t, err) {
|
||||
|
||||
lenTest(t, conf.Interface.Addresses, 2)
|
||||
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
|
||||
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
|
||||
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
|
||||
equal(t, uint16(51820), conf.Interface.ListenPort)
|
||||
|
||||
lenTest(t, conf.Peers, 3)
|
||||
lenTest(t, conf.Peers[0].AllowedIPs, 2)
|
||||
equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoint)
|
||||
equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.String())
|
||||
|
||||
lenTest(t, conf.Peers[1].AllowedIPs, 2)
|
||||
equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoint)
|
||||
equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.String())
|
||||
equal(t, uint16(100), conf.Peers[1].PersistentKeepalive)
|
||||
|
||||
lenTest(t, conf.Peers[2].AllowedIPs, 1)
|
||||
equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoint)
|
||||
equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.String())
|
||||
equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
_, err := parseEndpoint("[192.168.42.0:]:51880")
|
||||
if err == nil {
|
||||
t.Error("Error was expected")
|
||||
}
|
||||
e, err := parseEndpoint("192.168.42.0:51880")
|
||||
if noError(t, err) {
|
||||
equal(t, "192.168.42.0", e.Host)
|
||||
equal(t, uint16(51880), e.Port)
|
||||
}
|
||||
e, err = parseEndpoint("test.wireguard.com:18981")
|
||||
if noError(t, err) {
|
||||
equal(t, "test.wireguard.com", e.Host)
|
||||
equal(t, uint16(18981), e.Port)
|
||||
}
|
||||
e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468")
|
||||
if noError(t, err) {
|
||||
equal(t, "2607:5300:60:6b0::c05f:543", e.Host)
|
||||
equal(t, uint16(2468), e.Port)
|
||||
}
|
||||
_, err = parseEndpoint("[::::::invalid:18981")
|
||||
if err == nil {
|
||||
t.Error("Error was expected")
|
||||
}
|
||||
}
|
128
conf/path_windows.go
Normal file
128
conf/path_windows.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var cachedConfigFileDir string
|
||||
var cachedRootDir string
|
||||
|
||||
func tunnelConfigurationsDirectory() (string, error) {
|
||||
if cachedConfigFileDir != "" {
|
||||
return cachedConfigFileDir, nil
|
||||
}
|
||||
root, err := RootDirectory(true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c := filepath.Join(root, "Configurations")
|
||||
err = os.Mkdir(c, os.ModeDir|0700)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return "", err
|
||||
}
|
||||
cachedConfigFileDir = c
|
||||
return cachedConfigFileDir, nil
|
||||
}
|
||||
|
||||
// PresetRootDirectory causes RootDirectory() to not try any automatic deduction, and instead
|
||||
// uses what's passed to it. This isn't used by wireguard-windows, but is useful for external
|
||||
// consumers of our libraries who might want to do strange things.
|
||||
func PresetRootDirectory(root string) {
|
||||
cachedRootDir = root
|
||||
}
|
||||
|
||||
func RootDirectory(create bool) (string, error) {
|
||||
if cachedRootDir != "" {
|
||||
return cachedRootDir, nil
|
||||
}
|
||||
root, err := windows.KnownFolderPath(windows.FOLDERID_ProgramFiles, windows.KF_FLAG_DEFAULT)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
root = filepath.Join(root, "WireGuard")
|
||||
if !create {
|
||||
return filepath.Join(root, "Data"), nil
|
||||
}
|
||||
root16, err := windows.UTF16PtrFromString(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// The root directory inherits its ACL from Program Files; we don't want to mess with that
|
||||
err = windows.CreateDirectory(root16, nil)
|
||||
if err != nil && err != windows.ERROR_ALREADY_EXISTS {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataDirectorySa := &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: dataDirectorySd,
|
||||
}
|
||||
|
||||
data := filepath.Join(root, "Data")
|
||||
data16, err := windows.UTF16PtrFromString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var dataHandle windows.Handle
|
||||
for {
|
||||
err = windows.CreateDirectory(data16, dataDirectorySa)
|
||||
if err != nil && err != windows.ERROR_ALREADY_EXISTS {
|
||||
return "", err
|
||||
}
|
||||
dataHandle, err = windows.CreateFile(data16, windows.READ_CONTROL|windows.WRITE_OWNER|windows.WRITE_DAC, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_ATTRIBUTE_DIRECTORY, 0)
|
||||
if err != nil && err != windows.ERROR_FILE_NOT_FOUND {
|
||||
return "", err
|
||||
}
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer windows.CloseHandle(dataHandle)
|
||||
var fileInfo windows.ByHandleFileInformation
|
||||
err = windows.GetFileInformationByHandle(dataHandle, &fileInfo)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 {
|
||||
return "", errors.New("Data directory is actually a file")
|
||||
}
|
||||
if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 {
|
||||
return "", errors.New("Data directory is reparse point")
|
||||
}
|
||||
buf := make([]uint16, windows.MAX_PATH+4)
|
||||
for {
|
||||
bufLen, err := windows.GetFinalPathNameByHandle(dataHandle, &buf[0], uint32(len(buf)), 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if bufLen < uint32(len(buf)) {
|
||||
break
|
||||
}
|
||||
buf = make([]uint16, bufLen)
|
||||
}
|
||||
if !strings.EqualFold(`\\?\`+data, windows.UTF16ToString(buf[:])) {
|
||||
return "", errors.New("Data directory jumped to unexpected location")
|
||||
}
|
||||
err = windows.SetKernelObjectSecurity(dataHandle, windows.DACL_SECURITY_INFORMATION|windows.GROUP_SECURITY_INFORMATION|windows.OWNER_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, dataDirectorySd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cachedRootDir = data
|
||||
return cachedRootDir, nil
|
||||
}
|
144
conf/store.go
Normal file
144
conf/store.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/conf/dpapi"
|
||||
)
|
||||
|
||||
const configFileSuffix = ".conf.dpapi"
|
||||
const configFileUnencryptedSuffix = ".conf"
|
||||
|
||||
func ListConfigNames() ([]string, error) {
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files, err := os.ReadDir(configFileDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configs := make([]string, len(files))
|
||||
i := 0
|
||||
for _, file := range files {
|
||||
name := filepath.Base(file.Name())
|
||||
if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) {
|
||||
continue
|
||||
}
|
||||
if !file.Type().IsRegular() {
|
||||
continue
|
||||
}
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.Mode().Perm()&0444 == 0 {
|
||||
continue
|
||||
}
|
||||
name = strings.TrimSuffix(name, configFileSuffix)
|
||||
if !TunnelNameIsValid(name) {
|
||||
continue
|
||||
}
|
||||
configs[i] = name
|
||||
i++
|
||||
}
|
||||
return configs[:i], nil
|
||||
}
|
||||
|
||||
func LoadFromName(name string) (*Config, error) {
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return LoadFromPath(filepath.Join(configFileDir, name+configFileSuffix))
|
||||
}
|
||||
|
||||
func LoadFromPath(path string) (*Config, error) {
|
||||
name, err := NameFromPath(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(path, configFileSuffix) {
|
||||
bytes, err = dpapi.Decrypt(bytes, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return FromWgQuickWithUnknownEncoding(string(bytes), name)
|
||||
}
|
||||
|
||||
func PathIsEncrypted(path string) bool {
|
||||
return strings.HasSuffix(filepath.Base(path), configFileSuffix)
|
||||
}
|
||||
|
||||
func NameFromPath(path string) (string, error) {
|
||||
name := filepath.Base(path)
|
||||
if !((len(name) > len(configFileSuffix) && strings.HasSuffix(name, configFileSuffix)) ||
|
||||
(len(name) > len(configFileUnencryptedSuffix) && strings.HasSuffix(name, configFileUnencryptedSuffix))) {
|
||||
return "", errors.New("Path must end in either " + configFileSuffix + " or " + configFileUnencryptedSuffix)
|
||||
}
|
||||
if strings.HasSuffix(path, configFileSuffix) {
|
||||
name = strings.TrimSuffix(name, configFileSuffix)
|
||||
} else {
|
||||
name = strings.TrimSuffix(name, configFileUnencryptedSuffix)
|
||||
}
|
||||
if !TunnelNameIsValid(name) {
|
||||
return "", errors.New("Tunnel name is not valid")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (config *Config) Save(overwrite bool) error {
|
||||
if !TunnelNameIsValid(config.Name) {
|
||||
return errors.New("Tunnel name is not valid")
|
||||
}
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filename := filepath.Join(configFileDir, config.Name+configFileSuffix)
|
||||
bytes := []byte(config.ToWgQuick())
|
||||
bytes, err = dpapi.Encrypt(bytes, config.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeLockedDownFile(filename, overwrite, bytes)
|
||||
}
|
||||
|
||||
func (config *Config) Path() (string, error) {
|
||||
if !TunnelNameIsValid(config.Name) {
|
||||
return "", errors.New("Tunnel name is not valid")
|
||||
}
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configFileDir, config.Name+configFileSuffix), nil
|
||||
}
|
||||
|
||||
func DeleteName(name string) error {
|
||||
if !TunnelNameIsValid(name) {
|
||||
return errors.New("Tunnel name is not valid")
|
||||
}
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(filepath.Join(configFileDir, name+configFileSuffix))
|
||||
}
|
||||
|
||||
func (config *Config) Delete() error {
|
||||
return DeleteName(config.Name)
|
||||
}
|
91
conf/store_test.go
Normal file
91
conf/store_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
c, err := FromWgQuick(testInput, "golangTest")
|
||||
if err != nil {
|
||||
t.Errorf("Unable to parse test config: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = c.Save()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to save config: %s", err.Error())
|
||||
}
|
||||
|
||||
configs, err := ListConfigNames()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to list configs: %s", err.Error())
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, name := range configs {
|
||||
if name == "golangTest" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Unable to find saved config in list")
|
||||
}
|
||||
|
||||
loaded, err := LoadFromName("golangTest")
|
||||
if err != nil {
|
||||
t.Errorf("Unable to load config: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(loaded, c) {
|
||||
t.Error("Loaded config is not the same as saved config")
|
||||
}
|
||||
|
||||
k, err := NewPrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to generate new private key: %s", err.Error())
|
||||
}
|
||||
c.Interface.PrivateKey = *k
|
||||
|
||||
err = c.Save()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to save config a second time: %s", err.Error())
|
||||
}
|
||||
|
||||
loaded, err = LoadFromName("golangTest")
|
||||
if err != nil {
|
||||
t.Errorf("Unable to load config a second time: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(loaded, c) {
|
||||
t.Error("Second loaded config is not the same as second saved config")
|
||||
}
|
||||
|
||||
err = DeleteName("golangTest")
|
||||
if err != nil {
|
||||
t.Errorf("Unable to delete config: %s", err.Error())
|
||||
}
|
||||
|
||||
configs, err = ListConfigNames()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to list configs: %s", err.Error())
|
||||
}
|
||||
found = false
|
||||
for _, name := range configs {
|
||||
if name == "golangTest" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("Config wasn't actually deleted")
|
||||
}
|
||||
}
|
24
conf/storewatcher.go
Normal file
24
conf/storewatcher.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
type StoreCallback struct {
|
||||
cb func()
|
||||
}
|
||||
|
||||
var storeCallbacks = make(map[*StoreCallback]bool)
|
||||
|
||||
func RegisterStoreChangeCallback(cb func()) *StoreCallback {
|
||||
startWatchingConfigDir()
|
||||
cb()
|
||||
s := &StoreCallback{cb}
|
||||
storeCallbacks[s] = true
|
||||
return s
|
||||
}
|
||||
|
||||
func (cb *StoreCallback) Unregister() {
|
||||
delete(storeCallbacks, cb)
|
||||
}
|
61
conf/storewatcher_windows.go
Normal file
61
conf/storewatcher_windows.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var haveStartedWatchingConfigDir bool
|
||||
|
||||
func startWatchingConfigDir() {
|
||||
if haveStartedWatchingConfigDir {
|
||||
return
|
||||
}
|
||||
haveStartedWatchingConfigDir = true
|
||||
go func() {
|
||||
h := windows.InvalidHandle
|
||||
defer func() {
|
||||
if h != windows.InvalidHandle {
|
||||
windows.FindCloseChangeNotification(h)
|
||||
}
|
||||
haveStartedWatchingConfigDir = false
|
||||
}()
|
||||
startover:
|
||||
configFileDir, err := tunnelConfigurationsDirectory()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h, err = windows.FindFirstChangeNotification(configFileDir, true, windows.FILE_NOTIFY_CHANGE_FILE_NAME|windows.FILE_NOTIFY_CHANGE_DIR_NAME|windows.FILE_NOTIFY_CHANGE_ATTRIBUTES|windows.FILE_NOTIFY_CHANGE_SIZE|windows.FILE_NOTIFY_CHANGE_LAST_WRITE|windows.FILE_NOTIFY_CHANGE_LAST_ACCESS|windows.FILE_NOTIFY_CHANGE_CREATION|windows.FILE_NOTIFY_CHANGE_SECURITY)
|
||||
if err != nil {
|
||||
log.Printf("Unable to monitor config directory: %v", err)
|
||||
return
|
||||
}
|
||||
for {
|
||||
s, err := windows.WaitForSingleObject(h, windows.INFINITE)
|
||||
if err != nil || s == windows.WAIT_FAILED {
|
||||
log.Printf("Unable to wait on config directory watcher: %v", err)
|
||||
windows.FindCloseChangeNotification(h)
|
||||
h = windows.InvalidHandle
|
||||
goto startover
|
||||
}
|
||||
|
||||
for cb := range storeCallbacks {
|
||||
cb.cb()
|
||||
}
|
||||
|
||||
err = windows.FindNextChangeNotification(h)
|
||||
if err != nil {
|
||||
log.Printf("Unable to monitor config directory again: %v", err)
|
||||
windows.FindCloseChangeNotification(h)
|
||||
h = windows.InvalidHandle
|
||||
goto startover
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
124
conf/writer.go
Normal file
124
conf/writer.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (conf *Config) ToWgQuick() string {
|
||||
var output strings.Builder
|
||||
output.WriteString("[Interface]\n")
|
||||
|
||||
output.WriteString(fmt.Sprintf("PrivateKey = %s\n", conf.Interface.PrivateKey.String()))
|
||||
|
||||
if conf.Interface.ListenPort > 0 {
|
||||
output.WriteString(fmt.Sprintf("ListenPort = %d\n", conf.Interface.ListenPort))
|
||||
}
|
||||
|
||||
if len(conf.Interface.Addresses) > 0 {
|
||||
addrStrings := make([]string, len(conf.Interface.Addresses))
|
||||
for i, address := range conf.Interface.Addresses {
|
||||
addrStrings[i] = address.String()
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", ")))
|
||||
}
|
||||
|
||||
if len(conf.Interface.DNS)+len(conf.Interface.DNSSearch) > 0 {
|
||||
addrStrings := make([]string, 0, len(conf.Interface.DNS)+len(conf.Interface.DNSSearch))
|
||||
for _, address := range conf.Interface.DNS {
|
||||
addrStrings = append(addrStrings, address.String())
|
||||
}
|
||||
addrStrings = append(addrStrings, conf.Interface.DNSSearch...)
|
||||
output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", ")))
|
||||
}
|
||||
|
||||
if conf.Interface.MTU > 0 {
|
||||
output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.MTU))
|
||||
}
|
||||
|
||||
if len(conf.Interface.PreUp) > 0 {
|
||||
output.WriteString(fmt.Sprintf("PreUp = %s\n", conf.Interface.PreUp))
|
||||
}
|
||||
if len(conf.Interface.PostUp) > 0 {
|
||||
output.WriteString(fmt.Sprintf("PostUp = %s\n", conf.Interface.PostUp))
|
||||
}
|
||||
if len(conf.Interface.PreDown) > 0 {
|
||||
output.WriteString(fmt.Sprintf("PreDown = %s\n", conf.Interface.PreDown))
|
||||
}
|
||||
if len(conf.Interface.PostDown) > 0 {
|
||||
output.WriteString(fmt.Sprintf("PostDown = %s\n", conf.Interface.PostDown))
|
||||
}
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
output.WriteString("\n[Peer]\n")
|
||||
|
||||
output.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey.String()))
|
||||
|
||||
if !peer.PresharedKey.IsZero() {
|
||||
output.WriteString(fmt.Sprintf("PresharedKey = %s\n", peer.PresharedKey.String()))
|
||||
}
|
||||
|
||||
if len(peer.AllowedIPs) > 0 {
|
||||
addrStrings := make([]string, len(peer.AllowedIPs))
|
||||
for i, address := range peer.AllowedIPs {
|
||||
addrStrings[i] = address.String()
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("AllowedIPs = %s\n", strings.Join(addrStrings[:], ", ")))
|
||||
}
|
||||
|
||||
if !peer.Endpoint.IsEmpty() {
|
||||
output.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint.String()))
|
||||
}
|
||||
|
||||
if peer.PersistentKeepalive > 0 {
|
||||
output.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", peer.PersistentKeepalive))
|
||||
}
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString()))
|
||||
|
||||
if conf.Interface.ListenPort > 0 {
|
||||
output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort))
|
||||
}
|
||||
|
||||
if len(conf.Peers) > 0 {
|
||||
output.WriteString("replace_peers=true\n")
|
||||
}
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString()))
|
||||
|
||||
if !peer.PresharedKey.IsZero() {
|
||||
output.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey.HexString()))
|
||||
}
|
||||
|
||||
if !peer.Endpoint.IsEmpty() {
|
||||
var resolvedIP string
|
||||
resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host)
|
||||
if dnsErr != nil {
|
||||
return
|
||||
}
|
||||
resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port}
|
||||
output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String()))
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive))
|
||||
|
||||
if len(peer.AllowedIPs) > 0 {
|
||||
output.WriteString("replace_allowed_ips=true\n")
|
||||
for _, address := range peer.AllowedIPs {
|
||||
output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
50
conf/zsyscall_windows.go
Normal file
50
conf/zsyscall_windows.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package conf
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modwininet = windows.NewLazySystemDLL("wininet.dll")
|
||||
|
||||
procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState")
|
||||
)
|
||||
|
||||
func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) {
|
||||
r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0)
|
||||
connected = r0 != 0
|
||||
return
|
||||
}
|
Loading…
Reference in a new issue