Make stdcopy.stdWriter goroutine safe.
Stop using global variables as prefixes to inject the writer header. That can cause issues when two writers set the length of the buffer in the same header concurrently. Stop Writing to the internal buffer twice for each write. This could mess up with the ordering information is written. Signed-off-by: David Calavera <david.calavera@gmail.com>
This commit is contained in:
parent
29aeaf7880
commit
f55298771e
2 changed files with 44 additions and 42 deletions
|
@ -3,12 +3,24 @@ package stdcopy
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StdType is the type of standard stream
|
||||||
|
// a writer can multiplex to.
|
||||||
|
type StdType byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Stdin represents standard input stream type.
|
||||||
|
Stdin StdType = iota
|
||||||
|
// Stdout represents standard output stream type.
|
||||||
|
Stdout
|
||||||
|
// Stderr represents standard error steam type.
|
||||||
|
Stderr
|
||||||
|
|
||||||
stdWriterPrefixLen = 8
|
stdWriterPrefixLen = 8
|
||||||
stdWriterFdIndex = 0
|
stdWriterFdIndex = 0
|
||||||
stdWriterSizeIndex = 4
|
stdWriterSizeIndex = 4
|
||||||
|
@ -16,38 +28,32 @@ const (
|
||||||
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
|
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdType prefixes type and length to standard stream.
|
// stdWriter is wrapper of io.Writer with extra customized info.
|
||||||
type StdType [stdWriterPrefixLen]byte
|
type stdWriter struct {
|
||||||
|
|
||||||
var (
|
|
||||||
// Stdin represents standard input stream type.
|
|
||||||
Stdin = StdType{0: 0}
|
|
||||||
// Stdout represents standard output stream type.
|
|
||||||
Stdout = StdType{0: 1}
|
|
||||||
// Stderr represents standard error steam type.
|
|
||||||
Stderr = StdType{0: 2}
|
|
||||||
)
|
|
||||||
|
|
||||||
// StdWriter is wrapper of io.Writer with extra customized info.
|
|
||||||
type StdWriter struct {
|
|
||||||
io.Writer
|
io.Writer
|
||||||
prefix StdType
|
prefix byte
|
||||||
sizeBuf []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StdWriter) Write(buf []byte) (n int, err error) {
|
// Write sends the buffer to the underneath writer.
|
||||||
var n1, n2 int
|
// It insert the prefix header before the buffer,
|
||||||
|
// so stdcopy.StdCopy knows where to multiplex the output.
|
||||||
|
// It makes stdWriter to implement io.Writer.
|
||||||
|
func (w *stdWriter) Write(buf []byte) (n int, err error) {
|
||||||
if w == nil || w.Writer == nil {
|
if w == nil || w.Writer == nil {
|
||||||
return 0, errors.New("Writer not instantiated")
|
return 0, errors.New("Writer not instantiated")
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf)))
|
if buf == nil {
|
||||||
n1, err = w.Writer.Write(w.prefix[:])
|
return 0, nil
|
||||||
if err != nil {
|
|
||||||
n = n1 - stdWriterPrefixLen
|
|
||||||
} else {
|
|
||||||
n2, err = w.Writer.Write(buf)
|
|
||||||
n = n1 + n2 - stdWriterPrefixLen
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
|
||||||
|
binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf)))
|
||||||
|
|
||||||
|
line := append(header[:], buf...)
|
||||||
|
|
||||||
|
n, err = w.Writer.Write(line)
|
||||||
|
n -= stdWriterPrefixLen
|
||||||
|
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
n = 0
|
n = 0
|
||||||
}
|
}
|
||||||
|
@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) {
|
||||||
// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
|
// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
|
||||||
// `t` indicates the id of the stream to encapsulate.
|
// `t` indicates the id of the stream to encapsulate.
|
||||||
// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
|
// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
|
||||||
func NewStdWriter(w io.Writer, t StdType) *StdWriter {
|
func NewStdWriter(w io.Writer, t StdType) io.Writer {
|
||||||
return &StdWriter{
|
return &stdWriter{
|
||||||
Writer: w,
|
Writer: w,
|
||||||
prefix: t,
|
prefix: byte(t),
|
||||||
sizeBuf: make([]byte, 4),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errInvalidStdHeader = errors.New("Unrecognized input header")
|
|
||||||
|
|
||||||
// StdCopy is a modified version of io.Copy.
|
// StdCopy is a modified version of io.Copy.
|
||||||
//
|
//
|
||||||
// StdCopy will demultiplex `src`, assuming that it contains two streams,
|
// StdCopy will demultiplex `src`, assuming that it contains two streams,
|
||||||
|
@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the first byte to know where to write
|
// Check the first byte to know where to write
|
||||||
switch buf[stdWriterFdIndex] {
|
switch StdType(buf[stdWriterFdIndex]) {
|
||||||
case 0:
|
case Stdin:
|
||||||
fallthrough
|
fallthrough
|
||||||
case 1:
|
case Stdout:
|
||||||
// Write on stdout
|
// Write on stdout
|
||||||
out = dstout
|
out = dstout
|
||||||
case 2:
|
case Stderr:
|
||||||
// Write on stderr
|
// Write on stderr
|
||||||
out = dsterr
|
out = dsterr
|
||||||
default:
|
default:
|
||||||
logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
|
logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
|
||||||
return 0, errInvalidStdHeader
|
return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve the size of the frame
|
// Retrieve the size of the frame
|
||||||
|
|
|
@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteWithUnitializedStdWriter(t *testing.T) {
|
func TestWriteWithUnitializedStdWriter(t *testing.T) {
|
||||||
writer := StdWriter{
|
writer := stdWriter{
|
||||||
Writer: nil,
|
Writer: nil,
|
||||||
prefix: Stdout,
|
prefix: byte(Stdout),
|
||||||
sizeBuf: make([]byte, 4),
|
|
||||||
}
|
}
|
||||||
n, err := writer.Write([]byte("Something here"))
|
n, err := writer.Write([]byte("Something here"))
|
||||||
if n != 0 || err == nil {
|
if n != 0 || err == nil {
|
||||||
|
@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
|
||||||
src: buffer}
|
src: buffer}
|
||||||
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
||||||
if written != startingBufLen {
|
if written != startingBufLen {
|
||||||
t.Fatalf("Expected 0 bytes read, got %d", written)
|
t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Didn't get nil error")
|
t.Fatal("Didn't get nil error")
|
||||||
|
|
Loading…
Reference in a new issue