beam: Add simple framing system for UnixConn
This is needed for Send/Recieve to correctly handle borders between the messages. The framing uses a single 32bit uint32 length for each frame, of which the high bit is used to indicate whether the message contains a file descriptor or not. This is enough to separate out each message sent and to decide to which message each file descriptors belongs, even though multiple Sends may be coalesced into a single read, and/or one Send can be split into multiple writes. Docker-DCO-1.1-Signed-off-by: Alexander Larsson <alexl@redhat.com> (github: alexlarsson) Docker-DCO-1.1-Signed-off-by: Solomon Hykes <solomon@docker.com> (github: shykes)
This commit is contained in:
parent
d6deab19dc
commit
bf43f17c56
1 changed files with 136 additions and 30 deletions
164
beam/unix.go
164
beam/unix.go
|
@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
|
|||
|
||||
type UnixConn struct {
|
||||
*net.UnixConn
|
||||
fds []*os.File
|
||||
}
|
||||
|
||||
// Framing:
|
||||
// In order to handle framing in Send/Recieve, as these give frame
|
||||
// boundaries we use a very simple 4 bytes header. It is a big endiand
|
||||
// uint32 where the high bit is set if the message includes a file
|
||||
// descriptor. The rest of the uint32 is the length of the next frame.
|
||||
// We need the bit in order to be able to assign recieved fds to
|
||||
// the right message, as multiple messages may be coalesced into
|
||||
// a single recieve operation.
|
||||
func makeHeader(data []byte, fds []int) ([]byte, error) {
|
||||
header := make([]byte, 4)
|
||||
|
||||
length := uint32(len(data))
|
||||
|
||||
if length > 0x7fffffff {
|
||||
return nil, fmt.Errorf("Data to large")
|
||||
}
|
||||
|
||||
if len(fds) != 0 {
|
||||
length = length | 0x80000000
|
||||
}
|
||||
header[0] = byte((length >> 24) & 0xff)
|
||||
header[1] = byte((length >> 16) & 0xff)
|
||||
header[2] = byte((length >> 8) & 0xff)
|
||||
header[3] = byte((length >> 0) & 0xff)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func parseHeader(header []byte) (uint32, bool) {
|
||||
length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
|
||||
hasFd := length&0x80000000 != 0
|
||||
length = length & ^uint32(0x80000000)
|
||||
|
||||
return length, hasFd
|
||||
}
|
||||
|
||||
func FileConn(f *os.File) (*UnixConn, error) {
|
||||
|
@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
|
|||
conn.Close()
|
||||
return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
|
||||
}
|
||||
return &UnixConn{uconn}, nil
|
||||
return &UnixConn{UnixConn: uconn}, nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
|
|||
if f != nil {
|
||||
fds = append(fds, int(f.Fd()))
|
||||
}
|
||||
if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
|
||||
if err := conn.sendUnix(data, fds...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
|
|||
}
|
||||
debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
|
||||
}()
|
||||
for {
|
||||
data, fds, err := receiveUnix(conn.UnixConn)
|
||||
|
||||
// Read header
|
||||
header := make([]byte, 4)
|
||||
nRead := uint32(0)
|
||||
|
||||
for nRead < 4 {
|
||||
n, err := conn.receiveUnix(header[nRead:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var f *os.File
|
||||
if len(fds) > 1 {
|
||||
for _, fd := range fds[1:] {
|
||||
syscall.Close(fd)
|
||||
}
|
||||
}
|
||||
if len(fds) >= 1 {
|
||||
f = os.NewFile(uintptr(fds[0]), "")
|
||||
}
|
||||
return data, f, nil
|
||||
}
|
||||
panic("impossibru")
|
||||
return nil, nil, nil
|
||||
nRead = nRead + uint32(n)
|
||||
}
|
||||
|
||||
func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
|
||||
buf := make([]byte, 4096)
|
||||
oob := make([]byte, 4096)
|
||||
length, hasFd := parseHeader(header)
|
||||
|
||||
if hasFd {
|
||||
if len(conn.fds) == 0 {
|
||||
return nil, nil, fmt.Errorf("No expected file descriptor in message")
|
||||
}
|
||||
|
||||
rf = conn.fds[0]
|
||||
conn.fds = conn.fds[1:]
|
||||
}
|
||||
|
||||
rdata = make([]byte, length)
|
||||
|
||||
nRead = 0
|
||||
for nRead < length {
|
||||
n, err := conn.receiveUnix(rdata[nRead:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
nRead = nRead + uint32(n)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
|
||||
oob := make([]byte, syscall.CmsgSpace(4))
|
||||
bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
return buf[:bufn], extractFds(oob[:oobn]), nil
|
||||
fd := extractFd(oob[:oobn])
|
||||
if fd != -1 {
|
||||
f := os.NewFile(uintptr(fd), "")
|
||||
conn.fds = append(conn.fds, f)
|
||||
}
|
||||
|
||||
func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
|
||||
_, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
|
||||
return bufn, nil
|
||||
}
|
||||
|
||||
func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
|
||||
header, err := makeHeader(data, fds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func extractFds(oob []byte) (fds []int) {
|
||||
// There is a bug in conn.WriteMsgUnix where it doesn't correctly return
|
||||
// the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
|
||||
// So, we can't rely on the return value from it. However, we must use it to
|
||||
// send the fds. In order to handle this we only write one byte using WriteMsgUnix
|
||||
// (when we have to), as that can only ever block or fully suceed. We then write
|
||||
// the rest with conn.Write()
|
||||
// The reader side should not rely on this though, as hopefully this gets fixed
|
||||
// in go later.
|
||||
written := 0
|
||||
if len(fds) != 0 {
|
||||
oob := syscall.UnixRights(fds...)
|
||||
wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
for written < len(header) {
|
||||
wrote, err := conn.Write(header[written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
written = 0
|
||||
for written < len(data) {
|
||||
wrote, err := conn.Write(data[written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFd(oob []byte) int {
|
||||
// Grab forklock to make sure no forks accidentally inherit the new
|
||||
// fds before they are made CLOEXEC
|
||||
// There is a slight race condition between ReadMsgUnix returns and
|
||||
|
@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
|
|||
defer syscall.ForkLock.Unlock()
|
||||
scms, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return
|
||||
return -1
|
||||
}
|
||||
|
||||
foundFd := -1
|
||||
for _, scm := range scms {
|
||||
gotFds, err := syscall.ParseUnixRights(&scm)
|
||||
fds, err := syscall.ParseUnixRights(&scm)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fds = append(fds, gotFds...)
|
||||
|
||||
for _, fd := range fds {
|
||||
if foundFd == -1 {
|
||||
syscall.CloseOnExec(fd)
|
||||
foundFd = fd
|
||||
} else {
|
||||
syscall.Close(fd)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
return foundFd
|
||||
}
|
||||
|
||||
func socketpair() ([2]int, error) {
|
||||
|
|
Loading…
Reference in a new issue