diff --git a/pubsub/publisher.go b/pubsub/publisher.go new file mode 100644 index 0000000..98d0356 --- /dev/null +++ b/pubsub/publisher.go @@ -0,0 +1,66 @@ +package pubsub + +import ( + "sync" + "time" +) + +// NewPublisher creates a new pub/sub publisher to broadcast messages. +// The duration is used as the send timeout as to not block the publisher publishing +// messages to other clients if one client is slow or unresponsive. +// The buffer is used when creating new channels for subscribers. +func NewPublisher(publishTimeout time.Duration, buffer int) *Publisher { + return &Publisher{ + buffer: buffer, + timeout: publishTimeout, + subscribers: make(map[subscriber]struct{}), + } +} + +type subscriber chan interface{} + +type Publisher struct { + m sync.RWMutex + buffer int + timeout time.Duration + subscribers map[subscriber]struct{} +} + +// Subscribe adds a new subscriber to the publisher returning the channel. +func (p *Publisher) Subscribe() chan interface{} { + ch := make(chan interface{}, p.buffer) + p.m.Lock() + p.subscribers[ch] = struct{}{} + p.m.Unlock() + return ch +} + +// Evict removes the specified subscriber from receiving any more messages. +func (p *Publisher) Evict(sub chan interface{}) { + p.m.Lock() + delete(p.subscribers, sub) + close(sub) + p.m.Unlock() +} + +// Publish sends the data in v to all subscribers currently registered with the publisher. +func (p *Publisher) Publish(v interface{}) { + p.m.RLock() + for sub := range p.subscribers { + // send under a select as to not block if the receiver is unavailable + select { + case sub <- v: + case <-time.After(p.timeout): + } + } + p.m.RUnlock() +} + +// Close closes the channels to all subscribers registered with the publisher. +func (p *Publisher) Close() { + p.m.Lock() + for sub := range p.subscribers { + close(sub) + } + p.m.Unlock() +} diff --git a/pubsub/publisher_test.go b/pubsub/publisher_test.go new file mode 100644 index 0000000..c19059a --- /dev/null +++ b/pubsub/publisher_test.go @@ -0,0 +1,63 @@ +package pubsub + +import ( + "testing" + "time" +) + +func TestSendToOneSub(t *testing.T) { + p := NewPublisher(100*time.Millisecond, 10) + c := p.Subscribe() + + p.Publish("hi") + + msg := <-c + if msg.(string) != "hi" { + t.Fatalf("expected message hi but received %v", msg) + } +} + +func TestSendToMultipleSubs(t *testing.T) { + p := NewPublisher(100*time.Millisecond, 10) + subs := []chan interface{}{} + subs = append(subs, p.Subscribe(), p.Subscribe(), p.Subscribe()) + + p.Publish("hi") + + for _, c := range subs { + msg := <-c + if msg.(string) != "hi" { + t.Fatalf("expected message hi but received %v", msg) + } + } +} + +func TestEvictOneSub(t *testing.T) { + p := NewPublisher(100*time.Millisecond, 10) + s1 := p.Subscribe() + s2 := p.Subscribe() + + p.Evict(s1) + p.Publish("hi") + if _, ok := <-s1; ok { + t.Fatal("expected s1 to not receive the published message") + } + + msg := <-s2 + if msg.(string) != "hi" { + t.Fatalf("expected message hi but received %v", msg) + } +} + +func TestClosePublisher(t *testing.T) { + p := NewPublisher(100*time.Millisecond, 10) + subs := []chan interface{}{} + subs = append(subs, p.Subscribe(), p.Subscribe(), p.Subscribe()) + p.Close() + + for _, c := range subs { + if _, ok := <-c; ok { + t.Fatal("expected all subscriber channels to be closed") + } + } +}