From 3a41b3f1cecb0fe38c6b64221bc1f5703323a8ae Mon Sep 17 00:00:00 2001 From: Andrew Gerrand Date: Fri, 21 Aug 2015 13:51:28 +0100 Subject: [PATCH] Use a slice instead of a map of io.WriteClosers in broadcastwriter Maps rely on the keys being comparable. Using an interface type as the map key is dangerous, because some interface types are not comparable. I talked about this in my "Stupid Gopher Tricks" talk: https://talks.golang.org/2015/tricks.slide In this case, if the user-provided writer is backed by a slice (such as io.MultiWriter) then the code will panic at run time. Signed-off-by: Andrew Gerrand --- broadcastwriter/broadcastwriter.go | 36 +++++++++++++------------ broadcastwriter/broadcastwriter_test.go | 18 +++++++++++++ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/broadcastwriter/broadcastwriter.go b/broadcastwriter/broadcastwriter.go index b3bfa99..e49810c 100644 --- a/broadcastwriter/broadcastwriter.go +++ b/broadcastwriter/broadcastwriter.go @@ -7,46 +7,48 @@ import ( // BroadcastWriter accumulate multiple io.WriteCloser by stream. type BroadcastWriter struct { - sync.Mutex - writers map[io.WriteCloser]struct{} + mu sync.Mutex + writers []io.WriteCloser } // AddWriter adds new io.WriteCloser. func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) { - w.Lock() - w.writers[writer] = struct{}{} - w.Unlock() + w.mu.Lock() + w.writers = append(w.writers, writer) + w.mu.Unlock() } // Write writes bytes to all writers. Failed writers will be evicted during // this call. func (w *BroadcastWriter) Write(p []byte) (n int, err error) { - w.Lock() - for sw := range w.writers { + w.mu.Lock() + var evict []int + for i, sw := range w.writers { if n, err := sw.Write(p); err != nil || n != len(p) { // On error, evict the writer - delete(w.writers, sw) + evict = append(evict, i) } } - w.Unlock() + for n, i := range evict { + w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...) + } + w.mu.Unlock() return len(p), nil } // Clean closes and removes all writers. Last non-eol-terminated part of data // will be saved. func (w *BroadcastWriter) Clean() error { - w.Lock() - for w := range w.writers { - w.Close() + w.mu.Lock() + for _, sw := range w.writers { + sw.Close() } - w.writers = make(map[io.WriteCloser]struct{}) - w.Unlock() + w.writers = nil + w.mu.Unlock() return nil } // New creates a new BroadcastWriter. func New() *BroadcastWriter { - return &BroadcastWriter{ - writers: make(map[io.WriteCloser]struct{}), - } + return &BroadcastWriter{} } diff --git a/broadcastwriter/broadcastwriter_test.go b/broadcastwriter/broadcastwriter_test.go index bc24320..1ff4cae 100644 --- a/broadcastwriter/broadcastwriter_test.go +++ b/broadcastwriter/broadcastwriter_test.go @@ -3,6 +3,7 @@ package broadcastwriter import ( "bytes" "errors" + "strings" "testing" ) @@ -82,6 +83,23 @@ func TestBroadcastWriter(t *testing.T) { t.Errorf("Buffer contains %v", bufferC.String()) } + // Test4: Test eviction on multiple simultaneous failures + bufferB.failOnWrite = true + bufferC.failOnWrite = true + bufferD := &dummyWriter{} + writer.AddWriter(bufferD) + writer.Write([]byte("yo")) + writer.Write([]byte("ink")) + if strings.Contains(bufferB.String(), "yoink") { + t.Errorf("bufferB received write. contents: %q", bufferB) + } + if strings.Contains(bufferC.String(), "yoink") { + t.Errorf("bufferC received write. contents: %q", bufferC) + } + if g, w := bufferD.String(), "yoink"; g != w { + t.Errorf("bufferD = %q, want %q", g, w) + } + writer.Clean() }