diff --git a/util/failover.py b/util/failover.py index 7e75d2140..c88d47a6f 100644 --- a/util/failover.py +++ b/util/failover.py @@ -7,11 +7,12 @@ logger = logging.getLogger(__name__) class FailoverException(Exception): - """ Exception raised when an operation should be retried by the failover decorator. """ - def __init__(self, return_value, message): + """ Exception raised when an operation should be retried by the failover decorator. + Wraps the exception of the initial failure. + """ + def __init__(self, exception): super(FailoverException, self).__init__() - self.return_value = return_value - self.message = message + self.exception = exception def failover(func): """ Wraps a function such that it can be retried on specified failures. @@ -21,9 +22,10 @@ def failover(func): @failover def get_google(scheme, use_www=False): www = 'www.' if use_www else '' - r = requests.get(scheme + '://' + www + 'google.com') - if r.status_code != 200: - raise FailoverException('non 200 response from Google' ) + try: + r = requests.get(scheme + '://' + www + 'google.com') + except requests.RequestException as ex: + raise FailoverException(ex) return r def GooglePingTest(): @@ -41,8 +43,8 @@ def failover(func): try: return func(*arg_set[0], **arg_set[1]) except FailoverException as ex: - logger.debug('failing over: %s', ex.message) - return_value = ex.return_value + logger.debug('failing over') + exception = ex.exception continue - return return_value + raise exception return wrapper diff --git a/util/test/test_failover.py b/util/test/test_failover.py index 369ddea3a..333c39362 100644 --- a/util/test/test_failover.py +++ b/util/test/test_failover.py @@ -2,6 +2,9 @@ import pytest from util.failover import failover, FailoverException +class FinishedException(Exception): + """ Exception raised at the end of every iteration to force failover. """ + class Counter(object): """ Wraps a counter in an object so that it'll be passed by reference. """ @@ -18,7 +21,7 @@ def my_failover_func(i, should_raise=None): i.increment() if should_raise is not None: raise should_raise() - raise FailoverException(None, 'incrementing') + raise FailoverException(FinishedException()) @pytest.mark.parametrize('stop_on,exception', [ @@ -40,5 +43,6 @@ def test_readonly_failover(stop_on, exception): with pytest.raises(exception): my_failover_func(*arg_sets) else: - my_failover_func(*arg_sets) - assert counter.calls == stop_on + with pytest.raises(FinishedException): + my_failover_func(*arg_sets) + assert counter.calls == stop_on