diff --git a/test/test_api_usage.py b/test/test_api_usage.py index d5c738e59..277410234 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -7,10 +7,10 @@ import time import re import json as py_json -from contextlib import contextmanager -from calendar import timegm -from httmock import urlmatch, HTTMock, all_requests from StringIO import StringIO +from calendar import timegm +from contextlib import contextmanager +from httmock import urlmatch, HTTMock, all_requests from urllib import urlencode from urlparse import urlparse, urlunparse, parse_qs @@ -119,13 +119,28 @@ CSRF_TOKEN_KEY = '_csrf_token' CSRF_TOKEN = '123csrfforme' -class ConfigForTesting(object): +class AppConfigChange(object): + """ AppConfigChange takes a dictionary that overrides the global app config + for a given block of code. The values are restored on exit. """ + def __init__(self, changes=None): + self._changes = changes or {} + self._originals = {} + self._to_rm = [] + def __enter__(self): - config_provider.reset_for_test() - return config_provider + for key in self._changes.keys(): + try: + self._originals[key] = app.config[key] + except KeyError: + self._to_rm.append(key) + app.config[key] = self._changes[key] def __exit__(self, type, value, traceback): - config_provider.reset_for_test() + for key in self._originals.keys(): + app.config[key] = self._originals[key] + + for key in self._to_rm: + del app.config[key] class ApiTestCase(unittest.TestCase): @@ -4335,25 +4350,23 @@ class TestRepositoryImageSecurity(ApiTestCase): self.assertEquals(1, image_response['data']['Layer']['IndexedByVersion']) def test_get_vulnerabilities_read_failover(self): - with ConfigForTesting(): - self.login(ADMIN_ACCESS_USER) + self.login(ADMIN_ACCESS_USER) - # Get a layer and mark it as indexed. - layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, 'simple', 'latest') - layer.security_indexed = True - layer.security_indexed_engine = app.config['SECURITY_SCANNER_ENGINE_VERSION_TARGET'] - layer.save() + # Get a layer and mark it as indexed. + layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, 'simple', 'latest') + layer.security_indexed = True + layer.security_indexed_engine = app.config['SECURITY_SCANNER_ENGINE_VERSION_TARGET'] + layer.save() - with fake_security_scanner(hostname='failoverscanner') as security_scanner: - # Query the wrong security scanner URL without failover. - self.getResponse(RepositoryImageSecurity, - params=dict(repository=ADMIN_ACCESS_USER + '/simple', - imageid=layer.docker_image_id, vulnerabilities='true'), - expected_code=520) - - # Set the failover URL. - app.config['SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS'] = ['http://failoverscanner'] + with fake_security_scanner(hostname='failoverscanner') as security_scanner: + # Query the wrong security scanner URL without failover. + self.getResponse(RepositoryImageSecurity, + params=dict(repository=ADMIN_ACCESS_USER + '/simple', + imageid=layer.docker_image_id, vulnerabilities='true'), + expected_code=520) + # Set the failover URL in the global config. + with AppConfigChange({'SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS': ['https://failoverscanner']}): # Configure the API to return 200 for this layer. layer_id = security_scanner.layer_id(layer) security_scanner.set_ok_layer_id(layer_id) diff --git a/test/test_suconfig_api.py b/test/test_suconfig_api.py index ea3bef651..f24b28f19 100644 --- a/test/test_suconfig_api.py +++ b/test/test_suconfig_api.py @@ -1,4 +1,4 @@ -from test.test_api_usage import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER, ConfigForTesting +from test.test_api_usage import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER from endpoints.api.suconfig import (SuperUserRegistryStatus, SuperUserConfig, SuperUserConfigFile, SuperUserCreateInitialSuperUser, SuperUserConfigValidate) from app import config_provider, all_queues @@ -8,16 +8,25 @@ from data import model import unittest +class FreshConfigProvider(object): + def __enter__(self): + config_provider.reset_for_test() + return config_provider + + def __exit__(self, type, value, traceback): + config_provider.reset_for_test() + + class TestSuperUserRegistryStatus(ApiTestCase): def test_registry_status(self): - with ConfigForTesting(): + with FreshConfigProvider(): json = self.getJsonResponse(SuperUserRegistryStatus) self.assertEquals('upload-license', json['status']) class TestSuperUserConfigFile(ApiTestCase): def test_get_non_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): # No user. self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=403) @@ -26,18 +35,18 @@ class TestSuperUserConfigFile(ApiTestCase): self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=403) def test_get_superuser_invalid_filename(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) self.getResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) def test_get_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) result = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) self.assertFalse(result['exists']) def test_post_non_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): # No user, before config.yaml exists. self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) @@ -52,25 +61,25 @@ class TestSuperUserConfigFile(ApiTestCase): self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=403) def test_post_superuser_invalid_filename(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) self.postResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) def test_post_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) class TestSuperUserCreateInitialSuperUser(ApiTestCase): def test_no_config_file(self): - with ConfigForTesting(): + with FreshConfigProvider(): # If there is no config.yaml, then this method should security fail. data = dict(username='cooluser', password='password', email='fake@example.com') self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) def test_config_file_with_db_users(self): - with ConfigForTesting(): + with FreshConfigProvider(): # Write some config. self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) @@ -80,7 +89,7 @@ class TestSuperUserCreateInitialSuperUser(ApiTestCase): self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) def test_config_file_with_no_db_users(self): - with ConfigForTesting(): + with FreshConfigProvider(): # Write some config. self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) @@ -103,7 +112,7 @@ class TestSuperUserCreateInitialSuperUser(ApiTestCase): class TestSuperUserConfigValidate(ApiTestCase): def test_nonsuperuser_noconfig(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'), data=dict(config={})) @@ -112,7 +121,7 @@ class TestSuperUserConfigValidate(ApiTestCase): def test_nonsuperuser_config(self): - with ConfigForTesting(): + with FreshConfigProvider(): # The validate config call works if there is no config.yaml OR the user is a superuser. # Add a config, and verify it breaks when unauthenticated. json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) @@ -132,7 +141,7 @@ class TestSuperUserConfigValidate(ApiTestCase): class TestSuperUserConfig(ApiTestCase): def test_get_non_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): # No user. self.getResponse(SuperUserConfig, expected_code=401) @@ -141,7 +150,7 @@ class TestSuperUserConfig(ApiTestCase): self.getResponse(SuperUserConfig, expected_code=403) def test_get_superuser(self): - with ConfigForTesting(): + with FreshConfigProvider(): self.login(ADMIN_ACCESS_USER) json = self.getJsonResponse(SuperUserConfig) @@ -150,7 +159,7 @@ class TestSuperUserConfig(ApiTestCase): self.assertIsNone(json['config']) def test_put(self): - with ConfigForTesting() as config: + with FreshConfigProvider() as config: # The update config call works if there is no config.yaml OR the user is a superuser. First # try writing it without a superuser present. json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) diff --git a/util/failover.py b/util/failover.py index c1490b772..7e75d2140 100644 --- a/util/failover.py +++ b/util/failover.py @@ -8,8 +8,9 @@ logger = logging.getLogger(__name__) class FailoverException(Exception): """ Exception raised when an operation should be retried by the failover decorator. """ - def __init__(self, message): + def __init__(self, return_value, message): super(FailoverException, self).__init__() + self.return_value = return_value self.message = message def failover(func): @@ -41,6 +42,7 @@ def failover(func): return func(*arg_set[0], **arg_set[1]) except FailoverException as ex: logger.debug('failing over: %s', ex.message) + return_value = ex.return_value continue - raise FailoverException('exhausted all possible failovers') + return return_value return wrapper diff --git a/util/secscan/api.py b/util/secscan/api.py index 9bd0d8a49..84a16ad8b 100644 --- a/util/secscan/api.py +++ b/util/secscan/api.py @@ -288,30 +288,39 @@ class SecurityScannerAPI(object): return self._get_layer_data(layer_id, include_features, include_vulnerabilities) def _get_layer_data(self, layer_id, include_features=False, include_vulnerabilities=False): + params = {} + if include_features: + params = {'features': True} + + if include_vulnerabilities: + params = {'vulnerabilities': True} + try: - params = {} - if include_features: - params = {'features': True} - - if include_vulnerabilities: - params = {'vulnerabilities': True} - response = self._call('GET', _API_METHOD_GET_LAYER % layer_id, params=params) logger.debug('Got response %s for vulnerabilities for layer %s', response.status_code, layer_id) - json_response = response.json() + if response.status_code == 404: + return None + elif response.status_code // 100 == 5: + logger.error( + 'downstream security service failure: status %d, text: %s', + response.status_code, + response.text, + ) + raise APIRequestFailure('Downstream service returned 5xx') except requests.exceptions.Timeout: raise APIRequestFailure('API call timed out') except requests.exceptions.ConnectionError: raise APIRequestFailure('Could not connect to security service') - except (requests.exceptions.RequestException, ValueError): + except requests.exceptions.RequestException: logger.exception('Failed to get layer data response for %s', layer_id) raise APIRequestFailure() - if response.status_code == 404: - return None + try: + return response.json() + except ValueError: + logger.exception('Failed to decode response JSON') - return json_response def _request(self, method, endpoint, path, body, params, timeout): """ Issues an HTTP request to the security endpoint. """ @@ -334,7 +343,7 @@ class SecurityScannerAPI(object): if self._config is None: raise Exception('Cannot call unconfigured security system') - timeout = self._config['SECURITY_SCANNER_API_TIMEOUT_SECONDS'] + timeout = self._config.get('SECURITY_SCANNER_API_TIMEOUT_SECONDS', 1) endpoint = self._config['SECURITY_SCANNER_ENDPOINT'] with CloseForLongOperation(self._config): @@ -346,12 +355,12 @@ class SecurityScannerAPI(object): return self._request(method, endpoint, path, body, params, timeout) # The request is read-only and can failover. - all_endpoints = [endpoint] + self._config['SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS'] + all_endpoints = [endpoint] + self._config.get('SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS', []) try: - return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) + return _failover_read_request(*[((self._request, endpoint, path, body, params, timeout), {}) for endpoint in all_endpoints]) - except FailoverException: - raise APIRequestFailure() + except FailoverException as ex: + return ex.return_value def _join_api_url(endpoint, api_version, path): @@ -364,5 +373,5 @@ def _failover_read_request(request_fn, endpoint, path, body, params, timeout): """ This function auto-retries read-only requests until they return a 2xx status code. """ resp = request_fn('GET', endpoint, path, body, params, timeout) if resp.status_code / 100 != 2: - raise FailoverException('status code was not 2xx') + raise FailoverException(resp, 'status code was not 2xx') return resp diff --git a/util/secscan/fake.py b/util/secscan/fake.py index 49a7e493a..0f39f3ff9 100644 --- a/util/secscan/fake.py +++ b/util/secscan/fake.py @@ -5,6 +5,7 @@ import urlparse from contextlib import contextmanager from httmock import urlmatch, HTTMock, all_requests + from util.secscan.api import UNKNOWN_PARENT_LAYER_ERROR_MSG, compute_layer_id @contextmanager @@ -170,7 +171,6 @@ class FakeSecurityScanner(object): def get_endpoints(self): """ Returns the HTTMock endpoint definitions for the fake security scanner. """ - @urlmatch(netloc=r'(.*\.)?' + self.hostname, path=r'/v1/layers/(.+)', method='GET') def get_layer_mock(url, request): layer_id = url.path[len('/v1/layers/'):] @@ -320,7 +320,7 @@ class FakeSecurityScanner(object): def response_content(url, _): return { 'status_code': 500, - 'content': '', + 'content': json.dumps({'Error': {'Message': 'Unknown endpoint %s' % url.path}}), } return [get_layer_mock, post_layer_mock, remove_layer_mock, get_notification, diff --git a/util/test/test_failover.py b/util/test/test_failover.py index 340092179..369ddea3a 100644 --- a/util/test/test_failover.py +++ b/util/test/test_failover.py @@ -18,7 +18,7 @@ def my_failover_func(i, should_raise=None): i.increment() if should_raise is not None: raise should_raise() - raise FailoverException('incrementing') + raise FailoverException(None, 'incrementing') @pytest.mark.parametrize('stop_on,exception', [