Unify the database connection lifecycle across all workers

This commit is contained in:
Jake Moshenko 2015-12-04 15:51:53 -05:00
parent 38cb63d195
commit 2f626f2691
7 changed files with 111 additions and 130 deletions

View file

@ -1,13 +1,12 @@
import logging
from peewee import fn
from tempfile import SpooledTemporaryFile
from gzip import GzipFile
from data import model
from data.archivedlogs import JSON_MIMETYPE
from data.database import RepositoryBuild, db_random_func
from app import build_logs, log_archive
from data.database import CloseForLongOperation
from app import build_logs, log_archive, app
from util.streamingjsonencoder import StreamingJSONEncoder
from workers.worker import Worker
@ -39,19 +38,21 @@ class ArchiveBuildLogsWorker(Worker):
'logs': entries,
}
with SpooledTemporaryFile(MEMORY_TEMPFILE_SIZE) as tempfile:
with GzipFile('testarchive', fileobj=tempfile) as zipstream:
for chunk in StreamingJSONEncoder().iterencode(to_encode):
zipstream.write(chunk)
with CloseForLongOperation(app.config):
with SpooledTemporaryFile(MEMORY_TEMPFILE_SIZE) as tempfile:
with GzipFile('testarchive', fileobj=tempfile) as zipstream:
for chunk in StreamingJSONEncoder().iterencode(to_encode):
zipstream.write(chunk)
tempfile.seek(0)
log_archive.store_file(tempfile, JSON_MIMETYPE, content_encoding='gzip',
file_id=to_archive.uuid)
tempfile.seek(0)
log_archive.store_file(tempfile, JSON_MIMETYPE, content_encoding='gzip',
file_id=to_archive.uuid)
to_archive.logs_archived = True
to_archive.save()
to_update = model.build.get_repository_build(to_archive.uuid)
to_update.logs_archived = True
to_update.save()
build_logs.expire_log_entries(to_archive.uuid)
build_logs.expire_log_entries(to_update.uuid)
if __name__ == "__main__":

View file

@ -1,7 +1,6 @@
import logging
from app import app
from data.database import UseThenDisconnect
from data.model.repository import find_repository_with_garbage, garbage_collect_repo
from workers.worker import Worker
@ -14,17 +13,15 @@ class GarbageCollectionWorker(Worker):
def _garbage_collection_repos(self):
""" Performs garbage collection on repositories. """
with UseThenDisconnect(app.config):
repository = find_repository_with_garbage()
if repository is None:
logger.debug('No repository with garbage found')
return
logger.debug('Starting GC of repository #%s (%s)', repository.id, repository.name)
garbage_collect_repo(repository)
logger.debug('Finished GC of repository #%s (%s)', repository.id, repository.name)
repository = find_repository_with_garbage()
if repository is None:
logger.debug('No repository with garbage found')
return
logger.debug('Starting GC of repository #%s (%s)', repository.id, repository.name)
garbage_collect_repo(repository)
logger.debug('Finished GC of repository #%s (%s)', repository.id, repository.name)
if __name__ == "__main__":
worker = GarbageCollectionWorker()
worker.start()

View file

@ -1,19 +1,11 @@
import logging
import json
import signal
import sys
from threading import Event, Lock
from datetime import datetime, timedelta
from threading import Thread
from time import sleep
from app import app
from data.model import db
from data.queue import WorkQueue
from data.database import UseThenDisconnect
from data.database import CloseForLongOperation
from workers.worker import Worker
logger = logging.getLogger(__name__)
@ -92,20 +84,20 @@ class QueueWorker(Worker):
with self._current_item_lock:
current_queue_item = self.current_queue_item
if current_queue_item is None:
# Close the db handle.
self._close_db_handle()
break
logger.debug('Queue gave us some work: %s', current_queue_item.body)
job_details = json.loads(current_queue_item.body)
try:
self.process_queue_item(job_details)
with CloseForLongOperation(app.config):
self.process_queue_item(job_details)
self.mark_current_complete()
except JobException as jex:
logger.warning('An error occurred processing request: %s', current_queue_item.body)
logger.warning('Job exception: %s' % jex)
logger.warning('Job exception: %s', jex)
self.mark_current_incomplete(restore_retry=False)
except WorkerUnhealthyException as exc:
@ -114,10 +106,6 @@ class QueueWorker(Worker):
self.mark_current_incomplete(restore_retry=True)
self._stop.set()
finally:
# Close the db handle.
self._close_db_handle()
if not self._stop.is_set():
with self._current_item_lock:
self.current_queue_item = self._queue.get(processing_time=self._reservation_seconds)
@ -126,8 +114,7 @@ class QueueWorker(Worker):
logger.debug('No more work.')
def update_queue_metrics(self):
with UseThenDisconnect(app.config):
self._queue.update_metrics()
self._queue.update_metrics()
def mark_current_incomplete(self, restore_retry=False):
with self._current_item_lock:

View file

@ -1,9 +1,8 @@
import logging
from app import app
from data.database import (Repository, LogEntry, RepositoryActionCount, db_random_func, fn,
UseThenDisconnect)
from datetime import date, datetime, timedelta
from data.database import Repository, LogEntry, RepositoryActionCount, db_random_func
from datetime import date, timedelta
from workers.worker import Worker
POLL_PERIOD_SECONDS = 10
@ -11,33 +10,32 @@ POLL_PERIOD_SECONDS = 10
logger = logging.getLogger(__name__)
def count_repository_actions():
with UseThenDisconnect(app.config):
try:
# Get a random repository to count.
today = date.today()
yesterday = today - timedelta(days=1)
has_yesterday_actions = (RepositoryActionCount.select(RepositoryActionCount.repository)
.where(RepositoryActionCount.date == yesterday))
to_count = (Repository.select()
.where(~(Repository.id << (has_yesterday_actions)))
.order_by(db_random_func()).get())
logger.debug('Counting: %s', to_count.id)
actions = (LogEntry.select()
.where(LogEntry.repository == to_count,
LogEntry.datetime >= yesterday,
LogEntry.datetime < today)
.count())
# Create the row.
try:
# Get a random repository to count.
today = date.today()
yesterday = today - timedelta(days=1)
has_yesterday_actions = (RepositoryActionCount.select(RepositoryActionCount.repository)
.where(RepositoryActionCount.date == yesterday))
to_count = (Repository.select()
.where(~(Repository.id << (has_yesterday_actions)))
.order_by(db_random_func()).get())
logger.debug('Counting: %s', to_count.id)
actions = (LogEntry.select()
.where(LogEntry.repository == to_count,
LogEntry.datetime >= yesterday,
LogEntry.datetime < today)
.count())
# Create the row.
try:
RepositoryActionCount.create(repository=to_count, date=yesterday, count=actions)
except:
logger.exception('Exception when writing count')
except Repository.DoesNotExist:
logger.debug('No further repositories to count')
RepositoryActionCount.create(repository=to_count, date=yesterday, count=actions)
except:
logger.exception('Exception when writing count')
except Repository.DoesNotExist:
logger.debug('No further repositories to count')
class RepositoryActionCountWorker(Worker):

View file

@ -13,7 +13,7 @@ from data import model
from data.model.tag import filter_tags_have_repository_event, get_tags_for_image
from data.model.image import get_secscan_candidates, set_secscan_status
from data.model.storage import get_storage_locations
from data.database import (UseThenDisconnect, ExternalNotificationEvent)
from data.database import ExternalNotificationEvent
from util.secscan.api import SecurityConfigValidator
logger = logging.getLogger(__name__)
@ -150,68 +150,67 @@ class SecurityWorker(Worker):
logger.debug('Started indexing')
event = ExternalNotificationEvent.get(name='vulnerability_found')
with UseThenDisconnect(app.config):
while True:
# Lookup the images to index.
images = []
logger.debug('Looking up images to index')
images = get_secscan_candidates(self._target_version, BATCH_SIZE)
while True:
# Lookup the images to index.
images = []
logger.debug('Looking up images to index')
images = get_secscan_candidates(self._target_version, BATCH_SIZE)
if not images:
logger.debug('No more images left to analyze')
if not images:
logger.debug('No more images left to analyze')
return
logger.debug('Found %d images to index', len(images))
for image in images:
# If we couldn't analyze the parent, we can't analyze this image.
if (image.parent and not image.parent.security_indexed and
image.parent.security_indexed_engine >= self._target_version):
set_secscan_status(image, False, self._target_version)
continue
# Analyze the image.
analyzed = self._analyze_image(image)
if not analyzed:
return
logger.debug('Found %d images to index', len(images))
for image in images:
# If we couldn't analyze the parent, we can't analyze this image.
if (image.parent and not image.parent.security_indexed and
image.parent.security_indexed_engine >= self._target_version):
set_secscan_status(image, False, self._target_version)
# Get the tags of the image we analyzed
matching = list(filter_tags_have_repository_event(get_tags_for_image(image.id), event))
repository_map = defaultdict(list)
for tag in matching:
repository_map[tag.repository_id].append(tag)
# If there is at least one tag,
# Lookup the vulnerabilities for the image, now that it is analyzed.
if len(repository_map) > 0:
logger.debug('Loading vulnerabilities for layer %s', image.id)
sec_data = self._get_vulnerabilities(image)
if sec_data is None:
continue
# Analyze the image.
analyzed = self._analyze_image(image)
if not analyzed:
return
if not sec_data.get('Vulnerabilities'):
continue
# Get the tags of the image we analyzed
matching = list(filter_tags_have_repository_event(get_tags_for_image(image.id), event))
# Dispatch events for any detected vulnerabilities
logger.debug('Got vulnerabilities for layer %s: %s', image.id, sec_data)
repository_map = defaultdict(list)
for repository_id in repository_map:
tags = repository_map[repository_id]
for tag in matching:
repository_map[tag.repository_id].append(tag)
for vuln in sec_data['Vulnerabilities']:
event_data = {
'tags': [tag.name for tag in tags],
'vulnerability': {
'id': vuln['ID'],
'description': vuln['Description'],
'link': vuln['Link'],
'priority': vuln['Priority'],
},
}
# If there is at least one tag,
# Lookup the vulnerabilities for the image, now that it is analyzed.
if len(repository_map) > 0:
logger.debug('Loading vulnerabilities for layer %s', image.id)
sec_data = self._get_vulnerabilities(image)
if sec_data is None:
continue
if not sec_data.get('Vulnerabilities'):
continue
# Dispatch events for any detected vulnerabilities
logger.debug('Got vulnerabilities for layer %s: %s', image.id, sec_data)
for repository_id in repository_map:
tags = repository_map[repository_id]
for vuln in sec_data['Vulnerabilities']:
event_data = {
'tags': [tag.name for tag in tags],
'vulnerability': {
'id': vuln['ID'],
'description': vuln['Description'],
'link': vuln['Link'],
'priority': vuln['Priority'],
},
}
spawn_notification(tags[0].repository, 'vulnerability_found', event_data)
spawn_notification(tags[0].repository, 'vulnerability_found', event_data)
if __name__ == '__main__':
if not features.SECURITY_SCANNER:

View file

@ -3,7 +3,7 @@ import features
import time
from app import app, storage, image_replication_queue
from data.database import UseThenDisconnect, CloseForLongOperation
from data.database import CloseForLongOperation
from data import model
from storage.basestorage import StoragePaths
from workers.queueworker import QueueWorker

View file

@ -6,12 +6,10 @@ import socket
from threading import Event
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime, timedelta
from threading import Thread
from time import sleep
from raven import Client
from app import app
from data.model import db
from data.database import UseThenDisconnect
from functools import wraps
logger = logging.getLogger(__name__)
@ -44,7 +42,8 @@ class Worker(object):
@wraps(operation_func)
def _operation_func():
try:
return operation_func()
with UseThenDisconnect(app.config):
return operation_func()
except Exception:
logger.exception('Operation raised exception')
if self._raven_client: