import pytest from util.failover import failover, FailoverException class Counter(object): """ Wraps a counter in an object so that it'll be passed by reference. """ def __init__(self): self.calls = 0 def increment(self): self.calls += 1 @failover def my_failover_func(i, should_raise=None): """ Increments a counter and raises an exception when told. """ i.increment() if should_raise is not None: raise should_raise() raise FailoverException(None, 'incrementing') @pytest.mark.parametrize('stop_on,exception', [ (10, None), (5, IndexError), ]) def test_readonly_failover(stop_on, exception): """ Generates failover arguments and checks against a counter to ensure that the failover function has been called the proper amount of times and stops at unhandled exceptions. """ counter = Counter() arg_sets = [] for i in xrange(stop_on): should_raise = exception if exception is not None and i == stop_on-1 else None arg_sets.append(((counter,), {'should_raise': should_raise})) if exception is not None: with pytest.raises(exception): my_failover_func(*arg_sets) else: my_failover_func(*arg_sets) assert counter.calls == stop_on