Use a slice instead of a map of io.WriteClosers in broadcastwriter

Maps rely on the keys being comparable.
Using an interface type as the map key is dangerous,
because some interface types are not comparable.
I talked about this in my "Stupid Gopher Tricks" talk:
	https://talks.golang.org/2015/tricks.slide

In this case, if the user-provided writer is backed by a slice
(such as io.MultiWriter) then the code will panic at run time.

Signed-off-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Andrew Gerrand 2015-08-21 13:51:28 +01:00
parent 82a251318a
commit 3a41b3f1ce
2 changed files with 37 additions and 17 deletions

View file

@ -7,46 +7,48 @@ import (
// BroadcastWriter accumulate multiple io.WriteCloser by stream. // BroadcastWriter accumulate multiple io.WriteCloser by stream.
type BroadcastWriter struct { type BroadcastWriter struct {
sync.Mutex mu sync.Mutex
writers map[io.WriteCloser]struct{} writers []io.WriteCloser
} }
// AddWriter adds new io.WriteCloser. // AddWriter adds new io.WriteCloser.
func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) { func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) {
w.Lock() w.mu.Lock()
w.writers[writer] = struct{}{} w.writers = append(w.writers, writer)
w.Unlock() w.mu.Unlock()
} }
// Write writes bytes to all writers. Failed writers will be evicted during // Write writes bytes to all writers. Failed writers will be evicted during
// this call. // this call.
func (w *BroadcastWriter) Write(p []byte) (n int, err error) { func (w *BroadcastWriter) Write(p []byte) (n int, err error) {
w.Lock() w.mu.Lock()
for sw := range w.writers { var evict []int
for i, sw := range w.writers {
if n, err := sw.Write(p); err != nil || n != len(p) { if n, err := sw.Write(p); err != nil || n != len(p) {
// On error, evict the writer // On error, evict the writer
delete(w.writers, sw) evict = append(evict, i)
} }
} }
w.Unlock() for n, i := range evict {
w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
}
w.mu.Unlock()
return len(p), nil return len(p), nil
} }
// Clean closes and removes all writers. Last non-eol-terminated part of data // Clean closes and removes all writers. Last non-eol-terminated part of data
// will be saved. // will be saved.
func (w *BroadcastWriter) Clean() error { func (w *BroadcastWriter) Clean() error {
w.Lock() w.mu.Lock()
for w := range w.writers { for _, sw := range w.writers {
w.Close() sw.Close()
} }
w.writers = make(map[io.WriteCloser]struct{}) w.writers = nil
w.Unlock() w.mu.Unlock()
return nil return nil
} }
// New creates a new BroadcastWriter. // New creates a new BroadcastWriter.
func New() *BroadcastWriter { func New() *BroadcastWriter {
return &BroadcastWriter{ return &BroadcastWriter{}
writers: make(map[io.WriteCloser]struct{}),
}
} }

View file

@ -3,6 +3,7 @@ package broadcastwriter
import ( import (
"bytes" "bytes"
"errors" "errors"
"strings"
"testing" "testing"
) )
@ -82,6 +83,23 @@ func TestBroadcastWriter(t *testing.T) {
t.Errorf("Buffer contains %v", bufferC.String()) t.Errorf("Buffer contains %v", bufferC.String())
} }
// Test4: Test eviction on multiple simultaneous failures
bufferB.failOnWrite = true
bufferC.failOnWrite = true
bufferD := &dummyWriter{}
writer.AddWriter(bufferD)
writer.Write([]byte("yo"))
writer.Write([]byte("ink"))
if strings.Contains(bufferB.String(), "yoink") {
t.Errorf("bufferB received write. contents: %q", bufferB)
}
if strings.Contains(bufferC.String(), "yoink") {
t.Errorf("bufferC received write. contents: %q", bufferC)
}
if g, w := bufferD.String(), "yoink"; g != w {
t.Errorf("bufferD = %q, want %q", g, w)
}
writer.Clean() writer.Clean()
} }