initial import for Open Source 🎉
This commit is contained in:
parent
1898c361f3
commit
9c0dd3b722
2048 changed files with 218743 additions and 0 deletions
19
util/__init__.py
Normal file
19
util/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
|
||||
def get_app_url(config):
|
||||
""" Returns the application's URL, based on the given config. """
|
||||
return '%s://%s' % (config['PREFERRED_URL_SCHEME'], config['SERVER_HOSTNAME'])
|
||||
|
||||
def slash_join(*args):
|
||||
"""
|
||||
Joins together strings and guarantees there is only one '/' in between the
|
||||
each string joined. Double slashes ('//') are assumed to be intentional and
|
||||
are not deduplicated.
|
||||
"""
|
||||
def rmslash(path):
|
||||
path = path[1:] if len(path) > 0 and path[0] == '/' else path
|
||||
path = path[:-1] if len(path) > 0 and path[-1] == '/' else path
|
||||
return path
|
||||
|
||||
args = [rmslash(path) for path in args]
|
||||
return '/'.join(args)
|
||||
|
19
util/abchelpers.py
Normal file
19
util/abchelpers.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
class NoopIsANoopException(TypeError):
|
||||
""" Raised if the nooper decorator is unnecessary on a class. """
|
||||
pass
|
||||
|
||||
|
||||
def nooper(cls):
|
||||
""" Decorates a class that derives from an ABCMeta, filling in any unimplemented methods with
|
||||
no-ops.
|
||||
"""
|
||||
def empty_func(*args, **kwargs):
|
||||
# pylint: disable=unused-argument
|
||||
pass
|
||||
|
||||
empty_methods = {m_name: empty_func for m_name in cls.__abstractmethods__}
|
||||
|
||||
if not empty_methods:
|
||||
raise NoopIsANoopException('nooper implemented no abstract methods on %s' % cls)
|
||||
|
||||
return type(cls.__name__, (cls,), empty_methods)
|
68
util/asyncwrapper.py
Normal file
68
util/asyncwrapper.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import queue
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from concurrent.futures import Executor, Future, CancelledError
|
||||
|
||||
|
||||
class AsyncExecutorWrapper(object):
|
||||
""" This class will wrap a syncronous library transparently in a way which
|
||||
will move all calls off to an asynchronous Executor, and will change all
|
||||
returned values to be Future objects.
|
||||
"""
|
||||
SYNC_FLAG_FIELD = '__AsyncExecutorWrapper__sync__'
|
||||
|
||||
def __init__(self, delegate, executor):
|
||||
""" Wrap the specified synchronous delegate instance, and submit() all
|
||||
method calls to the specified Executor instance.
|
||||
"""
|
||||
self._delegate = delegate
|
||||
self._executor = executor
|
||||
|
||||
def __getattr__(self, attr_name):
|
||||
maybe_callable = getattr(self._delegate, attr_name) # Will raise proper attribute error
|
||||
if callable(maybe_callable):
|
||||
# Build a callable which when executed places the request
|
||||
# onto a queue
|
||||
@wraps(maybe_callable)
|
||||
def wrapped_method(*args, **kwargs):
|
||||
if getattr(maybe_callable, self.SYNC_FLAG_FIELD, False):
|
||||
sync_result = Future()
|
||||
try:
|
||||
sync_result.set_result(maybe_callable(*args, **kwargs))
|
||||
except Exception as ex:
|
||||
sync_result.set_exception(ex)
|
||||
return sync_result
|
||||
|
||||
try:
|
||||
return self._executor.submit(maybe_callable, *args, **kwargs)
|
||||
except queue.Full as ex:
|
||||
queue_full = Future()
|
||||
queue_full.set_exception(ex)
|
||||
return queue_full
|
||||
|
||||
return wrapped_method
|
||||
else:
|
||||
return maybe_callable
|
||||
|
||||
@classmethod
|
||||
def sync(cls, f):
|
||||
""" Annotate the given method to flag it as synchronous so that AsyncExecutorWrapper
|
||||
will return the result immediately without submitting it to the executor.
|
||||
"""
|
||||
setattr(f, cls.SYNC_FLAG_FIELD, True)
|
||||
return f
|
||||
|
||||
|
||||
class NullExecutorCancelled(CancelledError):
|
||||
def __init__(self):
|
||||
super(NullExecutorCancelled, self).__init__('Null executor always fails.')
|
||||
|
||||
|
||||
class NullExecutor(Executor):
|
||||
""" Executor instance which always returns a Future completed with a
|
||||
CancelledError exception. """
|
||||
def submit(self, _, *args, **kwargs):
|
||||
always_fail = Future()
|
||||
always_fail.set_exception(NullExecutorCancelled())
|
||||
return always_fail
|
90
util/audit.py
Normal file
90
util/audit.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
from collections import namedtuple
|
||||
from urlparse import urlparse
|
||||
|
||||
from flask import request
|
||||
|
||||
from app import analytics, userevents, ip_resolver
|
||||
from auth.auth_context import get_authenticated_context, get_authenticated_user
|
||||
from data.logs_model import logs_model
|
||||
from util.request import get_request_ip
|
||||
|
||||
from data.readreplica import ReadOnlyModeException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Repository = namedtuple('Repository', ['namespace_name', 'name', 'id', 'is_free_namespace'])
|
||||
|
||||
def wrap_repository(repo_obj):
|
||||
return Repository(namespace_name=repo_obj.namespace_user.username, name=repo_obj.name,
|
||||
id=repo_obj.id, is_free_namespace=repo_obj.namespace_user.stripe_id is None)
|
||||
|
||||
|
||||
def track_and_log(event_name, repo_obj, analytics_name=None, analytics_sample=1, **kwargs):
|
||||
repo_name = repo_obj.name
|
||||
namespace_name = repo_obj.namespace_name
|
||||
metadata = {
|
||||
'repo': repo_name,
|
||||
'namespace': namespace_name,
|
||||
}
|
||||
metadata.update(kwargs)
|
||||
|
||||
is_free_namespace = False
|
||||
if hasattr(repo_obj, 'is_free_namespace'):
|
||||
is_free_namespace = repo_obj.is_free_namespace
|
||||
|
||||
# Add auth context metadata.
|
||||
analytics_id = 'anonymous'
|
||||
auth_context = get_authenticated_context()
|
||||
if auth_context is not None:
|
||||
analytics_id, context_metadata = auth_context.analytics_id_and_public_metadata()
|
||||
metadata.update(context_metadata)
|
||||
|
||||
# Publish the user event (if applicable)
|
||||
logger.debug('Checking publishing %s to the user events system', event_name)
|
||||
if auth_context and auth_context.has_nonrobot_user:
|
||||
logger.debug('Publishing %s to the user events system', event_name)
|
||||
user_event_data = {
|
||||
'action': event_name,
|
||||
'repository': repo_name,
|
||||
'namespace': namespace_name,
|
||||
}
|
||||
|
||||
event = userevents.get_event(auth_context.authed_user.username)
|
||||
event.publish_event_data('docker-cli', user_event_data)
|
||||
|
||||
# Save the action to mixpanel.
|
||||
if random.random() < analytics_sample:
|
||||
if analytics_name is None:
|
||||
analytics_name = event_name
|
||||
|
||||
logger.debug('Logging the %s to analytics engine', analytics_name)
|
||||
|
||||
request_parsed = urlparse(request.url_root)
|
||||
extra_params = {
|
||||
'repository': '%s/%s' % (namespace_name, repo_name),
|
||||
'user-agent': request.user_agent.string,
|
||||
'hostname': request_parsed.hostname,
|
||||
}
|
||||
|
||||
analytics.track(analytics_id, analytics_name, extra_params)
|
||||
|
||||
# Add the resolved information to the metadata.
|
||||
logger.debug('Resolving IP address %s', get_request_ip())
|
||||
resolved_ip = ip_resolver.resolve_ip(get_request_ip())
|
||||
if resolved_ip is not None:
|
||||
metadata['resolved_ip'] = resolved_ip._asdict()
|
||||
|
||||
logger.debug('Resolved IP address %s', get_request_ip())
|
||||
|
||||
# Log the action to the database.
|
||||
logger.debug('Logging the %s to logs system', event_name)
|
||||
try:
|
||||
logs_model.log_action(event_name, namespace_name, performer=get_authenticated_user(),
|
||||
ip=get_request_ip(), metadata=metadata, repository=repo_obj,
|
||||
is_free_namespace=is_free_namespace)
|
||||
logger.debug('Track and log of %s complete', event_name)
|
||||
except ReadOnlyModeException:
|
||||
pass
|
48
util/backfillreplication.py
Normal file
48
util/backfillreplication.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
import features
|
||||
|
||||
from app import storage, image_replication_queue
|
||||
from data.database import (Image, ImageStorage, Repository, User, ImageStoragePlacement,
|
||||
ImageStorageLocation)
|
||||
from data import model
|
||||
from util.registry.replication import queue_storage_replication
|
||||
|
||||
def backfill_replication():
|
||||
encountered = set()
|
||||
query = (Image
|
||||
.select(Image, ImageStorage, Repository, User)
|
||||
.join(ImageStorage)
|
||||
.switch(Image)
|
||||
.join(Repository)
|
||||
.join(User))
|
||||
|
||||
for image in query:
|
||||
if image.storage.uuid in encountered:
|
||||
continue
|
||||
|
||||
namespace = image.repository.namespace_user
|
||||
locations = model.user.get_region_locations(namespace)
|
||||
locations_required = locations | set(storage.default_locations)
|
||||
|
||||
query = (ImageStoragePlacement
|
||||
.select(ImageStoragePlacement, ImageStorageLocation)
|
||||
.where(ImageStoragePlacement.storage == image.storage)
|
||||
.join(ImageStorageLocation))
|
||||
|
||||
existing_locations = set([p.location.name for p in query])
|
||||
locations_missing = locations_required - existing_locations
|
||||
if locations_missing:
|
||||
print "Enqueueing image storage %s to be replicated" % (image.storage.uuid)
|
||||
encountered.add(image.storage.uuid)
|
||||
|
||||
if not image_replication_queue.alive([image.storage.uuid]):
|
||||
queue_storage_replication(image.repository.namespace_user.username, image.storage)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
if not features.STORAGE_REPLICATION:
|
||||
print "Storage replication is not enabled"
|
||||
else:
|
||||
backfill_replication()
|
||||
|
5
util/backoff.py
Normal file
5
util/backoff.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
def exponential_backoff(attempts, scaling_factor, base):
|
||||
backoff = 5 * (pow(2, attempts) - 1)
|
||||
backoff_time = backoff * scaling_factor
|
||||
retry_at = backoff_time/10 + base
|
||||
return retry_at
|
32
util/bytes.py
Normal file
32
util/bytes.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
class Bytes(object):
|
||||
""" Wrapper around strings and unicode objects to ensure we are always using
|
||||
the correct encoded or decoded data.
|
||||
"""
|
||||
def __init__(self, data):
|
||||
assert isinstance(data, str)
|
||||
self._encoded_data = data
|
||||
|
||||
@classmethod
|
||||
def for_string_or_unicode(cls, input):
|
||||
# If the string is a unicode string, then encode its data as UTF-8. Note that
|
||||
# we don't catch any decode exceptions here, as we want those to be raised.
|
||||
if isinstance(input, unicode):
|
||||
return Bytes(input.encode('utf-8'))
|
||||
|
||||
# Next, try decoding as UTF-8. If we have a utf-8 encoded string, then we have no
|
||||
# additional conversion to do.
|
||||
try:
|
||||
input.decode('utf-8')
|
||||
return Bytes(input)
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
# Finally, if the data is (somehow) a unicode string inside a `str` type, then
|
||||
# re-encoded the data.
|
||||
return Bytes(input.encode('utf-8'))
|
||||
|
||||
def as_encoded_str(self):
|
||||
return self._encoded_data
|
||||
|
||||
def as_unicode(self):
|
||||
return self._encoded_data.decode('utf-8')
|
36
util/cache.py
Normal file
36
util/cache.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from functools import wraps
|
||||
|
||||
from flask_restful.utils import unpack
|
||||
|
||||
|
||||
def cache_control(max_age=55):
|
||||
def wrap(f):
|
||||
@wraps(f)
|
||||
def add_max_age(*args, **kwargs):
|
||||
response = f(*args, **kwargs)
|
||||
response.headers['Cache-Control'] = 'max-age=%d' % max_age
|
||||
return response
|
||||
return add_max_age
|
||||
return wrap
|
||||
|
||||
|
||||
def cache_control_flask_restful(max_age=55):
|
||||
def wrap(f):
|
||||
@wraps(f)
|
||||
def add_max_age(*args, **kwargs):
|
||||
response = f(*args, **kwargs)
|
||||
body, status_code, headers = unpack(response)
|
||||
headers['Cache-Control'] = 'max-age=%d' % max_age
|
||||
return body, status_code, headers
|
||||
return add_max_age
|
||||
return wrap
|
||||
|
||||
|
||||
def no_cache(f):
|
||||
@wraps(f)
|
||||
def add_no_cache(*args, **kwargs):
|
||||
response = f(*args, **kwargs)
|
||||
if response is not None:
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
return response
|
||||
return add_no_cache
|
18
util/canonicaljson.py
Normal file
18
util/canonicaljson.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import collections
|
||||
|
||||
def canonicalize(json_obj):
|
||||
"""This function canonicalizes a Python object that will be serialized as JSON.
|
||||
|
||||
Args:
|
||||
json_obj (object): the Python object that will later be serialized as JSON.
|
||||
|
||||
Returns:
|
||||
object: json_obj now sorted to its canonical form.
|
||||
|
||||
"""
|
||||
if isinstance(json_obj, collections.MutableMapping):
|
||||
sorted_obj = sorted({key: canonicalize(val) for key, val in json_obj.items()}.items())
|
||||
return collections.OrderedDict(sorted_obj)
|
||||
elif isinstance(json_obj, (list, tuple)):
|
||||
return [canonicalize(val) for val in json_obj]
|
||||
return json_obj
|
29
util/config/__init__.py
Normal file
29
util/config/__init__.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
class URLSchemeAndHostname:
|
||||
"""
|
||||
Immutable configuration for a given preferred url scheme (e.g. http or https), and a hostname (e.g. localhost:5000)
|
||||
"""
|
||||
def __init__(self, url_scheme, hostname):
|
||||
self._url_scheme = url_scheme
|
||||
self._hostname = hostname
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls, app_config):
|
||||
"""
|
||||
Helper method to instantiate class from app config, a frequent pattern
|
||||
:param app_config:
|
||||
:return:
|
||||
"""
|
||||
return cls(app_config['PREFERRED_URL_SCHEME'], app_config['SERVER_HOSTNAME'])
|
||||
|
||||
@property
|
||||
def url_scheme(self):
|
||||
return self._url_scheme
|
||||
|
||||
@property
|
||||
def hostname(self):
|
||||
return self._hostname
|
||||
|
||||
def get_url(self):
|
||||
""" Returns the application's URL, based on the given url scheme and hostname. """
|
||||
return '%s://%s' % (self._url_scheme, self._hostname)
|
||||
|
43
util/config/configdocs/configdoc.py
Normal file
43
util/config/configdocs/configdoc.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
""" Generates html documentation from JSON Schema """
|
||||
|
||||
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
import docsmodel
|
||||
import html_output
|
||||
|
||||
from util.config.schema import CONFIG_SCHEMA
|
||||
|
||||
|
||||
def make_custom_sort(orders):
|
||||
""" Sort in a specified order any dictionary nested in a complex structure """
|
||||
|
||||
orders = [{k: -i for (i, k) in enumerate(reversed(order), 1)} for order in orders]
|
||||
def process(stuff):
|
||||
if isinstance(stuff, dict):
|
||||
l = [(k, process(v)) for (k, v) in stuff.iteritems()]
|
||||
keys = set(stuff)
|
||||
for order in orders:
|
||||
if keys.issubset(order) or keys.issuperset(order):
|
||||
return OrderedDict(sorted(l, key=lambda x: order.get(x[0], 0)))
|
||||
return OrderedDict(sorted(l))
|
||||
if isinstance(stuff, list):
|
||||
return [process(x) for x in stuff]
|
||||
return stuff
|
||||
return process
|
||||
|
||||
SCHEMA_HTML_FILE = "schema.html"
|
||||
|
||||
schema = json.dumps(CONFIG_SCHEMA, sort_keys = True)
|
||||
schema = json.loads(schema, object_pairs_hook = OrderedDict)
|
||||
|
||||
req = sorted(schema["required"])
|
||||
custom_sort = make_custom_sort([req])
|
||||
schema = custom_sort(schema)
|
||||
|
||||
parsed_items = docsmodel.DocsModel().parse(schema)[1:]
|
||||
output = html_output.HtmlOutput().generate_output(parsed_items)
|
||||
|
||||
with open(SCHEMA_HTML_FILE, 'wt') as f:
|
||||
f.write(output)
|
93
util/config/configdocs/docsmodel.py
Normal file
93
util/config/configdocs/docsmodel.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import json
|
||||
import collections
|
||||
|
||||
class ParsedItem(dict):
|
||||
""" Parsed Schema item """
|
||||
|
||||
def __init__(self, json_object, name, required, level):
|
||||
"""Fills dict with basic item information"""
|
||||
super(ParsedItem, self).__init__()
|
||||
self['name'] = name
|
||||
self['title'] = json_object.get('title', '')
|
||||
self['type'] = json_object.get('type')
|
||||
self['description'] = json_object.get('description', '')
|
||||
self['level'] = level
|
||||
self['required'] = required
|
||||
self['x-reference'] = json_object.get('x-reference', '')
|
||||
self['x-example'] = json_object.get('x-example', '')
|
||||
self['pattern'] = json_object.get('pattern', '')
|
||||
self['enum'] = json_object.get('enum', '')
|
||||
|
||||
class DocsModel:
|
||||
""" Documentation model and Schema Parser """
|
||||
|
||||
def __init__(self):
|
||||
self.__parsed_items = None
|
||||
|
||||
def parse(self, json_object):
|
||||
""" Returns multi-level list of recursively parsed items """
|
||||
|
||||
self.__parsed_items = list()
|
||||
self.__parse_schema(json_object, 'root', True, 0)
|
||||
return self.__parsed_items
|
||||
|
||||
def __parse_schema(self, schema, name, required, level):
|
||||
""" Parses schema, which type is object, array or leaf.
|
||||
Appends new ParsedItem to self.__parsed_items lis """
|
||||
parsed_item = ParsedItem(schema, name, required, level)
|
||||
self.__parsed_items.append(parsed_item)
|
||||
required = schema.get('required', [])
|
||||
|
||||
if 'enum' in schema:
|
||||
parsed_item['item'] = schema.get('enum')
|
||||
item_type = schema.get('type')
|
||||
if item_type == 'object' and name != 'DISTRIBUTED_STORAGE_CONFIG':
|
||||
self.__parse_object(parsed_item, schema, required, level)
|
||||
elif item_type == 'array':
|
||||
self.__parse_array(parsed_item, schema, required, level)
|
||||
else:
|
||||
parse_leaf(parsed_item, schema)
|
||||
|
||||
def __parse_object(self, parsed_item, schema, required, level):
|
||||
""" Parses schema of type object """
|
||||
for key, value in schema.get('properties', {}).items():
|
||||
self.__parse_schema(value, key, key in required, level + 1)
|
||||
|
||||
def __parse_array(self, parsed_item, schema, required, level):
|
||||
""" Parses schema of type array """
|
||||
items = schema.get('items')
|
||||
parsed_item['minItems'] = schema.get('minItems', None)
|
||||
parsed_item['maxItems'] = schema.get('maxItems', None)
|
||||
parsed_item['uniqueItems'] = schema.get('uniqueItems', False)
|
||||
if isinstance(items, dict):
|
||||
# item is single schema describing all elements in an array
|
||||
self.__parse_schema(
|
||||
items,
|
||||
'array item',
|
||||
required,
|
||||
level + 1)
|
||||
|
||||
elif isinstance(items, list):
|
||||
# item is a list of schemas
|
||||
for index, list_item in enumerate(items):
|
||||
self.__parse_schema(
|
||||
list_item,
|
||||
'array item {}'.format(index),
|
||||
index in required,
|
||||
level + 1)
|
||||
|
||||
def parse_leaf(parsed_item, schema):
|
||||
""" Parses schema of a number and a string """
|
||||
if parsed_item['name'] != 'root':
|
||||
parsed_item['description'] = schema.get('description','')
|
||||
parsed_item['x-reference'] = schema.get('x-reference','')
|
||||
parsed_item['pattern'] = schema.get('pattern','')
|
||||
parsed_item['enum'] = ", ".join(schema.get('enum','')).encode()
|
||||
|
||||
ex = schema.get('x-example', '')
|
||||
if isinstance(ex, list):
|
||||
parsed_item['x-example'] = ", ".join(ex).encode()
|
||||
elif isinstance(ex, collections.OrderedDict):
|
||||
parsed_item['x-example'] = json.dumps(ex)
|
||||
else:
|
||||
parsed_item['x-example'] = ex
|
63
util/config/configdocs/html_output.py
Normal file
63
util/config/configdocs/html_output.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
class HtmlOutput:
|
||||
""" Generates HTML from documentation model """
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_output(self, parsed_items):
|
||||
"""Returns generated HTML strin"""
|
||||
return self.__get_html_begin() + \
|
||||
self.__get_html_middle(parsed_items) + \
|
||||
self.__get_html_end()
|
||||
|
||||
def __get_html_begin(self):
|
||||
return '<!DOCTYPE html>\n<html>\n<head>\n<link rel="stylesheet" type="text/css" href="style.css" />\n</head>\n<body>\n'
|
||||
|
||||
def __get_html_end(self):
|
||||
return '</body>\n</html>'
|
||||
|
||||
def __get_html_middle(self, parsed_items):
|
||||
output = ''
|
||||
root_item = parsed_items[0]
|
||||
|
||||
#output += '<h1 class="root_title">{}</h1>\n'.format(root_item['title'])
|
||||
#output += '<h1 class="root_title">{}</h1>\n'.format(root_item['title'])
|
||||
output += "Schema for Red Hat Quay"
|
||||
|
||||
output += '<ul class="level0">\n'
|
||||
last_level = 0
|
||||
is_root = True
|
||||
for item in parsed_items:
|
||||
level = item['level'] - 1
|
||||
if last_level < level:
|
||||
output += '<ul class="level{}">\n'.format(level)
|
||||
for i in range(last_level - level):
|
||||
output += '</ul>\n'
|
||||
last_level = level
|
||||
output += self.__get_html_item(item, is_root)
|
||||
is_root = False
|
||||
output += '</ul>\n'
|
||||
return output
|
||||
|
||||
def __get_required_field(self, parsed_item):
|
||||
return 'required' if parsed_item['required'] else ''
|
||||
|
||||
def __get_html_item(self, parsed_item, is_root):
|
||||
item = '<li class="schema item"> \n'
|
||||
item += '<div class="name">{}</div> \n'.format(parsed_item['name'])
|
||||
item += '<div class="type">[{}]</div> \n'.format(parsed_item['type'])
|
||||
item += '<div class="required">{}</div> \n'.format(self.__get_required_field(parsed_item))
|
||||
item += '<div class="docs">\n' if not is_root else '<div class="root_docs">\n'
|
||||
item += '<div class="title">{}</div>\n'.format(parsed_item['title'])
|
||||
item += ': ' if parsed_item['title'] != '' and parsed_item['description'] != '' else ''
|
||||
item += '<div class="description">{}</div>\n'.format(parsed_item['description'])
|
||||
item += '<div class="enum">enum: {}</div>\n'.format(parsed_item['enum']) if parsed_item['enum'] != '' else ''
|
||||
item += '<div class="minItems">Min Items: {}</div>\n'.format(parsed_item['minItems']) if parsed_item['type'] == "array" and parsed_item['minItems'] != "None" else ''
|
||||
item += '<div class="uniqueItems">Unique Items: {}</div>\n'.format(parsed_item['uniqueItems']) if parsed_item['type'] == "array" and parsed_item['uniqueItems'] else ''
|
||||
item += '<div class="pattern">Pattern: {}</div>\n'.format(parsed_item['pattern']) if parsed_item['pattern'] != 'None' and parsed_item['pattern'] != '' else ''
|
||||
item += '<div class="x-reference"><a href="{}">Reference: {}</a></div>\n'.format(parsed_item['x-reference'],parsed_item['x-reference']) if parsed_item['x-reference'] != '' else ''
|
||||
item += '<div class="x-example">Example: <code>{}</code></div>\n'.format(parsed_item['x-example']) if parsed_item['x-example'] != '' else ''
|
||||
item += '</div>\n'
|
||||
item += '</li>\n'
|
||||
return item
|
||||
|
1062
util/config/configdocs/schema.html
Normal file
1062
util/config/configdocs/schema.html
Normal file
File diff suppressed because it is too large
Load diff
78
util/config/configdocs/style.css
Normal file
78
util/config/configdocs/style.css
Normal file
|
@ -0,0 +1,78 @@
|
|||
body {
|
||||
font-family: sans-serif;
|
||||
}
|
||||
|
||||
pre, code{
|
||||
white-space: normal;
|
||||
}
|
||||
|
||||
div.root docs{
|
||||
display: none;
|
||||
}
|
||||
|
||||
div.name {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
div.type {
|
||||
display: inline;
|
||||
font-weight: bold;
|
||||
color: blue;
|
||||
}
|
||||
|
||||
div.required {
|
||||
display: inline;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
div.docs {
|
||||
display: block;
|
||||
}
|
||||
|
||||
div.title {
|
||||
display: block;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
div.description {
|
||||
display: block;
|
||||
font-family: serif;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div.enum {
|
||||
display: block;
|
||||
font-family: serif;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div.x-example {
|
||||
display: block;
|
||||
font-family: serif;
|
||||
font-style: italic;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
div.pattern {
|
||||
display: block;
|
||||
font-family: serif;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div.x-reference {
|
||||
display: block;
|
||||
font-family: serif;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div.uniqueItems {
|
||||
display: block;
|
||||
}
|
||||
|
||||
div.minItems {
|
||||
display: block;
|
||||
}
|
||||
|
||||
div.maxItems {
|
||||
display: block;
|
||||
}
|
108
util/config/configutil.py
Normal file
108
util/config/configutil.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
from random import SystemRandom
|
||||
from uuid import uuid4
|
||||
|
||||
def generate_secret_key():
|
||||
cryptogen = SystemRandom()
|
||||
return str(cryptogen.getrandbits(256))
|
||||
|
||||
|
||||
def add_enterprise_config_defaults(config_obj, current_secret_key):
|
||||
""" Adds/Sets the config defaults for enterprise registry config. """
|
||||
# These have to be false.
|
||||
config_obj['TESTING'] = False
|
||||
config_obj['USE_CDN'] = False
|
||||
|
||||
# Default for V3 upgrade.
|
||||
config_obj['V3_UPGRADE_MODE'] = config_obj.get('V3_UPGRADE_MODE', 'complete')
|
||||
|
||||
# Defaults for Red Hat Quay.
|
||||
config_obj['REGISTRY_TITLE'] = config_obj.get('REGISTRY_TITLE', 'Red Hat Quay')
|
||||
config_obj['REGISTRY_TITLE_SHORT'] = config_obj.get('REGISTRY_TITLE_SHORT', 'Red Hat Quay')
|
||||
|
||||
# Default features that are on.
|
||||
config_obj['FEATURE_USER_LOG_ACCESS'] = config_obj.get('FEATURE_USER_LOG_ACCESS', True)
|
||||
config_obj['FEATURE_USER_CREATION'] = config_obj.get('FEATURE_USER_CREATION', True)
|
||||
config_obj['FEATURE_ANONYMOUS_ACCESS'] = config_obj.get('FEATURE_ANONYMOUS_ACCESS', True)
|
||||
config_obj['FEATURE_REQUIRE_TEAM_INVITE'] = config_obj.get('FEATURE_REQUIRE_TEAM_INVITE', True)
|
||||
config_obj['FEATURE_CHANGE_TAG_EXPIRATION'] = config_obj.get('FEATURE_CHANGE_TAG_EXPIRATION',
|
||||
True)
|
||||
config_obj['FEATURE_DIRECT_LOGIN'] = config_obj.get('FEATURE_DIRECT_LOGIN', True)
|
||||
config_obj['FEATURE_APP_SPECIFIC_TOKENS'] = config_obj.get('FEATURE_APP_SPECIFIC_TOKENS', True)
|
||||
config_obj['FEATURE_PARTIAL_USER_AUTOCOMPLETE'] = config_obj.get('FEATURE_PARTIAL_USER_AUTOCOMPLETE', True)
|
||||
config_obj['FEATURE_USERNAME_CONFIRMATION'] = config_obj.get('FEATURE_USERNAME_CONFIRMATION', True)
|
||||
config_obj['FEATURE_RESTRICTED_V1_PUSH'] = config_obj.get('FEATURE_RESTRICTED_V1_PUSH', True)
|
||||
|
||||
# Default features that are off.
|
||||
config_obj['FEATURE_MAILING'] = config_obj.get('FEATURE_MAILING', False)
|
||||
config_obj['FEATURE_BUILD_SUPPORT'] = config_obj.get('FEATURE_BUILD_SUPPORT', False)
|
||||
config_obj['FEATURE_ACI_CONVERSION'] = config_obj.get('FEATURE_ACI_CONVERSION', False)
|
||||
config_obj['FEATURE_APP_REGISTRY'] = config_obj.get('FEATURE_APP_REGISTRY', False)
|
||||
config_obj['FEATURE_REPO_MIRROR'] = config_obj.get('FEATURE_REPO_MIRROR', False)
|
||||
|
||||
# Default repo mirror config.
|
||||
config_obj['REPO_MIRROR_TLS_VERIFY'] = config_obj.get('REPO_MIRROR_TLS_VERIFY', True)
|
||||
config_obj['REPO_MIRROR_SERVER_HOSTNAME'] = config_obj.get('REPO_MIRROR_SERVER_HOSTNAME', None)
|
||||
|
||||
# Default the signer config.
|
||||
config_obj['GPG2_PRIVATE_KEY_FILENAME'] = config_obj.get('GPG2_PRIVATE_KEY_FILENAME',
|
||||
'signing-private.gpg')
|
||||
config_obj['GPG2_PUBLIC_KEY_FILENAME'] = config_obj.get('GPG2_PUBLIC_KEY_FILENAME',
|
||||
'signing-public.gpg')
|
||||
config_obj['SIGNING_ENGINE'] = config_obj.get('SIGNING_ENGINE', 'gpg2')
|
||||
|
||||
# Default security scanner config.
|
||||
config_obj['FEATURE_SECURITY_NOTIFICATIONS'] = config_obj.get(
|
||||
'FEATURE_SECURITY_NOTIFICATIONS', True)
|
||||
|
||||
config_obj['FEATURE_SECURITY_SCANNER'] = config_obj.get(
|
||||
'FEATURE_SECURITY_SCANNER', False)
|
||||
|
||||
config_obj['SECURITY_SCANNER_ISSUER_NAME'] = config_obj.get(
|
||||
'SECURITY_SCANNER_ISSUER_NAME', 'security_scanner')
|
||||
|
||||
# Default time machine config.
|
||||
config_obj['TAG_EXPIRATION_OPTIONS'] = config_obj.get('TAG_EXPIRATION_OPTIONS',
|
||||
['0s', '1d', '1w', '2w', '4w'])
|
||||
config_obj['DEFAULT_TAG_EXPIRATION'] = config_obj.get('DEFAULT_TAG_EXPIRATION', '2w')
|
||||
|
||||
# Default mail setings.
|
||||
config_obj['MAIL_USE_TLS'] = config_obj.get('MAIL_USE_TLS', True)
|
||||
config_obj['MAIL_PORT'] = config_obj.get('MAIL_PORT', 587)
|
||||
config_obj['MAIL_DEFAULT_SENDER'] = config_obj.get('MAIL_DEFAULT_SENDER', 'support@quay.io')
|
||||
|
||||
# Default auth type.
|
||||
if not 'AUTHENTICATION_TYPE' in config_obj:
|
||||
config_obj['AUTHENTICATION_TYPE'] = 'Database'
|
||||
|
||||
# Default secret key.
|
||||
if not 'SECRET_KEY' in config_obj:
|
||||
if current_secret_key:
|
||||
config_obj['SECRET_KEY'] = current_secret_key
|
||||
else:
|
||||
config_obj['SECRET_KEY'] = generate_secret_key()
|
||||
|
||||
# Default database secret key.
|
||||
if not 'DATABASE_SECRET_KEY' in config_obj:
|
||||
config_obj['DATABASE_SECRET_KEY'] = generate_secret_key()
|
||||
|
||||
# Default torrent pepper.
|
||||
if not 'BITTORRENT_FILENAME_PEPPER' in config_obj:
|
||||
config_obj['BITTORRENT_FILENAME_PEPPER'] = str(uuid4())
|
||||
|
||||
# Default storage configuration.
|
||||
if not 'DISTRIBUTED_STORAGE_CONFIG' in config_obj:
|
||||
config_obj['DISTRIBUTED_STORAGE_PREFERENCE'] = ['default']
|
||||
config_obj['DISTRIBUTED_STORAGE_CONFIG'] = {
|
||||
'default': ['LocalStorage', {'storage_path': '/datastorage/registry'}]
|
||||
}
|
||||
|
||||
config_obj['USERFILES_LOCATION'] = 'default'
|
||||
config_obj['USERFILES_PATH'] = 'userfiles/'
|
||||
|
||||
config_obj['LOG_ARCHIVE_LOCATION'] = 'default'
|
||||
|
||||
# Misc configuration.
|
||||
config_obj['PREFERRED_URL_SCHEME'] = config_obj.get('PREFERRED_URL_SCHEME', 'http')
|
||||
config_obj['ENTERPRISE_LOGO_URL'] = config_obj.get(
|
||||
'ENTERPRISE_LOGO_URL', '/static/img/quay-horizontal-color.svg')
|
||||
config_obj['TEAM_RESYNC_STALE_TIME'] = config_obj.get('TEAM_RESYNC_STALE_TIME', '60m')
|
12
util/config/database.py
Normal file
12
util/config/database.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from data import model
|
||||
from data.appr_model import blob
|
||||
from data.appr_model.models import NEW_MODELS
|
||||
|
||||
|
||||
def sync_database_with_config(config):
|
||||
""" This ensures all implicitly required reference table entries exist in the database. """
|
||||
|
||||
location_names = config.get('DISTRIBUTED_STORAGE_CONFIG', {}).keys()
|
||||
if location_names:
|
||||
model.image.ensure_image_locations(*location_names)
|
||||
blob.ensure_blob_locations(NEW_MODELS, *location_names)
|
13
util/config/provider/__init__.py
Normal file
13
util/config/provider/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from util.config.provider.fileprovider import FileConfigProvider
|
||||
from util.config.provider.testprovider import TestConfigProvider
|
||||
from util.config.provider.k8sprovider import KubernetesConfigProvider
|
||||
|
||||
def get_config_provider(config_volume, yaml_filename, py_filename, testing=False, kubernetes=False):
|
||||
""" Loads and returns the config provider for the current environment. """
|
||||
if testing:
|
||||
return TestConfigProvider()
|
||||
|
||||
if kubernetes:
|
||||
return KubernetesConfigProvider(config_volume, yaml_filename, py_filename)
|
||||
|
||||
return FileConfigProvider(config_volume, yaml_filename, py_filename)
|
62
util/config/provider/basefileprovider.py
Normal file
62
util/config/provider/basefileprovider.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
from util.config.provider.baseprovider import (BaseProvider, import_yaml, export_yaml,
|
||||
CannotWriteConfigException)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseFileProvider(BaseProvider):
|
||||
""" Base implementation of the config provider that reads the data from the file system. """
|
||||
def __init__(self, config_volume, yaml_filename, py_filename):
|
||||
self.config_volume = config_volume
|
||||
self.yaml_filename = yaml_filename
|
||||
self.py_filename = py_filename
|
||||
|
||||
self.yaml_path = os.path.join(config_volume, yaml_filename)
|
||||
self.py_path = os.path.join(config_volume, py_filename)
|
||||
|
||||
def update_app_config(self, app_config):
|
||||
if os.path.exists(self.py_path):
|
||||
logger.debug('Applying config file: %s', self.py_path)
|
||||
app_config.from_pyfile(self.py_path)
|
||||
|
||||
if os.path.exists(self.yaml_path):
|
||||
logger.debug('Applying config file: %s', self.yaml_path)
|
||||
import_yaml(app_config, self.yaml_path)
|
||||
|
||||
def get_config(self):
|
||||
if not self.config_exists():
|
||||
return None
|
||||
|
||||
config_obj = {}
|
||||
import_yaml(config_obj, self.yaml_path)
|
||||
return config_obj
|
||||
|
||||
def config_exists(self):
|
||||
return self.volume_file_exists(self.yaml_filename)
|
||||
|
||||
def volume_exists(self):
|
||||
return os.path.exists(self.config_volume)
|
||||
|
||||
def volume_file_exists(self, relative_file_path):
|
||||
return os.path.exists(os.path.join(self.config_volume, relative_file_path))
|
||||
|
||||
def get_volume_file(self, relative_file_path, mode='r'):
|
||||
return open(os.path.join(self.config_volume, relative_file_path), mode=mode)
|
||||
|
||||
def get_volume_path(self, directory, relative_file_path):
|
||||
return os.path.join(directory, relative_file_path)
|
||||
|
||||
def list_volume_directory(self, path):
|
||||
dirpath = os.path.join(self.config_volume, path)
|
||||
if not os.path.exists(dirpath):
|
||||
return None
|
||||
|
||||
if not os.path.isdir(dirpath):
|
||||
return None
|
||||
|
||||
return os.listdir(dirpath)
|
||||
|
||||
def get_config_root(self):
|
||||
return self.config_volume
|
123
util/config/provider/baseprovider.py
Normal file
123
util/config/provider/baseprovider.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
import logging
|
||||
import yaml
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import add_metaclass
|
||||
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
from util.config.schema import CONFIG_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CannotWriteConfigException(Exception):
|
||||
""" Exception raised when the config cannot be written. """
|
||||
pass
|
||||
|
||||
|
||||
class SetupIncompleteException(Exception):
|
||||
""" Exception raised when attempting to verify config that has not yet been setup. """
|
||||
pass
|
||||
|
||||
|
||||
def import_yaml(config_obj, config_file):
|
||||
with open(config_file) as f:
|
||||
c = yaml.safe_load(f)
|
||||
if not c:
|
||||
logger.debug('Empty YAML config file')
|
||||
return
|
||||
|
||||
if isinstance(c, str):
|
||||
raise Exception('Invalid YAML config file: ' + str(c))
|
||||
|
||||
for key in c.iterkeys():
|
||||
if key.isupper():
|
||||
config_obj[key] = c[key]
|
||||
|
||||
if config_obj.get('SETUP_COMPLETE', True):
|
||||
try:
|
||||
validate(config_obj, CONFIG_SCHEMA)
|
||||
except ValidationError:
|
||||
# TODO: Change this into a real error
|
||||
logger.exception('Could not validate config schema')
|
||||
else:
|
||||
logger.debug('Skipping config schema validation because setup is not complete')
|
||||
|
||||
return config_obj
|
||||
|
||||
|
||||
def get_yaml(config_obj):
|
||||
return yaml.safe_dump(config_obj, encoding='utf-8', allow_unicode=True)
|
||||
|
||||
|
||||
def export_yaml(config_obj, config_file):
|
||||
try:
|
||||
with open(config_file, 'w') as f:
|
||||
f.write(get_yaml(config_obj))
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class BaseProvider(object):
|
||||
""" A configuration provider helps to load, save, and handle config override in the application.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_id(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_app_config(self, app_config):
|
||||
""" Updates the given application config object with the loaded override config. """
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self):
|
||||
""" Returns the contents of the config override file, or None if none. """
|
||||
|
||||
@abstractmethod
|
||||
def save_config(self, config_object):
|
||||
""" Updates the contents of the config override file to those given. """
|
||||
|
||||
@abstractmethod
|
||||
def config_exists(self):
|
||||
""" Returns true if a config override file exists in the config volume. """
|
||||
|
||||
@abstractmethod
|
||||
def volume_exists(self):
|
||||
""" Returns whether the config override volume exists. """
|
||||
|
||||
@abstractmethod
|
||||
def volume_file_exists(self, relative_file_path):
|
||||
""" Returns whether the file with the given relative path exists under the config override
|
||||
volume. """
|
||||
|
||||
@abstractmethod
|
||||
def get_volume_file(self, relative_file_path, mode='r'):
|
||||
""" Returns a Python file referring to the given path under the config override volume. """
|
||||
|
||||
@abstractmethod
|
||||
def remove_volume_file(self, relative_file_path):
|
||||
""" Removes the config override volume file with the given path. """
|
||||
|
||||
@abstractmethod
|
||||
def list_volume_directory(self, path):
|
||||
""" Returns a list of strings representing the names of the files found in the config override
|
||||
directory under the given path. If the path doesn't exist, returns None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_volume_file(self, flask_file, relative_file_path):
|
||||
""" Saves the given flask file to the config override volume, with the given
|
||||
relative path.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_volume_path(self, directory, filename):
|
||||
""" Helper for constructing relative file paths, which may differ between providers.
|
||||
For example, kubernetes can't have subfolders in configmaps """
|
||||
|
||||
@abstractmethod
|
||||
def get_config_root(self):
|
||||
""" Returns the config root directory. """
|
47
util/config/provider/fileprovider.py
Normal file
47
util/config/provider/fileprovider.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
from util.config.provider.baseprovider import export_yaml, CannotWriteConfigException
|
||||
from util.config.provider.basefileprovider import BaseFileProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _ensure_parent_dir(filepath):
|
||||
""" Ensures that the parent directory of the given file path exists. """
|
||||
try:
|
||||
parentpath = os.path.abspath(os.path.join(filepath, os.pardir))
|
||||
if not os.path.isdir(parentpath):
|
||||
os.makedirs(parentpath)
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
|
||||
class FileConfigProvider(BaseFileProvider):
|
||||
""" Implementation of the config provider that reads and writes the data
|
||||
from/to the file system. """
|
||||
def __init__(self, config_volume, yaml_filename, py_filename):
|
||||
super(FileConfigProvider, self).__init__(config_volume, yaml_filename, py_filename)
|
||||
|
||||
@property
|
||||
def provider_id(self):
|
||||
return 'file'
|
||||
|
||||
def save_config(self, config_obj):
|
||||
export_yaml(config_obj, self.yaml_path)
|
||||
|
||||
def remove_volume_file(self, relative_file_path):
|
||||
filepath = os.path.join(self.config_volume, relative_file_path)
|
||||
os.remove(filepath)
|
||||
|
||||
def save_volume_file(self, flask_file, relative_file_path):
|
||||
filepath = os.path.join(self.config_volume, relative_file_path)
|
||||
_ensure_parent_dir(filepath)
|
||||
|
||||
# Write the file.
|
||||
try:
|
||||
flask_file.save(filepath)
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
return filepath
|
188
util/config/provider/k8sprovider.py
Normal file
188
util/config/provider/k8sprovider.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
import os
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
|
||||
from cStringIO import StringIO
|
||||
from requests import Request, Session
|
||||
|
||||
from util.config.provider.baseprovider import CannotWriteConfigException, get_yaml
|
||||
from util.config.provider.basefileprovider import BaseFileProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KUBERNETES_API_HOST = os.environ.get('KUBERNETES_SERVICE_HOST', '')
|
||||
port = os.environ.get('KUBERNETES_SERVICE_PORT')
|
||||
if port:
|
||||
KUBERNETES_API_HOST += ':' + port
|
||||
|
||||
SERVICE_ACCOUNT_TOKEN_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/token'
|
||||
|
||||
QE_NAMESPACE = os.environ.get('QE_K8S_NAMESPACE', 'quay-enterprise')
|
||||
QE_CONFIG_SECRET = os.environ.get('QE_K8S_CONFIG_SECRET', 'quay-enterprise-config-secret')
|
||||
|
||||
class KubernetesConfigProvider(BaseFileProvider):
|
||||
""" Implementation of the config provider that reads and writes configuration
|
||||
data from a Kubernetes Secret. """
|
||||
def __init__(self, config_volume, yaml_filename, py_filename, api_host=None,
|
||||
service_account_token_path=None):
|
||||
super(KubernetesConfigProvider, self).__init__(config_volume, yaml_filename, py_filename)
|
||||
service_account_token_path = service_account_token_path or SERVICE_ACCOUNT_TOKEN_PATH
|
||||
api_host = api_host or KUBERNETES_API_HOST
|
||||
|
||||
# Load the service account token from the local store.
|
||||
if not os.path.exists(service_account_token_path):
|
||||
raise Exception('Cannot load Kubernetes service account token')
|
||||
|
||||
with open(service_account_token_path, 'r') as f:
|
||||
self._service_token = f.read()
|
||||
|
||||
self._api_host = api_host
|
||||
|
||||
@property
|
||||
def provider_id(self):
|
||||
return 'k8s'
|
||||
|
||||
def get_volume_path(self, directory, filename):
|
||||
# NOTE: Overridden to ensure we don't have subdirectories, which aren't supported
|
||||
# in Kubernetes secrets.
|
||||
return "_".join([directory.rstrip('/'), filename])
|
||||
|
||||
def volume_exists(self):
|
||||
secret = self._lookup_secret()
|
||||
return secret is not None
|
||||
|
||||
def volume_file_exists(self, relative_file_path):
|
||||
if '/' in relative_file_path:
|
||||
raise Exception('Expected path from get_volume_path, but found slashes')
|
||||
|
||||
# NOTE: Overridden because we don't have subdirectories, which aren't supported
|
||||
# in Kubernetes secrets.
|
||||
secret = self._lookup_secret()
|
||||
if not secret or not secret.get('data'):
|
||||
return False
|
||||
return relative_file_path in secret['data']
|
||||
|
||||
def list_volume_directory(self, path):
|
||||
# NOTE: Overridden because we don't have subdirectories, which aren't supported
|
||||
# in Kubernetes secrets.
|
||||
secret = self._lookup_secret()
|
||||
|
||||
if not secret:
|
||||
return []
|
||||
|
||||
paths = []
|
||||
for filename in secret.get('data', {}):
|
||||
if filename.startswith(path):
|
||||
paths.append(filename[len(path) + 1:])
|
||||
return paths
|
||||
|
||||
def save_config(self, config_obj):
|
||||
self._update_secret_file(self.yaml_filename, get_yaml(config_obj))
|
||||
|
||||
def remove_volume_file(self, relative_file_path):
|
||||
try:
|
||||
self._update_secret_file(relative_file_path, None)
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
def save_volume_file(self, flask_file, relative_file_path):
|
||||
# Write the file to a temp location.
|
||||
buf = StringIO()
|
||||
try:
|
||||
try:
|
||||
flask_file.save(buf)
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
self._update_secret_file(relative_file_path, buf.getvalue())
|
||||
finally:
|
||||
buf.close()
|
||||
|
||||
def _assert_success(self, response):
|
||||
if response.status_code != 200:
|
||||
logger.error('Kubernetes API call failed with response: %s => %s', response.status_code,
|
||||
response.text)
|
||||
raise CannotWriteConfigException('Kubernetes API call failed: %s' % response.text)
|
||||
|
||||
def _update_secret_file(self, relative_file_path, value=None):
|
||||
if '/' in relative_file_path:
|
||||
raise Exception('Expected path from get_volume_path, but found slashes')
|
||||
|
||||
# Check first that the namespace for Red Hat Quay exists. If it does not, report that
|
||||
# as an error, as it seems to be a common issue.
|
||||
namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
|
||||
response = self._execute_k8s_api('GET', namespace_url)
|
||||
if response.status_code // 100 != 2:
|
||||
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
|
||||
raise CannotWriteConfigException(msg)
|
||||
|
||||
# Check if the secret exists. If not, then we create an empty secret and then update the file
|
||||
# inside.
|
||||
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
|
||||
secret = self._lookup_secret()
|
||||
if secret is None:
|
||||
self._assert_success(self._execute_k8s_api('POST', secret_url, {
|
||||
"kind": "Secret",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {
|
||||
"name": QE_CONFIG_SECRET
|
||||
},
|
||||
"data": {}
|
||||
}))
|
||||
|
||||
# Update the secret to reflect the file change.
|
||||
secret['data'] = secret.get('data', {})
|
||||
|
||||
if value is not None:
|
||||
secret['data'][relative_file_path] = base64.b64encode(value)
|
||||
else:
|
||||
secret['data'].pop(relative_file_path)
|
||||
|
||||
self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))
|
||||
|
||||
# Wait until the local mounted copy of the secret has been updated, as
|
||||
# this is an eventual consistency operation, but the caller expects immediate
|
||||
# consistency.
|
||||
while True:
|
||||
matching_files = set()
|
||||
for secret_filename, encoded_value in secret['data'].iteritems():
|
||||
expected_value = base64.b64decode(encoded_value)
|
||||
try:
|
||||
with self.get_volume_file(secret_filename) as f:
|
||||
contents = f.read()
|
||||
|
||||
if contents == expected_value:
|
||||
matching_files.add(secret_filename)
|
||||
except IOError:
|
||||
continue
|
||||
|
||||
if matching_files == set(secret['data'].keys()):
|
||||
break
|
||||
|
||||
# Sleep for a second and then try again.
|
||||
time.sleep(1)
|
||||
|
||||
def _lookup_secret(self):
|
||||
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
|
||||
response = self._execute_k8s_api('GET', secret_url)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
return json.loads(response.text)
|
||||
|
||||
def _execute_k8s_api(self, method, relative_url, data=None):
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + self._service_token
|
||||
}
|
||||
|
||||
if data:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
data = json.dumps(data) if data else None
|
||||
session = Session()
|
||||
url = 'https://%s/api/v1/%s' % (self._api_host, relative_url)
|
||||
|
||||
request = Request(method, url, data=data, headers=headers)
|
||||
return session.send(request.prepare(), verify=False, timeout=2)
|
29
util/config/provider/test/test_fileprovider.py
Normal file
29
util/config/provider/test/test_fileprovider.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
|
||||
from util.config.provider import FileConfigProvider
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
|
||||
class TestFileConfigProvider(FileConfigProvider):
|
||||
def __init__(self):
|
||||
self.yaml_filename = 'yaml_filename'
|
||||
self._service_token = 'service_token'
|
||||
self.config_volume = 'config_volume'
|
||||
self.py_filename = 'py_filename'
|
||||
self.yaml_path = os.path.join(self.config_volume, self.yaml_filename)
|
||||
self.py_path = os.path.join(self.config_volume, self.py_filename)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('directory,filename,expected', [
|
||||
("directory", "file", "directory/file"),
|
||||
("directory/dir", "file", "directory/dir/file"),
|
||||
("directory/dir/", "file", "directory/dir/file"),
|
||||
("directory", "file/test", "directory/file/test"),
|
||||
])
|
||||
def test_get_volume_path(directory, filename, expected):
|
||||
provider = TestFileConfigProvider()
|
||||
|
||||
assert expected == provider.get_volume_path(directory, filename)
|
||||
|
||||
|
138
util/config/provider/test/test_k8sprovider.py
Normal file
138
util/config/provider/test/test_k8sprovider.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import base64
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from util.config.provider import KubernetesConfigProvider
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('/', '_')
|
||||
|
||||
@contextmanager
|
||||
def fake_kubernetes_api(tmpdir_factory, files=None):
|
||||
hostname = 'kubapi'
|
||||
service_account_token_path = str(tmpdir_factory.mktemp("k8s").join("serviceaccount"))
|
||||
auth_header = str(uuid.uuid4())
|
||||
|
||||
with open(service_account_token_path, 'w') as f:
|
||||
f.write(auth_header)
|
||||
|
||||
global secret
|
||||
secret = {
|
||||
'data': {}
|
||||
}
|
||||
|
||||
def write_file(config_dir, filepath, value):
|
||||
normalized_path = normalize_path(filepath)
|
||||
absolute_path = str(config_dir.join(normalized_path))
|
||||
try:
|
||||
os.makedirs(os.path.dirname(absolute_path))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
with open(absolute_path, 'w') as f:
|
||||
f.write(value)
|
||||
|
||||
config_dir = tmpdir_factory.mktemp("config")
|
||||
if files:
|
||||
for filepath, value in files.iteritems():
|
||||
normalized_path = normalize_path(filepath)
|
||||
write_file(config_dir, filepath, value)
|
||||
secret['data'][normalized_path] = base64.b64encode(value)
|
||||
|
||||
@urlmatch(netloc=hostname,
|
||||
path='/api/v1/namespaces/quay-enterprise/secrets/quay-enterprise-config-secret$',
|
||||
method='get')
|
||||
def get_secret(_, __):
|
||||
return {'status_code': 200, 'content': json.dumps(secret)}
|
||||
|
||||
@urlmatch(netloc=hostname,
|
||||
path='/api/v1/namespaces/quay-enterprise/secrets/quay-enterprise-config-secret$',
|
||||
method='put')
|
||||
def put_secret(_, request):
|
||||
updated_secret = json.loads(request.body)
|
||||
for filepath, value in updated_secret['data'].iteritems():
|
||||
if filepath not in secret['data']:
|
||||
# Add
|
||||
write_file(config_dir, filepath, base64.b64decode(value))
|
||||
|
||||
for filepath in secret['data']:
|
||||
if filepath not in updated_secret['data']:
|
||||
# Remove.
|
||||
normalized_path = normalize_path(filepath)
|
||||
os.remove(str(config_dir.join(normalized_path)))
|
||||
|
||||
secret['data'] = updated_secret['data']
|
||||
return {'status_code': 200, 'content': json.dumps(secret)}
|
||||
|
||||
@urlmatch(netloc=hostname, path='/api/v1/namespaces/quay-enterprise$')
|
||||
def get_namespace(_, __):
|
||||
return {'status_code': 200, 'content': json.dumps({})}
|
||||
|
||||
@urlmatch(netloc=hostname)
|
||||
def catch_all(url, _):
|
||||
print url
|
||||
return {'status_code': 404, 'content': '{}'}
|
||||
|
||||
with HTTMock(get_secret, put_secret, get_namespace, catch_all):
|
||||
provider = KubernetesConfigProvider(str(config_dir), 'config.yaml', 'config.py',
|
||||
api_host=hostname,
|
||||
service_account_token_path=service_account_token_path)
|
||||
|
||||
# Validate all the files.
|
||||
for filepath, value in files.iteritems():
|
||||
normalized_path = normalize_path(filepath)
|
||||
assert provider.volume_file_exists(normalized_path)
|
||||
with provider.get_volume_file(normalized_path) as f:
|
||||
assert f.read() == value
|
||||
|
||||
yield provider
|
||||
|
||||
|
||||
def test_basic_config(tmpdir_factory):
|
||||
basic_files = {
|
||||
'config.yaml': 'FOO: bar',
|
||||
}
|
||||
|
||||
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
|
||||
assert provider.config_exists()
|
||||
assert provider.get_config() is not None
|
||||
assert provider.get_config()['FOO'] == 'bar'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('filepath', [
|
||||
'foo',
|
||||
'foo/meh',
|
||||
'foo/bar/baz',
|
||||
])
|
||||
def test_remove_file(filepath, tmpdir_factory):
|
||||
basic_files = {
|
||||
filepath: 'foo',
|
||||
}
|
||||
|
||||
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
|
||||
normalized_path = normalize_path(filepath)
|
||||
assert provider.volume_file_exists(normalized_path)
|
||||
provider.remove_volume_file(normalized_path)
|
||||
assert not provider.volume_file_exists(normalized_path)
|
||||
|
||||
|
||||
class TestFlaskFile(object):
|
||||
def save(self, buf):
|
||||
buf.write('hello world!')
|
||||
|
||||
|
||||
def test_save_file(tmpdir_factory):
|
||||
basic_files = {}
|
||||
|
||||
with fake_kubernetes_api(tmpdir_factory, files=basic_files) as provider:
|
||||
assert not provider.volume_file_exists('testfile')
|
||||
flask_file = TestFlaskFile()
|
||||
provider.save_volume_file(flask_file, 'testfile')
|
||||
assert provider.volume_file_exists('testfile')
|
77
util/config/provider/testprovider.py
Normal file
77
util/config/provider/testprovider.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from util.config.provider.baseprovider import BaseProvider
|
||||
|
||||
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
|
||||
|
||||
class TestConfigProvider(BaseProvider):
|
||||
""" Implementation of the config provider for testing. Everything is kept in-memory instead on
|
||||
the real file system. """
|
||||
|
||||
def get_config_root(self):
|
||||
raise Exception('Test Config does not have a config root')
|
||||
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.files = {}
|
||||
self._config = {}
|
||||
|
||||
@property
|
||||
def provider_id(self):
|
||||
return 'test'
|
||||
|
||||
def update_app_config(self, app_config):
|
||||
self._config = app_config
|
||||
|
||||
def get_config(self):
|
||||
if not 'config.yaml' in self.files:
|
||||
return None
|
||||
|
||||
return json.loads(self.files.get('config.yaml', '{}'))
|
||||
|
||||
def save_config(self, config_obj):
|
||||
self.files['config.yaml'] = json.dumps(config_obj)
|
||||
|
||||
def config_exists(self):
|
||||
return 'config.yaml' in self.files
|
||||
|
||||
def volume_exists(self):
|
||||
return True
|
||||
|
||||
def volume_file_exists(self, filename):
|
||||
if filename in REAL_FILES:
|
||||
return True
|
||||
|
||||
return filename in self.files
|
||||
|
||||
def save_volume_file(self, flask_file, filename):
|
||||
self.files[filename] = flask_file.read()
|
||||
|
||||
def get_volume_file(self, filename, mode='r'):
|
||||
if filename in REAL_FILES:
|
||||
return open(filename, mode=mode)
|
||||
|
||||
return io.BytesIO(self.files[filename])
|
||||
|
||||
def remove_volume_file(self, filename):
|
||||
self.files.pop(filename, None)
|
||||
|
||||
def list_volume_directory(self, path):
|
||||
paths = []
|
||||
for filename in self.files:
|
||||
if filename.startswith(path):
|
||||
paths.append(filename[len(path)+1:])
|
||||
|
||||
return paths
|
||||
|
||||
def reset_for_test(self):
|
||||
self._config['SUPER_USERS'] = ['devtable']
|
||||
self.files = {}
|
||||
|
||||
def get_volume_path(self, directory, filename):
|
||||
return os.path.join(directory, filename)
|
1232
util/config/schema.py
Normal file
1232
util/config/schema.py
Normal file
File diff suppressed because it is too large
Load diff
38
util/config/superusermanager.py
Normal file
38
util/config/superusermanager.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from multiprocessing.sharedctypes import Array
|
||||
from util.validation import MAX_USERNAME_LENGTH
|
||||
|
||||
class SuperUserManager(object):
|
||||
""" In-memory helper class for quickly accessing (and updating) the valid
|
||||
set of super users. This class communicates across processes to ensure
|
||||
that the shared set is always the same.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
usernames = app.config.get('SUPER_USERS', [])
|
||||
usernames_str = ','.join(usernames)
|
||||
|
||||
self._max_length = len(usernames_str) + MAX_USERNAME_LENGTH + 1
|
||||
self._array = Array('c', self._max_length, lock=True)
|
||||
self._array.value = usernames_str
|
||||
|
||||
def is_superuser(self, username):
|
||||
""" Returns if the given username represents a super user. """
|
||||
usernames = self._array.value.split(',')
|
||||
return username in usernames
|
||||
|
||||
def register_superuser(self, username):
|
||||
""" Registers a new username as a super user for the duration of the container.
|
||||
Note that this does *not* change any underlying config files.
|
||||
"""
|
||||
usernames = self._array.value.split(',')
|
||||
usernames.append(username)
|
||||
new_string = ','.join(usernames)
|
||||
|
||||
if len(new_string) <= self._max_length:
|
||||
self._array.value = new_string
|
||||
else:
|
||||
raise Exception('Maximum superuser count reached. Please report this to support.')
|
||||
|
||||
def has_superusers(self):
|
||||
""" Returns whether there are any superusers defined. """
|
||||
return bool(self._array.value)
|
7
util/config/test/test_schema.py
Normal file
7
util/config/test/test_schema.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from config import DefaultConfig
|
||||
from util.config.schema import CONFIG_SCHEMA, INTERNAL_ONLY_PROPERTIES
|
||||
|
||||
def test_ensure_schema_defines_all_fields():
|
||||
for key in vars(DefaultConfig):
|
||||
has_key = key in CONFIG_SCHEMA['properties'] or key in INTERNAL_ONLY_PROPERTIES
|
||||
assert has_key, "Property `%s` is missing from config schema" % key
|
32
util/config/test/test_validator.py
Normal file
32
util/config/test/test_validator.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import is_valid_config_upload_filename
|
||||
from util.config.validator import CONFIG_FILENAMES, CONFIG_FILE_SUFFIXES
|
||||
|
||||
def test_valid_config_upload_filenames():
|
||||
for filename in CONFIG_FILENAMES:
|
||||
assert is_valid_config_upload_filename(filename)
|
||||
|
||||
for suffix in CONFIG_FILE_SUFFIXES:
|
||||
assert is_valid_config_upload_filename('foo' + suffix)
|
||||
assert not is_valid_config_upload_filename(suffix + 'foo')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('filename, expect_valid', [
|
||||
('', False),
|
||||
('foo', False),
|
||||
('config.yaml', False),
|
||||
|
||||
('ssl.cert', True),
|
||||
('ssl.key', True),
|
||||
|
||||
('ssl.crt', False),
|
||||
|
||||
('foobar-cloudfront-signing-key.pem', True),
|
||||
('foobaz-cloudfront-signing-key.pem', True),
|
||||
('barbaz-cloudfront-signing-key.pem', True),
|
||||
|
||||
('barbaz-cloudfront-signing-key.pem.bak', False),
|
||||
])
|
||||
def test_is_valid_config_upload_filename(filename, expect_valid):
|
||||
assert is_valid_config_upload_filename(filename) == expect_valid
|
152
util/config/validator.py
Normal file
152
util/config/validator.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import logging
|
||||
|
||||
from auth.auth_context import get_authenticated_user
|
||||
from data.users import LDAP_CERT_FILENAME
|
||||
from util.secscan.secscan_util import get_blob_download_uri_getter
|
||||
from util.config import URLSchemeAndHostname
|
||||
|
||||
from util.config.validators.validate_database import DatabaseValidator
|
||||
from util.config.validators.validate_redis import RedisValidator
|
||||
from util.config.validators.validate_storage import StorageValidator
|
||||
from util.config.validators.validate_ldap import LDAPValidator
|
||||
from util.config.validators.validate_keystone import KeystoneValidator
|
||||
from util.config.validators.validate_jwt import JWTAuthValidator
|
||||
from util.config.validators.validate_secscan import SecurityScannerValidator
|
||||
from util.config.validators.validate_signer import SignerValidator
|
||||
from util.config.validators.validate_torrent import BittorrentValidator
|
||||
from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES
|
||||
from util.config.validators.validate_google_login import GoogleLoginValidator
|
||||
from util.config.validators.validate_bitbucket_trigger import BitbucketTriggerValidator
|
||||
from util.config.validators.validate_gitlab_trigger import GitLabTriggerValidator
|
||||
from util.config.validators.validate_github import GitHubLoginValidator, GitHubTriggerValidator
|
||||
from util.config.validators.validate_oidc import OIDCLoginValidator
|
||||
from util.config.validators.validate_timemachine import TimeMachineValidator
|
||||
from util.config.validators.validate_access import AccessSettingsValidator
|
||||
from util.config.validators.validate_actionlog_archiving import ActionLogArchivingValidator
|
||||
from util.config.validators.validate_apptokenauth import AppTokenAuthValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigValidationException(Exception):
|
||||
""" Exception raised when the configuration fails to validate for a known reason. """
|
||||
pass
|
||||
|
||||
# Note: Only add files required for HTTPS to the SSL_FILESNAMES list.
|
||||
DB_SSL_FILENAMES = ['database.pem']
|
||||
JWT_FILENAMES = ['jwt-authn.cert']
|
||||
ACI_CERT_FILENAMES = ['signing-public.gpg', 'signing-private.gpg']
|
||||
LDAP_FILENAMES = [LDAP_CERT_FILENAME]
|
||||
CONFIG_FILENAMES = (SSL_FILENAMES + DB_SSL_FILENAMES + JWT_FILENAMES + ACI_CERT_FILENAMES +
|
||||
LDAP_FILENAMES)
|
||||
CONFIG_FILE_SUFFIXES = ['-cloudfront-signing-key.pem']
|
||||
EXTRA_CA_DIRECTORY = 'extra_ca_certs'
|
||||
EXTRA_CA_DIRECTORY_PREFIX = 'extra_ca_certs_'
|
||||
|
||||
VALIDATORS = {
|
||||
DatabaseValidator.name: DatabaseValidator.validate,
|
||||
RedisValidator.name: RedisValidator.validate,
|
||||
StorageValidator.name: StorageValidator.validate,
|
||||
GitHubLoginValidator.name: GitHubLoginValidator.validate,
|
||||
GitHubTriggerValidator.name: GitHubTriggerValidator.validate,
|
||||
GitLabTriggerValidator.name: GitLabTriggerValidator.validate,
|
||||
BitbucketTriggerValidator.name: BitbucketTriggerValidator.validate,
|
||||
GoogleLoginValidator.name: GoogleLoginValidator.validate,
|
||||
SSLValidator.name: SSLValidator.validate,
|
||||
LDAPValidator.name: LDAPValidator.validate,
|
||||
JWTAuthValidator.name: JWTAuthValidator.validate,
|
||||
KeystoneValidator.name: KeystoneValidator.validate,
|
||||
SignerValidator.name: SignerValidator.validate,
|
||||
SecurityScannerValidator.name: SecurityScannerValidator.validate,
|
||||
BittorrentValidator.name: BittorrentValidator.validate,
|
||||
OIDCLoginValidator.name: OIDCLoginValidator.validate,
|
||||
TimeMachineValidator.name: TimeMachineValidator.validate,
|
||||
AccessSettingsValidator.name: AccessSettingsValidator.validate,
|
||||
ActionLogArchivingValidator.name: ActionLogArchivingValidator.validate,
|
||||
AppTokenAuthValidator.name: AppTokenAuthValidator.validate,
|
||||
}
|
||||
|
||||
def validate_service_for_config(service, validator_context):
|
||||
""" Attempts to validate the configuration for the given service. """
|
||||
if not service in VALIDATORS:
|
||||
return {
|
||||
'status': False
|
||||
}
|
||||
|
||||
try:
|
||||
VALIDATORS[service](validator_context)
|
||||
return {
|
||||
'status': True
|
||||
}
|
||||
except Exception as ex:
|
||||
logger.exception('Validation exception')
|
||||
return {
|
||||
'status': False,
|
||||
'reason': str(ex)
|
||||
}
|
||||
|
||||
|
||||
def is_valid_config_upload_filename(filename):
|
||||
""" Returns true if and only if the given filename is one which is supported for upload
|
||||
from the configuration UI tool.
|
||||
"""
|
||||
if filename in CONFIG_FILENAMES:
|
||||
return True
|
||||
|
||||
return any([filename.endswith(suffix) for suffix in CONFIG_FILE_SUFFIXES])
|
||||
|
||||
|
||||
class ValidatorContext(object):
|
||||
""" Context to run validators in, with any additional runtime configuration they need
|
||||
"""
|
||||
def __init__(self, config, user_password=None, http_client=None, context=None,
|
||||
url_scheme_and_hostname=None, jwt_auth_max=None, registry_title=None,
|
||||
ip_resolver=None, feature_sec_scanner=False, is_testing=False,
|
||||
uri_creator=None, config_provider=None, instance_keys=None,
|
||||
init_scripts_location=None):
|
||||
self.config = config
|
||||
self.user = get_authenticated_user()
|
||||
self.user_password = user_password
|
||||
self.http_client = http_client
|
||||
self.context = context
|
||||
self.url_scheme_and_hostname = url_scheme_and_hostname
|
||||
self.jwt_auth_max = jwt_auth_max
|
||||
self.registry_title = registry_title
|
||||
self.ip_resolver = ip_resolver
|
||||
self.feature_sec_scanner = feature_sec_scanner
|
||||
self.is_testing = is_testing
|
||||
self.uri_creator = uri_creator
|
||||
self.config_provider = config_provider
|
||||
self.instance_keys = instance_keys
|
||||
self.init_scripts_location = init_scripts_location
|
||||
|
||||
@classmethod
|
||||
def from_app(cls, app, config, user_password, ip_resolver, instance_keys, client=None,
|
||||
config_provider=None, init_scripts_location=None):
|
||||
"""
|
||||
Creates a ValidatorContext from an app config, with a given config to validate
|
||||
:param app: the Flask app to pull configuration information from
|
||||
:param config: the config to validate
|
||||
:param user_password: request password
|
||||
:param instance_keys: The instance keys handler
|
||||
:param ip_resolver: an App
|
||||
:param client: http client used to connect to services
|
||||
:param config_provider: config provider used to access config volume(s)
|
||||
:param init_scripts_location: location where initial load scripts are stored
|
||||
:return: ValidatorContext
|
||||
"""
|
||||
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(app.config)
|
||||
|
||||
return cls(config,
|
||||
user_password=user_password,
|
||||
http_client=client or app.config['HTTPCLIENT'],
|
||||
context=app.app_context,
|
||||
url_scheme_and_hostname=url_scheme_and_hostname,
|
||||
jwt_auth_max=app.config.get('JWT_AUTH_MAX_FRESH_S', 300),
|
||||
registry_title=app.config['REGISTRY_TITLE'],
|
||||
ip_resolver=ip_resolver,
|
||||
feature_sec_scanner=app.config.get('FEATURE_SECURITY_SCANNER', False),
|
||||
is_testing=app.config.get('TESTING', False),
|
||||
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname),
|
||||
config_provider=config_provider,
|
||||
instance_keys=instance_keys,
|
||||
init_scripts_location=init_scripts_location)
|
20
util/config/validators/__init__.py
Normal file
20
util/config/validators/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||
from six import add_metaclass
|
||||
|
||||
class ConfigValidationException(Exception):
|
||||
""" Exception raised when the configuration fails to validate for a known reason. """
|
||||
pass
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class BaseValidator(object):
|
||||
@abstractproperty
|
||||
def name(self):
|
||||
""" The key for the validation API. """
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Raises Exception if failure to validate. """
|
||||
pass
|
29
util/config/validators/test/test_validate_access.py
Normal file
29
util/config/validators/test/test_validate_access.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_access import AccessSettingsValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config, expected_exception', [
|
||||
({}, None),
|
||||
({'FEATURE_DIRECT_LOGIN': False}, ConfigValidationException),
|
||||
({'FEATURE_DIRECT_LOGIN': False, 'SOMETHING_LOGIN_CONFIG': {}}, None),
|
||||
({'FEATURE_DIRECT_LOGIN': False, 'FEATURE_GITHUB_LOGIN': True}, None),
|
||||
({'FEATURE_DIRECT_LOGIN': False, 'FEATURE_GOOGLE_LOGIN': True}, None),
|
||||
({'FEATURE_USER_CREATION': True, 'FEATURE_INVITE_ONLY_USER_CREATION': False}, None),
|
||||
({'FEATURE_USER_CREATION': True, 'FEATURE_INVITE_ONLY_USER_CREATION': True}, None),
|
||||
({'FEATURE_INVITE_ONLY_USER_CREATION': True}, None),
|
||||
({'FEATURE_USER_CREATION': False, 'FEATURE_INVITE_ONLY_USER_CREATION': True},
|
||||
ConfigValidationException),
|
||||
({'FEATURE_USER_CREATION': False, 'FEATURE_INVITE_ONLY_USER_CREATION': False}, None),
|
||||
])
|
||||
def test_validate_invalid_oidc_login_config(unvalidated_config, expected_exception, app):
|
||||
validator = AccessSettingsValidator()
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(expected_exception):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
else:
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_actionlog_archiving import ActionLogArchivingValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'ACTION_LOG_ARCHIVE_PATH': 'foo'}),
|
||||
({'ACTION_LOG_ARCHIVE_LOCATION': ''}),
|
||||
])
|
||||
def test_skip_validate_actionlog(unvalidated_config, app):
|
||||
validator = ActionLogArchivingValidator()
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('config, expected_error', [
|
||||
({'FEATURE_ACTION_LOG_ROTATION': True}, 'Missing action log archive path'),
|
||||
({'FEATURE_ACTION_LOG_ROTATION': True,
|
||||
'ACTION_LOG_ARCHIVE_PATH': ''}, 'Missing action log archive path'),
|
||||
({'FEATURE_ACTION_LOG_ROTATION': True,
|
||||
'ACTION_LOG_ARCHIVE_PATH': 'foo'}, 'Missing action log archive storage location'),
|
||||
({'FEATURE_ACTION_LOG_ROTATION': True,
|
||||
'ACTION_LOG_ARCHIVE_PATH': 'foo',
|
||||
'ACTION_LOG_ARCHIVE_LOCATION': ''}, 'Missing action log archive storage location'),
|
||||
({'FEATURE_ACTION_LOG_ROTATION': True,
|
||||
'ACTION_LOG_ARCHIVE_PATH': 'foo',
|
||||
'ACTION_LOG_ARCHIVE_LOCATION': 'invalid'},
|
||||
'Action log archive storage location `invalid` not found in storage config'),
|
||||
])
|
||||
def test_invalid_config(config, expected_error, app):
|
||||
validator = ActionLogArchivingValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException) as ipe:
|
||||
validator.validate(ValidatorContext(config))
|
||||
|
||||
assert str(ipe.value) == expected_error
|
||||
|
||||
def test_valid_config(app):
|
||||
config = ValidatorContext({
|
||||
'FEATURE_ACTION_LOG_ROTATION': True,
|
||||
'ACTION_LOG_ARCHIVE_PATH': 'somepath',
|
||||
'ACTION_LOG_ARCHIVE_LOCATION': 'somelocation',
|
||||
'DISTRIBUTED_STORAGE_CONFIG': {
|
||||
'somelocation': {},
|
||||
},
|
||||
})
|
||||
|
||||
validator = ActionLogArchivingValidator()
|
||||
validator.validate(config)
|
30
util/config/validators/test/test_validate_apptokenauth.py
Normal file
30
util/config/validators/test/test_validate_apptokenauth.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_apptokenauth import AppTokenAuthValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'AUTHENTICATION_TYPE': 'AppToken'}),
|
||||
({'AUTHENTICATION_TYPE': 'AppToken', 'FEATURE_APP_SPECIFIC_TOKENS': False}),
|
||||
({'AUTHENTICATION_TYPE': 'AppToken', 'FEATURE_APP_SPECIFIC_TOKENS': True,
|
||||
'FEATURE_DIRECT_LOGIN': True}),
|
||||
])
|
||||
def test_validate_invalid_auth_config(unvalidated_config, app):
|
||||
validator = AppTokenAuthValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
|
||||
def test_validate_auth(app):
|
||||
config = ValidatorContext({
|
||||
'AUTHENTICATION_TYPE': 'AppToken',
|
||||
'FEATURE_APP_SPECIFIC_TOKENS': True,
|
||||
'FEATURE_DIRECT_LOGIN': False,
|
||||
})
|
||||
|
||||
validator = AppTokenAuthValidator()
|
||||
validator.validate(config)
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from util.config import URLSchemeAndHostname
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_bitbucket_trigger import BitbucketTriggerValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
(ValidatorContext({})),
|
||||
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {}})),
|
||||
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {'CONSUMER_KEY': 'foo'}})),
|
||||
(ValidatorContext({'BITBUCKET_TRIGGER_CONFIG': {'CONSUMER_SECRET': 'foo'}})),
|
||||
])
|
||||
def test_validate_invalid_bitbucket_trigger_config(unvalidated_config, app):
|
||||
validator = BitbucketTriggerValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
validator.validate(unvalidated_config)
|
||||
|
||||
def test_validate_bitbucket_trigger(app):
|
||||
url_hit = [False]
|
||||
|
||||
@urlmatch(netloc=r'bitbucket.org')
|
||||
def handler(url, request):
|
||||
url_hit[0] = True
|
||||
return {
|
||||
'status_code': 200,
|
||||
'content': 'oauth_token=foo&oauth_token_secret=bar',
|
||||
}
|
||||
|
||||
with HTTMock(handler):
|
||||
validator = BitbucketTriggerValidator()
|
||||
|
||||
url_scheme_and_hostname = URLSchemeAndHostname('http', 'localhost:5000')
|
||||
unvalidated_config = ValidatorContext({
|
||||
'BITBUCKET_TRIGGER_CONFIG': {
|
||||
'CONSUMER_KEY': 'foo',
|
||||
'CONSUMER_SECRET': 'bar',
|
||||
},
|
||||
}, url_scheme_and_hostname=url_scheme_and_hostname)
|
||||
|
||||
validator.validate(unvalidated_config)
|
||||
|
||||
assert url_hit[0]
|
23
util/config/validators/test/test_validate_database.py
Normal file
23
util/config/validators/test/test_validate_database.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_database import DatabaseValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config,user,user_password,expected', [
|
||||
(ValidatorContext(None), None, None, TypeError),
|
||||
(ValidatorContext({}), None, None, KeyError),
|
||||
(ValidatorContext({'DB_URI': 'sqlite:///:memory:'}), None, None, None),
|
||||
(ValidatorContext({'DB_URI': 'invalid:///:memory:'}), None, None, KeyError),
|
||||
(ValidatorContext({'DB_NOTURI': 'sqlite:///:memory:'}), None, None, KeyError),
|
||||
])
|
||||
def test_validate_database(unvalidated_config, user, user_password, expected, app):
|
||||
validator = DatabaseValidator()
|
||||
|
||||
if expected is not None:
|
||||
with pytest.raises(expected):
|
||||
validator.validate(unvalidated_config)
|
||||
else:
|
||||
validator.validate(unvalidated_config)
|
69
util/config/validators/test/test_validate_github.py
Normal file
69
util/config/validators/test/test_validate_github.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from config import build_requests_session
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_github import GitHubLoginValidator, GitHubTriggerValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.fixture(params=[GitHubLoginValidator, GitHubTriggerValidator])
|
||||
def github_validator(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.parametrize('github_config', [
|
||||
({}),
|
||||
({'GITHUB_ENDPOINT': 'foo'}),
|
||||
({'GITHUB_ENDPOINT': 'http://github.com'}),
|
||||
({'GITHUB_ENDPOINT': 'http://github.com', 'CLIENT_ID': 'foo'}),
|
||||
({'GITHUB_ENDPOINT': 'http://github.com', 'CLIENT_SECRET': 'foo'}),
|
||||
({
|
||||
'GITHUB_ENDPOINT': 'http://github.com',
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'foo',
|
||||
'ORG_RESTRICT': True
|
||||
}),
|
||||
({
|
||||
'GITHUB_ENDPOINT': 'http://github.com',
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'foo',
|
||||
'ORG_RESTRICT': True,
|
||||
'ALLOWED_ORGANIZATIONS': [],
|
||||
}),
|
||||
])
|
||||
def test_validate_invalid_github_config(github_config, github_validator, app):
|
||||
with pytest.raises(ConfigValidationException):
|
||||
unvalidated_config = {}
|
||||
unvalidated_config[github_validator.config_key] = github_config
|
||||
github_validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
def test_validate_github(github_validator, app):
|
||||
url_hit = [False, False]
|
||||
|
||||
@urlmatch(netloc=r'somehost')
|
||||
def handler(url, request):
|
||||
url_hit[0] = True
|
||||
return {'status_code': 200, 'content': '', 'headers': {'X-GitHub-Request-Id': 'foo'}}
|
||||
|
||||
@urlmatch(netloc=r'somehost', path=r'/api/v3/applications/foo/tokens/foo')
|
||||
def app_handler(url, request):
|
||||
url_hit[1] = True
|
||||
return {'status_code': 404, 'content': '', 'headers': {'X-GitHub-Request-Id': 'foo'}}
|
||||
|
||||
with HTTMock(app_handler, handler):
|
||||
unvalidated_config = ValidatorContext({
|
||||
github_validator.config_key: {
|
||||
'GITHUB_ENDPOINT': 'http://somehost',
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
},
|
||||
})
|
||||
|
||||
unvalidated_config.http_client = build_requests_session()
|
||||
github_validator.validate(unvalidated_config)
|
||||
|
||||
assert url_hit[0]
|
||||
assert url_hit[1]
|
49
util/config/validators/test/test_validate_gitlab_trigger.py
Normal file
49
util/config/validators/test/test_validate_gitlab_trigger.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import json
|
||||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from config import build_requests_session
|
||||
from util.config import URLSchemeAndHostname
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_gitlab_trigger import GitLabTriggerValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'foo'}}),
|
||||
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'http://someendpoint', 'CLIENT_ID': 'foo'}}),
|
||||
({'GITLAB_TRIGGER_CONFIG': {'GITLAB_ENDPOINT': 'http://someendpoint', 'CLIENT_SECRET': 'foo'}}),
|
||||
])
|
||||
def test_validate_invalid_gitlab_trigger_config(unvalidated_config, app):
|
||||
validator = GitLabTriggerValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
def test_validate_gitlab_enterprise_trigger(app):
|
||||
url_hit = [False]
|
||||
|
||||
@urlmatch(netloc=r'somegitlab', path='/oauth/token')
|
||||
def handler(_, __):
|
||||
url_hit[0] = True
|
||||
return {'status_code': 400, 'content': json.dumps({'error': 'invalid code'})}
|
||||
|
||||
with HTTMock(handler):
|
||||
validator = GitLabTriggerValidator()
|
||||
|
||||
url_scheme_and_hostname = URLSchemeAndHostname('http', 'localhost:5000')
|
||||
|
||||
unvalidated_config = ValidatorContext({
|
||||
'GITLAB_TRIGGER_CONFIG': {
|
||||
'GITLAB_ENDPOINT': 'http://somegitlab',
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
},
|
||||
}, http_client=build_requests_session(), url_scheme_and_hostname=url_scheme_and_hostname)
|
||||
|
||||
validator.validate(unvalidated_config)
|
||||
|
||||
assert url_hit[0]
|
45
util/config/validators/test/test_validate_google_login.py
Normal file
45
util/config/validators/test/test_validate_google_login.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from config import build_requests_session
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_google_login import GoogleLoginValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'GOOGLE_LOGIN_CONFIG': {}}),
|
||||
({'GOOGLE_LOGIN_CONFIG': {'CLIENT_ID': 'foo'}}),
|
||||
({'GOOGLE_LOGIN_CONFIG': {'CLIENT_SECRET': 'foo'}}),
|
||||
])
|
||||
def test_validate_invalid_google_login_config(unvalidated_config, app):
|
||||
validator = GoogleLoginValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
def test_validate_google_login(app):
|
||||
url_hit = [False]
|
||||
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v3/token')
|
||||
def handler(_, __):
|
||||
url_hit[0] = True
|
||||
return {'status_code': 200, 'content': ''}
|
||||
|
||||
validator = GoogleLoginValidator()
|
||||
|
||||
with HTTMock(handler):
|
||||
unvalidated_config = ValidatorContext({
|
||||
'GOOGLE_LOGIN_CONFIG': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
},
|
||||
})
|
||||
|
||||
unvalidated_config.http_client = build_requests_session()
|
||||
|
||||
validator.validate(unvalidated_config)
|
||||
|
||||
assert url_hit[0]
|
65
util/config/validators/test/test_validate_jwt.py
Normal file
65
util/config/validators/test/test_validate_jwt.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
|
||||
from config import build_requests_session
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_jwt import JWTAuthValidator
|
||||
from util.morecollections import AttrDict
|
||||
|
||||
from test.test_external_jwt_authn import fake_jwt
|
||||
|
||||
from test.fixtures import *
|
||||
from app import config_provider
|
||||
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'AUTHENTICATION_TYPE': 'Database'}),
|
||||
])
|
||||
def test_validate_noop(unvalidated_config, app):
|
||||
config = ValidatorContext(unvalidated_config)
|
||||
config.config_provider = config_provider
|
||||
JWTAuthValidator.validate(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'AUTHENTICATION_TYPE': 'JWT'}),
|
||||
({'AUTHENTICATION_TYPE': 'JWT', 'JWT_AUTH_ISSUER': 'foo'}),
|
||||
({'AUTHENTICATION_TYPE': 'JWT', 'JWT_VERIFY_ENDPOINT': 'foo'}),
|
||||
])
|
||||
def test_invalid_config(unvalidated_config, app):
|
||||
with pytest.raises(ConfigValidationException):
|
||||
config = ValidatorContext(unvalidated_config)
|
||||
config.config_provider = config_provider
|
||||
JWTAuthValidator.validate(config)
|
||||
|
||||
|
||||
# TODO: fix these when re-adding jwt auth mechanism to jwt validators
|
||||
@pytest.mark.skip(reason='No way of currently testing this')
|
||||
@pytest.mark.parametrize('username, password, expected_exception', [
|
||||
('invaliduser', 'invalidpass', ConfigValidationException),
|
||||
('cool.user', 'invalidpass', ConfigValidationException),
|
||||
('invaliduser', 'somepass', ConfigValidationException),
|
||||
('cool.user', 'password', None),
|
||||
])
|
||||
def test_validated_jwt(username, password, expected_exception, app):
|
||||
with fake_jwt() as jwt_auth:
|
||||
config = {}
|
||||
config['AUTHENTICATION_TYPE'] = 'JWT'
|
||||
config['JWT_AUTH_ISSUER'] = jwt_auth.issuer
|
||||
config['JWT_VERIFY_ENDPOINT'] = jwt_auth.verify_url
|
||||
config['JWT_QUERY_ENDPOINT'] = jwt_auth.query_url
|
||||
config['JWT_GETUSER_ENDPOINT'] = jwt_auth.getuser_url
|
||||
|
||||
unvalidated_config = ValidatorContext(config)
|
||||
unvalidated_config.user = AttrDict(dict(username=username))
|
||||
unvalidated_config.user_password = password
|
||||
unvalidated_config.config_provider = config_provider
|
||||
|
||||
unvalidated_config.http_client = build_requests_session()
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(ConfigValidationException):
|
||||
JWTAuthValidator.validate(unvalidated_config, public_key_path=jwt_auth.public_key_path)
|
||||
else:
|
||||
JWTAuthValidator.validate(unvalidated_config, public_key_path=jwt_auth.public_key_path)
|
54
util/config/validators/test/test_validate_keystone.py
Normal file
54
util/config/validators/test/test_validate_keystone.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_keystone import KeystoneValidator
|
||||
|
||||
from test.test_keystone_auth import fake_keystone
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'AUTHENTICATION_TYPE': 'Database'}),
|
||||
])
|
||||
def test_validate_noop(unvalidated_config, app):
|
||||
KeystoneValidator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'AUTHENTICATION_TYPE': 'Keystone'}),
|
||||
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo'}),
|
||||
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo',
|
||||
'KEYSTONE_ADMIN_USERNAME': 'bar'}),
|
||||
({'AUTHENTICATION_TYPE': 'Keystone', 'KEYSTONE_AUTH_URL': 'foo',
|
||||
'KEYSTONE_ADMIN_USERNAME': 'bar', 'KEYSTONE_ADMIN_PASSWORD': 'baz'}),
|
||||
])
|
||||
def test_invalid_config(unvalidated_config, app):
|
||||
with pytest.raises(ConfigValidationException):
|
||||
KeystoneValidator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('admin_tenant_id, expected_exception', [
|
||||
('somegroupid', None),
|
||||
('groupwithnousers', ConfigValidationException),
|
||||
('somegroupid', None),
|
||||
('groupwithnousers', ConfigValidationException),
|
||||
])
|
||||
def test_validated_keystone(admin_tenant_id, expected_exception, app):
|
||||
with fake_keystone(2) as keystone_auth:
|
||||
auth_url = keystone_auth.auth_url
|
||||
|
||||
config = {}
|
||||
config['AUTHENTICATION_TYPE'] = 'Keystone'
|
||||
config['KEYSTONE_AUTH_URL'] = auth_url
|
||||
config['KEYSTONE_ADMIN_USERNAME'] = 'adminuser'
|
||||
config['KEYSTONE_ADMIN_PASSWORD'] = 'adminpass'
|
||||
config['KEYSTONE_ADMIN_TENANT'] = admin_tenant_id
|
||||
|
||||
unvalidated_config = ValidatorContext(config)
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(ConfigValidationException):
|
||||
KeystoneValidator.validate(unvalidated_config)
|
||||
else:
|
||||
KeystoneValidator.validate(unvalidated_config)
|
72
util/config/validators/test/test_validate_ldap.py
Normal file
72
util/config/validators/test/test_validate_ldap.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_ldap import LDAPValidator
|
||||
from util.morecollections import AttrDict
|
||||
|
||||
from test.test_ldap import mock_ldap
|
||||
|
||||
from test.fixtures import *
|
||||
from app import config_provider
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'AUTHENTICATION_TYPE': 'Database'}),
|
||||
])
|
||||
def test_validate_noop(unvalidated_config, app):
|
||||
config = ValidatorContext(unvalidated_config, config_provider=config_provider)
|
||||
LDAPValidator.validate(config)
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'AUTHENTICATION_TYPE': 'LDAP'}),
|
||||
({'AUTHENTICATION_TYPE': 'LDAP', 'LDAP_ADMIN_DN': 'foo'}),
|
||||
])
|
||||
def test_invalid_config(unvalidated_config, app):
|
||||
with pytest.raises(ConfigValidationException):
|
||||
config = ValidatorContext(unvalidated_config, config_provider=config_provider)
|
||||
LDAPValidator.validate(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('uri', [
|
||||
'foo',
|
||||
'http://foo',
|
||||
'ldap:foo',
|
||||
])
|
||||
def test_invalid_uri(uri, app):
|
||||
config = {}
|
||||
config['AUTHENTICATION_TYPE'] = 'LDAP'
|
||||
config['LDAP_BASE_DN'] = ['dc=quay', 'dc=io']
|
||||
config['LDAP_ADMIN_DN'] = 'uid=testy,ou=employees,dc=quay,dc=io'
|
||||
config['LDAP_ADMIN_PASSWD'] = 'password'
|
||||
config['LDAP_USER_RDN'] = ['ou=employees']
|
||||
config['LDAP_URI'] = uri
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
config = ValidatorContext(config, config_provider=config_provider)
|
||||
LDAPValidator.validate(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('admin_dn, admin_passwd, user_rdn, expected_exception', [
|
||||
('uid=testy,ou=employees,dc=quay,dc=io', 'password', ['ou=employees'], None),
|
||||
('uid=invalidadmindn', 'password', ['ou=employees'], ConfigValidationException),
|
||||
('uid=testy,ou=employees,dc=quay,dc=io', 'invalid_password', ['ou=employees'], ConfigValidationException),
|
||||
('uid=testy,ou=employees,dc=quay,dc=io', 'password', ['ou=invalidgroup'], ConfigValidationException),
|
||||
])
|
||||
def test_validated_ldap(admin_dn, admin_passwd, user_rdn, expected_exception, app):
|
||||
config = {}
|
||||
config['AUTHENTICATION_TYPE'] = 'LDAP'
|
||||
config['LDAP_BASE_DN'] = ['dc=quay', 'dc=io']
|
||||
config['LDAP_ADMIN_DN'] = admin_dn
|
||||
config['LDAP_ADMIN_PASSWD'] = admin_passwd
|
||||
config['LDAP_USER_RDN'] = user_rdn
|
||||
|
||||
unvalidated_config = ValidatorContext(config, config_provider=config_provider)
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(ConfigValidationException):
|
||||
with mock_ldap():
|
||||
LDAPValidator.validate(unvalidated_config)
|
||||
else:
|
||||
with mock_ldap():
|
||||
LDAPValidator.validate(unvalidated_config)
|
50
util/config/validators/test/test_validate_oidc.py
Normal file
50
util/config/validators/test/test_validate_oidc.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import json
|
||||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from config import build_requests_session
|
||||
from oauth.oidc import OIDC_WELLKNOWN
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_oidc import OIDCLoginValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'SOMETHING_LOGIN_CONFIG': {}}),
|
||||
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo'}}),
|
||||
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo', 'CLIENT_ID': 'foobar'}}),
|
||||
({'SOMETHING_LOGIN_CONFIG': {'OIDC_SERVER': 'foo', 'CLIENT_SECRET': 'foobar'}}),
|
||||
])
|
||||
def test_validate_invalid_oidc_login_config(unvalidated_config, app):
|
||||
validator = OIDCLoginValidator()
|
||||
|
||||
with pytest.raises(ConfigValidationException):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
def test_validate_oidc_login(app):
|
||||
url_hit = [False]
|
||||
@urlmatch(netloc=r'someserver', path=r'/\.well-known/openid-configuration')
|
||||
def handler(_, __):
|
||||
url_hit[0] = True
|
||||
data = {
|
||||
'token_endpoint': 'foobar',
|
||||
}
|
||||
return {'status_code': 200, 'content': json.dumps(data)}
|
||||
|
||||
with HTTMock(handler):
|
||||
validator = OIDCLoginValidator()
|
||||
unvalidated_config = ValidatorContext({
|
||||
'SOMETHING_LOGIN_CONFIG': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'OIDC_SERVER': 'http://someserver',
|
||||
'DEBUGGING': True, # Allows for HTTP.
|
||||
},
|
||||
})
|
||||
unvalidated_config.http_client = build_requests_session()
|
||||
|
||||
validator.validate(unvalidated_config)
|
||||
|
||||
assert url_hit[0]
|
34
util/config/validators/test/test_validate_redis.py
Normal file
34
util/config/validators/test/test_validate_redis.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
import redis
|
||||
|
||||
from mock import patch
|
||||
|
||||
from mockredis import mock_strict_redis_client
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_redis import RedisValidator
|
||||
|
||||
from test.fixtures import *
|
||||
from util.morecollections import AttrDict
|
||||
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config,user,user_password,use_mock,expected', [
|
||||
({}, None, None, False, ConfigValidationException),
|
||||
({'BUILDLOGS_REDIS': {}}, None, None, False, ConfigValidationException),
|
||||
({'BUILDLOGS_REDIS': {'host': 'somehost'}}, None, None, False, redis.ConnectionError),
|
||||
({'BUILDLOGS_REDIS': {'host': 'localhost'}}, None, None, True, None),
|
||||
])
|
||||
def test_validate_redis(unvalidated_config, user, user_password, use_mock, expected, app):
|
||||
with patch('redis.StrictRedis' if use_mock else 'redis.None', mock_strict_redis_client):
|
||||
validator = RedisValidator()
|
||||
unvalidated_config = ValidatorContext(unvalidated_config)
|
||||
|
||||
unvalidated_config.user = AttrDict(dict(username=user))
|
||||
unvalidated_config.user_password = user_password
|
||||
|
||||
if expected is not None:
|
||||
with pytest.raises(expected):
|
||||
validator.validate(unvalidated_config)
|
||||
else:
|
||||
validator.validate(unvalidated_config)
|
48
util/config/validators/test/test_validate_secscan.py
Normal file
48
util/config/validators/test/test_validate_secscan.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from config import build_requests_session
|
||||
from util.config import URLSchemeAndHostname
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators.validate_secscan import SecurityScannerValidator
|
||||
from util.secscan.fake import fake_security_scanner
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({'DISTRIBUTED_STORAGE_PREFERENCE': []}),
|
||||
])
|
||||
def test_validate_noop(unvalidated_config, app):
|
||||
|
||||
unvalidated_config = ValidatorContext(unvalidated_config, feature_sec_scanner=False, is_testing=True,
|
||||
http_client=build_requests_session(),
|
||||
url_scheme_and_hostname=URLSchemeAndHostname('http', 'localhost:5000'))
|
||||
|
||||
SecurityScannerValidator.validate(unvalidated_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config, expected_error', [
|
||||
({
|
||||
'TESTING': True,
|
||||
'DISTRIBUTED_STORAGE_PREFERENCE': [],
|
||||
'FEATURE_SECURITY_SCANNER': True,
|
||||
'SECURITY_SCANNER_ENDPOINT': 'http://invalidhost',
|
||||
}, Exception),
|
||||
|
||||
({
|
||||
'TESTING': True,
|
||||
'DISTRIBUTED_STORAGE_PREFERENCE': [],
|
||||
'FEATURE_SECURITY_SCANNER': True,
|
||||
'SECURITY_SCANNER_ENDPOINT': 'http://fakesecurityscanner',
|
||||
}, None),
|
||||
])
|
||||
def test_validate(unvalidated_config, expected_error, app):
|
||||
unvalidated_config = ValidatorContext(unvalidated_config, feature_sec_scanner=True, is_testing=True,
|
||||
http_client=build_requests_session(),
|
||||
url_scheme_and_hostname=URLSchemeAndHostname('http', 'localhost:5000'))
|
||||
|
||||
with fake_security_scanner(hostname='fakesecurityscanner'):
|
||||
if expected_error is not None:
|
||||
with pytest.raises(expected_error):
|
||||
SecurityScannerValidator.validate(unvalidated_config)
|
||||
else:
|
||||
SecurityScannerValidator.validate(unvalidated_config)
|
20
util/config/validators/test/test_validate_signer.py
Normal file
20
util/config/validators/test/test_validate_signer.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_signer import SignerValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config,expected', [
|
||||
({}, None),
|
||||
({'SIGNING_ENGINE': 'foobar'}, ConfigValidationException),
|
||||
({'SIGNING_ENGINE': 'gpg2'}, Exception),
|
||||
])
|
||||
def test_validate_signer(unvalidated_config, expected, app):
|
||||
validator = SignerValidator()
|
||||
if expected is not None:
|
||||
with pytest.raises(expected):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
else:
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
75
util/config/validators/test/test_validate_ssl.py
Normal file
75
util/config/validators/test/test_validate_ssl.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import pytest
|
||||
|
||||
from mock import patch
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES
|
||||
from util.security.test.test_ssl_util import generate_test_cert
|
||||
|
||||
from test.fixtures import *
|
||||
from app import config_provider
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
({'PREFERRED_URL_SCHEME': 'http'}),
|
||||
({'PREFERRED_URL_SCHEME': 'https', 'EXTERNAL_TLS_TERMINATION': True}),
|
||||
])
|
||||
def test_skip_validate_ssl(unvalidated_config, app):
|
||||
validator = SSLValidator()
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cert, server_hostname, expected_error, error_message', [
|
||||
('invalidcert', 'someserver', ConfigValidationException, 'Could not load SSL certificate: no start line'),
|
||||
(generate_test_cert(hostname='someserver'), 'someserver', None, None),
|
||||
(generate_test_cert(hostname='invalidserver'), 'someserver', ConfigValidationException,
|
||||
'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'),
|
||||
(generate_test_cert(hostname='someserver'), 'someserver:1234', None, None),
|
||||
(generate_test_cert(hostname='invalidserver'), 'someserver:1234', ConfigValidationException,
|
||||
'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'),
|
||||
(generate_test_cert(hostname='someserver:1234'), 'someserver:1234', ConfigValidationException,
|
||||
'Supported names "someserver:1234" in SSL cert do not match server hostname "someserver"'),
|
||||
(generate_test_cert(hostname='someserver:more'), 'someserver:more', None, None),
|
||||
(generate_test_cert(hostname='someserver:more'), 'someserver:more:1234', None, None),
|
||||
])
|
||||
def test_validate_ssl(cert, server_hostname, expected_error, error_message, app):
|
||||
with NamedTemporaryFile(delete=False) as cert_file:
|
||||
cert_file.write(cert[0])
|
||||
cert_file.seek(0)
|
||||
|
||||
with NamedTemporaryFile(delete=False) as key_file:
|
||||
key_file.write(cert[1])
|
||||
key_file.seek(0)
|
||||
|
||||
def return_true(filename):
|
||||
return True
|
||||
|
||||
def get_volume_file(filename):
|
||||
if filename == SSL_FILENAMES[0]:
|
||||
return open(cert_file.name)
|
||||
|
||||
if filename == SSL_FILENAMES[1]:
|
||||
return open(key_file.name)
|
||||
|
||||
return None
|
||||
|
||||
config = {
|
||||
'PREFERRED_URL_SCHEME': 'https',
|
||||
'SERVER_HOSTNAME': server_hostname,
|
||||
}
|
||||
|
||||
with patch('app.config_provider.volume_file_exists', return_true):
|
||||
with patch('app.config_provider.get_volume_file', get_volume_file):
|
||||
validator = SSLValidator()
|
||||
config = ValidatorContext(config)
|
||||
config.config_provider = config_provider
|
||||
|
||||
if expected_error is not None:
|
||||
with pytest.raises(expected_error) as ipe:
|
||||
validator.validate(config)
|
||||
|
||||
assert str(ipe.value) == error_message
|
||||
else:
|
||||
validator.validate(config)
|
40
util/config/validators/test/test_validate_storage.py
Normal file
40
util/config/validators/test/test_validate_storage.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import pytest
|
||||
|
||||
from moto import mock_s3_deprecated as mock_s3
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_storage import StorageValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config, expected', [
|
||||
({}, ConfigValidationException),
|
||||
({'DISTRIBUTED_STORAGE_CONFIG': {}}, ConfigValidationException),
|
||||
({'DISTRIBUTED_STORAGE_CONFIG': {'local': None}}, ConfigValidationException),
|
||||
({'DISTRIBUTED_STORAGE_CONFIG': {'local': ['FakeStorage', {}]}}, None),
|
||||
])
|
||||
def test_validate_storage(unvalidated_config, expected, app):
|
||||
validator = StorageValidator()
|
||||
if expected is not None:
|
||||
with pytest.raises(expected):
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
else:
|
||||
validator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
def test_validate_s3_storage(app):
|
||||
validator = StorageValidator()
|
||||
with mock_s3():
|
||||
with pytest.raises(ConfigValidationException) as ipe:
|
||||
validator.validate(ValidatorContext({
|
||||
'DISTRIBUTED_STORAGE_CONFIG': {
|
||||
'default': ('S3Storage', {
|
||||
's3_access_key': 'invalid',
|
||||
's3_secret_key': 'invalid',
|
||||
's3_bucket': 'somebucket',
|
||||
'storage_path': ''
|
||||
}),
|
||||
}
|
||||
}))
|
||||
|
||||
assert str(ipe.value) == 'Invalid storage configuration: default: S3ResponseError: 404 Not Found'
|
32
util/config/validators/test/test_validate_timemachine.py
Normal file
32
util/config/validators/test/test_validate_timemachine.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_timemachine import TimeMachineValidator
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config', [
|
||||
({}),
|
||||
])
|
||||
def test_validate_noop(unvalidated_config):
|
||||
TimeMachineValidator.validate(ValidatorContext(unvalidated_config))
|
||||
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('default_exp,options,expected_exception', [
|
||||
('2d', ['1w', '2d'], None),
|
||||
|
||||
('2d', ['1w'], 'Default expiration must be in expiration options set'),
|
||||
('2d', ['2d', '1M'], 'Invalid tag expiration option: 1M'),
|
||||
])
|
||||
def test_validate(default_exp, options, expected_exception, app):
|
||||
config = {}
|
||||
config['DEFAULT_TAG_EXPIRATION'] = default_exp
|
||||
config['TAG_EXPIRATION_OPTIONS'] = options
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(ConfigValidationException) as cve:
|
||||
TimeMachineValidator.validate(ValidatorContext(config))
|
||||
assert str(cve.value) == str(expected_exception)
|
||||
else:
|
||||
TimeMachineValidator.validate(ValidatorContext(config))
|
39
util/config/validators/test/test_validate_torrent.py
Normal file
39
util/config/validators/test/test_validate_torrent.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import pytest
|
||||
|
||||
from config import build_requests_session
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
from app import instance_keys
|
||||
from util.config.validator import ValidatorContext
|
||||
from util.config.validators import ConfigValidationException
|
||||
from util.config.validators.validate_torrent import BittorrentValidator
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('unvalidated_config,expected', [
|
||||
({}, ConfigValidationException),
|
||||
({'BITTORRENT_ANNOUNCE_URL': 'http://faketorrent/announce'}, None),
|
||||
])
|
||||
def test_validate_torrent(unvalidated_config, expected, app):
|
||||
announcer_hit = [False]
|
||||
|
||||
@urlmatch(netloc=r'faketorrent', path='/announce')
|
||||
def handler(url, request):
|
||||
announcer_hit[0] = True
|
||||
return {'status_code': 200, 'content': ''}
|
||||
|
||||
with HTTMock(handler):
|
||||
validator = BittorrentValidator()
|
||||
if expected is not None:
|
||||
with pytest.raises(expected):
|
||||
config = ValidatorContext(unvalidated_config, instance_keys=instance_keys)
|
||||
config.http_client = build_requests_session()
|
||||
|
||||
validator.validate(config)
|
||||
assert not announcer_hit[0]
|
||||
else:
|
||||
config = ValidatorContext(unvalidated_config, instance_keys=instance_keys)
|
||||
config.http_client = build_requests_session()
|
||||
|
||||
validator.validate(config)
|
||||
assert announcer_hit[0]
|
28
util/config/validators/validate_access.py
Normal file
28
util/config/validators/validate_access.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from oauth.loginmanager import OAuthLoginManager
|
||||
from oauth.oidc import OIDCLoginService
|
||||
|
||||
class AccessSettingsValidator(BaseValidator):
|
||||
name = "access"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
|
||||
if not config.get('FEATURE_DIRECT_LOGIN', True):
|
||||
# Make sure we have at least one OIDC enabled.
|
||||
github_login = config.get('FEATURE_GITHUB_LOGIN', False)
|
||||
google_login = config.get('FEATURE_GOOGLE_LOGIN', False)
|
||||
|
||||
login_manager = OAuthLoginManager(config, client=client)
|
||||
custom_oidc = [s for s in login_manager.services if isinstance(s, OIDCLoginService)]
|
||||
|
||||
if not github_login and not google_login and not custom_oidc:
|
||||
msg = 'Cannot disable credentials login to UI without configured OIDC service'
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
if (not config.get('FEATURE_USER_CREATION', True) and
|
||||
config.get('FEATURE_INVITE_ONLY_USER_CREATION', False)):
|
||||
msg = "Invite only user creation requires user creation to be enabled"
|
||||
raise ConfigValidationException(msg)
|
24
util/config/validators/validate_actionlog_archiving.py
Normal file
24
util/config/validators/validate_actionlog_archiving.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class ActionLogArchivingValidator(BaseValidator):
|
||||
name = "actionlogarchiving"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
config = validator_context.config
|
||||
|
||||
""" Validates the action log archiving configuration. """
|
||||
if not config.get('FEATURE_ACTION_LOG_ROTATION', False):
|
||||
return
|
||||
|
||||
if not config.get('ACTION_LOG_ARCHIVE_PATH'):
|
||||
raise ConfigValidationException('Missing action log archive path')
|
||||
|
||||
if not config.get('ACTION_LOG_ARCHIVE_LOCATION'):
|
||||
raise ConfigValidationException('Missing action log archive storage location')
|
||||
|
||||
location = config['ACTION_LOG_ARCHIVE_LOCATION']
|
||||
storage_config = config.get('DISTRIBUTED_STORAGE_CONFIG') or {}
|
||||
if location not in storage_config:
|
||||
msg = 'Action log archive storage location `%s` not found in storage config' % location
|
||||
raise ConfigValidationException(msg)
|
21
util/config/validators/validate_apptokenauth.py
Normal file
21
util/config/validators/validate_apptokenauth.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class AppTokenAuthValidator(BaseValidator):
|
||||
name = "apptoken-auth"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
config = validator_context.config
|
||||
|
||||
if config.get('AUTHENTICATION_TYPE', 'Database') != 'AppToken':
|
||||
return
|
||||
|
||||
# Ensure that app tokens are enabled, as they are required.
|
||||
if not config.get('FEATURE_APP_SPECIFIC_TOKENS', False):
|
||||
msg = 'Application token support must be enabled to use External Application Token auth'
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
# Ensure that direct login is disabled.
|
||||
if config.get('FEATURE_DIRECT_LOGIN', True):
|
||||
msg = 'Direct login must be disabled to use External Application Token auth'
|
||||
raise ConfigValidationException(msg)
|
30
util/config/validators/validate_bitbucket_trigger.py
Normal file
30
util/config/validators/validate_bitbucket_trigger.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from bitbucket import BitBucket
|
||||
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class BitbucketTriggerValidator(BaseValidator):
|
||||
name = "bitbucket-trigger"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the config for BitBucket. """
|
||||
config = validator_context.config
|
||||
|
||||
trigger_config = config.get('BITBUCKET_TRIGGER_CONFIG')
|
||||
if not trigger_config:
|
||||
raise ConfigValidationException('Missing client ID and client secret')
|
||||
|
||||
if not trigger_config.get('CONSUMER_KEY'):
|
||||
raise ConfigValidationException('Missing Consumer Key')
|
||||
|
||||
if not trigger_config.get('CONSUMER_SECRET'):
|
||||
raise ConfigValidationException('Missing Consumer Secret')
|
||||
|
||||
key = trigger_config['CONSUMER_KEY']
|
||||
secret = trigger_config['CONSUMER_SECRET']
|
||||
callback_url = '%s/oauth1/bitbucket/callback/trigger/' % (validator_context.url_scheme_and_hostname.get_url())
|
||||
|
||||
bitbucket_client = BitBucket(key, secret, callback_url)
|
||||
(result, _, _) = bitbucket_client.get_authorization_url()
|
||||
if not result:
|
||||
raise ConfigValidationException('Invalid consumer key or secret')
|
20
util/config/validators/validate_database.py
Normal file
20
util/config/validators/validate_database.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from peewee import OperationalError
|
||||
|
||||
from data.database import validate_database_precondition
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class DatabaseValidator(BaseValidator):
|
||||
name = "database"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates connecting to the database. """
|
||||
config = validator_context.config
|
||||
|
||||
try:
|
||||
validate_database_precondition(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {}))
|
||||
except OperationalError as ex:
|
||||
if ex.args and len(ex.args) > 1:
|
||||
raise ConfigValidationException(ex.args[1])
|
||||
else:
|
||||
raise ex
|
53
util/config/validators/validate_github.py
Normal file
53
util/config/validators/validate_github.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from oauth.services.github import GithubOAuthService
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class BaseGitHubValidator(BaseValidator):
|
||||
name = None
|
||||
config_key = None
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the OAuth credentials and API endpoint for a Github service. """
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
|
||||
|
||||
github_config = config.get(cls.config_key)
|
||||
if not github_config:
|
||||
raise ConfigValidationException('Missing GitHub client id and client secret')
|
||||
|
||||
endpoint = github_config.get('GITHUB_ENDPOINT')
|
||||
if not endpoint:
|
||||
raise ConfigValidationException('Missing GitHub Endpoint')
|
||||
|
||||
if endpoint.find('http://') != 0 and endpoint.find('https://') != 0:
|
||||
raise ConfigValidationException('Github Endpoint must start with http:// or https://')
|
||||
|
||||
if not github_config.get('CLIENT_ID'):
|
||||
raise ConfigValidationException('Missing Client ID')
|
||||
|
||||
if not github_config.get('CLIENT_SECRET'):
|
||||
raise ConfigValidationException('Missing Client Secret')
|
||||
|
||||
if github_config.get('ORG_RESTRICT') and not github_config.get('ALLOWED_ORGANIZATIONS'):
|
||||
raise ConfigValidationException('Organization restriction must have at least one allowed ' +
|
||||
'organization')
|
||||
|
||||
oauth = GithubOAuthService(config, cls.config_key)
|
||||
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
|
||||
if not result:
|
||||
raise ConfigValidationException('Invalid client id or client secret')
|
||||
|
||||
if github_config.get('ALLOWED_ORGANIZATIONS'):
|
||||
for org_id in github_config.get('ALLOWED_ORGANIZATIONS'):
|
||||
if not oauth.validate_organization(org_id, client):
|
||||
raise ConfigValidationException('Invalid organization: %s' % org_id)
|
||||
|
||||
|
||||
class GitHubLoginValidator(BaseGitHubValidator):
|
||||
name = "github-login"
|
||||
config_key = "GITHUB_LOGIN_CONFIG"
|
||||
|
||||
class GitHubTriggerValidator(BaseGitHubValidator):
|
||||
name = "github-trigger"
|
||||
config_key = "GITHUB_TRIGGER_CONFIG"
|
32
util/config/validators/validate_gitlab_trigger.py
Normal file
32
util/config/validators/validate_gitlab_trigger.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from oauth.services.gitlab import GitLabOAuthService
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class GitLabTriggerValidator(BaseValidator):
|
||||
name = "gitlab-trigger"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the OAuth credentials and API endpoint for a GitLab service. """
|
||||
config = validator_context.config
|
||||
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
|
||||
client = validator_context.http_client
|
||||
|
||||
github_config = config.get('GITLAB_TRIGGER_CONFIG')
|
||||
if not github_config:
|
||||
raise ConfigValidationException('Missing GitLab client id and client secret')
|
||||
|
||||
endpoint = github_config.get('GITLAB_ENDPOINT')
|
||||
if endpoint:
|
||||
if endpoint.find('http://') != 0 and endpoint.find('https://') != 0:
|
||||
raise ConfigValidationException('GitLab Endpoint must start with http:// or https://')
|
||||
|
||||
if not github_config.get('CLIENT_ID'):
|
||||
raise ConfigValidationException('Missing Client ID')
|
||||
|
||||
if not github_config.get('CLIENT_SECRET'):
|
||||
raise ConfigValidationException('Missing Client Secret')
|
||||
|
||||
oauth = GitLabOAuthService(config, 'GITLAB_TRIGGER_CONFIG')
|
||||
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
|
||||
if not result:
|
||||
raise ConfigValidationException('Invalid client id or client secret')
|
27
util/config/validators/validate_google_login.py
Normal file
27
util/config/validators/validate_google_login.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from oauth.services.google import GoogleOAuthService
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class GoogleLoginValidator(BaseValidator):
|
||||
name = "google-login"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the Google Login client ID and secret. """
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
url_scheme_and_hostname = validator_context.url_scheme_and_hostname
|
||||
|
||||
google_login_config = config.get('GOOGLE_LOGIN_CONFIG')
|
||||
if not google_login_config:
|
||||
raise ConfigValidationException('Missing client ID and client secret')
|
||||
|
||||
if not google_login_config.get('CLIENT_ID'):
|
||||
raise ConfigValidationException('Missing Client ID')
|
||||
|
||||
if not google_login_config.get('CLIENT_SECRET'):
|
||||
raise ConfigValidationException('Missing Client Secret')
|
||||
|
||||
oauth = GoogleOAuthService(config, 'GOOGLE_LOGIN_CONFIG')
|
||||
result = oauth.validate_client_id_and_secret(client, url_scheme_and_hostname)
|
||||
if not result:
|
||||
raise ConfigValidationException('Invalid client id or client secret')
|
48
util/config/validators/validate_jwt.py
Normal file
48
util/config/validators/validate_jwt.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
from data.users.externaljwt import ExternalJWTAuthN
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class JWTAuthValidator(BaseValidator):
|
||||
name = "jwt"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context, public_key_path=None):
|
||||
""" Validates the JWT authentication system. """
|
||||
config = validator_context.config
|
||||
http_client = validator_context.http_client
|
||||
jwt_auth_max = validator_context.jwt_auth_max
|
||||
config_provider = validator_context.config_provider
|
||||
|
||||
if config.get('AUTHENTICATION_TYPE', 'Database') != 'JWT':
|
||||
return
|
||||
|
||||
verify_endpoint = config.get('JWT_VERIFY_ENDPOINT')
|
||||
query_endpoint = config.get('JWT_QUERY_ENDPOINT', None)
|
||||
getuser_endpoint = config.get('JWT_GETUSER_ENDPOINT', None)
|
||||
|
||||
issuer = config.get('JWT_AUTH_ISSUER')
|
||||
|
||||
if not verify_endpoint:
|
||||
raise ConfigValidationException('Missing JWT Verification endpoint')
|
||||
|
||||
if not issuer:
|
||||
raise ConfigValidationException('Missing JWT Issuer ID')
|
||||
|
||||
|
||||
override_config_directory = config_provider.get_config_dir_path()
|
||||
|
||||
# Try to instatiate the JWT authentication mechanism. This will raise an exception if
|
||||
# the key cannot be found.
|
||||
users = ExternalJWTAuthN(verify_endpoint, query_endpoint, getuser_endpoint, issuer,
|
||||
override_config_directory,
|
||||
http_client,
|
||||
jwt_auth_max,
|
||||
public_key_path=public_key_path,
|
||||
requires_email=config.get('FEATURE_MAILING', True))
|
||||
|
||||
# Verify that we can reach the jwt server
|
||||
(result, err_msg) = users.ping()
|
||||
if not result:
|
||||
msg = ('Verification of JWT failed: %s. \n\nWe cannot reach the JWT server' +
|
||||
'OR JWT auth is misconfigured') % err_msg
|
||||
raise ConfigValidationException(msg)
|
44
util/config/validators/validate_keystone.py
Normal file
44
util/config/validators/validate_keystone.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from data.users.keystone import get_keystone_users
|
||||
|
||||
class KeystoneValidator(BaseValidator):
|
||||
name = "keystone"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the Keystone authentication system. """
|
||||
config = validator_context.config
|
||||
|
||||
if config.get('AUTHENTICATION_TYPE', 'Database') != 'Keystone':
|
||||
return
|
||||
|
||||
auth_url = config.get('KEYSTONE_AUTH_URL')
|
||||
auth_version = int(config.get('KEYSTONE_AUTH_VERSION', 2))
|
||||
admin_username = config.get('KEYSTONE_ADMIN_USERNAME')
|
||||
admin_password = config.get('KEYSTONE_ADMIN_PASSWORD')
|
||||
admin_tenant = config.get('KEYSTONE_ADMIN_TENANT')
|
||||
|
||||
if not auth_url:
|
||||
raise ConfigValidationException('Missing authentication URL')
|
||||
|
||||
if not admin_username:
|
||||
raise ConfigValidationException('Missing admin username')
|
||||
|
||||
if not admin_password:
|
||||
raise ConfigValidationException('Missing admin password')
|
||||
|
||||
if not admin_tenant:
|
||||
raise ConfigValidationException('Missing admin tenant')
|
||||
|
||||
requires_email = config.get('FEATURE_MAILING', True)
|
||||
users = get_keystone_users(auth_version, auth_url, admin_username, admin_password, admin_tenant,
|
||||
requires_email)
|
||||
|
||||
# Verify that the superuser exists. If not, raise an exception.
|
||||
(result, err_msg) = users.at_least_one_user_exists()
|
||||
if not result:
|
||||
msg = ('Verification that users exist failed: %s. \n\nNo users exist ' +
|
||||
'in the admin tenant/project ' +
|
||||
'in the remote authentication system ' +
|
||||
'OR Keystone auth is misconfigured.') % err_msg
|
||||
raise ConfigValidationException(msg)
|
68
util/config/validators/validate_ldap.py
Normal file
68
util/config/validators/validate_ldap.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
import ldap
|
||||
import subprocess
|
||||
|
||||
from data.users import LDAP_CERT_FILENAME
|
||||
from data.users.externalldap import LDAPConnection, LDAPUsers
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class LDAPValidator(BaseValidator):
|
||||
name = "ldap"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the LDAP connection. """
|
||||
config = validator_context.config
|
||||
config_provider = validator_context.config_provider
|
||||
init_scripts_location = validator_context.init_scripts_location
|
||||
|
||||
if config.get('AUTHENTICATION_TYPE', 'Database') != 'LDAP':
|
||||
return
|
||||
|
||||
# If there is a custom LDAP certificate, then reinstall the certificates for the container.
|
||||
if config_provider.volume_file_exists(LDAP_CERT_FILENAME):
|
||||
subprocess.check_call([os.path.join(init_scripts_location, 'certs_install.sh')],
|
||||
env={ 'QUAYCONFIG': config_provider.get_config_dir_path() })
|
||||
|
||||
# Note: raises ldap.INVALID_CREDENTIALS on failure
|
||||
admin_dn = config.get('LDAP_ADMIN_DN')
|
||||
admin_passwd = config.get('LDAP_ADMIN_PASSWD')
|
||||
|
||||
if not admin_dn:
|
||||
raise ConfigValidationException('Missing Admin DN for LDAP configuration')
|
||||
|
||||
if not admin_passwd:
|
||||
raise ConfigValidationException('Missing Admin Password for LDAP configuration')
|
||||
|
||||
ldap_uri = config.get('LDAP_URI', 'ldap://localhost')
|
||||
if not ldap_uri.startswith('ldap://') and not ldap_uri.startswith('ldaps://'):
|
||||
raise ConfigValidationException('LDAP URI must start with ldap:// or ldaps://')
|
||||
|
||||
allow_tls_fallback = config.get('LDAP_ALLOW_INSECURE_FALLBACK', False)
|
||||
|
||||
try:
|
||||
with LDAPConnection(ldap_uri, admin_dn, admin_passwd, allow_tls_fallback):
|
||||
pass
|
||||
except ldap.LDAPError as ex:
|
||||
values = ex.args[0] if ex.args else {}
|
||||
if not isinstance(values, dict):
|
||||
raise ConfigValidationException(str(ex.args))
|
||||
|
||||
raise ConfigValidationException(values.get('desc', 'Unknown error'))
|
||||
|
||||
base_dn = config.get('LDAP_BASE_DN')
|
||||
user_rdn = config.get('LDAP_USER_RDN', [])
|
||||
uid_attr = config.get('LDAP_UID_ATTR', 'uid')
|
||||
email_attr = config.get('LDAP_EMAIL_ATTR', 'mail')
|
||||
requires_email = config.get('FEATURE_MAILING', True)
|
||||
|
||||
users = LDAPUsers(ldap_uri, base_dn, admin_dn, admin_passwd, user_rdn, uid_attr, email_attr,
|
||||
allow_tls_fallback, requires_email=requires_email)
|
||||
|
||||
# Ensure at least one user exists to verify the connection is setup properly.
|
||||
(result, err_msg) = users.at_least_one_user_exists()
|
||||
if not result:
|
||||
msg = ('Verification that users exist failed: %s. \n\nNo users exist ' +
|
||||
'in the remote authentication system ' +
|
||||
'OR LDAP auth is misconfigured.') % err_msg
|
||||
raise ConfigValidationException(msg)
|
36
util/config/validators/validate_oidc.py
Normal file
36
util/config/validators/validate_oidc.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from oauth.loginmanager import OAuthLoginManager
|
||||
from oauth.oidc import OIDCLoginService, DiscoveryFailureException
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class OIDCLoginValidator(BaseValidator):
|
||||
name = "oidc-login"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
|
||||
login_manager = OAuthLoginManager(config, client=client)
|
||||
for service in login_manager.services:
|
||||
if not isinstance(service, OIDCLoginService):
|
||||
continue
|
||||
|
||||
if service.config.get('OIDC_SERVER') is None:
|
||||
msg = 'Missing OIDC_SERVER on OIDC service %s' % service.service_id()
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
if service.config.get('CLIENT_ID') is None:
|
||||
msg = 'Missing CLIENT_ID on OIDC service %s' % service.service_id()
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
if service.config.get('CLIENT_SECRET') is None:
|
||||
msg = 'Missing CLIENT_SECRET on OIDC service %s' % service.service_id()
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
try:
|
||||
if not service.validate():
|
||||
msg = 'Could not validate OIDC service %s' % service.service_id()
|
||||
raise ConfigValidationException(msg)
|
||||
except DiscoveryFailureException as dfe:
|
||||
msg = 'Could not validate OIDC service %s: %s' % (service.service_id(), dfe.message)
|
||||
raise ConfigValidationException(msg)
|
18
util/config/validators/validate_redis.py
Normal file
18
util/config/validators/validate_redis.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import redis
|
||||
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class RedisValidator(BaseValidator):
|
||||
name = "redis"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates connecting to redis. """
|
||||
config = validator_context.config
|
||||
|
||||
redis_config = config.get('BUILDLOGS_REDIS', {})
|
||||
if not 'host' in redis_config:
|
||||
raise ConfigValidationException('Missing redis hostname')
|
||||
|
||||
client = redis.StrictRedis(socket_connect_timeout=5, **redis_config)
|
||||
client.ping()
|
52
util/config/validators/validate_secscan.py
Normal file
52
util/config/validators/validate_secscan.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import time
|
||||
|
||||
# from boot import setup_jwt_proxy
|
||||
from util.secscan.api import SecurityScannerAPI
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
class SecurityScannerValidator(BaseValidator):
|
||||
name = "security-scanner"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the configuration for talking to a Quay Security Scanner. """
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
feature_sec_scanner = validator_context.feature_sec_scanner
|
||||
is_testing = validator_context.is_testing
|
||||
|
||||
server_hostname = validator_context.url_scheme_and_hostname.hostname
|
||||
uri_creator = validator_context.uri_creator
|
||||
|
||||
if not feature_sec_scanner:
|
||||
return
|
||||
|
||||
api = SecurityScannerAPI(config, None, server_hostname, client=client, skip_validation=True, uri_creator=uri_creator)
|
||||
|
||||
# if not is_testing:
|
||||
# Generate a temporary Quay key to use for signing the outgoing requests.
|
||||
# setup_jwt_proxy()
|
||||
|
||||
# We have to wait for JWT proxy to restart with the newly generated key.
|
||||
max_tries = 5
|
||||
response = None
|
||||
last_exception = None
|
||||
|
||||
while max_tries > 0:
|
||||
try:
|
||||
response = api.ping()
|
||||
last_exception = None
|
||||
if response.status_code == 200:
|
||||
return
|
||||
except Exception as ex:
|
||||
last_exception = ex
|
||||
|
||||
time.sleep(1)
|
||||
max_tries = max_tries - 1
|
||||
|
||||
if last_exception is not None:
|
||||
message = str(last_exception)
|
||||
raise ConfigValidationException('Could not ping security scanner: %s' % message)
|
||||
else:
|
||||
message = 'Expected 200 status code, got %s: %s' % (response.status_code, response.text)
|
||||
raise ConfigValidationException('Could not ping security scanner: %s' % message)
|
22
util/config/validators/validate_signer.py
Normal file
22
util/config/validators/validate_signer.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from StringIO import StringIO
|
||||
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from util.security.signing import SIGNING_ENGINES
|
||||
|
||||
class SignerValidator(BaseValidator):
|
||||
name = "signer"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the GPG public+private key pair used for signing converted ACIs. """
|
||||
config = validator_context.config
|
||||
config_provider = validator_context.config_provider
|
||||
|
||||
if config.get('SIGNING_ENGINE') is None:
|
||||
return
|
||||
|
||||
if config['SIGNING_ENGINE'] not in SIGNING_ENGINES:
|
||||
raise ConfigValidationException('Unknown signing engine: %s' % config['SIGNING_ENGINE'])
|
||||
|
||||
engine = SIGNING_ENGINES[config['SIGNING_ENGINE']](config, config_provider)
|
||||
engine.detached_sign(StringIO('test string'))
|
72
util/config/validators/validate_ssl.py
Normal file
72
util/config/validators/validate_ssl.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
||||
|
||||
SSL_FILENAMES = ['ssl.cert', 'ssl.key']
|
||||
|
||||
class SSLValidator(BaseValidator):
|
||||
name = "ssl"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the SSL configuration (if enabled). """
|
||||
config = validator_context.config
|
||||
config_provider = validator_context.config_provider
|
||||
|
||||
# Skip if non-SSL.
|
||||
if config.get('PREFERRED_URL_SCHEME', 'http') != 'https':
|
||||
return
|
||||
|
||||
# Skip if externally terminated.
|
||||
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
|
||||
return
|
||||
|
||||
# Verify that we have all the required SSL files.
|
||||
for filename in SSL_FILENAMES:
|
||||
if not config_provider.volume_file_exists(filename):
|
||||
raise ConfigValidationException('Missing required SSL file: %s' % filename)
|
||||
|
||||
# Read the contents of the SSL certificate.
|
||||
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
|
||||
cert_contents = f.read()
|
||||
|
||||
# Validate the certificate.
|
||||
try:
|
||||
certificate = load_certificate(cert_contents)
|
||||
except CertInvalidException as cie:
|
||||
raise ConfigValidationException('Could not load SSL certificate: %s' % cie)
|
||||
|
||||
# Verify the certificate has not expired.
|
||||
if certificate.expired:
|
||||
raise ConfigValidationException('The specified SSL certificate has expired.')
|
||||
|
||||
# Verify the hostname matches the name in the certificate.
|
||||
if not certificate.matches_name(_ssl_cn(config['SERVER_HOSTNAME'])):
|
||||
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
|
||||
(', '.join(list(certificate.names)), _ssl_cn(config['SERVER_HOSTNAME'])))
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
# Verify the private key against the certificate.
|
||||
private_key_path = None
|
||||
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
|
||||
private_key_path = f.name
|
||||
|
||||
if not private_key_path:
|
||||
# Only in testing.
|
||||
return
|
||||
|
||||
try:
|
||||
certificate.validate_private_key(private_key_path)
|
||||
except KeyInvalidException as kie:
|
||||
raise ConfigValidationException('SSL private key failed to validate: %s' % kie)
|
||||
|
||||
|
||||
def _ssl_cn(server_hostname):
|
||||
""" Return the common name (fully qualified host name) from the SERVER_HOSTNAME. """
|
||||
host_port = server_hostname.rsplit(':', 1)
|
||||
|
||||
# SERVER_HOSTNAME includes the port
|
||||
if len(host_port) == 2:
|
||||
if host_port[-1].isdigit():
|
||||
return host_port[-2]
|
||||
|
||||
return server_hostname
|
54
util/config/validators/validate_storage.py
Normal file
54
util/config/validators/validate_storage.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from storage import get_storage_driver, TYPE_LOCAL_STORAGE
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
|
||||
class StorageValidator(BaseValidator):
|
||||
name = "registry-storage"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates registry storage. """
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
ip_resolver = validator_context.ip_resolver
|
||||
config_provider = validator_context.config_provider
|
||||
|
||||
replication_enabled = config.get('FEATURE_STORAGE_REPLICATION', False)
|
||||
|
||||
providers = _get_storage_providers(config, ip_resolver, config_provider).items()
|
||||
if not providers:
|
||||
raise ConfigValidationException('Storage configuration required')
|
||||
|
||||
for name, (storage_type, driver) in providers:
|
||||
# We can skip localstorage validation, since we can't guarantee that
|
||||
# this will be the same machine Q.E. will run under
|
||||
if storage_type == TYPE_LOCAL_STORAGE:
|
||||
continue
|
||||
|
||||
try:
|
||||
if replication_enabled and storage_type == 'LocalStorage':
|
||||
raise ConfigValidationException('Locally mounted directory not supported ' +
|
||||
'with storage replication')
|
||||
|
||||
# Run validation on the driver.
|
||||
driver.validate(client)
|
||||
|
||||
# Run setup on the driver if the read/write succeeded.
|
||||
driver.setup()
|
||||
except Exception as ex:
|
||||
msg = str(ex).strip().split("\n")[0]
|
||||
raise ConfigValidationException('Invalid storage configuration: %s: %s' % (name, msg))
|
||||
|
||||
|
||||
def _get_storage_providers(config, ip_resolver, config_provider):
|
||||
storage_config = config.get('DISTRIBUTED_STORAGE_CONFIG', {})
|
||||
drivers = {}
|
||||
|
||||
try:
|
||||
for name, parameters in storage_config.items():
|
||||
driver = get_storage_driver(None, None, None, config_provider, ip_resolver, parameters)
|
||||
drivers[name] = (parameters[0], driver)
|
||||
except TypeError:
|
||||
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)
|
||||
|
||||
return drivers
|
31
util/config/validators/validate_timemachine.py
Normal file
31
util/config/validators/validate_timemachine.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import logging
|
||||
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from util.timedeltastring import convert_to_timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TimeMachineValidator(BaseValidator):
|
||||
name = "time-machine"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
config = validator_context.config
|
||||
|
||||
if not 'DEFAULT_TAG_EXPIRATION' in config:
|
||||
# Old style config
|
||||
return
|
||||
|
||||
try:
|
||||
convert_to_timedelta(config['DEFAULT_TAG_EXPIRATION']).total_seconds()
|
||||
except ValueError as ve:
|
||||
raise ConfigValidationException('Invalid default expiration: %s' % ve.message)
|
||||
|
||||
if not config['DEFAULT_TAG_EXPIRATION'] in config.get('TAG_EXPIRATION_OPTIONS', []):
|
||||
raise ConfigValidationException('Default expiration must be in expiration options set')
|
||||
|
||||
for ts in config.get('TAG_EXPIRATION_OPTIONS', []):
|
||||
try:
|
||||
convert_to_timedelta(ts)
|
||||
except ValueError as ve:
|
||||
raise ConfigValidationException('Invalid tag expiration option: %s' % ts)
|
59
util/config/validators/validate_torrent.py
Normal file
59
util/config/validators/validate_torrent.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import logging
|
||||
|
||||
from hashlib import sha1
|
||||
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
from util.registry.torrent import jwt_from_infohash, TorrentConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BittorrentValidator(BaseValidator):
|
||||
name = "bittorrent"
|
||||
|
||||
@classmethod
|
||||
def validate(cls, validator_context):
|
||||
""" Validates the configuration for using BitTorrent for downloads. """
|
||||
config = validator_context.config
|
||||
client = validator_context.http_client
|
||||
|
||||
announce_url = config.get('BITTORRENT_ANNOUNCE_URL')
|
||||
if not announce_url:
|
||||
raise ConfigValidationException('Missing announce URL')
|
||||
|
||||
# Ensure that the tracker is reachable and accepts requests signed with a registry key.
|
||||
params = {
|
||||
'info_hash': sha1('test').digest(),
|
||||
'peer_id': '-QUAY00-6wfG2wk6wWLc',
|
||||
'uploaded': 0,
|
||||
'downloaded': 0,
|
||||
'left': 0,
|
||||
'numwant': 0,
|
||||
'port': 80,
|
||||
}
|
||||
|
||||
torrent_config = TorrentConfiguration.for_testing(validator_context.instance_keys, announce_url,
|
||||
validator_context.registry_title)
|
||||
encoded_jwt = jwt_from_infohash(torrent_config, params['info_hash'])
|
||||
params['jwt'] = encoded_jwt
|
||||
|
||||
resp = client.get(announce_url, timeout=5, params=params)
|
||||
logger.debug('Got tracker response: %s: %s', resp.status_code, resp.text)
|
||||
|
||||
if resp.status_code == 404:
|
||||
raise ConfigValidationException('Announce path not found; did you forget `/announce`?')
|
||||
|
||||
if resp.status_code == 500:
|
||||
raise ConfigValidationException('Did not get expected response from Tracker; ' +
|
||||
'please check your settings')
|
||||
|
||||
if resp.status_code == 200:
|
||||
if 'invalid jwt' in resp.text:
|
||||
raise ConfigValidationException('Could not authorize to Tracker; is your Tracker ' +
|
||||
'properly configured?')
|
||||
|
||||
if 'failure reason' in resp.text:
|
||||
raise ConfigValidationException('Could not validate signed announce request: ' + resp.text)
|
||||
|
||||
if 'go_goroutines' in resp.text:
|
||||
raise ConfigValidationException('Could not validate signed announce request: ' +
|
||||
'provided port is used for Prometheus')
|
92
util/dict_wrappers.py
Normal file
92
util/dict_wrappers.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import json
|
||||
from jsonpath_rw import parse
|
||||
|
||||
class SafeDictSetter(object):
|
||||
""" Specialized write-only dictionary wrapper class that allows for setting
|
||||
nested keys via a path syntax.
|
||||
|
||||
Example:
|
||||
sds = SafeDictSetter()
|
||||
sds['foo.bar.baz'] = 'hello' # Sets 'foo' = {'bar': {'baz': 'hello'}}
|
||||
sds['somekey'] = None # Does not set the key since the value is None
|
||||
"""
|
||||
def __init__(self, initial_object=None):
|
||||
self._object = initial_object or {}
|
||||
|
||||
def __setitem__(self, path, value):
|
||||
self.set(path, value)
|
||||
|
||||
def set(self, path, value, allow_none=False):
|
||||
""" Sets the value of the given path to the given value. """
|
||||
if value is None and not allow_none:
|
||||
return
|
||||
|
||||
pieces = path.split('.')
|
||||
current = self._object
|
||||
|
||||
for piece in pieces[:len(pieces)-1]:
|
||||
current_obj = current.get(piece, {})
|
||||
if not isinstance(current_obj, dict):
|
||||
raise Exception('Key %s is a non-object value: %s' % (piece, current_obj))
|
||||
|
||||
current[piece] = current_obj
|
||||
current = current_obj
|
||||
|
||||
current[pieces[-1]] = value
|
||||
|
||||
def dict_value(self):
|
||||
""" Returns the dict value built. """
|
||||
return self._object
|
||||
|
||||
def json_value(self):
|
||||
""" Returns the JSON string value of the dictionary built. """
|
||||
return json.dumps(self._object)
|
||||
|
||||
|
||||
class JSONPathDict(object):
|
||||
""" Specialized read-only dictionary wrapper class that uses the jsonpath_rw library
|
||||
to access keys via an X-Path-like syntax.
|
||||
|
||||
Example:
|
||||
pd = JSONPathDict({'hello': {'hi': 'there'}})
|
||||
pd['hello.hi'] # Returns 'there'
|
||||
"""
|
||||
def __init__(self, dict_value):
|
||||
""" Init the helper with the JSON object.
|
||||
"""
|
||||
self._object = dict_value
|
||||
|
||||
def __getitem__(self, path):
|
||||
return self.get(path)
|
||||
|
||||
def __iter__(self):
|
||||
return self._object.itervalues()
|
||||
|
||||
def iterkeys(self):
|
||||
return self._object.iterkeys()
|
||||
|
||||
def get(self, path, not_found_handler=None):
|
||||
""" Returns the value found at the given path. Path is a json-path expression. """
|
||||
if self._object == {} or self._object is None:
|
||||
return None
|
||||
jsonpath_expr = parse(path)
|
||||
|
||||
try:
|
||||
matches = jsonpath_expr.find(self._object)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
if not matches:
|
||||
return not_found_handler() if not_found_handler else None
|
||||
|
||||
match = matches[0].value
|
||||
if not match:
|
||||
return not_found_handler() if not_found_handler else None
|
||||
|
||||
if isinstance(match, dict):
|
||||
return JSONPathDict(match)
|
||||
|
||||
return match
|
||||
|
||||
def keys(self):
|
||||
return self._object.keys()
|
110
util/disableabuser.py
Normal file
110
util/disableabuser.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import argparse
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app import tf
|
||||
from data import model
|
||||
from data.model import db_transaction
|
||||
from data.database import QueueItem, Repository, RepositoryBuild, RepositoryBuildTrigger, RepoMirrorConfig
|
||||
from data.queue import WorkQueue
|
||||
|
||||
def ask_disable_namespace(username, queue_name):
|
||||
user = model.user.get_namespace_user(username)
|
||||
if user is None:
|
||||
raise Exception('Unknown user or organization %s' % username)
|
||||
|
||||
if not user.enabled:
|
||||
print "NOTE: Namespace %s is already disabled" % username
|
||||
|
||||
queue_prefix = '%s/%s/%%' % (queue_name, username)
|
||||
existing_queue_item_count = (QueueItem
|
||||
.select()
|
||||
.where(QueueItem.queue_name ** queue_prefix)
|
||||
.where(QueueItem.available == 1,
|
||||
QueueItem.retries_remaining > 0,
|
||||
QueueItem.processing_expires > datetime.now())
|
||||
.count())
|
||||
|
||||
repository_trigger_count = (RepositoryBuildTrigger
|
||||
.select()
|
||||
.join(Repository)
|
||||
.where(Repository.namespace_user == user)
|
||||
.count())
|
||||
|
||||
print "============================================="
|
||||
print "For namespace %s" % username
|
||||
print "============================================="
|
||||
|
||||
print "User %s has email address %s" % (username, user.email)
|
||||
print "User %s has %s queued builds in their namespace" % (username, existing_queue_item_count)
|
||||
print "User %s has %s build triggers in their namespace" % (username, repository_trigger_count)
|
||||
|
||||
confirm_msg = "Would you like to disable this user and delete their triggers and builds? [y/N]> "
|
||||
letter = str(raw_input(confirm_msg))
|
||||
if letter.lower() != 'y':
|
||||
print "Action canceled"
|
||||
return
|
||||
|
||||
print "============================================="
|
||||
|
||||
triggers = []
|
||||
count_removed = 0
|
||||
with db_transaction():
|
||||
user.enabled = False
|
||||
user.save()
|
||||
|
||||
repositories_query = Repository.select().where(Repository.namespace_user == user)
|
||||
if len(repositories_query.clone()):
|
||||
builds = list(RepositoryBuild
|
||||
.select()
|
||||
.where(RepositoryBuild.repository << list(repositories_query)))
|
||||
|
||||
triggers = list(RepositoryBuildTrigger
|
||||
.select()
|
||||
.where(RepositoryBuildTrigger.repository << list(repositories_query)))
|
||||
|
||||
mirrors = list(RepoMirrorConfig
|
||||
.select()
|
||||
.where(RepoMirrorConfig.repository << list(repositories_query)))
|
||||
|
||||
# Delete all builds for the user's repositories.
|
||||
if builds:
|
||||
RepositoryBuild.delete().where(RepositoryBuild.id << builds).execute()
|
||||
|
||||
# Delete all build triggers for the user's repositories.
|
||||
if triggers:
|
||||
RepositoryBuildTrigger.delete().where(RepositoryBuildTrigger.id << triggers).execute()
|
||||
|
||||
# Delete all mirrors for the user's repositories.
|
||||
if mirrors:
|
||||
RepoMirrorConfig.delete().where(RepoMirrorConfig.id << mirrors).execute()
|
||||
|
||||
# Delete all queue items for the user's namespace.
|
||||
dockerfile_build_queue = WorkQueue(queue_name, tf, has_namespace=True)
|
||||
count_removed = dockerfile_build_queue.delete_namespaced_items(user.username)
|
||||
|
||||
info = (user.username, len(triggers), count_removed, len(mirrors))
|
||||
print "Namespace %s disabled, %s triggers deleted, %s queued builds removed, %s mirrors deleted" % info
|
||||
return user
|
||||
|
||||
|
||||
def disable_abusing_user(username, queue_name):
|
||||
if not username:
|
||||
raise Exception('Must enter a username')
|
||||
|
||||
# Disable the namespace itself.
|
||||
user = ask_disable_namespace(username, queue_name)
|
||||
|
||||
# If an organization, ask if all team members should be disabled as well.
|
||||
if user.organization:
|
||||
members = model.organization.get_organization_member_set(user)
|
||||
for membername in members:
|
||||
ask_disable_namespace(membername, queue_name)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Disables a user abusing the build system')
|
||||
parser.add_argument('username', help='The username of the abuser')
|
||||
parser.add_argument('queuename', help='The name of the dockerfile build queue ' +
|
||||
'(e.g. `dockerfilebuild` or `dockerfilebuildstaging`)')
|
||||
args = parser.parse_args()
|
||||
disable_abusing_user(args.username, args.queuename)
|
106
util/dockerfileparse.py
Normal file
106
util/dockerfileparse.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import re
|
||||
|
||||
LINE_CONTINUATION_REGEX = re.compile(r'(\s)*\\(\s)*\n')
|
||||
COMMAND_REGEX = re.compile('([A-Za-z]+)\s(.*)')
|
||||
|
||||
COMMENT_CHARACTER = '#'
|
||||
LATEST_TAG = 'latest'
|
||||
|
||||
class ParsedDockerfile(object):
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
|
||||
def _get_commands_of_kind(self, kind):
|
||||
return [command for command in self.commands if command['command'] == kind]
|
||||
|
||||
def _get_from_image_identifier(self):
|
||||
from_commands = self._get_commands_of_kind('FROM')
|
||||
if not from_commands:
|
||||
return None
|
||||
|
||||
return from_commands[-1]['parameters']
|
||||
|
||||
@staticmethod
|
||||
def parse_image_identifier(image_identifier):
|
||||
""" Parses a docker image identifier, and returns a tuple of image name and tag, where the tag
|
||||
is filled in with "latest" if left unspecified.
|
||||
"""
|
||||
# Note:
|
||||
# Dockerfile images references can be of multiple forms:
|
||||
# server:port/some/path
|
||||
# somepath
|
||||
# server/some/path
|
||||
# server/some/path:tag
|
||||
# server:port/some/path:tag
|
||||
parts = image_identifier.strip().split(':')
|
||||
|
||||
if len(parts) == 1:
|
||||
# somepath
|
||||
return (parts[0], LATEST_TAG)
|
||||
|
||||
# Otherwise, determine if the last part is a port
|
||||
# or a tag.
|
||||
if parts[-1].find('/') >= 0:
|
||||
# Last part is part of the hostname.
|
||||
return (image_identifier, LATEST_TAG)
|
||||
|
||||
# Remaining cases:
|
||||
# server/some/path:tag
|
||||
# server:port/some/path:tag
|
||||
return (':'.join(parts[0:-1]), parts[-1])
|
||||
|
||||
def get_base_image(self):
|
||||
""" Return the base image without the tag name. """
|
||||
return self.get_image_and_tag()[0]
|
||||
|
||||
def get_image_and_tag(self):
|
||||
""" Returns the image and tag from the FROM line of the dockerfile. """
|
||||
image_identifier = self._get_from_image_identifier()
|
||||
if image_identifier is None:
|
||||
return (None, None)
|
||||
|
||||
return self.parse_image_identifier(image_identifier)
|
||||
|
||||
|
||||
def strip_comments(contents):
|
||||
lines = []
|
||||
for line in contents.split('\n'):
|
||||
index = line.find(COMMENT_CHARACTER)
|
||||
if index < 0:
|
||||
lines.append(line)
|
||||
continue
|
||||
|
||||
line = line[:index]
|
||||
lines.append(line)
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def join_continued_lines(contents):
|
||||
return LINE_CONTINUATION_REGEX.sub('', contents)
|
||||
|
||||
|
||||
def parse_dockerfile(contents):
|
||||
# If we receive ASCII, translate into unicode.
|
||||
try:
|
||||
contents = contents.decode('utf-8')
|
||||
except ValueError:
|
||||
# Already unicode or unable to convert.
|
||||
pass
|
||||
|
||||
contents = join_continued_lines(strip_comments(contents))
|
||||
lines = [line.strip() for line in contents.split('\n') if len(line) > 0]
|
||||
|
||||
commands = []
|
||||
for line in lines:
|
||||
match_command = COMMAND_REGEX.match(line)
|
||||
if match_command:
|
||||
command = match_command.group(1).upper()
|
||||
parameters = match_command.group(2)
|
||||
|
||||
commands.append({
|
||||
'command': command,
|
||||
'parameters': parameters
|
||||
})
|
||||
|
||||
return ParsedDockerfile(commands)
|
7
util/dynamic.py
Normal file
7
util/dynamic.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
def import_class(module_name, class_name):
|
||||
""" Import a class given the specified module name and class name. """
|
||||
klass = __import__(module_name)
|
||||
class_segments = class_name.split('.')
|
||||
for segment in class_segments:
|
||||
klass = getattr(klass, segment)
|
||||
return klass
|
84
util/expiresdict.py
Normal file
84
util/expiresdict.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from datetime import datetime
|
||||
|
||||
from six import iteritems
|
||||
|
||||
|
||||
class ExpiresEntry(object):
|
||||
""" A single entry under a ExpiresDict. """
|
||||
def __init__(self, value, expires=None):
|
||||
self.value = value
|
||||
self._expiration = expires
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
if self._expiration is None:
|
||||
return False
|
||||
|
||||
return datetime.now() >= self._expiration
|
||||
|
||||
|
||||
class ExpiresDict(object):
|
||||
""" ExpiresDict defines a dictionary-like class whose keys have expiration. The rebuilder is
|
||||
a function that returns the full contents of the cached dictionary as a dict of the keys
|
||||
and whose values are TTLEntry's. If the rebuilder is None, then no rebuilding is performed.
|
||||
"""
|
||||
def __init__(self, rebuilder=None):
|
||||
self._rebuilder = rebuilder
|
||||
self._items = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
found = self.get(key)
|
||||
if found is None:
|
||||
raise KeyError
|
||||
|
||||
return found
|
||||
|
||||
def get(self, key, default_value=None):
|
||||
# Check the cache first. If the key is found and it has not yet expired,
|
||||
# return it.
|
||||
found = self._items.get(key)
|
||||
if found is not None and not found.expired:
|
||||
return found.value
|
||||
|
||||
# Otherwise the key has expired or was not found. Rebuild the cache and check it again.
|
||||
items = self._rebuild()
|
||||
found_item = items.get(key)
|
||||
if found_item is None:
|
||||
return default_value
|
||||
|
||||
return found_item.value
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
def _rebuild(self):
|
||||
if self._rebuilder is None:
|
||||
return self._items
|
||||
|
||||
items = self._rebuilder()
|
||||
self._items = items
|
||||
return items
|
||||
|
||||
def _alive_items(self):
|
||||
return {k: entry.value for (k, entry) in self._items.items() if not entry.expired}
|
||||
|
||||
def items(self):
|
||||
return self._alive_items().items()
|
||||
|
||||
def iteritems(self):
|
||||
return iteritems(self._alive_items())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._alive_items())
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._items[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._alive_items())
|
||||
|
||||
def set(self, key, value, expires=None):
|
||||
self._items[key] = ExpiresEntry(value, expires=expires)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return self.set(key, value)
|
50
util/failover.py
Normal file
50
util/failover.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import logging
|
||||
|
||||
from functools import wraps
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FailoverException(Exception):
|
||||
""" Exception raised when an operation should be retried by the failover decorator.
|
||||
Wraps the exception of the initial failure.
|
||||
"""
|
||||
def __init__(self, exception):
|
||||
super(FailoverException, self).__init__()
|
||||
self.exception = exception
|
||||
|
||||
def failover(func):
|
||||
""" Wraps a function such that it can be retried on specified failures.
|
||||
Raises FailoverException when all failovers are exhausted.
|
||||
Example:
|
||||
|
||||
@failover
|
||||
def get_google(scheme, use_www=False):
|
||||
www = 'www.' if use_www else ''
|
||||
try:
|
||||
r = requests.get(scheme + '://' + www + 'google.com')
|
||||
except requests.RequestException as ex:
|
||||
raise FailoverException(ex)
|
||||
return r
|
||||
|
||||
def GooglePingTest():
|
||||
r = get_google(
|
||||
(('http'), {'use_www': False}),
|
||||
(('http'), {'use_www': True}),
|
||||
(('https'), {'use_www': False}),
|
||||
(('https'), {'use_www': True}),
|
||||
)
|
||||
print('Successfully contacted ' + r.url)
|
||||
"""
|
||||
@wraps(func)
|
||||
def wrapper(*args_sets):
|
||||
for arg_set in args_sets:
|
||||
try:
|
||||
return func(*arg_set[0], **arg_set[1])
|
||||
except FailoverException as ex:
|
||||
logger.debug('failing over')
|
||||
exception = ex.exception
|
||||
continue
|
||||
raise exception
|
||||
return wrapper
|
70
util/fixuseradmin.py
Normal file
70
util/fixuseradmin.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import argparse
|
||||
import sys
|
||||
|
||||
from app import app
|
||||
from data.database import Namespace, Repository, RepositoryPermission, Role
|
||||
from data.model.permission import get_user_repo_permissions
|
||||
from data.model.user import get_active_users, get_nonrobot_user
|
||||
|
||||
DESCRIPTION = '''
|
||||
Fix user repositories missing admin permissions for owning user.
|
||||
'''
|
||||
|
||||
parser = argparse.ArgumentParser(description=DESCRIPTION)
|
||||
parser.add_argument('users', nargs='*', help='Users to check')
|
||||
parser.add_argument('-a', '--all', action='store_true', help='Check all users')
|
||||
parser.add_argument('-n', '--dry-run', action='store_true', help="Don't act")
|
||||
|
||||
ADMIN = Role.get(name='admin')
|
||||
|
||||
|
||||
def repos_for_namespace(namespace):
|
||||
return (Repository
|
||||
.select(Repository.id, Repository.name, Namespace.username)
|
||||
.join(Namespace)
|
||||
.where(Namespace.username == namespace))
|
||||
|
||||
|
||||
def has_admin(user, repo):
|
||||
perms = get_user_repo_permissions(user, repo)
|
||||
return any(p.role == ADMIN for p in perms)
|
||||
|
||||
|
||||
def get_users(all_users=False, users_list=None):
|
||||
if all_users:
|
||||
return get_active_users(disabled=False)
|
||||
|
||||
return map(get_nonrobot_user, users_list)
|
||||
|
||||
|
||||
def ensure_admin(user, repos, dry_run=False):
|
||||
repos = [repo for repo in repos if not has_admin(user, repo)]
|
||||
|
||||
for repo in repos:
|
||||
print('User {} missing admin on: {}'.format(user.username, repo.name))
|
||||
|
||||
if not dry_run:
|
||||
RepositoryPermission.create(user=user, repository=repo, role=ADMIN)
|
||||
print('Granted {} admin on: {}'.format(user.username, repo.name))
|
||||
|
||||
return len(repos)
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
found = 0
|
||||
|
||||
if not args.all and len(args.users) == 0:
|
||||
sys.exit('Need a list of users or --all')
|
||||
|
||||
for user in get_users(all_users=args.all, users_list=args.users):
|
||||
if user is not None:
|
||||
repos = repos_for_namespace(user.username)
|
||||
found += ensure_admin(user, repos, dry_run=args.dry_run)
|
||||
|
||||
print('\nFound {} user repos missing admin'
|
||||
' permissions for owner.'.format(found))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
56
util/generatepresharedkey.py
Normal file
56
util/generatepresharedkey.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import argparse
|
||||
|
||||
from dateutil.parser import parse as parse_date
|
||||
|
||||
from app import app
|
||||
from data import model
|
||||
from data.database import ServiceKeyApprovalType
|
||||
from data.logs_model import logs_model
|
||||
|
||||
|
||||
def generate_key(service, name, expiration_date=None, notes=None):
|
||||
metadata = {
|
||||
'created_by': 'CLI tool',
|
||||
}
|
||||
|
||||
# Generate a key with a private key that we *never save*.
|
||||
(private_key, key) = model.service_keys.generate_service_key(service, expiration_date,
|
||||
metadata=metadata,
|
||||
name=name)
|
||||
# Auto-approve the service key.
|
||||
model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.AUTOMATIC, notes=notes or '')
|
||||
|
||||
# Log the creation and auto-approval of the service key.
|
||||
key_log_metadata = {
|
||||
'kid': key.kid,
|
||||
'preshared': True,
|
||||
'service': service,
|
||||
'name': name,
|
||||
'expiration_date': expiration_date,
|
||||
'auto_approved': True,
|
||||
}
|
||||
|
||||
logs_model.log_action('service_key_create', metadata=key_log_metadata)
|
||||
logs_model.log_action('service_key_approve', metadata=key_log_metadata)
|
||||
return private_key, key.kid
|
||||
|
||||
|
||||
def valid_date(s):
|
||||
try:
|
||||
return parse_date(s)
|
||||
except ValueError:
|
||||
msg = "Not a valid date: '{0}'.".format(s)
|
||||
raise argparse.ArgumentTypeError(msg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Generates a preshared key')
|
||||
parser.add_argument('service', help='The service name for which the key is being generated')
|
||||
parser.add_argument('name', help='The friendly name for the key')
|
||||
parser.add_argument('--expiration', default=None, type=valid_date,
|
||||
help='The optional expiration date for the key')
|
||||
parser.add_argument('--notes', help='Optional notes about the key', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
generated, _ = generate_key(args.service, args.name, args.expiration, args.notes)
|
||||
print generated.exportKey('PEM')
|
20
util/headers.py
Normal file
20
util/headers.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import base64
|
||||
|
||||
def parse_basic_auth(header_value):
|
||||
""" Attempts to parse the given header value as a Base64-encoded Basic auth header. """
|
||||
|
||||
if not header_value:
|
||||
return None
|
||||
|
||||
parts = header_value.split(' ')
|
||||
if len(parts) != 2 or parts[0].lower() != 'basic':
|
||||
return None
|
||||
|
||||
try:
|
||||
basic_parts = base64.b64decode(parts[1]).split(':', 1)
|
||||
if len(basic_parts) != 2:
|
||||
return None
|
||||
|
||||
return basic_parts
|
||||
except ValueError:
|
||||
return None
|
60
util/html.py
Normal file
60
util/html.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from bs4 import BeautifulSoup, Tag, NavigableString
|
||||
|
||||
_NEWLINE_INDICATOR = '<<<newline>>>'
|
||||
|
||||
def _bold(elem):
|
||||
elem.replace_with('*%s*' % elem.text)
|
||||
|
||||
def _unordered_list(elem):
|
||||
constructed = ''
|
||||
for child in elem.children:
|
||||
if child.name == 'li':
|
||||
constructed += '* %s\n' % child.text
|
||||
elem.replace_with(constructed)
|
||||
|
||||
def _horizontal_rule(elem):
|
||||
elem.replace_with('%s\n' % ('-' * 80))
|
||||
|
||||
def _anchor(elem):
|
||||
elem.replace_with('[%s](%s)' % (elem.text, elem['href']))
|
||||
|
||||
def _table(elem):
|
||||
elem.replace_with('%s%s' % (elem.text, _NEWLINE_INDICATOR))
|
||||
|
||||
|
||||
_ELEMENT_REPLACER = {
|
||||
'b': _bold,
|
||||
'strong': _bold,
|
||||
'ul': _unordered_list,
|
||||
'hr': _horizontal_rule,
|
||||
'a': _anchor,
|
||||
'table': _table,
|
||||
}
|
||||
|
||||
def _collapse_whitespace(text):
|
||||
new_lines = []
|
||||
lines = text.split('\n')
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
new_lines.append(line.strip().replace(_NEWLINE_INDICATOR, '\n'))
|
||||
|
||||
return '\n'.join(new_lines)
|
||||
|
||||
def html2text(html):
|
||||
soup = BeautifulSoup(html, 'html5lib')
|
||||
_html2text(soup)
|
||||
return _collapse_whitespace(soup.text)
|
||||
|
||||
def _html2text(elem):
|
||||
for child in elem.children:
|
||||
if isinstance(child, Tag):
|
||||
_html2text(child)
|
||||
elif isinstance(child, NavigableString):
|
||||
# No changes necessary
|
||||
continue
|
||||
|
||||
if elem.parent:
|
||||
if elem.name in _ELEMENT_REPLACER:
|
||||
_ELEMENT_REPLACER[elem.name](elem)
|
82
util/http.py
Normal file
82
util/http.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
import json
|
||||
|
||||
from flask import request, make_response, current_app
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from app import analytics
|
||||
from auth.auth_context import get_authenticated_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_MESSAGE = {}
|
||||
DEFAULT_MESSAGE[400] = 'Invalid Request'
|
||||
DEFAULT_MESSAGE[401] = 'Unauthorized'
|
||||
DEFAULT_MESSAGE[403] = 'Permission Denied'
|
||||
DEFAULT_MESSAGE[404] = 'Not Found'
|
||||
DEFAULT_MESSAGE[409] = 'Conflict'
|
||||
DEFAULT_MESSAGE[501] = 'Not Implemented'
|
||||
|
||||
|
||||
def _abort(status_code, data_object, description, headers):
|
||||
# Add CORS headers to all errors
|
||||
options_resp = current_app.make_default_options_response()
|
||||
headers['Access-Control-Allow-Origin'] = '*'
|
||||
headers['Access-Control-Allow-Methods'] = options_resp.headers['allow']
|
||||
headers['Access-Control-Max-Age'] = str(21600)
|
||||
headers['Access-Control-Allow-Headers'] = ['Authorization', 'Content-Type']
|
||||
|
||||
resp = make_response(json.dumps(data_object), status_code, headers)
|
||||
|
||||
# Report the abort to the user.
|
||||
# Raising HTTPException as workaround for https://github.com/pallets/werkzeug/issues/1098
|
||||
new_exception = HTTPException(response=resp, description=description)
|
||||
new_exception.code = status_code
|
||||
raise new_exception
|
||||
|
||||
|
||||
def exact_abort(status_code, message=None):
|
||||
data = {}
|
||||
|
||||
if message is not None:
|
||||
data['error'] = message
|
||||
|
||||
_abort(status_code, data, message or None, {})
|
||||
|
||||
|
||||
def abort(status_code, message=None, issue=None, headers=None, **kwargs):
|
||||
message = (str(message) % kwargs if message else
|
||||
DEFAULT_MESSAGE.get(status_code, ''))
|
||||
|
||||
params = dict(request.view_args or {})
|
||||
params.update(kwargs)
|
||||
|
||||
params['url'] = request.url
|
||||
params['status_code'] = status_code
|
||||
params['message'] = message
|
||||
|
||||
# Add the user information.
|
||||
auth_context = get_authenticated_context()
|
||||
if auth_context is not None:
|
||||
message = '%s (authorized: %s)' % (message, auth_context.description)
|
||||
|
||||
# Log the abort.
|
||||
logger.error('Error %s: %s; Arguments: %s' % (status_code, message, params))
|
||||
|
||||
# Calculate the issue URL (if the issue ID was supplied).
|
||||
issue_url = None
|
||||
if issue:
|
||||
issue_url = 'http://docs.quay.io/issues/%s.html' % (issue)
|
||||
|
||||
# Create the final response data and message.
|
||||
data = {}
|
||||
data['error'] = message
|
||||
|
||||
if issue_url:
|
||||
data['info_url'] = issue_url
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
_abort(status_code, data, message, headers)
|
60
util/invoice.py
Normal file
60
util/invoice.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from datetime import datetime
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from xhtml2pdf import pisa
|
||||
|
||||
import StringIO
|
||||
|
||||
from app import app
|
||||
|
||||
jinja_options = {
|
||||
"loader": FileSystemLoader('util'),
|
||||
}
|
||||
|
||||
env = Environment(**jinja_options)
|
||||
|
||||
|
||||
def renderInvoiceToPdf(invoice, user):
|
||||
""" Renders a nice PDF display for the given invoice. """
|
||||
sourceHtml = renderInvoiceToHtml(invoice, user)
|
||||
output = StringIO.StringIO()
|
||||
pisaStatus = pisa.CreatePDF(sourceHtml, dest=output)
|
||||
if pisaStatus.err:
|
||||
return None
|
||||
|
||||
value = output.getvalue()
|
||||
output.close()
|
||||
return value
|
||||
|
||||
|
||||
def renderInvoiceToHtml(invoice, user):
|
||||
""" Renders a nice HTML display for the given invoice. """
|
||||
from endpoints.api.billing import get_invoice_fields
|
||||
|
||||
def get_price(price):
|
||||
if not price:
|
||||
return '$0'
|
||||
|
||||
return '$' + '{0:.2f}'.format(float(price) / 100)
|
||||
|
||||
def get_range(line):
|
||||
if line.period and line.period.start and line.period.end:
|
||||
return ': ' + format_date(line.period.start) + ' - ' + format_date(line.period.end)
|
||||
return ''
|
||||
|
||||
def format_date(timestamp):
|
||||
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d')
|
||||
|
||||
app_logo = app.config.get('ENTERPRISE_LOGO_URL', 'https://quay.io/static/img/quay-logo.png')
|
||||
data = {
|
||||
'user': user.username,
|
||||
'invoice': invoice,
|
||||
'invoice_date': format_date(invoice.date),
|
||||
'getPrice': get_price,
|
||||
'getRange': get_range,
|
||||
'custom_fields': get_invoice_fields(user)[0],
|
||||
'logo': app_logo,
|
||||
}
|
||||
|
||||
template = env.get_template('invoice.tmpl')
|
||||
rendered = template.render(data)
|
||||
return rendered
|
69
util/invoice.tmpl
Normal file
69
util/invoice.tmpl
Normal file
|
@ -0,0 +1,69 @@
|
|||
<html>
|
||||
<body>
|
||||
<table width="100%" style="max-width: 640px">
|
||||
<tr>
|
||||
<td valign="center" style="padding: 10px;">
|
||||
<img src="{{ logo }}" alt="Quay" style="width: 100px;">
|
||||
</td>
|
||||
<td valign="center">
|
||||
<h3>Quay</h3>
|
||||
<p style="font-size: 12px; -webkit-text-adjust: none">
|
||||
Red Hat, Inc<br>
|
||||
https://redhat.com<br>
|
||||
100 East Davie Street<br>
|
||||
Raleigh, North Carolina 27601
|
||||
</p>
|
||||
</td>
|
||||
<td align="right" width="100%">
|
||||
<h1 style="color: #ddd;">RECEIPT</h1>
|
||||
<table>
|
||||
<tr><td>Date:</td><td>{{ invoice_date }}</td></tr>
|
||||
<tr><td>Invoice #:</td><td style="font-size: 10px">{{ invoice.id }}</td></tr>
|
||||
{% for custom_field in custom_fields %}
|
||||
<tr>
|
||||
<td>*{{ custom_field['title'] }}:</td>
|
||||
<td style="font-size: 10px">{{ custom_field['value'] }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<hr>
|
||||
|
||||
<table width="100%" style="max-width: 640px">
|
||||
<thead>
|
||||
<th style="padding: 4px; background: #eee; text-align: center; font-weight: bold">Description</th>
|
||||
<th style="padding: 4px; background: #eee; text-align: center; font-weight: bold">Line Total</th>
|
||||
</thead>
|
||||
<tbody>
|
||||
{%- for line in invoice.lines.data -%}
|
||||
<tr>
|
||||
<td width="100%" style="padding: 4px;">{{ line.description or ('Plan Subscription' + getRange(line)) }}</td>
|
||||
<td style="padding: 4px; min-width: 150px;">{{ getPrice(line.amount) }}</td>
|
||||
</tr>
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
<tr>
|
||||
<td></td>
|
||||
<td valign="right">
|
||||
<table>
|
||||
<tr><td><b>Subtotal: </b></td><td>{{ getPrice(invoice.subtotal) }}</td></tr>
|
||||
<tr><td><b>Total: </b></td><td>{{ getPrice(invoice.total) }}</td></tr>
|
||||
<tr><td><b>Paid: </b></td><td>{{ getPrice(invoice.total) if invoice.paid else 0 }}</td></tr>
|
||||
<tr><td><b>Total Due:</b></td>
|
||||
<td>{{ getPrice(invoice.ending_balance) }}</td></tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<div style="margin: 6px; padding: 6px; width: 100%; max-width: 640px; border-top: 2px solid #eee; text-align: center; font-size: 14px; -webkit-text-adjust: none; font-weight: bold;">
|
||||
We thank you for your continued business!
|
||||
</div>
|
||||
|
||||
</body>
|
||||
</html>
|
BIN
util/ipresolver/GeoLite2-Country.mmdb
Normal file
BIN
util/ipresolver/GeoLite2-Country.mmdb
Normal file
Binary file not shown.
153
util/ipresolver/__init__.py
Normal file
153
util/ipresolver/__init__.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
import logging
|
||||
import json
|
||||
import time
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from threading import Thread, Lock
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import add_metaclass
|
||||
from cachetools.func import ttl_cache, lru_cache
|
||||
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError
|
||||
|
||||
import geoip2.database
|
||||
import geoip2.errors
|
||||
import requests
|
||||
|
||||
from util.abchelpers import nooper
|
||||
|
||||
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'service', 'sync_token',
|
||||
'country_iso_code'])
|
||||
|
||||
AWS_SERVICES = {'EC2', 'CODEBUILD'}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_aws_ip_ranges():
|
||||
try:
|
||||
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
|
||||
return json.loads(f.read())
|
||||
except IOError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except ValueError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except TypeError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class IPResolverInterface(object):
|
||||
""" Helper class for resolving information about an IP address. """
|
||||
@abstractmethod
|
||||
def resolve_ip(self, ip_address):
|
||||
""" Attempts to return resolved information about the specified IP Address. If such an attempt
|
||||
fails, returns None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_ip_possible_threat(self, ip_address):
|
||||
""" Attempts to return whether the given IP address is a possible abuser or spammer.
|
||||
Returns False if the IP address information could not be looked up.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@nooper
|
||||
class NoopIPResolver(IPResolverInterface):
|
||||
""" No-op version of the security scanner API. """
|
||||
pass
|
||||
|
||||
|
||||
class IPResolver(IPResolverInterface):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
|
||||
self.amazon_ranges = None
|
||||
self.sync_token = None
|
||||
|
||||
logger.info('Loading AWS IP ranges from disk')
|
||||
aws_ip_ranges_data = _get_aws_ip_ranges()
|
||||
if aws_ip_ranges_data is not None and aws_ip_ranges_data.get('syncToken'):
|
||||
logger.info('Building AWS IP ranges')
|
||||
self.amazon_ranges = IPResolver._parse_amazon_ranges(aws_ip_ranges_data)
|
||||
self.sync_token = aws_ip_ranges_data['syncToken']
|
||||
logger.info('Finished building AWS IP ranges')
|
||||
|
||||
@ttl_cache(maxsize=100, ttl=600)
|
||||
def is_ip_possible_threat(self, ip_address):
|
||||
if self.app.config.get('THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT') is None:
|
||||
return False
|
||||
|
||||
if self.app.config.get('IP_DATA_API_KEY') is None:
|
||||
return False
|
||||
|
||||
if not ip_address:
|
||||
return False
|
||||
|
||||
api_key = self.app.config['IP_DATA_API_KEY']
|
||||
|
||||
try:
|
||||
logger.debug('Requesting IP data for IP %s', ip_address)
|
||||
r = requests.get('https://api.ipdata.co/%s/threat?api-key=%s' % (ip_address, api_key),
|
||||
timeout=1)
|
||||
if r.status_code != 200:
|
||||
logger.debug('Got non-200 response for IP %s: %s', ip_address, r.status_code)
|
||||
return False
|
||||
|
||||
logger.debug('Got IP data for IP %s: %s => %s', ip_address, r.status_code, r.json())
|
||||
threat_data = r.json()
|
||||
return threat_data.get('is_threat', False) or threat_data.get('is_bogon', False)
|
||||
except requests.RequestException:
|
||||
logger.exception('Got exception when trying to lookup IP Address')
|
||||
except ValueError:
|
||||
logger.exception('Got exception when trying to lookup IP Address')
|
||||
except Exception:
|
||||
logger.exception('Got exception when trying to lookup IP Address')
|
||||
|
||||
return False
|
||||
|
||||
def resolve_ip(self, ip_address):
|
||||
""" Attempts to return resolved information about the specified IP Address. If such an attempt
|
||||
fails, returns None.
|
||||
"""
|
||||
if not ip_address:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed_ip = IPAddress(ip_address)
|
||||
except AddrFormatError:
|
||||
return ResolvedLocation('invalid_ip', None, self.sync_token, None)
|
||||
|
||||
# Try geoip classification
|
||||
try:
|
||||
geoinfo = self.geoip_db.country(ip_address)
|
||||
except geoip2.errors.AddressNotFoundError:
|
||||
geoinfo = None
|
||||
|
||||
if self.amazon_ranges is None or parsed_ip not in self.amazon_ranges:
|
||||
if geoinfo:
|
||||
return ResolvedLocation(
|
||||
'internet',
|
||||
geoinfo.country.iso_code,
|
||||
self.sync_token,
|
||||
geoinfo.country.iso_code,
|
||||
)
|
||||
|
||||
return ResolvedLocation('internet', None, self.sync_token, None)
|
||||
|
||||
return ResolvedLocation('aws', None, self.sync_token,
|
||||
geoinfo.country.iso_code if geoinfo else None)
|
||||
|
||||
@staticmethod
|
||||
def _parse_amazon_ranges(ranges):
|
||||
all_amazon = IPSet()
|
||||
for service_description in ranges['prefixes']:
|
||||
if service_description['service'] in AWS_SERVICES:
|
||||
all_amazon.add(IPNetwork(service_description['ip_prefix']))
|
||||
|
||||
return all_amazon
|
0
util/ipresolver/test/__init__.py
Normal file
0
util/ipresolver/test/__init__.py
Normal file
51
util/ipresolver/test/test_ipresolver.py
Normal file
51
util/ipresolver/test/test_ipresolver.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import pytest
|
||||
|
||||
from mock import patch
|
||||
|
||||
from util.ipresolver import IPResolver, ResolvedLocation
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.fixture()
|
||||
def test_aws_ip():
|
||||
return '10.0.0.1'
|
||||
|
||||
@pytest.fixture()
|
||||
def aws_ip_range_data():
|
||||
fake_range_doc = {
|
||||
'syncToken': 123456789,
|
||||
'prefixes': [
|
||||
{
|
||||
'ip_prefix': '10.0.0.0/8',
|
||||
'region': 'GLOBAL',
|
||||
'service': 'EC2',
|
||||
},
|
||||
{
|
||||
'ip_prefix': '6.0.0.0/8',
|
||||
'region': 'GLOBAL',
|
||||
'service': 'EC2',
|
||||
},
|
||||
],
|
||||
}
|
||||
return fake_range_doc
|
||||
|
||||
@pytest.fixture()
|
||||
def test_ip_range_cache(aws_ip_range_data):
|
||||
sync_token = aws_ip_range_data['syncToken']
|
||||
all_amazon = IPResolver._parse_amazon_ranges(aws_ip_range_data)
|
||||
fake_cache = {
|
||||
'sync_token': sync_token,
|
||||
'all_amazon': all_amazon,
|
||||
}
|
||||
return fake_cache
|
||||
|
||||
|
||||
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
|
||||
ipresolver = IPResolver(app)
|
||||
ipresolver.amazon_ranges = test_ip_range_cache['all_amazon']
|
||||
ipresolver.sync_token = test_ip_range_cache['sync_token']
|
||||
|
||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=None)
|
||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=None)
|
||||
assert ipresolver.resolve_ip('6.0.0.2') == ResolvedLocation(provider='aws', service=None, sync_token=123456789, country_iso_code=u'US')
|
||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', service=u'US', sync_token=123456789, country_iso_code=u'US')
|
||||
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', service=None, sync_token=123456789, country_iso_code=None)
|
6
util/itertoolrecipes.py
Normal file
6
util/itertoolrecipes.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from itertools import islice
|
||||
|
||||
# From: https://docs.python.org/2/library/itertools.html
|
||||
def take(n, iterable):
|
||||
""" Return first n items of the iterable as a list """
|
||||
return list(islice(iterable, n))
|
99
util/jinjautil.py
Normal file
99
util/jinjautil.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from app import get_app_url, avatar
|
||||
from data import model
|
||||
from util.names import parse_robot_username
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
|
||||
def icon_path(icon_name):
|
||||
return '%s/static/img/icons/%s.png' % (get_app_url(), icon_name)
|
||||
|
||||
def icon_image(icon_name):
|
||||
return '<img src="%s" alt="%s">' % (icon_path(icon_name), icon_name)
|
||||
|
||||
def team_reference(teamname):
|
||||
avatar_html = avatar.get_mail_html(teamname, teamname, 24, 'team')
|
||||
return "<span>%s <b>%s</b></span>" % (avatar_html, teamname)
|
||||
|
||||
def user_reference(username):
|
||||
user = model.user.get_namespace_user(username)
|
||||
if not user:
|
||||
return username
|
||||
|
||||
if user.robot:
|
||||
parts = parse_robot_username(username)
|
||||
user = model.user.get_namespace_user(parts[0])
|
||||
|
||||
return """<span><img src="%s" alt="Robot"> <b>%s</b></span>""" % (icon_path('wrench'), username)
|
||||
|
||||
avatar_html = avatar.get_mail_html(user.username, user.email, 24,
|
||||
'org' if user.organization else 'user')
|
||||
|
||||
return """
|
||||
<span>
|
||||
%s
|
||||
<b>%s</b>
|
||||
</span>""" % (avatar_html, username)
|
||||
|
||||
|
||||
def repository_tag_reference(repository_path_and_tag):
|
||||
(repository_path, tag) = repository_path_and_tag
|
||||
(namespace, repository) = repository_path.split('/')
|
||||
owner = model.user.get_namespace_user(namespace)
|
||||
if not owner:
|
||||
return tag
|
||||
|
||||
return """<a href="%s/repository/%s/%s?tag=%s&tab=tags">%s</a>""" % (get_app_url(), namespace,
|
||||
repository, tag, tag)
|
||||
|
||||
def repository_reference(pair):
|
||||
if isinstance(pair, tuple):
|
||||
(namespace, repository) = pair
|
||||
else:
|
||||
pair = pair.split('/')
|
||||
namespace = pair[0]
|
||||
repository = pair[1]
|
||||
|
||||
owner = model.user.get_namespace_user(namespace)
|
||||
if not owner:
|
||||
return "%s/%s" % (namespace, repository)
|
||||
|
||||
avatar_html = avatar.get_mail_html(owner.username, owner.email, 16,
|
||||
'org' if owner.organization else 'user')
|
||||
|
||||
return """
|
||||
<span style="white-space: nowrap;">
|
||||
%s
|
||||
<a href="%s/repository/%s/%s">%s/%s</a>
|
||||
</span>
|
||||
""" % (avatar_html, get_app_url(), namespace, repository, namespace, repository)
|
||||
|
||||
|
||||
def admin_reference(username):
|
||||
user = model.user.get_user_or_org(username)
|
||||
if not user:
|
||||
return 'account settings'
|
||||
|
||||
if user.organization:
|
||||
return """
|
||||
<a href="%s/organization/%s?tab=settings">organization's admin setting</a>
|
||||
""" % (get_app_url(), username)
|
||||
else:
|
||||
return """
|
||||
<a href="%s/user/">account settings</a>
|
||||
""" % (get_app_url())
|
||||
|
||||
|
||||
def get_template_env(searchpath):
|
||||
template_loader = FileSystemLoader(searchpath=searchpath)
|
||||
template_env = Environment(loader=template_loader)
|
||||
add_filters(template_env)
|
||||
return template_env
|
||||
|
||||
|
||||
def add_filters(template_env):
|
||||
template_env.filters['icon_image'] = icon_image
|
||||
template_env.filters['team_reference'] = team_reference
|
||||
template_env.filters['user_reference'] = user_reference
|
||||
template_env.filters['admin_reference'] = admin_reference
|
||||
template_env.filters['repository_reference'] = repository_reference
|
||||
template_env.filters['repository_tag_reference'] = repository_tag_reference
|
20
util/label_validator.py
Normal file
20
util/label_validator.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
class LabelValidator(object):
|
||||
""" Helper class for validating that labels meet prefix requirements. """
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
overridden_prefixes = app.config.get('LABEL_KEY_RESERVED_PREFIXES', [])
|
||||
for prefix in overridden_prefixes:
|
||||
if not prefix.endswith('.'):
|
||||
raise Exception('Prefix "%s" in LABEL_KEY_RESERVED_PREFIXES must end in a dot', prefix)
|
||||
|
||||
default_prefixes = app.config.get('DEFAULT_LABEL_KEY_RESERVED_PREFIXES', [])
|
||||
self.reserved_prefixed_set = set(default_prefixes + overridden_prefixes)
|
||||
|
||||
def has_reserved_prefix(self, label_key):
|
||||
""" Validates that the provided label key does not match any reserved prefixes. """
|
||||
for prefix in self.reserved_prefixed_set:
|
||||
if label_key.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
66
util/locking.py
Normal file
66
util/locking.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import logging
|
||||
|
||||
from redis import RedisError
|
||||
from redlock import RedLock, RedLockError
|
||||
|
||||
from app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockNotAcquiredException(Exception):
|
||||
""" Exception raised if a GlobalLock could not be acquired. """
|
||||
|
||||
|
||||
class GlobalLock(object):
|
||||
""" A lock object that blocks globally via Redis. Note that Redis is not considered a tier-1
|
||||
service, so this lock should not be used for any critical code paths.
|
||||
"""
|
||||
def __init__(self, name, lock_ttl=600):
|
||||
self._lock_name = name
|
||||
self._redis_info = dict(app.config['USER_EVENTS_REDIS'])
|
||||
self._redis_info.update({'socket_connect_timeout': 5,
|
||||
'socket_timeout': 5,
|
||||
'single_connection_client': True})
|
||||
|
||||
self._lock_ttl = lock_ttl
|
||||
self._redlock = None
|
||||
|
||||
def __enter__(self):
|
||||
if not self.acquire():
|
||||
raise LockNotAcquiredException()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.release()
|
||||
|
||||
def acquire(self):
|
||||
logger.debug('Acquiring global lock %s', self._lock_name)
|
||||
try:
|
||||
self._redlock = RedLock(self._lock_name, connection_details=[self._redis_info],
|
||||
ttl=self._lock_ttl)
|
||||
acquired = self._redlock.acquire()
|
||||
if not acquired:
|
||||
logger.debug('Was unable to not acquire lock %s', self._lock_name)
|
||||
return False
|
||||
|
||||
logger.debug('Acquired lock %s', self._lock_name)
|
||||
return True
|
||||
except RedLockError:
|
||||
logger.debug('Could not acquire lock %s', self._lock_name)
|
||||
return False
|
||||
except RedisError as re:
|
||||
logger.debug('Could not connect to Redis for lock %s: %s', self._lock_name, re)
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
if self._redlock is not None:
|
||||
logger.debug('Releasing lock %s', self._lock_name)
|
||||
try:
|
||||
self._redlock.release()
|
||||
except RedLockError:
|
||||
logger.debug('Could not release lock %s', self._lock_name)
|
||||
except RedisError as re:
|
||||
logger.debug('Could not connect to Redis for releasing lock %s: %s', self._lock_name, re)
|
||||
|
||||
logger.debug('Released lock %s', self._lock_name)
|
||||
self._redlock = None
|
47
util/log.py
Normal file
47
util/log.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
from _init import CONF_DIR
|
||||
|
||||
|
||||
def logfile_path(jsonfmt=False, debug=False):
|
||||
"""
|
||||
Returns the a logfileconf path following this rules:
|
||||
- conf/logging_debug_json.conf # jsonfmt=true, debug=true
|
||||
- conf/logging_json.conf # jsonfmt=true, debug=false
|
||||
- conf/logging_debug.conf # jsonfmt=false, debug=true
|
||||
- conf/logging.conf # jsonfmt=false, debug=false
|
||||
Can be parametrized via envvars: JSONLOG=true, DEBUGLOG=true
|
||||
"""
|
||||
_json = ""
|
||||
_debug = ""
|
||||
|
||||
if jsonfmt or os.getenv('JSONLOG', 'false').lower() == 'true':
|
||||
_json = "_json"
|
||||
|
||||
if debug or os.getenv('DEBUGLOG', 'false').lower() == 'true':
|
||||
_debug = "_debug"
|
||||
|
||||
return os.path.join(CONF_DIR, "logging%s%s.conf" % (_debug, _json))
|
||||
|
||||
|
||||
def filter_logs(values, filtered_fields):
|
||||
"""
|
||||
Takes a dict and a list of keys to filter.
|
||||
eg:
|
||||
with filtered_fields:
|
||||
[{'key': ['k1', k2'], 'fn': lambda x: 'filtered'}]
|
||||
and values:
|
||||
{'k1': {'k2': 'some-secret'}, 'k3': 'some-value'}
|
||||
the returned dict is:
|
||||
{'k1': {k2: 'filtered'}, 'k3': 'some-value'}
|
||||
"""
|
||||
for field in filtered_fields:
|
||||
cdict = values
|
||||
|
||||
for key in field['key'][:-1]:
|
||||
if key in cdict:
|
||||
cdict = cdict[key]
|
||||
|
||||
last_key = field['key'][-1]
|
||||
|
||||
if last_key in cdict and cdict[last_key]:
|
||||
cdict[last_key] = field['fn'](cdict[last_key])
|
0
util/metrics/__init__.py
Normal file
0
util/metrics/__init__.py
Normal file
210
util/metrics/metricqueue.py
Normal file
210
util/metrics/metricqueue.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
from functools import wraps
|
||||
from Queue import Queue, Full
|
||||
|
||||
from flask import g, request
|
||||
from trollius import Return
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Buckets for the API response times.
|
||||
API_RESPONSE_TIME_BUCKETS = [.01, .025, .05, .1, .25, .5, 1.0, 2.5, 5.0]
|
||||
|
||||
# Buckets for the builder start times.
|
||||
BUILDER_START_TIME_BUCKETS = [.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0, 180.0, 240.0, 300.0, 600.0]
|
||||
|
||||
|
||||
class MetricQueue(object):
|
||||
""" Object to which various metrics are written, for distribution to metrics collection
|
||||
system(s) such as Prometheus.
|
||||
"""
|
||||
def __init__(self, prom):
|
||||
# Define the various exported metrics.
|
||||
self.resp_time = prom.create_histogram('response_time', 'HTTP response time in seconds',
|
||||
labelnames=['endpoint'],
|
||||
buckets=API_RESPONSE_TIME_BUCKETS)
|
||||
self.resp_code = prom.create_counter('response_code', 'HTTP response code',
|
||||
labelnames=['endpoint', 'code'])
|
||||
self.non_200 = prom.create_counter('response_non200', 'Non-200 HTTP response codes',
|
||||
labelnames=['endpoint'])
|
||||
self.error_500 = prom.create_counter('response_500', '5XX HTTP response codes',
|
||||
labelnames=['endpoint'])
|
||||
self.multipart_upload_start = prom.create_counter('multipart_upload_start',
|
||||
'Multipart upload started')
|
||||
self.multipart_upload_end = prom.create_counter('multipart_upload_end',
|
||||
'Multipart upload ends.', labelnames=['type'])
|
||||
self.build_capacity_shortage = prom.create_gauge('build_capacity_shortage',
|
||||
'Build capacity shortage.')
|
||||
self.builder_time_to_start = prom.create_histogram('builder_tts',
|
||||
'Time from triggering to starting a builder.',
|
||||
labelnames=['builder_type'],
|
||||
buckets=BUILDER_START_TIME_BUCKETS)
|
||||
self.builder_time_to_build = prom.create_histogram('builder_ttb',
|
||||
'Time from triggering to actually starting a build',
|
||||
labelnames=['builder_type'],
|
||||
buckets=BUILDER_START_TIME_BUCKETS)
|
||||
self.build_time = prom.create_histogram('build_time', 'Time spent building', labelnames=['builder_type'])
|
||||
self.builder_fallback = prom.create_counter('builder_fallback', 'Builder fell back to secondary executor')
|
||||
self.build_start_success = prom.create_counter('build_start_success', 'Executor succeeded in starting a build', labelnames=['builder_type'])
|
||||
self.build_start_failure = prom.create_counter('build_start_failure', 'Executor failed to start a build', labelnames=['builder_type'])
|
||||
self.percent_building = prom.create_gauge('build_percent_building', 'Percent building.')
|
||||
self.build_counter = prom.create_counter('builds', 'Number of builds', labelnames=['name'])
|
||||
self.ephemeral_build_workers = prom.create_counter('ephemeral_build_workers',
|
||||
'Number of started ephemeral build workers')
|
||||
self.ephemeral_build_worker_failure = prom.create_counter('ephemeral_build_worker_failure',
|
||||
'Number of failed-to-start ephemeral build workers')
|
||||
|
||||
self.work_queue_running = prom.create_gauge('work_queue_running', 'Running items in a queue',
|
||||
labelnames=['queue_name'])
|
||||
self.work_queue_available = prom.create_gauge('work_queue_available',
|
||||
'Available items in a queue',
|
||||
labelnames=['queue_name'])
|
||||
|
||||
self.work_queue_available_not_running = prom.create_gauge('work_queue_available_not_running',
|
||||
'Available items that are not yet running',
|
||||
labelnames=['queue_name'])
|
||||
|
||||
self.repository_pull = prom.create_counter('repository_pull', 'Repository Pull Count',
|
||||
labelnames=['namespace', 'repo_name', 'protocol',
|
||||
'status'])
|
||||
|
||||
self.repository_push = prom.create_counter('repository_push', 'Repository Push Count',
|
||||
labelnames=['namespace', 'repo_name', 'protocol',
|
||||
'status'])
|
||||
|
||||
self.repository_build_queued = prom.create_counter('repository_build_queued',
|
||||
'Repository Build Queued Count',
|
||||
labelnames=['namespace', 'repo_name'])
|
||||
|
||||
self.repository_build_completed = prom.create_counter('repository_build_completed',
|
||||
'Repository Build Complete Count',
|
||||
labelnames=['namespace', 'repo_name',
|
||||
'status', 'executor'])
|
||||
|
||||
self.chunk_size = prom.create_histogram('chunk_size',
|
||||
'Registry blob chunk size',
|
||||
labelnames=['storage_region'])
|
||||
|
||||
self.chunk_upload_time = prom.create_histogram('chunk_upload_time',
|
||||
'Registry blob chunk upload time',
|
||||
labelnames=['storage_region'])
|
||||
|
||||
self.authentication_count = prom.create_counter('authentication_count',
|
||||
'Authentication count',
|
||||
labelnames=['kind', 'status'])
|
||||
|
||||
self.repository_count = prom.create_gauge('repository_count', 'Number of repositories')
|
||||
self.user_count = prom.create_gauge('user_count', 'Number of users')
|
||||
self.org_count = prom.create_gauge('org_count', 'Number of Organizations')
|
||||
self.robot_count = prom.create_gauge('robot_count', 'Number of robot accounts')
|
||||
|
||||
self.instance_key_renewal_success = prom.create_counter('instance_key_renewal_success',
|
||||
'Instance Key Renewal Success Count',
|
||||
labelnames=['key_id'])
|
||||
|
||||
self.instance_key_renewal_failure = prom.create_counter('instance_key_renewal_failure',
|
||||
'Instance Key Renewal Failure Count',
|
||||
labelnames=['key_id'])
|
||||
|
||||
self.invalid_instance_key_count = prom.create_counter('invalid_registry_instance_key_count',
|
||||
'Invalid registry instance key count',
|
||||
labelnames=['key_id'])
|
||||
|
||||
self.verb_action_passes = prom.create_counter('verb_action_passes', 'Verb Pass Count',
|
||||
labelnames=['kind', 'pass_count'])
|
||||
|
||||
self.push_byte_count = prom.create_counter('registry_push_byte_count',
|
||||
'Number of bytes pushed to the registry')
|
||||
|
||||
self.pull_byte_count = prom.create_counter('estimated_registry_pull_byte_count',
|
||||
'Number of (estimated) bytes pulled from the registry',
|
||||
labelnames=['protocol_version'])
|
||||
|
||||
# Deprecated: Define an in-memory queue for reporting metrics to CloudWatch or another
|
||||
# provider.
|
||||
self._queue = None
|
||||
|
||||
def enable_deprecated(self, maxsize=10000):
|
||||
self._queue = Queue(maxsize)
|
||||
|
||||
def put_deprecated(self, name, value, **kwargs):
|
||||
if self._queue is None:
|
||||
logger.debug('No metric queue %s %s %s', name, value, kwargs)
|
||||
return
|
||||
|
||||
try:
|
||||
kwargs.setdefault('timestamp', datetime.datetime.now())
|
||||
kwargs.setdefault('dimensions', {})
|
||||
self._queue.put_nowait((name, value, kwargs))
|
||||
except Full:
|
||||
logger.error('Metric queue full')
|
||||
|
||||
def get_deprecated(self):
|
||||
return self._queue.get()
|
||||
|
||||
def get_nowait_deprecated(self):
|
||||
return self._queue.get_nowait()
|
||||
|
||||
|
||||
def duration_collector_async(metric, labelvalues):
|
||||
""" Decorates a method to have its duration time logged to the metric. """
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
trigger_time = time.time()
|
||||
try:
|
||||
rv = func(*args, **kwargs)
|
||||
except Return as e:
|
||||
metric.Observe(time.time() - trigger_time, labelvalues=labelvalues)
|
||||
raise e
|
||||
return rv
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def time_decorator(name, metric_queue):
|
||||
""" Decorates an endpoint method to have its request time logged to the metrics queue. """
|
||||
after = _time_after_request(name, metric_queue)
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_time_before_request()
|
||||
rv = func(*args, **kwargs)
|
||||
after(rv)
|
||||
return rv
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def time_blueprint(bp, metric_queue):
|
||||
""" Decorates a blueprint to have its request time logged to the metrics queue. """
|
||||
bp.before_request(_time_before_request)
|
||||
bp.after_request(_time_after_request(bp.name, metric_queue))
|
||||
|
||||
|
||||
def _time_before_request():
|
||||
g._request_start_time = time.time()
|
||||
|
||||
|
||||
def _time_after_request(name, metric_queue):
|
||||
def f(r):
|
||||
start = getattr(g, '_request_start_time', None)
|
||||
if start is None:
|
||||
return r
|
||||
|
||||
dur = time.time() - start
|
||||
|
||||
metric_queue.resp_time.Observe(dur, labelvalues=[request.endpoint])
|
||||
metric_queue.resp_code.Inc(labelvalues=[request.endpoint, r.status_code])
|
||||
|
||||
if r.status_code >= 500:
|
||||
metric_queue.error_500.Inc(labelvalues=[request.endpoint])
|
||||
elif r.status_code < 200 or r.status_code >= 300:
|
||||
metric_queue.non_200.Inc(labelvalues=[request.endpoint])
|
||||
|
||||
return r
|
||||
return f
|
168
util/metrics/prometheus.py
Normal file
168
util/metrics/prometheus.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
||||
from Queue import Queue, Full, Empty
|
||||
from threading import Thread
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUEUE_MAX = 1000
|
||||
MAX_BATCH_SIZE = 100
|
||||
REGISTER_WAIT = datetime.timedelta(hours=1)
|
||||
|
||||
class PrometheusPlugin(object):
|
||||
""" Application plugin for reporting metrics to Prometheus. """
|
||||
def __init__(self, app=None):
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.state = self.init_app(app)
|
||||
else:
|
||||
self.state = None
|
||||
|
||||
def init_app(self, app):
|
||||
prom_url = app.config.get('PROMETHEUS_AGGREGATOR_URL')
|
||||
prom_namespace = app.config.get('PROMETHEUS_NAMESPACE')
|
||||
logger.debug('Initializing prometheus with aggregator url: %s', prom_url)
|
||||
prometheus = Prometheus(prom_url, prom_namespace)
|
||||
|
||||
# register extension with app
|
||||
app.extensions = getattr(app, 'extensions', {})
|
||||
app.extensions['prometheus'] = prometheus
|
||||
return prometheus
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.state, name, None)
|
||||
|
||||
|
||||
class Prometheus(object):
|
||||
""" Aggregator for collecting stats that are reported to Prometheus. """
|
||||
def __init__(self, url=None, namespace=None):
|
||||
self._metric_collectors = []
|
||||
self._url = url
|
||||
self._namespace = namespace or ''
|
||||
|
||||
if url is not None:
|
||||
self._queue = Queue(QUEUE_MAX)
|
||||
self._sender = _QueueSender(self._queue, url, self._metric_collectors)
|
||||
self._sender.start()
|
||||
logger.debug('Prometheus aggregator sending to %s', url)
|
||||
else:
|
||||
self._queue = None
|
||||
logger.debug('Prometheus aggregator disabled')
|
||||
|
||||
def enqueue(self, call, data):
|
||||
if not self._queue:
|
||||
return
|
||||
|
||||
v = json.dumps({
|
||||
'Call': call,
|
||||
'Data': data,
|
||||
})
|
||||
|
||||
if call == 'register':
|
||||
self._metric_collectors.append(v)
|
||||
return
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(v)
|
||||
except Full:
|
||||
# If the queue is full, it is because 1) no aggregator was enabled or 2)
|
||||
# the aggregator is taking a long time to respond to requests. In the case
|
||||
# of 1, it's probably enterprise mode and we don't care. In the case of 2,
|
||||
# the response timeout error is printed inside the queue handler. In either case,
|
||||
# we don't need to print an error here.
|
||||
pass
|
||||
|
||||
def create_gauge(self, *args, **kwargs):
|
||||
return self._create_collector('Gauge', args, kwargs)
|
||||
|
||||
def create_counter(self, *args, **kwargs):
|
||||
return self._create_collector('Counter', args, kwargs)
|
||||
|
||||
def create_summary(self, *args, **kwargs):
|
||||
return self._create_collector('Summary', args, kwargs)
|
||||
|
||||
def create_histogram(self, *args, **kwargs):
|
||||
return self._create_collector('Histogram', args, kwargs)
|
||||
|
||||
def create_untyped(self, *args, **kwargs):
|
||||
return self._create_collector('Untyped', args, kwargs)
|
||||
|
||||
def _create_collector(self, collector_type, args, kwargs):
|
||||
kwargs['namespace'] = kwargs.get('namespace', self._namespace)
|
||||
return _Collector(self.enqueue, collector_type, *args, **kwargs)
|
||||
|
||||
|
||||
class _QueueSender(Thread):
|
||||
""" Helper class which uses a thread to asynchronously send metrics to the local Prometheus
|
||||
aggregator. """
|
||||
def __init__(self, queue, url, metric_collectors):
|
||||
Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.next_register = datetime.datetime.now()
|
||||
self._queue = queue
|
||||
self._url = url
|
||||
self._metric_collectors = metric_collectors
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
reqs = []
|
||||
reqs.append(self._queue.get())
|
||||
|
||||
while len(reqs) < MAX_BATCH_SIZE:
|
||||
try:
|
||||
req = self._queue.get_nowait()
|
||||
reqs.append(req)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
try:
|
||||
resp = requests.post(self._url + '/call', '\n'.join(reqs))
|
||||
if resp.status_code == 500 and self.next_register <= datetime.datetime.now():
|
||||
resp = requests.post(self._url + '/call', '\n'.join(self._metric_collectors))
|
||||
self.next_register = datetime.datetime.now() + REGISTER_WAIT
|
||||
logger.debug('Register returned %s for %s metrics; setting next to %s', resp.status_code,
|
||||
len(self._metric_collectors), self.next_register)
|
||||
elif resp.status_code != 200:
|
||||
logger.debug('Failed sending to prometheus: %s: %s: %s', resp.status_code, resp.text,
|
||||
', '.join(reqs))
|
||||
else:
|
||||
logger.debug('Sent %d prometheus metrics', len(reqs))
|
||||
except:
|
||||
logger.exception('Failed to write to prometheus aggregator: %s', reqs)
|
||||
|
||||
|
||||
class _Collector(object):
|
||||
""" Collector for a Prometheus metric. """
|
||||
def __init__(self, enqueue_method, collector_type, collector_name, collector_help,
|
||||
namespace='', subsystem='', **kwargs):
|
||||
self._enqueue_method = enqueue_method
|
||||
self._base_args = {
|
||||
'Name': collector_name,
|
||||
'Namespace': namespace,
|
||||
'Subsystem': subsystem,
|
||||
'Type': collector_type,
|
||||
}
|
||||
|
||||
registration_params = dict(kwargs)
|
||||
registration_params.update(self._base_args)
|
||||
registration_params['Help'] = collector_help
|
||||
|
||||
self._enqueue_method('register', registration_params)
|
||||
|
||||
def __getattr__(self, method):
|
||||
def f(value=0, labelvalues=()):
|
||||
data = dict(self._base_args)
|
||||
data.update({
|
||||
'Value': value,
|
||||
'LabelValues': [str(i) for i in labelvalues],
|
||||
'Method': method,
|
||||
})
|
||||
|
||||
self._enqueue_method('put', data)
|
||||
|
||||
return f
|
58
util/metrics/test/test_metricqueue.py
Normal file
58
util/metrics/test/test_metricqueue.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from mock import Mock
|
||||
from trollius import coroutine, Return, get_event_loop, From
|
||||
|
||||
from util.metrics.metricqueue import duration_collector_async
|
||||
|
||||
|
||||
mock_histogram = Mock()
|
||||
|
||||
class NonReturn(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@coroutine
|
||||
@duration_collector_async(mock_histogram, labelvalues=["testlabel"])
|
||||
def duration_decorated():
|
||||
time.sleep(1)
|
||||
raise Return("fin")
|
||||
|
||||
|
||||
@coroutine
|
||||
@duration_collector_async(mock_histogram, labelvalues=["testlabel"])
|
||||
def duration_decorated_error():
|
||||
raise NonReturn("not a Return error")
|
||||
|
||||
@coroutine
|
||||
def calls_decorated():
|
||||
yield From(duration_decorated())
|
||||
|
||||
|
||||
def test_duration_decorator():
|
||||
loop = get_event_loop()
|
||||
loop.run_until_complete(duration_decorated())
|
||||
assert mock_histogram.Observe.called
|
||||
assert 1 - mock_histogram.Observe.call_args[0][0] < 1 # duration should be close to 1s
|
||||
assert mock_histogram.Observe.call_args[1]["labelvalues"] == ["testlabel"]
|
||||
|
||||
|
||||
def test_duration_decorator_error():
|
||||
loop = get_event_loop()
|
||||
mock_histogram.reset_mock()
|
||||
|
||||
with pytest.raises(NonReturn):
|
||||
loop.run_until_complete(duration_decorated_error())
|
||||
assert not mock_histogram.Observe.called
|
||||
|
||||
|
||||
def test_duration_decorator_caller():
|
||||
mock_histogram.reset_mock()
|
||||
|
||||
loop = get_event_loop()
|
||||
loop.run_until_complete(calls_decorated())
|
||||
assert mock_histogram.Observe.called
|
||||
assert 1 - mock_histogram.Observe.call_args[0][0] < 1 # duration should be close to 1s
|
||||
assert mock_histogram.Observe.call_args[1]["labelvalues"] == ["testlabel"]
|
38
util/migrate/__init__.py
Normal file
38
util/migrate/__init__.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import logging
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, Text, String
|
||||
from sqlalchemy.dialects.mysql import TEXT as MySQLText, LONGTEXT, VARCHAR as MySQLString
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UTF8LongText(TypeDecorator):
|
||||
""" Platform-independent UTF-8 LONGTEXT type.
|
||||
|
||||
Uses MySQL's LongText with charset utf8mb4, otherwise uses TEXT, because
|
||||
other engines default to UTF-8 and have longer TEXT fields.
|
||||
"""
|
||||
impl = Text
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'mysql':
|
||||
return dialect.type_descriptor(LONGTEXT(charset='utf8mb4', collation='utf8mb4_unicode_ci'))
|
||||
else:
|
||||
return dialect.type_descriptor(Text())
|
||||
|
||||
|
||||
class UTF8CharField(TypeDecorator):
|
||||
""" Platform-independent UTF-8 Char type.
|
||||
|
||||
Uses MySQL's VARCHAR with charset utf8mb4, otherwise uses String, because
|
||||
other engines default to UTF-8.
|
||||
"""
|
||||
impl = String
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'mysql':
|
||||
return dialect.type_descriptor(MySQLString(charset='utf8mb4', collation='utf8mb4_unicode_ci',
|
||||
length=self.impl.length))
|
||||
else:
|
||||
return dialect.type_descriptor(String(length=self.impl.length))
|
175
util/migrate/allocator.py
Normal file
175
util/migrate/allocator.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
from bintrees import RBTree
|
||||
from threading import Event
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class NoAvailableKeysError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class CompletedKeys(object):
|
||||
def __init__(self, max_index, min_index=0):
|
||||
self._max_index = max_index
|
||||
self._min_index = min_index
|
||||
self.num_remaining = max_index - min_index
|
||||
self._slabs = RBTree()
|
||||
|
||||
def _get_previous_or_none(self, index):
|
||||
try:
|
||||
return self._slabs.floor_item(index)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def is_available(self, index):
|
||||
logger.debug('Testing index %s', index)
|
||||
if index >= self._max_index or index < self._min_index:
|
||||
logger.debug('Index out of range')
|
||||
return False
|
||||
|
||||
try:
|
||||
prev_start, prev_length = self._slabs.floor_item(index)
|
||||
logger.debug('Prev range: %s-%s', prev_start, prev_start + prev_length)
|
||||
return (prev_start + prev_length) <= index
|
||||
except KeyError:
|
||||
return True
|
||||
|
||||
def mark_completed(self, start_index, past_last_index):
|
||||
logger.debug('Marking the range completed: %s-%s', start_index, past_last_index)
|
||||
num_completed = min(past_last_index, self._max_index) - max(start_index, self._min_index)
|
||||
|
||||
# Find the item directly before this and see if there is overlap
|
||||
to_discard = set()
|
||||
try:
|
||||
prev_start, prev_length = self._slabs.floor_item(start_index)
|
||||
max_prev_completed = prev_start + prev_length
|
||||
if max_prev_completed >= start_index:
|
||||
# we are going to merge with the range before us
|
||||
logger.debug('Merging with the prev range: %s-%s', prev_start, prev_start + prev_length)
|
||||
to_discard.add(prev_start)
|
||||
num_completed = max(num_completed - (max_prev_completed - start_index), 0)
|
||||
start_index = prev_start
|
||||
past_last_index = max(past_last_index, prev_start + prev_length)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Find all keys between the start and last index and merge them into one block
|
||||
for merge_start, merge_length in self._slabs.iter_items(start_index, past_last_index + 1):
|
||||
if merge_start in to_discard:
|
||||
logger.debug('Already merged with block %s-%s', merge_start, merge_start + merge_length)
|
||||
continue
|
||||
|
||||
candidate_next_index = merge_start + merge_length
|
||||
logger.debug('Merging with block %s-%s', merge_start, candidate_next_index)
|
||||
num_completed -= merge_length - max(candidate_next_index - past_last_index, 0)
|
||||
to_discard.add(merge_start)
|
||||
past_last_index = max(past_last_index, candidate_next_index)
|
||||
|
||||
# write the new block which is fully merged
|
||||
discard = False
|
||||
if past_last_index >= self._max_index:
|
||||
logger.debug('Discarding block and setting new max to: %s', start_index)
|
||||
self._max_index = start_index
|
||||
discard = True
|
||||
|
||||
if start_index <= self._min_index:
|
||||
logger.debug('Discarding block and setting new min to: %s', past_last_index)
|
||||
self._min_index = past_last_index
|
||||
discard = True
|
||||
|
||||
if to_discard:
|
||||
logger.debug('Discarding %s obsolete blocks', len(to_discard))
|
||||
self._slabs.remove_items(to_discard)
|
||||
|
||||
if not discard:
|
||||
logger.debug('Writing new block with range: %s-%s', start_index, past_last_index)
|
||||
self._slabs.insert(start_index, past_last_index - start_index)
|
||||
|
||||
# Update the number of remaining items with the adjustments we've made
|
||||
assert num_completed >= 0
|
||||
self.num_remaining -= num_completed
|
||||
logger.debug('Total blocks: %s', len(self._slabs))
|
||||
|
||||
def get_block_start_index(self, block_size_estimate):
|
||||
logger.debug('Total range: %s-%s', self._min_index, self._max_index)
|
||||
if self._max_index <= self._min_index:
|
||||
raise NoAvailableKeysError('All indexes have been marked completed')
|
||||
|
||||
num_holes = len(self._slabs) + 1
|
||||
random_hole = random.randint(0, num_holes - 1)
|
||||
logger.debug('Selected random hole %s with %s total holes', random_hole, num_holes)
|
||||
|
||||
hole_start = self._min_index
|
||||
past_hole_end = self._max_index
|
||||
|
||||
# Now that we have picked a hole, we need to define the bounds
|
||||
if random_hole > 0:
|
||||
# There will be a slab before this hole, find where it ends
|
||||
bound_entries = self._slabs.nsmallest(random_hole + 1)[-2:]
|
||||
left_index, left_len = bound_entries[0]
|
||||
logger.debug('Left range %s-%s', left_index, left_index + left_len)
|
||||
hole_start = left_index + left_len
|
||||
|
||||
if len(bound_entries) > 1:
|
||||
right_index, right_len = bound_entries[1]
|
||||
logger.debug('Right range %s-%s', right_index, right_index + right_len)
|
||||
past_hole_end, _ = bound_entries[1]
|
||||
elif not self._slabs.is_empty():
|
||||
right_index, right_len = self._slabs.nsmallest(1)[0]
|
||||
logger.debug('Right range %s-%s', right_index, right_index + right_len)
|
||||
past_hole_end, _ = self._slabs.nsmallest(1)[0]
|
||||
|
||||
# Now that we have our hole bounds, select a random block from [0:len - block_size_estimate]
|
||||
logger.debug('Selecting from hole range: %s-%s', hole_start, past_hole_end)
|
||||
rand_max_bound = max(hole_start, past_hole_end - block_size_estimate)
|
||||
logger.debug('Rand max bound: %s', rand_max_bound)
|
||||
return random.randint(hole_start, rand_max_bound)
|
||||
|
||||
|
||||
def yield_random_entries(batch_query, primary_key_field, batch_size, max_id, min_id=0):
|
||||
""" This method will yield items from random blocks in the database. We will track metadata
|
||||
about which keys are available for work, and we will complete the backfill when there is no
|
||||
more work to be done. The method yields tuples of (candidate, Event), and if the work was
|
||||
already done by another worker, the caller should set the event. Batch candidates must have
|
||||
an "id" field which can be inspected.
|
||||
"""
|
||||
|
||||
min_id = max(min_id, 0)
|
||||
max_id = max(max_id, 1)
|
||||
allocator = CompletedKeys(max_id + 1, min_id)
|
||||
|
||||
try:
|
||||
while True:
|
||||
start_index = allocator.get_block_start_index(batch_size)
|
||||
end_index = min(start_index + batch_size, max_id + 1)
|
||||
all_candidates = list(batch_query()
|
||||
.where(primary_key_field >= start_index,
|
||||
primary_key_field < end_index)
|
||||
.order_by(primary_key_field))
|
||||
|
||||
if len(all_candidates) == 0:
|
||||
logger.info('No candidates, marking entire block completed %s-%s', start_index, end_index)
|
||||
allocator.mark_completed(start_index, end_index)
|
||||
continue
|
||||
|
||||
logger.info('Found %s candidates, processing block', len(all_candidates))
|
||||
batch_completed = 0
|
||||
for candidate in all_candidates:
|
||||
abort_early = Event()
|
||||
yield candidate, abort_early, allocator.num_remaining - batch_completed
|
||||
batch_completed += 1
|
||||
if abort_early.is_set():
|
||||
logger.info('Overlap with another worker, aborting')
|
||||
break
|
||||
|
||||
completed_through = candidate.id + 1
|
||||
logger.info('Marking id range as completed: %s-%s', start_index, completed_through)
|
||||
allocator.mark_completed(start_index, completed_through)
|
||||
|
||||
except NoAvailableKeysError:
|
||||
logger.info('No more work')
|
54
util/migrate/cleanup_old_robots.py
Normal file
54
util/migrate/cleanup_old_robots.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import logging
|
||||
|
||||
from app import app
|
||||
from data.database import User
|
||||
from util.names import parse_robot_username
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def cleanup_old_robots(page_size=50, force=False):
|
||||
""" Deletes any robots that live under namespaces that no longer exist. """
|
||||
if not force and not app.config.get('SETUP_COMPLETE', False):
|
||||
return
|
||||
|
||||
# Collect the robot accounts to delete.
|
||||
page_number = 1
|
||||
to_delete = []
|
||||
encountered_namespaces = {}
|
||||
|
||||
while True:
|
||||
found_bots = False
|
||||
for robot in list(User.select().where(User.robot == True).paginate(page_number, page_size)):
|
||||
found_bots = True
|
||||
logger.info("Checking robot %s (page %s)", robot.username, page_number)
|
||||
parsed = parse_robot_username(robot.username)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
namespace, _ = parsed
|
||||
if namespace in encountered_namespaces:
|
||||
if not encountered_namespaces[namespace]:
|
||||
logger.info('Marking %s to be deleted', robot.username)
|
||||
to_delete.append(robot)
|
||||
else:
|
||||
try:
|
||||
User.get(username=namespace)
|
||||
encountered_namespaces[namespace] = True
|
||||
except User.DoesNotExist:
|
||||
# Save the robot account for deletion.
|
||||
logger.info('Marking %s to be deleted', robot.username)
|
||||
to_delete.append(robot)
|
||||
encountered_namespaces[namespace] = False
|
||||
|
||||
if not found_bots:
|
||||
break
|
||||
|
||||
page_number = page_number + 1
|
||||
|
||||
# Cleanup any robot accounts whose corresponding namespace doesn't exist.
|
||||
logger.info('Found %s robots to delete', len(to_delete))
|
||||
for index, robot in enumerate(to_delete):
|
||||
logger.info('Deleting robot %s of %s (%s)', index, len(to_delete), robot.username)
|
||||
robot.delete_instance(recursive=True, delete_nullable=True)
|
48
util/migrate/delete_access_tokens.py
Normal file
48
util/migrate/delete_access_tokens.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from data.database import RepositoryBuild, AccessToken
|
||||
from app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
def delete_temporary_access_tokens(older_than):
|
||||
# Find the highest ID up to which we should delete
|
||||
up_to_id = (AccessToken
|
||||
.select(AccessToken.id)
|
||||
.where(AccessToken.created < older_than)
|
||||
.limit(1)
|
||||
.order_by(AccessToken.id.desc())
|
||||
.get().id)
|
||||
logger.debug('Deleting temporary access tokens with ids lower than: %s', up_to_id)
|
||||
|
||||
|
||||
access_tokens_in_builds = (RepositoryBuild.select(RepositoryBuild.access_token).distinct())
|
||||
|
||||
while up_to_id > 0:
|
||||
starting_at_id = max(up_to_id - BATCH_SIZE, 0)
|
||||
logger.debug('Deleting tokens with ids between %s and %s', starting_at_id, up_to_id)
|
||||
start_time = datetime.utcnow()
|
||||
(AccessToken
|
||||
.delete()
|
||||
.where(AccessToken.id >= starting_at_id,
|
||||
AccessToken.id < up_to_id,
|
||||
AccessToken.temporary == True,
|
||||
~(AccessToken.id << access_tokens_in_builds))
|
||||
.execute())
|
||||
|
||||
time_to_delete = datetime.utcnow() - start_time
|
||||
|
||||
up_to_id -= BATCH_SIZE
|
||||
|
||||
logger.debug('Sleeping for %s seconds', time_to_delete.total_seconds())
|
||||
time.sleep(time_to_delete.total_seconds())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
delete_temporary_access_tokens(datetime.utcnow() - timedelta(days=2))
|
13
util/migrate/table_ops.py
Normal file
13
util/migrate/table_ops.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
def copy_table_contents(source_table, destination_table, conn):
|
||||
if conn.engine.name == 'postgresql':
|
||||
conn.execute('INSERT INTO "%s" SELECT * FROM "%s"' % (destination_table, source_table))
|
||||
result = list(conn.execute('Select Max(id) from "%s"' % destination_table))[0]
|
||||
if result[0] is not None:
|
||||
new_start_id = result[0] + 1
|
||||
conn.execute('ALTER SEQUENCE "%s_id_seq" RESTART WITH %s' % (destination_table, new_start_id))
|
||||
else:
|
||||
conn.execute("INSERT INTO `%s` SELECT * FROM `%s` WHERE 1" % (destination_table, source_table))
|
||||
result = list(conn.execute('Select Max(id) from `%s` WHERE 1' % destination_table))[0]
|
||||
if result[0] is not None:
|
||||
new_start_id = result[0] + 1
|
||||
conn.execute("ALTER TABLE `%s` AUTO_INCREMENT = %s" % (destination_table, new_start_id))
|
Some files were not shown because too many files have changed in this diff Show more
Reference in a new issue