diff --git a/data/model/tag.py b/data/model/tag.py index 79f132109..14abc026d 100644 --- a/data/model/tag.py +++ b/data/model/tag.py @@ -53,10 +53,30 @@ def _tag_alive(query, now_ts=None): (RepositoryTag.lifetime_end_ts > now_ts)) +def filter_has_repository_event(query, event): + """ Filters the query by ensuring the repositories returned have the given event. """ + return (query + .join(Repository) + .join(RepositoryNotification) + .where(RepositoryNotification.event == event)) + + +def filter_tags_have_repository_event(query, event): + """ Filters the query by ensuring the repository tags live in a repository that has the given + event. Also returns the image storage for the tag's image and orders the results by + lifetime_start_ts. + """ + query = filter_has_repository_event(query, event) + query = query.switch(Image).join(ImageStorage) + query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc()) + return query + + _MAX_SUB_QUERIES = 100 _MAX_IMAGE_LOOKUP_COUNT = 500 -def get_matching_tags_for_images(image_pairs, filter_query=None, selections=None): +def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=None, + selections=None): """ Returns all tags that contain the images with the given docker_image_id and storage_uuid, as specified as an iterable of pairs. """ if not image_pairs: @@ -76,7 +96,11 @@ def get_matching_tags_for_images(image_pairs, filter_query=None, selections=None images_query = (Image .select(Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid) .join(ImageStorage) - .where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)) + .where(Image.docker_image_id << ids, ImageStorage.uuid << uuids) + .switch(Image)) + + if filter_images is not None: + images_query = filter_images(images_query) images.extend(list(images_query)) image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:] @@ -119,10 +143,11 @@ def get_matching_tags_for_images(image_pairs, filter_query=None, selections=None .distinct() .join(Image) .where(RepositoryTag.hidden == False) - .where(Image.id << query))) + .where(Image.id << query) + .switch(RepositoryTag))) - if filter_query is not None: - tag_query = filter_query(tag_query) + if filter_tags is not None: + tag_query = filter_tags(tag_query) for tag in tag_query: tags[tag.id] = tag @@ -156,15 +181,6 @@ def get_tags_for_image(image_id, *args): RepositoryTag.hidden == False)) -def filter_tags_have_repository_event(query, event): - return (query - .switch(RepositoryTag) - .join(Repository) - .join(RepositoryNotification) - .where(RepositoryNotification.event == event) - .order_by(RepositoryTag.lifetime_start_ts.desc())) - - def get_tag_manifest_digests(tags): """ Returns a map from tag ID to its associated manifest digest, if any. """ if not tags: diff --git a/data/model/test/test_tag.py b/data/model/test/test_tag.py index fa993eedc..8fa0eb852 100644 --- a/data/model/test/test_tag.py +++ b/data/model/test/test_tag.py @@ -94,7 +94,8 @@ def test_get_matching_tag_ids_images_filtered(initialized_db): for image in filtered_images: pairs.append((image.docker_image_id, image.storage.uuid)) - matching_tags = get_matching_tags_for_images(pairs, filter_query=filter_query) + matching_tags = get_matching_tags_for_images(pairs, filter_images=filter_query, + filter_tags=filter_query) expected_tag_ids = set([tag.id for tag in expected_tags_query]) matching_tags_ids = set([tag.id for tag in matching_tags]) diff --git a/util/secscan/notifier.py b/util/secscan/notifier.py index ad31f876b..c71209ebe 100644 --- a/util/secscan/notifier.py +++ b/util/secscan/notifier.py @@ -5,7 +5,9 @@ from collections import defaultdict from enum import Enum from app import secscan_api -from data.model.tag import filter_tags_have_repository_event, get_matching_tags_for_images +from data.model.tag import (filter_has_repository_event, filter_tags_have_repository_event, + get_matching_tags_for_images) + from data.database import (Image, ImageStorage, ExternalNotificationEvent, Repository, RepositoryTag) from endpoints.notificationhelper import notification_batch @@ -159,12 +161,15 @@ class SecurityNotificationHandler(object): pairs = [tuple(layer_id.split('.', 2)) for layer_id in self.layer_ids] def filter_notifying_repos(query): - query = query.join(ImageStorage) + return filter_has_repository_event(query, self.event) + + def filter_and_order(query): return filter_tags_have_repository_event(query, self.event) # Find the matching tags. tags = get_matching_tags_for_images(pairs, selections=[RepositoryTag, Image, ImageStorage], - filter_query=filter_notifying_repos) + filter_images=filter_notifying_repos, + filter_tags=filter_and_order) cve_id = self.vulnerability_info['Name'] for tag in tags: