diff --git a/stdcopy/stdcopy.go b/stdcopy/stdcopy.go index 9be0184..b37ae39 100644 --- a/stdcopy/stdcopy.go +++ b/stdcopy/stdcopy.go @@ -3,12 +3,24 @@ package stdcopy import ( "encoding/binary" "errors" + "fmt" "io" "github.com/Sirupsen/logrus" ) +// StdType is the type of standard stream +// a writer can multiplex to. +type StdType byte + const ( + // Stdin represents standard input stream type. + Stdin StdType = iota + // Stdout represents standard output stream type. + Stdout + // Stderr represents standard error steam type. + Stderr + stdWriterPrefixLen = 8 stdWriterFdIndex = 0 stdWriterSizeIndex = 4 @@ -16,38 +28,32 @@ const ( startingBufLen = 32*1024 + stdWriterPrefixLen + 1 ) -// StdType prefixes type and length to standard stream. -type StdType [stdWriterPrefixLen]byte - -var ( - // Stdin represents standard input stream type. - Stdin = StdType{0: 0} - // Stdout represents standard output stream type. - Stdout = StdType{0: 1} - // Stderr represents standard error steam type. - Stderr = StdType{0: 2} -) - -// StdWriter is wrapper of io.Writer with extra customized info. -type StdWriter struct { +// stdWriter is wrapper of io.Writer with extra customized info. +type stdWriter struct { io.Writer - prefix StdType - sizeBuf []byte + prefix byte } -func (w *StdWriter) Write(buf []byte) (n int, err error) { - var n1, n2 int +// Write sends the buffer to the underneath writer. +// It insert the prefix header before the buffer, +// so stdcopy.StdCopy knows where to multiplex the output. +// It makes stdWriter to implement io.Writer. +func (w *stdWriter) Write(buf []byte) (n int, err error) { if w == nil || w.Writer == nil { return 0, errors.New("Writer not instantiated") } - binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf))) - n1, err = w.Writer.Write(w.prefix[:]) - if err != nil { - n = n1 - stdWriterPrefixLen - } else { - n2, err = w.Writer.Write(buf) - n = n1 + n2 - stdWriterPrefixLen + if buf == nil { + return 0, nil } + + header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix} + binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf))) + + line := append(header[:], buf...) + + n, err = w.Writer.Write(line) + n -= stdWriterPrefixLen + if n < 0 { n = 0 } @@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) { // This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection. // `t` indicates the id of the stream to encapsulate. // It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr. -func NewStdWriter(w io.Writer, t StdType) *StdWriter { - return &StdWriter{ - Writer: w, - prefix: t, - sizeBuf: make([]byte, 4), +func NewStdWriter(w io.Writer, t StdType) io.Writer { + return &stdWriter{ + Writer: w, + prefix: byte(t), } } -var errInvalidStdHeader = errors.New("Unrecognized input header") - // StdCopy is a modified version of io.Copy. // // StdCopy will demultiplex `src`, assuming that it contains two streams, @@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) } // Check the first byte to know where to write - switch buf[stdWriterFdIndex] { - case 0: + switch StdType(buf[stdWriterFdIndex]) { + case Stdin: fallthrough - case 1: + case Stdout: // Write on stdout out = dstout - case 2: + case Stderr: // Write on stderr out = dsterr default: logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex]) - return 0, errInvalidStdHeader + return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex]) } // Retrieve the size of the frame diff --git a/stdcopy/stdcopy_test.go b/stdcopy/stdcopy_test.go index 796d165..3137a75 100644 --- a/stdcopy/stdcopy_test.go +++ b/stdcopy/stdcopy_test.go @@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) { } func TestWriteWithUnitializedStdWriter(t *testing.T) { - writer := StdWriter{ - Writer: nil, - prefix: Stdout, - sizeBuf: make([]byte, 4), + writer := stdWriter{ + Writer: nil, + prefix: byte(Stdout), } n, err := writer.Write([]byte("Something here")) if n != 0 || err == nil { @@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) { src: buffer} written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) if written != startingBufLen { - t.Fatalf("Expected 0 bytes read, got %d", written) + t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written) } if err != nil { t.Fatal("Didn't get nil error")