fix wireguard routing
Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
parent
a1dda3097a
commit
87adb7d366
10 changed files with 367 additions and 56 deletions
|
@ -83,10 +83,16 @@ func main() {
|
|||
Value: generateKey(),
|
||||
EnvVar: "HEIMDALL_CLUSTER_KEY",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "node-network",
|
||||
Usage: "subnet to be used for nodes",
|
||||
Value: "10.10.0.0/16",
|
||||
EnvVar: "HEIMDALL_NODE_NETWORK",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "peer-network",
|
||||
Usage: "subnet to be used for peers",
|
||||
Value: "10.254.0.0/16",
|
||||
Value: "10.51.0.0/16",
|
||||
EnvVar: "HEIMDALL_PEER_NETWORK",
|
||||
},
|
||||
cli.StringFlag{
|
||||
|
|
|
@ -45,6 +45,7 @@ func runServer(cx *cli.Context) error {
|
|||
GRPCAddress: cx.String("addr"),
|
||||
GRPCPeerAddress: cx.String("peer"),
|
||||
ClusterKey: cx.String("cluster-key"),
|
||||
NodeNetwork: cx.String("node-network"),
|
||||
PeerNetwork: cx.String("peer-network"),
|
||||
GatewayIP: cx.String("gateway-ip"),
|
||||
GatewayPort: cx.Int("gateway-port"),
|
||||
|
|
|
@ -31,7 +31,9 @@ type Config struct {
|
|||
GRPCPeerAddress string
|
||||
// ClusterKey is a preshared key for cluster peers
|
||||
ClusterKey string
|
||||
// PeerNetwork is the subnet that will be used for cluster peers
|
||||
// NodeNetwork is the network for the cluster nodes
|
||||
NodeNetwork string
|
||||
// PeerNetwork is the subnet that is used for cluster peers
|
||||
PeerNetwork string
|
||||
// GatewayIP is the IP used for peer communication
|
||||
GatewayIP string
|
||||
|
|
5
go.sum
5
go.sum
|
@ -17,6 +17,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma
|
|||
github.com/crosbymichael/guard v0.0.0-20190716141324-5c2daadf8067 h1:jlV8Svz9lOwvxWBt2RN3uA1JUZ8AFj46boym2+Fx488=
|
||||
github.com/crosbymichael/guard v0.0.0-20190716141324-5c2daadf8067/go.mod h1:+l2fIHwwiNb/sUw9RcsUH6wXnO07793PC4XjDWCuiHs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/ehazlett/ttlcache v0.0.0-20190820213212-4400e3aef9f0/go.mod h1:D7IiYXsX2n2xixWvFTxGeZucCvvNtI14ikLj6L9Kp9E=
|
||||
github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs=
|
||||
|
@ -32,6 +33,7 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a
|
|||
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
|
||||
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
|
||||
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
|
@ -69,6 +71,7 @@ github.com/olebedev/emitter v0.0.0-20190110104742-e8d1457e6aee/go.mod h1:eT2/Pcs
|
|||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
|
@ -100,6 +103,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/urfave/cli v1.21.0/go.mod h1:lxDj6qX9Q6lWQxIrbrT0nwecwUtRnhVZAJjJZrVUZZQ=
|
||||
github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY=
|
||||
|
@ -142,5 +146,6 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks
|
|||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
|
183
server/net.go
183
server/net.go
|
@ -24,6 +24,7 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
|
@ -36,8 +37,55 @@ type subnetRange struct {
|
|||
Subnet *net.IPNet
|
||||
}
|
||||
|
||||
func (s *Server) getIPs(ctx context.Context) (map[string]net.IP, error) {
|
||||
values, err := redis.StringMap(s.local(ctx, "HGETALL", ipsKey))
|
||||
func (s *Server) updateNodeNetwork(ctx context.Context, subnet string) error {
|
||||
if _, err := s.master(ctx, "SET", s.getNodeNetworkKey(s.cfg.ID), subnet); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) getOrAllocatePeerIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) {
|
||||
r, err := parseSubnetRange(s.cfg.PeerNetwork)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ip, err := s.getPeerIP(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
return ip, r.Subnet, nil
|
||||
}
|
||||
|
||||
ip, err = s.allocatePeerIP(ctx, id, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ip, r.Subnet, nil
|
||||
}
|
||||
|
||||
func (s *Server) getNodeIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) {
|
||||
subnet, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(id)))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
r, err := parseSubnetRange(subnet)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ip := r.Start
|
||||
// assign .1 for router
|
||||
ip[len(ip)-1] = 1
|
||||
|
||||
return ip, r.Subnet, nil
|
||||
}
|
||||
|
||||
func (s *Server) getPeerIPs(ctx context.Context) (map[string]net.IP, error) {
|
||||
values, err := redis.StringMap(s.local(ctx, "HGETALL", peerIPsKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -50,8 +98,22 @@ func (s *Server) getIPs(ctx context.Context) (map[string]net.IP, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *Server) getIP(ctx context.Context, id string) (net.IP, error) {
|
||||
allIPs, err := s.getIPs(ctx)
|
||||
func (s *Server) getNodeIPs(ctx context.Context) (map[string]net.IP, error) {
|
||||
values, err := redis.StringMap(s.local(ctx, "HGETALL", nodeIPsKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ips := make(map[string]net.IP, len(values))
|
||||
for id, val := range values {
|
||||
ip := net.ParseIP(string(val))
|
||||
ips[id] = ip
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *Server) getPeerIP(ctx context.Context, id string) (net.IP, error) {
|
||||
allIPs, err := s.getPeerIPs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -62,31 +124,8 @@ func (s *Server) getIP(ctx context.Context, id string) (net.IP, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Server) getOrAllocateIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) {
|
||||
r, err := s.parseSubnetRange(s.cfg.PeerNetwork)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ip, err := s.getIP(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
return ip, r.Subnet, nil
|
||||
}
|
||||
|
||||
ip, err = s.allocateIP(ctx, id, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ip, r.Subnet, nil
|
||||
}
|
||||
|
||||
func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net.IP, error) {
|
||||
reservedIPs, err := s.getIPs(ctx)
|
||||
func (s *Server) allocatePeerIP(ctx context.Context, id string, r *subnetRange) (net.IP, error) {
|
||||
reservedIPs, err := s.getPeerIPs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -110,7 +149,7 @@ func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net
|
|||
}
|
||||
|
||||
// save
|
||||
if _, err := s.master(ctx, "HSET", ipsKey, id, ip.String()); err != nil {
|
||||
if _, err := s.master(ctx, "HSET", peerIPsKey, id, ip.String()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ip, nil
|
||||
|
@ -119,14 +158,14 @@ func (s *Server) allocateIP(ctx context.Context, id string, r *subnetRange) (net
|
|||
return nil, fmt.Errorf("no available IPs")
|
||||
}
|
||||
|
||||
func (s *Server) releaseIP(ctx context.Context, id string) error {
|
||||
ip, err := s.getIP(ctx, id)
|
||||
func (s *Server) releasePeerIP(ctx context.Context, id string) error {
|
||||
ip, err := s.getPeerIP(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
if _, err := s.master(ctx, "HDEL", ipsKey, id); err != nil {
|
||||
if _, err := s.master(ctx, "HDEL", peerIPsKey, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -153,7 +192,7 @@ func (s *Server) validIP(ip net.IP) bool {
|
|||
|
||||
// parseSubnetRange parses the subnet range
|
||||
// format can either be a subnet like 10.0.0.0/8 or range like 10.0.0.100-10.0.0.200/24
|
||||
func (s *Server) parseSubnetRange(subnet string) (*subnetRange, error) {
|
||||
func parseSubnetRange(subnet string) (*subnetRange, error) {
|
||||
parts := strings.Split(subnet, "-")
|
||||
if len(parts) == 1 {
|
||||
ip, sub, err := net.ParseCIDR(parts[0])
|
||||
|
@ -185,3 +224,79 @@ func (s *Server) parseSubnetRange(subnet string) (*subnetRange, error) {
|
|||
Subnet: sub,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// vendored from
|
||||
func nextSubnet(n *net.IPNet, prefix int) (*net.IPNet, bool) {
|
||||
_, currentLast := addressRange(n)
|
||||
mask := net.CIDRMask(prefix, 8*len(currentLast))
|
||||
currentSubnet := &net.IPNet{IP: currentLast.Mask(mask), Mask: mask}
|
||||
_, last := addressRange(currentSubnet)
|
||||
last = inc(last)
|
||||
next := &net.IPNet{IP: last.Mask(mask), Mask: mask}
|
||||
if last.Equal(net.IPv4zero) || last.Equal(net.IPv6zero) {
|
||||
return nil, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
func addressRange(network *net.IPNet) (net.IP, net.IP) {
|
||||
firstIP := network.IP
|
||||
prefixLen, bits := network.Mask.Size()
|
||||
if prefixLen == bits {
|
||||
lastIP := make([]byte, len(firstIP))
|
||||
copy(lastIP, firstIP)
|
||||
return firstIP, lastIP
|
||||
}
|
||||
|
||||
firstIPInt, bits := ipToInt(firstIP)
|
||||
hostLen := uint(bits) - uint(prefixLen)
|
||||
lastIPInt := big.NewInt(1)
|
||||
lastIPInt.Lsh(lastIPInt, hostLen)
|
||||
lastIPInt.Sub(lastIPInt, big.NewInt(1))
|
||||
lastIPInt.Or(lastIPInt, firstIPInt)
|
||||
|
||||
return firstIP, intToIP(lastIPInt, bits)
|
||||
|
||||
}
|
||||
|
||||
func inc(IP net.IP) net.IP {
|
||||
IP = checkIPv4(IP)
|
||||
incIP := make([]byte, len(IP))
|
||||
copy(incIP, IP)
|
||||
for j := len(incIP) - 1; j >= 0; j-- {
|
||||
incIP[j]++
|
||||
if incIP[j] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return incIP
|
||||
}
|
||||
|
||||
func ipToInt(ip net.IP) (*big.Int, int) {
|
||||
val := &big.Int{}
|
||||
val.SetBytes([]byte(ip))
|
||||
if len(ip) == net.IPv4len {
|
||||
return val, 32
|
||||
} else if len(ip) == net.IPv6len {
|
||||
return val, 128
|
||||
} else {
|
||||
panic(fmt.Errorf("Unsupported address length %d", len(ip)))
|
||||
}
|
||||
}
|
||||
|
||||
func intToIP(ipInt *big.Int, bits int) net.IP {
|
||||
ipBytes := ipInt.Bytes()
|
||||
ret := make([]byte, bits/8)
|
||||
for i := 1; i <= len(ipBytes); i++ {
|
||||
ret[len(ret)-i] = ipBytes[len(ipBytes)-i]
|
||||
|
||||
}
|
||||
return net.IP(ret)
|
||||
}
|
||||
|
||||
func checkIPv4(ip net.IP) net.IP {
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
return v4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
|
82
server/net_test.go
Normal file
82
server/net_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
Copyright 2019 Stellar Project
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in the
|
||||
Software without restriction, including without limitation the rights to use, copy,
|
||||
modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
and to permit persons to whom the Software is furnished to do so, subject to the
|
||||
following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies
|
||||
or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
|
||||
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
||||
USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stellarproject/heimdall"
|
||||
)
|
||||
|
||||
const (
|
||||
testPeerNetwork = "10.51.0.0/16"
|
||||
testNodeNetwork = "10.10.0.0/16"
|
||||
)
|
||||
|
||||
func TestNetSuite(t *testing.T) {
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
t.Skip("TEST_REDIS_URL env var must be set")
|
||||
}
|
||||
cfg := &heimdall.Config{
|
||||
ID: "test",
|
||||
RedisURL: redisURL,
|
||||
NodeNetwork: testNodeNetwork,
|
||||
PeerNetwork: testPeerNetwork,
|
||||
}
|
||||
|
||||
srv, err := NewServer(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
ctx := context.Background()
|
||||
if _, err := srv.master(ctx, "FLUSHDB"); err != nil {
|
||||
t.Errorf("error tearing down: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// run tests
|
||||
t.Run("AllocatePeerIP", testNetAllocatePeerIP(srv))
|
||||
}
|
||||
|
||||
func testNetAllocatePeerIP(s *Server) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ip, ipnet, err := s.getOrAllocatePeerIP(ctx, "test-node")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedIP := "10.51.0.2"
|
||||
|
||||
if ip.String() != expectedIP {
|
||||
t.Errorf("expected ip %s; received %s", expectedIP, ip.String())
|
||||
}
|
||||
|
||||
if ipnet.String() != testPeerNetwork {
|
||||
t.Errorf("expected net %s; received %s", testPeerNetwork, ipnet.String())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -98,7 +98,29 @@ func (s *Server) updatePeerInfo(ctx context.Context) error {
|
|||
endpoint := fmt.Sprintf("%s:%d", s.cfg.GatewayIP, s.cfg.GatewayPort)
|
||||
|
||||
// build allowedIPs from routes and peer network
|
||||
allowedIPs := []string{s.cfg.PeerNetwork}
|
||||
allowedIPs := []string{}
|
||||
nodes, err := s.getNodes(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logrus.Debugf("nodes: %+v", nodes)
|
||||
|
||||
for _, node := range nodes {
|
||||
// only add the route if a peer to prevent route duplicate
|
||||
if node.ID != s.cfg.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
_, gatewayNet, err := s.getNodeIP(ctx, node.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logrus.Debugf("peer network %s", gatewayNet)
|
||||
allowedIPs = append(allowedIPs, gatewayNet.String())
|
||||
}
|
||||
|
||||
routes, err := s.getRoutes(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -111,7 +133,6 @@ func (s *Server) updatePeerInfo(ctx context.Context) error {
|
|||
}
|
||||
|
||||
logrus.Debugf("adding route to allowed IPs: %s", route.Network)
|
||||
|
||||
allowedIPs = append(allowedIPs, route.Network)
|
||||
}
|
||||
|
||||
|
@ -198,7 +219,7 @@ func (s *Server) updatePeerConfig(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
gatewayIP, _, err := s.getOrAllocateIP(ctx, s.cfg.ID)
|
||||
gatewayIP, _, err := s.getNodeIP(ctx, s.cfg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -206,7 +227,7 @@ func (s *Server) updatePeerConfig(ctx context.Context) error {
|
|||
Iface: defaultWireguardInterface,
|
||||
PrivateKey: keyPair.PrivateKey,
|
||||
ListenPort: s.cfg.GatewayPort,
|
||||
Address: gatewayIP.String() + "/32",
|
||||
Address: gatewayIP.To4().String() + "/32",
|
||||
Peers: peers,
|
||||
}
|
||||
|
||||
|
@ -251,6 +272,12 @@ func hashData(data []byte) string {
|
|||
}
|
||||
|
||||
func hashConfig(cfgPath string) (string, error) {
|
||||
if _, err := os.Stat(cfgPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
peerData, err := ioutil.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
@ -48,7 +48,9 @@ const (
|
|||
nodeJoinKey = "heimdall:join"
|
||||
peersKey = "heimdall:peers"
|
||||
routesKey = "heimdall:routes"
|
||||
ipsKey = "heimdall:ips"
|
||||
peerIPsKey = "heimdall:peerips"
|
||||
nodeIPsKey = "heimdall:nodeips"
|
||||
nodeNetworksKey = "heimdall:nodenetworks"
|
||||
|
||||
wireguardConfigPath = "/etc/wireguard/darknet.conf"
|
||||
)
|
||||
|
@ -134,20 +136,29 @@ func (s *Server) Run() error {
|
|||
}
|
||||
}
|
||||
|
||||
// ensure keypair
|
||||
if _, err := s.getOrCreateKeyPair(ctx, s.cfg.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure wireguard is started
|
||||
_, _ = wgquick(ctx, "up", getTunnelName())
|
||||
|
||||
if err := s.updatePeerInfo(ctx); err != nil {
|
||||
// ensure node network subnet
|
||||
if err := s.ensureNetworkSubnet(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// start node heartbeat to update in redis
|
||||
go s.nodeHeartbeat(ctx)
|
||||
|
||||
// initial peer info update
|
||||
if err := s.updatePeerInfo(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// initial config update
|
||||
if err := s.updatePeerConfig(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// start peer config updater to configure wireguard as peers join
|
||||
go s.peerUpdater(ctx)
|
||||
|
||||
|
@ -193,6 +204,57 @@ func getPool(u string) *redis.Pool {
|
|||
return pool
|
||||
}
|
||||
|
||||
func (s *Server) ensureNetworkSubnet(ctx context.Context) error {
|
||||
network, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(s.cfg.ID)))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return err
|
||||
}
|
||||
// allocate initial node subnet
|
||||
r, err := parseSubnetRange(s.cfg.NodeNetwork)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// iterate node networks to find first free
|
||||
nodeNetworkKeys, err := redis.Strings(s.local(ctx, "KEYS", s.getNodeNetworkKey("*")))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Debug(nodeNetworkKeys)
|
||||
lookup := map[string]struct{}{}
|
||||
for _, netKey := range nodeNetworkKeys {
|
||||
n, err := redis.String(s.local(ctx, "GET", netKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lookup[n] = struct{}{}
|
||||
}
|
||||
|
||||
subnet := r.Subnet
|
||||
size, _ := subnet.Mask.Size()
|
||||
|
||||
for {
|
||||
n, ok := nextSubnet(subnet, size)
|
||||
if !ok {
|
||||
return fmt.Errorf("error getting next subnet")
|
||||
}
|
||||
if _, exists := lookup[n.String()]; exists {
|
||||
subnet = n
|
||||
continue
|
||||
}
|
||||
logrus.Debugf("allocated network %s for %s", n.String(), s.cfg.ID)
|
||||
if err := s.updateNodeNetwork(ctx, n.String()); err != nil {
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
logrus.Debugf("node network: %s", network)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) getOrCreateKeyPair(ctx context.Context, id string) (*v1.KeyPair, error) {
|
||||
key := s.getKeyPairKey(id)
|
||||
keyData, err := redis.Bytes(s.master(ctx, "GET", key))
|
||||
|
@ -242,6 +304,10 @@ func (s *Server) getKeyPairKey(id string) string {
|
|||
return fmt.Sprintf("%s:%s", keypairsKey, id)
|
||||
}
|
||||
|
||||
func (s *Server) getNodeNetworkKey(id string) string {
|
||||
return fmt.Sprintf("%s:%s", nodeNetworksKey, id)
|
||||
}
|
||||
|
||||
func (s *Server) getClient(addr string) (*client.Client, error) {
|
||||
return client.NewClient(s.cfg.ID, addr)
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ PostDown = iptables -D FORWARD -i {{ .Iface }} -j ACCEPT; iptables -t nat -D POS
|
|||
{{ range .Peers }}
|
||||
[Peer]
|
||||
PublicKey = {{ .KeyPair.PublicKey }}
|
||||
AllowedIPs = {{ allowedIPs .AllowedIPs }}
|
||||
{{ if .AllowedIPs }}AllowedIPs = {{ allowedIPs .AllowedIPs }}{{ end }}
|
||||
Endpoint = {{ .Endpoint }}
|
||||
{{ end }}
|
||||
`
|
||||
|
@ -114,10 +114,17 @@ func getTunnelName() string {
|
|||
func restartWireguardTunnel(ctx context.Context) error {
|
||||
tunnelName := getTunnelName()
|
||||
logrus.Infof("restarting tunnel %s", tunnelName)
|
||||
d, err := wg(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// only stop if running
|
||||
if string(d) != "" {
|
||||
d, err := wgquick(ctx, "down", tunnelName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, string(d))
|
||||
}
|
||||
}
|
||||
u, err := wgquick(ctx, "up", tunnelName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, string(u))
|
||||
|
|
2
utils.go
2
utils.go
|
@ -41,7 +41,7 @@ func NodeID() string {
|
|||
|
||||
var i net.Interface
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagPointToPoint != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue