Use a slice instead of a map of io.WriteClosers in broadcastwriter
Maps rely on the keys being comparable. Using an interface type as the map key is dangerous, because some interface types are not comparable. I talked about this in my "Stupid Gopher Tricks" talk: https://talks.golang.org/2015/tricks.slide In this case, if the user-provided writer is backed by a slice (such as io.MultiWriter) then the code will panic at run time. Signed-off-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
parent
82a251318a
commit
3a41b3f1ce
2 changed files with 37 additions and 17 deletions
|
@ -7,46 +7,48 @@ import (
|
||||||
|
|
||||||
// BroadcastWriter accumulate multiple io.WriteCloser by stream.
|
// BroadcastWriter accumulate multiple io.WriteCloser by stream.
|
||||||
type BroadcastWriter struct {
|
type BroadcastWriter struct {
|
||||||
sync.Mutex
|
mu sync.Mutex
|
||||||
writers map[io.WriteCloser]struct{}
|
writers []io.WriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddWriter adds new io.WriteCloser.
|
// AddWriter adds new io.WriteCloser.
|
||||||
func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) {
|
func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) {
|
||||||
w.Lock()
|
w.mu.Lock()
|
||||||
w.writers[writer] = struct{}{}
|
w.writers = append(w.writers, writer)
|
||||||
w.Unlock()
|
w.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes bytes to all writers. Failed writers will be evicted during
|
// Write writes bytes to all writers. Failed writers will be evicted during
|
||||||
// this call.
|
// this call.
|
||||||
func (w *BroadcastWriter) Write(p []byte) (n int, err error) {
|
func (w *BroadcastWriter) Write(p []byte) (n int, err error) {
|
||||||
w.Lock()
|
w.mu.Lock()
|
||||||
for sw := range w.writers {
|
var evict []int
|
||||||
|
for i, sw := range w.writers {
|
||||||
if n, err := sw.Write(p); err != nil || n != len(p) {
|
if n, err := sw.Write(p); err != nil || n != len(p) {
|
||||||
// On error, evict the writer
|
// On error, evict the writer
|
||||||
delete(w.writers, sw)
|
evict = append(evict, i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Unlock()
|
for n, i := range evict {
|
||||||
|
w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
|
||||||
|
}
|
||||||
|
w.mu.Unlock()
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean closes and removes all writers. Last non-eol-terminated part of data
|
// Clean closes and removes all writers. Last non-eol-terminated part of data
|
||||||
// will be saved.
|
// will be saved.
|
||||||
func (w *BroadcastWriter) Clean() error {
|
func (w *BroadcastWriter) Clean() error {
|
||||||
w.Lock()
|
w.mu.Lock()
|
||||||
for w := range w.writers {
|
for _, sw := range w.writers {
|
||||||
w.Close()
|
sw.Close()
|
||||||
}
|
}
|
||||||
w.writers = make(map[io.WriteCloser]struct{})
|
w.writers = nil
|
||||||
w.Unlock()
|
w.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new BroadcastWriter.
|
// New creates a new BroadcastWriter.
|
||||||
func New() *BroadcastWriter {
|
func New() *BroadcastWriter {
|
||||||
return &BroadcastWriter{
|
return &BroadcastWriter{}
|
||||||
writers: make(map[io.WriteCloser]struct{}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package broadcastwriter
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -82,6 +83,23 @@ func TestBroadcastWriter(t *testing.T) {
|
||||||
t.Errorf("Buffer contains %v", bufferC.String())
|
t.Errorf("Buffer contains %v", bufferC.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test4: Test eviction on multiple simultaneous failures
|
||||||
|
bufferB.failOnWrite = true
|
||||||
|
bufferC.failOnWrite = true
|
||||||
|
bufferD := &dummyWriter{}
|
||||||
|
writer.AddWriter(bufferD)
|
||||||
|
writer.Write([]byte("yo"))
|
||||||
|
writer.Write([]byte("ink"))
|
||||||
|
if strings.Contains(bufferB.String(), "yoink") {
|
||||||
|
t.Errorf("bufferB received write. contents: %q", bufferB)
|
||||||
|
}
|
||||||
|
if strings.Contains(bufferC.String(), "yoink") {
|
||||||
|
t.Errorf("bufferC received write. contents: %q", bufferC)
|
||||||
|
}
|
||||||
|
if g, w := bufferD.String(), "yoink"; g != w {
|
||||||
|
t.Errorf("bufferD = %q, want %q", g, w)
|
||||||
|
}
|
||||||
|
|
||||||
writer.Clean()
|
writer.Clean()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue