diff --git a/stdcopy/stdcopy_test.go b/stdcopy/stdcopy_test.go index 7351ce0..c9964aa 100644 --- a/stdcopy/stdcopy_test.go +++ b/stdcopy/stdcopy_test.go @@ -2,6 +2,7 @@ package stdcopy import ( "bytes" + "errors" "io/ioutil" "strings" "testing" @@ -49,6 +50,41 @@ func TestWrite(t *testing.T) { } } +type errWriter struct { + n int + err error +} + +func (f *errWriter) Write(buf []byte) (int, error) { + return f.n, f.err +} + +func TestWriteWithWriterError(t *testing.T) { + expectedError := errors.New("expected") + expectedReturnedBytes := 10 + writer := NewStdWriter(&errWriter{ + n: stdWriterPrefixLen + expectedReturnedBytes, + err: expectedError}, Stdout) + data := []byte("This won't get written, sigh") + n, err := writer.Write(data) + if err != expectedError { + t.Fatalf("Didn't get expected error.") + } + if n != expectedReturnedBytes { + t.Fatalf("Didn't get expected writen bytes %d, got %d.", + expectedReturnedBytes, n) + } +} + +func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) { + writer := NewStdWriter(&errWriter{n: -1}, Stdout) + data := []byte("This won't get written, sigh") + actual, _ := writer.Write(data) + if actual != 0 { + t.Fatalf("Expected returned written bytes equal to 0, got %d", actual) + } +} + func TestStdCopyWithInvalidInputHeader(t *testing.T) { dstOut := NewStdWriter(ioutil.Discard, Stdout) dstErr := NewStdWriter(ioutil.Discard, Stderr)