Unify the database connection lifecycle across all workers

This commit is contained in:
Jake Moshenko 2015-12-04 15:51:53 -05:00
parent 38cb63d195
commit 2f626f2691
7 changed files with 111 additions and 130 deletions

View file

@ -1,13 +1,12 @@
import logging import logging
from peewee import fn
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
from gzip import GzipFile from gzip import GzipFile
from data import model from data import model
from data.archivedlogs import JSON_MIMETYPE from data.archivedlogs import JSON_MIMETYPE
from data.database import RepositoryBuild, db_random_func from data.database import CloseForLongOperation
from app import build_logs, log_archive from app import build_logs, log_archive, app
from util.streamingjsonencoder import StreamingJSONEncoder from util.streamingjsonencoder import StreamingJSONEncoder
from workers.worker import Worker from workers.worker import Worker
@ -39,19 +38,21 @@ class ArchiveBuildLogsWorker(Worker):
'logs': entries, 'logs': entries,
} }
with SpooledTemporaryFile(MEMORY_TEMPFILE_SIZE) as tempfile: with CloseForLongOperation(app.config):
with GzipFile('testarchive', fileobj=tempfile) as zipstream: with SpooledTemporaryFile(MEMORY_TEMPFILE_SIZE) as tempfile:
for chunk in StreamingJSONEncoder().iterencode(to_encode): with GzipFile('testarchive', fileobj=tempfile) as zipstream:
zipstream.write(chunk) for chunk in StreamingJSONEncoder().iterencode(to_encode):
zipstream.write(chunk)
tempfile.seek(0) tempfile.seek(0)
log_archive.store_file(tempfile, JSON_MIMETYPE, content_encoding='gzip', log_archive.store_file(tempfile, JSON_MIMETYPE, content_encoding='gzip',
file_id=to_archive.uuid) file_id=to_archive.uuid)
to_archive.logs_archived = True to_update = model.build.get_repository_build(to_archive.uuid)
to_archive.save() to_update.logs_archived = True
to_update.save()
build_logs.expire_log_entries(to_archive.uuid) build_logs.expire_log_entries(to_update.uuid)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,7 +1,6 @@
import logging import logging
from app import app from app import app
from data.database import UseThenDisconnect
from data.model.repository import find_repository_with_garbage, garbage_collect_repo from data.model.repository import find_repository_with_garbage, garbage_collect_repo
from workers.worker import Worker from workers.worker import Worker
@ -14,17 +13,15 @@ class GarbageCollectionWorker(Worker):
def _garbage_collection_repos(self): def _garbage_collection_repos(self):
""" Performs garbage collection on repositories. """ """ Performs garbage collection on repositories. """
with UseThenDisconnect(app.config): repository = find_repository_with_garbage()
repository = find_repository_with_garbage() if repository is None:
if repository is None: logger.debug('No repository with garbage found')
logger.debug('No repository with garbage found')
return
logger.debug('Starting GC of repository #%s (%s)', repository.id, repository.name)
garbage_collect_repo(repository)
logger.debug('Finished GC of repository #%s (%s)', repository.id, repository.name)
return return
logger.debug('Starting GC of repository #%s (%s)', repository.id, repository.name)
garbage_collect_repo(repository)
logger.debug('Finished GC of repository #%s (%s)', repository.id, repository.name)
if __name__ == "__main__": if __name__ == "__main__":
worker = GarbageCollectionWorker() worker = GarbageCollectionWorker()
worker.start() worker.start()

View file

@ -1,19 +1,11 @@
import logging import logging
import json import json
import signal
import sys
from threading import Event, Lock from threading import Event, Lock
from datetime import datetime, timedelta
from threading import Thread
from time import sleep
from app import app from app import app
from data.model import db from data.model import db
from data.queue import WorkQueue from data.database import CloseForLongOperation
from data.database import UseThenDisconnect
from workers.worker import Worker from workers.worker import Worker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -92,20 +84,20 @@ class QueueWorker(Worker):
with self._current_item_lock: with self._current_item_lock:
current_queue_item = self.current_queue_item current_queue_item = self.current_queue_item
if current_queue_item is None: if current_queue_item is None:
# Close the db handle.
self._close_db_handle()
break break
logger.debug('Queue gave us some work: %s', current_queue_item.body) logger.debug('Queue gave us some work: %s', current_queue_item.body)
job_details = json.loads(current_queue_item.body) job_details = json.loads(current_queue_item.body)
try: try:
self.process_queue_item(job_details) with CloseForLongOperation(app.config):
self.process_queue_item(job_details)
self.mark_current_complete() self.mark_current_complete()
except JobException as jex: except JobException as jex:
logger.warning('An error occurred processing request: %s', current_queue_item.body) logger.warning('An error occurred processing request: %s', current_queue_item.body)
logger.warning('Job exception: %s' % jex) logger.warning('Job exception: %s', jex)
self.mark_current_incomplete(restore_retry=False) self.mark_current_incomplete(restore_retry=False)
except WorkerUnhealthyException as exc: except WorkerUnhealthyException as exc:
@ -114,10 +106,6 @@ class QueueWorker(Worker):
self.mark_current_incomplete(restore_retry=True) self.mark_current_incomplete(restore_retry=True)
self._stop.set() self._stop.set()
finally:
# Close the db handle.
self._close_db_handle()
if not self._stop.is_set(): if not self._stop.is_set():
with self._current_item_lock: with self._current_item_lock:
self.current_queue_item = self._queue.get(processing_time=self._reservation_seconds) self.current_queue_item = self._queue.get(processing_time=self._reservation_seconds)
@ -126,8 +114,7 @@ class QueueWorker(Worker):
logger.debug('No more work.') logger.debug('No more work.')
def update_queue_metrics(self): def update_queue_metrics(self):
with UseThenDisconnect(app.config): self._queue.update_metrics()
self._queue.update_metrics()
def mark_current_incomplete(self, restore_retry=False): def mark_current_incomplete(self, restore_retry=False):
with self._current_item_lock: with self._current_item_lock:

View file

@ -1,9 +1,8 @@
import logging import logging
from app import app from app import app
from data.database import (Repository, LogEntry, RepositoryActionCount, db_random_func, fn, from data.database import Repository, LogEntry, RepositoryActionCount, db_random_func
UseThenDisconnect) from datetime import date, timedelta
from datetime import date, datetime, timedelta
from workers.worker import Worker from workers.worker import Worker
POLL_PERIOD_SECONDS = 10 POLL_PERIOD_SECONDS = 10
@ -11,33 +10,32 @@ POLL_PERIOD_SECONDS = 10
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def count_repository_actions(): def count_repository_actions():
with UseThenDisconnect(app.config): try:
# Get a random repository to count.
today = date.today()
yesterday = today - timedelta(days=1)
has_yesterday_actions = (RepositoryActionCount.select(RepositoryActionCount.repository)
.where(RepositoryActionCount.date == yesterday))
to_count = (Repository.select()
.where(~(Repository.id << (has_yesterday_actions)))
.order_by(db_random_func()).get())
logger.debug('Counting: %s', to_count.id)
actions = (LogEntry.select()
.where(LogEntry.repository == to_count,
LogEntry.datetime >= yesterday,
LogEntry.datetime < today)
.count())
# Create the row.
try: try:
# Get a random repository to count. RepositoryActionCount.create(repository=to_count, date=yesterday, count=actions)
today = date.today() except:
yesterday = today - timedelta(days=1) logger.exception('Exception when writing count')
has_yesterday_actions = (RepositoryActionCount.select(RepositoryActionCount.repository) except Repository.DoesNotExist:
.where(RepositoryActionCount.date == yesterday)) logger.debug('No further repositories to count')
to_count = (Repository.select()
.where(~(Repository.id << (has_yesterday_actions)))
.order_by(db_random_func()).get())
logger.debug('Counting: %s', to_count.id)
actions = (LogEntry.select()
.where(LogEntry.repository == to_count,
LogEntry.datetime >= yesterday,
LogEntry.datetime < today)
.count())
# Create the row.
try:
RepositoryActionCount.create(repository=to_count, date=yesterday, count=actions)
except:
logger.exception('Exception when writing count')
except Repository.DoesNotExist:
logger.debug('No further repositories to count')
class RepositoryActionCountWorker(Worker): class RepositoryActionCountWorker(Worker):

View file

@ -13,7 +13,7 @@ from data import model
from data.model.tag import filter_tags_have_repository_event, get_tags_for_image from data.model.tag import filter_tags_have_repository_event, get_tags_for_image
from data.model.image import get_secscan_candidates, set_secscan_status from data.model.image import get_secscan_candidates, set_secscan_status
from data.model.storage import get_storage_locations from data.model.storage import get_storage_locations
from data.database import (UseThenDisconnect, ExternalNotificationEvent) from data.database import ExternalNotificationEvent
from util.secscan.api import SecurityConfigValidator from util.secscan.api import SecurityConfigValidator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -150,68 +150,67 @@ class SecurityWorker(Worker):
logger.debug('Started indexing') logger.debug('Started indexing')
event = ExternalNotificationEvent.get(name='vulnerability_found') event = ExternalNotificationEvent.get(name='vulnerability_found')
with UseThenDisconnect(app.config): while True:
while True: # Lookup the images to index.
# Lookup the images to index. images = []
images = [] logger.debug('Looking up images to index')
logger.debug('Looking up images to index') images = get_secscan_candidates(self._target_version, BATCH_SIZE)
images = get_secscan_candidates(self._target_version, BATCH_SIZE)
if not images: if not images:
logger.debug('No more images left to analyze') logger.debug('No more images left to analyze')
return
logger.debug('Found %d images to index', len(images))
for image in images:
# If we couldn't analyze the parent, we can't analyze this image.
if (image.parent and not image.parent.security_indexed and
image.parent.security_indexed_engine >= self._target_version):
set_secscan_status(image, False, self._target_version)
continue
# Analyze the image.
analyzed = self._analyze_image(image)
if not analyzed:
return return
logger.debug('Found %d images to index', len(images)) # Get the tags of the image we analyzed
for image in images: matching = list(filter_tags_have_repository_event(get_tags_for_image(image.id), event))
# If we couldn't analyze the parent, we can't analyze this image.
if (image.parent and not image.parent.security_indexed and repository_map = defaultdict(list)
image.parent.security_indexed_engine >= self._target_version):
set_secscan_status(image, False, self._target_version) for tag in matching:
repository_map[tag.repository_id].append(tag)
# If there is at least one tag,
# Lookup the vulnerabilities for the image, now that it is analyzed.
if len(repository_map) > 0:
logger.debug('Loading vulnerabilities for layer %s', image.id)
sec_data = self._get_vulnerabilities(image)
if sec_data is None:
continue continue
# Analyze the image. if not sec_data.get('Vulnerabilities'):
analyzed = self._analyze_image(image) continue
if not analyzed:
return
# Get the tags of the image we analyzed # Dispatch events for any detected vulnerabilities
matching = list(filter_tags_have_repository_event(get_tags_for_image(image.id), event)) logger.debug('Got vulnerabilities for layer %s: %s', image.id, sec_data)
repository_map = defaultdict(list) for repository_id in repository_map:
tags = repository_map[repository_id]
for tag in matching: for vuln in sec_data['Vulnerabilities']:
repository_map[tag.repository_id].append(tag) event_data = {
'tags': [tag.name for tag in tags],
'vulnerability': {
'id': vuln['ID'],
'description': vuln['Description'],
'link': vuln['Link'],
'priority': vuln['Priority'],
},
}
# If there is at least one tag, spawn_notification(tags[0].repository, 'vulnerability_found', event_data)
# Lookup the vulnerabilities for the image, now that it is analyzed.
if len(repository_map) > 0:
logger.debug('Loading vulnerabilities for layer %s', image.id)
sec_data = self._get_vulnerabilities(image)
if sec_data is None:
continue
if not sec_data.get('Vulnerabilities'):
continue
# Dispatch events for any detected vulnerabilities
logger.debug('Got vulnerabilities for layer %s: %s', image.id, sec_data)
for repository_id in repository_map:
tags = repository_map[repository_id]
for vuln in sec_data['Vulnerabilities']:
event_data = {
'tags': [tag.name for tag in tags],
'vulnerability': {
'id': vuln['ID'],
'description': vuln['Description'],
'link': vuln['Link'],
'priority': vuln['Priority'],
},
}
spawn_notification(tags[0].repository, 'vulnerability_found', event_data)
if __name__ == '__main__': if __name__ == '__main__':
if not features.SECURITY_SCANNER: if not features.SECURITY_SCANNER:

View file

@ -3,7 +3,7 @@ import features
import time import time
from app import app, storage, image_replication_queue from app import app, storage, image_replication_queue
from data.database import UseThenDisconnect, CloseForLongOperation from data.database import CloseForLongOperation
from data import model from data import model
from storage.basestorage import StoragePaths from storage.basestorage import StoragePaths
from workers.queueworker import QueueWorker from workers.queueworker import QueueWorker

View file

@ -6,12 +6,10 @@ import socket
from threading import Event from threading import Event
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Thread
from time import sleep
from raven import Client from raven import Client
from app import app from app import app
from data.model import db from data.database import UseThenDisconnect
from functools import wraps from functools import wraps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,7 +42,8 @@ class Worker(object):
@wraps(operation_func) @wraps(operation_func)
def _operation_func(): def _operation_func():
try: try:
return operation_func() with UseThenDisconnect(app.config):
return operation_func()
except Exception: except Exception:
logger.exception('Operation raised exception') logger.exception('Operation raised exception')
if self._raven_client: if self._raven_client: