add advertise grpc address option

Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
Evan Hazlett 2019-10-07 13:52:47 -04:00
parent 4bc08c92c2
commit 190ec3130d
12 changed files with 415 additions and 204 deletions

View file

@ -60,7 +60,35 @@ func (s *Server) Join(ctx context.Context, req *v1.JoinRequest) (*v1.JoinRespons
return nil, err
}
peers, err := s.getPeers(ctx)
if err != nil {
return nil, err
}
if err := s.ensureNetworkSubnet(ctx, req.ID); err != nil {
return nil, err
}
node, err := s.getNode(ctx, req.ID)
if err != nil {
if err != redis.ErrNil {
return nil, err
}
n, err := s.createNode(ctx, req)
if err != nil {
return nil, err
}
if err := s.updatePeerInfo(ctx, req.ID); err != nil {
return nil, err
}
node = n
}
return &v1.JoinResponse{
Master: &master,
Node: node,
Peers: peers,
}, nil
}

View file

@ -29,6 +29,7 @@ import (
"strings"
"github.com/gomodule/redigo/redis"
"github.com/pkg/errors"
)
type subnetRange struct {
@ -37,8 +38,8 @@ type subnetRange struct {
Subnet *net.IPNet
}
func (s *Server) updateNodeNetwork(ctx context.Context, subnet string) error {
if _, err := s.master(ctx, "SET", s.getNodeNetworkKey(s.cfg.ID), subnet); err != nil {
func (s *Server) updateNodeNetwork(ctx context.Context, id string, subnet string) error {
if _, err := s.master(ctx, "SET", s.getNodeNetworkKey(id), subnet); err != nil {
return err
}
return nil
@ -70,7 +71,7 @@ func (s *Server) getOrAllocatePeerIP(ctx context.Context, id string) (net.IP, *n
func (s *Server) getNodeIP(ctx context.Context, id string) (net.IP, *net.IPNet, error) {
subnet, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(id)))
if err != nil {
return nil, nil, err
return nil, nil, errors.Wrap(err, "error getting node ip")
}
r, err := parseSubnetRange(subnet)
if err != nil {

View file

@ -23,6 +23,7 @@ package server
import (
"context"
"fmt"
"net/url"
"sort"
"strings"
@ -82,24 +83,15 @@ func (s *Server) getNode(ctx context.Context, id string) (*v1.Node, error) {
func (s *Server) configureNode() error {
ctx := context.Background()
nodeKeys, err := redis.Strings(s.local(ctx, "KEYS", s.getNodeKey("*")))
nodes, err := s.getNodes(ctx)
if err != nil {
return err
}
// attempt to connect to existing
if len(nodeKeys) > 0 {
for _, nodeKey := range nodeKeys {
nodeData, err := redis.Bytes(s.local(ctx, "GET", nodeKey))
if err != nil {
logrus.Warn(err)
continue
}
var node v1.Node
if err := proto.Unmarshal(nodeData, &node); err != nil {
return err
}
if len(nodes) > 0 {
for _, node := range nodes {
// ignore self
if node.Addr == s.cfg.GRPCAddress {
if node.ID == s.cfg.ID {
continue
}
@ -109,14 +101,28 @@ func (s *Server) configureNode() error {
logrus.Warn(err)
continue
}
m, err := c.Join(s.cfg.ClusterKey)
r, err := c.Join(&v1.JoinRequest{
ID: s.cfg.ID,
ClusterKey: s.cfg.ClusterKey,
GRPCAddress: s.cfg.GRPCAddress,
EndpointIP: s.cfg.EndpointIP,
EndpointPort: uint64(s.cfg.EndpointPort),
InterfaceName: s.cfg.InterfaceName,
})
if err != nil {
c.Close()
logrus.Warn(err)
continue
}
if err := s.joinMaster(m); err != nil {
// TODO: start tunnel
if err := s.updatePeerConfig(ctx, r.Node, r.Peers); err != nil {
return err
}
// TODO: wait for tunnel to come up
time.Sleep(time.Second * 20)
if err := s.joinMaster(r.Master); err != nil {
c.Close()
logrus.Warn(err)
continue
@ -154,6 +160,8 @@ func (s *Server) configureNode() error {
return err
}
// TODO: start tunnel
if err := s.joinMaster(&master); err != nil {
return err
}
@ -163,14 +171,19 @@ func (s *Server) configureNode() error {
return nil
}
func (s *Server) disableReplica() {
s.wpool = getPool(s.cfg.RedisURL)
func (s *Server) disableReplica() error {
p, err := getPool(s.cfg.RedisURL)
if err != nil {
return err
}
s.wpool = p
// signal replica monitor to stop if started as a peer
close(s.replicaCh)
// unset peer
s.cfg.GRPCPeerAddress = ""
return nil
}
func (s *Server) replicaMonitor() {
@ -187,7 +200,7 @@ func (s *Server) replicaMonitor() {
logrus.Error(err)
continue
}
if n.ID != s.cfg.ID {
if n == nil || n.ID != s.cfg.ID {
logrus.Debugf("waiting for new master to initialize: %s", n.ID)
continue
}
@ -242,12 +255,15 @@ func (s *Server) masterHeartbeat() {
func (s *Server) joinMaster(m *v1.Master) error {
// configure replica
logrus.Infof("configuring node as replica of %+v", m.ID)
conn, err := redis.DialURL(s.cfg.RedisURL)
pool, err := getPool(s.cfg.RedisURL)
if err != nil {
return errors.Wrap(err, "unable to connect to redis")
return err
}
conn := pool.Get()
defer conn.Close()
logrus.Debugf("configuring redis as slave of %s", m.RedisURL)
u, err := url.Parse(m.RedisURL)
if err != nil {
return errors.Wrap(err, "error parsing master redis url")
@ -258,15 +274,11 @@ func (s *Server) joinMaster(m *v1.Master) error {
if _, err := conn.Do("REPLICAOF", host, port); err != nil {
return err
}
// auth
auth, ok := u.User.Password()
if ok {
if _, err := conn.Do("CONFIG", "SET", "MASTERAUTH", auth); err != nil {
return errors.Wrap(err, "error authenticating to redis")
}
}
s.wpool = getPool(m.RedisURL)
s.wpool, err = getPool(m.RedisURL)
if err != nil {
return err
}
return nil
}
@ -275,10 +287,21 @@ func (s *Server) updateMasterInfo(ctx context.Context) error {
if _, err := s.master(ctx, "SET", clusterKey, s.cfg.ClusterKey); err != nil {
return err
}
// build redis url with gateway ip
gatewayIP, _, err := s.getNodeIP(ctx, s.cfg.ID)
if err != nil {
return err
}
u, err := url.Parse(s.cfg.RedisURL)
if err != nil {
return err
}
// update master redis url to gateway to serve over wireguard
u.Host = fmt.Sprintf("%s:%s", gatewayIP.String(), u.Port())
m := &v1.Master{
ID: s.cfg.ID,
GRPCAddress: s.cfg.GRPCAddress,
RedisURL: s.cfg.AdvertiseRedisURL,
GRPCAddress: s.cfg.AdvertiseGRPCAddress,
RedisURL: u.String(),
}
data, err := proto.Marshal(m)
if err != nil {
@ -292,48 +315,91 @@ func (s *Server) updateMasterInfo(ctx context.Context) error {
if _, err := s.master(ctx, "EXPIRE", masterKey, int(masterHeartbeatInterval.Seconds())); err != nil {
return errors.Wrap(err, "error setting expire for master info")
}
return nil
}
func (s *Server) nodeHeartbeat(ctx context.Context) {
func (s *Server) updateNodeInfo(ctx context.Context) {
logrus.Debugf("starting node heartbeat: ttl=%s", nodeHeartbeatInterval)
t := time.NewTicker(nodeHeartbeatInterval)
key := s.getNodeKey(s.cfg.ID)
for range t.C {
keyPair, err := s.getOrCreateKeyPair(ctx, s.cfg.ID)
if err != nil {
logrus.Error(err)
continue
}
nodeIP, _, err := s.getNodeIP(ctx, s.cfg.ID)
if err != nil {
logrus.Error(err)
continue
}
node := &v1.Node{
Updated: time.Now(),
ID: s.cfg.ID,
Addr: s.cfg.GRPCAddress,
KeyPair: keyPair,
EndpointIP: s.cfg.EndpointIP,
EndpointPort: uint64(s.cfg.EndpointPort),
GatewayIP: nodeIP.String(),
}
data, err := proto.Marshal(node)
if err != nil {
logrus.Error(err)
continue
}
if _, err := s.master(ctx, "SET", key, data); err != nil {
logrus.Error(err)
continue
}
if _, err := s.master(ctx, "EXPIRE", key, nodeHeartbeatExpiry); err != nil {
if err := s.updateLocalNodeInfo(ctx); err != nil {
logrus.Error(err)
continue
}
}
}
func (s *Server) updateLocalNodeInfo(ctx context.Context) error {
key := s.getNodeKey(s.cfg.ID)
keyPair, err := s.getOrCreateKeyPair(ctx, s.cfg.ID)
if err != nil {
return err
}
nodeIP, _, err := s.getNodeIP(ctx, s.cfg.ID)
if err != nil {
return err
}
node := &v1.Node{
Updated: time.Now(),
ID: s.cfg.ID,
Addr: s.cfg.GRPCAddress,
KeyPair: keyPair,
EndpointIP: s.cfg.EndpointIP,
EndpointPort: uint64(s.cfg.EndpointPort),
GatewayIP: nodeIP.String(),
InterfaceName: s.cfg.InterfaceName,
}
data, err := proto.Marshal(node)
if err != nil {
return err
}
if _, err := s.master(ctx, "SET", key, data); err != nil {
return err
}
if _, err := s.master(ctx, "EXPIRE", key, nodeHeartbeatExpiry); err != nil {
return err
}
return nil
}
func (s *Server) createNode(ctx context.Context, req *v1.JoinRequest) (*v1.Node, error) {
key := s.getNodeKey(req.ID)
keyPair, err := s.getOrCreateKeyPair(ctx, req.ID)
if err != nil {
return nil, errors.Wrapf(err, "error getting/creating keypair for %s", req.ID)
}
nodeIP, _, err := s.getNodeIP(ctx, req.ID)
if err != nil {
return nil, errors.Wrapf(err, "error getting node ip for %s", req.ID)
}
node := &v1.Node{
Updated: time.Now(),
ID: req.ID,
Addr: req.GRPCAddress,
KeyPair: keyPair,
EndpointIP: req.EndpointIP,
EndpointPort: uint64(req.EndpointPort),
GatewayIP: nodeIP.String(),
InterfaceName: req.InterfaceName,
}
data, err := proto.Marshal(node)
if err != nil {
return nil, err
}
if _, err := s.master(ctx, "SET", key, data); err != nil {
return nil, err
}
if _, err := s.master(ctx, "EXPIRE", key, nodeHeartbeatExpiry); err != nil {
return nil, err
}
return node, nil
}

View file

@ -69,10 +69,9 @@ func (s *Server) getPeers(ctx context.Context) ([]*v1.Peer, error) {
func (s *Server) peerUpdater(ctx context.Context) {
logrus.Debugf("starting peer config updater: ttl=%s", peerConfigUpdateInterval)
t := time.NewTicker(peerConfigUpdateInterval)
for range t.C {
logrus.Debug("peer config update")
uctx, cancel := context.WithTimeout(ctx, peerConfigUpdateInterval)
if err := s.updatePeerInfo(uctx, s.cfg.ID); err != nil {
logrus.Errorf("updateLocalPeerInfo: %s", err)
@ -80,7 +79,23 @@ func (s *Server) peerUpdater(ctx context.Context) {
continue
}
if err := s.updatePeerConfig(uctx); err != nil {
peers, err := s.getPeers(ctx)
if err != nil {
logrus.Error(err)
cancel()
continue
}
logrus.Debugf("peer update: peers %+v", peers)
node, err := s.getNode(ctx, s.cfg.ID)
if err != nil {
logrus.Error(err)
cancel()
continue
}
if err := s.updatePeerConfig(uctx, node, peers); err != nil {
logrus.Errorf("updatePeerConfig: %s", err)
cancel()
continue
@ -208,47 +223,29 @@ func (s *Server) getPeerInfo(ctx context.Context, id string) (*v1.Peer, error) {
return &peer, nil
}
func (s *Server) updatePeerConfig(ctx context.Context) error {
peerKeys, err := redis.Strings(s.local(ctx, "KEYS", s.getPeerKey("*")))
if err != nil {
return err
}
var peers []*v1.Peer
for _, peerKey := range peerKeys {
peerData, err := redis.Bytes(s.local(ctx, "GET", peerKey))
if err != nil {
return err
}
var p v1.Peer
if err := proto.Unmarshal(peerData, &p); err != nil {
return err
}
func (s *Server) updatePeerConfig(ctx context.Context, node *v1.Node, peers []*v1.Peer) error {
var nodePeers []*v1.Peer
for _, peer := range peers {
// do not add self as a peer
if p.ID == s.cfg.ID {
if peer.ID == node.ID {
continue
}
peers = append(peers, &p)
nodePeers = append(nodePeers, peer)
}
keyPair, err := s.getOrCreateKeyPair(ctx, s.cfg.ID)
keyPair, err := s.getOrCreateKeyPair(ctx, node.ID)
if err != nil {
return err
}
gatewayIP, gatewayNet, err := s.getNodeIP(ctx, s.cfg.ID)
if err != nil {
return err
}
size, _ := gatewayNet.Mask.Size()
//size, _ := gatewayNet.Mask.Size()
wireguardCfg := &wg.Config{
Iface: s.cfg.InterfaceName,
Iface: node.InterfaceName,
PrivateKey: keyPair.PrivateKey,
ListenPort: s.cfg.EndpointPort,
Address: fmt.Sprintf("%s/%d", gatewayIP.To4().String(), size),
Peers: peers,
ListenPort: int(node.EndpointPort),
Address: fmt.Sprintf("%s/%d", node.GatewayIP, 16),
Peers: nodePeers,
}
wireguardConfigPath := s.getWireguardConfigPath()

View file

@ -25,6 +25,7 @@ import (
"context"
"fmt"
"io/ioutil"
"net/url"
"path/filepath"
"runtime"
"runtime/pprof"
@ -81,7 +82,10 @@ type Server struct {
// NewServer returns a new Heimdall server
func NewServer(cfg *heimdall.Config) (*Server, error) {
pool := getPool(cfg.RedisURL)
pool, err := getPool(cfg.RedisURL)
if err != nil {
return nil, err
}
return &Server{
cfg: cfg,
rpool: pool,
@ -121,13 +125,28 @@ func (s *Server) Run() error {
}
defer c.Close()
master, err := c.Join(s.cfg.ClusterKey)
r, err := c.Join(&v1.JoinRequest{
ID: s.cfg.ID,
ClusterKey: s.cfg.ClusterKey,
GRPCAddress: s.cfg.GRPCAddress,
EndpointIP: s.cfg.EndpointIP,
EndpointPort: uint64(s.cfg.EndpointPort),
InterfaceName: s.cfg.InterfaceName,
})
if err != nil {
return err
}
logrus.Debugf("master info received: %+v", master)
if err := s.joinMaster(master); err != nil {
logrus.Debugf("response: %+v", r)
// start tunnel
if err := s.updatePeerConfig(ctx, r.Node, r.Peers); err != nil {
return errors.Wrap(err, "error updating peer config")
}
// TODO: wait for tunnel to come up
time.Sleep(time.Second * 20)
logrus.Debugf("master info received: %+v", r)
if err := s.joinMaster(r.Master); err != nil {
return err
}
@ -144,12 +163,17 @@ func (s *Server) Run() error {
}
// ensure node network subnet
if err := s.ensureNetworkSubnet(ctx); err != nil {
if err := s.ensureNetworkSubnet(ctx, s.cfg.ID); err != nil {
return err
}
// initial node update
if err := s.updateLocalNodeInfo(ctx); err != nil {
return err
}
// start node heartbeat to update in redis
go s.nodeHeartbeat(ctx)
go s.updateNodeInfo(ctx)
// initial peer info update
if err := s.updatePeerInfo(ctx, s.cfg.ID); err != nil {
@ -157,7 +181,15 @@ func (s *Server) Run() error {
}
// initial config update
if err := s.updatePeerConfig(ctx); err != nil {
node, err := s.getNode(ctx, s.cfg.ID)
if err != nil {
return err
}
peers, err := s.getPeers(ctx)
if err != nil {
return err
}
if err := s.updatePeerConfig(ctx, node, peers); err != nil {
return err
}
@ -184,7 +216,7 @@ func (s *Server) Run() error {
}
}()
err := <-errCh
err = <-errCh
return err
}
@ -194,20 +226,33 @@ func (s *Server) Stop() error {
return nil
}
func getPool(u string) *redis.Pool {
func getPool(redisUrl string) (*redis.Pool, error) {
pool := redis.NewPool(func() (redis.Conn, error) {
conn, err := redis.DialURL(u)
conn, err := redis.DialURL(redisUrl)
if err != nil {
return nil, errors.Wrap(err, "unable to connect to redis")
}
u, err := url.Parse(redisUrl)
if err != nil {
return nil, err
}
auth, ok := u.User.Password()
if ok {
logrus.Debug("setting masterauth for redis")
if _, err := conn.Do("CONFIG", "SET", "MASTERAUTH", auth); err != nil {
return nil, errors.Wrap(err, "error authenticating to redis")
}
}
return conn, nil
}, 10)
return pool
return pool, nil
}
func (s *Server) ensureNetworkSubnet(ctx context.Context) error {
network, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(s.cfg.ID)))
func (s *Server) ensureNetworkSubnet(ctx context.Context, id string) error {
network, err := redis.String(s.local(ctx, "GET", s.getNodeNetworkKey(id)))
if err != nil {
if err != redis.ErrNil {
return err
@ -222,7 +267,6 @@ func (s *Server) ensureNetworkSubnet(ctx context.Context) error {
if err != nil {
return err
}
logrus.Debug(nodeNetworkKeys)
lookup := map[string]struct{}{}
for _, netKey := range nodeNetworkKeys {
n, err := redis.String(s.local(ctx, "GET", netKey))
@ -244,8 +288,8 @@ func (s *Server) ensureNetworkSubnet(ctx context.Context) error {
subnet = n
continue
}
logrus.Debugf("allocated network %s for %s", n.String(), s.cfg.ID)
if err := s.updateNodeNetwork(ctx, n.String()); err != nil {
logrus.Debugf("allocated network %s for %s", n.String(), id)
if err := s.updateNodeNetwork(ctx, id, n.String()); err != nil {
return err
}
break
@ -253,7 +297,7 @@ func (s *Server) ensureNetworkSubnet(ctx context.Context) error {
return nil
}
logrus.Debugf("node network: %s", network)
logrus.Debugf("node network for %s: %s", id, network)
return nil
}