Use a slice instead of a map of io.WriteClosers in broadcastwriter

Maps rely on the keys being comparable.
Using an interface type as the map key is dangerous,
because some interface types are not comparable.
I talked about this in my "Stupid Gopher Tricks" talk:
	https://talks.golang.org/2015/tricks.slide

In this case, if the user-provided writer is backed by a slice
(such as io.MultiWriter) then the code will panic at run time.

Signed-off-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Andrew Gerrand 2015-08-21 13:51:28 +01:00
parent 82a251318a
commit 3a41b3f1ce
2 changed files with 37 additions and 17 deletions

View file

@ -7,46 +7,48 @@ import (
// BroadcastWriter accumulate multiple io.WriteCloser by stream.
type BroadcastWriter struct {
sync.Mutex
writers map[io.WriteCloser]struct{}
mu sync.Mutex
writers []io.WriteCloser
}
// AddWriter adds new io.WriteCloser.
func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) {
w.Lock()
w.writers[writer] = struct{}{}
w.Unlock()
w.mu.Lock()
w.writers = append(w.writers, writer)
w.mu.Unlock()
}
// Write writes bytes to all writers. Failed writers will be evicted during
// this call.
func (w *BroadcastWriter) Write(p []byte) (n int, err error) {
w.Lock()
for sw := range w.writers {
w.mu.Lock()
var evict []int
for i, sw := range w.writers {
if n, err := sw.Write(p); err != nil || n != len(p) {
// On error, evict the writer
delete(w.writers, sw)
evict = append(evict, i)
}
}
w.Unlock()
for n, i := range evict {
w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
}
w.mu.Unlock()
return len(p), nil
}
// Clean closes and removes all writers. Last non-eol-terminated part of data
// will be saved.
func (w *BroadcastWriter) Clean() error {
w.Lock()
for w := range w.writers {
w.Close()
w.mu.Lock()
for _, sw := range w.writers {
sw.Close()
}
w.writers = make(map[io.WriteCloser]struct{})
w.Unlock()
w.writers = nil
w.mu.Unlock()
return nil
}
// New creates a new BroadcastWriter.
func New() *BroadcastWriter {
return &BroadcastWriter{
writers: make(map[io.WriteCloser]struct{}),
}
return &BroadcastWriter{}
}

View file

@ -3,6 +3,7 @@ package broadcastwriter
import (
"bytes"
"errors"
"strings"
"testing"
)
@ -82,6 +83,23 @@ func TestBroadcastWriter(t *testing.T) {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Test4: Test eviction on multiple simultaneous failures
bufferB.failOnWrite = true
bufferC.failOnWrite = true
bufferD := &dummyWriter{}
writer.AddWriter(bufferD)
writer.Write([]byte("yo"))
writer.Write([]byte("ink"))
if strings.Contains(bufferB.String(), "yoink") {
t.Errorf("bufferB received write. contents: %q", bufferB)
}
if strings.Contains(bufferC.String(), "yoink") {
t.Errorf("bufferC received write. contents: %q", bufferC)
}
if g, w := bufferD.String(), "yoink"; g != w {
t.Errorf("bufferD = %q, want %q", g, w)
}
writer.Clean()
}