diff --git a/util/failover.py b/util/failover.py index 7e75d2140..c88d47a6f 100644 --- a/util/failover.py +++ b/util/failover.py @@ -7,11 +7,12 @@ logger = logging.getLogger(__name__) class FailoverException(Exception): - """ Exception raised when an operation should be retried by the failover decorator. """ - def __init__(self, return_value, message): + """ Exception raised when an operation should be retried by the failover decorator. + Wraps the exception of the initial failure. + """ + def __init__(self, exception): super(FailoverException, self).__init__() - self.return_value = return_value - self.message = message + self.exception = exception def failover(func): """ Wraps a function such that it can be retried on specified failures. @@ -21,9 +22,10 @@ def failover(func): @failover def get_google(scheme, use_www=False): www = 'www.' if use_www else '' - r = requests.get(scheme + '://' + www + 'google.com') - if r.status_code != 200: - raise FailoverException('non 200 response from Google' ) + try: + r = requests.get(scheme + '://' + www + 'google.com') + except requests.RequestException as ex: + raise FailoverException(ex) return r def GooglePingTest(): @@ -41,8 +43,8 @@ def failover(func): try: return func(*arg_set[0], **arg_set[1]) except FailoverException as ex: - logger.debug('failing over: %s', ex.message) - return_value = ex.return_value + logger.debug('failing over') + exception = ex.exception continue - return return_value + raise exception return wrapper diff --git a/util/secscan/api.py b/util/secscan/api.py index efcd79571..f90b3ba37 100644 --- a/util/secscan/api.py +++ b/util/secscan/api.py @@ -43,6 +43,12 @@ class InvalidLayerException(AnalyzeLayerException): class APIRequestFailure(Exception): """ Exception raised when there is a failure to conduct an API request. """ +class Non200ResponseException(Exception): + """ Exception raised when the upstream API returns a non-200 HTTP status code. """ + def __init__(self, response): + super(Non200ResponseException, self).__init__() + self.response = response + _API_METHOD_INSERT = 'layers' _API_METHOD_GET_LAYER = 'layers/%s' @@ -176,8 +182,10 @@ class SecurityScannerAPI(object): """ layer_id = compute_layer_id(layer) try: - response = self._call('DELETE', _API_METHOD_DELETE_LAYER % layer_id) - return response.status_code / 100 == 2 + self._call('DELETE', _API_METHOD_DELETE_LAYER % layer_id) + return True + except Non200ResponseException: + return False except requests.exceptions.RequestException: logger.exception('Failed to delete layer: %s', layer_id) return False @@ -187,6 +195,13 @@ class SecurityScannerAPI(object): Returns the analysis version on success or raises an exception deriving from AnalyzeLayerException on failure. Callers should handle all cases of AnalyzeLayerException. """ + def _response_json(request, response): + try: + return response.json() + except ValueError: + logger.exception('Failed to decode JSON when analyzing layer %s', request['Layer']['Name']) + raise AnalyzeLayerException + request = self._new_analyze_request(layer) if not request: raise AnalyzeLayerException @@ -194,42 +209,35 @@ class SecurityScannerAPI(object): logger.info('Analyzing layer %s', request['Layer']['Name']) try: response = self._call('POST', _API_METHOD_INSERT, body=request) - json_response = response.json() except requests.exceptions.Timeout: logger.exception('Timeout when trying to post layer data response for %s', layer.id) raise AnalyzeLayerRetryException except requests.exceptions.ConnectionError: logger.exception('Connection error when trying to post layer data response for %s', layer.id) raise AnalyzeLayerRetryException - except (requests.exceptions.RequestException, ValueError) as re: + except (requests.exceptions.RequestException) as re: logger.exception('Failed to post layer data response for %s: %s', layer.id, re) raise AnalyzeLayerException - - # Handle any errors from the security scanner. - if response.status_code != 201: - message = json_response.get('Error').get('Message', '') + except Non200ResponseException as ex: + message = _response_json(request, ex.response).get('Error').get('Message', '') logger.warning('A warning event occurred when analyzing layer %s (status code %s): %s', - request['Layer']['Name'], response.status_code, message) - + request['Layer']['Name'], ex.response.status_code, message) # 400 means the layer could not be analyzed due to a bad request. - if response.status_code == 400: + if ex.response.status_code == 400: if message == UNKNOWN_PARENT_LAYER_ERROR_MSG: raise MissingParentLayerException('Bad request to security scanner: %s' % message) else: raise AnalyzeLayerException('Bad request to security scanner: %s' % message) - # 422 means that the layer could not be analyzed: # - the layer could not be extracted (might be a manifest or an invalid .tar.gz) # - the layer operating system / package manager is unsupported - elif response.status_code == 422: + elif ex.response.status_code == 422: raise InvalidLayerException - # Otherwise, it is some other error and we should retry. - else: - raise AnalyzeLayerRetryException + raise AnalyzeLayerRetryException # Return the parsed API version. - return json_response['Layer']['IndexedByVersion'] + return _response_json(request, response)['Layer']['IndexedByVersion'] def check_layer_vulnerable(self, layer_id, cve_name): """ Checks to see if the layer with the given ID is vulnerable to the specified CVE. """ @@ -267,17 +275,18 @@ class SecurityScannerAPI(object): except (requests.exceptions.RequestException, ValueError): logger.exception('Failed to get notification for %s', notification_name) return None, False - - if response.status_code != 200: - return None, response.status_code != 404 and response.status_code != 400 + except Non200ResponseException as ex: + return None, ex.response.status_code != 404 and ex.response.status_code != 400 return json_response, False def mark_notification_read(self, notification_name): """ Marks a security scanner notification as read. """ try: - response = self._call('DELETE', _API_METHOD_MARK_NOTIFICATION_READ % notification_name) - return response.status_code / 100 == 2 + self._call('DELETE', _API_METHOD_MARK_NOTIFICATION_READ % notification_name) + return True + except Non200ResponseException: + return False except requests.exceptions.RequestException: logger.exception('Failed to mark notification as read: %s', notification_name) return False @@ -299,13 +308,16 @@ class SecurityScannerAPI(object): response = self._call('GET', _API_METHOD_GET_LAYER % layer_id, params=params) logger.debug('Got response %s for vulnerabilities for layer %s', response.status_code, layer_id) - if response.status_code == 404: + except Non200ResponseException as ex: + logger.debug('Got failed response %s for vulnerabilities for layer %s', + ex.response.status_code, layer_id) + if ex.response.status_code == 404: return None - elif response.status_code // 100 == 5: + elif ex.response.status_code // 100 == 5: logger.error( 'downstream security service failure: status %d, text: %s', - response.status_code, - response.text, + ex.response.status_code, + ex.response.text, ) raise APIRequestFailure('Downstream service returned 5xx') except requests.exceptions.Timeout: @@ -331,10 +343,13 @@ class SecurityScannerAPI(object): signer_proxy_url = self._config.get('JWTPROXY_SIGNER', 'localhost:8080') logger.debug('%sing security URL %s', method.upper(), url) - return self._client.request(method, url, json=body, params=params, timeout=timeout, + resp = self._client.request(method, url, json=body, params=params, timeout=timeout, verify=MITM_CERT_PATH, headers=DEFAULT_HTTP_HEADERS, proxies={'https': 'https://' + signer_proxy_url, 'http': 'http://' + signer_proxy_url}) + if resp.status_code // 100 != 2: + raise Non200ResponseException(resp) + return resp def _call(self, method, path, params=None, body=None): """ Issues an HTTP request to the security endpoint handling the logic of using an alternative @@ -356,11 +371,8 @@ class SecurityScannerAPI(object): # The request is read-only and can failover. all_endpoints = [endpoint] + self._config.get('SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS', []) - try: - return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) - for endpoint in all_endpoints]) - except FailoverException as ex: - return ex.return_value + return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) + for endpoint in all_endpoints]) def _join_api_url(endpoint, api_version, path): @@ -372,9 +384,6 @@ def _join_api_url(endpoint, api_version, path): def _failover_read_request(request_fn, endpoint, path, body, params, timeout): """ This function auto-retries read-only requests until they return a 2xx status code. """ try: - resp = request_fn('GET', endpoint, path, body, params, timeout) - if resp.status_code / 100 != 2: - raise FailoverException(resp, 'status code was not 2xx') - return resp - except requests.exceptions.RequestException: - raise FailoverException(None, 'connection failure') + return request_fn('GET', endpoint, path, body, params, timeout) + except (requests.exceptions.RequestException, Non200ResponseException) as ex: + raise FailoverException(ex) diff --git a/util/test/test_failover.py b/util/test/test_failover.py index 369ddea3a..333c39362 100644 --- a/util/test/test_failover.py +++ b/util/test/test_failover.py @@ -2,6 +2,9 @@ import pytest from util.failover import failover, FailoverException +class FinishedException(Exception): + """ Exception raised at the end of every iteration to force failover. """ + class Counter(object): """ Wraps a counter in an object so that it'll be passed by reference. """ @@ -18,7 +21,7 @@ def my_failover_func(i, should_raise=None): i.increment() if should_raise is not None: raise should_raise() - raise FailoverException(None, 'incrementing') + raise FailoverException(FinishedException()) @pytest.mark.parametrize('stop_on,exception', [ @@ -40,5 +43,6 @@ def test_readonly_failover(stop_on, exception): with pytest.raises(exception): my_failover_func(*arg_sets) else: - my_failover_func(*arg_sets) - assert counter.calls == stop_on + with pytest.raises(FinishedException): + my_failover_func(*arg_sets) + assert counter.calls == stop_on