From 87adb7d366b1b3d9e6be4b0398504d75452f3d39 Mon Sep 17 00:00:00 2001 From: Evan Hazlett Date: Sun, 6 Oct 2019 02:28:58 -0400 Subject: [PATCH] fix wireguard routing Signed-off-by: Evan Hazlett --- cmd/heimdall/main.go | 8 +- cmd/heimdall/run.go | 1 + config.go | 4 +- go.sum | 5 ++ server/net.go | 183 +++++++++++++++++++++++++++++++++++-------- server/net_test.go | 82 +++++++++++++++++++ server/peer.go | 35 ++++++++- server/server.go | 90 ++++++++++++++++++--- server/wireguard.go | 13 ++- utils.go | 2 +- 10 files changed, 367 insertions(+), 56 deletions(-) create mode 100644 server/net_test.go diff --git a/cmd/heimdall/main.go b/cmd/heimdall/main.go index c6f54c4..a568c0d 100644 --- a/cmd/heimdall/main.go +++ b/cmd/heimdall/main.go @@ -83,10 +83,16 @@ func main() { Value: generateKey(), EnvVar: "HEIMDALL_CLUSTER_KEY", }, + cli.StringFlag{ + Name: "node-network", + Usage: "subnet to be used for nodes", + Value: "10.10.0.0/16", + EnvVar: "HEIMDALL_NODE_NETWORK", + }, cli.StringFlag{ Name: "peer-network", Usage: "subnet to be used for peers", - Value: "10.254.0.0/16", + Value: "10.51.0.0/16", EnvVar: "HEIMDALL_PEER_NETWORK", }, cli.StringFlag{ diff --git a/cmd/heimdall/run.go b/cmd/heimdall/run.go index 76e8fc9..dfce06e 100644 --- a/cmd/heimdall/run.go +++ b/cmd/heimdall/run.go @@ -45,6 +45,7 @@ func runServer(cx *cli.Context) error { GRPCAddress: cx.String("addr"), GRPCPeerAddress: cx.String("peer"), ClusterKey: cx.String("cluster-key"), + NodeNetwork: cx.String("node-network"), PeerNetwork: cx.String("peer-network"), GatewayIP: cx.String("gateway-ip"), GatewayPort: cx.Int("gateway-port"), diff --git a/config.go b/config.go index d44f0da..9c7d8e0 100644 --- a/config.go +++ b/config.go @@ -31,7 +31,9 @@ type Config struct { GRPCPeerAddress string // ClusterKey is a preshared key for cluster peers ClusterKey string - // PeerNetwork is the subnet that will be used for cluster peers + // NodeNetwork is the network for the cluster nodes + NodeNetwork string + // PeerNetwork is the subnet that is used for cluster peers PeerNetwork string // GatewayIP is the IP used for peer communication GatewayIP string diff --git a/go.sum b/go.sum index 1dc2f8e..9a4f229 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/crosbymichael/guard v0.0.0-20190716141324-5c2daadf8067 h1:jlV8Svz9lOwvxWBt2RN3uA1JUZ8AFj46boym2+Fx488= github.com/crosbymichael/guard v0.0.0-20190716141324-5c2daadf8067/go.mod h1:+l2fIHwwiNb/sUw9RcsUH6wXnO07793PC4XjDWCuiHs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ehazlett/ttlcache v0.0.0-20190820213212-4400e3aef9f0/go.mod h1:D7IiYXsX2n2xixWvFTxGeZucCvvNtI14ikLj6L9Kp9E= github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= @@ -32,6 +33,7 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE= github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -69,6 +71,7 @@ github.com/olebedev/emitter v0.0.0-20190110104742-e8d1457e6aee/go.mod h1:eT2/Pcs github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= @@ -100,6 +103,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/urfave/cli v1.21.0/go.mod h1:lxDj6qX9Q6lWQxIrbrT0nwecwUtRnhVZAJjJZrVUZZQ= github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= @@ -142,5 +146,6 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/server/net.go b/server/net.go index 242807b..20449a9 100644 --- a/server/net.go +++ b/server/net.go @@ -24,6 +24,7 @@ package server import ( "context" "fmt" + "math/big" "net" "strings" @@ -36,8 +37,55 @@ type subnetRange struct { Subnet *net.IPNet } -func (s *Server) getIPs(ctx context.Context) (map[string]net.IP, error) { - values, err := redis.StringMap(s.local(ctx, "HGETALL", ipsKey)) +func (s *Server) updateNodeNetwork(ctx context.Context, subnet string) error { + if _, err := s.master(ctx, "SET", s.getNodeNetworkKey(s.cfg.ID), subnet); err != nil { + return err + } + return nil +} + +func (s *Server) getOrAllocatePeerIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) { + r, err := parseSubnetRange(s.cfg.PeerNetwork) + if err != nil { + return nil, nil, err + } + + ip, err := s.getPeerIP(ctx, id) + if err != nil { + return nil, nil, err + } + + if ip != nil { + return ip, r.Subnet, nil + } + + ip, err = s.allocatePeerIP(ctx, id, r) + if err != nil { + return nil, nil, err + } + + return ip, r.Subnet, nil +} + +func (s *Server) getNodeIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) { + subnet, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(id))) + if err != nil { + return nil, nil, err + } + r, err := parseSubnetRange(subnet) + if err != nil { + return nil, nil, err + } + + ip := r.Start + // assign .1 for router + ip[len(ip)-1] = 1 + + return ip, r.Subnet, nil +} + +func (s *Server) getPeerIPs(ctx context.Context) (map[string]net.IP, error) { + values, err := redis.StringMap(s.local(ctx, "HGETALL", peerIPsKey)) if err != nil { return nil, err } @@ -50,8 +98,22 @@ func (s *Server) getIPs(ctx context.Context) (map[string]net.IP, error) { return ips, nil } -func (s *Server) getIP(ctx context.Context, id string) (net.IP, error) { - allIPs, err := s.getIPs(ctx) +func (s *Server) getNodeIPs(ctx context.Context) (map[string]net.IP, error) { + values, err := redis.StringMap(s.local(ctx, "HGETALL", nodeIPsKey)) + if err != nil { + return nil, err + } + + ips := make(map[string]net.IP, len(values)) + for id, val := range values { + ip := net.ParseIP(string(val)) + ips[id] = ip + } + return ips, nil +} + +func (s *Server) getPeerIP(ctx context.Context, id string) (net.IP, error) { + allIPs, err := s.getPeerIPs(ctx) if err != nil { return nil, err } @@ -62,31 +124,8 @@ func (s *Server) getIP(ctx context.Context, id string) (net.IP, error) { return nil, nil } -func (s *Server) getOrAllocateIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) { - r, err := s.parseSubnetRange(s.cfg.PeerNetwork) - if err != nil { - return nil, nil, err - } - - ip, err := s.getIP(ctx, id) - if err != nil { - return nil, nil, err - } - - if ip != nil { - return ip, r.Subnet, nil - } - - ip, err = s.allocateIP(ctx, id, r) - if err != nil { - return nil, nil, err - } - - return ip, r.Subnet, nil -} - -func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net.IP, error) { - reservedIPs, err := s.getIPs(ctx) +func (s *Server) allocatePeerIP(ctx context.Context, id string, r *subnetRange) (net.IP, error) { + reservedIPs, err := s.getPeerIPs(ctx) if err != nil { return nil, err } @@ -110,7 +149,7 @@ func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net } // save - if _, err := s.master(ctx, "HSET", ipsKey, id, ip.String()); err != nil { + if _, err := s.master(ctx, "HSET", peerIPsKey, id, ip.String()); err != nil { return nil, err } return ip, nil @@ -119,14 +158,14 @@ func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net return nil, fmt.Errorf("no available IPs") } -func (s *Server) releaseIP(ctx context.Context, id string) error { - ip, err := s.getIP(ctx, id) +func (s *Server) releasePeerIP(ctx context.Context, id string) error { + ip, err := s.getPeerIP(ctx, id) if err != nil { return err } if ip != nil { - if _, err := s.master(ctx, "HDEL", ipsKey, id); err != nil { + if _, err := s.master(ctx, "HDEL", peerIPsKey, id); err != nil { return err } } @@ -153,7 +192,7 @@ func (s *Server) validIP(ip net.IP) bool { // parseSubnetRange parses the subnet range // format can either be a subnet like 10.0.0.0/8 or range like 10.0.0.100-10.0.0.200/24 -func (s *Server) parseSubnetRange(subnet string) (*subnetRange, error) { +func parseSubnetRange(subnet string) (*subnetRange, error) { parts := strings.Split(subnet, "-") if len(parts) == 1 { ip, sub, err := net.ParseCIDR(parts[0]) @@ -185,3 +224,79 @@ func (s *Server) parseSubnetRange(subnet string) (*subnetRange, error) { Subnet: sub, }, nil } + +// vendored from +func nextSubnet(n *net.IPNet, prefix int) (*net.IPNet, bool) { + _, currentLast := addressRange(n) + mask := net.CIDRMask(prefix, 8*len(currentLast)) + currentSubnet := &net.IPNet{IP: currentLast.Mask(mask), Mask: mask} + _, last := addressRange(currentSubnet) + last = inc(last) + next := &net.IPNet{IP: last.Mask(mask), Mask: mask} + if last.Equal(net.IPv4zero) || last.Equal(net.IPv6zero) { + return nil, false + } + return next, true +} + +func addressRange(network *net.IPNet) (net.IP, net.IP) { + firstIP := network.IP + prefixLen, bits := network.Mask.Size() + if prefixLen == bits { + lastIP := make([]byte, len(firstIP)) + copy(lastIP, firstIP) + return firstIP, lastIP + } + + firstIPInt, bits := ipToInt(firstIP) + hostLen := uint(bits) - uint(prefixLen) + lastIPInt := big.NewInt(1) + lastIPInt.Lsh(lastIPInt, hostLen) + lastIPInt.Sub(lastIPInt, big.NewInt(1)) + lastIPInt.Or(lastIPInt, firstIPInt) + + return firstIP, intToIP(lastIPInt, bits) + +} + +func inc(IP net.IP) net.IP { + IP = checkIPv4(IP) + incIP := make([]byte, len(IP)) + copy(incIP, IP) + for j := len(incIP) - 1; j >= 0; j-- { + incIP[j]++ + if incIP[j] > 0 { + break + } + } + return incIP +} + +func ipToInt(ip net.IP) (*big.Int, int) { + val := &big.Int{} + val.SetBytes([]byte(ip)) + if len(ip) == net.IPv4len { + return val, 32 + } else if len(ip) == net.IPv6len { + return val, 128 + } else { + panic(fmt.Errorf("Unsupported address length %d", len(ip))) + } +} + +func intToIP(ipInt *big.Int, bits int) net.IP { + ipBytes := ipInt.Bytes() + ret := make([]byte, bits/8) + for i := 1; i <= len(ipBytes); i++ { + ret[len(ret)-i] = ipBytes[len(ipBytes)-i] + + } + return net.IP(ret) +} + +func checkIPv4(ip net.IP) net.IP { + if v4 := ip.To4(); v4 != nil { + return v4 + } + return ip +} diff --git a/server/net_test.go b/server/net_test.go new file mode 100644 index 0000000..242eeee --- /dev/null +++ b/server/net_test.go @@ -0,0 +1,82 @@ +/* + Copyright 2019 Stellar Project + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in the + Software without restriction, including without limitation the rights to use, copy, + modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included in all copies + or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, + INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE + FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +package server + +import ( + "context" + "os" + "testing" + + "github.com/stellarproject/heimdall" +) + +const ( + testPeerNetwork = "10.51.0.0/16" + testNodeNetwork = "10.10.0.0/16" +) + +func TestNetSuite(t *testing.T) { + redisURL := os.Getenv("TEST_REDIS_URL") + if redisURL == "" { + t.Skip("TEST_REDIS_URL env var must be set") + } + cfg := &heimdall.Config{ + ID: "test", + RedisURL: redisURL, + NodeNetwork: testNodeNetwork, + PeerNetwork: testPeerNetwork, + } + + srv, err := NewServer(cfg) + if err != nil { + t.Fatal(err) + } + defer func() { + ctx := context.Background() + if _, err := srv.master(ctx, "FLUSHDB"); err != nil { + t.Errorf("error tearing down: %s", err) + } + }() + + // run tests + t.Run("AllocatePeerIP", testNetAllocatePeerIP(srv)) +} + +func testNetAllocatePeerIP(s *Server) func(t *testing.T) { + return func(t *testing.T) { + ctx := context.Background() + ip, ipnet, err := s.getOrAllocatePeerIP(ctx, "test-node") + if err != nil { + t.Fatal(err) + } + + expectedIP := "10.51.0.2" + + if ip.String() != expectedIP { + t.Errorf("expected ip %s; received %s", expectedIP, ip.String()) + } + + if ipnet.String() != testPeerNetwork { + t.Errorf("expected net %s; received %s", testPeerNetwork, ipnet.String()) + } + } +} diff --git a/server/peer.go b/server/peer.go index a4c4b8b..376b2e8 100644 --- a/server/peer.go +++ b/server/peer.go @@ -98,7 +98,29 @@ func (s *Server) updatePeerInfo(ctx context.Context) error { endpoint := fmt.Sprintf("%s:%d", s.cfg.GatewayIP, s.cfg.GatewayPort) // build allowedIPs from routes and peer network - allowedIPs := []string{s.cfg.PeerNetwork} + allowedIPs := []string{} + nodes, err := s.getNodes(ctx) + if err != nil { + return err + } + + logrus.Debugf("nodes: %+v", nodes) + + for _, node := range nodes { + // only add the route if a peer to prevent route duplicate + if node.ID != s.cfg.ID { + continue + } + + _, gatewayNet, err := s.getNodeIP(ctx, node.ID) + if err != nil { + return err + } + + logrus.Debugf("peer network %s", gatewayNet) + allowedIPs = append(allowedIPs, gatewayNet.String()) + } + routes, err := s.getRoutes(ctx) if err != nil { return err @@ -111,7 +133,6 @@ func (s *Server) updatePeerInfo(ctx context.Context) error { } logrus.Debugf("adding route to allowed IPs: %s", route.Network) - allowedIPs = append(allowedIPs, route.Network) } @@ -198,7 +219,7 @@ func (s *Server) updatePeerConfig(ctx context.Context) error { return err } - gatewayIP, _, err := s.getOrAllocateIP(ctx, s.cfg.ID) + gatewayIP, _, err := s.getNodeIP(ctx, s.cfg.ID) if err != nil { return err } @@ -206,7 +227,7 @@ func (s *Server) updatePeerConfig(ctx context.Context) error { Iface: defaultWireguardInterface, PrivateKey: keyPair.PrivateKey, ListenPort: s.cfg.GatewayPort, - Address: gatewayIP.String() + "/32", + Address: gatewayIP.To4().String() + "/32", Peers: peers, } @@ -251,6 +272,12 @@ func hashData(data []byte) string { } func hashConfig(cfgPath string) (string, error) { + if _, err := os.Stat(cfgPath); err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } peerData, err := ioutil.ReadFile(cfgPath) if err != nil { return "", err diff --git a/server/server.go b/server/server.go index 5a1a1a5..7636045 100644 --- a/server/server.go +++ b/server/server.go @@ -41,14 +41,16 @@ import ( ) const ( - masterKey = "heimdall:master" - clusterKey = "heimdall:key" - keypairsKey = "heimdall:keypairs" - nodesKey = "heimdall:nodes" - nodeJoinKey = "heimdall:join" - peersKey = "heimdall:peers" - routesKey = "heimdall:routes" - ipsKey = "heimdall:ips" + masterKey = "heimdall:master" + clusterKey = "heimdall:key" + keypairsKey = "heimdall:keypairs" + nodesKey = "heimdall:nodes" + nodeJoinKey = "heimdall:join" + peersKey = "heimdall:peers" + routesKey = "heimdall:routes" + peerIPsKey = "heimdall:peerips" + nodeIPsKey = "heimdall:nodeips" + nodeNetworksKey = "heimdall:nodenetworks" wireguardConfigPath = "/etc/wireguard/darknet.conf" ) @@ -134,20 +136,29 @@ func (s *Server) Run() error { } } + // ensure keypair if _, err := s.getOrCreateKeyPair(ctx, s.cfg.ID); err != nil { return err } - // ensure wireguard is started - _, _ = wgquick(ctx, "up", getTunnelName()) - - if err := s.updatePeerInfo(ctx); err != nil { + // ensure node network subnet + if err := s.ensureNetworkSubnet(ctx); err != nil { return err } // start node heartbeat to update in redis go s.nodeHeartbeat(ctx) + // initial peer info update + if err := s.updatePeerInfo(ctx); err != nil { + return err + } + + // initial config update + if err := s.updatePeerConfig(ctx); err != nil { + return err + } + // start peer config updater to configure wireguard as peers join go s.peerUpdater(ctx) @@ -193,6 +204,57 @@ func getPool(u string) *redis.Pool { return pool } +func (s *Server) ensureNetworkSubnet(ctx context.Context) error { + network, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(s.cfg.ID))) + if err != nil { + if err != redis.ErrNil { + return err + } + // allocate initial node subnet + r, err := parseSubnetRange(s.cfg.NodeNetwork) + if err != nil { + return err + } + // iterate node networks to find first free + nodeNetworkKeys, err := redis.Strings(s.local(ctx, "KEYS", s.getNodeNetworkKey("*"))) + if err != nil { + return err + } + logrus.Debug(nodeNetworkKeys) + lookup := map[string]struct{}{} + for _, netKey := range nodeNetworkKeys { + n, err := redis.String(s.local(ctx, "GET", netKey)) + if err != nil { + return err + } + lookup[n] = struct{}{} + } + + subnet := r.Subnet + size, _ := subnet.Mask.Size() + + for { + n, ok := nextSubnet(subnet, size) + if !ok { + return fmt.Errorf("error getting next subnet") + } + if _, exists := lookup[n.String()]; exists { + subnet = n + continue + } + logrus.Debugf("allocated network %s for %s", n.String(), s.cfg.ID) + if err := s.updateNodeNetwork(ctx, n.String()); err != nil { + return err + } + break + } + + return nil + } + logrus.Debugf("node network: %s", network) + return nil +} + func (s *Server) getOrCreateKeyPair(ctx context.Context, id string) (*v1.KeyPair, error) { key := s.getKeyPairKey(id) keyData, err := redis.Bytes(s.master(ctx, "GET", key)) @@ -242,6 +304,10 @@ func (s *Server) getKeyPairKey(id string) string { return fmt.Sprintf("%s:%s", keypairsKey, id) } +func (s *Server) getNodeNetworkKey(id string) string { + return fmt.Sprintf("%s:%s", nodeNetworksKey, id) +} + func (s *Server) getClient(addr string) (*client.Client, error) { return client.NewClient(s.cfg.ID, addr) } diff --git a/server/wireguard.go b/server/wireguard.go index 1bd75c8..7ce2886 100644 --- a/server/wireguard.go +++ b/server/wireguard.go @@ -49,7 +49,7 @@ PostDown = iptables -D FORWARD -i {{ .Iface }} -j ACCEPT; iptables -t nat -D POS {{ range .Peers }} [Peer] PublicKey = {{ .KeyPair.PublicKey }} -AllowedIPs = {{ allowedIPs .AllowedIPs }} +{{ if .AllowedIPs }}AllowedIPs = {{ allowedIPs .AllowedIPs }}{{ end }} Endpoint = {{ .Endpoint }} {{ end }} ` @@ -114,9 +114,16 @@ func getTunnelName() string { func restartWireguardTunnel(ctx context.Context) error { tunnelName := getTunnelName() logrus.Infof("restarting tunnel %s", tunnelName) - d, err := wgquick(ctx, "down", tunnelName) + d, err := wg(ctx, nil) if err != nil { - return errors.Wrap(err, string(d)) + return err + } + // only stop if running + if string(d) != "" { + d, err := wgquick(ctx, "down", tunnelName) + if err != nil { + return errors.Wrap(err, string(d)) + } } u, err := wgquick(ctx, "up", tunnelName) if err != nil { diff --git a/utils.go b/utils.go index eb7dd7b..d784367 100644 --- a/utils.go +++ b/utils.go @@ -41,7 +41,7 @@ func NodeID() string { var i net.Interface for _, iface := range ifaces { - if iface.Flags&net.FlagLoopback != 0 { + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagPointToPoint != 0 { continue }