diff --git a/go.mod b/go.mod index d61eecb..68d1267 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.12 require ( github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 // indirect github.com/gliderlabs/ssh v0.2.2 + github.com/gomodule/redigo v2.0.0+incompatible github.com/sirupsen/logrus v1.4.2 github.com/urfave/cli v1.22.1 + go.etcd.io/bbolt v1.3.3 golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 ) diff --git a/go.sum b/go.sum index ac50b29..a3f2bd4 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gliderlabs/ssh v0.2.2 h1:6zsha5zo/TWhRhwqCD3+EarCAgZ2yN28ipRnGPnwkI0= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -21,6 +23,8 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/main.go b/main.go index 1135c16..ca2c887 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,16 @@ func main() { Usage: "path to host key", Value: "/etc/ssh/ssh_host_rsa_key", }, + cli.StringFlag{ + Name: "subnet, s", + Usage: "subnet for ip allocation", + Value: "10.199.254.0/24", + }, + cli.StringFlag{ + Name: "redis, r", + Usage: "redis url", + Value: "redis://127.0.0.1:6379", + }, } app.Before = func(cx *cli.Context) error { if cx.Bool("debug") { @@ -47,6 +57,8 @@ func main() { ListenPort: cx.Int("port"), KeysPath: cx.String("key-dir"), HostKeyPath: cx.String("host-key"), + RedisURL: cx.String("redis"), + Subnet: cx.String("subnet"), } srv, err := NewServer(cfg) if err != nil { diff --git a/net.go b/net.go new file mode 100644 index 0000000..f0e105e --- /dev/null +++ b/net.go @@ -0,0 +1,187 @@ +package main + +import ( + "fmt" + "net" + "strings" + + "github.com/gomodule/redigo/redis" +) + +const ( + ipsKey = "gatekeeper/ips" +) + +type subnetRange struct { + Start net.IP + End net.IP + Subnet *net.IPNet +} + +func (s *Server) getIPs() (map[string]net.IP, error) { + c, err := s.getConn() + if err != nil { + return nil, err + } + defer c.Close() + + values, err := redis.StringMap(c.Do("HGETALL", ipsKey)) + if err != nil { + return nil, err + } + + ips := make(map[string]net.IP, len(values)) + for id, val := range values { + ip := net.ParseIP(string(val)) + ips[id] = ip + } + return ips, nil +} + +func (s *Server) getIP(id string) (net.IP, error) { + allIPs, err := s.getIPs() + if err != nil { + return nil, err + } + + if ip, exists := allIPs[id]; exists { + return ip, nil + } + return nil, nil +} + +func (s *Server) getOrAllocateIP(id, subnet string) (net.IP, *net.IPNet, error) { + r, err := s.parseSubnetRange(subnet) + if err != nil { + return nil, nil, err + } + + ip, err := s.getIP(id) + if err != nil { + return nil, nil, err + } + + if ip != nil { + return ip, r.Subnet, nil + } + + ip, err = s.allocateIP(id, r) + if err != nil { + return nil, nil, err + } + + return ip, r.Subnet, nil +} + +func (s *Server) allocateIP(id string, r *subnetRange) (net.IP, error) { + c, err := s.getConn() + if err != nil { + return nil, err + } + defer c.Close() + + reservedIPs, err := s.getIPs() + if err != nil { + return nil, err + } + + if ip, exists := reservedIPs[id]; exists { + return ip, nil + } + + lookup := map[string]string{} + for id, ip := range reservedIPs { + lookup[ip.String()] = id + } + for ip := r.Start; !ip.Equal(r.End); s.nextIP(ip) { + // filter out network, gateway and broadcast + if !s.validIP(ip) { + continue + } + if _, exists := lookup[ip.String()]; exists { + // ip already reserved + continue + } + + // save + if _, err := c.Do("HSET", ipsKey, id, ip.String()); err != nil { + return nil, err + } + return ip, nil + } + + return nil, fmt.Errorf("no available IPs") +} + +func (s *Server) releaseIP(id string) error { + c, err := s.getConn() + if err != nil { + return err + } + defer c.Close() + + ip, err := s.getIP(id) + if err != nil { + return err + } + + if ip != nil { + if _, err := c.Do("HDEL", ipsKey, id); err != nil { + return err + } + } + return nil +} + +func (s *Server) nextIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +func (s *Server) validIP(ip net.IP) bool { + v := ip[len(ip)-1] + switch v { + case 0, 1, 255: + return false + } + return true +} + +// parseSubnetRange parses the subnet range +// format can either be a subnet like 10.0.0.0/8 or range like 10.0.0.100-10.0.0.200/24 +func (s *Server) parseSubnetRange(subnet string) (*subnetRange, error) { + parts := strings.Split(subnet, "-") + if len(parts) == 1 { + ip, sub, err := net.ParseCIDR(parts[0]) + if err != nil { + return nil, err + } + + end := make(net.IP, len(ip)) + copy(end, ip) + end[len(end)-1] = 254 + return &subnetRange{ + Start: ip, + End: end, + Subnet: sub, + }, nil + } + if len(parts) > 2 || !strings.Contains(subnet, "/") { + return nil, fmt.Errorf("invalid range specified; expect format 10.0.0.100-10.0.0.200/24") + } + start := net.ParseIP(parts[0]) + end, sub, err := net.ParseCIDR(parts[1]) + if err != nil { + return nil, err + } + + return &subnetRange{ + Start: start, + End: end, + Subnet: sub, + }, nil +} diff --git a/server.go b/server.go index c46f507..818282d 100644 --- a/server.go +++ b/server.go @@ -10,13 +10,15 @@ import ( "github.com/gliderlabs/ssh" "github.com/sirupsen/logrus" - gossh "golang.org/x/crypto/ssh" ) type ServerConfig struct { ListenPort int + DBPath string KeysPath string HostKeyPath string + RedisURL string + Subnet string } type Server struct { @@ -35,6 +37,37 @@ func NewServer(cfg *ServerConfig) (*Server, error) { }, nil } +func (s *Server) Run() error { + if err := s.loadKeys(); err != nil { + return err + } + + ssh.Handle(func(session ssh.Session) { + //authorizedKey := gossh.MarshalAuthorizedKey(session.PublicKey()) + id := s.getID(session.PublicKey()) + ip, ipnet, err := s.getOrAllocateIP(id, s.cfg.Subnet) + if err != nil { + logrus.Error(err) + return + } + logrus.Debugf("config: id=%s ip=%s net=%s", id, ip, ipnet) + io.WriteString(session, ip.String()+"\n") + }) + + pubKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { + return s.isAuthorized(ctx, key) + }) + + logrus.Infof("starting ssh server on port %d", s.cfg.ListenPort) + opts := []ssh.Option{ + pubKeyOption, + } + if _, err := os.Stat(s.cfg.HostKeyPath); err == nil { + opts = append(opts, ssh.HostKeyFile(s.cfg.HostKeyPath)) + } + return ssh.ListenAndServe(fmt.Sprintf(":%d", s.cfg.ListenPort), nil, pubKeyOption) +} + func (s *Server) loadKeys() error { s.mu.Lock() defer s.mu.Unlock() @@ -67,31 +100,6 @@ func (s *Server) loadKeys() error { return nil } -func (s *Server) Run() error { - if err := s.loadKeys(); err != nil { - return err - } - - ssh.Handle(func(s ssh.Session) { - authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey()) - io.WriteString(s, fmt.Sprintf("pub key used by %s\n", s.User())) - s.Write(authorizedKey) - }) - - pubKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - return s.isAuthorized(ctx, key) - }) - - logrus.Infof("starting ssh server on port %d", s.cfg.ListenPort) - opts := []ssh.Option{ - pubKeyOption, - } - if _, err := os.Stat(s.cfg.HostKeyPath); err == nil { - opts = append(opts, ssh.HostKeyFile(s.cfg.HostKeyPath)) - } - return ssh.ListenAndServe(fmt.Sprintf(":%d", s.cfg.ListenPort), nil, pubKeyOption) -} - func (s *Server) isAuthorized(ctx ssh.Context, key ssh.PublicKey) bool { for _, k := range s.publicKeys { if ssh.KeysEqual(key, k) { diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..d139c37 --- /dev/null +++ b/utils.go @@ -0,0 +1,19 @@ +package main + +import ( + "crypto/sha256" + "fmt" + + "github.com/gliderlabs/ssh" + "github.com/gomodule/redigo/redis" +) + +func (s *Server) getConn() (redis.Conn, error) { + return redis.DialURL(s.cfg.RedisURL) +} + +func (s *Server) getID(key ssh.PublicKey) string { + h := sha256.New() + h.Write(key.Marshal()) + return fmt.Sprintf("%x", h.Sum(nil)) +}