From f6c6ae7fad0f36dc2a86077e28ea54014f977680 Mon Sep 17 00:00:00 2001 From: David Calavera Date: Wed, 20 May 2015 16:48:39 -0700 Subject: [PATCH] Extract sockets initialization to a package. Because I just used it somewhere else and it would be nice if I didn't have to copy and paste the code. Signed-off-by: David Calavera --- sockets/README.md | 0 sockets/tcp_socket.go | 69 ++++++++++++++++++++++++++++++++++++ sockets/unix_socket.go | 80 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 sockets/README.md create mode 100644 sockets/tcp_socket.go create mode 100644 sockets/unix_socket.go diff --git a/sockets/README.md b/sockets/README.md new file mode 100644 index 0000000..e69de29 diff --git a/sockets/tcp_socket.go b/sockets/tcp_socket.go new file mode 100644 index 0000000..ac9edae --- /dev/null +++ b/sockets/tcp_socket.go @@ -0,0 +1,69 @@ +package sockets + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "os" + + "github.com/docker/docker/pkg/listenbuffer" +) + +type TlsConfig struct { + CA string + Certificate string + Key string + Verify bool +} + +func NewTlsConfig(tlsCert, tlsKey, tlsCA string, verify bool) *TlsConfig { + return &TlsConfig{ + Verify: verify, + Certificate: tlsCert, + Key: tlsKey, + CA: tlsCA, + } +} + +func NewTcpSocket(addr string, config *TlsConfig, activate <-chan struct{}) (net.Listener, error) { + l, err := listenbuffer.NewListenBuffer("tcp", addr, activate) + if err != nil { + return nil, err + } + if config != nil { + if l, err = setupTls(l, config); err != nil { + return nil, err + } + } + return l, nil +} + +func setupTls(l net.Listener, config *TlsConfig) (net.Listener, error) { + tlsCert, err := tls.LoadX509KeyPair(config.Certificate, config.Key) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("Could not load X509 key pair (%s, %s): %v", config.Certificate, config.Key, err) + } + return nil, fmt.Errorf("Error reading X509 key pair (%s, %s): %q. Make sure the key is encrypted.", + config.Certificate, config.Key, err) + } + tlsConfig := &tls.Config{ + NextProtos: []string{"http/1.1"}, + Certificates: []tls.Certificate{tlsCert}, + // Avoid fallback on insecure SSL protocols + MinVersion: tls.VersionTLS10, + } + if config.CA != "" { + certPool := x509.NewCertPool() + file, err := ioutil.ReadFile(config.CA) + if err != nil { + return nil, fmt.Errorf("Could not read CA certificate: %v", err) + } + certPool.AppendCertsFromPEM(file) + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = certPool + } + return tls.NewListener(l, tlsConfig), nil +} diff --git a/sockets/unix_socket.go b/sockets/unix_socket.go new file mode 100644 index 0000000..0536382 --- /dev/null +++ b/sockets/unix_socket.go @@ -0,0 +1,80 @@ +// +build linux + +package sockets + +import ( + "fmt" + "net" + "os" + "strconv" + "syscall" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/listenbuffer" + "github.com/docker/libcontainer/user" +) + +func NewUnixSocket(path, group string, activate <-chan struct{}) (net.Listener, error) { + if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) { + return nil, err + } + mask := syscall.Umask(0777) + defer syscall.Umask(mask) + l, err := listenbuffer.NewListenBuffer("unix", path, activate) + if err != nil { + return nil, err + } + if err := setSocketGroup(path, group); err != nil { + l.Close() + return nil, err + } + if err := os.Chmod(path, 0660); err != nil { + l.Close() + return nil, err + } + return l, nil +} + +func setSocketGroup(path, group string) error { + if group == "" { + return nil + } + if err := changeGroup(path, group); err != nil { + if group != "docker" { + return err + } + logrus.Debugf("Warning: could not change group %s to docker: %v", path, err) + } + return nil +} + +func changeGroup(path string, nameOrGid string) error { + gid, err := lookupGidByName(nameOrGid) + if err != nil { + return err + } + logrus.Debugf("%s group found. gid: %d", nameOrGid, gid) + return os.Chown(path, 0, gid) +} + +func lookupGidByName(nameOrGid string) (int, error) { + groupFile, err := user.GetGroupPath() + if err != nil { + return -1, err + } + groups, err := user.ParseGroupFileFilter(groupFile, func(g user.Group) bool { + return g.Name == nameOrGid || strconv.Itoa(g.Gid) == nameOrGid + }) + if err != nil { + return -1, err + } + if groups != nil && len(groups) > 0 { + return groups[0].Gid, nil + } + gid, err := strconv.Atoi(nameOrGid) + if err == nil { + logrus.Warnf("Could not find GID %d", gid) + return gid, nil + } + return -1, fmt.Errorf("Group %s not found", nameOrGid) +}