diff --git a/conmon/conmon.c b/conmon/conmon.c index f766f609..9a1a75b7 100644 --- a/conmon/conmon.c +++ b/conmon/conmon.c @@ -1002,7 +1002,13 @@ int main(int argc, char *argv[]) } } else { num_read = read(masterfd, buf, BUF_SIZE); - if (num_read <= 0) + if (num_read == 0) { + ninfo("Remote socket closed"); + close(conn_sock); + conn_sock = -1; + continue; + } + if (num_read < 0) goto out; ninfo("got data on connection: %d", num_read); if (terminal) { diff --git a/server/container_attach.go b/server/container_attach.go index a94690e0..9d96a56a 100644 --- a/server/container_attach.go +++ b/server/container_attach.go @@ -11,6 +11,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/kubernetes-incubator/cri-o/oci" + "github.com/kubernetes-incubator/cri-o/utils" "golang.org/x/net/context" pb "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" @@ -74,18 +75,22 @@ func (ss streamService) Attach(containerID string, inputStream io.Reader, output }() } - stdinDone := make(chan struct{}) + stdinDone := make(chan error) go func() { + var err error if inputStream != nil { - io.Copy(conn, inputStream) + _, err = utils.CopyDetachable(conn, inputStream, nil) } - close(stdinDone) + stdinDone <- err }() select { case err := <-receiveStdout: return err - case <-stdinDone: + case err := <-stdinDone: + if _, ok := err.(utils.DetachError); ok { + return nil + } if outputStream != nil || errorStream != nil { return <-receiveStdout } diff --git a/utils/utils.go b/utils/utils.go index 1635e4b4..8494c02a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -79,3 +79,63 @@ func newProp(name string, units interface{}) systemdDbus.Property { Value: dbus.MakeVariant(units), } } + +// DetachError is special error which returned in case of container detach. +type DetachError struct{} + +func (DetachError) Error() string { + return "detached from container" +} + +// CopyDetachable is similar to io.Copy but support a detach key sequence to break out. +func CopyDetachable(dst io.Writer, src io.Reader, keys []byte) (written int64, err error) { + if len(keys) == 0 { + // Default keys : ctrl-p ctrl-q + keys = []byte{16, 17} + } + + buf := make([]byte, 32*1024) + for { + nr, er := src.Read(buf) + if nr > 0 { + preservBuf := []byte{} + for i, key := range keys { + preservBuf = append(preservBuf, buf[0:nr]...) + if nr != 1 || buf[0] != key { + break + } + if i == len(keys)-1 { + // src.Close() + return 0, DetachError{} + } + nr, er = src.Read(buf) + } + var nw int + var ew error + if len(preservBuf) > 0 { + nw, ew = dst.Write(preservBuf) + nr = len(preservBuf) + } else { + nw, ew = dst.Write(buf[0:nr]) + } + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +}