Merge pull request #2396 from jzelinskie/fix-status-code

Fix Security Scanner Status Code Exception
This commit is contained in:
Jimmy Zelinskie 2017-03-01 13:42:38 -05:00 committed by GitHub
commit c54a99b2c2
3 changed files with 66 additions and 51 deletions

View file

@ -7,11 +7,12 @@ logger = logging.getLogger(__name__)
class FailoverException(Exception): class FailoverException(Exception):
""" Exception raised when an operation should be retried by the failover decorator. """ """ Exception raised when an operation should be retried by the failover decorator.
def __init__(self, return_value, message): Wraps the exception of the initial failure.
"""
def __init__(self, exception):
super(FailoverException, self).__init__() super(FailoverException, self).__init__()
self.return_value = return_value self.exception = exception
self.message = message
def failover(func): def failover(func):
""" Wraps a function such that it can be retried on specified failures. """ Wraps a function such that it can be retried on specified failures.
@ -21,9 +22,10 @@ def failover(func):
@failover @failover
def get_google(scheme, use_www=False): def get_google(scheme, use_www=False):
www = 'www.' if use_www else '' www = 'www.' if use_www else ''
try:
r = requests.get(scheme + '://' + www + 'google.com') r = requests.get(scheme + '://' + www + 'google.com')
if r.status_code != 200: except requests.RequestException as ex:
raise FailoverException('non 200 response from Google' ) raise FailoverException(ex)
return r return r
def GooglePingTest(): def GooglePingTest():
@ -41,8 +43,8 @@ def failover(func):
try: try:
return func(*arg_set[0], **arg_set[1]) return func(*arg_set[0], **arg_set[1])
except FailoverException as ex: except FailoverException as ex:
logger.debug('failing over: %s', ex.message) logger.debug('failing over')
return_value = ex.return_value exception = ex.exception
continue continue
return return_value raise exception
return wrapper return wrapper

View file

@ -43,6 +43,12 @@ class InvalidLayerException(AnalyzeLayerException):
class APIRequestFailure(Exception): class APIRequestFailure(Exception):
""" Exception raised when there is a failure to conduct an API request. """ """ Exception raised when there is a failure to conduct an API request. """
class Non200ResponseException(Exception):
""" Exception raised when the upstream API returns a non-200 HTTP status code. """
def __init__(self, response):
super(Non200ResponseException, self).__init__()
self.response = response
_API_METHOD_INSERT = 'layers' _API_METHOD_INSERT = 'layers'
_API_METHOD_GET_LAYER = 'layers/%s' _API_METHOD_GET_LAYER = 'layers/%s'
@ -176,8 +182,10 @@ class SecurityScannerAPI(object):
""" """
layer_id = compute_layer_id(layer) layer_id = compute_layer_id(layer)
try: try:
response = self._call('DELETE', _API_METHOD_DELETE_LAYER % layer_id) self._call('DELETE', _API_METHOD_DELETE_LAYER % layer_id)
return response.status_code / 100 == 2 return True
except Non200ResponseException:
return False
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
logger.exception('Failed to delete layer: %s', layer_id) logger.exception('Failed to delete layer: %s', layer_id)
return False return False
@ -187,6 +195,13 @@ class SecurityScannerAPI(object):
Returns the analysis version on success or raises an exception deriving from Returns the analysis version on success or raises an exception deriving from
AnalyzeLayerException on failure. Callers should handle all cases of AnalyzeLayerException. AnalyzeLayerException on failure. Callers should handle all cases of AnalyzeLayerException.
""" """
def _response_json(request, response):
try:
return response.json()
except ValueError:
logger.exception('Failed to decode JSON when analyzing layer %s', request['Layer']['Name'])
raise AnalyzeLayerException
request = self._new_analyze_request(layer) request = self._new_analyze_request(layer)
if not request: if not request:
raise AnalyzeLayerException raise AnalyzeLayerException
@ -194,42 +209,35 @@ class SecurityScannerAPI(object):
logger.info('Analyzing layer %s', request['Layer']['Name']) logger.info('Analyzing layer %s', request['Layer']['Name'])
try: try:
response = self._call('POST', _API_METHOD_INSERT, body=request) response = self._call('POST', _API_METHOD_INSERT, body=request)
json_response = response.json()
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
logger.exception('Timeout when trying to post layer data response for %s', layer.id) logger.exception('Timeout when trying to post layer data response for %s', layer.id)
raise AnalyzeLayerRetryException raise AnalyzeLayerRetryException
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.exception('Connection error when trying to post layer data response for %s', layer.id) logger.exception('Connection error when trying to post layer data response for %s', layer.id)
raise AnalyzeLayerRetryException raise AnalyzeLayerRetryException
except (requests.exceptions.RequestException, ValueError) as re: except (requests.exceptions.RequestException) as re:
logger.exception('Failed to post layer data response for %s: %s', layer.id, re) logger.exception('Failed to post layer data response for %s: %s', layer.id, re)
raise AnalyzeLayerException raise AnalyzeLayerException
except Non200ResponseException as ex:
# Handle any errors from the security scanner. message = _response_json(request, ex.response).get('Error').get('Message', '')
if response.status_code != 201:
message = json_response.get('Error').get('Message', '')
logger.warning('A warning event occurred when analyzing layer %s (status code %s): %s', logger.warning('A warning event occurred when analyzing layer %s (status code %s): %s',
request['Layer']['Name'], response.status_code, message) request['Layer']['Name'], ex.response.status_code, message)
# 400 means the layer could not be analyzed due to a bad request. # 400 means the layer could not be analyzed due to a bad request.
if response.status_code == 400: if ex.response.status_code == 400:
if message == UNKNOWN_PARENT_LAYER_ERROR_MSG: if message == UNKNOWN_PARENT_LAYER_ERROR_MSG:
raise MissingParentLayerException('Bad request to security scanner: %s' % message) raise MissingParentLayerException('Bad request to security scanner: %s' % message)
else: else:
raise AnalyzeLayerException('Bad request to security scanner: %s' % message) raise AnalyzeLayerException('Bad request to security scanner: %s' % message)
# 422 means that the layer could not be analyzed: # 422 means that the layer could not be analyzed:
# - the layer could not be extracted (might be a manifest or an invalid .tar.gz) # - the layer could not be extracted (might be a manifest or an invalid .tar.gz)
# - the layer operating system / package manager is unsupported # - the layer operating system / package manager is unsupported
elif response.status_code == 422: elif ex.response.status_code == 422:
raise InvalidLayerException raise InvalidLayerException
# Otherwise, it is some other error and we should retry. # Otherwise, it is some other error and we should retry.
else:
raise AnalyzeLayerRetryException raise AnalyzeLayerRetryException
# Return the parsed API version. # Return the parsed API version.
return json_response['Layer']['IndexedByVersion'] return _response_json(request, response)['Layer']['IndexedByVersion']
def check_layer_vulnerable(self, layer_id, cve_name): def check_layer_vulnerable(self, layer_id, cve_name):
""" Checks to see if the layer with the given ID is vulnerable to the specified CVE. """ """ Checks to see if the layer with the given ID is vulnerable to the specified CVE. """
@ -267,17 +275,18 @@ class SecurityScannerAPI(object):
except (requests.exceptions.RequestException, ValueError): except (requests.exceptions.RequestException, ValueError):
logger.exception('Failed to get notification for %s', notification_name) logger.exception('Failed to get notification for %s', notification_name)
return None, False return None, False
except Non200ResponseException as ex:
if response.status_code != 200: return None, ex.response.status_code != 404 and ex.response.status_code != 400
return None, response.status_code != 404 and response.status_code != 400
return json_response, False return json_response, False
def mark_notification_read(self, notification_name): def mark_notification_read(self, notification_name):
""" Marks a security scanner notification as read. """ """ Marks a security scanner notification as read. """
try: try:
response = self._call('DELETE', _API_METHOD_MARK_NOTIFICATION_READ % notification_name) self._call('DELETE', _API_METHOD_MARK_NOTIFICATION_READ % notification_name)
return response.status_code / 100 == 2 return True
except Non200ResponseException:
return False
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
logger.exception('Failed to mark notification as read: %s', notification_name) logger.exception('Failed to mark notification as read: %s', notification_name)
return False return False
@ -299,13 +308,16 @@ class SecurityScannerAPI(object):
response = self._call('GET', _API_METHOD_GET_LAYER % layer_id, params=params) response = self._call('GET', _API_METHOD_GET_LAYER % layer_id, params=params)
logger.debug('Got response %s for vulnerabilities for layer %s', logger.debug('Got response %s for vulnerabilities for layer %s',
response.status_code, layer_id) response.status_code, layer_id)
if response.status_code == 404: except Non200ResponseException as ex:
logger.debug('Got failed response %s for vulnerabilities for layer %s',
ex.response.status_code, layer_id)
if ex.response.status_code == 404:
return None return None
elif response.status_code // 100 == 5: elif ex.response.status_code // 100 == 5:
logger.error( logger.error(
'downstream security service failure: status %d, text: %s', 'downstream security service failure: status %d, text: %s',
response.status_code, ex.response.status_code,
response.text, ex.response.text,
) )
raise APIRequestFailure('Downstream service returned 5xx') raise APIRequestFailure('Downstream service returned 5xx')
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
@ -331,10 +343,13 @@ class SecurityScannerAPI(object):
signer_proxy_url = self._config.get('JWTPROXY_SIGNER', 'localhost:8080') signer_proxy_url = self._config.get('JWTPROXY_SIGNER', 'localhost:8080')
logger.debug('%sing security URL %s', method.upper(), url) logger.debug('%sing security URL %s', method.upper(), url)
return self._client.request(method, url, json=body, params=params, timeout=timeout, resp = self._client.request(method, url, json=body, params=params, timeout=timeout,
verify=MITM_CERT_PATH, headers=DEFAULT_HTTP_HEADERS, verify=MITM_CERT_PATH, headers=DEFAULT_HTTP_HEADERS,
proxies={'https': 'https://' + signer_proxy_url, proxies={'https': 'https://' + signer_proxy_url,
'http': 'http://' + signer_proxy_url}) 'http': 'http://' + signer_proxy_url})
if resp.status_code // 100 != 2:
raise Non200ResponseException(resp)
return resp
def _call(self, method, path, params=None, body=None): def _call(self, method, path, params=None, body=None):
""" Issues an HTTP request to the security endpoint handling the logic of using an alternative """ Issues an HTTP request to the security endpoint handling the logic of using an alternative
@ -356,11 +371,8 @@ class SecurityScannerAPI(object):
# The request is read-only and can failover. # The request is read-only and can failover.
all_endpoints = [endpoint] + self._config.get('SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS', []) all_endpoints = [endpoint] + self._config.get('SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS', [])
try:
return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {})
for endpoint in all_endpoints]) for endpoint in all_endpoints])
except FailoverException as ex:
return ex.return_value
def _join_api_url(endpoint, api_version, path): def _join_api_url(endpoint, api_version, path):
@ -372,9 +384,6 @@ def _join_api_url(endpoint, api_version, path):
def _failover_read_request(request_fn, endpoint, path, body, params, timeout): def _failover_read_request(request_fn, endpoint, path, body, params, timeout):
""" This function auto-retries read-only requests until they return a 2xx status code. """ """ This function auto-retries read-only requests until they return a 2xx status code. """
try: try:
resp = request_fn('GET', endpoint, path, body, params, timeout) return request_fn('GET', endpoint, path, body, params, timeout)
if resp.status_code / 100 != 2: except (requests.exceptions.RequestException, Non200ResponseException) as ex:
raise FailoverException(resp, 'status code was not 2xx') raise FailoverException(ex)
return resp
except requests.exceptions.RequestException:
raise FailoverException(None, 'connection failure')

View file

@ -2,6 +2,9 @@ import pytest
from util.failover import failover, FailoverException from util.failover import failover, FailoverException
class FinishedException(Exception):
""" Exception raised at the end of every iteration to force failover. """
class Counter(object): class Counter(object):
""" Wraps a counter in an object so that it'll be passed by reference. """ """ Wraps a counter in an object so that it'll be passed by reference. """
@ -18,7 +21,7 @@ def my_failover_func(i, should_raise=None):
i.increment() i.increment()
if should_raise is not None: if should_raise is not None:
raise should_raise() raise should_raise()
raise FailoverException(None, 'incrementing') raise FailoverException(FinishedException())
@pytest.mark.parametrize('stop_on,exception', [ @pytest.mark.parametrize('stop_on,exception', [
@ -40,5 +43,6 @@ def test_readonly_failover(stop_on, exception):
with pytest.raises(exception): with pytest.raises(exception):
my_failover_func(*arg_sets) my_failover_func(*arg_sets)
else: else:
with pytest.raises(FinishedException):
my_failover_func(*arg_sets) my_failover_func(*arg_sets)
assert counter.calls == stop_on assert counter.calls == stop_on