diff --git a/stdcopy/stdcopy.go b/stdcopy/stdcopy.go index 89c64ba..9be0184 100644 --- a/stdcopy/stdcopy.go +++ b/stdcopy/stdcopy.go @@ -12,6 +12,8 @@ const ( stdWriterPrefixLen = 8 stdWriterFdIndex = 0 stdWriterSizeIndex = 4 + + startingBufLen = 32*1024 + stdWriterPrefixLen + 1 ) // StdType prefixes type and length to standard stream. @@ -80,7 +82,7 @@ var errInvalidStdHeader = errors.New("Unrecognized input header") // `written` will hold the total number of bytes written to `dstout` and `dsterr`. func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) { var ( - buf = make([]byte, 32*1024+stdWriterPrefixLen+1) + buf = make([]byte, startingBufLen) bufLen = len(buf) nr, nw int er, ew error diff --git a/stdcopy/stdcopy_test.go b/stdcopy/stdcopy_test.go index c9964aa..7f0d2ef 100644 --- a/stdcopy/stdcopy_test.go +++ b/stdcopy/stdcopy_test.go @@ -85,6 +85,30 @@ func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) { } } +func TestStdCopyWriteAndRead(t *testing.T) { + buffer := new(bytes.Buffer) + stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) + dstOut := NewStdWriter(buffer, Stdout) + _, err := dstOut.Write(stdOutBytes) + if err != nil { + t.Fatal(err) + } + stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) + dstErr := NewStdWriter(buffer, Stderr) + _, err = dstErr.Write(stdErrBytes) + if err != nil { + t.Fatal(err) + } + written, err := StdCopy(ioutil.Discard, ioutil.Discard, buffer) + if err != nil { + t.Fatal(err) + } + expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes) + if written != int64(expectedTotalWritten) { + t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written) + } +} + func TestStdCopyWithInvalidInputHeader(t *testing.T) { dstOut := NewStdWriter(ioutil.Discard, Stdout) dstErr := NewStdWriter(ioutil.Discard, Stderr)