gatekeeper/server.go

191 lines
3.8 KiB
Go
Raw Normal View History

package main
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"path"
"path/filepath"
"sync"
"time"
v1 "github.com/crosbymichael/guard/api/v1"
"github.com/gliderlabs/ssh"
"github.com/gomodule/redigo/redis"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
const (
ipsKey = "gatekeeper/ips"
configKeyPrefix = "gatekeeper/config/"
)
type ServerConfig struct {
ListenPort int
DBPath string
KeysPath string
HostKeyPath string
RedisURL string
Subnet string
GuardAddr string
GuardTunnel string
GuardDNS string
}
type Server struct {
cfg *ServerConfig
publicKeys []ssh.PublicKey
mu *sync.Mutex
}
func NewServer(cfg *ServerConfig) (*Server, error) {
if cfg.ListenPort == 0 {
cfg.ListenPort = 2222
}
return &Server{
cfg: cfg,
mu: &sync.Mutex{},
}, nil
}
func (s *Server) Run() error {
if err := s.loadKeys(); err != nil {
return err
}
ssh.Handle(func(session ssh.Session) {
id := s.getID(session.PublicKey())
config, err := s.getConfig(id)
if err != nil {
logrus.Error(err)
return
}
logrus.Debugf("config: id=%s", id)
io.WriteString(session, config)
})
pubKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
return s.isAuthorized(ctx, key)
})
logrus.Infof("starting ssh server on port %d", s.cfg.ListenPort)
opts := []ssh.Option{
pubKeyOption,
}
if _, err := os.Stat(s.cfg.HostKeyPath); err == nil {
opts = append(opts, ssh.HostKeyFile(s.cfg.HostKeyPath))
}
return ssh.ListenAndServe(fmt.Sprintf(":%d", s.cfg.ListenPort), nil, opts...)
}
func (s *Server) loadKeys() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cfg.KeysPath == "" {
return nil
}
keys, err := ioutil.ReadDir(s.cfg.KeysPath)
if err != nil {
return err
}
pubKeys := []ssh.PublicKey{}
for _, k := range keys {
logrus.Debugf("loading public key %s", k.Name())
p := filepath.Join(s.cfg.KeysPath, k.Name())
data, err := ioutil.ReadFile(p)
if err != nil {
return err
}
k, _, _, _, err := ssh.ParseAuthorizedKey(data)
if err != nil {
return err
}
pubKeys = append(pubKeys, k)
}
s.publicKeys = pubKeys
return nil
}
func (s *Server) isAuthorized(ctx ssh.Context, key ssh.PublicKey) bool {
for _, k := range s.publicKeys {
if ssh.KeysEqual(key, k) {
return true
}
}
logrus.WithFields(logrus.Fields{
"user": ctx.User(),
"addr": ctx.RemoteAddr(),
}).Warn("access denied")
return false
}
func (s *Server) getConfig(id string) (string, error) {
c, err := s.getConn()
if err != nil {
return "", err
}
defer c.Close()
key := path.Join(configKeyPrefix, id)
cfg, err := redis.String(c.Do("GET", key))
if err != nil && err != redis.ErrNil {
return "", err
}
if cfg != "" {
return cfg, nil
}
conn, err := grpc.Dial(s.cfg.GuardAddr, grpc.WithInsecure())
if err != nil {
return "", errors.Wrap(err, "error connecting to guard server")
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
client := v1.NewWireguardClient(conn)
ip, ipnet, err := s.getOrAllocateIP(id, s.cfg.Subnet)
if err != nil {
return "", err
}
r, err := client.NewPeer(ctx, &v1.NewPeerRequest{
ID: s.cfg.GuardTunnel,
PeerID: id,
Address: ip.String() + "/32",
})
if err != nil {
return "", err
}
// generate peer tunnel config
t := &v1.Tunnel{
PrivateKey: r.Peer.PrivateKey,
Address: r.Peer.AllowedIPs[0],
DNS: s.cfg.GuardDNS,
Peers: []*v1.Peer{
{
ID: r.Tunnel.ID,
PublicKey: r.Tunnel.PublicKey,
Endpoint: net.JoinHostPort(r.Tunnel.Endpoint, r.Tunnel.ListenPort),
AllowedIPs: []string{ipnet.String()},
},
},
}
b := bytes.NewBuffer(nil)
t.Render(b)
if _, err := c.Do("SET", key, b.String()); err != nil {
return "", err
}
return b.String(), nil
}