diff --git a/chanotify/chanotify.go b/chanotify/chanotify.go index a317fa9..f996421 100644 --- a/chanotify/chanotify.go +++ b/chanotify/chanotify.go @@ -1,6 +1,7 @@ package chanotify import ( + "errors" "sync" ) @@ -33,23 +34,33 @@ func (n *Notifier) Chan() <-chan interface{} { return n.c } +// Add adds new notification channel to Notifier. +func (n *Notifier) Add(id interface{}, ch <-chan struct{}) error { + n.m.Lock() + defer n.m.Unlock() + + if n.closed { + return errors.New("notifier closed; cannot add the channel on the notifier") + } + if _, ok := n.doneCh[id]; ok { + return errors.New("cannot register duplicate key") + } + + done := make(chan struct{}) + n.doneCh[id] = done + + n.startWorker(ch, id, done) + return nil +} + func (n *Notifier) killWorker(id interface{}, done chan struct{}) { n.m.Lock() delete(n.doneCh, id) n.m.Unlock() } -// Add adds new notification channel to Notifier. -func (n *Notifier) Add(id interface{}, ch <-chan struct{}) { - done := make(chan struct{}) - n.m.Lock() - if n.closed { - panic("notifier closed; cannot add the channel") - } - n.doneCh[id] = done - n.m.Unlock() - - go func(ch <-chan struct{}, id interface{}, done chan struct{}) { +func (n *Notifier) startWorker(ch <-chan struct{}, id interface{}, done chan struct{}) { + go func() { for { select { case _, ok := <-ch: @@ -66,7 +77,7 @@ func (n *Notifier) Add(id interface{}, ch <-chan struct{}) { return } } - }(ch, id, done) + }() } // Close closes the notifier and releases its underlying resources. diff --git a/chanotify/chanotify_test.go b/chanotify/chanotify_test.go index 9a9e409..ef770bb 100644 --- a/chanotify/chanotify_test.go +++ b/chanotify/chanotify_test.go @@ -13,8 +13,12 @@ func TestNotifier(t *testing.T) { id1 := "1" id2 := "2" - s.Add(id1, ch1) - s.Add(id2, ch2) + if err := s.Add(id1, ch1); err != nil { + t.Fatal(err) + } + if err := s.Add(id2, ch2); err != nil { + t.Fatal(err) + } s.m.Lock() if len(s.doneCh) != 2 { t.Fatalf("want 2 channels, got %d", len(s.doneCh)) @@ -43,7 +47,9 @@ func TestConcurrentNotifier(t *testing.T) { var chs []chan struct{} for i := 0; i < 8; i++ { ch := make(chan struct{}, 2) - s.Add(i, ch) + if err := s.Add(i, ch); err != nil { + t.Fatal(err) + } chs = append(chs, ch) } testCounter := make(map[interface{}]int) @@ -86,10 +92,26 @@ func TestAddToBlocked(t *testing.T) { go func() { // give some time to start first select time.Sleep(1 * time.Second) - s.Add(id, ch) + if err := s.Add(id, ch); err != nil { + t.Fatal(err) + } ch <- struct{}{} }() if got, want := <-s.Chan(), id; got != want { t.Fatalf("got %v; want %v", got, want) } } + +func TestAddDuplicate(t *testing.T) { + s := New() + ch1 := make(chan struct{}, 1) + ch2 := make(chan struct{}, 1) + + if err := s.Add(1, ch1); err != nil { + t.Fatalf("cannot add; err = %v", err) + } + + if err := s.Add(1, ch2); err == nil { + t.Fatalf("duplicate keys are not allowed; but Add succeeded") + } +}