Merge pull request #20706 from calavera/remove_concurrent_access_to_stdtypes
Make stdcopy.StdWriter thread safe.
This commit is contained in:
commit
fcb2e0b085
2 changed files with 44 additions and 42 deletions
|
@ -3,12 +3,24 @@ package stdcopy
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StdType is the type of standard stream
|
||||||
|
// a writer can multiplex to.
|
||||||
|
type StdType byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Stdin represents standard input stream type.
|
||||||
|
Stdin StdType = iota
|
||||||
|
// Stdout represents standard output stream type.
|
||||||
|
Stdout
|
||||||
|
// Stderr represents standard error steam type.
|
||||||
|
Stderr
|
||||||
|
|
||||||
stdWriterPrefixLen = 8
|
stdWriterPrefixLen = 8
|
||||||
stdWriterFdIndex = 0
|
stdWriterFdIndex = 0
|
||||||
stdWriterSizeIndex = 4
|
stdWriterSizeIndex = 4
|
||||||
|
@ -16,38 +28,32 @@ const (
|
||||||
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
|
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdType prefixes type and length to standard stream.
|
// stdWriter is wrapper of io.Writer with extra customized info.
|
||||||
type StdType [stdWriterPrefixLen]byte
|
type stdWriter struct {
|
||||||
|
|
||||||
var (
|
|
||||||
// Stdin represents standard input stream type.
|
|
||||||
Stdin = StdType{0: 0}
|
|
||||||
// Stdout represents standard output stream type.
|
|
||||||
Stdout = StdType{0: 1}
|
|
||||||
// Stderr represents standard error steam type.
|
|
||||||
Stderr = StdType{0: 2}
|
|
||||||
)
|
|
||||||
|
|
||||||
// StdWriter is wrapper of io.Writer with extra customized info.
|
|
||||||
type StdWriter struct {
|
|
||||||
io.Writer
|
io.Writer
|
||||||
prefix StdType
|
prefix byte
|
||||||
sizeBuf []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StdWriter) Write(buf []byte) (n int, err error) {
|
// Write sends the buffer to the underneath writer.
|
||||||
var n1, n2 int
|
// It insert the prefix header before the buffer,
|
||||||
|
// so stdcopy.StdCopy knows where to multiplex the output.
|
||||||
|
// It makes stdWriter to implement io.Writer.
|
||||||
|
func (w *stdWriter) Write(buf []byte) (n int, err error) {
|
||||||
if w == nil || w.Writer == nil {
|
if w == nil || w.Writer == nil {
|
||||||
return 0, errors.New("Writer not instantiated")
|
return 0, errors.New("Writer not instantiated")
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf)))
|
if buf == nil {
|
||||||
n1, err = w.Writer.Write(w.prefix[:])
|
return 0, nil
|
||||||
if err != nil {
|
|
||||||
n = n1 - stdWriterPrefixLen
|
|
||||||
} else {
|
|
||||||
n2, err = w.Writer.Write(buf)
|
|
||||||
n = n1 + n2 - stdWriterPrefixLen
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
|
||||||
|
binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf)))
|
||||||
|
|
||||||
|
line := append(header[:], buf...)
|
||||||
|
|
||||||
|
n, err = w.Writer.Write(line)
|
||||||
|
n -= stdWriterPrefixLen
|
||||||
|
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
n = 0
|
n = 0
|
||||||
}
|
}
|
||||||
|
@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) {
|
||||||
// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
|
// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
|
||||||
// `t` indicates the id of the stream to encapsulate.
|
// `t` indicates the id of the stream to encapsulate.
|
||||||
// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
|
// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
|
||||||
func NewStdWriter(w io.Writer, t StdType) *StdWriter {
|
func NewStdWriter(w io.Writer, t StdType) io.Writer {
|
||||||
return &StdWriter{
|
return &stdWriter{
|
||||||
Writer: w,
|
Writer: w,
|
||||||
prefix: t,
|
prefix: byte(t),
|
||||||
sizeBuf: make([]byte, 4),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errInvalidStdHeader = errors.New("Unrecognized input header")
|
|
||||||
|
|
||||||
// StdCopy is a modified version of io.Copy.
|
// StdCopy is a modified version of io.Copy.
|
||||||
//
|
//
|
||||||
// StdCopy will demultiplex `src`, assuming that it contains two streams,
|
// StdCopy will demultiplex `src`, assuming that it contains two streams,
|
||||||
|
@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the first byte to know where to write
|
// Check the first byte to know where to write
|
||||||
switch buf[stdWriterFdIndex] {
|
switch StdType(buf[stdWriterFdIndex]) {
|
||||||
case 0:
|
case Stdin:
|
||||||
fallthrough
|
fallthrough
|
||||||
case 1:
|
case Stdout:
|
||||||
// Write on stdout
|
// Write on stdout
|
||||||
out = dstout
|
out = dstout
|
||||||
case 2:
|
case Stderr:
|
||||||
// Write on stderr
|
// Write on stderr
|
||||||
out = dsterr
|
out = dsterr
|
||||||
default:
|
default:
|
||||||
logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
|
logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
|
||||||
return 0, errInvalidStdHeader
|
return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve the size of the frame
|
// Retrieve the size of the frame
|
||||||
|
|
|
@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteWithUnitializedStdWriter(t *testing.T) {
|
func TestWriteWithUnitializedStdWriter(t *testing.T) {
|
||||||
writer := StdWriter{
|
writer := stdWriter{
|
||||||
Writer: nil,
|
Writer: nil,
|
||||||
prefix: Stdout,
|
prefix: byte(Stdout),
|
||||||
sizeBuf: make([]byte, 4),
|
|
||||||
}
|
}
|
||||||
n, err := writer.Write([]byte("Something here"))
|
n, err := writer.Write([]byte("Something here"))
|
||||||
if n != 0 || err == nil {
|
if n != 0 || err == nil {
|
||||||
|
@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
|
||||||
src: buffer}
|
src: buffer}
|
||||||
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
||||||
if written != startingBufLen {
|
if written != startingBufLen {
|
||||||
t.Fatalf("Expected 0 bytes read, got %d", written)
|
t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Didn't get nil error")
|
t.Fatal("Didn't get nil error")
|
||||||
|
|
Loading…
Reference in a new issue