b8c77e8b4a
Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
190 lines
3.8 KiB
Go
190 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
v1 "github.com/crosbymichael/guard/api/v1"
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/gomodule/redigo/redis"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
const (
|
|
ipsKey = "gatekeeper/ips"
|
|
configKeyPrefix = "gatekeeper/config/"
|
|
)
|
|
|
|
type ServerConfig struct {
|
|
ListenPort int
|
|
DBPath string
|
|
KeysPath string
|
|
HostKeyPath string
|
|
RedisURL string
|
|
Subnet string
|
|
GuardAddr string
|
|
GuardTunnel string
|
|
GuardDNS string
|
|
}
|
|
|
|
type Server struct {
|
|
cfg *ServerConfig
|
|
publicKeys []ssh.PublicKey
|
|
mu *sync.Mutex
|
|
}
|
|
|
|
func NewServer(cfg *ServerConfig) (*Server, error) {
|
|
if cfg.ListenPort == 0 {
|
|
cfg.ListenPort = 2222
|
|
}
|
|
return &Server{
|
|
cfg: cfg,
|
|
mu: &sync.Mutex{},
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
if err := s.loadKeys(); err != nil {
|
|
return err
|
|
}
|
|
|
|
ssh.Handle(func(session ssh.Session) {
|
|
id := s.getID(session.PublicKey())
|
|
config, err := s.getConfig(id)
|
|
if err != nil {
|
|
logrus.Error(err)
|
|
return
|
|
}
|
|
logrus.Debugf("config: id=%s", id)
|
|
io.WriteString(session, config)
|
|
})
|
|
|
|
pubKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
|
return s.isAuthorized(ctx, key)
|
|
})
|
|
|
|
logrus.Infof("starting ssh server on port %d", s.cfg.ListenPort)
|
|
opts := []ssh.Option{
|
|
pubKeyOption,
|
|
}
|
|
if _, err := os.Stat(s.cfg.HostKeyPath); err == nil {
|
|
opts = append(opts, ssh.HostKeyFile(s.cfg.HostKeyPath))
|
|
}
|
|
return ssh.ListenAndServe(fmt.Sprintf(":%d", s.cfg.ListenPort), nil, opts...)
|
|
}
|
|
|
|
func (s *Server) loadKeys() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.cfg.KeysPath == "" {
|
|
return nil
|
|
}
|
|
|
|
keys, err := ioutil.ReadDir(s.cfg.KeysPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pubKeys := []ssh.PublicKey{}
|
|
for _, k := range keys {
|
|
logrus.Debugf("loading public key %s", k.Name())
|
|
p := filepath.Join(s.cfg.KeysPath, k.Name())
|
|
data, err := ioutil.ReadFile(p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
k, _, _, _, err := ssh.ParseAuthorizedKey(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pubKeys = append(pubKeys, k)
|
|
}
|
|
|
|
s.publicKeys = pubKeys
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) isAuthorized(ctx ssh.Context, key ssh.PublicKey) bool {
|
|
for _, k := range s.publicKeys {
|
|
if ssh.KeysEqual(key, k) {
|
|
return true
|
|
}
|
|
}
|
|
logrus.WithFields(logrus.Fields{
|
|
"user": ctx.User(),
|
|
"addr": ctx.RemoteAddr(),
|
|
}).Warn("access denied")
|
|
return false
|
|
}
|
|
|
|
func (s *Server) getConfig(id string) (string, error) {
|
|
c, err := s.getConn()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer c.Close()
|
|
key := path.Join(configKeyPrefix, id)
|
|
cfg, err := redis.String(c.Do("GET", key))
|
|
if err != nil && err != redis.ErrNil {
|
|
return "", err
|
|
}
|
|
if cfg != "" {
|
|
return cfg, nil
|
|
}
|
|
|
|
conn, err := grpc.Dial(s.cfg.GuardAddr, grpc.WithInsecure())
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "error connecting to guard server")
|
|
}
|
|
defer conn.Close()
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer cancel()
|
|
|
|
client := v1.NewWireguardClient(conn)
|
|
|
|
ip, ipnet, err := s.getOrAllocateIP(id, s.cfg.Subnet)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
r, err := client.NewPeer(ctx, &v1.NewPeerRequest{
|
|
ID: s.cfg.GuardTunnel,
|
|
PeerID: id,
|
|
Address: ip.String() + "/32",
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// generate peer tunnel config
|
|
t := &v1.Tunnel{
|
|
PrivateKey: r.Peer.PrivateKey,
|
|
Address: r.Peer.AllowedIPs[0],
|
|
DNS: s.cfg.GuardDNS,
|
|
Peers: []*v1.Peer{
|
|
{
|
|
ID: r.Tunnel.ID,
|
|
PublicKey: r.Tunnel.PublicKey,
|
|
Endpoint: net.JoinHostPort(r.Tunnel.Endpoint, r.Tunnel.ListenPort),
|
|
AllowedIPs: []string{ipnet.String()},
|
|
},
|
|
},
|
|
}
|
|
b := bytes.NewBuffer(nil)
|
|
t.Render(b)
|
|
if _, err := c.Do("SET", key, b.String()); err != nil {
|
|
return "", err
|
|
}
|
|
return b.String(), nil
|
|
}
|