initial import for Open Source 🎉

This commit is contained in:
Jimmy Zelinskie 2019-11-12 11:09:47 -05:00
parent 1898c361f3
commit 9c0dd3b722
2048 changed files with 218743 additions and 0 deletions

19
util/__init__.py Normal file
View file

@ -0,0 +1,19 @@
def get_app_url(config):
""" Returns the application's URL, based on the given config. """
return '%s://%s' % (config['PREFERRED_URL_SCHEME'], config['SERVER_HOSTNAME'])
def slash_join(*args):
"""
Joins together strings and guarantees there is only one '/' in between the
each string joined. Double slashes ('//') are assumed to be intentional and
are not deduplicated.
"""
def rmslash(path):
path = path[1:] if len(path) > 0 and path[0] == '/' else path
path = path[:-1] if len(path) > 0 and path[-1] == '/' else path
return path
args = [rmslash(path) for path in args]
return '/'.join(args)

19
util/abchelpers.py Normal file
View file

@ -0,0 +1,19 @@
class NoopIsANoopException(TypeError):
""" Raised if the nooper decorator is unnecessary on a class. """
pass
def nooper(cls):
""" Decorates a class that derives from an ABCMeta, filling in any unimplemented methods with
no-ops.
"""
def empty_func(*args, **kwargs):
# pylint: disable=unused-argument
pass
empty_methods = {m_name: empty_func for m_name in cls.__abstractmethods__}
if not empty_methods:
raise NoopIsANoopException('nooper implemented no abstract methods on %s' % cls)
return type(cls.__name__, (cls,), empty_methods)

68
util/asyncwrapper.py Normal file
View file

@ -0,0 +1,68 @@
import queue
from functools import wraps
from concurrent.futures import Executor, Future, CancelledError
class AsyncExecutorWrapper(object):
""" This class will wrap a syncronous library transparently in a way which
will move all calls off to an asynchronous Executor, and will change all
returned values to be Future objects.
"""
SYNC_FLAG_FIELD = '__AsyncExecutorWrapper__sync__'
def __init__(self, delegate, executor):
""" Wrap the specified synchronous delegate instance, and submit() all
method calls to the specified Executor instance.
"""
self._delegate = delegate
self._executor = executor
def __getattr__(self, attr_name):
maybe_callable = getattr(self._delegate, attr_name) # Will raise proper attribute error
if callable(maybe_callable):
# Build a callable which when executed places the request
# onto a queue
@wraps(maybe_callable)
def wrapped_method(*args, **kwargs):
if getattr(maybe_callable, self.SYNC_FLAG_FIELD, False):
sync_result = Future()
try:
sync_result.set_result(maybe_callable(*args, **kwargs))
except Exception as ex:
sync_result.set_exception(ex)
return sync_result
try:
return self._executor.submit(maybe_callable, *args, **kwargs)
except queue.Full as ex:
queue_full = Future()
queue_full.set_exception(ex)
return queue_full
return wrapped_method
else:
return maybe_callable
@classmethod
def sync(cls, f):
""" Annotate the given method to flag it as synchronous so that AsyncExecutorWrapper
will return the result immediately without submitting it to the executor.
"""
setattr(f, cls.SYNC_FLAG_FIELD, True)
return f
class NullExecutorCancelled(CancelledError):
def __init__(self):
super(NullExecutorCancelled, self).__init__('Null executor always fails.')
class NullExecutor(Executor):
""" Executor instance which always returns a Future completed with a
CancelledError exception. """
def submit(self, _, *args, **kwargs):
always_fail = Future()
always_fail.set_exception(NullExecutorCancelled())
return always_fail

90
util/audit.py Normal file
View file

@ -0,0 +1,90 @@
import logging
import random
from collections import namedtuple
from urlparse import urlparse
from flask import request
from app import analytics, userevents, ip_resolver
from auth.auth_context import get_authenticated_context, get_authenticated_user
from data.logs_model import logs_model
from util.request import get_request_ip
from data.readreplica import ReadOnlyModeException
logger = logging.getLogger(__name__)
Repository = namedtuple('Repository', ['namespace_name', 'name', 'id', 'is_free_namespace'])
def wrap_repository(repo_obj):
return Repository(namespace_name=repo_obj.namespace_user.username, name=repo_obj.name,
id=repo_obj.id, is_free_namespace=repo_obj.namespace_user.stripe_id is None)
def track_and_log(event_name, repo_obj, analytics_name=None, analytics_sample=1, **kwargs):
repo_name = repo_obj.name
namespace_name = repo_obj.namespace_name
metadata = {
'repo': repo_name,
'namespace': namespace_name,
}
metadata.update(kwargs)
is_free_namespace = False
if hasattr(repo_obj, 'is_free_namespace'):
is_free_namespace = repo_obj.is_free_namespace
# Add auth context metadata.
analytics_id = 'anonymous'
auth_context = get_authenticated_context()
if auth_context is not None:
analytics_id, context_metadata = auth_context.analytics_id_and_public_metadata()
metadata.update(context_metadata)
# Publish the user event (if applicable)
logger.debug('Checking publishing %s to the user events system', event_name)
if auth_context and auth_context.has_nonrobot_user:
logger.debug('Publishing %s to the user events system', event_name)
user_event_data = {
'action': event_name,
'repository': repo_name,
'namespace': namespace_name,
}
event = userevents.get_event(auth_context.authed_user.username)
event.publish_event_data('docker-cli', user_event_data)
# Save the action to mixpanel.
if random.random() < analytics_sample:
if analytics_name is None:
analytics_name = event_name
logger.debug('Logging the %s to analytics engine', analytics_name)
request_parsed = urlparse(request.url_root)
extra_params = {
'repository': '%s/%s' % (namespace_name, repo_name),
'user-agent': request.user_agent.string,
'hostname': request_parsed.hostname,
}
analytics.track(analytics_id, analytics_name, extra_params)
# Add the resolved information to the metadata.
logger.debug('Resolving IP address %s', get_request_ip())
resolved_ip = ip_resolver.resolve_ip(get_request_ip())
if resolved_ip is not None:
metadata['resolved_ip'] = resolved_ip._asdict()
logger.debug('Resolved IP address %s', get_request_ip())
# Log the action to the database.
logger.debug('Logging the %s to logs system', event_name)
try:
logs_model.log_action(event_name, namespace_name, performer=get_authenticated_user(),
ip=get_request_ip(), metadata=metadata, repository=repo_obj,
is_free_namespace=is_free_namespace)
logger.debug('Track and log of %s complete', event_name)
except ReadOnlyModeException:
pass

View file

@ -0,0 +1,48 @@
import logging
import features
from app import storage, image_replication_queue
from data.database import (Image, ImageStorage, Repository, User, ImageStoragePlacement,
ImageStorageLocation)
from data import model
from util.registry.replication import queue_storage_replication
def backfill_replication():
encountered = set()
query = (Image
.select(Image, ImageStorage, Repository, User)
.join(ImageStorage)
.switch(Image)
.join(Repository)
.join(User))
for image in query:
if image.storage.uuid in encountered:
continue
namespace = image.repository.namespace_user
locations = model.user.get_region_locations(namespace)
locations_required = locations | set(storage.default_locations)
query = (ImageStoragePlacement
.select(ImageStoragePlacement, ImageStorageLocation)
.where(ImageStoragePlacement.storage == image.storage)
.join(ImageStorageLocation))
existing_locations = set([p.location.name for p in query])
locations_missing = locations_required - existing_locations
if locations_missing:
print "Enqueueing image storage %s to be replicated" % (image.storage.uuid)
encountered.add(image.storage.uuid)
if not image_replication_queue.alive([image.storage.uuid]):
queue_storage_replication(image.repository.namespace_user.username, image.storage)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
if not features.STORAGE_REPLICATION:
print "Storage replication is not enabled"
else:
backfill_replication()

5
util/backoff.py Normal file
View file

@ -0,0 +1,5 @@
def exponential_backoff(attempts, scaling_factor, base):
backoff = 5 * (pow(2, attempts) - 1)
backoff_time = backoff * scaling_factor
retry_at = backoff_time/10 + base
return retry_at

32
util/bytes.py Normal file
View file

@ -0,0 +1,32 @@
class Bytes(object):
""" Wrapper around strings and unicode objects to ensure we are always using
the correct encoded or decoded data.
"""
def __init__(self, data):
assert isinstance(data, str)
self._encoded_data = data
@classmethod
def for_string_or_unicode(cls, input):
# If the string is a unicode string, then encode its data as UTF-8. Note that
# we don't catch any decode exceptions here, as we want those to be raised.
if isinstance(input, unicode):
return Bytes(input.encode('utf-8'))
# Next, try decoding as UTF-8. If we have a utf-8 encoded string, then we have no
# additional conversion to do.
try:
input.decode('utf-8')
return Bytes(input)
except UnicodeDecodeError:
pass
# Finally, if the data is (somehow) a unicode string inside a `str` type, then
# re-encoded the data.
return Bytes(input.encode('utf-8'))
def as_encoded_str(self):
return self._encoded_data
def as_unicode(self):
return self._encoded_data.decode('utf-8')

36
util/cache.py Normal file
View file

@ -0,0 +1,36 @@
from functools import wraps
from flask_restful.utils import unpack
def cache_control(max_age=55):
def wrap(f):
@wraps(f)
def add_max_age(*args, **kwargs):
response = f(*args, **kwargs)
response.headers['Cache-Control'] = 'max-age=%d' % max_age
return response
return add_max_age
return wrap
def cache_control_flask_restful(max_age=55):
def wrap(f):
@wraps(f)
def add_max_age(*args, **kwargs):
response = f(*args, **kwargs)
body, status_code, headers = unpack(response)
headers['Cache-Control'] = 'max-age=%d' % max_age
return body, status_code, headers
return add_max_age
return wrap
def no_cache(f):
@wraps(f)
def add_no_cache(*args, **kwargs):
response = f(*args, **kwargs)
if response is not None:
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
return response
return add_no_cache

18
util/canonicaljson.py Normal file
View file

@ -0,0 +1,18 @@
import collections
def canonicalize(json_obj):
"""This function canonicalizes a Python object that will be serialized as JSON.
Args:
json_obj (object): the Python object that will later be serialized as JSON.
Returns:
object: json_obj now sorted to its canonical form.
"""
if isinstance(json_obj, collections.MutableMapping):
sorted_obj = sorted({key: canonicalize(val) for key, val in json_obj.items()}.items())
return collections.OrderedDict(sorted_obj)
elif isinstance(json_obj, (list, tuple)):
return [canonicalize(val) for val in json_obj]
return json_obj

29
util/config/__init__.py Normal file
View file

@ -0,0 +1,29 @@
class URLSchemeAndHostname:
"""
Immutable configuration for a given preferred url scheme (e.g. http or https), and a hostname (e.g. localhost:5000)
"""
def __init__(self, url_scheme, hostname):
self._url_scheme = url_scheme
self._hostname = hostname
@classmethod
def from_app_config(cls, app_config):
"""
Helper method to instantiate class from app config, a frequent pattern
:param app_config:
:return:
"""
return cls(app_config['PREFERRED_URL_SCHEME'], app_config['SERVER_HOSTNAME'])
@property
def url_scheme(self):
return self._url_scheme
@property
def hostname(self):
return self._hostname
def get_url(self):
""" Returns the application's URL, based on the given url scheme and hostname. """
return '%s://%s' % (self._url_scheme, self._hostname)

View file

@ -0,0 +1,43 @@
""" Generates html documentation from JSON Schema """
import json
from collections import OrderedDict
import docsmodel
import html_output
from util.config.schema import CONFIG_SCHEMA
def make_custom_sort(orders):
""" Sort in a specified order any dictionary nested in a complex structure """
orders = [{k: -i for (i, k) in enumerate(reversed(order), 1)} for order in orders]
def process(stuff):
if isinstance(stuff, dict):
l = [(k, process(v)) for (k, v) in stuff.iteritems()]
keys = set(stuff)
for order in orders:
if keys.issubset(order) or keys.issuperset(order):
return OrderedDict(sorted(l, key=lambda x: order.get(x[0], 0)))
return OrderedDict(sorted(l))
if isinstance(stuff, list):
return [process(x) for x in stuff]
return stuff
return process
SCHEMA_HTML_FILE = "schema.html"
schema = json.dumps(CONFIG_SCHEMA, sort_keys = True)
schema = json.loads(schema, object_pairs_hook = OrderedDict)
req = sorted(schema["required"])
custom_sort = make_custom_sort([req])
schema = custom_sort(schema)
parsed_items = docsmodel.DocsModel().parse(schema)[1:]
output = html_output.HtmlOutput().generate_output(parsed_items)
with open(SCHEMA_HTML_FILE, 'wt') as f:
f.write(output)

View file

@ -0,0 +1,93 @@
import json
import collections
class ParsedItem(dict):
""" Parsed Schema item """
def __init__(self, json_object, name, required, level):
"""Fills dict with basic item information"""
super(ParsedItem, self).__init__()
self['name'] = name
self['title'] = json_object.get('title', '')
self['type'] = json_object.get('type')
self['description'] = json_object.get('description', '')
self['level'] = level
self['required'] = required
self['x-reference'] = json_object.get('x-reference', '')
self['x-example'] = json_object.get('x-example', '')
self['pattern'] = json_object.get('pattern', '')
self['enum'] = json_object.get('enum', '')
class DocsModel:
""" Documentation model and Schema Parser """
def __init__(self):
self.__parsed_items = None
def parse(self, json_object):
""" Returns multi-level list of recursively parsed items """
self.__parsed_items = list()
self.__parse_schema(json_object, 'root', True, 0)
return self.__parsed_items
def __parse_schema(self, schema, name, required, level):
""" Parses schema, which type is object, array or leaf.
Appends new ParsedItem to self.__parsed_items lis """
parsed_item = ParsedItem(schema, name, required, level)
self.__parsed_items.append(parsed_item)
required = schema.get('required', [])
if 'enum' in schema:
parsed_item['item'] = schema.get('enum')
item_type = schema.get('type')
if item_type == 'object' and name != 'DISTRIBUTED_STORAGE_CONFIG':
self.__parse_object(parsed_item, schema, required, level)
elif item_type == 'array':
self.__parse_array(parsed_item, schema, required, level)
else:
parse_leaf(parsed_item, schema)
def __parse_object(self, parsed_item, schema, required, level):
""" Parses schema of type object """
for key, value in schema.get('properties', {}).items():
self.__parse_schema(value, key, key in required, level + 1)
def __parse_array(self, parsed_item, schema, required, level):
""" Parses schema of type array """
items = schema.get('items')
parsed_item['minItems'] = schema.get('minItems', None)
parsed_item['maxItems'] = schema.get('maxItems', None)
parsed_item['uniqueItems'] = schema.get('uniqueItems', False)
if isinstance(items, dict):
# item is single schema describing all elements in an array
self.__parse_schema(
items,
'array item',
required,
level + 1)
elif isinstance(items, list):
# item is a list of schemas
for index, list_item in enumerate(items):
self.__parse_schema(
list_item,
'array item {}'.format(index),
index in required,
level + 1)
def parse_leaf(parsed_item, schema):
""" Parses schema of a number and a string """
if parsed_item['name'] != 'root':
parsed_item['description'] = schema.get('description','')
parsed_item['x-reference'] = schema.get('x-reference','')
parsed_item['pattern'] = schema.get('pattern','')
parsed_item['enum'] = ", ".join(schema.get('enum','')).encode()
ex = schema.get('x-example', '')
if isinstance(ex, list):
parsed_item['x-example'] = ", ".join(ex).encode()
elif isinstance(ex, collections.OrderedDict):
parsed_item['x-example'] = json.dumps(ex)
else:
parsed_item['x-example'] = ex

View file

@ -0,0 +1,63 @@
class HtmlOutput:
""" Generates HTML from documentation model """
def __init__(self):
pass
def generate_output(self, parsed_items):
"""Returns generated HTML strin"""
return self.__get_html_begin() + \
self.__get_html_middle(parsed_items) + \
self.__get_html_end()
def __get_html_begin(self):
return '<!DOCTYPE html>\n<html>\n<head>\n<link rel="stylesheet" type="text/css" href="style.css" />\n</head>\n<body>\n'
def __get_html_end(self):
return '</body>\n</html>'
def __get_html_middle(self, parsed_items):
output = ''
root_item = parsed_items[0]
#output += '<h1 class="root_title">{}</h1>\n'.format(root_item['title'])
#output += '<h1 class="root_title">{}</h1>\n'.format(root_item['title'])
output += "Schema for Red Hat Quay"
output += '<ul class="level0">\n'
last_level = 0
is_root = True
for item in parsed_items:
level = item['level'] - 1
if last_level < level:
output += '<ul class="level{}">\n'.format(level)
for i in range(last_level - level):
output += '</ul>\n'
last_level = level
output += self.__get_html_item(item, is_root)
is_root = False
output += '</ul>\n'
return output
def __get_required_field(self, parsed_item):
return 'required' if parsed_item['required'] else ''
def __get_html_item(self, parsed_item, is_root):
item = '<li class="schema item"> \n'
item += '<div class="name">{}</div> \n'.format(parsed_item['name'])
item += '<div class="type">[{}]</div> \n'.format(parsed_item['type'])
item += '<div class="required">{}</div> \n'.format(self.__get_required_field(parsed_item))
item += '<div class="docs">\n' if not is_root else '<div class="root_docs">\n'
item += '<div class="title">{}</div>\n'.format(parsed_item['title'])
item += ': ' if parsed_item['title'] != '' and parsed_item['description'] != '' else ''
item += '<div class="description">{}</div>\n'.format(parsed_item['description'])
item += '<div class="enum">enum: {}</div>\n'.format(parsed_item['enum']) if parsed_item['enum'] != '' else ''
item += '<div class="minItems">Min Items: {}</div>\n'.format(parsed_item['minItems']) if parsed_item['type'] == "array" and parsed_item['minItems'] != "None" else ''
item += '<div class="uniqueItems">Unique Items: {}</div>\n'.format(parsed_item['uniqueItems']) if parsed_item['type'] == "array" and parsed_item['uniqueItems'] else ''
item += '<div class="pattern">Pattern: {}</div>\n'.format(parsed_item['pattern']) if parsed_item['pattern'] != 'None' and parsed_item['pattern'] != '' else ''
item += '<div class="x-reference"><a href="{}">Reference: {}</a></div>\n'.format(parsed_item['x-reference'],parsed_item['x-reference']) if parsed_item['x-reference'] != '' else ''
item += '<div class="x-example">Example: <code>{}</code></div>\n'.format(parsed_item['x-example']) if parsed_item['x-example'] != '' else ''
item += '</div>\n'
item += '</li>\n'
return item

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,78 @@
body {
font-family: sans-serif;
}
pre, code{
white-space: normal;
}
div.root docs{
display: none;
}
div.name {
display: inline;
}
div.type {
display: inline;
font-weight: bold;
color: blue;
}
div.required {
display: inline;
font-weight: bold;
}
div.docs {
display: block;
}
div.title {
display: block;
font-weight: bold;
}
div.description {
display: block;
font-family: serif;
font-style: italic;
}
div.enum {
display: block;
font-family: serif;
font-style: italic;
}
div.x-example {
display: block;
font-family: serif;
font-style: italic;
margin-bottom: 10px;
}
div.pattern {
display: block;
font-family: serif;
font-style: italic;
}
div.x-reference {
display: block;
font-family: serif;
font-style: italic;
}
div.uniqueItems {
display: block;
}
div.minItems {
display: block;
}
div.maxItems {
display: block;
}

108
util/config/configutil.py Normal file
View file

@ -0,0 +1,108 @@
from random import SystemRandom
from uuid import uuid4
def generate_secret_key():
cryptogen = SystemRandom()
return str(cryptogen.getrandbits(256))
def add_enterprise_config_defaults(config_obj, current_secret_key):
""" Adds/Sets the config defaults for enterprise registry config. """
# These have to be false.
config_obj['TESTING'] = False
config_obj['USE_CDN'] = False
# Default for V3 upgrade.
config_obj['V3_UPGRADE_MODE'] = config_obj.get('V3_UPGRADE_MODE', 'complete')
# Defaults for Red Hat Quay.
config_obj['REGISTRY_TITLE'] = config_obj.get('REGISTRY_TITLE', 'Red Hat Quay')
config_obj['REGISTRY_TITLE_SHORT'] = config_obj.get('REGISTRY_TITLE_SHORT', 'Red Hat Quay')
# Default features that are on.
config_obj['FEATURE_USER_LOG_ACCESS'] = config_obj.get('FEATURE_USER_LOG_ACCESS', True)
config_obj['FEATURE_USER_CREATION'] = config_obj.get('FEATURE_USER_CREATION', True)
config_obj['FEATURE_ANONYMOUS_ACCESS'] = config_obj.get('FEATURE_ANONYMOUS_ACCESS', True)
config_obj['FEATURE_REQUIRE_TEAM_INVITE'] = config_obj.get('FEATURE_REQUIRE_TEAM_INVITE', True)
config_obj['FEATURE_CHANGE_TAG_EXPIRATION'] = config_obj.get('FEATURE_CHANGE_TAG_EXPIRATION',
True)
config_obj['FEATURE_DIRECT_LOGIN'] = config_obj.get('FEATURE_DIRECT_LOGIN', True)
config_obj['FEATURE_APP_SPECIFIC_TOKENS'] = config_obj.get('FEATURE_APP_SPECIFIC_TOKENS', True)
config_obj['FEATURE_PARTIAL_USER_AUTOCOMPLETE'] = config_obj.get('FEATURE_PARTIAL_USER_AUTOCOMPLETE', True)
config_obj['FEATURE_USERNAME_CONFIRMATION'] = config_obj.get('FEATURE_USERNAME_CONFIRMATION', True)
config_obj['FEATURE_RESTRICTED_V1_PUSH'] = config_obj.get('FEATURE_RESTRICTED_V1_PUSH', True)
# Default features that are off.
config_obj['FEATURE_MAILING'] = config_obj.get('FEATURE_MAILING', False)
config_obj['FEATURE_BUILD_SUPPORT'] = config_obj.get('FEATURE_BUILD_SUPPORT', False)
config_obj['FEATURE_ACI_CONVERSION'] = config_obj.get('FEATURE_ACI_CONVERSION', False)
config_obj['FEATURE_APP_REGISTRY'] = config_obj.get('FEATURE_APP_REGISTRY', False)
config_obj['FEATURE_REPO_MIRROR'] = config_obj.get('FEATURE_REPO_MIRROR', False)
# Default repo mirror config.
config_obj['REPO_MIRROR_TLS_VERIFY'] = config_obj.get('REPO_MIRROR_TLS_VERIFY', True)
config_obj['REPO_MIRROR_SERVER_HOSTNAME'] = config_obj.get('REPO_MIRROR_SERVER_HOSTNAME', None)
# Default the signer config.
config_obj['GPG2_PRIVATE_KEY_FILENAME'] = config_obj.get('GPG2_PRIVATE_KEY_FILENAME',
'signing-private.gpg')
config_obj['GPG2_PUBLIC_KEY_FILENAME'] = config_obj.get('GPG2_PUBLIC_KEY_FILENAME',
'signing-public.gpg')
config_obj['SIGNING_ENGINE'] = config_obj.get('SIGNING_ENGINE', 'gpg2')
# Default security scanner config.
config_obj['FEATURE_SECURITY_NOTIFICATIONS'] = config_obj.get(
'FEATURE_SECURITY_NOTIFICATIONS', True)
config_obj['FEATURE_SECURITY_SCANNER'] = config_obj.get(
'FEATURE_SECURITY_SCANNER', False)
config_obj['SECURITY_SCANNER_ISSUER_NAME'] = config_obj.get(
'SECURITY_SCANNER_ISSUER_NAME', 'security_scanner')
# Default time machine config.
config_obj['TAG_EXPIRATION_OPTIONS'] = config_obj.get('TAG_EXPIRATION_OPTIONS',
['0s', '1d', '1w', '2w', '4w'])
config_obj['DEFAULT_TAG_EXPIRATION'] = config_obj.get('DEFAULT_TAG_EXPIRATION', '2w')
# Default mail setings.
config_obj['MAIL_USE_TLS'] = config_obj.get('MAIL_USE_TLS', True)
config_obj['MAIL_PORT'] = config_obj.get('MAIL_PORT', 587)
config_obj['MAIL_DEFAULT_SENDER'] = config_obj.get('MAIL_DEFAULT_SENDER', 'support@quay.io')
# Default auth type.
if not 'AUTHENTICATION_TYPE' in config_obj:
config_obj['AUTHENTICATION_TYPE'] = 'Database'
# Default secret key.
if not 'SECRET_KEY' in config_obj:
if current_secret_key:
config_obj['SECRET_KEY'] = current_secret_key
else:
config_obj['SECRET_KEY'] = generate_secret_key()
# Default database secret key.
if not 'DATABASE_SECRET_KEY' in config_obj:
config_obj['DATABASE_SECRET_KEY'] = generate_secret_key()
# Default torrent pepper.
if not 'BITTORRENT_FILENAME_PEPPER' in config_obj:
config_obj['BITTORRENT_FILENAME_PEPPER'] = str(uuid4())
# Default storage configuration.
if not 'DISTRIBUTED_STORAGE_CONFIG' in config_obj:
config_obj['DISTRIBUTED_STORAGE_PREFERENCE'] = ['default']
config_obj['DISTRIBUTED_STORAGE_CONFIG'] = {
'default': ['LocalStorage', {'storage_path': '/datastorage/registry'}]
}
config_obj['USERFILES_LOCATION'] = 'default'
config_obj['USERFILES_PATH'] = 'userfiles/'
config_obj['LOG_ARCHIVE_LOCATION'] = 'default'
# Misc configuration.
config_obj['PREFERRED_URL_SCHEME'] = config_obj.get('PREFERRED_URL_SCHEME', 'http')
config_obj['ENTERPRISE_LOGO_URL'] = config_obj.get(
'ENTERPRISE_LOGO_URL', '/static/img/quay-horizontal-color.svg')
config_obj['TEAM_RESYNC_STALE_TIME'] = config_obj.get('TEAM_RESYNC_STALE_TIME', '60m')

12
util/config/database.py Normal file
View file

@ -0,0 +1,12 @@
from data import model
from data.appr_model import blob
from data.appr_model.models import NEW_MODELS
def sync_database_with_config(config):
""" This ensures all implicitly required reference table entries exist in the database. """
location_names = config.get('DISTRIBUTED_STORAGE_CONFIG', {}).keys()
if location_names:
model.image.ensure_image_locations(*location_names)
blob.ensure_blob_locations(NEW_MODELS, *location_names)

View file

@ -0,0 +1,13 @@
from util.config.provider.fileprovider import FileConfigProvider
from util.config.provider.testprovider import TestConfigProvider
from util.config.provider.k8sprovider import KubernetesConfigProvider
def get_config_provider(config_volume, yaml_filename, py_filename, testing=False, kubernetes=False):
""" Loads and returns the config provider for the current environment. """
if testing:
return TestConfigProvider()
if kubernetes:
return KubernetesConfigProvider(config_volume, yaml_filename, py_filename)
return FileConfigProvider(config_volume, yaml_filename, py_filename)

View file

@ -0,0 +1,62 @@
import os
import logging
from util.config.provider.baseprovider import (BaseProvider, import_yaml, export_yaml,
CannotWriteConfigException)
logger = logging.getLogger(__name__)
class BaseFileProvider(BaseProvider):
""" Base implementation of the config provider that reads the data from the file system. """
def __init__(self, config_volume, yaml_filename, py_filename):
self.config_volume = config_volume
self.yaml_filename = yaml_filename
self.py_filename = py_filename
self.yaml_path = os.path.join(config_volume, yaml_filename)
self.py_path = os.path.join(config_volume, py_filename)
def update_app_config(self, app_config):
if os.path.exists(self.py_path):
logger.debug('Applying config file: %s', self.py_path)
app_config.from_pyfile(self.py_path)
if os.path.exists(self.yaml_path):
logger.debug('Applying config file: %s', self.yaml_path)
import_yaml(app_config, self.yaml_path)
def get_config(self):
if not self.config_exists():
return None
config_obj = {}
import_yaml(config_obj, self.yaml_path)
return config_obj
def config_exists(self):
return self.volume_file_exists(self.yaml_filename)
def volume_exists(self):
return os.path.exists(self.config_volume)
def volume_file_exists(self, relative_file_path):
return os.path.exists(os.path.join(self.config_volume, relative_file_path))
def get_volume_file(self, relative_file_path, mode='r'):
return open(os.path.join(self.config_volume, relative_file_path), mode=mode)
def get_volume_path(self, directory, relative_file_path):
return os.path.join(directory, relative_file_path)
def list_volume_directory(self, path):
dirpath = os.path.join(self.config_volume, path)
if not os.path.exists(dirpath):
return None
if not os.path.isdir(dirpath):
return None
return os.listdir(dirpath)
def get_config_root(self):
return self.config_volume

View file

@ -0,0 +1,123 @@
import logging
import yaml
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from jsonschema import validate, ValidationError
from util.config.schema import CONFIG_SCHEMA
logger = logging.getLogger(__name__)
class CannotWriteConfigException(Exception):
""" Exception raised when the config cannot be written. """
pass
class SetupIncompleteException(Exception):
""" Exception raised when attempting to verify config that has not yet been setup. """
pass
def import_yaml(config_obj, config_file):
with open(config_file) as f:
c = yaml.safe_load(f)
if not c:
logger.debug('Empty YAML config file')
return
if isinstance(c, str):
raise Exception('Invalid YAML config file: ' + str(c))
for key in c.iterkeys():
if key.isupper():
config_obj[key] = c[key]
if config_obj.get('SETUP_COMPLETE', True):
try:
validate(config_obj, CONFIG_SCHEMA)
except ValidationError:
# TODO: Change this into a real error
logger.exception('Could not validate config schema')
else:
logger.debug('Skipping config schema validation because setup is not complete')
return config_obj
def get_yaml(config_obj):
return yaml.safe_dump(config_obj, encoding='utf-8', allow_unicode=True)
def export_yaml(config_obj, config_file):
try:
with open(config_file, 'w') as f:
f.write(get_yaml(config_obj))
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
@add_metaclass(ABCMeta)
class BaseProvider(object):
""" A configuration provider helps to load, save, and handle config override in the application.
"""
@property
def provider_id(self):
raise NotImplementedError
@abstractmethod
def update_app_config(self, app_config):
""" Updates the given application config object with the loaded override config. """
@abstractmethod
def get_config(self):
""" Returns the contents of the config override file, or None if none. """
@abstractmethod
def save_config(self, config_object):
""" Updates the contents of the config override file to those given. """
@abstractmethod
def config_exists(self):
""" Returns true if a config override file exists in the config volume. """
@abstractmethod
def volume_exists(self):
""" Returns whether the config override volume exists. """
@abstractmethod
def volume_file_exists(self, relative_file_path):
""" Returns whether the file with the given relative path exists under the config override
volume. """
@abstractmethod
def get_volume_file(self, relative_file_path, mode='r'):
""" Returns a Python file referring to the given path under the config override volume. """
@abstractmethod
def remove_volume_file(self, relative_file_path):
""" Removes the config override volume file with the given path. """
@abstractmethod
def list_volume_directory(self, path):
""" Returns a list of strings representing the names of the files found in the config override
directory under the given path. If the path doesn't exist, returns None.
"""
@abstractmethod
def save_volume_file(self, flask_file, relative_file_path):
""" Saves the given flask file to the config override volume, with the given
relative path.
"""
@abstractmethod
def get_volume_path(self, directory, filename):
""" Helper for constructing relative file paths, which may differ between providers.
For example, kubernetes can't have subfolders in configmaps """
@abstractmethod
def get_config_root(self):
""" Returns the config root directory. """

View file

@ -0,0 +1,47 @@
import os
import logging
from util.config.provider.baseprovider import export_yaml, CannotWriteConfigException
from util.config.provider.basefileprovider import BaseFileProvider
logger = logging.getLogger(__name__)
def _ensure_parent_dir(filepath):
""" Ensures that the parent directory of the given file path exists. """
try:
parentpath = os.path.abspath(os.path.join(filepath, os.pardir))
if not os.path.isdir(parentpath):
os.makedirs(parentpath)
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
class FileConfigProvider(BaseFileProvider):
""" Implementation of the config provider that reads and writes the data
from/to the file system. """
def __init__(self, config_volume, yaml_filename, py_filename):
super(FileConfigProvider, self).__init__(config_volume, yaml_filename, py_filename)
@property
def provider_id(self):
return 'file'
def save_config(self, config_obj):
export_yaml(config_obj, self.yaml_path)
def remove_volume_file(self, relative_file_path):
filepath = os.path.join(self.config_volume, relative_file_path)
os.remove(filepath)
def save_volume_file(self, flask_file, relative_file_path):
filepath = os.path.join(self.config_volume, relative_file_path)
_ensure_parent_dir(filepath)
# Write the file.
try:
flask_file.save(filepath)
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
return filepath

View file

@ -0,0 +1,188 @@
import os
import logging
import json
import base64
import time
from cStringIO import StringIO
from requests import Request, Session
from util.config.provider.baseprovider import CannotWriteConfigException, get_yaml
from util.config.provider.basefileprovider import BaseFileProvider
logger = logging.getLogger(__name__)
KUBERNETES_API_HOST = os.environ.get('KUBERNETES_SERVICE_HOST', '')
port = os.environ.get('KUBERNETES_SERVICE_PORT')
if port:
KUBERNETES_API_HOST += ':' + port
SERVICE_ACCOUNT_TOKEN_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/token'
QE_NAMESPACE = os.environ.get('QE_K8S_NAMESPACE', 'quay-enterprise')
QE_CONFIG_SECRET = os.environ.get('QE_K8S_CONFIG_SECRET', 'quay-enterprise-config-secret')
class KubernetesConfigProvider(BaseFileProvider):
""" Implementation of the config provider that reads and writes configuration
data from a Kubernetes Secret. """
def __init__(self, config_volume, yaml_filename, py_filename, api_host=None,
service_account_token_path=None):
super(KubernetesConfigProvider, self).__init__(config_volume, yaml_filename, py_filename)
service_account_token_path = service_account_token_path or SERVICE_ACCOUNT_TOKEN_PATH
api_host = api_host or KUBERNETES_API_HOST
# Load the service account token from the local store.
if not os.path.exists(service_account_token_path):
raise Exception('Cannot load Kubernetes service account token')
with open(service_account_token_path, 'r') as f:
self._service_token = f.read()
self._api_host = api_host
@property
def provider_id(self):
return 'k8s'
def get_volume_path(self, directory, filename):
# NOTE: Overridden to ensure we don't have subdirectories, which aren't supported
# in Kubernetes secrets.
return "_".join([directory.rstrip('/'), filename])
def volume_exists(self):
secret = self._lookup_secret()
return secret is not None
def volume_file_exists(self, relative_file_path):
if '/' in relative_file_path:
raise Exception('Expected path from get_volume_path, but found slashes')
# NOTE: Overridden because we don't have subdirectories, which aren't supported
# in Kubernetes secrets.
secret = self._lookup_secret()
if not secret or not secret.get('data'):
return False
return relative_file_path in secret['data']
def list_volume_directory(self, path):
# NOTE: Overridden because we don't have subdirectories, which aren't supported
# in Kubernetes secrets.
secret = self._lookup_secret()
if not secret:
return []
paths = []
for filename in secret.get('data', {}):
if filename.startswith(path):
paths.append(filename[len(path) + 1:])
return paths
def save_config(self, config_obj):
self._update_secret_file(self.yaml_filename, get_yaml(config_obj))
def remove_volume_file(self, relative_file_path):
try:
self._update_secret_file(relative_file_path, None)
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
def save_volume_file(self, flask_file, relative_file_path):
# Write the file to a temp location.
buf = StringIO()
try:
try:
flask_file.save(buf)
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
self._update_secret_file(relative_file_path, buf.getvalue())
finally:
buf.close()
def _assert_success(self, response):
if response.status_code != 200:
logger.error('Kubernetes API call failed with response: %s => %s', response.status_code,
response.text)
raise CannotWriteConfigException('Kubernetes API call failed: %s' % response.text)
def _update_secret_file(self, relative_file_path, value=None):
if '/' in relative_file_path:
raise Exception('Expected path from get_volume_path, but found slashes')
# Check first that the namespace for Red Hat Quay exists. If it does not, report that
# as an error, as it seems to be a common issue.
namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
response = self._execute_k8s_api('GET', namespace_url)
if response.status_code // 100 != 2:
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
raise CannotWriteConfigException(msg)
# Check if the secret exists. If not, then we create an empty secret and then update the file
# inside.
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
secret = self._lookup_secret()
if secret is None:
self._assert_success(self._execute_k8s_api('POST', secret_url, {
"kind": "Secret",
"apiVersion": "v1",
"metadata": {
"name": QE_CONFIG_SECRET
},
"data": {}
}))
# Update the secret to reflect the file change.
secret['data'] = secret.get('data', {})
if value is not None:
secret['data'][relative_file_path] = base64.b64encode(value)
else:
secret['data'].pop(relative_file_path)
self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))
# Wait until the local mounted copy of the secret has been updated, as
# this is an eventual consistency operation, but the caller expects immediate
# consistency.
while True:
matching_files = set()
for secret_filename, encoded_value in secret['data'].iteritems():
expected_value = base64.b64decode(encoded_value)
try:
with self.get_volume_file(secret_filename) as f:
contents = f.read()
if contents == expected_value:
matching_files.add(secret_filename)
except IOError:
continue
if matching_files == set(secret['data'].keys()):
break
# Sleep for a second and then try again.
time.sleep(1)
def _lookup_secret(self):
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
response = self._execute_k8s_api('GET', secret_url)
if response.status_code != 200:
return None
return json.loads(response.text)
def _execute_k8s_api(self, method, relative_url, data=None):
headers = {
'Authorization': 'Bearer ' + self._service_token
}
if data:
headers['Content-Type'] = 'application/json'
data = json.dumps(data) if data else None
session = Session()
url = 'https://%s/api/v1/%s' % (self._api_host, relative_url)
request = Request(method, url, data=data, headers=headers)
return session.send(request.prepare(), verify=False, timeout=2)

View file

@ -0,0 +1,29 @@
import pytest
from util.config.provider import FileConfigProvider
from test.fixtures import *
class TestFileConfigProvider(FileConfigProvider):
def __init__(self):
self.yaml_filename = 'yaml_filename'
self._service_token = 'service_token'
self.config_volume = 'config_volume'
self.py_filename = 'py_filename'
self.yaml_path = os.path.join(self.config_volume, self.yaml_filename)
self.py_path = os.path.join(self.config_volume, self.py_filename)
@pytest.mark.parametrize('directory,filename,expected', [
("directory", "file", "directory/file"),
("directory/dir", "file", "directory/dir/file"),
("directory/dir/", "file", "directory/dir/file"),
("directory", "file/test", "directory/file/test"),
])
def test_get_volume_path(directory, filename, expected):
provider = TestFileConfigProvider()
assert expected == provider.get_volume_path(directory, filename)

View file

@ -0,0 +1,138 @@
import base64
import os
import json
import uuid
import pytest
from contextlib import contextmanager
from collections import namedtuple
from httmock import urlmatch, HTTMock
from util.config.provider import KubernetesConfigProvider
def normalize_path(path):
return path.replace('/', '_')
@contextmanager
def fake_kubernetes_api(tmpdir_factory, files=None):
hostname = 'kubapi'
service_account_token_path = str(tmpdir_factory.mktemp("k8s").join("serviceaccount"))
auth_header = str(uuid.uuid4())
with open(service_account_token_path, 'w') as f:
f.write(auth_header)
global secret
secret = {
'data': {}
}
def write_file(config_dir, filepath, value):
normalized_path = normalize_path(filepath)
absolute_path = str(config_dir.join(normalized_path))
try:
os.makedirs(os.path.dirname(absolute_path))
except OSError:
pass
with open(absolute_path, 'w') as f:
f.write(value)
config_dir = tmpdir_factory.mktemp("config")
if files:
for filepath, value in files.iteritems():
normalized_path = normalize_path(filepath)
write_file(config_dir, filepath, value)
secret['data'][normalized_path] = base64.b64encode(value)
@urlmatch(netloc=hostname,
path='/api/v1/namespaces/quay-enterprise/secrets/quay-enterprise-config-secret$',
method='get')
def get_secret(_, __):
return {'status_code': 200, 'content': json.dumps(secret)}
@urlmatch(netloc=hostname,
path='/api/v1/namespaces/quay-enterprise/secrets/quay-enterprise-config-secret$',
method='put')
def put_secret(_, request):
updated_secret = json.loads(request.body)
for filepath, value in updated_secret['data'].iteritems():
if filepath not in secret['data']:
# Add
write_file(config_dir, filepath, base64.b64decode(value))
for filepath in secret['data']:
if filepath not in updated_secret['data']:
# Remove.
normalized_path = normalize_path(filepath)
os.remove(str(config_dir.join(normalized_path)))
secret['data'] = updated_secret['data']
return {'status_code': 200, 'content': json.dumps(secret)}
@urlmatch(netloc=hostname, path='/api/v1/namespaces/quay-enterprise$')
def get_namespace(_, __):
return {'status_code': 200, 'content': json.dumps({})}
@urlmatch(netloc=hostname)
def catch_all(url, _):
print url
return {'status_code': 404, 'content': '{}'}
with HTTMock(get_secret, put_secret, get_namespace, catch_all):
provider = KubernetesConfigProvider(str(config_dir), 'config.yaml', 'config.py',
api_host=hostname,
service_account_token_path=service_account_token_path)
# Validate all the files.
for filepath, value in files.iteritems():
normalized_path = normalize_path(filepath)
assert provider.volume_file_exists(normalized_path)
with provider.get_volume_file(normalized_path) as f:
assert f.read() == value
yield provider
def test_basic_config(tmpdir_factory):
basic_files = {
'config.yaml': 'FOO: bar',
}
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
assert provider.config_exists()
assert provider.get_config() is not None
assert provider.get_config()['FOO'] == 'bar'
@pytest.mark.parametrize('filepath', [
'foo',
'foo/meh',
'foo/bar/baz',
])
def test_remove_file(filepath, tmpdir_factory):
basic_files = {
filepath: 'foo',
}
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
normalized_path = normalize_path(filepath)
assert provider.volume_file_exists(normalized_path)
provider.remove_volume_file(normalized_path)
assert not provider.volume_file_exists(normalized_path)
class TestFlaskFile(object):
def save(self, buf):
buf.write('hello world!')
def test_save_file(tmpdir_factory):
basic_files = {}
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
assert not provider.volume_file_exists('testfile')
flask_file = TestFlaskFile()
provider.save_volume_file(flask_file, 'testfile')
assert provider.volume_file_exists('testfile')

View file

@ -0,0 +1,77 @@
import json
import io
import os
from datetime import datetime, timedelta
from util.config.provider.baseprovider import BaseProvider
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
class TestConfigProvider(BaseProvider):
""" Implementation of the config provider for testing. Everything is kept in-memory instead on
the real file system. """
def get_config_root(self):
raise Exception('Test Config does not have a config root')
def __init__(self):
self.clear()
def clear(self):
self.files = {}
self._config = {}
@property
def provider_id(self):
return 'test'
def update_app_config(self, app_config):
self._config = app_config
def get_config(self):
if not 'config.yaml' in self.files:
return None
return json.loads(self.files.get('config.yaml', '{}'))
def save_config(self, config_obj):
self.files['config.yaml'] = json.dumps(config_obj)
def config_exists(self):
return 'config.yaml' in self.files
def volume_exists(self):
return True
def volume_file_exists(self, filename):
if filename in REAL_FILES:
return True
return filename in self.files
def save_volume_file(self, flask_file, filename):
self.files[filename] = flask_file.read()
def get_volume_file(self, filename, mode='r'):
if filename in REAL_FILES:
return open(filename, mode=mode)
return io.BytesIO(self.files[filename])
def remove_volume_file(self, filename):
self.files.pop(filename, None)
def list_volume_directory(self, path):
paths = []
for filename in self.files:
if filename.startswith(path):
paths.append(filename[len(path)+1:])
return paths
def reset_for_test(self):
self._config['SUPER_USERS'] = ['devtable']
self.files = {}
def get_volume_path(self, directory, filename):
return os.path.join(directory, filename)

1232
util/config/schema.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,38 @@
from multiprocessing.sharedctypes import Array
from util.validation import MAX_USERNAME_LENGTH
class SuperUserManager(object):
""" In-memory helper class for quickly accessing (and updating) the valid
set of super users. This class communicates across processes to ensure
that the shared set is always the same.
"""
def __init__(self, app):
usernames = app.config.get('SUPER_USERS', [])
usernames_str = ','.join(usernames)
self._max_length = len(usernames_str) + MAX_USERNAME_LENGTH + 1
self._array = Array('c', self._max_length, lock=True)
self._array.value = usernames_str
def is_superuser(self, username):
""" Returns if the given username represents a super user. """
usernames = self._array.value.split(',')
return username in usernames
def register_superuser(self, username):
""" Registers a new username as a super user for the duration of the container.
Note that this does *not* change any underlying config files.
"""
usernames = self._array.value.split(',')
usernames.append(username)
new_string = ','.join(usernames)
if len(new_string) <= self._max_length:
self._array.value = new_string
else:
raise Exception('Maximum superuser count reached. Please report this to support.')
def has_superusers(self):
""" Returns whether there are any superusers defined. """
return bool(self._array.value)

View file

@ -0,0 +1,7 @@
from config import DefaultConfig
from util.config.schema import CONFIG_SCHEMA, INTERNAL_ONLY_PROPERTIES
def test_ensure_schema_defines_all_fields():
for key in vars(DefaultConfig):
has_key = key in CONFIG_SCHEMA['properties'] or key in INTERNAL_ONLY_PROPERTIES
assert has_key, "Property `%s` is missing from config schema" % key

View file

@ -0,0 +1,32 @@
import pytest
from util.config.validator import is_valid_config_upload_filename
from util.config.validator import CONFIG_FILENAMES, CONFIG_FILE_SUFFIXES
def test_valid_config_upload_filenames():
for filename in CONFIG_FILENAMES:
assert is_valid_config_upload_filename(filename)
for suffix in CONFIG_FILE_SUFFIXES:
assert is_valid_config_upload_filename('foo' + suffix)
assert not is_valid_config_upload_filename(suffix + 'foo')
@pytest.mark.parametrize('filename, expect_valid', [
('', False),
('foo', False),
('config.yaml', False),
('ssl.cert', True),
('ssl.key', True),
('ssl.crt', False),
('foobar-cloudfront-signing-key.pem', True),
('foobaz-cloudfront-signing-key.pem', True),
('barbaz-cloudfront-signing-key.pem', True),
('barbaz-cloudfront-signing-key.pem.bak', False),
])
def test_is_valid_config_upload_filename(filename, expect_valid):
assert is_valid_config_upload_filename(filename) == expect_valid

152
util/config/validator.py Normal file
View file

@ -0,0 +1,152 @@
import logging
from auth.auth_context import get_authenticated_user
from data.users import LDAP_CERT_FILENAME
from util.secscan.secscan_util import get_blob_download_uri_getter
from util.config import URLSchemeAndHostname
from util.config.validators.validate_database import DatabaseValidator
from util.config.validators.validate_redis import RedisValidator
from util.config.validators.validate_storage import StorageValidator
from util.config.validators.validate_ldap import LDAPValidator
from util.config.validators.validate_keystone import KeystoneValidator
from util.config.validators.validate_jwt import JWTAuthValidator
from util.config.validators.validate_secscan import SecurityScannerValidator
from util.config.validators.validate_signer import SignerValidator
from util.config.validators.validate_torrent import BittorrentValidator
from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES
from util.config.validators.validate_google_login import GoogleLoginValidator
from util.config.validators.validate_bitbucket_trigger import BitbucketTriggerValidator
from util.config.validators.validate_gitlab_trigger import GitLabTriggerValidator
from util.config.validators.validate_github import GitHubLoginValidator, GitHubTriggerValidator
from util.config.validators.validate_oidc import OIDCLoginValidator
from util.config.validators.validate_timemachine import TimeMachineValidator
from util.config.validators.validate_access import AccessSettingsValidator
from util.config.validators.validate_actionlog_archiving import ActionLogArchivingValidator
from util.config.validators.validate_apptokenauth import AppTokenAuthValidator
logger = logging.getLogger(__name__)
class ConfigValidationException(Exception):
""" Exception raised when the configuration fails to validate for a known reason. """
pass
# Note: Only add files required for HTTPS to the SSL_FILESNAMES list.
DB_SSL_FILENAMES = ['database.pem']
JWT_FILENAMES = ['jwt-authn.cert']
ACI_CERT_FILENAMES = ['signing-public.gpg', 'signing-private.gpg']
LDAP_FILENAMES = [LDAP_CERT_FILENAME]
CONFIG_FILENAMES = (SSL_FILENAMES + DB_SSL_FILENAMES + JWT_FILENAMES + ACI_CERT_FILENAMES +
LDAP_FILENAMES)
CONFIG_FILE_SUFFIXES = ['-cloudfront-signing-key.pem']
EXTRA_CA_DIRECTORY = 'extra_ca_certs'
EXTRA_CA_DIRECTORY_PREFIX = 'extra_ca_certs_'
VALIDATORS = {
DatabaseValidator.name: DatabaseValidator.validate,
RedisValidator.name: RedisValidator.validate,
StorageValidator.name: StorageValidator.validate,
GitHubLoginValidator.name: GitHubLoginValidator.validate,
GitHubTriggerValidator.name: GitHubTriggerValidator.validate,
GitLabTriggerValidator.name: GitLabTriggerValidator.validate,
BitbucketTriggerValidator.name: BitbucketTriggerValidator.validate,
GoogleLoginValidator.name: GoogleLoginValidator.validate,
SSLValidator.name: SSLValidator.validate,
LDAPValidator.name: LDAPValidator.validate,
JWTAuthValidator.name: JWTAuthValidator.validate,
KeystoneValidator.name: KeystoneValidator.validate,
SignerValidator.name: SignerValidator.validate,
SecurityScannerValidator.name: SecurityScannerValidator.validate,
BittorrentValidator.name: BittorrentValidator.validate,
OIDCLoginValidator.name: OIDCLoginValidator.validate,
TimeMachineValidator.name: TimeMachineValidator.validate,
AccessSettingsValidator.name: AccessSettingsValidator.validate,
ActionLogArchivingValidator.name: ActionLogArchivingValidator.validate,
AppTokenAuthValidator.name: AppTokenAuthValidator.validate,
}
def validate_service_for_config(service, validator_context):
""" Attempts to validate the configuration for the given service. """
if not service in VALIDATORS:
return {
'status': False
}
try:
VALIDATORS[service](validator_context)
return {
'status': True
}
except Exception as ex:
logger.exception('Validation exception')
return {
'status': False,
'reason': str(ex)
}
def is_valid_config_upload_filename(filename):
""" Returns true if and only if the given filename is one which is supported for upload
from the configuration UI tool.
"""
if filename in CONFIG_FILENAMES:
return True
return any([filename.endswith(suffix) for suffix in CONFIG_FILE_SUFFIXES])
class ValidatorContext(object):
""" Context to run validators in, with any additional runtime configuration they need
"""
def __init__(self, config, user_password=None, http_client=None, context=None,
url_scheme_and_hostname=None, jwt_auth_max=None, registry_title=None,
ip_resolver=None, feature_sec_scanner=False, is_testing=False,
uri_creator=None, config_provider=None, instance_keys=None,
init_scripts_location=None):
self.config = config
self.user = get_authenticated_user()
self.user_password = user_password
self.http_client = http_client
self.context = context
self.url_scheme_and_hostname = url_scheme_and_hostname
self.jwt_auth_max = jwt_auth_max
self.registry_title = registry_title
self.ip_resolver = ip_resolver
self.feature_sec_scanner = feature_sec_scanner
self.is_testing = is_testing
self.uri_creator = uri_creator
self.config_provider = config_provider
self.instance_keys = instance_keys
self.init_scripts_location = init_scripts_location
@classmethod
def from_app(cls, app, config, user_password, ip_resolver, instance_keys, client=None,
config_provider=None, init_scripts_location=None):
"""
Creates a ValidatorContext from an app config, with a given config to validate
:param app: the Flask app to pull configuration information from
:param config: the config to validate
:param user_password: request password
:param instance_keys: The instance keys handler
:param ip_resolver: an App
:param client: http client used to connect to services
:param config_provider: config provider used to access config volume(s)
:param init_scripts_location: location where initial load scripts are stored
:return: ValidatorContext
"""
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(app.config)
return cls(config,
user_password=user_password,
http_client=client or app.config['HTTPCLIENT'],
context=app.app_context,
url_scheme_and_hostname=url_scheme_and_hostname,
jwt_auth_max=app.config.get('JWT_AUTH_MAX_FRESH_S', 300),
registry_title=app.config['REGISTRY_TITLE'],
ip_resolver=ip_resolver,
feature_sec_scanner=app.config.get('FEATURE_SECURITY_SCANNER', False),
is_testing=app.config.get('TESTING', False),
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname),
config_provider=config_provider,
instance_keys=instance_keys,
init_scripts_location=init_scripts_location)

View file

@ -0,0 +1,20 @@
from abc import ABCMeta, abstractmethod, abstractproperty
from six import add_metaclass
class ConfigValidationException(Exception):
""" Exception raised when the configuration fails to validate for a known reason. """
pass
@add_metaclass(ABCMeta)
class BaseValidator(object):
@abstractproperty
def name(self):
""" The key for the validation API. """
pass
@classmethod
@abstractmethod
def validate(cls, validator_context):
""" Raises Exception if failure to validate. """
pass

View file

@ -0,0 +1,29 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_access import AccessSettingsValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config, expected_exception', [
({}, None),
({'FEATURE_DIRECT_LOGIN': False}, ConfigValidationException),
({'FEATURE_DIRECT_LOGIN': False, 'SOMETHING_LOGIN_CONFIG': {}}, None),
({'FEATURE_DIRECT_LOGIN': False, 'FEATURE_GITHUB_LOGIN': True}, None),
({'FEATURE_DIRECT_LOGIN': False, 'FEATURE_GOOGLE_LOGIN': True}, None),
({'FEATURE_USER_CREATION': True, 'FEATURE_INVITE_ONLY_USER_CREATION': False}, None),
({'FEATURE_USER_CREATION': True, 'FEATURE_INVITE_ONLY_USER_CREATION': True}, None),
({'FEATURE_INVITE_ONLY_USER_CREATION': True}, None),
({'FEATURE_USER_CREATION': False, 'FEATURE_INVITE_ONLY_USER_CREATION': True},
ConfigValidationException),
({'FEATURE_USER_CREATION': False, 'FEATURE_INVITE_ONLY_USER_CREATION': False}, None),
])
def test_validate_invalid_oidc_login_config(unvalidated_config, expected_exception, app):
validator = AccessSettingsValidator()
if expected_exception is not None:
with pytest.raises(expected_exception):
validator.validate(ValidatorContext(unvalidated_config))
else:
validator.validate(ValidatorContext(unvalidated_config))

View file

@ -0,0 +1,52 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_actionlog_archiving import ActionLogArchivingValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'ACTION_LOG_ARCHIVE_PATH': 'foo'}),
({'ACTION_LOG_ARCHIVE_LOCATION': ''}),
])
def test_skip_validate_actionlog(unvalidated_config, app):
validator = ActionLogArchivingValidator()
validator.validate(ValidatorContext(unvalidated_config))
@pytest.mark.parametrize('config, expected_error', [
({'FEATURE_ACTION_LOG_ROTATION': True}, 'Missing action log archive path'),
({'FEATURE_ACTION_LOG_ROTATION': True,
'ACTION_LOG_ARCHIVE_PATH': ''}, 'Missing action log archive path'),
({'FEATURE_ACTION_LOG_ROTATION': True,
'ACTION_LOG_ARCHIVE_PATH': 'foo'}, 'Missing action log archive storage location'),
({'FEATURE_ACTION_LOG_ROTATION': True,
'ACTION_LOG_ARCHIVE_PATH': 'foo',
'ACTION_LOG_ARCHIVE_LOCATION': ''}, 'Missing action log archive storage location'),
({'FEATURE_ACTION_LOG_ROTATION': True,
'ACTION_LOG_ARCHIVE_PATH': 'foo',
'ACTION_LOG_ARCHIVE_LOCATION': 'invalid'},
'Action log archive storage location `invalid` not found in storage config'),
])
def test_invalid_config(config, expected_error, app):
validator = ActionLogArchivingValidator()
with pytest.raises(ConfigValidationException) as ipe:
validator.validate(ValidatorContext(config))
assert str(ipe.value) == expected_error
def test_valid_config(app):
config = ValidatorContext({
'FEATURE_ACTION_LOG_ROTATION': True,
'ACTION_LOG_ARCHIVE_PATH': 'somepath',
'ACTION_LOG_ARCHIVE_LOCATION': 'somelocation',
'DISTRIBUTED_STORAGE_CONFIG': {
'somelocation': {},
},
})
validator = ActionLogArchivingValidator()
validator.validate(config)

View file

@ -0,0 +1,30 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_apptokenauth import AppTokenAuthValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({'AUTHENTICATION_TYPE': 'AppToken'}),
({'AUTHENTICATION_TYPE': 'AppToken', 'FEATURE_APP_SPECIFIC_TOKENS': False}),
({'AUTHENTICATION_TYPE': 'AppToken', 'FEATURE_APP_SPECIFIC_TOKENS': True,
'FEATURE_DIRECT_LOGIN': True}),
])
def test_validate_invalid_auth_config(unvalidated_config, app):
validator = AppTokenAuthValidator()
with pytest.raises(ConfigValidationException):
validator.validate(ValidatorContext(unvalidated_config))
def test_validate_auth(app):
config = ValidatorContext({
'AUTHENTICATION_TYPE': 'AppToken',
'FEATURE_APP_SPECIFIC_TOKENS': True,
'FEATURE_DIRECT_LOGIN': False,
})
validator = AppTokenAuthValidator()
validator.validate(config)

View file

@ -0,0 +1,48 @@
import pytest
from httmock import urlmatch, HTTMock
from util.config import URLSchemeAndHostname
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_bitbucket_trigger import BitbucketTriggerValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
(ValidatorContext({})),
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {}})),
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {'CONSUMER_KEY': 'foo'}})),
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {'CONSUMER_SECRET': 'foo'}})),
])
def test_validate_invalid_bitbucket_trigger_config(unvalidated_config, app):
validator = BitbucketTriggerValidator()
with pytest.raises(ConfigValidationException):
validator.validate(unvalidated_config)
def test_validate_bitbucket_trigger(app):
url_hit = [False]
@urlmatch(netloc=r'bitbucket.org')
def handler(url, request):
url_hit[0] = True
return {
'status_code': 200,
'content': 'oauth_token=foo&oauth_token_secret=bar',
}
with HTTMock(handler):
validator = BitbucketTriggerValidator()
url_scheme_and_hostname = URLSchemeAndHostname('http', 'localhost:5000')
unvalidated_config = ValidatorContext({
'BITBUCKET_TRIGGER_CONFIG': {
'CONSUMER_KEY': 'foo',
'CONSUMER_SECRET': 'bar',
},
}, url_scheme_and_hostname=url_scheme_and_hostname)
validator.validate(unvalidated_config)
assert url_hit[0]

View file

@ -0,0 +1,23 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_database import DatabaseValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config,user,user_password,expected', [
(ValidatorContext(None), None, None, TypeError),
(ValidatorContext({}), None, None, KeyError),
(ValidatorContext({'DB_URI': 'sqlite:///:memory:'}), None, None, None),
(ValidatorContext({'DB_URI': 'invalid:///:memory:'}), None, None, KeyError),
(ValidatorContext({'DB_NOTURI': 'sqlite:///:memory:'}), None, None, KeyError),
])
def test_validate_database(unvalidated_config, user, user_password, expected, app):
validator = DatabaseValidator()
if expected is not None:
with pytest.raises(expected):
validator.validate(unvalidated_config)
else:
validator.validate(unvalidated_config)

View file

@ -0,0 +1,69 @@
import pytest
from httmock import urlmatch, HTTMock
from config import build_requests_session
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_github import GitHubLoginValidator, GitHubTriggerValidator
from test.fixtures import *
@pytest.fixture(params=[GitHubLoginValidator, GitHubTriggerValidator])
def github_validator(request):
return request.param
@pytest.mark.parametrize('github_config', [
({}),
({'GITHUB_ENDPOINT': 'foo'}),
({'GITHUB_ENDPOINT': 'http://github.com'}),
({'GITHUB_ENDPOINT': 'http://github.com', 'CLIENT_ID': 'foo'}),
({'GITHUB_ENDPOINT': 'http://github.com', 'CLIENT_SECRET': 'foo'}),
({
'GITHUB_ENDPOINT': 'http://github.com',
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'foo',
'ORG_RESTRICT': True
}),
({
'GITHUB_ENDPOINT': 'http://github.com',
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'foo',
'ORG_RESTRICT': True,
'ALLOWED_ORGANIZATIONS': [],
}),
])
def test_validate_invalid_github_config(github_config, github_validator, app):
with pytest.raises(ConfigValidationException):
unvalidated_config = {}
unvalidated_config[github_validator.config_key] = github_config
github_validator.validate(ValidatorContext(unvalidated_config))
def test_validate_github(github_validator, app):
url_hit = [False, False]
@urlmatch(netloc=r'somehost')
def handler(url, request):
url_hit[0] = True
return {'status_code': 200, 'content': '', 'headers': {'X-GitHub-Request-Id': 'foo'}}
@urlmatch(netloc=r'somehost', path=r'/api/v3/applications/foo/tokens/foo')
def app_handler(url, request):
url_hit[1] = True
return {'status_code': 404, 'content': '', 'headers': {'X-GitHub-Request-Id': 'foo'}}
with HTTMock(app_handler, handler):
unvalidated_config = ValidatorContext({
github_validator.config_key: {
'GITHUB_ENDPOINT': 'http://somehost',
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
},
})
unvalidated_config.http_client = build_requests_session()
github_validator.validate(unvalidated_config)
assert url_hit[0]
assert url_hit[1]

View file

@ -0,0 +1,49 @@
import json
import pytest
from httmock import urlmatch, HTTMock
from config import build_requests_session
from util.config import URLSchemeAndHostname
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_gitlab_trigger import GitLabTriggerValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'foo'}}),
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'http://someendpoint', 'CLIENT_ID': 'foo'}}),
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'http://someendpoint', 'CLIENT_SECRET': 'foo'}}),
])
def test_validate_invalid_gitlab_trigger_config(unvalidated_config, app):
validator = GitLabTriggerValidator()
with pytest.raises(ConfigValidationException):
validator.validate(ValidatorContext(unvalidated_config))
def test_validate_gitlab_enterprise_trigger(app):
url_hit = [False]
@urlmatch(netloc=r'somegitlab', path='/oauth/token')
def handler(_, __):
url_hit[0] = True
return {'status_code': 400, 'content': json.dumps({'error': 'invalid code'})}
with HTTMock(handler):
validator = GitLabTriggerValidator()
url_scheme_and_hostname = URLSchemeAndHostname('http', 'localhost:5000')
unvalidated_config = ValidatorContext({
'GITLAB_TRIGGER_CONFIG': {
'GITLAB_ENDPOINT': 'http://somegitlab',
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
},
}, http_client=build_requests_session(), url_scheme_and_hostname=url_scheme_and_hostname)
validator.validate(unvalidated_config)
assert url_hit[0]

View file

@ -0,0 +1,45 @@
import pytest
from httmock import urlmatch, HTTMock
from config import build_requests_session
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_google_login import GoogleLoginValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'GOOGLE_LOGIN_CONFIG': {}}),
({'GOOGLE_LOGIN_CONFIG': {'CLIENT_ID': 'foo'}}),
({'GOOGLE_LOGIN_CONFIG': {'CLIENT_SECRET': 'foo'}}),
])
def test_validate_invalid_google_login_config(unvalidated_config, app):
validator = GoogleLoginValidator()
with pytest.raises(ConfigValidationException):
validator.validate(ValidatorContext(unvalidated_config))
def test_validate_google_login(app):
url_hit = [False]
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v3/token')
def handler(_, __):
url_hit[0] = True
return {'status_code': 200, 'content': ''}
validator = GoogleLoginValidator()
with HTTMock(handler):
unvalidated_config = ValidatorContext({
'GOOGLE_LOGIN_CONFIG': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
},
})
unvalidated_config.http_client = build_requests_session()
validator.validate(unvalidated_config)
assert url_hit[0]

View file

@ -0,0 +1,65 @@
import pytest
from config import build_requests_session
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_jwt import JWTAuthValidator
from util.morecollections import AttrDict
from test.test_external_jwt_authn import fake_jwt
from test.fixtures import *
from app import config_provider
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'AUTHENTICATION_TYPE': 'Database'}),
])
def test_validate_noop(unvalidated_config, app):
config = ValidatorContext(unvalidated_config)
config.config_provider = config_provider
JWTAuthValidator.validate(config)
@pytest.mark.parametrize('unvalidated_config', [
({'AUTHENTICATION_TYPE': 'JWT'}),
({'AUTHENTICATION_TYPE': 'JWT', 'JWT_AUTH_ISSUER': 'foo'}),
({'AUTHENTICATION_TYPE': 'JWT', 'JWT_VERIFY_ENDPOINT': 'foo'}),
])
def test_invalid_config(unvalidated_config, app):
with pytest.raises(ConfigValidationException):
config = ValidatorContext(unvalidated_config)
config.config_provider = config_provider
JWTAuthValidator.validate(config)
# TODO: fix these when re-adding jwt auth mechanism to jwt validators
@pytest.mark.skip(reason='No way of currently testing this')
@pytest.mark.parametrize('username, password, expected_exception', [
('invaliduser', 'invalidpass', ConfigValidationException),
('cool.user', 'invalidpass', ConfigValidationException),
('invaliduser', 'somepass', ConfigValidationException),
('cool.user', 'password', None),
])
def test_validated_jwt(username, password, expected_exception, app):
with fake_jwt() as jwt_auth:
config = {}
config['AUTHENTICATION_TYPE'] = 'JWT'
config['JWT_AUTH_ISSUER'] = jwt_auth.issuer
config['JWT_VERIFY_ENDPOINT'] = jwt_auth.verify_url
config['JWT_QUERY_ENDPOINT'] = jwt_auth.query_url
config['JWT_GETUSER_ENDPOINT'] = jwt_auth.getuser_url
unvalidated_config = ValidatorContext(config)
unvalidated_config.user = AttrDict(dict(username=username))
unvalidated_config.user_password = password
unvalidated_config.config_provider = config_provider
unvalidated_config.http_client = build_requests_session()
if expected_exception is not None:
with pytest.raises(ConfigValidationException):
JWTAuthValidator.validate(unvalidated_config, public_key_path=jwt_auth.public_key_path)
else:
JWTAuthValidator.validate(unvalidated_config, public_key_path=jwt_auth.public_key_path)

View file

@ -0,0 +1,54 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_keystone import KeystoneValidator
from test.test_keystone_auth import fake_keystone
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'AUTHENTICATION_TYPE': 'Database'}),
])
def test_validate_noop(unvalidated_config, app):
KeystoneValidator.validate(ValidatorContext(unvalidated_config))
@pytest.mark.parametrize('unvalidated_config', [
({'AUTHENTICATION_TYPE': 'Keystone'}),
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo'}),
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo',
'KEYSTONE_ADMIN_USERNAME': 'bar'}),
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo',
'KEYSTONE_ADMIN_USERNAME': 'bar', 'KEYSTONE_ADMIN_PASSWORD': 'baz'}),
])
def test_invalid_config(unvalidated_config, app):
with pytest.raises(ConfigValidationException):
KeystoneValidator.validate(ValidatorContext(unvalidated_config))
@pytest.mark.parametrize('admin_tenant_id, expected_exception', [
('somegroupid', None),
('groupwithnousers', ConfigValidationException),
('somegroupid', None),
('groupwithnousers', ConfigValidationException),
])
def test_validated_keystone(admin_tenant_id, expected_exception, app):
with fake_keystone(2) as keystone_auth:
auth_url = keystone_auth.auth_url
config = {}
config['AUTHENTICATION_TYPE'] = 'Keystone'
config['KEYSTONE_AUTH_URL'] = auth_url
config['KEYSTONE_ADMIN_USERNAME'] = 'adminuser'
config['KEYSTONE_ADMIN_PASSWORD'] = 'adminpass'
config['KEYSTONE_ADMIN_TENANT'] = admin_tenant_id
unvalidated_config = ValidatorContext(config)
if expected_exception is not None:
with pytest.raises(ConfigValidationException):
KeystoneValidator.validate(unvalidated_config)
else:
KeystoneValidator.validate(unvalidated_config)

View file

@ -0,0 +1,72 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_ldap import LDAPValidator
from util.morecollections import AttrDict
from test.test_ldap import mock_ldap
from test.fixtures import *
from app import config_provider
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'AUTHENTICATION_TYPE': 'Database'}),
])
def test_validate_noop(unvalidated_config, app):
config = ValidatorContext(unvalidated_config, config_provider=config_provider)
LDAPValidator.validate(config)
@pytest.mark.parametrize('unvalidated_config', [
({'AUTHENTICATION_TYPE': 'LDAP'}),
({'AUTHENTICATION_TYPE': 'LDAP', 'LDAP_ADMIN_DN': 'foo'}),
])
def test_invalid_config(unvalidated_config, app):
with pytest.raises(ConfigValidationException):
config = ValidatorContext(unvalidated_config, config_provider=config_provider)
LDAPValidator.validate(config)
@pytest.mark.parametrize('uri', [
'foo',
'http://foo',
'ldap:foo',
])
def test_invalid_uri(uri, app):
config = {}
config['AUTHENTICATION_TYPE'] = 'LDAP'
config['LDAP_BASE_DN'] = ['dc=quay', 'dc=io']
config['LDAP_ADMIN_DN'] = 'uid=testy,ou=employees,dc=quay,dc=io'
config['LDAP_ADMIN_PASSWD'] = 'password'
config['LDAP_USER_RDN'] = ['ou=employees']
config['LDAP_URI'] = uri
with pytest.raises(ConfigValidationException):
config = ValidatorContext(config, config_provider=config_provider)
LDAPValidator.validate(config)
@pytest.mark.parametrize('admin_dn, admin_passwd, user_rdn, expected_exception', [
('uid=testy,ou=employees,dc=quay,dc=io', 'password', ['ou=employees'], None),
('uid=invalidadmindn', 'password', ['ou=employees'], ConfigValidationException),
('uid=testy,ou=employees,dc=quay,dc=io', 'invalid_password', ['ou=employees'], ConfigValidationException),
('uid=testy,ou=employees,dc=quay,dc=io', 'password', ['ou=invalidgroup'], ConfigValidationException),
])
def test_validated_ldap(admin_dn, admin_passwd, user_rdn, expected_exception, app):
config = {}
config['AUTHENTICATION_TYPE'] = 'LDAP'
config['LDAP_BASE_DN'] = ['dc=quay', 'dc=io']
config['LDAP_ADMIN_DN'] = admin_dn
config['LDAP_ADMIN_PASSWD'] = admin_passwd
config['LDAP_USER_RDN'] = user_rdn
unvalidated_config = ValidatorContext(config, config_provider=config_provider)
if expected_exception is not None:
with pytest.raises(ConfigValidationException):
with mock_ldap():
LDAPValidator.validate(unvalidated_config)
else:
with mock_ldap():
LDAPValidator.validate(unvalidated_config)

View file

@ -0,0 +1,50 @@
import json
import pytest
from httmock import urlmatch, HTTMock
from config import build_requests_session
from oauth.oidc import OIDC_WELLKNOWN
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_oidc import OIDCLoginValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({'SOMETHING_LOGIN_CONFIG': {}}),
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo'}}),
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo', 'CLIENT_ID': 'foobar'}}),
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo', 'CLIENT_SECRET': 'foobar'}}),
])
def test_validate_invalid_oidc_login_config(unvalidated_config, app):
validator = OIDCLoginValidator()
with pytest.raises(ConfigValidationException):
validator.validate(ValidatorContext(unvalidated_config))
def test_validate_oidc_login(app):
url_hit = [False]
@urlmatch(netloc=r'someserver', path=r'/\.well-known/openid-configuration')
def handler(_, __):
url_hit[0] = True
data = {
'token_endpoint': 'foobar',
}
return {'status_code': 200, 'content': json.dumps(data)}
with HTTMock(handler):
validator = OIDCLoginValidator()
unvalidated_config = ValidatorContext({
'SOMETHING_LOGIN_CONFIG': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'OIDC_SERVER': 'http://someserver',
'DEBUGGING': True, # Allows for HTTP.
},
})
unvalidated_config.http_client = build_requests_session()
validator.validate(unvalidated_config)
assert url_hit[0]

View file

@ -0,0 +1,34 @@
import pytest
import redis
from mock import patch
from mockredis import mock_strict_redis_client
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_redis import RedisValidator
from test.fixtures import *
from util.morecollections import AttrDict
@pytest.mark.parametrize('unvalidated_config,user,user_password,use_mock,expected', [
({}, None, None, False, ConfigValidationException),
({'BUILDLOGS_REDIS': {}}, None, None, False, ConfigValidationException),
({'BUILDLOGS_REDIS': {'host': 'somehost'}}, None, None, False, redis.ConnectionError),
({'BUILDLOGS_REDIS': {'host': 'localhost'}}, None, None, True, None),
])
def test_validate_redis(unvalidated_config, user, user_password, use_mock, expected, app):
with patch('redis.StrictRedis' if use_mock else 'redis.None', mock_strict_redis_client):
validator = RedisValidator()
unvalidated_config = ValidatorContext(unvalidated_config)
unvalidated_config.user = AttrDict(dict(username=user))
unvalidated_config.user_password = user_password
if expected is not None:
with pytest.raises(expected):
validator.validate(unvalidated_config)
else:
validator.validate(unvalidated_config)

View file

@ -0,0 +1,48 @@
import pytest
from config import build_requests_session
from util.config import URLSchemeAndHostname
from util.config.validator import ValidatorContext
from util.config.validators.validate_secscan import SecurityScannerValidator
from util.secscan.fake import fake_security_scanner
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config', [
({'DISTRIBUTED_STORAGE_PREFERENCE': []}),
])
def test_validate_noop(unvalidated_config, app):
unvalidated_config = ValidatorContext(unvalidated_config, feature_sec_scanner=False, is_testing=True,
http_client=build_requests_session(),
url_scheme_and_hostname=URLSchemeAndHostname('http', 'localhost:5000'))
SecurityScannerValidator.validate(unvalidated_config)
@pytest.mark.parametrize('unvalidated_config, expected_error', [
({
'TESTING': True,
'DISTRIBUTED_STORAGE_PREFERENCE': [],
'FEATURE_SECURITY_SCANNER': True,
'SECURITY_SCANNER_ENDPOINT': 'http://invalidhost',
}, Exception),
({
'TESTING': True,
'DISTRIBUTED_STORAGE_PREFERENCE': [],
'FEATURE_SECURITY_SCANNER': True,
'SECURITY_SCANNER_ENDPOINT': 'http://fakesecurityscanner',
}, None),
])
def test_validate(unvalidated_config, expected_error, app):
unvalidated_config = ValidatorContext(unvalidated_config, feature_sec_scanner=True, is_testing=True,
http_client=build_requests_session(),
url_scheme_and_hostname=URLSchemeAndHostname('http', 'localhost:5000'))
with fake_security_scanner(hostname='fakesecurityscanner'):
if expected_error is not None:
with pytest.raises(expected_error):
SecurityScannerValidator.validate(unvalidated_config)
else:
SecurityScannerValidator.validate(unvalidated_config)

View file

@ -0,0 +1,20 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_signer import SignerValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config,expected', [
({}, None),
({'SIGNING_ENGINE': 'foobar'}, ConfigValidationException),
({'SIGNING_ENGINE': 'gpg2'}, Exception),
])
def test_validate_signer(unvalidated_config, expected, app):
validator = SignerValidator()
if expected is not None:
with pytest.raises(expected):
validator.validate(ValidatorContext(unvalidated_config))
else:
validator.validate(ValidatorContext(unvalidated_config))

View file

@ -0,0 +1,75 @@
import pytest
from mock import patch
from tempfile import NamedTemporaryFile
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES
from util.security.test.test_ssl_util import generate_test_cert
from test.fixtures import *
from app import config_provider
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'PREFERRED_URL_SCHEME': 'http'}),
({'PREFERRED_URL_SCHEME': 'https', 'EXTERNAL_TLS_TERMINATION': True}),
])
def test_skip_validate_ssl(unvalidated_config, app):
validator = SSLValidator()
validator.validate(ValidatorContext(unvalidated_config))
@pytest.mark.parametrize('cert, server_hostname, expected_error, error_message', [
('invalidcert', 'someserver', ConfigValidationException, 'Could not load SSL certificate: no start line'),
(generate_test_cert(hostname='someserver'), 'someserver', None, None),
(generate_test_cert(hostname='invalidserver'), 'someserver', ConfigValidationException,
'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'),
(generate_test_cert(hostname='someserver'), 'someserver:1234', None, None),
(generate_test_cert(hostname='invalidserver'), 'someserver:1234', ConfigValidationException,
'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'),
(generate_test_cert(hostname='someserver:1234'), 'someserver:1234', ConfigValidationException,
'Supported names "someserver:1234" in SSL cert do not match server hostname "someserver"'),
(generate_test_cert(hostname='someserver:more'), 'someserver:more', None, None),
(generate_test_cert(hostname='someserver:more'), 'someserver:more:1234', None, None),
])
def test_validate_ssl(cert, server_hostname, expected_error, error_message, app):
with NamedTemporaryFile(delete=False) as cert_file:
cert_file.write(cert[0])
cert_file.seek(0)
with NamedTemporaryFile(delete=False) as key_file:
key_file.write(cert[1])
key_file.seek(0)
def return_true(filename):
return True
def get_volume_file(filename):
if filename == SSL_FILENAMES[0]:
return open(cert_file.name)
if filename == SSL_FILENAMES[1]:
return open(key_file.name)
return None
config = {
'PREFERRED_URL_SCHEME': 'https',
'SERVER_HOSTNAME': server_hostname,
}
with patch('app.config_provider.volume_file_exists', return_true):
with patch('app.config_provider.get_volume_file', get_volume_file):
validator = SSLValidator()
config = ValidatorContext(config)
config.config_provider = config_provider
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
validator.validate(config)
assert str(ipe.value) == error_message
else:
validator.validate(config)

View file

@ -0,0 +1,40 @@
import pytest
from moto import mock_s3_deprecated as mock_s3
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_storage import StorageValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config, expected', [
({}, ConfigValidationException),
({'DISTRIBUTED_STORAGE_CONFIG': {}}, ConfigValidationException),
({'DISTRIBUTED_STORAGE_CONFIG': {'local': None}}, ConfigValidationException),
({'DISTRIBUTED_STORAGE_CONFIG': {'local': ['FakeStorage', {}]}}, None),
])
def test_validate_storage(unvalidated_config, expected, app):
validator = StorageValidator()
if expected is not None:
with pytest.raises(expected):
validator.validate(ValidatorContext(unvalidated_config))
else:
validator.validate(ValidatorContext(unvalidated_config))
def test_validate_s3_storage(app):
validator = StorageValidator()
with mock_s3():
with pytest.raises(ConfigValidationException) as ipe:
validator.validate(ValidatorContext({
'DISTRIBUTED_STORAGE_CONFIG': {
'default': ('S3Storage', {
's3_access_key': 'invalid',
's3_secret_key': 'invalid',
's3_bucket': 'somebucket',
'storage_path': ''
}),
}
}))
assert str(ipe.value) == 'Invalid storage configuration: default: S3ResponseError: 404 Not Found'

View file

@ -0,0 +1,32 @@
import pytest
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_timemachine import TimeMachineValidator
@pytest.mark.parametrize('unvalidated_config', [
({}),
])
def test_validate_noop(unvalidated_config):
TimeMachineValidator.validate(ValidatorContext(unvalidated_config))
from test.fixtures import *
@pytest.mark.parametrize('default_exp,options,expected_exception', [
('2d', ['1w', '2d'], None),
('2d', ['1w'], 'Default expiration must be in expiration options set'),
('2d', ['2d', '1M'], 'Invalid tag expiration option: 1M'),
])
def test_validate(default_exp, options, expected_exception, app):
config = {}
config['DEFAULT_TAG_EXPIRATION'] = default_exp
config['TAG_EXPIRATION_OPTIONS'] = options
if expected_exception is not None:
with pytest.raises(ConfigValidationException) as cve:
TimeMachineValidator.validate(ValidatorContext(config))
assert str(cve.value) == str(expected_exception)
else:
TimeMachineValidator.validate(ValidatorContext(config))

View file

@ -0,0 +1,39 @@
import pytest
from config import build_requests_session
from httmock import urlmatch, HTTMock
from app import instance_keys
from util.config.validator import ValidatorContext
from util.config.validators import ConfigValidationException
from util.config.validators.validate_torrent import BittorrentValidator
from test.fixtures import *
@pytest.mark.parametrize('unvalidated_config,expected', [
({}, ConfigValidationException),
({'BITTORRENT_ANNOUNCE_URL': 'http://faketorrent/announce'}, None),
])
def test_validate_torrent(unvalidated_config, expected, app):
announcer_hit = [False]
@urlmatch(netloc=r'faketorrent', path='/announce')
def handler(url, request):
announcer_hit[0] = True
return {'status_code': 200, 'content': ''}
with HTTMock(handler):
validator = BittorrentValidator()
if expected is not None:
with pytest.raises(expected):
config = ValidatorContext(unvalidated_config, instance_keys=instance_keys)
config.http_client = build_requests_session()
validator.validate(config)
assert not announcer_hit[0]
else:
config = ValidatorContext(unvalidated_config, instance_keys=instance_keys)
config.http_client = build_requests_session()
validator.validate(config)
assert announcer_hit[0]

View file

@ -0,0 +1,28 @@
from util.config.validators import BaseValidator, ConfigValidationException
from oauth.loginmanager import OAuthLoginManager
from oauth.oidc import OIDCLoginService
class AccessSettingsValidator(BaseValidator):
name = "access"
@classmethod
def validate(cls, validator_context):
config = validator_context.config
client = validator_context.http_client
if not config.get('FEATURE_DIRECT_LOGIN', True):
# Make sure we have at least one OIDC enabled.
github_login = config.get('FEATURE_GITHUB_LOGIN', False)
google_login = config.get('FEATURE_GOOGLE_LOGIN', False)
login_manager = OAuthLoginManager(config, client=client)
custom_oidc = [s for s in login_manager.services if isinstance(s, OIDCLoginService)]
if not github_login and not google_login and not custom_oidc:
msg = 'Cannot disable credentials login to UI without configured OIDC service'
raise ConfigValidationException(msg)
if (not config.get('FEATURE_USER_CREATION', True) and
config.get('FEATURE_INVITE_ONLY_USER_CREATION', False)):
msg = "Invite only user creation requires user creation to be enabled"
raise ConfigValidationException(msg)

View file

@ -0,0 +1,24 @@
from util.config.validators import BaseValidator, ConfigValidationException
class ActionLogArchivingValidator(BaseValidator):
name = "actionlogarchiving"
@classmethod
def validate(cls, validator_context):
config = validator_context.config
""" Validates the action log archiving configuration. """
if not config.get('FEATURE_ACTION_LOG_ROTATION', False):
return
if not config.get('ACTION_LOG_ARCHIVE_PATH'):
raise ConfigValidationException('Missing action log archive path')
if not config.get('ACTION_LOG_ARCHIVE_LOCATION'):
raise ConfigValidationException('Missing action log archive storage location')
location = config['ACTION_LOG_ARCHIVE_LOCATION']
storage_config = config.get('DISTRIBUTED_STORAGE_CONFIG') or {}
if location not in storage_config:
msg = 'Action log archive storage location `%s` not found in storage config' % location
raise ConfigValidationException(msg)

View file

@ -0,0 +1,21 @@
from util.config.validators import BaseValidator, ConfigValidationException
class AppTokenAuthValidator(BaseValidator):
name = "apptoken-auth"
@classmethod
def validate(cls, validator_context):
config = validator_context.config
if config.get('AUTHENTICATION_TYPE', 'Database') != 'AppToken':
return
# Ensure that app tokens are enabled, as they are required.
if not config.get('FEATURE_APP_SPECIFIC_TOKENS', False):
msg = 'Application token support must be enabled to use External Application Token auth'
raise ConfigValidationException(msg)
# Ensure that direct login is disabled.
if config.get('FEATURE_DIRECT_LOGIN', True):
msg = 'Direct login must be disabled to use External Application Token auth'
raise ConfigValidationException(msg)

View file

@ -0,0 +1,30 @@
from bitbucket import BitBucket
from util.config.validators import BaseValidator, ConfigValidationException
class BitbucketTriggerValidator(BaseValidator):
name = "bitbucket-trigger"
@classmethod
def validate(cls, validator_context):
""" Validates the config for BitBucket. """
config = validator_context.config
trigger_config = config.get('BITBUCKET_TRIGGER_CONFIG')
if not trigger_config:
raise ConfigValidationException('Missing client ID and client secret')
if not trigger_config.get('CONSUMER_KEY'):
raise ConfigValidationException('Missing Consumer Key')
if not trigger_config.get('CONSUMER_SECRET'):
raise ConfigValidationException('Missing Consumer Secret')
key = trigger_config['CONSUMER_KEY']
secret = trigger_config['CONSUMER_SECRET']
callback_url = '%s/oauth1/bitbucket/callback/trigger/' % (validator_context.url_scheme_and_hostname.get_url())
bitbucket_client = BitBucket(key, secret, callback_url)
(result, _, _) = bitbucket_client.get_authorization_url()
if not result:
raise ConfigValidationException('Invalid consumer key or secret')

View file

@ -0,0 +1,20 @@
from peewee import OperationalError
from data.database import validate_database_precondition
from util.config.validators import BaseValidator, ConfigValidationException
class DatabaseValidator(BaseValidator):
name = "database"
@classmethod
def validate(cls, validator_context):
""" Validates connecting to the database. """
config = validator_context.config
try:
validate_database_precondition(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {}))
except OperationalError as ex:
if ex.args and len(ex.args) > 1:
raise ConfigValidationException(ex.args[1])
else:
raise ex

View file

@ -0,0 +1,53 @@
from oauth.services.github import GithubOAuthService
from util.config.validators import BaseValidator, ConfigValidationException
class BaseGitHubValidator(BaseValidator):
name = None
config_key = None
@classmethod
def validate(cls, validator_context):
""" Validates the OAuth credentials and API endpoint for a Github service. """
config = validator_context.config
client = validator_context.http_client
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
github_config = config.get(cls.config_key)
if not github_config:
raise ConfigValidationException('Missing GitHub client id and client secret')
endpoint = github_config.get('GITHUB_ENDPOINT')
if not endpoint:
raise ConfigValidationException('Missing GitHub Endpoint')
if endpoint.find('http://') != 0 and endpoint.find('https://') != 0:
raise ConfigValidationException('Github Endpoint must start with http:// or https://')
if not github_config.get('CLIENT_ID'):
raise ConfigValidationException('Missing Client ID')
if not github_config.get('CLIENT_SECRET'):
raise ConfigValidationException('Missing Client Secret')
if github_config.get('ORG_RESTRICT') and not github_config.get('ALLOWED_ORGANIZATIONS'):
raise ConfigValidationException('Organization restriction must have at least one allowed ' +
'organization')
oauth = GithubOAuthService(config, cls.config_key)
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
if not result:
raise ConfigValidationException('Invalid client id or client secret')
if github_config.get('ALLOWED_ORGANIZATIONS'):
for org_id in github_config.get('ALLOWED_ORGANIZATIONS'):
if not oauth.validate_organization(org_id, client):
raise ConfigValidationException('Invalid organization: %s' % org_id)
class GitHubLoginValidator(BaseGitHubValidator):
name = "github-login"
config_key = "GITHUB_LOGIN_CONFIG"
class GitHubTriggerValidator(BaseGitHubValidator):
name = "github-trigger"
config_key = "GITHUB_TRIGGER_CONFIG"

View file

@ -0,0 +1,32 @@
from oauth.services.gitlab import GitLabOAuthService
from util.config.validators import BaseValidator, ConfigValidationException
class GitLabTriggerValidator(BaseValidator):
name = "gitlab-trigger"
@classmethod
def validate(cls, validator_context):
""" Validates the OAuth credentials and API endpoint for a GitLab service. """
config = validator_context.config
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
client = validator_context.http_client
github_config = config.get('GITLAB_TRIGGER_CONFIG')
if not github_config:
raise ConfigValidationException('Missing GitLab client id and client secret')
endpoint = github_config.get('GITLAB_ENDPOINT')
if endpoint:
if endpoint.find('http://') != 0 and endpoint.find('https://') != 0:
raise ConfigValidationException('GitLab Endpoint must start with http:// or https://')
if not github_config.get('CLIENT_ID'):
raise ConfigValidationException('Missing Client ID')
if not github_config.get('CLIENT_SECRET'):
raise ConfigValidationException('Missing Client Secret')
oauth = GitLabOAuthService(config, 'GITLAB_TRIGGER_CONFIG')
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
if not result:
raise ConfigValidationException('Invalid client id or client secret')

View file

@ -0,0 +1,27 @@
from oauth.services.google import GoogleOAuthService
from util.config.validators import BaseValidator, ConfigValidationException
class GoogleLoginValidator(BaseValidator):
name = "google-login"
@classmethod
def validate(cls, validator_context):
""" Validates the Google Login client ID and secret. """
config = validator_context.config
client = validator_context.http_client
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
google_login_config = config.get('GOOGLE_LOGIN_CONFIG')
if not google_login_config:
raise ConfigValidationException('Missing client ID and client secret')
if not google_login_config.get('CLIENT_ID'):
raise ConfigValidationException('Missing Client ID')
if not google_login_config.get('CLIENT_SECRET'):
raise ConfigValidationException('Missing Client Secret')
oauth = GoogleOAuthService(config, 'GOOGLE_LOGIN_CONFIG')
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
if not result:
raise ConfigValidationException('Invalid client id or client secret')

View file

@ -0,0 +1,48 @@
import os
from data.users.externaljwt import ExternalJWTAuthN
from util.config.validators import BaseValidator, ConfigValidationException
class JWTAuthValidator(BaseValidator):
name = "jwt"
@classmethod
def validate(cls, validator_context, public_key_path=None):
""" Validates the JWT authentication system. """
config = validator_context.config
http_client = validator_context.http_client
jwt_auth_max = validator_context.jwt_auth_max
config_provider = validator_context.config_provider
if config.get('AUTHENTICATION_TYPE', 'Database') != 'JWT':
return
verify_endpoint = config.get('JWT_VERIFY_ENDPOINT')
query_endpoint = config.get('JWT_QUERY_ENDPOINT', None)
getuser_endpoint = config.get('JWT_GETUSER_ENDPOINT', None)
issuer = config.get('JWT_AUTH_ISSUER')
if not verify_endpoint:
raise ConfigValidationException('Missing JWT Verification endpoint')
if not issuer:
raise ConfigValidationException('Missing JWT Issuer ID')
override_config_directory = config_provider.get_config_dir_path()
# Try to instatiate the JWT authentication mechanism. This will raise an exception if
# the key cannot be found.
users = ExternalJWTAuthN(verify_endpoint, query_endpoint, getuser_endpoint, issuer,
override_config_directory,
http_client,
jwt_auth_max,
public_key_path=public_key_path,
requires_email=config.get('FEATURE_MAILING', True))
# Verify that we can reach the jwt server
(result, err_msg) = users.ping()
if not result:
msg = ('Verification of JWT failed: %s. \n\nWe cannot reach the JWT server' +
'OR JWT auth is misconfigured') % err_msg
raise ConfigValidationException(msg)

View file

@ -0,0 +1,44 @@
from util.config.validators import BaseValidator, ConfigValidationException
from data.users.keystone import get_keystone_users
class KeystoneValidator(BaseValidator):
name = "keystone"
@classmethod
def validate(cls, validator_context):
""" Validates the Keystone authentication system. """
config = validator_context.config
if config.get('AUTHENTICATION_TYPE', 'Database') != 'Keystone':
return
auth_url = config.get('KEYSTONE_AUTH_URL')
auth_version = int(config.get('KEYSTONE_AUTH_VERSION', 2))
admin_username = config.get('KEYSTONE_ADMIN_USERNAME')
admin_password = config.get('KEYSTONE_ADMIN_PASSWORD')
admin_tenant = config.get('KEYSTONE_ADMIN_TENANT')
if not auth_url:
raise ConfigValidationException('Missing authentication URL')
if not admin_username:
raise ConfigValidationException('Missing admin username')
if not admin_password:
raise ConfigValidationException('Missing admin password')
if not admin_tenant:
raise ConfigValidationException('Missing admin tenant')
requires_email = config.get('FEATURE_MAILING', True)
users = get_keystone_users(auth_version, auth_url, admin_username, admin_password, admin_tenant,
requires_email)
# Verify that the superuser exists. If not, raise an exception.
(result, err_msg) = users.at_least_one_user_exists()
if not result:
msg = ('Verification that users exist failed: %s. \n\nNo users exist ' +
'in the admin tenant/project ' +
'in the remote authentication system ' +
'OR Keystone auth is misconfigured.') % err_msg
raise ConfigValidationException(msg)

View file

@ -0,0 +1,68 @@
import os
import ldap
import subprocess
from data.users import LDAP_CERT_FILENAME
from data.users.externalldap import LDAPConnection, LDAPUsers
from util.config.validators import BaseValidator, ConfigValidationException
class LDAPValidator(BaseValidator):
name = "ldap"
@classmethod
def validate(cls, validator_context):
""" Validates the LDAP connection. """
config = validator_context.config
config_provider = validator_context.config_provider
init_scripts_location = validator_context.init_scripts_location
if config.get('AUTHENTICATION_TYPE', 'Database') != 'LDAP':
return
# If there is a custom LDAP certificate, then reinstall the certificates for the container.
if config_provider.volume_file_exists(LDAP_CERT_FILENAME):
subprocess.check_call([os.path.join(init_scripts_location, 'certs_install.sh')],
env={ 'QUAYCONFIG': config_provider.get_config_dir_path() })
# Note: raises ldap.INVALID_CREDENTIALS on failure
admin_dn = config.get('LDAP_ADMIN_DN')
admin_passwd = config.get('LDAP_ADMIN_PASSWD')
if not admin_dn:
raise ConfigValidationException('Missing Admin DN for LDAP configuration')
if not admin_passwd:
raise ConfigValidationException('Missing Admin Password for LDAP configuration')
ldap_uri = config.get('LDAP_URI', 'ldap://localhost')
if not ldap_uri.startswith('ldap://') and not ldap_uri.startswith('ldaps://'):
raise ConfigValidationException('LDAP URI must start with ldap:// or ldaps://')
allow_tls_fallback = config.get('LDAP_ALLOW_INSECURE_FALLBACK', False)
try:
with LDAPConnection(ldap_uri, admin_dn, admin_passwd, allow_tls_fallback):
pass
except ldap.LDAPError as ex:
values = ex.args[0] if ex.args else {}
if not isinstance(values, dict):
raise ConfigValidationException(str(ex.args))
raise ConfigValidationException(values.get('desc', 'Unknown error'))
base_dn = config.get('LDAP_BASE_DN')
user_rdn = config.get('LDAP_USER_RDN', [])
uid_attr = config.get('LDAP_UID_ATTR', 'uid')
email_attr = config.get('LDAP_EMAIL_ATTR', 'mail')
requires_email = config.get('FEATURE_MAILING', True)
users = LDAPUsers(ldap_uri, base_dn, admin_dn, admin_passwd, user_rdn, uid_attr, email_attr,
allow_tls_fallback, requires_email=requires_email)
# Ensure at least one user exists to verify the connection is setup properly.
(result, err_msg) = users.at_least_one_user_exists()
if not result:
msg = ('Verification that users exist failed: %s. \n\nNo users exist ' +
'in the remote authentication system ' +
'OR LDAP auth is misconfigured.') % err_msg
raise ConfigValidationException(msg)

View file

@ -0,0 +1,36 @@
from oauth.loginmanager import OAuthLoginManager
from oauth.oidc import OIDCLoginService, DiscoveryFailureException
from util.config.validators import BaseValidator, ConfigValidationException
class OIDCLoginValidator(BaseValidator):
name = "oidc-login"
@classmethod
def validate(cls, validator_context):
config = validator_context.config
client = validator_context.http_client
login_manager = OAuthLoginManager(config, client=client)
for service in login_manager.services:
if not isinstance(service, OIDCLoginService):
continue
if service.config.get('OIDC_SERVER') is None:
msg = 'Missing OIDC_SERVER on OIDC service %s' % service.service_id()
raise ConfigValidationException(msg)
if service.config.get('CLIENT_ID') is None:
msg = 'Missing CLIENT_ID on OIDC service %s' % service.service_id()
raise ConfigValidationException(msg)
if service.config.get('CLIENT_SECRET') is None:
msg = 'Missing CLIENT_SECRET on OIDC service %s' % service.service_id()
raise ConfigValidationException(msg)
try:
if not service.validate():
msg = 'Could not validate OIDC service %s' % service.service_id()
raise ConfigValidationException(msg)
except DiscoveryFailureException as dfe:
msg = 'Could not validate OIDC service %s: %s' % (service.service_id(), dfe.message)
raise ConfigValidationException(msg)

View file

@ -0,0 +1,18 @@
import redis
from util.config.validators import BaseValidator, ConfigValidationException
class RedisValidator(BaseValidator):
name = "redis"
@classmethod
def validate(cls, validator_context):
""" Validates connecting to redis. """
config = validator_context.config
redis_config = config.get('BUILDLOGS_REDIS', {})
if not 'host' in redis_config:
raise ConfigValidationException('Missing redis hostname')
client = redis.StrictRedis(socket_connect_timeout=5, **redis_config)
client.ping()

View file

@ -0,0 +1,52 @@
import time
# from boot import setup_jwt_proxy
from util.secscan.api import SecurityScannerAPI
from util.config.validators import BaseValidator, ConfigValidationException
class SecurityScannerValidator(BaseValidator):
name = "security-scanner"
@classmethod
def validate(cls, validator_context):
""" Validates the configuration for talking to a Quay Security Scanner. """
config = validator_context.config
client = validator_context.http_client
feature_sec_scanner = validator_context.feature_sec_scanner
is_testing = validator_context.is_testing
server_hostname = validator_context.url_scheme_and_hostname.hostname
uri_creator = validator_context.uri_creator
if not feature_sec_scanner:
return
api = SecurityScannerAPI(config, None, server_hostname, client=client, skip_validation=True, uri_creator=uri_creator)
# if not is_testing:
# Generate a temporary Quay key to use for signing the outgoing requests.
# setup_jwt_proxy()
# We have to wait for JWT proxy to restart with the newly generated key.
max_tries = 5
response = None
last_exception = None
while max_tries > 0:
try:
response = api.ping()
last_exception = None
if response.status_code == 200:
return
except Exception as ex:
last_exception = ex
time.sleep(1)
max_tries = max_tries - 1
if last_exception is not None:
message = str(last_exception)
raise ConfigValidationException('Could not ping security scanner: %s' % message)
else:
message = 'Expected 200 status code, got %s: %s' % (response.status_code, response.text)
raise ConfigValidationException('Could not ping security scanner: %s' % message)

View file

@ -0,0 +1,22 @@
from StringIO import StringIO
from util.config.validators import BaseValidator, ConfigValidationException
from util.security.signing import SIGNING_ENGINES
class SignerValidator(BaseValidator):
name = "signer"
@classmethod
def validate(cls, validator_context):
""" Validates the GPG public+private key pair used for signing converted ACIs. """
config = validator_context.config
config_provider = validator_context.config_provider
if config.get('SIGNING_ENGINE') is None:
return
if config['SIGNING_ENGINE'] not in SIGNING_ENGINES:
raise ConfigValidationException('Unknown signing engine: %s' % config['SIGNING_ENGINE'])
engine = SIGNING_ENGINES[config['SIGNING_ENGINE']](config, config_provider)
engine.detached_sign(StringIO('test string'))

View file

@ -0,0 +1,72 @@
from util.config.validators import BaseValidator, ConfigValidationException
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
SSL_FILENAMES = ['ssl.cert', 'ssl.key']
class SSLValidator(BaseValidator):
name = "ssl"
@classmethod
def validate(cls, validator_context):
""" Validates the SSL configuration (if enabled). """
config = validator_context.config
config_provider = validator_context.config_provider
# Skip if non-SSL.
if config.get('PREFERRED_URL_SCHEME', 'http') != 'https':
return
# Skip if externally terminated.
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
return
# Verify that we have all the required SSL files.
for filename in SSL_FILENAMES:
if not config_provider.volume_file_exists(filename):
raise ConfigValidationException('Missing required SSL file: %s' % filename)
# Read the contents of the SSL certificate.
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
cert_contents = f.read()
# Validate the certificate.
try:
certificate = load_certificate(cert_contents)
except CertInvalidException as cie:
raise ConfigValidationException('Could not load SSL certificate: %s' % cie)
# Verify the certificate has not expired.
if certificate.expired:
raise ConfigValidationException('The specified SSL certificate has expired.')
# Verify the hostname matches the name in the certificate.
if not certificate.matches_name(_ssl_cn(config['SERVER_HOSTNAME'])):
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(certificate.names)), _ssl_cn(config['SERVER_HOSTNAME'])))
raise ConfigValidationException(msg)
# Verify the private key against the certificate.
private_key_path = None
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
private_key_path = f.name
if not private_key_path:
# Only in testing.
return
try:
certificate.validate_private_key(private_key_path)
except KeyInvalidException as kie:
raise ConfigValidationException('SSL private key failed to validate: %s' % kie)
def _ssl_cn(server_hostname):
""" Return the common name (fully qualified host name) from the SERVER_HOSTNAME. """
host_port = server_hostname.rsplit(':', 1)
# SERVER_HOSTNAME includes the port
if len(host_port) == 2:
if host_port[-1].isdigit():
return host_port[-2]
return server_hostname

View file

@ -0,0 +1,54 @@
from storage import get_storage_driver, TYPE_LOCAL_STORAGE
from util.config.validators import BaseValidator, ConfigValidationException
class StorageValidator(BaseValidator):
name = "registry-storage"
@classmethod
def validate(cls, validator_context):
""" Validates registry storage. """
config = validator_context.config
client = validator_context.http_client
ip_resolver = validator_context.ip_resolver
config_provider = validator_context.config_provider
replication_enabled = config.get('FEATURE_STORAGE_REPLICATION', False)
providers = _get_storage_providers(config, ip_resolver, config_provider).items()
if not providers:
raise ConfigValidationException('Storage configuration required')
for name, (storage_type, driver) in providers:
# We can skip localstorage validation, since we can't guarantee that
# this will be the same machine Q.E. will run under
if storage_type == TYPE_LOCAL_STORAGE:
continue
try:
if replication_enabled and storage_type == 'LocalStorage':
raise ConfigValidationException('Locally mounted directory not supported ' +
'with storage replication')
# Run validation on the driver.
driver.validate(client)
# Run setup on the driver if the read/write succeeded.
driver.setup()
except Exception as ex:
msg = str(ex).strip().split("\n")[0]
raise ConfigValidationException('Invalid storage configuration: %s: %s' % (name, msg))
def _get_storage_providers(config, ip_resolver, config_provider):
storage_config = config.get('DISTRIBUTED_STORAGE_CONFIG', {})
drivers = {}
try:
for name, parameters in storage_config.items():
driver = get_storage_driver(None, None, None, config_provider, ip_resolver, parameters)
drivers[name] = (parameters[0], driver)
except TypeError:
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)
return drivers

View file

@ -0,0 +1,31 @@
import logging
from util.config.validators import BaseValidator, ConfigValidationException
from util.timedeltastring import convert_to_timedelta
logger = logging.getLogger(__name__)
class TimeMachineValidator(BaseValidator):
name = "time-machine"
@classmethod
def validate(cls, validator_context):
config = validator_context.config
if not 'DEFAULT_TAG_EXPIRATION' in config:
# Old style config
return
try:
convert_to_timedelta(config['DEFAULT_TAG_EXPIRATION']).total_seconds()
except ValueError as ve:
raise ConfigValidationException('Invalid default expiration: %s' % ve.message)
if not config['DEFAULT_TAG_EXPIRATION'] in config.get('TAG_EXPIRATION_OPTIONS', []):
raise ConfigValidationException('Default expiration must be in expiration options set')
for ts in config.get('TAG_EXPIRATION_OPTIONS', []):
try:
convert_to_timedelta(ts)
except ValueError as ve:
raise ConfigValidationException('Invalid tag expiration option: %s' % ts)

View file

@ -0,0 +1,59 @@
import logging
from hashlib import sha1
from util.config.validators import BaseValidator, ConfigValidationException
from util.registry.torrent import jwt_from_infohash, TorrentConfiguration
logger = logging.getLogger(__name__)
class BittorrentValidator(BaseValidator):
name = "bittorrent"
@classmethod
def validate(cls, validator_context):
""" Validates the configuration for using BitTorrent for downloads. """
config = validator_context.config
client = validator_context.http_client
announce_url = config.get('BITTORRENT_ANNOUNCE_URL')
if not announce_url:
raise ConfigValidationException('Missing announce URL')
# Ensure that the tracker is reachable and accepts requests signed with a registry key.
params = {
'info_hash': sha1('test').digest(),
'peer_id': '-QUAY00-6wfG2wk6wWLc',
'uploaded': 0,
'downloaded': 0,
'left': 0,
'numwant': 0,
'port': 80,
}
torrent_config = TorrentConfiguration.for_testing(validator_context.instance_keys, announce_url,
validator_context.registry_title)
encoded_jwt = jwt_from_infohash(torrent_config, params['info_hash'])
params['jwt'] = encoded_jwt
resp = client.get(announce_url, timeout=5, params=params)
logger.debug('Got tracker response: %s: %s', resp.status_code, resp.text)
if resp.status_code == 404:
raise ConfigValidationException('Announce path not found; did you forget `/announce`?')
if resp.status_code == 500:
raise ConfigValidationException('Did not get expected response from Tracker; ' +
'please check your settings')
if resp.status_code == 200:
if 'invalid jwt' in resp.text:
raise ConfigValidationException('Could not authorize to Tracker; is your Tracker ' +
'properly configured?')
if 'failure reason' in resp.text:
raise ConfigValidationException('Could not validate signed announce request: ' + resp.text)
if 'go_goroutines' in resp.text:
raise ConfigValidationException('Could not validate signed announce request: ' +
'provided port is used for Prometheus')

92
util/dict_wrappers.py Normal file
View file

@ -0,0 +1,92 @@
import json
from jsonpath_rw import parse
class SafeDictSetter(object):
""" Specialized write-only dictionary wrapper class that allows for setting
nested keys via a path syntax.
Example:
sds = SafeDictSetter()
sds['foo.bar.baz'] = 'hello' # Sets 'foo' = {'bar': {'baz': 'hello'}}
sds['somekey'] = None # Does not set the key since the value is None
"""
def __init__(self, initial_object=None):
self._object = initial_object or {}
def __setitem__(self, path, value):
self.set(path, value)
def set(self, path, value, allow_none=False):
""" Sets the value of the given path to the given value. """
if value is None and not allow_none:
return
pieces = path.split('.')
current = self._object
for piece in pieces[:len(pieces)-1]:
current_obj = current.get(piece, {})
if not isinstance(current_obj, dict):
raise Exception('Key %s is a non-object value: %s' % (piece, current_obj))
current[piece] = current_obj
current = current_obj
current[pieces[-1]] = value
def dict_value(self):
""" Returns the dict value built. """
return self._object
def json_value(self):
""" Returns the JSON string value of the dictionary built. """
return json.dumps(self._object)
class JSONPathDict(object):
""" Specialized read-only dictionary wrapper class that uses the jsonpath_rw library
to access keys via an X-Path-like syntax.
Example:
pd = JSONPathDict({'hello': {'hi': 'there'}})
pd['hello.hi'] # Returns 'there'
"""
def __init__(self, dict_value):
""" Init the helper with the JSON object.
"""
self._object = dict_value
def __getitem__(self, path):
return self.get(path)
def __iter__(self):
return self._object.itervalues()
def iterkeys(self):
return self._object.iterkeys()
def get(self, path, not_found_handler=None):
""" Returns the value found at the given path. Path is a json-path expression. """
if self._object == {} or self._object is None:
return None
jsonpath_expr = parse(path)
try:
matches = jsonpath_expr.find(self._object)
except IndexError:
return None
if not matches:
return not_found_handler() if not_found_handler else None
match = matches[0].value
if not match:
return not_found_handler() if not_found_handler else None
if isinstance(match, dict):
return JSONPathDict(match)
return match
def keys(self):
return self._object.keys()

110
util/disableabuser.py Normal file
View file

@ -0,0 +1,110 @@
import argparse
from datetime import datetime
from app import tf
from data import model
from data.model import db_transaction
from data.database import QueueItem, Repository, RepositoryBuild, RepositoryBuildTrigger, RepoMirrorConfig
from data.queue import WorkQueue
def ask_disable_namespace(username, queue_name):
user = model.user.get_namespace_user(username)
if user is None:
raise Exception('Unknown user or organization %s' % username)
if not user.enabled:
print "NOTE: Namespace %s is already disabled" % username
queue_prefix = '%s/%s/%%' % (queue_name, username)
existing_queue_item_count = (QueueItem
.select()
.where(QueueItem.queue_name ** queue_prefix)
.where(QueueItem.available == 1,
QueueItem.retries_remaining > 0,
QueueItem.processing_expires > datetime.now())
.count())
repository_trigger_count = (RepositoryBuildTrigger
.select()
.join(Repository)
.where(Repository.namespace_user == user)
.count())
print "============================================="
print "For namespace %s" % username
print "============================================="
print "User %s has email address %s" % (username, user.email)
print "User %s has %s queued builds in their namespace" % (username, existing_queue_item_count)
print "User %s has %s build triggers in their namespace" % (username, repository_trigger_count)
confirm_msg = "Would you like to disable this user and delete their triggers and builds? [y/N]> "
letter = str(raw_input(confirm_msg))
if letter.lower() != 'y':
print "Action canceled"
return
print "============================================="
triggers = []
count_removed = 0
with db_transaction():
user.enabled = False
user.save()
repositories_query = Repository.select().where(Repository.namespace_user == user)
if len(repositories_query.clone()):
builds = list(RepositoryBuild
.select()
.where(RepositoryBuild.repository << list(repositories_query)))
triggers = list(RepositoryBuildTrigger
.select()
.where(RepositoryBuildTrigger.repository << list(repositories_query)))
mirrors = list(RepoMirrorConfig
.select()
.where(RepoMirrorConfig.repository << list(repositories_query)))
# Delete all builds for the user's repositories.
if builds:
RepositoryBuild.delete().where(RepositoryBuild.id << builds).execute()
# Delete all build triggers for the user's repositories.
if triggers:
RepositoryBuildTrigger.delete().where(RepositoryBuildTrigger.id << triggers).execute()
# Delete all mirrors for the user's repositories.
if mirrors:
RepoMirrorConfig.delete().where(RepoMirrorConfig.id << mirrors).execute()
# Delete all queue items for the user's namespace.
dockerfile_build_queue = WorkQueue(queue_name, tf, has_namespace=True)
count_removed = dockerfile_build_queue.delete_namespaced_items(user.username)
info = (user.username, len(triggers), count_removed, len(mirrors))
print "Namespace %s disabled, %s triggers deleted, %s queued builds removed, %s mirrors deleted" % info
return user
def disable_abusing_user(username, queue_name):
if not username:
raise Exception('Must enter a username')
# Disable the namespace itself.
user = ask_disable_namespace(username, queue_name)
# If an organization, ask if all team members should be disabled as well.
if user.organization:
members = model.organization.get_organization_member_set(user)
for membername in members:
ask_disable_namespace(membername, queue_name)
parser = argparse.ArgumentParser(description='Disables a user abusing the build system')
parser.add_argument('username', help='The username of the abuser')
parser.add_argument('queuename', help='The name of the dockerfile build queue ' +
'(e.g. `dockerfilebuild` or `dockerfilebuildstaging`)')
args = parser.parse_args()
disable_abusing_user(args.username, args.queuename)

106
util/dockerfileparse.py Normal file
View file

@ -0,0 +1,106 @@
import re
LINE_CONTINUATION_REGEX = re.compile(r'(\s)*\\(\s)*\n')
COMMAND_REGEX = re.compile('([A-Za-z]+)\s(.*)')
COMMENT_CHARACTER = '#'
LATEST_TAG = 'latest'
class ParsedDockerfile(object):
def __init__(self, commands):
self.commands = commands
def _get_commands_of_kind(self, kind):
return [command for command in self.commands if command['command'] == kind]
def _get_from_image_identifier(self):
from_commands = self._get_commands_of_kind('FROM')
if not from_commands:
return None
return from_commands[-1]['parameters']
@staticmethod
def parse_image_identifier(image_identifier):
""" Parses a docker image identifier, and returns a tuple of image name and tag, where the tag
is filled in with "latest" if left unspecified.
"""
# Note:
# Dockerfile images references can be of multiple forms:
# server:port/some/path
# somepath
# server/some/path
# server/some/path:tag
# server:port/some/path:tag
parts = image_identifier.strip().split(':')
if len(parts) == 1:
# somepath
return (parts[0], LATEST_TAG)
# Otherwise, determine if the last part is a port
# or a tag.
if parts[-1].find('/') >= 0:
# Last part is part of the hostname.
return (image_identifier, LATEST_TAG)
# Remaining cases:
# server/some/path:tag
# server:port/some/path:tag
return (':'.join(parts[0:-1]), parts[-1])
def get_base_image(self):
""" Return the base image without the tag name. """
return self.get_image_and_tag()[0]
def get_image_and_tag(self):
""" Returns the image and tag from the FROM line of the dockerfile. """
image_identifier = self._get_from_image_identifier()
if image_identifier is None:
return (None, None)
return self.parse_image_identifier(image_identifier)
def strip_comments(contents):
lines = []
for line in contents.split('\n'):
index = line.find(COMMENT_CHARACTER)
if index < 0:
lines.append(line)
continue
line = line[:index]
lines.append(line)
return '\n'.join(lines)
def join_continued_lines(contents):
return LINE_CONTINUATION_REGEX.sub('', contents)
def parse_dockerfile(contents):
# If we receive ASCII, translate into unicode.
try:
contents = contents.decode('utf-8')
except ValueError:
# Already unicode or unable to convert.
pass
contents = join_continued_lines(strip_comments(contents))
lines = [line.strip() for line in contents.split('\n') if len(line) > 0]
commands = []
for line in lines:
match_command = COMMAND_REGEX.match(line)
if match_command:
command = match_command.group(1).upper()
parameters = match_command.group(2)
commands.append({
'command': command,
'parameters': parameters
})
return ParsedDockerfile(commands)

7
util/dynamic.py Normal file
View file

@ -0,0 +1,7 @@
def import_class(module_name, class_name):
""" Import a class given the specified module name and class name. """
klass = __import__(module_name)
class_segments = class_name.split('.')
for segment in class_segments:
klass = getattr(klass, segment)
return klass

84
util/expiresdict.py Normal file
View file

@ -0,0 +1,84 @@
from datetime import datetime
from six import iteritems
class ExpiresEntry(object):
""" A single entry under a ExpiresDict. """
def __init__(self, value, expires=None):
self.value = value
self._expiration = expires
@property
def expired(self):
if self._expiration is None:
return False
return datetime.now() >= self._expiration
class ExpiresDict(object):
""" ExpiresDict defines a dictionary-like class whose keys have expiration. The rebuilder is
a function that returns the full contents of the cached dictionary as a dict of the keys
and whose values are TTLEntry's. If the rebuilder is None, then no rebuilding is performed.
"""
def __init__(self, rebuilder=None):
self._rebuilder = rebuilder
self._items = {}
def __getitem__(self, key):
found = self.get(key)
if found is None:
raise KeyError
return found
def get(self, key, default_value=None):
# Check the cache first. If the key is found and it has not yet expired,
# return it.
found = self._items.get(key)
if found is not None and not found.expired:
return found.value
# Otherwise the key has expired or was not found. Rebuild the cache and check it again.
items = self._rebuild()
found_item = items.get(key)
if found_item is None:
return default_value
return found_item.value
def __contains__(self, key):
return self.get(key) is not None
def _rebuild(self):
if self._rebuilder is None:
return self._items
items = self._rebuilder()
self._items = items
return items
def _alive_items(self):
return {k: entry.value for (k, entry) in self._items.items() if not entry.expired}
def items(self):
return self._alive_items().items()
def iteritems(self):
return iteritems(self._alive_items())
def __iter__(self):
return iter(self._alive_items())
def __delitem__(self, key):
del self._items[key]
def __len__(self):
return len(self._alive_items())
def set(self, key, value, expires=None):
self._items[key] = ExpiresEntry(value, expires=expires)
def __setitem__(self, key, value):
return self.set(key, value)

50
util/failover.py Normal file
View file

@ -0,0 +1,50 @@
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class FailoverException(Exception):
""" Exception raised when an operation should be retried by the failover decorator.
Wraps the exception of the initial failure.
"""
def __init__(self, exception):
super(FailoverException, self).__init__()
self.exception = exception
def failover(func):
""" Wraps a function such that it can be retried on specified failures.
Raises FailoverException when all failovers are exhausted.
Example:
@failover
def get_google(scheme, use_www=False):
www = 'www.' if use_www else ''
try:
r = requests.get(scheme + '://' + www + 'google.com')
except requests.RequestException as ex:
raise FailoverException(ex)
return r
def GooglePingTest():
r = get_google(
(('http'), {'use_www': False}),
(('http'), {'use_www': True}),
(('https'), {'use_www': False}),
(('https'), {'use_www': True}),
)
print('Successfully contacted ' + r.url)
"""
@wraps(func)
def wrapper(*args_sets):
for arg_set in args_sets:
try:
return func(*arg_set[0], **arg_set[1])
except FailoverException as ex:
logger.debug('failing over')
exception = ex.exception
continue
raise exception
return wrapper

70
util/fixuseradmin.py Normal file
View file

@ -0,0 +1,70 @@
import argparse
import sys
from app import app
from data.database import Namespace, Repository, RepositoryPermission, Role
from data.model.permission import get_user_repo_permissions
from data.model.user import get_active_users, get_nonrobot_user
DESCRIPTION = '''
Fix user repositories missing admin permissions for owning user.
'''
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument('users', nargs='*', help='Users to check')
parser.add_argument('-a', '--all', action='store_true', help='Check all users')
parser.add_argument('-n', '--dry-run', action='store_true', help="Don't act")
ADMIN = Role.get(name='admin')
def repos_for_namespace(namespace):
return (Repository
.select(Repository.id, Repository.name, Namespace.username)
.join(Namespace)
.where(Namespace.username == namespace))
def has_admin(user, repo):
perms = get_user_repo_permissions(user, repo)
return any(p.role == ADMIN for p in perms)
def get_users(all_users=False, users_list=None):
if all_users:
return get_active_users(disabled=False)
return map(get_nonrobot_user, users_list)
def ensure_admin(user, repos, dry_run=False):
repos = [repo for repo in repos if not has_admin(user, repo)]
for repo in repos:
print('User {} missing admin on: {}'.format(user.username, repo.name))
if not dry_run:
RepositoryPermission.create(user=user, repository=repo, role=ADMIN)
print('Granted {} admin on: {}'.format(user.username, repo.name))
return len(repos)
def main():
args = parser.parse_args()
found = 0
if not args.all and len(args.users) == 0:
sys.exit('Need a list of users or --all')
for user in get_users(all_users=args.all, users_list=args.users):
if user is not None:
repos = repos_for_namespace(user.username)
found += ensure_admin(user, repos, dry_run=args.dry_run)
print('\nFound {} user repos missing admin'
' permissions for owner.'.format(found))
if __name__ == '__main__':
main()

View file

@ -0,0 +1,56 @@
import argparse
from dateutil.parser import parse as parse_date
from app import app
from data import model
from data.database import ServiceKeyApprovalType
from data.logs_model import logs_model
def generate_key(service, name, expiration_date=None, notes=None):
metadata = {
'created_by': 'CLI tool',
}
# Generate a key with a private key that we *never save*.
(private_key, key) = model.service_keys.generate_service_key(service, expiration_date,
metadata=metadata,
name=name)
# Auto-approve the service key.
model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.AUTOMATIC, notes=notes or '')
# Log the creation and auto-approval of the service key.
key_log_metadata = {
'kid': key.kid,
'preshared': True,
'service': service,
'name': name,
'expiration_date': expiration_date,
'auto_approved': True,
}
logs_model.log_action('service_key_create', metadata=key_log_metadata)
logs_model.log_action('service_key_approve', metadata=key_log_metadata)
return private_key, key.kid
def valid_date(s):
try:
return parse_date(s)
except ValueError:
msg = "Not a valid date: '{0}'.".format(s)
raise argparse.ArgumentTypeError(msg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generates a preshared key')
parser.add_argument('service', help='The service name for which the key is being generated')
parser.add_argument('name', help='The friendly name for the key')
parser.add_argument('--expiration', default=None, type=valid_date,
help='The optional expiration date for the key')
parser.add_argument('--notes', help='Optional notes about the key', default=None)
args = parser.parse_args()
generated, _ = generate_key(args.service, args.name, args.expiration, args.notes)
print generated.exportKey('PEM')

20
util/headers.py Normal file
View file

@ -0,0 +1,20 @@
import base64
def parse_basic_auth(header_value):
""" Attempts to parse the given header value as a Base64-encoded Basic auth header. """
if not header_value:
return None
parts = header_value.split(' ')
if len(parts) != 2 or parts[0].lower() != 'basic':
return None
try:
basic_parts = base64.b64decode(parts[1]).split(':', 1)
if len(basic_parts) != 2:
return None
return basic_parts
except ValueError:
return None

60
util/html.py Normal file
View file

@ -0,0 +1,60 @@
from bs4 import BeautifulSoup, Tag, NavigableString
_NEWLINE_INDICATOR = '<<<newline>>>'
def _bold(elem):
elem.replace_with('*%s*' % elem.text)
def _unordered_list(elem):
constructed = ''
for child in elem.children:
if child.name == 'li':
constructed += '* %s\n' % child.text
elem.replace_with(constructed)
def _horizontal_rule(elem):
elem.replace_with('%s\n' % ('-' * 80))
def _anchor(elem):
elem.replace_with('[%s](%s)' % (elem.text, elem['href']))
def _table(elem):
elem.replace_with('%s%s' % (elem.text, _NEWLINE_INDICATOR))
_ELEMENT_REPLACER = {
'b': _bold,
'strong': _bold,
'ul': _unordered_list,
'hr': _horizontal_rule,
'a': _anchor,
'table': _table,
}
def _collapse_whitespace(text):
new_lines = []
lines = text.split('\n')
for line in lines:
if not line.strip():
continue
new_lines.append(line.strip().replace(_NEWLINE_INDICATOR, '\n'))
return '\n'.join(new_lines)
def html2text(html):
soup = BeautifulSoup(html, 'html5lib')
_html2text(soup)
return _collapse_whitespace(soup.text)
def _html2text(elem):
for child in elem.children:
if isinstance(child, Tag):
_html2text(child)
elif isinstance(child, NavigableString):
# No changes necessary
continue
if elem.parent:
if elem.name in _ELEMENT_REPLACER:
_ELEMENT_REPLACER[elem.name](elem)

82
util/http.py Normal file
View file

@ -0,0 +1,82 @@
import logging
import json
from flask import request, make_response, current_app
from werkzeug.exceptions import HTTPException
from app import analytics
from auth.auth_context import get_authenticated_context
logger = logging.getLogger(__name__)
DEFAULT_MESSAGE = {}
DEFAULT_MESSAGE[400] = 'Invalid Request'
DEFAULT_MESSAGE[401] = 'Unauthorized'
DEFAULT_MESSAGE[403] = 'Permission Denied'
DEFAULT_MESSAGE[404] = 'Not Found'
DEFAULT_MESSAGE[409] = 'Conflict'
DEFAULT_MESSAGE[501] = 'Not Implemented'
def _abort(status_code, data_object, description, headers):
# Add CORS headers to all errors
options_resp = current_app.make_default_options_response()
headers['Access-Control-Allow-Origin'] = '*'
headers['Access-Control-Allow-Methods'] = options_resp.headers['allow']
headers['Access-Control-Max-Age'] = str(21600)
headers['Access-Control-Allow-Headers'] = ['Authorization', 'Content-Type']
resp = make_response(json.dumps(data_object), status_code, headers)
# Report the abort to the user.
# Raising HTTPException as workaround for https://github.com/pallets/werkzeug/issues/1098
new_exception = HTTPException(response=resp, description=description)
new_exception.code = status_code
raise new_exception
def exact_abort(status_code, message=None):
data = {}
if message is not None:
data['error'] = message
_abort(status_code, data, message or None, {})
def abort(status_code, message=None, issue=None, headers=None, **kwargs):
message = (str(message) % kwargs if message else
DEFAULT_MESSAGE.get(status_code, ''))
params = dict(request.view_args or {})
params.update(kwargs)
params['url'] = request.url
params['status_code'] = status_code
params['message'] = message
# Add the user information.
auth_context = get_authenticated_context()
if auth_context is not None:
message = '%s (authorized: %s)' % (message, auth_context.description)
# Log the abort.
logger.error('Error %s: %s; Arguments: %s' % (status_code, message, params))
# Calculate the issue URL (if the issue ID was supplied).
issue_url = None
if issue:
issue_url = 'http://docs.quay.io/issues/%s.html' % (issue)
# Create the final response data and message.
data = {}
data['error'] = message
if issue_url:
data['info_url'] = issue_url
if headers is None:
headers = {}
_abort(status_code, data, message, headers)

60
util/invoice.py Normal file
View file

@ -0,0 +1,60 @@
from datetime import datetime
from jinja2 import Environment, FileSystemLoader
from xhtml2pdf import pisa
import StringIO
from app import app
jinja_options = {
"loader": FileSystemLoader('util'),
}
env = Environment(**jinja_options)
def renderInvoiceToPdf(invoice, user):
""" Renders a nice PDF display for the given invoice. """
sourceHtml = renderInvoiceToHtml(invoice, user)
output = StringIO.StringIO()
pisaStatus = pisa.CreatePDF(sourceHtml, dest=output)
if pisaStatus.err:
return None
value = output.getvalue()
output.close()
return value
def renderInvoiceToHtml(invoice, user):
""" Renders a nice HTML display for the given invoice. """
from endpoints.api.billing import get_invoice_fields
def get_price(price):
if not price:
return '$0'
return '$' + '{0:.2f}'.format(float(price) / 100)
def get_range(line):
if line.period and line.period.start and line.period.end:
return ': ' + format_date(line.period.start) + ' - ' + format_date(line.period.end)
return ''
def format_date(timestamp):
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d')
app_logo = app.config.get('ENTERPRISE_LOGO_URL', 'https://quay.io/static/img/quay-logo.png')
data = {
'user': user.username,
'invoice': invoice,
'invoice_date': format_date(invoice.date),
'getPrice': get_price,
'getRange': get_range,
'custom_fields': get_invoice_fields(user)[0],
'logo': app_logo,
}
template = env.get_template('invoice.tmpl')
rendered = template.render(data)
return rendered

69
util/invoice.tmpl Normal file
View file

@ -0,0 +1,69 @@
<html>
<body>
<table width="100%" style="max-width: 640px">
<tr>
<td valign="center" style="padding: 10px;">
<img src="{{ logo }}" alt="Quay" style="width: 100px;">
</td>
<td valign="center">
<h3>Quay</h3>
<p style="font-size: 12px; -webkit-text-adjust: none">
Red Hat, Inc<br>
https://redhat.com<br>
100 East Davie Street<br>
Raleigh, North Carolina 27601
</p>
</td>
<td align="right" width="100%">
<h1 style="color: #ddd;">RECEIPT</h1>
<table>
<tr><td>Date:</td><td>{{ invoice_date }}</td></tr>
<tr><td>Invoice #:</td><td style="font-size: 10px">{{ invoice.id }}</td></tr>
{% for custom_field in custom_fields %}
<tr>
<td>*{{ custom_field['title'] }}:</td>
<td style="font-size: 10px">{{ custom_field['value'] }}</td>
</tr>
{% endfor %}
</table>
</td>
</tr>
</table>
<hr>
<table width="100%" style="max-width: 640px">
<thead>
<th style="padding: 4px; background: #eee; text-align: center; font-weight: bold">Description</th>
<th style="padding: 4px; background: #eee; text-align: center; font-weight: bold">Line Total</th>
</thead>
<tbody>
{%- for line in invoice.lines.data -%}
<tr>
<td width="100%" style="padding: 4px;">{{ line.description or ('Plan Subscription' + getRange(line)) }}</td>
<td style="padding: 4px; min-width: 150px;">{{ getPrice(line.amount) }}</td>
</tr>
{%- endfor -%}
<tr>
<td></td>
<td valign="right">
<table>
<tr><td><b>Subtotal: </b></td><td>{{ getPrice(invoice.subtotal) }}</td></tr>
<tr><td><b>Total: </b></td><td>{{ getPrice(invoice.total) }}</td></tr>
<tr><td><b>Paid: </b></td><td>{{ getPrice(invoice.total) if invoice.paid else 0 }}</td></tr>
<tr><td><b>Total Due:</b></td>
<td>{{ getPrice(invoice.ending_balance) }}</td></tr>
</table>
</td>
</tr>
</tbody>
</table>
<div style="margin: 6px; padding: 6px; width: 100%; max-width: 640px; border-top: 2px solid #eee; text-align: center; font-size: 14px; -webkit-text-adjust: none; font-weight: bold;">
We thank you for your continued business!
</div>
</body>
</html>

Binary file not shown.

153
util/ipresolver/__init__.py Normal file
View file

@ -0,0 +1,153 @@
import logging
import json
import time
from collections import namedtuple
from threading import Thread, Lock
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from cachetools.func import ttl_cache, lru_cache
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError
import geoip2.database
import geoip2.errors
import requests
from util.abchelpers import nooper
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'service', 'sync_token',
'country_iso_code'])
AWS_SERVICES = {'EC2', 'CODEBUILD'}
logger = logging.getLogger(__name__)
def _get_aws_ip_ranges():
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@add_metaclass(ABCMeta)
class IPResolverInterface(object):
""" Helper class for resolving information about an IP address. """
@abstractmethod
def resolve_ip(self, ip_address):
""" Attempts to return resolved information about the specified IP Address. If such an attempt
fails, returns None.
"""
pass
@abstractmethod
def is_ip_possible_threat(self, ip_address):
""" Attempts to return whether the given IP address is a possible abuser or spammer.
Returns False if the IP address information could not be looked up.
"""
pass
@nooper
class NoopIPResolver(IPResolverInterface):
""" No-op version of the security scanner API. """
pass
class IPResolver(IPResolverInterface):
def __init__(self, app):
self.app = app
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
self.amazon_ranges = None
self.sync_token = None
logger.info('Loading AWS IP ranges from disk')
aws_ip_ranges_data = _get_aws_ip_ranges()
if aws_ip_ranges_data is not None and aws_ip_ranges_data.get('syncToken'):
logger.info('Building AWS IP ranges')
self.amazon_ranges = IPResolver._parse_amazon_ranges(aws_ip_ranges_data)
self.sync_token = aws_ip_ranges_data['syncToken']
logger.info('Finished building AWS IP ranges')
@ttl_cache(maxsize=100, ttl=600)
def is_ip_possible_threat(self, ip_address):
if self.app.config.get('THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT') is None:
return False
if self.app.config.get('IP_DATA_API_KEY') is None:
return False
if not ip_address:
return False
api_key = self.app.config['IP_DATA_API_KEY']
try:
logger.debug('Requesting IP data for IP %s', ip_address)
r = requests.get('https://api.ipdata.co/%s/threat?api-key=%s' % (ip_address, api_key),
timeout=1)
if r.status_code != 200:
logger.debug('Got non-200 response for IP %s: %s', ip_address, r.status_code)
return False
logger.debug('Got IP data for IP %s: %s => %s', ip_address, r.status_code, r.json())
threat_data = r.json()
return threat_data.get('is_threat', False) or threat_data.get('is_bogon', False)
except requests.RequestException:
logger.exception('Got exception when trying to lookup IP Address')
except ValueError:
logger.exception('Got exception when trying to lookup IP Address')
except Exception:
logger.exception('Got exception when trying to lookup IP Address')
return False
def resolve_ip(self, ip_address):
""" Attempts to return resolved information about the specified IP Address. If such an attempt
fails, returns None.
"""
if not ip_address:
return None
try:
parsed_ip = IPAddress(ip_address)
except AddrFormatError:
return ResolvedLocation('invalid_ip', None, self.sync_token, None)
# Try geoip classification
try:
geoinfo = self.geoip_db.country(ip_address)
except geoip2.errors.AddressNotFoundError:
geoinfo = None
if self.amazon_ranges is None or parsed_ip not in self.amazon_ranges:
if geoinfo:
return ResolvedLocation(
'internet',
geoinfo.country.iso_code,
self.sync_token,
geoinfo.country.iso_code,
)
return ResolvedLocation('internet', None, self.sync_token, None)
return ResolvedLocation('aws', None, self.sync_token,
geoinfo.country.iso_code if geoinfo else None)
@staticmethod
def _parse_amazon_ranges(ranges):
all_amazon = IPSet()
for service_description in ranges['prefixes']:
if service_description['service'] in AWS_SERVICES:
all_amazon.add(IPNetwork(service_description['ip_prefix']))
return all_amazon

View file

View file

@ -0,0 +1,51 @@
import pytest
from mock import patch
from util.ipresolver import IPResolver, ResolvedLocation
from test.fixtures import *
@pytest.fixture()
def test_aws_ip():
return '10.0.0.1'
@pytest.fixture()
def aws_ip_range_data():
fake_range_doc = {
'syncToken': 123456789,
'prefixes': [
{
'ip_prefix': '10.0.0.0/8',
'region': 'GLOBAL',
'service': 'EC2',
},
{
'ip_prefix': '6.0.0.0/8',
'region': 'GLOBAL',
'service': 'EC2',
},
],
}
return fake_range_doc
@pytest.fixture()
def test_ip_range_cache(aws_ip_range_data):
sync_token = aws_ip_range_data['syncToken']
all_amazon = IPResolver._parse_amazon_ranges(aws_ip_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
}
return fake_cache
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
ipresolver = IPResolver(app)
ipresolver.amazon_ranges = test_ip_range_cache['all_amazon']
ipresolver.sync_token = test_ip_range_cache['sync_token']
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=None)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=None)
assert ipresolver.resolve_ip('6.0.0.2') == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=u'US')
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', service=u'US', sync_token=123456789, country_iso_code=u'US')
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', service=None, sync_token=123456789, country_iso_code=None)

6
util/itertoolrecipes.py Normal file
View file

@ -0,0 +1,6 @@
from itertools import islice
# From: https://docs.python.org/2/library/itertools.html
def take(n, iterable):
""" Return first n items of the iterable as a list """
return list(islice(iterable, n))

99
util/jinjautil.py Normal file
View file

@ -0,0 +1,99 @@
from app import get_app_url, avatar
from data import model
from util.names import parse_robot_username
from jinja2 import Environment, FileSystemLoader
def icon_path(icon_name):
return '%s/static/img/icons/%s.png' % (get_app_url(), icon_name)
def icon_image(icon_name):
return '<img src="%s" alt="%s">' % (icon_path(icon_name), icon_name)
def team_reference(teamname):
avatar_html = avatar.get_mail_html(teamname, teamname, 24, 'team')
return "<span>%s <b>%s</b></span>" % (avatar_html, teamname)
def user_reference(username):
user = model.user.get_namespace_user(username)
if not user:
return username
if user.robot:
parts = parse_robot_username(username)
user = model.user.get_namespace_user(parts[0])
return """<span><img src="%s" alt="Robot"> <b>%s</b></span>""" % (icon_path('wrench'), username)
avatar_html = avatar.get_mail_html(user.username, user.email, 24,
'org' if user.organization else 'user')
return """
<span>
%s
<b>%s</b>
</span>""" % (avatar_html, username)
def repository_tag_reference(repository_path_and_tag):
(repository_path, tag) = repository_path_and_tag
(namespace, repository) = repository_path.split('/')
owner = model.user.get_namespace_user(namespace)
if not owner:
return tag
return """<a href="%s/repository/%s/%s?tag=%s&tab=tags">%s</a>""" % (get_app_url(), namespace,
repository, tag, tag)
def repository_reference(pair):
if isinstance(pair, tuple):
(namespace, repository) = pair
else:
pair = pair.split('/')
namespace = pair[0]
repository = pair[1]
owner = model.user.get_namespace_user(namespace)
if not owner:
return "%s/%s" % (namespace, repository)
avatar_html = avatar.get_mail_html(owner.username, owner.email, 16,
'org' if owner.organization else 'user')
return """
<span style="white-space: nowrap;">
%s
<a href="%s/repository/%s/%s">%s/%s</a>
</span>
""" % (avatar_html, get_app_url(), namespace, repository, namespace, repository)
def admin_reference(username):
user = model.user.get_user_or_org(username)
if not user:
return 'account settings'
if user.organization:
return """
<a href="%s/organization/%s?tab=settings">organization's admin setting</a>
""" % (get_app_url(), username)
else:
return """
<a href="%s/user/">account settings</a>
""" % (get_app_url())
def get_template_env(searchpath):
template_loader = FileSystemLoader(searchpath=searchpath)
template_env = Environment(loader=template_loader)
add_filters(template_env)
return template_env
def add_filters(template_env):
template_env.filters['icon_image'] = icon_image
template_env.filters['team_reference'] = team_reference
template_env.filters['user_reference'] = user_reference
template_env.filters['admin_reference'] = admin_reference
template_env.filters['repository_reference'] = repository_reference
template_env.filters['repository_tag_reference'] = repository_tag_reference

20
util/label_validator.py Normal file
View file

@ -0,0 +1,20 @@
class LabelValidator(object):
""" Helper class for validating that labels meet prefix requirements. """
def __init__(self, app):
self.app = app
overridden_prefixes = app.config.get('LABEL_KEY_RESERVED_PREFIXES', [])
for prefix in overridden_prefixes:
if not prefix.endswith('.'):
raise Exception('Prefix "%s" in LABEL_KEY_RESERVED_PREFIXES must end in a dot', prefix)
default_prefixes = app.config.get('DEFAULT_LABEL_KEY_RESERVED_PREFIXES', [])
self.reserved_prefixed_set = set(default_prefixes + overridden_prefixes)
def has_reserved_prefix(self, label_key):
""" Validates that the provided label key does not match any reserved prefixes. """
for prefix in self.reserved_prefixed_set:
if label_key.startswith(prefix):
return True
return False

66
util/locking.py Normal file
View file

@ -0,0 +1,66 @@
import logging
from redis import RedisError
from redlock import RedLock, RedLockError
from app import app
logger = logging.getLogger(__name__)
class LockNotAcquiredException(Exception):
""" Exception raised if a GlobalLock could not be acquired. """
class GlobalLock(object):
""" A lock object that blocks globally via Redis. Note that Redis is not considered a tier-1
service, so this lock should not be used for any critical code paths.
"""
def __init__(self, name, lock_ttl=600):
self._lock_name = name
self._redis_info = dict(app.config['USER_EVENTS_REDIS'])
self._redis_info.update({'socket_connect_timeout': 5,
'socket_timeout': 5,
'single_connection_client': True})
self._lock_ttl = lock_ttl
self._redlock = None
def __enter__(self):
if not self.acquire():
raise LockNotAcquiredException()
def __exit__(self, type, value, traceback):
self.release()
def acquire(self):
logger.debug('Acquiring global lock %s', self._lock_name)
try:
self._redlock = RedLock(self._lock_name, connection_details=[self._redis_info],
ttl=self._lock_ttl)
acquired = self._redlock.acquire()
if not acquired:
logger.debug('Was unable to not acquire lock %s', self._lock_name)
return False
logger.debug('Acquired lock %s', self._lock_name)
return True
except RedLockError:
logger.debug('Could not acquire lock %s', self._lock_name)
return False
except RedisError as re:
logger.debug('Could not connect to Redis for lock %s: %s', self._lock_name, re)
return False
def release(self):
if self._redlock is not None:
logger.debug('Releasing lock %s', self._lock_name)
try:
self._redlock.release()
except RedLockError:
logger.debug('Could not release lock %s', self._lock_name)
except RedisError as re:
logger.debug('Could not connect to Redis for releasing lock %s: %s', self._lock_name, re)
logger.debug('Released lock %s', self._lock_name)
self._redlock = None

47
util/log.py Normal file
View file

@ -0,0 +1,47 @@
import os
from _init import CONF_DIR
def logfile_path(jsonfmt=False, debug=False):
"""
Returns the a logfileconf path following this rules:
- conf/logging_debug_json.conf # jsonfmt=true, debug=true
- conf/logging_json.conf # jsonfmt=true, debug=false
- conf/logging_debug.conf # jsonfmt=false, debug=true
- conf/logging.conf # jsonfmt=false, debug=false
Can be parametrized via envvars: JSONLOG=true, DEBUGLOG=true
"""
_json = ""
_debug = ""
if jsonfmt or os.getenv('JSONLOG', 'false').lower() == 'true':
_json = "_json"
if debug or os.getenv('DEBUGLOG', 'false').lower() == 'true':
_debug = "_debug"
return os.path.join(CONF_DIR, "logging%s%s.conf" % (_debug, _json))
def filter_logs(values, filtered_fields):
"""
Takes a dict and a list of keys to filter.
eg:
with filtered_fields:
[{'key': ['k1', k2'], 'fn': lambda x: 'filtered'}]
and values:
{'k1': {'k2': 'some-secret'}, 'k3': 'some-value'}
the returned dict is:
{'k1': {k2: 'filtered'}, 'k3': 'some-value'}
"""
for field in filtered_fields:
cdict = values
for key in field['key'][:-1]:
if key in cdict:
cdict = cdict[key]
last_key = field['key'][-1]
if last_key in cdict and cdict[last_key]:
cdict[last_key] = field['fn'](cdict[last_key])

0
util/metrics/__init__.py Normal file
View file

210
util/metrics/metricqueue.py Normal file
View file

@ -0,0 +1,210 @@
import datetime
import logging
import time
from functools import wraps
from Queue import Queue, Full
from flask import g, request
from trollius import Return
logger = logging.getLogger(__name__)
# Buckets for the API response times.
API_RESPONSE_TIME_BUCKETS = [.01, .025, .05, .1, .25, .5, 1.0, 2.5, 5.0]
# Buckets for the builder start times.
BUILDER_START_TIME_BUCKETS = [.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0, 180.0, 240.0, 300.0, 600.0]
class MetricQueue(object):
""" Object to which various metrics are written, for distribution to metrics collection
system(s) such as Prometheus.
"""
def __init__(self, prom):
# Define the various exported metrics.
self.resp_time = prom.create_histogram('response_time', 'HTTP response time in seconds',
labelnames=['endpoint'],
buckets=API_RESPONSE_TIME_BUCKETS)
self.resp_code = prom.create_counter('response_code', 'HTTP response code',
labelnames=['endpoint', 'code'])
self.non_200 = prom.create_counter('response_non200', 'Non-200 HTTP response codes',
labelnames=['endpoint'])
self.error_500 = prom.create_counter('response_500', '5XX HTTP response codes',
labelnames=['endpoint'])
self.multipart_upload_start = prom.create_counter('multipart_upload_start',
'Multipart upload started')
self.multipart_upload_end = prom.create_counter('multipart_upload_end',
'Multipart upload ends.', labelnames=['type'])
self.build_capacity_shortage = prom.create_gauge('build_capacity_shortage',
'Build capacity shortage.')
self.builder_time_to_start = prom.create_histogram('builder_tts',
'Time from triggering to starting a builder.',
labelnames=['builder_type'],
buckets=BUILDER_START_TIME_BUCKETS)
self.builder_time_to_build = prom.create_histogram('builder_ttb',
'Time from triggering to actually starting a build',
labelnames=['builder_type'],
buckets=BUILDER_START_TIME_BUCKETS)
self.build_time = prom.create_histogram('build_time', 'Time spent building', labelnames=['builder_type'])
self.builder_fallback = prom.create_counter('builder_fallback', 'Builder fell back to secondary executor')
self.build_start_success = prom.create_counter('build_start_success', 'Executor succeeded in starting a build', labelnames=['builder_type'])
self.build_start_failure = prom.create_counter('build_start_failure', 'Executor failed to start a build', labelnames=['builder_type'])
self.percent_building = prom.create_gauge('build_percent_building', 'Percent building.')
self.build_counter = prom.create_counter('builds', 'Number of builds', labelnames=['name'])
self.ephemeral_build_workers = prom.create_counter('ephemeral_build_workers',
'Number of started ephemeral build workers')
self.ephemeral_build_worker_failure = prom.create_counter('ephemeral_build_worker_failure',
'Number of failed-to-start ephemeral build workers')
self.work_queue_running = prom.create_gauge('work_queue_running', 'Running items in a queue',
labelnames=['queue_name'])
self.work_queue_available = prom.create_gauge('work_queue_available',
'Available items in a queue',
labelnames=['queue_name'])
self.work_queue_available_not_running = prom.create_gauge('work_queue_available_not_running',
'Available items that are not yet running',
labelnames=['queue_name'])
self.repository_pull = prom.create_counter('repository_pull', 'Repository Pull Count',
labelnames=['namespace', 'repo_name', 'protocol',
'status'])
self.repository_push = prom.create_counter('repository_push', 'Repository Push Count',
labelnames=['namespace', 'repo_name', 'protocol',
'status'])
self.repository_build_queued = prom.create_counter('repository_build_queued',
'Repository Build Queued Count',
labelnames=['namespace', 'repo_name'])
self.repository_build_completed = prom.create_counter('repository_build_completed',
'Repository Build Complete Count',
labelnames=['namespace', 'repo_name',
'status', 'executor'])
self.chunk_size = prom.create_histogram('chunk_size',
'Registry blob chunk size',
labelnames=['storage_region'])
self.chunk_upload_time = prom.create_histogram('chunk_upload_time',
'Registry blob chunk upload time',
labelnames=['storage_region'])
self.authentication_count = prom.create_counter('authentication_count',
'Authentication count',
labelnames=['kind', 'status'])
self.repository_count = prom.create_gauge('repository_count', 'Number of repositories')
self.user_count = prom.create_gauge('user_count', 'Number of users')
self.org_count = prom.create_gauge('org_count', 'Number of Organizations')
self.robot_count = prom.create_gauge('robot_count', 'Number of robot accounts')
self.instance_key_renewal_success = prom.create_counter('instance_key_renewal_success',
'Instance Key Renewal Success Count',
labelnames=['key_id'])
self.instance_key_renewal_failure = prom.create_counter('instance_key_renewal_failure',
'Instance Key Renewal Failure Count',
labelnames=['key_id'])
self.invalid_instance_key_count = prom.create_counter('invalid_registry_instance_key_count',
'Invalid registry instance key count',
labelnames=['key_id'])
self.verb_action_passes = prom.create_counter('verb_action_passes', 'Verb Pass Count',
labelnames=['kind', 'pass_count'])
self.push_byte_count = prom.create_counter('registry_push_byte_count',
'Number of bytes pushed to the registry')
self.pull_byte_count = prom.create_counter('estimated_registry_pull_byte_count',
'Number of (estimated) bytes pulled from the registry',
labelnames=['protocol_version'])
# Deprecated: Define an in-memory queue for reporting metrics to CloudWatch or another
# provider.
self._queue = None
def enable_deprecated(self, maxsize=10000):
self._queue = Queue(maxsize)
def put_deprecated(self, name, value, **kwargs):
if self._queue is None:
logger.debug('No metric queue %s %s %s', name, value, kwargs)
return
try:
kwargs.setdefault('timestamp', datetime.datetime.now())
kwargs.setdefault('dimensions', {})
self._queue.put_nowait((name, value, kwargs))
except Full:
logger.error('Metric queue full')
def get_deprecated(self):
return self._queue.get()
def get_nowait_deprecated(self):
return self._queue.get_nowait()
def duration_collector_async(metric, labelvalues):
""" Decorates a method to have its duration time logged to the metric. """
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
trigger_time = time.time()
try:
rv = func(*args, **kwargs)
except Return as e:
metric.Observe(time.time() - trigger_time, labelvalues=labelvalues)
raise e
return rv
return wrapper
return decorator
def time_decorator(name, metric_queue):
""" Decorates an endpoint method to have its request time logged to the metrics queue. """
after = _time_after_request(name, metric_queue)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
_time_before_request()
rv = func(*args, **kwargs)
after(rv)
return rv
return wrapper
return decorator
def time_blueprint(bp, metric_queue):
""" Decorates a blueprint to have its request time logged to the metrics queue. """
bp.before_request(_time_before_request)
bp.after_request(_time_after_request(bp.name, metric_queue))
def _time_before_request():
g._request_start_time = time.time()
def _time_after_request(name, metric_queue):
def f(r):
start = getattr(g, '_request_start_time', None)
if start is None:
return r
dur = time.time() - start
metric_queue.resp_time.Observe(dur, labelvalues=[request.endpoint])
metric_queue.resp_code.Inc(labelvalues=[request.endpoint, r.status_code])
if r.status_code >= 500:
metric_queue.error_500.Inc(labelvalues=[request.endpoint])
elif r.status_code < 200 or r.status_code >= 300:
metric_queue.non_200.Inc(labelvalues=[request.endpoint])
return r
return f

168
util/metrics/prometheus.py Normal file
View file

@ -0,0 +1,168 @@
import datetime
import json
import logging
from Queue import Queue, Full, Empty
from threading import Thread
import requests
logger = logging.getLogger(__name__)
QUEUE_MAX = 1000
MAX_BATCH_SIZE = 100
REGISTER_WAIT = datetime.timedelta(hours=1)
class PrometheusPlugin(object):
""" Application plugin for reporting metrics to Prometheus. """
def __init__(self, app=None):
self.app = app
if app is not None:
self.state = self.init_app(app)
else:
self.state = None
def init_app(self, app):
prom_url = app.config.get('PROMETHEUS_AGGREGATOR_URL')
prom_namespace = app.config.get('PROMETHEUS_NAMESPACE')
logger.debug('Initializing prometheus with aggregator url: %s', prom_url)
prometheus = Prometheus(prom_url, prom_namespace)
# register extension with app
app.extensions = getattr(app, 'extensions', {})
app.extensions['prometheus'] = prometheus
return prometheus
def __getattr__(self, name):
return getattr(self.state, name, None)
class Prometheus(object):
""" Aggregator for collecting stats that are reported to Prometheus. """
def __init__(self, url=None, namespace=None):
self._metric_collectors = []
self._url = url
self._namespace = namespace or ''
if url is not None:
self._queue = Queue(QUEUE_MAX)
self._sender = _QueueSender(self._queue, url, self._metric_collectors)
self._sender.start()
logger.debug('Prometheus aggregator sending to %s', url)
else:
self._queue = None
logger.debug('Prometheus aggregator disabled')
def enqueue(self, call, data):
if not self._queue:
return
v = json.dumps({
'Call': call,
'Data': data,
})
if call == 'register':
self._metric_collectors.append(v)
return
try:
self._queue.put_nowait(v)
except Full:
# If the queue is full, it is because 1) no aggregator was enabled or 2)
# the aggregator is taking a long time to respond to requests. In the case
# of 1, it's probably enterprise mode and we don't care. In the case of 2,
# the response timeout error is printed inside the queue handler. In either case,
# we don't need to print an error here.
pass
def create_gauge(self, *args, **kwargs):
return self._create_collector('Gauge', args, kwargs)
def create_counter(self, *args, **kwargs):
return self._create_collector('Counter', args, kwargs)
def create_summary(self, *args, **kwargs):
return self._create_collector('Summary', args, kwargs)
def create_histogram(self, *args, **kwargs):
return self._create_collector('Histogram', args, kwargs)
def create_untyped(self, *args, **kwargs):
return self._create_collector('Untyped', args, kwargs)
def _create_collector(self, collector_type, args, kwargs):
kwargs['namespace'] = kwargs.get('namespace', self._namespace)
return _Collector(self.enqueue, collector_type, *args, **kwargs)
class _QueueSender(Thread):
""" Helper class which uses a thread to asynchronously send metrics to the local Prometheus
aggregator. """
def __init__(self, queue, url, metric_collectors):
Thread.__init__(self)
self.daemon = True
self.next_register = datetime.datetime.now()
self._queue = queue
self._url = url
self._metric_collectors = metric_collectors
def run(self):
while True:
reqs = []
reqs.append(self._queue.get())
while len(reqs) < MAX_BATCH_SIZE:
try:
req = self._queue.get_nowait()
reqs.append(req)
except Empty:
break
try:
resp = requests.post(self._url + '/call', '\n'.join(reqs))
if resp.status_code == 500 and self.next_register <= datetime.datetime.now():
resp = requests.post(self._url + '/call', '\n'.join(self._metric_collectors))
self.next_register = datetime.datetime.now() + REGISTER_WAIT
logger.debug('Register returned %s for %s metrics; setting next to %s', resp.status_code,
len(self._metric_collectors), self.next_register)
elif resp.status_code != 200:
logger.debug('Failed sending to prometheus: %s: %s: %s', resp.status_code, resp.text,
', '.join(reqs))
else:
logger.debug('Sent %d prometheus metrics', len(reqs))
except:
logger.exception('Failed to write to prometheus aggregator: %s', reqs)
class _Collector(object):
""" Collector for a Prometheus metric. """
def __init__(self, enqueue_method, collector_type, collector_name, collector_help,
namespace='', subsystem='', **kwargs):
self._enqueue_method = enqueue_method
self._base_args = {
'Name': collector_name,
'Namespace': namespace,
'Subsystem': subsystem,
'Type': collector_type,
}
registration_params = dict(kwargs)
registration_params.update(self._base_args)
registration_params['Help'] = collector_help
self._enqueue_method('register', registration_params)
def __getattr__(self, method):
def f(value=0, labelvalues=()):
data = dict(self._base_args)
data.update({
'Value': value,
'LabelValues': [str(i) for i in labelvalues],
'Method': method,
})
self._enqueue_method('put', data)
return f

View file

@ -0,0 +1,58 @@
import time
import pytest
from mock import Mock
from trollius import coroutine, Return, get_event_loop, From
from util.metrics.metricqueue import duration_collector_async
mock_histogram = Mock()
class NonReturn(Exception):
pass
@coroutine
@duration_collector_async(mock_histogram, labelvalues=["testlabel"])
def duration_decorated():
time.sleep(1)
raise Return("fin")
@coroutine
@duration_collector_async(mock_histogram, labelvalues=["testlabel"])
def duration_decorated_error():
raise NonReturn("not a Return error")
@coroutine
def calls_decorated():
yield From(duration_decorated())
def test_duration_decorator():
loop = get_event_loop()
loop.run_until_complete(duration_decorated())
assert mock_histogram.Observe.called
assert 1 - mock_histogram.Observe.call_args[0][0] < 1 # duration should be close to 1s
assert mock_histogram.Observe.call_args[1]["labelvalues"] == ["testlabel"]
def test_duration_decorator_error():
loop = get_event_loop()
mock_histogram.reset_mock()
with pytest.raises(NonReturn):
loop.run_until_complete(duration_decorated_error())
assert not mock_histogram.Observe.called
def test_duration_decorator_caller():
mock_histogram.reset_mock()
loop = get_event_loop()
loop.run_until_complete(calls_decorated())
assert mock_histogram.Observe.called
assert 1 - mock_histogram.Observe.call_args[0][0] < 1 # duration should be close to 1s
assert mock_histogram.Observe.call_args[1]["labelvalues"] == ["testlabel"]

38
util/migrate/__init__.py Normal file
View file

@ -0,0 +1,38 @@
import logging
from sqlalchemy.types import TypeDecorator, Text, String
from sqlalchemy.dialects.mysql import TEXT as MySQLText, LONGTEXT, VARCHAR as MySQLString
logger = logging.getLogger(__name__)
class UTF8LongText(TypeDecorator):
""" Platform-independent UTF-8 LONGTEXT type.
Uses MySQL's LongText with charset utf8mb4, otherwise uses TEXT, because
other engines default to UTF-8 and have longer TEXT fields.
"""
impl = Text
def load_dialect_impl(self, dialect):
if dialect.name == 'mysql':
return dialect.type_descriptor(LONGTEXT(charset='utf8mb4', collation='utf8mb4_unicode_ci'))
else:
return dialect.type_descriptor(Text())
class UTF8CharField(TypeDecorator):
""" Platform-independent UTF-8 Char type.
Uses MySQL's VARCHAR with charset utf8mb4, otherwise uses String, because
other engines default to UTF-8.
"""
impl = String
def load_dialect_impl(self, dialect):
if dialect.name == 'mysql':
return dialect.type_descriptor(MySQLString(charset='utf8mb4', collation='utf8mb4_unicode_ci',
length=self.impl.length))
else:
return dialect.type_descriptor(String(length=self.impl.length))

175
util/migrate/allocator.py Normal file
View file

@ -0,0 +1,175 @@
import logging
import random
from bintrees import RBTree
from threading import Event
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class NoAvailableKeysError(ValueError):
pass
class CompletedKeys(object):
def __init__(self, max_index, min_index=0):
self._max_index = max_index
self._min_index = min_index
self.num_remaining = max_index - min_index
self._slabs = RBTree()
def _get_previous_or_none(self, index):
try:
return self._slabs.floor_item(index)
except KeyError:
return None
def is_available(self, index):
logger.debug('Testing index %s', index)
if index >= self._max_index or index < self._min_index:
logger.debug('Index out of range')
return False
try:
prev_start, prev_length = self._slabs.floor_item(index)
logger.debug('Prev range: %s-%s', prev_start, prev_start + prev_length)
return (prev_start + prev_length) <= index
except KeyError:
return True
def mark_completed(self, start_index, past_last_index):
logger.debug('Marking the range completed: %s-%s', start_index, past_last_index)
num_completed = min(past_last_index, self._max_index) - max(start_index, self._min_index)
# Find the item directly before this and see if there is overlap
to_discard = set()
try:
prev_start, prev_length = self._slabs.floor_item(start_index)
max_prev_completed = prev_start + prev_length
if max_prev_completed >= start_index:
# we are going to merge with the range before us
logger.debug('Merging with the prev range: %s-%s', prev_start, prev_start + prev_length)
to_discard.add(prev_start)
num_completed = max(num_completed - (max_prev_completed - start_index), 0)
start_index = prev_start
past_last_index = max(past_last_index, prev_start + prev_length)
except KeyError:
pass
# Find all keys between the start and last index and merge them into one block
for merge_start, merge_length in self._slabs.iter_items(start_index, past_last_index + 1):
if merge_start in to_discard:
logger.debug('Already merged with block %s-%s', merge_start, merge_start + merge_length)
continue
candidate_next_index = merge_start + merge_length
logger.debug('Merging with block %s-%s', merge_start, candidate_next_index)
num_completed -= merge_length - max(candidate_next_index - past_last_index, 0)
to_discard.add(merge_start)
past_last_index = max(past_last_index, candidate_next_index)
# write the new block which is fully merged
discard = False
if past_last_index >= self._max_index:
logger.debug('Discarding block and setting new max to: %s', start_index)
self._max_index = start_index
discard = True
if start_index <= self._min_index:
logger.debug('Discarding block and setting new min to: %s', past_last_index)
self._min_index = past_last_index
discard = True
if to_discard:
logger.debug('Discarding %s obsolete blocks', len(to_discard))
self._slabs.remove_items(to_discard)
if not discard:
logger.debug('Writing new block with range: %s-%s', start_index, past_last_index)
self._slabs.insert(start_index, past_last_index - start_index)
# Update the number of remaining items with the adjustments we've made
assert num_completed >= 0
self.num_remaining -= num_completed
logger.debug('Total blocks: %s', len(self._slabs))
def get_block_start_index(self, block_size_estimate):
logger.debug('Total range: %s-%s', self._min_index, self._max_index)
if self._max_index <= self._min_index:
raise NoAvailableKeysError('All indexes have been marked completed')
num_holes = len(self._slabs) + 1
random_hole = random.randint(0, num_holes - 1)
logger.debug('Selected random hole %s with %s total holes', random_hole, num_holes)
hole_start = self._min_index
past_hole_end = self._max_index
# Now that we have picked a hole, we need to define the bounds
if random_hole > 0:
# There will be a slab before this hole, find where it ends
bound_entries = self._slabs.nsmallest(random_hole + 1)[-2:]
left_index, left_len = bound_entries[0]
logger.debug('Left range %s-%s', left_index, left_index + left_len)
hole_start = left_index + left_len
if len(bound_entries) > 1:
right_index, right_len = bound_entries[1]
logger.debug('Right range %s-%s', right_index, right_index + right_len)
past_hole_end, _ = bound_entries[1]
elif not self._slabs.is_empty():
right_index, right_len = self._slabs.nsmallest(1)[0]
logger.debug('Right range %s-%s', right_index, right_index + right_len)
past_hole_end, _ = self._slabs.nsmallest(1)[0]
# Now that we have our hole bounds, select a random block from [0:len - block_size_estimate]
logger.debug('Selecting from hole range: %s-%s', hole_start, past_hole_end)
rand_max_bound = max(hole_start, past_hole_end - block_size_estimate)
logger.debug('Rand max bound: %s', rand_max_bound)
return random.randint(hole_start, rand_max_bound)
def yield_random_entries(batch_query, primary_key_field, batch_size, max_id, min_id=0):
""" This method will yield items from random blocks in the database. We will track metadata
about which keys are available for work, and we will complete the backfill when there is no
more work to be done. The method yields tuples of (candidate, Event), and if the work was
already done by another worker, the caller should set the event. Batch candidates must have
an "id" field which can be inspected.
"""
min_id = max(min_id, 0)
max_id = max(max_id, 1)
allocator = CompletedKeys(max_id + 1, min_id)
try:
while True:
start_index = allocator.get_block_start_index(batch_size)
end_index = min(start_index + batch_size, max_id + 1)
all_candidates = list(batch_query()
.where(primary_key_field >= start_index,
primary_key_field < end_index)
.order_by(primary_key_field))
if len(all_candidates) == 0:
logger.info('No candidates, marking entire block completed %s-%s', start_index, end_index)
allocator.mark_completed(start_index, end_index)
continue
logger.info('Found %s candidates, processing block', len(all_candidates))
batch_completed = 0
for candidate in all_candidates:
abort_early = Event()
yield candidate, abort_early, allocator.num_remaining - batch_completed
batch_completed += 1
if abort_early.is_set():
logger.info('Overlap with another worker, aborting')
break
completed_through = candidate.id + 1
logger.info('Marking id range as completed: %s-%s', start_index, completed_through)
allocator.mark_completed(start_index, completed_through)
except NoAvailableKeysError:
logger.info('No more work')

View file

@ -0,0 +1,54 @@
import logging
from app import app
from data.database import User
from util.names import parse_robot_username
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def cleanup_old_robots(page_size=50, force=False):
""" Deletes any robots that live under namespaces that no longer exist. """
if not force and not app.config.get('SETUP_COMPLETE', False):
return
# Collect the robot accounts to delete.
page_number = 1
to_delete = []
encountered_namespaces = {}
while True:
found_bots = False
for robot in list(User.select().where(User.robot == True).paginate(page_number, page_size)):
found_bots = True
logger.info("Checking robot %s (page %s)", robot.username, page_number)
parsed = parse_robot_username(robot.username)
if parsed is None:
continue
namespace, _ = parsed
if namespace in encountered_namespaces:
if not encountered_namespaces[namespace]:
logger.info('Marking %s to be deleted', robot.username)
to_delete.append(robot)
else:
try:
User.get(username=namespace)
encountered_namespaces[namespace] = True
except User.DoesNotExist:
# Save the robot account for deletion.
logger.info('Marking %s to be deleted', robot.username)
to_delete.append(robot)
encountered_namespaces[namespace] = False
if not found_bots:
break
page_number = page_number + 1
# Cleanup any robot accounts whose corresponding namespace doesn't exist.
logger.info('Found %s robots to delete', len(to_delete))
for index, robot in enumerate(to_delete):
logger.info('Deleting robot %s of %s (%s)', index, len(to_delete), robot.username)
robot.delete_instance(recursive=True, delete_nullable=True)

View file

@ -0,0 +1,48 @@
import logging
import time
from datetime import datetime, timedelta
from data.database import RepositoryBuild, AccessToken
from app import app
logger = logging.getLogger(__name__)
BATCH_SIZE = 1000
def delete_temporary_access_tokens(older_than):
# Find the highest ID up to which we should delete
up_to_id = (AccessToken
.select(AccessToken.id)
.where(AccessToken.created < older_than)
.limit(1)
.order_by(AccessToken.id.desc())
.get().id)
logger.debug('Deleting temporary access tokens with ids lower than: %s', up_to_id)
access_tokens_in_builds = (RepositoryBuild.select(RepositoryBuild.access_token).distinct())
while up_to_id > 0:
starting_at_id = max(up_to_id - BATCH_SIZE, 0)
logger.debug('Deleting tokens with ids between %s and %s', starting_at_id, up_to_id)
start_time = datetime.utcnow()
(AccessToken
.delete()
.where(AccessToken.id >= starting_at_id,
AccessToken.id < up_to_id,
AccessToken.temporary == True,
~(AccessToken.id << access_tokens_in_builds))
.execute())
time_to_delete = datetime.utcnow() - start_time
up_to_id -= BATCH_SIZE
logger.debug('Sleeping for %s seconds', time_to_delete.total_seconds())
time.sleep(time_to_delete.total_seconds())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
delete_temporary_access_tokens(datetime.utcnow() - timedelta(days=2))

13
util/migrate/table_ops.py Normal file
View file

@ -0,0 +1,13 @@
def copy_table_contents(source_table, destination_table, conn):
if conn.engine.name == 'postgresql':
conn.execute('INSERT INTO "%s" SELECT * FROM "%s"' % (destination_table, source_table))
result = list(conn.execute('Select Max(id) from "%s"' % destination_table))[0]
if result[0] is not None:
new_start_id = result[0] + 1
conn.execute('ALTER SEQUENCE "%s_id_seq" RESTART WITH %s' % (destination_table, new_start_id))
else:
conn.execute("INSERT INTO `%s` SELECT * FROM `%s` WHERE 1" % (destination_table, source_table))
result = list(conn.execute('Select Max(id) from `%s` WHERE 1' % destination_table))[0]
if result[0] is not None:
new_start_id = result[0] + 1
conn.execute("ALTER TABLE `%s` AUTO_INCREMENT = %s" % (destination_table, new_start_id))

Some files were not shown because too many files have changed in this diff Show more