diff --git a/locker/locker_test.go b/locker/locker_test.go index 20b3c14..631aaf3 100644 --- a/locker/locker_test.go +++ b/locker/locker_test.go @@ -1,8 +1,8 @@ package locker import ( - "runtime" "testing" + "time" ) func TestLockCounter(t *testing.T) { @@ -34,7 +34,21 @@ func TestLockerLock(t *testing.T) { close(chDone) }() - runtime.Gosched() + chWaiting := make(chan struct{}) + go func() { + for range time.Tick(1 * time.Millisecond) { + if ctr.count() == 1 { + close(chWaiting) + break + } + } + }() + + select { + case <-chWaiting: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for lock waiters to be incremented") + } select { case <-chDone: @@ -42,25 +56,14 @@ func TestLockerLock(t *testing.T) { default: } - if ctr.count() != 1 { - t.Fatalf("expected waiters to be 1, got: %d", ctr.count()) - } - if err := l.Unlock("test"); err != nil { t.Fatal(err) } - runtime.Gosched() select { case <-chDone: - default: - // one more time just to be sure - runtime.Gosched() - select { - case <-chDone: - default: - t.Fatalf("lock should have completed") - } + case <-time.After(3 * time.Second): + t.Fatalf("lock should have completed") } if ctr.count() != 0 { @@ -80,11 +83,9 @@ func TestLockerUnlock(t *testing.T) { close(chDone) }() - runtime.Gosched() - select { case <-chDone: - default: + case <-time.After(3 * time.Second): t.Fatalf("lock should not be blocked") } }