Merge pull request #2607 from coreos-inc/faster-security-notify

Batch the tag lookups in the security notification worker in an attempt to significant reduce load
This commit is contained in:
josephschorr 2017-05-03 13:49:13 -04:00 committed by GitHub
commit 19f67bfa1b
6 changed files with 277 additions and 67 deletions

View file

@ -6,6 +6,29 @@ from data.database import (Repository, User, Team, TeamMember, RepositoryPermiss
Namespace, Visibility, ImageStorage, Image, RepositoryKind, Namespace, Visibility, ImageStorage, Image, RepositoryKind,
db_for_update) db_for_update)
def reduce_as_tree(queries_to_reduce):
""" This method will split a list of queries into halves recursively until we reach individual
queries, at which point it will start unioning the queries, or the already unioned subqueries.
This works around a bug in peewee SQL generation where reducing linearly generates a chain
of queries that will exceed the recursion depth limit when it has around 80 queries.
"""
mid = len(queries_to_reduce)/2
left = queries_to_reduce[:mid]
right = queries_to_reduce[mid:]
to_reduce_right = right[0]
if len(right) > 1:
to_reduce_right = reduce_as_tree(right)
if len(left) > 1:
to_reduce_left = reduce_as_tree(left)
elif len(left) == 1:
to_reduce_left = left[0]
else:
return to_reduce_right
return to_reduce_left.union_all(to_reduce_right)
def get_existing_repository(namespace_name, repository_name, for_update=False, kind_filter=None): def get_existing_repository(namespace_name, repository_name, for_update=False, kind_filter=None):
query = (Repository query = (Repository

View file

@ -17,23 +17,18 @@ from util.canonicaljson import canonicalize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_repository_image_and_deriving(docker_image_id, storage_uuid): def get_image_with_storage(docker_image_id, storage_uuid):
""" Returns all matching images with the given docker image ID and storage uuid, along with any """ Returns the image with the given docker image ID and storage uuid or None if none.
images which have the image ID as parents.
""" """
try: try:
image_found = (Image return (Image
.select() .select()
.join(ImageStorage) .join(ImageStorage)
.where(Image.docker_image_id == docker_image_id, .where(Image.docker_image_id == docker_image_id,
ImageStorage.uuid == storage_uuid) ImageStorage.uuid == storage_uuid)
.get()) .get())
except Image.DoesNotExist: except Image.DoesNotExist:
return Image.select().where(Image.id < 0) # Empty query return None
ancestors_pattern = '%s%s/%%' % (image_found.ancestors, image_found.id)
return Image.select().where((Image.ancestors ** ancestors_pattern) |
(Image.id == image_found.id))
def get_parent_images_with_placements(namespace_name, repository_name, image_obj): def get_parent_images_with_placements(namespace_name, repository_name, image_obj):

View file

@ -68,7 +68,8 @@ def _orphaned_storage_query(candidate_ids):
.from_(storage_subq)) .from_(storage_subq))
# Build the set of storages that are missing. These storages are orphaned. # Build the set of storages that are missing. These storages are orphaned.
nonorphaned_storage_ids = {storage.id for storage in _reduce_as_tree(nonorphaned_queries)} nonorphaned_storage_ids = {storage.id for storage
in _basequery.reduce_as_tree(nonorphaned_queries)}
return list(candidate_ids - nonorphaned_storage_ids) return list(candidate_ids - nonorphaned_storage_ids)
@ -275,31 +276,7 @@ def lookup_repo_storages_by_content_checksum(repo, checksums):
.select(SQL('*')) .select(SQL('*'))
.from_(candidate_subq)) .from_(candidate_subq))
return _reduce_as_tree(queries) return _basequery.reduce_as_tree(queries)
def _reduce_as_tree(queries_to_reduce):
""" This method will split a list of queries into halves recursively until we reach individual
queries, at which point it will start unioning the queries, or the already unioned subqueries.
This works around a bug in peewee SQL generation where reducing linearly generates a chain
of queries that will exceed the recursion depth limit when it has around 80 queries.
"""
mid = len(queries_to_reduce)/2
left = queries_to_reduce[:mid]
right = queries_to_reduce[mid:]
to_reduce_right = right[0]
if len(right) > 1:
to_reduce_right = _reduce_as_tree(right)
if len(left) > 1:
to_reduce_left = _reduce_as_tree(left)
elif len(left) == 1:
to_reduce_left = left[0]
else:
return to_reduce_right
return to_reduce_left.union_all(to_reduce_right)
def set_image_storage_metadata(docker_image_id, namespace_name, repository_name, image_size, def set_image_storage_metadata(docker_image_id, namespace_name, repository_name, image_size,

View file

@ -53,17 +53,128 @@ def _tag_alive(query, now_ts=None):
(RepositoryTag.lifetime_end_ts > now_ts)) (RepositoryTag.lifetime_end_ts > now_ts))
def filter_has_repository_event(query, event):
""" Filters the query by ensuring the repositories returned have the given event. """
return (query
.join(Repository)
.join(RepositoryNotification)
.where(RepositoryNotification.event == event))
def filter_tags_have_repository_event(query, event):
""" Filters the query by ensuring the repository tags live in a repository that has the given
event. Also returns the image storage for the tag's image and orders the results by
lifetime_start_ts.
"""
query = filter_has_repository_event(query, event)
query = query.switch(Image).join(ImageStorage)
query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
return query
_MAX_SUB_QUERIES = 100
_MAX_IMAGE_LOOKUP_COUNT = 500
def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=None,
selections=None):
""" Returns all tags that contain the images with the given docker_image_id and storage_uuid,
as specified as an iterable of pairs. """
if not image_pairs:
return []
image_pairs_set = set(image_pairs)
# Find all possible matching image+storages.
images = []
while image_pairs:
image_pairs_slice = image_pairs[:_MAX_IMAGE_LOOKUP_COUNT]
ids = [pair[0] for pair in image_pairs_slice]
uuids = [pair[1] for pair in image_pairs_slice]
images_query = (Image
.select(Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid)
.join(ImageStorage)
.where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)
.switch(Image))
if filter_images is not None:
images_query = filter_images(images_query)
images.extend(list(images_query))
image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:]
# Filter down to those images actually in the pairs set and build the set of queries to run.
individual_image_queries = []
for img in images:
# Make sure the image found is in the set of those requested, and that we haven't already
# processed it. We need this check because the query above checks for images with matching
# IDs OR storage UUIDs, rather than the expected ID+UUID pair. We do this for efficiency
# reasons, and it is highly unlikely we'll find an image with a mismatch, but we need this
# check to be absolutely sure.
pair = (img.docker_image_id, img.storage.uuid)
if pair not in image_pairs_set:
continue
# Remove the pair so we don't try it again.
image_pairs_set.remove(pair)
ancestors_str = '%s%s/%%' % (img.ancestors, img.id)
query = (Image
.select(Image.id)
.where((Image.id == img.id) | (Image.ancestors ** ancestors_str)))
individual_image_queries.append(query)
if not individual_image_queries:
return []
# Shard based on the max subquery count. This is used to prevent going over the DB's max query
# size, as well as to prevent the DB from locking up on a massive query.
sharded_queries = []
while individual_image_queries:
shard = individual_image_queries[:_MAX_SUB_QUERIES]
sharded_queries.append(_basequery.reduce_as_tree(shard))
individual_image_queries = individual_image_queries[_MAX_SUB_QUERIES:]
# Collect IDs of the tags found for each query.
tags = {}
for query in sharded_queries:
tag_query = (_tag_alive(RepositoryTag
.select(*(selections or []))
.distinct()
.join(Image)
.where(RepositoryTag.hidden == False)
.where(Image.id << query)
.switch(RepositoryTag)))
if filter_tags is not None:
tag_query = filter_tags(tag_query)
for tag in tag_query:
tags[tag.id] = tag
return tags.values()
def get_matching_tags(docker_image_id, storage_uuid, *args): def get_matching_tags(docker_image_id, storage_uuid, *args):
""" Returns a query pointing to all tags that contain the image with the """ Returns a query pointing to all tags that contain the image with the
given docker_image_id and storage_uuid. """ given docker_image_id and storage_uuid. """
image_query = image.get_repository_image_and_deriving(docker_image_id, storage_uuid) image_row = image.get_image_with_storage(docker_image_id, storage_uuid)
if image_row is None:
return RepositoryTag.select().where(RepositoryTag.id < 0) # Empty query.
ancestors_str = '%s%s/%%' % (image_row.ancestors, image_row.id)
return _tag_alive(RepositoryTag return _tag_alive(RepositoryTag
.select(*args) .select(*args)
.distinct() .distinct()
.join(Image) .join(Image)
.join(ImageStorage) .join(ImageStorage)
.where(Image.id << image_query, RepositoryTag.hidden == False)) .where(RepositoryTag.hidden == False)
.where((Image.id == image_row.id) |
(Image.ancestors ** ancestors_str)))
def get_tags_for_image(image_id, *args): def get_tags_for_image(image_id, *args):
@ -74,15 +185,6 @@ def get_tags_for_image(image_id, *args):
RepositoryTag.hidden == False)) RepositoryTag.hidden == False))
def filter_tags_have_repository_event(query, event):
return (query
.switch(RepositoryTag)
.join(Repository)
.join(RepositoryNotification)
.where(RepositoryNotification.event == event)
.order_by(RepositoryTag.lifetime_start_ts.desc()))
def get_tag_manifest_digests(tags): def get_tag_manifest_digests(tags):
""" Returns a map from tag ID to its associated manifest digest, if any. """ """ Returns a map from tag ID to its associated manifest digest, if any. """
if not tags: if not tags:

View file

@ -1,8 +1,109 @@
import pytest
from mock import patch
from data.database import Image, RepositoryTag, ImageStorage, Repository
from data.model.repository import create_repository from data.model.repository import create_repository
from data.model.tag import list_active_repo_tags, create_or_update_tag, delete_tag from data.model.tag import (list_active_repo_tags, create_or_update_tag, delete_tag,
get_matching_tags, _tag_alive, get_matching_tags_for_images)
from data.model.image import find_create_or_link_image from data.model.image import find_create_or_link_image
from test.fixtures import * from test.fixtures import *
def _get_expected_tags(image):
expected_query = (RepositoryTag
.select()
.join(Image)
.where(RepositoryTag.hidden == False)
.where((Image.id == image.id) | (Image.ancestors ** ('%%/%s/%%' % image.id))))
return set([tag.id for tag in _tag_alive(expected_query)])
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tags(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
# Test for every image in the test database.
for image in Image.select(Image, ImageStorage).join(ImageStorage):
matching_query = get_matching_tags(image.docker_image_id, image.storage.uuid)
matching_tags = set([tag.id for tag in matching_query])
expected_tags = _get_expected_tags(image)
assert matching_tags == expected_tags, "mismatch for image %s" % image.id
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tag_ids_for_images(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
# Try for various sets of the first N images.
for count in [5, 10, 15]:
pairs = []
expected_tags_ids = set()
for image in Image.select(Image, ImageStorage).join(ImageStorage):
if len(pairs) >= count:
break
pairs.append((image.docker_image_id, image.storage.uuid))
expected_tags_ids.update(_get_expected_tags(image))
matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
assert matching_tags_ids == expected_tags_ids
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tag_ids_for_all_images(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
pairs = []
for image in Image.select(Image, ImageStorage).join(ImageStorage):
pairs.append((image.docker_image_id, image.storage.uuid))
expected_tags_ids = set([tag.id for tag in _tag_alive(RepositoryTag.select())])
matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
# Ensure every alive tag was found.
assert matching_tags_ids == expected_tags_ids
def test_get_matching_tag_ids_images_filtered(initialized_db):
def filter_query(query):
return query.join(Repository).where(Repository.name == 'simple')
filtered_images = filter_query(Image
.select(Image, ImageStorage)
.join(RepositoryTag)
.switch(Image)
.join(ImageStorage)
.switch(Image))
expected_tags_query = _tag_alive(filter_query(RepositoryTag
.select()))
pairs = []
for image in filtered_images:
pairs.append((image.docker_image_id, image.storage.uuid))
matching_tags = get_matching_tags_for_images(pairs, filter_images=filter_query,
filter_tags=filter_query)
expected_tag_ids = set([tag.id for tag in expected_tags_query])
matching_tags_ids = set([tag.id for tag in matching_tags])
# Ensure every alive tag was found.
assert matching_tags_ids == expected_tag_ids
def assert_tags(repository, *args): def assert_tags(repository, *args):
tags = list(list_active_repo_tags(repository)) tags = list(list_active_repo_tags(repository))
assert len(tags) == len(args) assert len(tags) == len(args)

View file

@ -5,7 +5,9 @@ from collections import defaultdict
from enum import Enum from enum import Enum
from app import secscan_api from app import secscan_api
from data.model.tag import filter_tags_have_repository_event, get_matching_tags from data.model.tag import (filter_has_repository_event, filter_tags_have_repository_event,
get_matching_tags_for_images)
from data.database import (Image, ImageStorage, ExternalNotificationEvent, Repository, from data.database import (Image, ImageStorage, ExternalNotificationEvent, Repository,
RepositoryTag) RepositoryTag)
from endpoints.notificationhelper import notification_batch from endpoints.notificationhelper import notification_batch
@ -33,10 +35,10 @@ class SecurityNotificationHandler(object):
self.tag_map = defaultdict(set) self.tag_map = defaultdict(set)
self.repository_map = {} self.repository_map = {}
self.check_map = {} self.check_map = {}
self.layer_ids = set()
self.stream_tracker = None self.stream_tracker = None
self.results_per_stream = results_per_stream self.results_per_stream = results_per_stream
self.reporting_failed = False
self.vulnerability_info = None self.vulnerability_info = None
self.event = ExternalNotificationEvent.get(name='vulnerability_found') self.event = ExternalNotificationEvent.get(name='vulnerability_found')
@ -133,10 +135,6 @@ class SecurityNotificationHandler(object):
self.stream_tracker.push_new(new_layer_ids) self.stream_tracker.push_new(new_layer_ids)
self.stream_tracker.push_old(old_layer_ids) self.stream_tracker.push_old(old_layer_ids)
# If the reporting failed at any point, nothing more we can do.
if self.reporting_failed:
return ProcessNotificationPageResult.FAILED
# Check to see if there are any additional pages to process. # Check to see if there are any additional pages to process.
if 'NextPage' not in notification_page_data: if 'NextPage' not in notification_page_data:
return self._done() return self._done()
@ -145,21 +143,33 @@ class SecurityNotificationHandler(object):
def _done(self): def _done(self):
if self.stream_tracker is not None: if self.stream_tracker is not None:
# Mark the tracker as done, so that it finishes reporting any outstanding layers.
self.stream_tracker.done() self.stream_tracker.done()
if self.reporting_failed: # Process all the layers.
return ProcessNotificationPageResult.FAILED if self.vulnerability_info is not None:
if not self._process_layers():
return ProcessNotificationPageResult.FAILED
return ProcessNotificationPageResult.FINISHED_PROCESSING return ProcessNotificationPageResult.FINISHED_PROCESSING
def _report(self, new_layer_id): def _report(self, new_layer_id):
# Split the layer ID into its Docker Image ID and storage ID. self.layer_ids.add(new_layer_id)
(docker_image_id, storage_uuid) = new_layer_id.split('.', 2)
def _process_layers(self):
# Builds the pairs of layer ID and storage uuid.
pairs = [tuple(layer_id.split('.', 2)) for layer_id in self.layer_ids]
def filter_notifying_repos(query):
return filter_has_repository_event(query, self.event)
def filter_and_order(query):
return filter_tags_have_repository_event(query, self.event)
# Find the matching tags. # Find the matching tags.
matching = get_matching_tags(docker_image_id, storage_uuid, RepositoryTag, Repository, tags = get_matching_tags_for_images(pairs, selections=[RepositoryTag, Image, ImageStorage],
Image, ImageStorage) filter_images=filter_notifying_repos,
tags = list(filter_tags_have_repository_event(matching, self.event)) filter_tags=filter_and_order)
cve_id = self.vulnerability_info['Name'] cve_id = self.vulnerability_info['Name']
for tag in tags: for tag in tags:
@ -170,12 +180,14 @@ class SecurityNotificationHandler(object):
try: try:
self.check_map[tag_layer_id] = secscan_api.check_layer_vulnerable(tag_layer_id, cve_id) self.check_map[tag_layer_id] = secscan_api.check_layer_vulnerable(tag_layer_id, cve_id)
except APIRequestFailure: except APIRequestFailure:
self.reporting_failed = True return False
return
logger.debug('Result of layer %s is vulnerable to %s check: %s', tag_layer_id, cve_id, logger.debug('Result of layer %s is vulnerable to %s check: %s', tag_layer_id, cve_id,
self.check_map[tag_layer_id]) self.check_map[tag_layer_id])
if self.check_map[tag_layer_id]: if self.check_map[tag_layer_id]:
# Add the vulnerable tag to the list. # Add the vulnerable tag to the list.
self.tag_map[tag.repository_id].add(tag.name) self.tag_map[tag.repository_id].add(tag.name)
self.repository_map[tag.repository_id] = tag.repository self.repository_map[tag.repository_id] = tag.repository
return True