diff --git a/util/failover.py b/util/failover.py index c1490b772..7e75d2140 100644 --- a/util/failover.py +++ b/util/failover.py @@ -8,8 +8,9 @@ logger = logging.getLogger(__name__) class FailoverException(Exception): """ Exception raised when an operation should be retried by the failover decorator. """ - def __init__(self, message): + def __init__(self, return_value, message): super(FailoverException, self).__init__() + self.return_value = return_value self.message = message def failover(func): @@ -41,6 +42,7 @@ def failover(func): return func(*arg_set[0], **arg_set[1]) except FailoverException as ex: logger.debug('failing over: %s', ex.message) + return_value = ex.return_value continue - raise FailoverException('exhausted all possible failovers') + return return_value return wrapper diff --git a/util/secscan/api.py b/util/secscan/api.py index 9bd0d8a49..5c242f46a 100644 --- a/util/secscan/api.py +++ b/util/secscan/api.py @@ -334,7 +334,7 @@ class SecurityScannerAPI(object): if self._config is None: raise Exception('Cannot call unconfigured security system') - timeout = self._config['SECURITY_SCANNER_API_TIMEOUT_SECONDS'] + timeout = self._config.get('SECURITY_SCANNER_API_TIMEOUT_SECONDS', 1) endpoint = self._config['SECURITY_SCANNER_ENDPOINT'] with CloseForLongOperation(self._config): @@ -346,12 +346,12 @@ class SecurityScannerAPI(object): return self._request(method, endpoint, path, body, params, timeout) # The request is read-only and can failover. - all_endpoints = [endpoint] + self._config['SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS'] + all_endpoints = [endpoint] + self._config.get('SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS', []) try: - return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) + return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) for endpoint in all_endpoints]) - except FailoverException: - raise APIRequestFailure() + except FailoverException as ex: + return ex.return_value def _join_api_url(endpoint, api_version, path): @@ -364,5 +364,5 @@ def _failover_read_request(request_fn, endpoint, path, body, params, timeout): """ This function auto-retries read-only requests until they return a 2xx status code. """ resp = request_fn('GET', endpoint, path, body, params, timeout) if resp.status_code / 100 != 2: - raise FailoverException('status code was not 2xx') + raise FailoverException(resp, 'status code was not 2xx') return resp diff --git a/util/test/test_failover.py b/util/test/test_failover.py index 340092179..369ddea3a 100644 --- a/util/test/test_failover.py +++ b/util/test/test_failover.py @@ -18,7 +18,7 @@ def my_failover_func(i, should_raise=None): i.increment() if should_raise is not None: raise should_raise() - raise FailoverException('incrementing') + raise FailoverException(None, 'incrementing') @pytest.mark.parametrize('stop_on,exception', [