diff --git a/test/test_morecollections.py b/test/test_morecollections.py new file mode 100644 index 000000000..b739b24e7 --- /dev/null +++ b/test/test_morecollections.py @@ -0,0 +1,310 @@ +import unittest + +from util.morecollections import (FastIndexList, StreamingDiffTracker, + IndexedStreamingDiffTracker) + +class FastIndexListTests(unittest.TestCase): + def test_basic_usage(self): + indexlist = FastIndexList() + + # Add 1 + indexlist.add(1) + self.assertEquals([1], indexlist.values()) + self.assertEquals(0, indexlist.index(1)) + + # Add 2 + indexlist.add(2) + self.assertEquals([1, 2], indexlist.values()) + self.assertEquals(0, indexlist.index(1)) + self.assertEquals(1, indexlist.index(2)) + + # Pop nothing. + indexlist.pop_until(-1) + self.assertEquals([1, 2], indexlist.values()) + self.assertEquals(0, indexlist.index(1)) + self.assertEquals(1, indexlist.index(2)) + + # Pop 1. + self.assertEquals([1], indexlist.pop_until(0)) + self.assertEquals([2], indexlist.values()) + self.assertIsNone(indexlist.index(1)) + self.assertEquals(0, indexlist.index(2)) + + # Add 3. + indexlist.add(3) + self.assertEquals([2, 3], indexlist.values()) + self.assertEquals(0, indexlist.index(2)) + self.assertEquals(1, indexlist.index(3)) + + # Pop 2, 3. + self.assertEquals([2, 3], indexlist.pop_until(1)) + self.assertEquals([], indexlist.values()) + self.assertIsNone(indexlist.index(1)) + self.assertIsNone(indexlist.index(2)) + self.assertIsNone(indexlist.index(3)) + + def test_popping(self): + indexlist = FastIndexList() + indexlist.add('hello') + indexlist.add('world') + indexlist.add('you') + indexlist.add('rock') + + self.assertEquals(0, indexlist.index('hello')) + self.assertEquals(1, indexlist.index('world')) + self.assertEquals(2, indexlist.index('you')) + self.assertEquals(3, indexlist.index('rock')) + + indexlist.pop_until(1) + self.assertEquals(0, indexlist.index('you')) + self.assertEquals(1, indexlist.index('rock')) + + +class IndexedStreamingDiffTrackerTests(unittest.TestCase): + def test_basic(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 3) + tracker.push_new([('a', 0), ('b', 1), ('c', 2)]) + tracker.push_old([('b', 1)]) + tracker.done() + + self.assertEquals(['a', 'c'], added) + + def test_multiple_done(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 3) + tracker.push_new([('a', 0), ('b', 1), ('c', 2)]) + tracker.push_old([('b', 1)]) + tracker.done() + tracker.done() + + self.assertEquals(['a', 'c'], added) + + def test_same_streams(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 3) + tracker.push_new([('a', 0), ('b', 1), ('c', 2)]) + tracker.push_old([('a', 0), ('b', 1), ('c', 2)]) + tracker.done() + + self.assertEquals([], added) + + def test_only_new(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 3) + tracker.push_new([('a', 0), ('b', 1), ('c', 2)]) + tracker.push_old([]) + tracker.done() + + self.assertEquals(['a', 'b', 'c'], added) + + def test_pagination(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('a', 0), ('b', 1)]) + tracker.push_old([]) + + tracker.push_new([('c', 2)]) + tracker.push_old([]) + + tracker.done() + + self.assertEquals(['a', 'b', 'c'], added) + + def test_old_pagination_no_repeat(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('new1', 3), ('new2', 4)]) + tracker.push_old([('old1', 1), ('old2', 2)]) + + tracker.push_new([]) + tracker.push_old([('new1', 3)]) + + tracker.done() + + self.assertEquals(['new2'], added) + + def test_old_pagination(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('a', 10), ('b', 11)]) + tracker.push_old([('z', 1), ('y', 2)]) + + tracker.push_new([('c', 12)]) + tracker.push_old([('a', 10)]) + + tracker.done() + + self.assertEquals(['b', 'c'], added) + + def test_very_offset(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('a', 10), ('b', 11)]) + tracker.push_old([('z', 1), ('y', 2)]) + + tracker.push_new([('c', 12), ('d', 13)]) + tracker.push_old([('x', 3), ('w', 4)]) + + tracker.push_new([('e', 14)]) + tracker.push_old([('a', 10), ('d', 13)]) + + tracker.done() + + self.assertEquals(['b', 'c', 'e'], added) + + def test_many_old(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('z', 26), ('hello', 100)]) + tracker.push_old([('a', 1), ('b', 2)]) + + tracker.push_new([]) + tracker.push_old([('c', 1), ('d', 2)]) + + tracker.push_new([]) + tracker.push_old([('e', 3), ('f', 4)]) + + tracker.push_new([]) + tracker.push_old([('g', 5), ('z', 26)]) + + tracker.done() + + self.assertEquals(['hello'], added) + + def test_high_old_bound(self): + added = [] + + tracker = IndexedStreamingDiffTracker(added.append, 2) + tracker.push_new([('z', 26), ('hello', 100)]) + tracker.push_old([('end1', 999), ('end2', 1000)]) + + tracker.push_new([]) + tracker.push_old([]) + + tracker.done() + + self.assertEquals(['z', 'hello'], added) + + +class StreamingDiffTrackerTests(unittest.TestCase): + def test_basic(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 3) + tracker.push_new(['a', 'b', 'c']) + tracker.push_old(['b']) + tracker.done() + + self.assertEquals(['a', 'c'], added) + + def test_same_streams(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 3) + tracker.push_new(['a', 'b', 'c']) + tracker.push_old(['a', 'b', 'c']) + tracker.done() + + self.assertEquals([], added) + + def test_some_new(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 5) + tracker.push_new(['a', 'b', 'c', 'd', 'e']) + tracker.push_old(['a', 'b', 'c']) + tracker.done() + + self.assertEquals(['d', 'e'], added) + + def test_offset_new(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 5) + tracker.push_new(['b', 'c', 'd', 'e']) + tracker.push_old(['a', 'b', 'c']) + tracker.done() + + self.assertEquals(['d', 'e'], added) + + def test_multiple_calls(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 3) + tracker.push_new(['a', 'b', 'c']) + tracker.push_old(['b', 'd', 'e']) + + tracker.push_new(['f', 'g', 'h']) + tracker.push_old(['g', 'h']) + tracker.done() + + self.assertEquals(['a', 'c', 'f'], added) + + def test_empty_old(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 3) + tracker.push_new(['a', 'b', 'c']) + tracker.push_old([]) + + tracker.push_new(['f', 'g', 'h']) + tracker.push_old([]) + tracker.done() + + self.assertEquals(['a', 'b', 'c', 'f', 'g', 'h'], added) + + def test_more_old(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 2) + tracker.push_new(['c', 'd']) + tracker.push_old(['a', 'b']) + + tracker.push_new([]) + tracker.push_old(['c']) + tracker.done() + + self.assertEquals(['d'], added) + + def test_more_new(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 4) + tracker.push_new(['a', 'b', 'c', 'd']) + tracker.push_old(['r']) + + tracker.push_new(['e', 'f', 'r', 'z']) + tracker.push_old([]) + tracker.done() + + self.assertEquals(['a', 'b', 'c', 'd', 'e', 'f', 'z'], added) + + def test_more_new2(self): + added = [] + + tracker = StreamingDiffTracker(added.append, 4) + tracker.push_new(['a', 'b', 'c', 'd']) + tracker.push_old(['r']) + + tracker.push_new(['e', 'f', 'g', 'h']) + tracker.push_old([]) + + tracker.push_new(['i', 'j', 'r', 'z']) + tracker.push_old([]) + tracker.done() + + self.assertEquals(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'z'], added) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_secscan.py b/test/test_secscan.py index bb6fee94b..7a40d81a9 100644 --- a/test/test_secscan.py +++ b/test/test_secscan.py @@ -11,7 +11,7 @@ from initdb import setup_database_for_testing, finished_database_for_testing from util.secscan.api import SecurityScannerAPI from util.secscan.analyzer import LayerAnalyzer from util.secscan.fake import fake_security_scanner -from util.secscan.notifier import process_notification_data +from util.secscan.notifier import SecurityNotificationHandler, ProcessNotificationPageResult from workers.security_notification_worker import SecurityNotificationWorker @@ -20,6 +20,13 @@ SIMPLE_REPO = 'simple' COMPLEX_REPO = 'complex' +def process_notification_data(notification_data): + handler = SecurityNotificationHandler(100) + result = handler.process_notification_page_data(notification_data) + handler.send_notifications() + return result == ProcessNotificationPageResult.FINISHED_PROCESSING + + class TestSecurityScanner(unittest.TestCase): def setUp(self): # Enable direct download in fake storage. @@ -57,7 +64,6 @@ class TestSecurityScanner(unittest.TestCase): self.assertEquals(engineVersion, parent.security_indexed_engine) self.assertTrue(security_scanner.has_layer(security_scanner.layer_id(parent))) - def test_get_layer(self): """ Test for basic retrieval of layers from the security scanner. """ layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) @@ -75,7 +81,6 @@ class TestSecurityScanner(unittest.TestCase): self.assertIsNotNone(result) self.assertEquals(result['Layer']['Name'], security_scanner.layer_id(layer)) - def test_analyze_layer_nodirectdownload_success(self): """ Tests analyzing a layer when direct download is disabled. """ @@ -114,7 +119,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, True, 1) - def test_analyze_layer_success(self): """ Tests that analyzing a layer successfully marks it as analyzed. """ @@ -129,7 +133,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, True, 1) - def test_analyze_layer_failure(self): """ Tests that failing to analyze a layer (because it 422s) marks it as analyzed but failed. """ @@ -146,7 +149,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, 1) - def test_analyze_layer_internal_error(self): """ Tests that failing to analyze a layer (because it 500s) marks it as not analyzed. """ @@ -163,7 +165,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, -1) - def test_analyze_layer_error(self): """ Tests that failing to analyze a layer (because it 400s) marks it as analyzed but failed. """ @@ -183,7 +184,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, 1) - def test_analyze_layer_missing_parent_handled(self): """ Tests that a missing parent causes an automatic reanalysis, which succeeds. """ @@ -214,7 +214,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, True, 1) - def test_analyze_layer_invalid_parent(self): """ Tests that trying to reanalyze a parent that is invalid causes the layer to be marked as analyzed, but failed. @@ -250,7 +249,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, 1) - def test_analyze_layer_unsupported_parent(self): """ Tests that attempting to analyze a layer whose parent is unanalyzable, results in the layer being marked as analyzed, but failed. @@ -271,7 +269,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, 1) - def test_analyze_layer_missing_storage(self): """ Tests trying to analyze a layer with missing storage. """ @@ -292,7 +289,6 @@ class TestSecurityScanner(unittest.TestCase): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest') self.assertAnalyzed(layer, security_scanner, False, 1) - def assert_analyze_layer_notify(self, security_indexed_engine, security_indexed, expect_notification): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) @@ -350,22 +346,18 @@ class TestSecurityScanner(unittest.TestCase): self.assertEquals(updated_layer.id, layer.id) self.assertTrue(updated_layer.security_indexed_engine > 0) - def test_analyze_layer_success_events(self): # Not previously indexed at all => Notification self.assert_analyze_layer_notify(IMAGE_NOT_SCANNED_ENGINE_VERSION, False, True) - def test_analyze_layer_success_no_notification(self): # Previously successfully indexed => No notification self.assert_analyze_layer_notify(0, True, False) - def test_analyze_layer_failed_then_success_notification(self): # Previously failed to index => Notification self.assert_analyze_layer_notify(0, False, True) - def test_notification_new_layers_not_vulnerable(self): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid) @@ -395,7 +387,6 @@ class TestSecurityScanner(unittest.TestCase): # Ensure that there are no event queue items for the layer. self.assertIsNone(notification_queue.get()) - def test_notification_delete(self): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid) @@ -425,7 +416,6 @@ class TestSecurityScanner(unittest.TestCase): # Ensure that there are no event queue items for the layer. self.assertIsNone(notification_queue.get()) - def test_notification_new_layers(self): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid) @@ -452,7 +442,7 @@ class TestSecurityScanner(unittest.TestCase): "Description": "Some service", "Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471", "Severity": "Low", - "FixedIn": {'Version': "9.23-5"}, + "FixedIn": {"Version": "9.23-5"}, } security_scanner.set_vulns(layer_id, [vuln_info]) @@ -473,7 +463,6 @@ class TestSecurityScanner(unittest.TestCase): self.assertEquals('Low', item_body['event_data']['vulnerability']['priority']) self.assertTrue(item_body['event_data']['vulnerability']['has_fix']) - def test_notification_no_new_layers(self): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) @@ -502,7 +491,6 @@ class TestSecurityScanner(unittest.TestCase): # Ensure that there are no event queue items for the layer. self.assertIsNone(notification_queue.get()) - def test_notification_no_new_layers_increased_severity(self): layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid) @@ -577,7 +565,6 @@ class TestSecurityScanner(unittest.TestCase): {'level': 0}) self.assertFalse(VulnerabilityFoundEvent().should_perform(event_data, notification)) - def test_select_images_to_scan(self): # Set all images to have a security index of a version to that of the config. expected_version = app.config['SECURITY_SCANNER_ENGINE_VERSION_TARGET'] @@ -591,7 +578,6 @@ class TestSecurityScanner(unittest.TestCase): self.assertIsNotNone(model.image.get_min_id_for_sec_scan(expected_version + 1)) self.assertTrue(len(model.image.get_images_eligible_for_scan(expected_version + 1)) > 0) - def test_notification_worker(self): layer1 = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) layer2 = model.tag.get_tag_image(ADMIN_ACCESS_USER, COMPLEX_REPO, 'prod', include_storage=True) @@ -634,7 +620,7 @@ class TestSecurityScanner(unittest.TestCase): security_scanner.set_vulns(security_scanner.layer_id(layer2), [new_vuln_info]) layer_ids = [security_scanner.layer_id(layer1), security_scanner.layer_id(layer2)] - notification_data = security_scanner.add_notification([], layer_ids, {}, new_vuln_info) + notification_data = security_scanner.add_notification([], layer_ids, None, new_vuln_info) # Test with a known notification with pages. data = { @@ -642,13 +628,103 @@ class TestSecurityScanner(unittest.TestCase): } worker = SecurityNotificationWorker(None) - self.assertTrue(worker.perform_notification_work(data, layer_limit=1)) + self.assertTrue(worker.perform_notification_work(data, layer_limit=2)) # Make sure all pages were processed by ensuring we have two notifications. time.sleep(1) self.assertIsNotNone(notification_queue.get()) self.assertIsNotNone(notification_queue.get()) + def test_notification_worker_offset_pages_not_indexed(self): + # Try without indexes. + self.assert_notification_worker_offset_pages(indexed=False) + + def test_notification_worker_offset_pages_indexed(self): + # Try with indexes. + self.assert_notification_worker_offset_pages(indexed=True) + + def assert_notification_worker_offset_pages(self, indexed=False): + layer1 = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True) + layer2 = model.tag.get_tag_image(ADMIN_ACCESS_USER, COMPLEX_REPO, 'prod', include_storage=True) + + # Add a repo events for the layers. + simple_repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO) + complex_repo = model.repository.get_repository(ADMIN_ACCESS_USER, COMPLEX_REPO) + + model.notification.create_repo_notification(simple_repo, 'vulnerability_found', + 'quay_notification', {}, {'level': 100}) + model.notification.create_repo_notification(complex_repo, 'vulnerability_found', + 'quay_notification', {}, {'level': 100}) + + # Ensure that there are no event queue items for the layer. + self.assertIsNone(notification_queue.get()) + + with fake_security_scanner() as security_scanner: + # Test with an unknown notification. + worker = SecurityNotificationWorker(None) + self.assertFalse(worker.perform_notification_work({ + 'Name': 'unknownnotification' + })) + + # Add some analyzed layers. + analyzer = LayerAnalyzer(app.config, self.api) + analyzer.analyze_recursively(layer1) + analyzer.analyze_recursively(layer2) + + # Add a notification with pages of data. + new_vuln_info = { + "Name": "CVE-TEST", + "Namespace": "debian:8", + "Description": "Some service", + "Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471", + "Severity": "Critical", + "FixedIn": {'Version': "9.23-5"}, + } + + security_scanner.set_vulns(security_scanner.layer_id(layer1), [new_vuln_info]) + security_scanner.set_vulns(security_scanner.layer_id(layer2), [new_vuln_info]) + + # Define offsetting sets of layer IDs, to test cross-pagination support. In this test, we + # will only serve 2 layer IDs per page: the first page will serve both of the 'New' layer IDs, + # but since the first 2 'Old' layer IDs are "earlier" than the shared ID of + # `devtable/simple:latest`, they won't get served in the 'New' list until the *second* page. + # The notification handling system should correctly not notify for this layer, even though it + # is marked 'New' on page 1 and marked 'Old' on page 2. Clair will served these + # IDs sorted in the same manner. + idx_old_layer_ids = [{'LayerName': 'old1', 'Index': 1}, + {'LayerName': 'old2', 'Index': 2}, + {'LayerName': security_scanner.layer_id(layer1), 'Index': 3}] + + idx_new_layer_ids = [{'LayerName': security_scanner.layer_id(layer1), 'Index': 3}, + {'LayerName': security_scanner.layer_id(layer2), 'Index': 4}] + + old_layer_ids = [t['LayerName'] for t in idx_old_layer_ids] + new_layer_ids = [t['LayerName'] for t in idx_new_layer_ids] + + if not indexed: + idx_old_layer_ids = None + idx_new_layer_ids = None + + notification_data = security_scanner.add_notification(old_layer_ids, new_layer_ids, None, + new_vuln_info, max_per_page=2, + indexed_old_layer_ids=idx_old_layer_ids, + indexed_new_layer_ids=idx_new_layer_ids) + + # Test with a known notification with pages. + data = { + 'Name': notification_data['Name'], + } + + worker = SecurityNotificationWorker(None) + self.assertTrue(worker.perform_notification_work(data, layer_limit=2)) + + # Make sure all pages were processed by ensuring we have only one notification. If the second + # page was not processed, then the `Old` entry for layer1 will not be found, and we'd get two + # notifications. + time.sleep(1) + self.assertIsNotNone(notification_queue.get()) + self.assertIsNone(notification_queue.get()) + if __name__ == '__main__': unittest.main() diff --git a/util/morecollections.py b/util/morecollections.py index 6d05c4d25..c9f5ff0cb 100644 --- a/util/morecollections.py +++ b/util/morecollections.py @@ -10,3 +10,217 @@ class AttrDict(dict): if isinstance(value, AttrDict): copy[key] = cls.deep_copy(value) return copy + + +class FastIndexList(object): + """ List which keeps track of the indicies of its items in a fast manner, and allows for + quick removal of items. + """ + def __init__(self): + self._list = [] + self._index_map = {} + self._index_offset = 0 + self._counter = 0 + + def add(self, item): + """ Adds an item to the index list. """ + self._list.append(item) + self._index_map[item] = self._counter + self._counter = self._counter + 1 + + def values(self): + """ Returns an iterable stream of all the values in the list. """ + return list(self._list) + + def index(self, value): + """ Returns the index of the given item in the list or None if none. """ + found = self._index_map.get(value, None) + if found is None: + return None + + return found - self._index_offset + + def pop_until(self, index_inclusive): + """ Pops off any items in the list until the given index, inclusive, and returns them. """ + values = self._list[0:index_inclusive+1] + for value in values: + self._index_map.pop(value, None) + + self._index_offset = self._index_offset + index_inclusive + 1 + self._list = self._list[index_inclusive+1:] + return values + + +class IndexedStreamingDiffTracker(object): + """ Helper class which tracks the difference between two streams of strings, + calling the `added` callback for strings when they are successfully verified + as being present in the first stream and not present in the second stream. + Unlike StreamingDiffTracker, this class expects each string value to have an + associated `index` value, which must be the same for equal values in both + streams and *must* be in order. This allows us to be a bit more efficient + in clearing up items that we know won't be present. The `index` is *not* + assumed to start at 0 or be contiguous, merely increasing. + """ + def __init__(self, reporter, result_per_stream): + self._reporter = reporter + self._reports_per_stream = result_per_stream + self._new_stream_finished = False + self._old_stream_finished = False + + self._new_stream = [] + self._old_stream = [] + + self._new_stream_map = {} + self._old_stream_map = {} + + def push_new(self, stream_tuples): + """ Pushes a list of values for the `New` stream. + """ + stream_tuples_list = list(stream_tuples) + assert len(stream_tuples_list) <= self._reports_per_stream + + if len(stream_tuples_list) < self._reports_per_stream: + self._new_stream_finished = True + + for (item, index) in stream_tuples_list: + if self._new_stream: + assert index > self._new_stream[-1].index + + self._new_stream_map[index] = item + self._new_stream.append(AttrDict(item=item, index=index)) + + self._process() + + def push_old(self, stream_tuples): + """ Pushes a list of values for the `Old` stream. + """ + if self._new_stream_finished and not self._new_stream: + # Nothing more to do. + return + + stream_tuples_list = list(stream_tuples) + assert len(stream_tuples_list) <= self._reports_per_stream + + if len(stream_tuples_list) < self._reports_per_stream: + self._old_stream_finished = True + + for (item, index) in stream_tuples: + if self._old_stream: + assert index > self._old_stream[-1].index + + self._old_stream_map[index] = item + self._old_stream.append(AttrDict(item=item, index=index)) + + self._process() + + def done(self): + self._old_stream_finished = True + self._process() + + def _process(self): + # Process any new items that can be reported. + old_lower_bound = self._old_stream[0].index if self._old_stream else -1 + for item_info in self._new_stream: + # If the new item's index <= the old_lower_bound, then we know + # we can check the old item map for it. + if item_info.index <= old_lower_bound or self._old_stream_finished: + if self._old_stream_map.get(item_info.index, None) is None: + self._reporter(item_info.item) + + # Remove the item from the map. + self._new_stream_map.pop(item_info.index, None) + + # Rebuild the new stream list (faster than just removing). + self._new_stream = [item_info for item_info in self._new_stream + if self._new_stream_map.get(item_info.index)] + + # Process any old items that can be removed. + new_lower_bound = self._new_stream[0].index if self._new_stream else -1 + for item_info in list(self._old_stream): + # Any items with indexes below the new lower bound can be removed, + # as any comparison from the new stream was done above. + if item_info.index < new_lower_bound: + self._old_stream_map.pop(item_info.index, None) + + # Rebuild the old stream list (faster than just removing). + self._old_stream = [item_info for item_info in self._old_stream + if self._old_stream_map.get(item_info.index)] + + + +class StreamingDiffTracker(object): + """ Helper class which tracks the difference between two streams of strings, calling the + `added` callback for strings when they are successfully verified as being present in + the first stream and not present in the second stream. This class requires that the + streams of strings be consistently ordered *in some way common to both* (but the + strings themselves do not need to be sorted). + """ + def __init__(self, reporter, result_per_stream): + self._reporter = reporter + self._reports_per_stream = result_per_stream + self._old_stream_finished = False + + self._old_stream = FastIndexList() + self._new_stream = FastIndexList() + + def done(self): + self._old_stream_finished = True + self.push_new([]) + + def push_new(self, stream_values): + """ Pushes a list of values for the `New` stream. + """ + + # Add all the new values to the list. + counter = 0 + for value in stream_values: + self._new_stream.add(value) + counter = counter + 1 + + assert counter <= self._reports_per_stream + + # Process them all to see if anything has changed. + for value in self._new_stream.values(): + old_index = self._old_stream.index(value) + if old_index is not None: + # The item is present, so we cannot report it. However, since we've reached this point, + # all items *before* this item in the `Old` stream are no longer necessary, so we can + # throw them out, along with this item. + self._old_stream.pop_until(old_index) + else: + # If the old stream has completely finished, then we can report, knowing no more old + # information will be present. + if self._old_stream_finished: + self._reporter(value) + self._new_stream.pop_until(self._new_stream.index(value)) + + def push_old(self, stream_values): + """ Pushes a stream of values for the `Old` stream. + """ + + if self._old_stream_finished: + return + + value_list = list(stream_values) + assert len(value_list) <= self._reports_per_stream + + for value in value_list: + # If the value exists in the new stream somewhere, then we know that all items *before* + # that index in the new stream will not be in the old stream, so we can report them. We can + # also remove the matching `New` item, as it is clearly in both streams. + new_index = self._new_stream.index(value) + if new_index is not None: + # Report all items up to the current item. + for item in self._new_stream.pop_until(new_index - 1): + self._reporter(item) + + # Remove the current item from the new stream. + self._new_stream.pop_until(0) + else: + # This item may be seen later. Add it to the old stream set. + self._old_stream.add(value) + + # Check to see if the `Old` stream has finished. + if len(value_list) < self._reports_per_stream: + self._old_stream_finished = True + diff --git a/util/secscan/fake.py b/util/secscan/fake.py index a741868c5..0ed5c12f5 100644 --- a/util/secscan/fake.py +++ b/util/secscan/fake.py @@ -58,15 +58,22 @@ class FakeSecurityScanner(object): """ Returns whether a notification with the given ID is found in the scanner. """ return notification_id in self.notifications - def add_notification(self, old_layer_ids, new_layer_ids, old_vuln, new_vuln): + def add_notification(self, old_layer_ids, new_layer_ids, old_vuln, new_vuln, max_per_page=100, + indexed_old_layer_ids=None, indexed_new_layer_ids=None): """ Adds a new notification over the given sets of layer IDs and vulnerability information, returning the structural data of the notification created. """ notification_id = str(uuid.uuid4()) + if old_vuln is None: + old_vuln = dict(new_vuln) + self.notifications[notification_id] = dict(old_layer_ids=old_layer_ids, new_layer_ids=new_layer_ids, old_vuln=old_vuln, - new_vuln=new_vuln) + new_vuln=new_vuln, + max_per_page=max_per_page, + indexed_old_layer_ids=indexed_old_layer_ids, + indexed_new_layer_ids=indexed_new_layer_ids) return self._get_notification_data(notification_id, 0, 100) @@ -106,6 +113,8 @@ class FakeSecurityScanner(object): """ Returns the structural data for the notification with the given ID, paginated using the given page and limit. """ notification = self.notifications[notification_id] + limit = min(limit, notification['max_per_page']) + notification_data = { "Name": notification_id, "Created": "1456247389", @@ -127,6 +136,11 @@ class FakeSecurityScanner(object): 'LayersIntroducingVulnerability': old_layer_ids, } + if notification.get('indexed_old_layer_ids', None): + indexed_old_layer_ids = notification['indexed_old_layer_ids'][start_index:end_index] + notification_data['Old']['OrderedLayersIntroducingVulnerability'] = indexed_old_layer_ids + + if notification.get('new_vuln'): new_layer_ids = notification['new_layer_ids'] new_layer_ids = new_layer_ids[start_index:end_index] @@ -137,6 +151,11 @@ class FakeSecurityScanner(object): 'LayersIntroducingVulnerability': new_layer_ids, } + if notification.get('indexed_new_layer_ids', None): + indexed_new_layer_ids = notification['indexed_new_layer_ids'][start_index:end_index] + notification_data['New']['OrderedLayersIntroducingVulnerability'] = indexed_new_layer_ids + + if has_additional_page: notification_data['NextPage'] = str(page+1) diff --git a/util/secscan/notifier.py b/util/secscan/notifier.py index e1aa68731..336514ce9 100644 --- a/util/secscan/notifier.py +++ b/util/secscan/notifier.py @@ -2,6 +2,7 @@ import logging import sys from collections import defaultdict +from enum import Enum from app import secscan_api from data.model.tag import filter_tags_have_repository_event, get_matching_tags @@ -10,105 +11,169 @@ from data.database import (Image, ImageStorage, ExternalNotificationEvent, Repos from endpoints.notificationhelper import notification_batch from util.secscan import PRIORITY_LEVELS from util.secscan.api import APIRequestFailure -from util.morecollections import AttrDict +from util.morecollections import AttrDict, StreamingDiffTracker, IndexedStreamingDiffTracker + logger = logging.getLogger(__name__) +class ProcessNotificationPageResult(Enum): + FINISHED_PAGE = 'Finished Page' + FINISHED_PROCESSING = 'Finished Processing' + FAILED = 'Failed' -def process_notification_data(notification_data): - """ Processes the given notification data to spawn vulnerability notifications as necessary. - Returns whether the processing succeeded. + +class SecurityNotificationHandler(object): + """ Class to process paginated notifications from the security scanner and issue + Quay vulnerability_found notifications for all necessary tags. Callers should + initialize, call process_notification_page_data for each page until it returns + FINISHED_PROCESSING or FAILED and, if succeeded, then call send_notifications + to send out the notifications queued. """ - if not 'New' in notification_data: - # Nothing to do. - return True + def __init__(self, results_per_stream): + self.tag_map = defaultdict(set) + self.repository_map = {} + self.check_map = {} - new_data = notification_data['New'] - old_data = notification_data.get('Old', {}) + self.stream_tracker = None + self.results_per_stream = results_per_stream + self.reporting_failed = False + self.vulnerability_info = None - new_vuln = new_data['Vulnerability'] - old_vuln = old_data.get('Vulnerability', {}) + self.event = ExternalNotificationEvent.get(name='vulnerability_found') - new_layer_ids = set(new_data.get('LayersIntroducingVulnerability', [])) - old_layer_ids = set(old_data.get('LayersIntroducingVulnerability', [])) + def send_notifications(self): + """ Sends all queued up notifications. """ + if self.vulnerability_info is None: + return - new_severity = PRIORITY_LEVELS.get(new_vuln.get('Severity', 'Unknown'), {'index': sys.maxint}) - old_severity = PRIORITY_LEVELS.get(old_vuln.get('Severity', 'Unknown'), {'index': sys.maxint}) + new_vuln = self.vulnerability_info + new_severity = PRIORITY_LEVELS.get(new_vuln.get('Severity', 'Unknown'), {'index': sys.maxint}) - # By default we only notify the new layers that are affected by the vulnerability. If, however, - # the severity of the vulnerability has increased, we need to notify *all* layers, as we might - # need to send new notifications for older layers. - notify_layers = new_layer_ids - old_layer_ids - if new_severity['index'] < old_severity['index']: - notify_layers = new_layer_ids | old_layer_ids + # For each of the tags found, issue a notification. + with notification_batch() as spawn_notification: + for repository_id in self.tag_map: + tags = self.tag_map[repository_id] + event_data = { + 'tags': list(tags), + 'vulnerability': { + 'id': new_vuln['Name'], + 'description': new_vuln.get('Description', None), + 'link': new_vuln.get('Link', None), + 'priority': new_severity['title'], + 'has_fix': 'FixedIn' in new_vuln, + }, + } - if not notify_layers: - # Nothing more to do. - return True + # TODO(jzelinskie): remove when more endpoints have been converted to using interfaces + repository = AttrDict({ + 'namespace_name': self.repository_map[repository_id].namespace_user.username, + 'name': self.repository_map[repository_id].name, + }) - # Lookup the external event for when we have vulnerabilities. - event = ExternalNotificationEvent.get(name='vulnerability_found') + spawn_notification(repository, 'vulnerability_found', event_data) - # For each layer, retrieving the matching tags and join with repository to determine which - # require new notifications. - tag_map = defaultdict(set) - repository_map = {} - cve_id = new_vuln['Name'] + def process_notification_page_data(self, notification_page_data): + """ Processes the given notification page data to spawn vulnerability notifications as + necessary. Returns the status of the processing. + """ + if not 'New' in notification_page_data: + return self._done() - # Find all tags that contain the layer(s) introducing the vulnerability, - # in repositories that have the event setup. - for layer_id in notify_layers: + new_data = notification_page_data['New'] + old_data = notification_page_data.get('Old', {}) + + new_vuln = new_data['Vulnerability'] + old_vuln = old_data.get('Vulnerability', {}) + + self.vulnerability_info = new_vuln + + new_layer_ids = new_data.get('LayersIntroducingVulnerability', []) + old_layer_ids = old_data.get('LayersIntroducingVulnerability', []) + + new_severity = PRIORITY_LEVELS.get(new_vuln.get('Severity', 'Unknown'), {'index': sys.maxint}) + old_severity = PRIORITY_LEVELS.get(old_vuln.get('Severity', 'Unknown'), {'index': sys.maxint}) + + # Check if the severity of the vulnerability has increased. If so, then we report this + # vulnerability for *all* layers, rather than a difference, as it is important for everyone. + if new_severity['index'] < old_severity['index']: + # The vulnerability has had its severity increased. Report for *all* layers. + all_layer_ids = set(new_layer_ids) | set(old_layer_ids) + for layer_id in all_layer_ids: + self._report(layer_id) + + if 'NextPage' not in notification_page_data: + return self._done() + else: + return ProcessNotificationPageResult.FINISHED_PAGE + + # Otherwise, only send the notification to new layers. To find only the new layers, we + # need to do a streaming diff vs the old layer IDs stream. + + # Check for ordered data. If found, we use the indexed tracker, which is faster and + # more memory efficient. + is_indexed = False + if 'OrderedLayersIntroducingVulnerability' in new_data: + def tuplize(stream): + return [(entry['LayerName'], entry['Index']) for entry in stream] + + new_layer_ids = tuplize(new_data.get('OrderedLayersIntroducingVulnerability', [])) + old_layer_ids = tuplize(old_data.get('OrderedLayersIntroducingVulnerability', [])) + is_indexed = True + + # If this is the first call, initialize the tracker. + if self.stream_tracker is None: + self.stream_tracker = (IndexedStreamingDiffTracker(self._report, self.results_per_stream) + if is_indexed + else StreamingDiffTracker(self._report, self.results_per_stream)) + + # Call to add the old and new layer ID streams to the tracker. The tracker itself will + # call _report whenever it has determined a new layer has been found. + self.stream_tracker.push_new(new_layer_ids) + self.stream_tracker.push_old(old_layer_ids) + + # If the reporting failed at any point, nothing more we can do. + if self.reporting_failed: + return ProcessNotificationPageResult.FAILED + + # Check to see if there are any additional pages to process. + if 'NextPage' not in notification_page_data: + return self._done() + else: + return ProcessNotificationPageResult.FINISHED_PAGE + + def _done(self): + if self.stream_tracker is not None: + self.stream_tracker.done() + + if self.reporting_failed: + return ProcessNotificationPageResult.FAILED + + return ProcessNotificationPageResult.FINISHED_PROCESSING + + def _report(self, new_layer_id): # Split the layer ID into its Docker Image ID and storage ID. - (docker_image_id, storage_uuid) = layer_id.split('.', 2) + (docker_image_id, storage_uuid) = new_layer_id.split('.', 2) # Find the matching tags. matching = get_matching_tags(docker_image_id, storage_uuid, RepositoryTag, Repository, Image, ImageStorage) - tags = list(filter_tags_have_repository_event(matching, event)) + tags = list(filter_tags_have_repository_event(matching, self.event)) - check_map = {} + cve_id = self.vulnerability_info['Name'] for tag in tags: - # Verify that the tag's root image has the vulnerability. + # Verify that the tag's *top layer* has the vulnerability. tag_layer_id = '%s.%s' % (tag.image.docker_image_id, tag.image.storage.uuid) - logger.debug('Checking if layer %s is vulnerable to %s', tag_layer_id, cve_id) - - if not tag_layer_id in check_map: + if not tag_layer_id in self.check_map: + logger.debug('Checking if layer %s is vulnerable to %s', tag_layer_id, cve_id) try: - is_vulerable = secscan_api.check_layer_vulnerable(tag_layer_id, cve_id) + self.check_map[tag_layer_id] = secscan_api.check_layer_vulnerable(tag_layer_id, cve_id) except APIRequestFailure: - return False - - check_map[tag_layer_id] = is_vulerable + self.reporting_failed = True + return logger.debug('Result of layer %s is vulnerable to %s check: %s', tag_layer_id, cve_id, - check_map[tag_layer_id]) - - if check_map[tag_layer_id]: + self.check_map[tag_layer_id]) + if self.check_map[tag_layer_id]: # Add the vulnerable tag to the list. - tag_map[tag.repository_id].add(tag.name) - repository_map[tag.repository_id] = tag.repository - - # For each of the tags found, issue a notification. - with notification_batch() as spawn_notification: - for repository_id in tag_map: - tags = tag_map[repository_id] - event_data = { - 'tags': list(tags), - 'vulnerability': { - 'id': cve_id, - 'description': new_vuln.get('Description', None), - 'link': new_vuln.get('Link', None), - 'priority': new_severity['title'], - 'has_fix': 'FixedIn' in new_vuln, - }, - } - - # TODO(jzelinskie): remove when more endpoints have been converted to using interfaces - repository = AttrDict({ - 'namespace_name': repository_map[repository_id].namespace_user.username, - 'name': repository_map[repository_id].name, - }) - spawn_notification(repository, 'vulnerability_found', event_data) - - return True - + self.tag_map[tag.repository_id].add(tag.name) + self.repository_map[tag.repository_id] = tag.repository diff --git a/workers/security_notification_worker.py b/workers/security_notification_worker.py index 7717048df..11e15f295 100644 --- a/workers/security_notification_worker.py +++ b/workers/security_notification_worker.py @@ -6,7 +6,7 @@ import features from app import secscan_notification_queue, secscan_api from workers.queueworker import QueueWorker, JobException -from util.secscan.notifier import process_notification_data +from util.secscan.notifier import SecurityNotificationHandler, ProcessNotificationPageResult logger = logging.getLogger(__name__) @@ -28,11 +28,15 @@ class SecurityNotificationWorker(QueueWorker): notification_name = data['Name'] current_page = data.get('page', None) + handler = SecurityNotificationHandler(layer_limit) while True: + # Retrieve the current page of notification data from the security scanner. (response_data, should_retry) = secscan_api.get_notification(notification_name, layer_limit=layer_limit, page=current_page) + + # If no response, something went wrong. if response_data is None: if should_retry: raise JobException() @@ -44,25 +48,34 @@ class SecurityNotificationWorker(QueueWorker): # Return to mark the job as "complete", as we'll never be able to finish it. return False + # Extend processing on the queue item so it doesn't expire while we're working. self.extend_processing(_PROCESSING_SECONDS, json.dumps(data)) - notification_data = response_data['Notification'] - if not process_notification_data(notification_data): - raise JobException() - # Check for a next page of results. If none, we're done. - if 'NextPage' not in notification_data: - # Mark the notification as read and processed. + # Process the notification data. + notification_data = response_data['Notification'] + result = handler.process_notification_page_data(notification_data) + + # Possible states after processing: failed to process, finished processing entirely + # or finished processing the page. + if result == ProcessNotificationPageResult.FAILED: + # Something went wrong. + raise JobException + + if result == ProcessNotificationPageResult.FINISHED_PROCESSING: + # Mark the notification as read. if not secscan_api.mark_notification_read(notification_name): # Return to mark the job as "complete", as we'll never be able to finish it. logger.error('Failed to mark notification %s as read', notification_name) return False + # Send the generated Quay notifications. + handler.send_notifications() return True - # Otherwise, save the next page token into the queue item (so we can pick up from here if - # something goes wrong in the next loop iteration), and continue. - current_page = notification_data['NextPage'] - data['page'] = current_page + if result == ProcessNotificationPageResult.FINISHED_PAGE: + # Continue onto the next page. + current_page = notification_data['NextPage'] + continue if __name__ == '__main__':