Reformat Python code with psf/black

This commit is contained in:
cclauss 2019-11-20 08:58:47 +01:00
parent f915352138
commit aa13f95ca5
746 changed files with 103596 additions and 76051 deletions

View file

@ -7,40 +7,45 @@ from util.config.provider import get_config_provider
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/")) CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/') STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/') STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/') STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, 'webfonts/') STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, "webfonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/') TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_TESTING = 'TEST' in os.environ IS_TESTING = "TEST" in os.environ
IS_BUILDING = 'BUILDING' in os.environ IS_BUILDING = "BUILDING" in os.environ
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, 'stack/') OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, "stack/")
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py', config_provider = get_config_provider(
testing=IS_TESTING, kubernetes=IS_KUBERNETES) OVERRIDE_CONFIG_DIRECTORY,
"config.yaml",
"config.py",
testing=IS_TESTING,
kubernetes=IS_KUBERNETES,
)
def _get_version_number_changelog(): def _get_version_number_changelog():
try: try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f: with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0) return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError: except IOError:
return '' return ""
def _get_git_sha(): def _get_git_sha():
if os.path.exists("GIT_HEAD"): if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f: with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read() return f.read()
else: else:
try: try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8] return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError, Exception): except (OSError, subprocess.CalledProcessError, Exception):
pass pass
return "unknown" return "unknown"
__version__ = _get_version_number_changelog() __version__ = _get_version_number_changelog()

View file

@ -1,22 +1,30 @@
from enum import Enum, unique from enum import Enum, unique
from data.migrationutil import DefinedDataMigration, MigrationPhase from data.migrationutil import DefinedDataMigration, MigrationPhase
@unique @unique
class ERTMigrationFlags(Enum): class ERTMigrationFlags(Enum):
""" Flags for the encrypted robot token migration. """ """ Flags for the encrypted robot token migration. """
READ_OLD_FIELDS = 'read-old'
WRITE_OLD_FIELDS = 'write-old' READ_OLD_FIELDS = "read-old"
WRITE_OLD_FIELDS = "write-old"
ActiveDataMigration = DefinedDataMigration( ActiveDataMigration = DefinedDataMigration(
'encrypted_robot_tokens', "encrypted_robot_tokens",
'ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE', "ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE",
[ [
MigrationPhase('add-new-fields', 'c13c8052f7a6', [ERTMigrationFlags.READ_OLD_FIELDS, MigrationPhase(
ERTMigrationFlags.WRITE_OLD_FIELDS]), "add-new-fields",
MigrationPhase('backfill-then-read-only-new', "c13c8052f7a6",
'703298a825c2', [ERTMigrationFlags.WRITE_OLD_FIELDS]), [ERTMigrationFlags.READ_OLD_FIELDS, ERTMigrationFlags.WRITE_OLD_FIELDS],
MigrationPhase('stop-writing-both', '703298a825c2', []), ),
MigrationPhase('remove-old-fields', 'c059b952ed76', []), MigrationPhase(
] "backfill-then-read-only-new",
"703298a825c2",
[ERTMigrationFlags.WRITE_OLD_FIELDS],
),
MigrationPhase("stop-writing-both", "703298a825c2", []),
MigrationPhase("remove-old-fields", "c059b952ed76", []),
],
) )

318
app.py
View file

@ -16,8 +16,14 @@ from werkzeug.exceptions import HTTPException
import features import features
from _init import (config_provider, CONF_DIR, IS_KUBERNETES, IS_TESTING, OVERRIDE_CONFIG_DIRECTORY, from _init import (
IS_BUILDING) config_provider,
CONF_DIR,
IS_KUBERNETES,
IS_TESTING,
OVERRIDE_CONFIG_DIRECTORY,
IS_BUILDING,
)
from auth.auth_context import get_authenticated_user from auth.auth_context import get_authenticated_user
from avatars.avatars import Avatar from avatars.avatars import Avatar
@ -35,7 +41,11 @@ from data.userevent import UserEventsBuilderModule
from data.userfiles import Userfiles from data.userfiles import Userfiles
from data.users import UserAuthentication from data.users import UserAuthentication
from data.registry_model import registry_model from data.registry_model import registry_model
from path_converters import RegexConverter, RepositoryPathConverter, APIRepositoryPathConverter from path_converters import (
RegexConverter,
RepositoryPathConverter,
APIRepositoryPathConverter,
)
from oauth.services.github import GithubOAuthService from oauth.services.github import GithubOAuthService
from oauth.services.gitlab import GitLabOAuthService from oauth.services.gitlab import GitLabOAuthService
from oauth.loginmanager import OAuthLoginManager from oauth.loginmanager import OAuthLoginManager
@ -62,13 +72,13 @@ from util.security.instancekeys import InstanceKeys
from util.security.signing import Signer from util.security.signing import Signer
OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, 'stack/config.yaml') OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, "stack/config.yaml")
OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, 'stack/config.py') OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, "stack/config.py")
OVERRIDE_CONFIG_KEY = 'QUAY_OVERRIDE_CONFIG' OVERRIDE_CONFIG_KEY = "QUAY_OVERRIDE_CONFIG"
DOCKER_V2_SIGNINGKEY_FILENAME = 'docker_v2.pem' DOCKER_V2_SIGNINGKEY_FILENAME = "docker_v2.pem"
INIT_SCRIPTS_LOCATION = '/conf/init/' INIT_SCRIPTS_LOCATION = "/conf/init/"
app = Flask(__name__) app = Flask(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,62 +89,75 @@ is_kubernetes = IS_KUBERNETES
is_building = IS_BUILDING is_building = IS_BUILDING
if is_testing: if is_testing:
from test.testconfig import TestConfig from test.testconfig import TestConfig
logger.debug('Loading test config.')
app.config.from_object(TestConfig()) logger.debug("Loading test config.")
app.config.from_object(TestConfig())
else: else:
from config import DefaultConfig from config import DefaultConfig
logger.debug('Loading default config.')
app.config.from_object(DefaultConfig()) logger.debug("Loading default config.")
app.teardown_request(database.close_db_filter) app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter)
# Load the override config via the provider. # Load the override config via the provider.
config_provider.update_app_config(app.config) config_provider.update_app_config(app.config)
# Update any configuration found in the override environment variable. # Update any configuration found in the override environment variable.
environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, '{}')) environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, "{}"))
app.config.update(environ_config) app.config.update(environ_config)
# Fix remote address handling for Flask. # Fix remote address handling for Flask.
if app.config.get('PROXY_COUNT', 1): if app.config.get("PROXY_COUNT", 1):
app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get('PROXY_COUNT', 1)) app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get("PROXY_COUNT", 1))
# Ensure the V3 upgrade key is specified correctly. If not, simply fail. # Ensure the V3 upgrade key is specified correctly. If not, simply fail.
# TODO: Remove for V3.1. # TODO: Remove for V3.1.
if not is_testing and not is_building and app.config.get('SETUP_COMPLETE', False): if not is_testing and not is_building and app.config.get("SETUP_COMPLETE", False):
v3_upgrade_mode = app.config.get('V3_UPGRADE_MODE') v3_upgrade_mode = app.config.get("V3_UPGRADE_MODE")
if v3_upgrade_mode is None: if v3_upgrade_mode is None:
raise Exception('Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs') raise Exception(
"Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs"
)
if (v3_upgrade_mode != 'background' if (
and v3_upgrade_mode != 'complete' v3_upgrade_mode != "background"
and v3_upgrade_mode != 'production-transition' and v3_upgrade_mode != "complete"
and v3_upgrade_mode != 'post-oci-rollout' and v3_upgrade_mode != "production-transition"
and v3_upgrade_mode != 'post-oci-roll-back-compat'): and v3_upgrade_mode != "post-oci-rollout"
raise Exception('Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs') and v3_upgrade_mode != "post-oci-roll-back-compat"
):
raise Exception(
"Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs"
)
# Split the registry model based on config. # Split the registry model based on config.
# TODO: Remove once we are fully on the OCI data model. # TODO: Remove once we are fully on the OCI data model.
registry_model.setup_split(app.config.get('OCI_NAMESPACE_PROPORTION') or 0, registry_model.setup_split(
app.config.get('OCI_NAMESPACE_WHITELIST') or set(), app.config.get("OCI_NAMESPACE_PROPORTION") or 0,
app.config.get('V22_NAMESPACE_WHITELIST') or set(), app.config.get("OCI_NAMESPACE_WHITELIST") or set(),
app.config.get('V3_UPGRADE_MODE')) app.config.get("V22_NAMESPACE_WHITELIST") or set(),
app.config.get("V3_UPGRADE_MODE"),
)
# Allow user to define a custom storage preference for the local instance. # Allow user to define a custom storage preference for the local instance.
_distributed_storage_preference = os.environ.get('QUAY_DISTRIBUTED_STORAGE_PREFERENCE', '').split() _distributed_storage_preference = os.environ.get(
"QUAY_DISTRIBUTED_STORAGE_PREFERENCE", ""
).split()
if _distributed_storage_preference: if _distributed_storage_preference:
app.config['DISTRIBUTED_STORAGE_PREFERENCE'] = _distributed_storage_preference app.config["DISTRIBUTED_STORAGE_PREFERENCE"] = _distributed_storage_preference
# Generate a secret key if none was specified. # Generate a secret key if none was specified.
if app.config['SECRET_KEY'] is None: if app.config["SECRET_KEY"] is None:
logger.debug('Generating in-memory secret key') logger.debug("Generating in-memory secret key")
app.config['SECRET_KEY'] = generate_secret_key() app.config["SECRET_KEY"] = generate_secret_key()
# If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure # If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure
# session cookie. # session cookie.
if (app.config['PREFERRED_URL_SCHEME'] == 'https' and if app.config["PREFERRED_URL_SCHEME"] == "https" and not app.config.get(
not app.config.get('FORCE_NONSECURE_SESSION_COOKIE', False)): "FORCE_NONSECURE_SESSION_COOKIE", False
app.config['SESSION_COOKIE_SECURE'] = True ):
app.config["SESSION_COOKIE_SECURE"] = True
# Load features from config. # Load features from config.
features.import_features(app.config) features.import_features(app.config)
@ -145,65 +168,77 @@ logger.debug("Loaded config", extra={"config": app.config})
class RequestWithId(Request): class RequestWithId(Request):
request_gen = staticmethod(urn_generator(['request'])) request_gen = staticmethod(urn_generator(["request"]))
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(RequestWithId, self).__init__(*args, **kwargs) super(RequestWithId, self).__init__(*args, **kwargs)
self.request_id = self.request_gen() self.request_id = self.request_gen()
@app.before_request @app.before_request
def _request_start(): def _request_start():
if os.getenv('PYDEV_DEBUG', None): if os.getenv("PYDEV_DEBUG", None):
import pydevd import pydevd
host, port = os.getenv('PYDEV_DEBUG').split(':')
pydevd.settrace(host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False)
logger.debug('Starting request: %s (%s)', request.request_id, request.path, host, port = os.getenv("PYDEV_DEBUG").split(":")
extra={"request_id": request.request_id}) pydevd.settrace(
host,
port=int(port),
stdoutToServer=True,
stderrToServer=True,
suspend=False,
)
logger.debug(
"Starting request: %s (%s)",
request.request_id,
request.path,
extra={"request_id": request.request_id},
)
DEFAULT_FILTER = lambda x: '[FILTERED]' DEFAULT_FILTER = lambda x: "[FILTERED]"
FILTERED_VALUES = [ FILTERED_VALUES = [
{'key': ['password'], 'fn': DEFAULT_FILTER}, {"key": ["password"], "fn": DEFAULT_FILTER},
{'key': ['user', 'password'], 'fn': DEFAULT_FILTER}, {"key": ["user", "password"], "fn": DEFAULT_FILTER},
{'key': ['blob'], 'fn': lambda x: x[0:8]} {"key": ["blob"], "fn": lambda x: x[0:8]},
] ]
@app.after_request @app.after_request
def _request_end(resp): def _request_end(resp):
try: try:
jsonbody = request.get_json(force=True, silent=True) jsonbody = request.get_json(force=True, silent=True)
except HTTPException: except HTTPException:
jsonbody = None jsonbody = None
values = request.values.to_dict() values = request.values.to_dict()
if jsonbody and not isinstance(jsonbody, dict): if jsonbody and not isinstance(jsonbody, dict):
jsonbody = {'_parsererror': jsonbody} jsonbody = {"_parsererror": jsonbody}
if isinstance(values, dict): if isinstance(values, dict):
filter_logs(values, FILTERED_VALUES) filter_logs(values, FILTERED_VALUES)
extra = { extra = {
"endpoint": request.endpoint, "endpoint": request.endpoint,
"request_id" : request.request_id, "request_id": request.request_id,
"remote_addr": request.remote_addr, "remote_addr": request.remote_addr,
"http_method": request.method, "http_method": request.method,
"original_url": request.url, "original_url": request.url,
"path": request.path, "path": request.path,
"parameters": values, "parameters": values,
"json_body": jsonbody, "json_body": jsonbody,
"confsha": CONFIG_DIGEST, "confsha": CONFIG_DIGEST,
} }
if request.user_agent is not None: if request.user_agent is not None:
extra["user-agent"] = request.user_agent.string extra["user-agent"] = request.user_agent.string
logger.debug('Ending request: %s (%s)', request.request_id, request.path, extra=extra)
return resp
logger.debug(
"Ending request: %s (%s)", request.request_id, request.path, extra=extra
)
return resp
root_logger = logging.getLogger() root_logger = logging.getLogger()
@ -211,13 +246,13 @@ root_logger = logging.getLogger()
app.request_class = RequestWithId app.request_class = RequestWithId
# Register custom converters. # Register custom converters.
app.url_map.converters['regex'] = RegexConverter app.url_map.converters["regex"] = RegexConverter
app.url_map.converters['repopath'] = RepositoryPathConverter app.url_map.converters["repopath"] = RepositoryPathConverter
app.url_map.converters['apirepopath'] = APIRepositoryPathConverter app.url_map.converters["apirepopath"] = APIRepositoryPathConverter
Principal(app, use_sessions=False) Principal(app, use_sessions=False)
tf = app.config['DB_TRANSACTION_FACTORY'] tf = app.config["DB_TRANSACTION_FACTORY"]
model_cache = get_model_cache(app.config) model_cache = get_model_cache(app.config)
avatar = Avatar(app) avatar = Avatar(app)
@ -225,10 +260,14 @@ login_manager = LoginManager(app)
mail = Mail(app) mail = Mail(app)
prometheus = PrometheusPlugin(app) prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus) metric_queue = MetricQueue(prometheus)
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue) chunk_cleanup_queue = WorkQueue(
app.config["CHUNK_CLEANUP_QUEUE_NAME"], tf, metric_queue=metric_queue
)
instance_keys = InstanceKeys(app) instance_keys = InstanceKeys(app)
ip_resolver = IPResolver(app) ip_resolver = IPResolver(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver) storage = Storage(
app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver
)
userfiles = Userfiles(app, storage) userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage) log_archive = LogArchive(app, storage)
analytics = Analytics(app) analytics = Analytics(app)
@ -246,55 +285,99 @@ build_canceller = BuildCanceller(app)
start_cloudwatch_sender(metric_queue, app) start_cloudwatch_sender(metric_queue, app)
github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG') github_trigger = GithubOAuthService(app.config, "GITHUB_TRIGGER_CONFIG")
gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG') gitlab_trigger = GitLabOAuthService(app.config, "GITLAB_TRIGGER_CONFIG")
oauth_login = OAuthLoginManager(app.config) oauth_login = OAuthLoginManager(app.config)
oauth_apps = [github_trigger, gitlab_trigger] oauth_apps = [github_trigger, gitlab_trigger]
image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf, image_replication_queue = WorkQueue(
has_namespace=False, metric_queue=metric_queue) app.config["REPLICATION_QUEUE_NAME"],
dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf, tf,
metric_queue=metric_queue, has_namespace=False,
reporter=BuildMetricQueueReporter(metric_queue), metric_queue=metric_queue,
has_namespace=True) )
notification_queue = WorkQueue(app.config['NOTIFICATION_QUEUE_NAME'], tf, has_namespace=True, dockerfile_build_queue = WorkQueue(
metric_queue=metric_queue) app.config["DOCKERFILE_BUILD_QUEUE_NAME"],
secscan_notification_queue = WorkQueue(app.config['SECSCAN_NOTIFICATION_QUEUE_NAME'], tf, tf,
has_namespace=False, metric_queue=metric_queue,
metric_queue=metric_queue) reporter=BuildMetricQueueReporter(metric_queue),
export_action_logs_queue = WorkQueue(app.config['EXPORT_ACTION_LOGS_QUEUE_NAME'], tf, has_namespace=True,
has_namespace=True, )
metric_queue=metric_queue) notification_queue = WorkQueue(
app.config["NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
secscan_notification_queue = WorkQueue(
app.config["SECSCAN_NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
export_action_logs_queue = WorkQueue(
app.config["EXPORT_ACTION_LOGS_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
# Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied # Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied
# when a namespace is marked for deletion. # when a namespace is marked for deletion.
namespace_gc_queue = WorkQueue(app.config['NAMESPACE_GC_QUEUE_NAME'], tf, has_namespace=False, namespace_gc_queue = WorkQueue(
metric_queue=metric_queue) app.config["NAMESPACE_GC_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
all_queues = [image_replication_queue, dockerfile_build_queue, notification_queue, all_queues = [
secscan_notification_queue, chunk_cleanup_queue, namespace_gc_queue] image_replication_queue,
dockerfile_build_queue,
notification_queue,
secscan_notification_queue,
chunk_cleanup_queue,
namespace_gc_queue,
]
url_scheme_and_hostname = URLSchemeAndHostname(app.config['PREFERRED_URL_SCHEME'], app.config['SERVER_HOSTNAME']) url_scheme_and_hostname = URLSchemeAndHostname(
secscan_api = SecurityScannerAPI(app.config, storage, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'], app.config["PREFERRED_URL_SCHEME"], app.config["SERVER_HOSTNAME"]
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname), )
instance_keys=instance_keys) secscan_api = SecurityScannerAPI(
app.config,
storage,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
uri_creator=get_blob_download_uri_getter(
app.test_request_context("/"), url_scheme_and_hostname
),
instance_keys=instance_keys,
)
repo_mirror_api = RepoMirrorAPI(app.config, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'], repo_mirror_api = RepoMirrorAPI(
instance_keys=instance_keys) app.config,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
instance_keys=instance_keys,
)
tuf_metadata_api = TUFMetadataAPI(app, app.config) tuf_metadata_api = TUFMetadataAPI(app, app.config)
# Check for a key in config. If none found, generate a new signing key for Docker V2 manifests. # Check for a key in config. If none found, generate a new signing key for Docker V2 manifests.
_v2_key_path = os.path.join(OVERRIDE_CONFIG_DIRECTORY, DOCKER_V2_SIGNINGKEY_FILENAME) _v2_key_path = os.path.join(OVERRIDE_CONFIG_DIRECTORY, DOCKER_V2_SIGNINGKEY_FILENAME)
if os.path.exists(_v2_key_path): if os.path.exists(_v2_key_path):
docker_v2_signing_key = RSAKey().load(_v2_key_path) docker_v2_signing_key = RSAKey().load(_v2_key_path)
else: else:
docker_v2_signing_key = RSAKey(key=RSA.generate(2048)) docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
# Configure the database. # Configure the database.
if app.config.get('DATABASE_SECRET_KEY') is None and app.config.get('SETUP_COMPLETE', False): if app.config.get("DATABASE_SECRET_KEY") is None and app.config.get(
raise Exception('Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?') "SETUP_COMPLETE", False
):
raise Exception(
"Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?"
)
database.configure(app.config) database.configure(app.config)
@ -306,8 +389,9 @@ model.config.register_repo_cleanup_callback(tuf_metadata_api.delete_metadata)
@login_manager.user_loader @login_manager.user_loader
def load_user(user_uuid): def load_user(user_uuid):
logger.debug('User loader loading deferred user with uuid: %s', user_uuid) logger.debug("User loader loading deferred user with uuid: %s", user_uuid)
return LoginWrappedDBUser(user_uuid) return LoginWrappedDBUser(user_uuid)
logs_model.configure(app.config) logs_model.configure(app.config)

View file

@ -1,5 +1,6 @@
# NOTE: Must be before we import or call anything that may be synchronous. # NOTE: Must be before we import or call anything that may be synchronous.
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
import os import os
@ -17,6 +18,6 @@ import registry
import secscan import secscan
if __name__ == '__main__': if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False) logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0') application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -1,21 +1,25 @@
from flask import _request_ctx_stack from flask import _request_ctx_stack
def get_authenticated_context(): def get_authenticated_context():
""" Returns the auth context for the current request context, if any. """ """ Returns the auth context for the current request context, if any. """
return getattr(_request_ctx_stack.top, 'authenticated_context', None) return getattr(_request_ctx_stack.top, "authenticated_context", None)
def get_authenticated_user(): def get_authenticated_user():
""" Returns the authenticated user, if any, or None if none. """ """ Returns the authenticated user, if any, or None if none. """
context = get_authenticated_context() context = get_authenticated_context()
return context.authed_user if context else None return context.authed_user if context else None
def get_validated_oauth_token(): def get_validated_oauth_token():
""" Returns the authenticated and validated OAuth access token, if any, or None if none. """ """ Returns the authenticated and validated OAuth access token, if any, or None if none. """
context = get_authenticated_context() context = get_authenticated_context()
return context.authed_oauth_token if context else None return context.authed_oauth_token if context else None
def set_authenticated_context(auth_context): def set_authenticated_context(auth_context):
""" Sets the auth context for the current request context to that given. """ """ Sets the auth context for the current request context to that given. """
ctx = _request_ctx_stack.top ctx = _request_ctx_stack.top
ctx.authenticated_context = auth_context ctx.authenticated_context = auth_context
return auth_context return auth_context

View file

@ -16,422 +16,446 @@ from auth.scopes import scopes_from_scope_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class AuthContext(object): class AuthContext(object):
""" """
Interface that represents the current context of authentication. Interface that represents the current context of authentication.
""" """
@property @property
@abstractmethod @abstractmethod
def entity_kind(self): def entity_kind(self):
""" Returns the kind of the entity in this auth context. """ """ Returns the kind of the entity in this auth context. """
pass pass
@property @property
@abstractmethod @abstractmethod
def is_anonymous(self): def is_anonymous(self):
""" Returns true if this is an anonymous context. """ """ Returns true if this is an anonymous context. """
pass pass
@property @property
@abstractmethod @abstractmethod
def authed_oauth_token(self): def authed_oauth_token(self):
""" Returns the authenticated OAuth token, if any. """ """ Returns the authenticated OAuth token, if any. """
pass pass
@property @property
@abstractmethod @abstractmethod
def authed_user(self): def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts. this property will also return robot accounts.
""" """
pass pass
@property @property
@abstractmethod @abstractmethod
def has_nonrobot_user(self): def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """ """ Returns whether a user (not a robot) was authenticated successfully. """
pass pass
@property @property
@abstractmethod @abstractmethod
def identity(self): def identity(self):
""" Returns the identity for the auth context. """ """ Returns the identity for the auth context. """
pass pass
@property @property
@abstractmethod @abstractmethod
def description(self): def description(self):
""" Returns a human-readable and *public* description of the current auth context. """ """ Returns a human-readable and *public* description of the current auth context. """
pass pass
@property @property
@abstractmethod @abstractmethod
def credential_username(self): def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """ """ Returns the username to create credentials for this context's entity, if any. """
pass pass
@abstractmethod @abstractmethod
def analytics_id_and_public_metadata(self): def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """ """ Returns the analytics ID and public log metadata for this auth context. """
pass pass
@abstractmethod @abstractmethod
def apply_to_request_context(self): def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """ """ Applies this auth result to the auth context and Flask-Principal. """
pass pass
@abstractmethod @abstractmethod
def to_signed_dict(self): def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization. form of signed serialization.
""" """
pass pass
@property @property
@abstractmethod @abstractmethod
def unique_key(self): def unique_key(self):
""" Returns a key that is unique to this auth context type and its data. For example, an """ Returns a key that is unique to this auth context type and its data. For example, an
instance of the auth context type for the user might be a string of the form instance of the auth context type for the user might be a string of the form
`user-{user-uuid}`. Callers should treat this key as opaque and not rely on the contents `user-{user-uuid}`. Callers should treat this key as opaque and not rely on the contents
for anything besides uniqueness. This is typically used by callers when they'd like to for anything besides uniqueness. This is typically used by callers when they'd like to
check cache but not hit the database to get a fully validated auth context. check cache but not hit the database to get a fully validated auth context.
""" """
pass pass
class ValidatedAuthContext(AuthContext): class ValidatedAuthContext(AuthContext):
""" ValidatedAuthContext represents the loaded, authenticated and validated auth information """ ValidatedAuthContext represents the loaded, authenticated and validated auth information
for the current request context. for the current request context.
""" """
def __init__(self, user=None, token=None, oauthtoken=None, robot=None, appspecifictoken=None,
signed_data=None):
# Note: These field names *MUST* match the string values of the kinds defined in
# ContextEntityKind.
self.user = user
self.robot = robot
self.token = token
self.oauthtoken = oauthtoken
self.appspecifictoken = appspecifictoken
self.signed_data = signed_data
def tuple(self): def __init__(
return vars(self).values() self,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
):
# Note: These field names *MUST* match the string values of the kinds defined in
# ContextEntityKind.
self.user = user
self.robot = robot
self.token = token
self.oauthtoken = oauthtoken
self.appspecifictoken = appspecifictoken
self.signed_data = signed_data
def __eq__(self, other): def tuple(self):
return self.tuple() == other.tuple() return vars(self).values()
@property def __eq__(self, other):
def entity_kind(self): return self.tuple() == other.tuple()
""" Returns the kind of the entity in this auth context. """
for kind in ContextEntityKind:
if hasattr(self, kind.value) and getattr(self, kind.value):
return kind
return ContextEntityKind.anonymous @property
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
for kind in ContextEntityKind:
if hasattr(self, kind.value) and getattr(self, kind.value):
return kind
@property return ContextEntityKind.anonymous
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. Note that this @property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. Note that this
will also return robot accounts. will also return robot accounts.
""" """
authed_user = self._authed_user() authed_user = self._authed_user()
if authed_user is not None and not authed_user.enabled: if authed_user is not None and not authed_user.enabled:
logger.warning('Attempt to reference a disabled user/robot: %s', authed_user.username) logger.warning(
return None "Attempt to reference a disabled user/robot: %s", authed_user.username
)
return None
return authed_user return authed_user
@property @property
def authed_oauth_token(self): def authed_oauth_token(self):
return self.oauthtoken return self.oauthtoken
def _authed_user(self): def _authed_user(self):
if self.oauthtoken: if self.oauthtoken:
return self.oauthtoken.authorized_user return self.oauthtoken.authorized_user
if self.appspecifictoken: if self.appspecifictoken:
return self.appspecifictoken.user return self.appspecifictoken.user
if self.signed_data: if self.signed_data:
return model.user.get_user(self.signed_data['user_context']) return model.user.get_user(self.signed_data["user_context"])
return self.user if self.user else self.robot return self.user if self.user else self.robot
@property @property
def is_anonymous(self): def is_anonymous(self):
""" Returns true if this is an anonymous context. """ """ Returns true if this is an anonymous context. """
return not self.authed_user and not self.token and not self.signed_data return not self.authed_user and not self.token and not self.signed_data
@property @property
def has_nonrobot_user(self): def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """ """ Returns whether a user (not a robot) was authenticated successfully. """
return bool(self.authed_user and not self.robot) return bool(self.authed_user and not self.robot)
@property @property
def identity(self): def identity(self):
""" Returns the identity for the auth context. """ """ Returns the identity for the auth context. """
if self.oauthtoken: if self.oauthtoken:
scope_set = scopes_from_scope_string(self.oauthtoken.scope) scope_set = scopes_from_scope_string(self.oauthtoken.scope)
return QuayDeferredPermissionUser.for_user(self.oauthtoken.authorized_user, scope_set) return QuayDeferredPermissionUser.for_user(
self.oauthtoken.authorized_user, scope_set
)
if self.authed_user: if self.authed_user:
return QuayDeferredPermissionUser.for_user(self.authed_user) return QuayDeferredPermissionUser.for_user(self.authed_user)
if self.token: if self.token:
return Identity(self.token.get_code(), 'token') return Identity(self.token.get_code(), "token")
if self.signed_data: if self.signed_data:
identity = Identity(None, 'signed_grant') identity = Identity(None, "signed_grant")
identity.provides.update(self.signed_data['grants']) identity.provides.update(self.signed_data["grants"])
return identity return identity
return None return None
@property @property
def entity_reference(self): def entity_reference(self):
""" Returns the DB object reference for this context's entity. """ """ Returns the DB object reference for this context's entity. """
if self.entity_kind == ContextEntityKind.anonymous: if self.entity_kind == ContextEntityKind.anonymous:
return None return None
return getattr(self, self.entity_kind.value) return getattr(self, self.entity_kind.value)
@property @property
def description(self): def description(self):
""" Returns a human-readable and *public* description of the current auth context. """ """ Returns a human-readable and *public* description of the current auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]() handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.description(self.entity_reference) return handler.description(self.entity_reference)
@property @property
def credential_username(self): def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """ """ Returns the username to create credentials for this context's entity, if any. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]() handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.credential_username(self.entity_reference) return handler.credential_username(self.entity_reference)
def analytics_id_and_public_metadata(self): def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """ """ Returns the analytics ID and public log metadata for this auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]() handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.analytics_id_and_public_metadata(self.entity_reference) return handler.analytics_id_and_public_metadata(self.entity_reference)
def apply_to_request_context(self): def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """ """ Applies this auth result to the auth context and Flask-Principal. """
# Save to the request context. # Save to the request context.
set_authenticated_context(self) set_authenticated_context(self)
# Set the identity for Flask-Principal. # Set the identity for Flask-Principal.
if self.identity: if self.identity:
identity_changed.send(app, identity=self.identity) identity_changed.send(app, identity=self.identity)
@property @property
def unique_key(self): def unique_key(self):
signed_dict = self.to_signed_dict() signed_dict = self.to_signed_dict()
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)')) return "%s-%s" % (
signed_dict["entity_kind"],
signed_dict.get("entity_reference", "(anon)"),
)
def to_signed_dict(self): def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization. form of signed serialization.
""" """
dict_data = { dict_data = {"version": 2, "entity_kind": self.entity_kind.value}
'version': 2,
'entity_kind': self.entity_kind.value,
}
if self.entity_kind != ContextEntityKind.anonymous: if self.entity_kind != ContextEntityKind.anonymous:
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]() handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
dict_data.update({ dict_data.update(
'entity_reference': handler.get_serialized_entity_reference(self.entity_reference), {
}) "entity_reference": handler.get_serialized_entity_reference(
self.entity_reference
)
}
)
# Add legacy information. # Add legacy information.
# TODO: Remove this all once the new code is fully deployed. # TODO: Remove this all once the new code is fully deployed.
if self.token: if self.token:
dict_data.update({ dict_data.update({"kind": "token", "token": self.token.code})
'kind': 'token',
'token': self.token.code,
})
if self.oauthtoken: if self.oauthtoken:
dict_data.update({ dict_data.update(
'kind': 'oauth', {
'oauth': self.oauthtoken.uuid, "kind": "oauth",
'user': self.authed_user.username, "oauth": self.oauthtoken.uuid,
}) "user": self.authed_user.username,
}
)
if self.user or self.robot: if self.user or self.robot:
dict_data.update({ dict_data.update({"kind": "user", "user": self.authed_user.username})
'kind': 'user',
'user': self.authed_user.username,
})
if self.appspecifictoken: if self.appspecifictoken:
dict_data.update({ dict_data.update({"kind": "user", "user": self.authed_user.username})
'kind': 'user',
'user': self.authed_user.username,
})
if self.is_anonymous: if self.is_anonymous:
dict_data.update({ dict_data.update({"kind": "anonymous"})
'kind': 'anonymous',
}) # End of legacy information.
return dict_data
# End of legacy information.
return dict_data
class SignedAuthContext(AuthContext): class SignedAuthContext(AuthContext):
""" SignedAuthContext represents an auth context loaded from a signed token of some kind, """ SignedAuthContext represents an auth context loaded from a signed token of some kind,
such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading
the actual {user, robot, token, etc} when requested. This allows registry operations that the actual {user, robot, token, etc} when requested. This allows registry operations that
only need to check if *some* entity is present to do so, without hitting the database. only need to check if *some* entity is present to do so, without hitting the database.
""" """
def __init__(self, kind, signed_data, v1_dict_format):
self.kind = kind
self.signed_data = signed_data
self.v1_dict_format = v1_dict_format
@property def __init__(self, kind, signed_data, v1_dict_format):
def unique_key(self): self.kind = kind
if self.v1_dict_format: self.signed_data = signed_data
# Since V1 data format is verbose, just use the validated version to get the key. self.v1_dict_format = v1_dict_format
return self._get_validated().unique_key
signed_dict = self.signed_data @property
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)')) def unique_key(self):
if self.v1_dict_format:
# Since V1 data format is verbose, just use the validated version to get the key.
return self._get_validated().unique_key
@classmethod signed_dict = self.signed_data
def build_from_signed_dict(cls, dict_data, v1_dict_format=False): return "%s-%s" % (
if not v1_dict_format: signed_dict["entity_kind"],
entity_kind = ContextEntityKind(dict_data.get('entity_kind', 'anonymous')) signed_dict.get("entity_reference", "(anon)"),
return SignedAuthContext(entity_kind, dict_data, v1_dict_format) )
# Legacy handling. @classmethod
# TODO: Remove this all once the new code is fully deployed. def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
kind_string = dict_data.get('kind', 'anonymous') if not v1_dict_format:
if kind_string == 'oauth': entity_kind = ContextEntityKind(dict_data.get("entity_kind", "anonymous"))
kind_string = 'oauthtoken' return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
kind = ContextEntityKind(kind_string) # Legacy handling.
return SignedAuthContext(kind, dict_data, v1_dict_format) # TODO: Remove this all once the new code is fully deployed.
kind_string = dict_data.get("kind", "anonymous")
if kind_string == "oauth":
kind_string = "oauthtoken"
@lru_cache(maxsize=1) kind = ContextEntityKind(kind_string)
def _get_validated(self): return SignedAuthContext(kind, dict_data, v1_dict_format)
""" Returns a ValidatedAuthContext for this signed context, resolving all the necessary
@lru_cache(maxsize=1)
def _get_validated(self):
""" Returns a ValidatedAuthContext for this signed context, resolving all the necessary
references. references.
""" """
if not self.v1_dict_format: if not self.v1_dict_format:
if self.kind == ContextEntityKind.anonymous: if self.kind == ContextEntityKind.anonymous:
return ValidatedAuthContext() return ValidatedAuthContext()
serialized_entity_reference = self.signed_data['entity_reference'] serialized_entity_reference = self.signed_data["entity_reference"]
handler = CONTEXT_ENTITY_HANDLERS[self.kind]() handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
entity_reference = handler.deserialize_entity_reference(serialized_entity_reference) entity_reference = handler.deserialize_entity_reference(
if entity_reference is None: serialized_entity_reference
logger.debug('Could not deserialize entity reference `%s` under kind `%s`', )
serialized_entity_reference, self.kind) if entity_reference is None:
return ValidatedAuthContext() logger.debug(
"Could not deserialize entity reference `%s` under kind `%s`",
serialized_entity_reference,
self.kind,
)
return ValidatedAuthContext()
return ValidatedAuthContext(**{self.kind.value: entity_reference}) return ValidatedAuthContext(**{self.kind.value: entity_reference})
# Legacy handling. # Legacy handling.
# TODO: Remove this all once the new code is fully deployed. # TODO: Remove this all once the new code is fully deployed.
kind_string = self.signed_data.get('kind', 'anonymous') kind_string = self.signed_data.get("kind", "anonymous")
if kind_string == 'oauth': if kind_string == "oauth":
kind_string = 'oauthtoken' kind_string = "oauthtoken"
kind = ContextEntityKind(kind_string) kind = ContextEntityKind(kind_string)
if kind == ContextEntityKind.anonymous: if kind == ContextEntityKind.anonymous:
return ValidatedAuthContext() return ValidatedAuthContext()
if kind == ContextEntityKind.user or kind == ContextEntityKind.robot: if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
user = model.user.get_user(self.signed_data.get('user', '')) user = model.user.get_user(self.signed_data.get("user", ""))
if not user: if not user:
return None return None
return ValidatedAuthContext(robot=user) if user.robot else ValidatedAuthContext(user=user) return (
ValidatedAuthContext(robot=user)
if user.robot
else ValidatedAuthContext(user=user)
)
if kind == ContextEntityKind.token: if kind == ContextEntityKind.token:
token = model.token.load_token_data(self.signed_data.get('token')) token = model.token.load_token_data(self.signed_data.get("token"))
if not token: if not token:
return None return None
return ValidatedAuthContext(token=token) return ValidatedAuthContext(token=token)
if kind == ContextEntityKind.oauthtoken: if kind == ContextEntityKind.oauthtoken:
user = model.user.get_user(self.signed_data.get('user', '')) user = model.user.get_user(self.signed_data.get("user", ""))
if not user: if not user:
return None return None
token_uuid = self.signed_data.get('oauth', '') token_uuid = self.signed_data.get("oauth", "")
oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid) oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
if not oauthtoken: if not oauthtoken:
return None return None
return ValidatedAuthContext(oauthtoken=oauthtoken) return ValidatedAuthContext(oauthtoken=oauthtoken)
raise Exception('Unknown auth context kind `%s` when deserializing %s' % (kind, raise Exception(
self.signed_data)) "Unknown auth context kind `%s` when deserializing %s"
# End of legacy handling. % (kind, self.signed_data)
)
# End of legacy handling.
@property @property
def entity_kind(self): def entity_kind(self):
""" Returns the kind of the entity in this auth context. """ """ Returns the kind of the entity in this auth context. """
return self.kind return self.kind
@property @property
def is_anonymous(self): def is_anonymous(self):
""" Returns true if this is an anonymous context. """ """ Returns true if this is an anonymous context. """
return self.kind == ContextEntityKind.anonymous return self.kind == ContextEntityKind.anonymous
@property @property
def authed_user(self): def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts. this property will also return robot accounts.
""" """
if self.kind == ContextEntityKind.anonymous: if self.kind == ContextEntityKind.anonymous:
return None return None
return self._get_validated().authed_user return self._get_validated().authed_user
@property @property
def authed_oauth_token(self): def authed_oauth_token(self):
if self.kind == ContextEntityKind.anonymous: if self.kind == ContextEntityKind.anonymous:
return None return None
return self._get_validated().authed_oauth_token return self._get_validated().authed_oauth_token
@property @property
def has_nonrobot_user(self): def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """ """ Returns whether a user (not a robot) was authenticated successfully. """
if self.kind == ContextEntityKind.anonymous: if self.kind == ContextEntityKind.anonymous:
return False return False
return self._get_validated().has_nonrobot_user return self._get_validated().has_nonrobot_user
@property @property
def identity(self): def identity(self):
""" Returns the identity for the auth context. """ """ Returns the identity for the auth context. """
return self._get_validated().identity return self._get_validated().identity
@property @property
def description(self): def description(self):
""" Returns a human-readable and *public* description of the current auth context. """ """ Returns a human-readable and *public* description of the current auth context. """
return self._get_validated().description return self._get_validated().description
@property @property
def credential_username(self): def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """ """ Returns the username to create credentials for this context's entity, if any. """
return self._get_validated().credential_username return self._get_validated().credential_username
def analytics_id_and_public_metadata(self): def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """ """ Returns the analytics ID and public log metadata for this auth context. """
return self._get_validated().analytics_id_and_public_metadata() return self._get_validated().analytics_id_and_public_metadata()
def apply_to_request_context(self): def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """ """ Applies this auth result to the auth context and Flask-Principal. """
return self._get_validated().apply_to_request_context() return self._get_validated().apply_to_request_context()
def to_signed_dict(self): def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization. form of signed serialization.
""" """
return self.signed_data return self.signed_data

View file

@ -8,51 +8,54 @@ from auth.validateresult import ValidateResult, AuthKind
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def has_basic_auth(username): def has_basic_auth(username):
""" Returns true if a basic auth header exists with a username and password pair that validates """ Returns true if a basic auth header exists with a username and password pair that validates
against the internal authentication system. Returns True on full success and False on any against the internal authentication system. Returns True on full success and False on any
failure (missing header, invalid header, invalid credentials, etc). failure (missing header, invalid header, invalid credentials, etc).
""" """
auth_header = request.headers.get('authorization', '') auth_header = request.headers.get("authorization", "")
result = validate_basic_auth(auth_header) result = validate_basic_auth(auth_header)
return result.has_nonrobot_user and result.context.user.username == username return result.has_nonrobot_user and result.context.user.username == username
def validate_basic_auth(auth_header): def validate_basic_auth(auth_header):
""" Validates the specified basic auth header, returning whether its credentials point """ Validates the specified basic auth header, returning whether its credentials point
to a valid user or token. to a valid user or token.
""" """
if not auth_header: if not auth_header:
return ValidateResult(AuthKind.basic, missing=True) return ValidateResult(AuthKind.basic, missing=True)
logger.debug('Attempt to process basic auth header') logger.debug("Attempt to process basic auth header")
# Parse the basic auth header. # Parse the basic auth header.
assert isinstance(auth_header, basestring) assert isinstance(auth_header, basestring)
credentials, err = _parse_basic_auth_header(auth_header) credentials, err = _parse_basic_auth_header(auth_header)
if err is not None: if err is not None:
logger.debug('Got invalid basic auth header: %s', auth_header) logger.debug("Got invalid basic auth header: %s", auth_header)
return ValidateResult(AuthKind.basic, missing=True) return ValidateResult(AuthKind.basic, missing=True)
auth_username, auth_password_or_token = credentials auth_username, auth_password_or_token = credentials
result, _ = validate_credentials(auth_username, auth_password_or_token) result, _ = validate_credentials(auth_username, auth_password_or_token)
return result.with_kind(AuthKind.basic) return result.with_kind(AuthKind.basic)
def _parse_basic_auth_header(auth): def _parse_basic_auth_header(auth):
""" Parses the given basic auth header, returning the credentials found inside. """ Parses the given basic auth header, returning the credentials found inside.
""" """
normalized = [part.strip() for part in auth.split(' ') if part] normalized = [part.strip() for part in auth.split(" ") if part]
if normalized[0].lower() != 'basic' or len(normalized) != 2: if normalized[0].lower() != "basic" or len(normalized) != 2:
return None, 'Invalid basic auth header' return None, "Invalid basic auth header"
try: try:
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)] credentials = [
except (TypeError, UnicodeDecodeError, ValueError): part.decode("utf-8") for part in b64decode(normalized[1]).split(":", 1)
logger.exception('Exception when parsing basic auth header: %s', auth) ]
return None, 'Could not parse basic auth header' except (TypeError, UnicodeDecodeError, ValueError):
logger.exception("Exception when parsing basic auth header: %s", auth)
return None, "Could not parse basic auth header"
if len(credentials) != 2: if len(credentials) != 2:
return None, 'Unexpected number of credentials found in basic auth header' return None, "Unexpected number of credentials found in basic auth header"
return credentials, None return credentials, None

View file

@ -4,200 +4,210 @@ from enum import Enum
from data import model from data import model
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
class ContextEntityKind(Enum): class ContextEntityKind(Enum):
""" Defines the various kinds of entities in an auth context. Note that the string values of """ Defines the various kinds of entities in an auth context. Note that the string values of
these fields *must* match the names of the fields in the ValidatedAuthContext class, as these fields *must* match the names of the fields in the ValidatedAuthContext class, as
we fill them in directly based on the string names here. we fill them in directly based on the string names here.
""" """
anonymous = 'anonymous'
user = 'user' anonymous = "anonymous"
robot = 'robot' user = "user"
token = 'token' robot = "robot"
oauthtoken = 'oauthtoken' token = "token"
appspecifictoken = 'appspecifictoken' oauthtoken = "oauthtoken"
signed_data = 'signed_data' appspecifictoken = "appspecifictoken"
signed_data = "signed_data"
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class ContextEntityHandler(object): class ContextEntityHandler(object):
""" """
Interface that represents handling specific kinds of entities under an auth context. Interface that represents handling specific kinds of entities under an auth context.
""" """
@abstractmethod @abstractmethod
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
""" Returns the username to create credentials for this entity, if any. """ """ Returns the username to create credentials for this entity, if any. """
pass pass
@abstractmethod @abstractmethod
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
""" Returns the entity reference for this kind of auth context, serialized into a form that can """ Returns the entity reference for this kind of auth context, serialized into a form that can
be placed into a JSON object and put into a JWT. This is typically a DB UUID or another be placed into a JSON object and put into a JWT. This is typically a DB UUID or another
unique identifier for the object in the DB. unique identifier for the object in the DB.
""" """
pass pass
@abstractmethod @abstractmethod
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
""" Returns the deserialized reference to the entity in the database, or None if none. """ """ Returns the deserialized reference to the entity in the database, or None if none. """
pass pass
@abstractmethod @abstractmethod
def description(self, entity_reference): def description(self, entity_reference):
""" Returns a human-readable and *public* description of the current entity. """ """ Returns a human-readable and *public* description of the current entity. """
pass pass
@abstractmethod @abstractmethod
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
""" Returns the analyitics ID and a dict of public metadata for the current entity. """ """ Returns the analyitics ID and a dict of public metadata for the current entity. """
pass pass
class AnonymousEntityHandler(ContextEntityHandler): class AnonymousEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return None return None
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return None return None
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return None return None
def description(self, entity_reference): def description(self, entity_reference):
return "anonymous" return "anonymous"
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return "anonymous", {} return "anonymous", {}
class UserEntityHandler(ContextEntityHandler): class UserEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return entity_reference.username return entity_reference.username
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.get_user_by_uuid(serialized_entity_reference) return model.user.get_user_by_uuid(serialized_entity_reference)
def description(self, entity_reference): def description(self, entity_reference):
return "user %s" % entity_reference.username return "user %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, { return entity_reference.username, {"username": entity_reference.username}
'username': entity_reference.username,
}
class RobotEntityHandler(ContextEntityHandler): class RobotEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return entity_reference.username return entity_reference.username
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return entity_reference.username return entity_reference.username
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.lookup_robot(serialized_entity_reference) return model.user.lookup_robot(serialized_entity_reference)
def description(self, entity_reference): def description(self, entity_reference):
return "robot %s" % entity_reference.username return "robot %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, { return (
'username': entity_reference.username, entity_reference.username,
'is_robot': True, {"username": entity_reference.username, "is_robot": True},
} )
class TokenEntityHandler(ContextEntityHandler): class TokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return ACCESS_TOKEN_USERNAME return ACCESS_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return entity_reference.get_code() return entity_reference.get_code()
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return model.token.load_token_data(serialized_entity_reference) return model.token.load_token_data(serialized_entity_reference)
def description(self, entity_reference): def description(self, entity_reference):
return "token %s" % entity_reference.friendly_name return "token %s" % entity_reference.friendly_name
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'token:%s' % entity_reference.id, { return (
'token': entity_reference.friendly_name, "token:%s" % entity_reference.id,
} {"token": entity_reference.friendly_name},
)
class OAuthTokenEntityHandler(ContextEntityHandler): class OAuthTokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return OAUTH_TOKEN_USERNAME return OAUTH_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference) return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference)
def description(self, entity_reference): def description(self, entity_reference):
return "oauthtoken for user %s" % entity_reference.authorized_user.username return "oauthtoken for user %s" % entity_reference.authorized_user.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'oauthtoken:%s' % entity_reference.id, { return (
'oauth_token_id': entity_reference.id, "oauthtoken:%s" % entity_reference.id,
'oauth_token_application_id': entity_reference.application.client_id, {
'oauth_token_application': entity_reference.application.name, "oauth_token_id": entity_reference.id,
'username': entity_reference.authorized_user.username, "oauth_token_application_id": entity_reference.application.client_id,
} "oauth_token_application": entity_reference.application.name,
"username": entity_reference.authorized_user.username,
},
)
class AppSpecificTokenEntityHandler(ContextEntityHandler): class AppSpecificTokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return APP_SPECIFIC_TOKEN_USERNAME return APP_SPECIFIC_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference) return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference)
def description(self, entity_reference): def description(self, entity_reference):
tpl = (entity_reference.title, entity_reference.user.username) tpl = (entity_reference.title, entity_reference.user.username)
return "app specific token %s for user %s" % tpl return "app specific token %s for user %s" % tpl
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'appspecifictoken:%s' % entity_reference.id, { return (
'app_specific_token': entity_reference.uuid, "appspecifictoken:%s" % entity_reference.id,
'app_specific_token_title': entity_reference.title, {
'username': entity_reference.user.username, "app_specific_token": entity_reference.uuid,
} "app_specific_token_title": entity_reference.title,
"username": entity_reference.user.username,
},
)
class SignedDataEntityHandler(ContextEntityHandler): class SignedDataEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference): def credential_username(self, entity_reference):
return None return None
def get_serialized_entity_reference(self, entity_reference): def get_serialized_entity_reference(self, entity_reference):
raise NotImplementedError raise NotImplementedError
def deserialize_entity_reference(self, serialized_entity_reference): def deserialize_entity_reference(self, serialized_entity_reference):
raise NotImplementedError raise NotImplementedError
def description(self, entity_reference): def description(self, entity_reference):
return "signed" return "signed"
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'signed', {'signed': entity_reference} return "signed", {"signed": entity_reference}
CONTEXT_ENTITY_HANDLERS = { CONTEXT_ENTITY_HANDLERS = {
ContextEntityKind.anonymous: AnonymousEntityHandler, ContextEntityKind.anonymous: AnonymousEntityHandler,
ContextEntityKind.user: UserEntityHandler, ContextEntityKind.user: UserEntityHandler,
ContextEntityKind.robot: RobotEntityHandler, ContextEntityKind.robot: RobotEntityHandler,
ContextEntityKind.token: TokenEntityHandler, ContextEntityKind.token: TokenEntityHandler,
ContextEntityKind.oauthtoken: OAuthTokenEntityHandler, ContextEntityKind.oauthtoken: OAuthTokenEntityHandler,
ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler, ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler,
ContextEntityKind.signed_data: SignedDataEntityHandler, ContextEntityKind.signed_data: SignedDataEntityHandler,
} }

View file

@ -7,31 +7,40 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_session_cookie(auth_header_unusued=None): def validate_session_cookie(auth_header_unusued=None):
""" Attempts to load a user from a session cookie. """ """ Attempts to load a user from a session cookie. """
if current_user.is_anonymous: if current_user.is_anonymous:
return ValidateResult(AuthKind.cookie, missing=True) return ValidateResult(AuthKind.cookie, missing=True)
try: try:
# Attempt to parse the user uuid to make sure the cookie has the right value type # Attempt to parse the user uuid to make sure the cookie has the right value type
UUID(current_user.get_id()) UUID(current_user.get_id())
except ValueError: except ValueError:
logger.debug('Got non-UUID for session cookie user: %s', current_user.get_id()) logger.debug("Got non-UUID for session cookie user: %s", current_user.get_id())
return ValidateResult(AuthKind.cookie, error_message='Invalid session cookie format') return ValidateResult(
AuthKind.cookie, error_message="Invalid session cookie format"
)
logger.debug('Loading user from cookie: %s', current_user.get_id()) logger.debug("Loading user from cookie: %s", current_user.get_id())
db_user = current_user.db_user() db_user = current_user.db_user()
if db_user is None: if db_user is None:
return ValidateResult(AuthKind.cookie, error_message='Could not find matching user') return ValidateResult(
AuthKind.cookie, error_message="Could not find matching user"
)
# Don't allow disabled users to login. # Don't allow disabled users to login.
if not db_user.enabled: if not db_user.enabled:
logger.debug('User %s in session cookie is disabled', db_user.username) logger.debug("User %s in session cookie is disabled", db_user.username)
return ValidateResult(AuthKind.cookie, error_message='User account is disabled') return ValidateResult(AuthKind.cookie, error_message="User account is disabled")
# Don't allow organizations to "login". # Don't allow organizations to "login".
if db_user.organization: if db_user.organization:
logger.debug('User %s in session cookie is in-fact organization', db_user.username) logger.debug(
return ValidateResult(AuthKind.cookie, error_message='Cannot login to organization') "User %s in session cookie is in-fact organization", db_user.username
)
return ValidateResult(
AuthKind.cookie, error_message="Cannot login to organization"
)
return ValidateResult(AuthKind.cookie, user=db_user) return ValidateResult(AuthKind.cookie, user=db_user)

View file

@ -1,3 +1,3 @@
ACCESS_TOKEN_USERNAME = '$token' ACCESS_TOKEN_USERNAME = "$token"
OAUTH_TOKEN_USERNAME = '$oauthtoken' OAUTH_TOKEN_USERNAME = "$oauthtoken"
APP_SPECIFIC_TOKEN_USERNAME = '$app' APP_SPECIFIC_TOKEN_USERNAME = "$app"

View file

@ -7,8 +7,11 @@ import features
from app import authentication from app import authentication
from auth.oauth import validate_oauth_token from auth.oauth import validate_oauth_token
from auth.validateresult import ValidateResult, AuthKind from auth.validateresult import ValidateResult, AuthKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from data import model from data import model
from util.names import parse_robot_username from util.names import parse_robot_username
@ -16,70 +19,116 @@ logger = logging.getLogger(__name__)
class CredentialKind(Enum): class CredentialKind(Enum):
user = 'user' user = "user"
robot = 'robot' robot = "robot"
token = ACCESS_TOKEN_USERNAME token = ACCESS_TOKEN_USERNAME
oauth_token = OAUTH_TOKEN_USERNAME oauth_token = OAUTH_TOKEN_USERNAME
app_specific_token = APP_SPECIFIC_TOKEN_USERNAME app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
def validate_credentials(auth_username, auth_password_or_token): def validate_credentials(auth_username, auth_password_or_token):
""" Validates a pair of auth username and password/token credentials. """ """ Validates a pair of auth username and password/token credentials. """
# Check for access tokens. # Check for access tokens.
if auth_username == ACCESS_TOKEN_USERNAME: if auth_username == ACCESS_TOKEN_USERNAME:
logger.debug('Found credentials for access token') logger.debug("Found credentials for access token")
try: try:
token = model.token.load_token_data(auth_password_or_token) token = model.token.load_token_data(auth_password_or_token)
logger.debug('Successfully validated credentials for access token %s', token.id) logger.debug(
return ValidateResult(AuthKind.credentials, token=token), CredentialKind.token "Successfully validated credentials for access token %s", token.id
except model.DataModelException: )
logger.warning('Failed to validate credentials for access token %s', auth_password_or_token) return (
return (ValidateResult(AuthKind.credentials, error_message='Invalid access token'), ValidateResult(AuthKind.credentials, token=token),
CredentialKind.token) CredentialKind.token,
)
except model.DataModelException:
logger.warning(
"Failed to validate credentials for access token %s",
auth_password_or_token,
)
return (
ValidateResult(
AuthKind.credentials, error_message="Invalid access token"
),
CredentialKind.token,
)
# Check for App Specific tokens. # Check for App Specific tokens.
if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME: if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
logger.debug('Found credentials for app specific auth token') logger.debug("Found credentials for app specific auth token")
token = model.appspecifictoken.access_valid_token(auth_password_or_token) token = model.appspecifictoken.access_valid_token(auth_password_or_token)
if token is None: if token is None:
logger.debug('Failed to validate credentials for app specific token: %s', logger.debug(
auth_password_or_token) "Failed to validate credentials for app specific token: %s",
return (ValidateResult(AuthKind.credentials, error_message='Invalid token'), auth_password_or_token,
CredentialKind.app_specific_token) )
return (
ValidateResult(AuthKind.credentials, error_message="Invalid token"),
CredentialKind.app_specific_token,
)
if not token.user.enabled: if not token.user.enabled:
logger.debug('Tried to use an app specific token for a disabled user: %s', logger.debug(
token.uuid) "Tried to use an app specific token for a disabled user: %s", token.uuid
return (ValidateResult(AuthKind.credentials, )
error_message='This user has been disabled. Please contact your administrator.'), return (
CredentialKind.app_specific_token) ValidateResult(
AuthKind.credentials,
error_message="This user has been disabled. Please contact your administrator.",
),
CredentialKind.app_specific_token,
)
logger.debug('Successfully validated credentials for app specific token %s', token.id) logger.debug(
return (ValidateResult(AuthKind.credentials, appspecifictoken=token), "Successfully validated credentials for app specific token %s", token.id
CredentialKind.app_specific_token) )
return (
ValidateResult(AuthKind.credentials, appspecifictoken=token),
CredentialKind.app_specific_token,
)
# Check for OAuth tokens. # Check for OAuth tokens.
if auth_username == OAUTH_TOKEN_USERNAME: if auth_username == OAUTH_TOKEN_USERNAME:
return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token
# Check for robots and users. # Check for robots and users.
is_robot = parse_robot_username(auth_username) is_robot = parse_robot_username(auth_username)
if is_robot: if is_robot:
logger.debug('Found credentials header for robot %s', auth_username) logger.debug("Found credentials header for robot %s", auth_username)
try: try:
robot = model.user.verify_robot(auth_username, auth_password_or_token) robot = model.user.verify_robot(auth_username, auth_password_or_token)
logger.debug('Successfully validated credentials for robot %s', auth_username) logger.debug(
return ValidateResult(AuthKind.credentials, robot=robot), CredentialKind.robot "Successfully validated credentials for robot %s", auth_username
except model.InvalidRobotException as ire: )
logger.warning('Failed to validate credentials for robot %s: %s', auth_username, ire) return (
return ValidateResult(AuthKind.credentials, error_message=str(ire)), CredentialKind.robot ValidateResult(AuthKind.credentials, robot=robot),
CredentialKind.robot,
)
except model.InvalidRobotException as ire:
logger.warning(
"Failed to validate credentials for robot %s: %s", auth_username, ire
)
return (
ValidateResult(AuthKind.credentials, error_message=str(ire)),
CredentialKind.robot,
)
# Otherwise, treat as a standard user. # Otherwise, treat as a standard user.
(authenticated, err) = authentication.verify_and_link_user(auth_username, auth_password_or_token, (authenticated, err) = authentication.verify_and_link_user(
basic_auth=True) auth_username, auth_password_or_token, basic_auth=True
if authenticated: )
logger.debug('Successfully validated credentials for user %s', authenticated.username) if authenticated:
return ValidateResult(AuthKind.credentials, user=authenticated), CredentialKind.user logger.debug(
else: "Successfully validated credentials for user %s", authenticated.username
logger.warning('Failed to validate credentials for user %s: %s', auth_username, err) )
return ValidateResult(AuthKind.credentials, error_message=err), CredentialKind.user return (
ValidateResult(AuthKind.credentials, user=authenticated),
CredentialKind.user,
)
else:
logger.warning(
"Failed to validate credentials for user %s: %s", auth_username, err
)
return (
ValidateResult(AuthKind.credentials, error_message=err),
CredentialKind.user,
)

View file

@ -14,83 +14,101 @@ from util.http import abort
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _auth_decorator(pass_result=False, handlers=None): def _auth_decorator(pass_result=False, handlers=None):
""" Builds an auth decorator that runs the given handlers and, if any return successfully, """ Builds an auth decorator that runs the given handlers and, if any return successfully,
sets up the auth context. The wrapped function will be invoked *regardless of success or sets up the auth context. The wrapped function will be invoked *regardless of success or
failure of the auth handler(s)* failure of the auth handler(s)*
""" """
def processor(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get('authorization', '')
result = None
for handler in handlers: def processor(func):
result = handler(auth_header) @wraps(func)
# If the handler was missing the necessary information, skip it and try the next one. def wrapper(*args, **kwargs):
if result.missing: auth_header = request.headers.get("authorization", "")
continue result = None
# Check for a valid result. for handler in handlers:
if result.auth_valid: result = handler(auth_header)
logger.debug('Found valid auth result: %s', result.tuple()) # If the handler was missing the necessary information, skip it and try the next one.
if result.missing:
continue
# Set the various pieces of the auth context. # Check for a valid result.
result.apply_to_context() if result.auth_valid:
logger.debug("Found valid auth result: %s", result.tuple())
# Log the metric. # Set the various pieces of the auth context.
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True]) result.apply_to_context()
break
# Otherwise, report the error. # Log the metric.
if result.error_message is not None: metric_queue.authentication_count.Inc(
# Log the failure. labelvalues=[result.kind, True]
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False]) )
break break
if pass_result: # Otherwise, report the error.
kwargs['auth_result'] = result if result.error_message is not None:
# Log the failure.
metric_queue.authentication_count.Inc(
labelvalues=[result.kind, False]
)
break
return func(*args, **kwargs) if pass_result:
return wrapper kwargs["auth_result"] = result
return processor
return func(*args, **kwargs)
return wrapper
return processor
process_oauth = _auth_decorator(handlers=[validate_bearer_auth, validate_session_cookie]) process_oauth = _auth_decorator(
handlers=[validate_bearer_auth, validate_session_cookie]
)
process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth]) process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth])
process_auth_or_cookie = _auth_decorator(handlers=[validate_basic_auth, validate_session_cookie]) process_auth_or_cookie = _auth_decorator(
handlers=[validate_basic_auth, validate_session_cookie]
)
process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True) process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True)
process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth]) process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth])
def require_session_login(func): def require_session_login(func):
""" Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If """ Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
a valid session cookie does exist, the authenticated user and identity are also set. a valid session cookie does exist, the authenticated user and identity are also set.
""" """
@wraps(func)
def wrapper(*args, **kwargs):
result = validate_session_cookie()
if result.has_nonrobot_user:
result.apply_to_context()
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
return func(*args, **kwargs)
elif not result.missing:
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
abort(401, message='Method requires login and no valid login could be loaded.') @wraps(func)
return wrapper def wrapper(*args, **kwargs):
result = validate_session_cookie()
if result.has_nonrobot_user:
result.apply_to_context()
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
return func(*args, **kwargs)
elif not result.missing:
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
abort(401, message="Method requires login and no valid login could be loaded.")
return wrapper
def extract_namespace_repo_from_session(func): def extract_namespace_repo_from_session(func):
""" Extracts the namespace and repository name from the current session (which must exist) """ Extracts the namespace and repository name from the current session (which must exist)
and passes them into the decorated function as the first and second arguments. If the and passes them into the decorated function as the first and second arguments. If the
session doesn't exist or does not contain these arugments, a 400 error is raised. session doesn't exist or does not contain these arugments, a 400 error is raised.
""" """
@wraps(func)
def wrapper(*args, **kwargs):
if 'namespace' not in session or 'repository' not in session:
logger.error('Unable to load namespace or repository from session: %s', session)
abort(400, message='Missing namespace in request')
return func(session['namespace'], session['repository'], *args, **kwargs) @wraps(func)
return wrapper def wrapper(*args, **kwargs):
if "namespace" not in session or "repository" not in session:
logger.error(
"Unable to load namespace or repository from session: %s", session
)
abort(400, message="Missing namespace in request")
return func(session["namespace"], session["repository"], *args, **kwargs)
return wrapper

View file

@ -8,41 +8,47 @@ from data import model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_bearer_auth(auth_header): def validate_bearer_auth(auth_header):
""" Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it """ Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
points to a valid OAuth token. points to a valid OAuth token.
""" """
if not auth_header: if not auth_header:
return ValidateResult(AuthKind.oauth, missing=True) return ValidateResult(AuthKind.oauth, missing=True)
normalized = [part.strip() for part in auth_header.split(' ') if part] normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != 'bearer' or len(normalized) != 2: if normalized[0].lower() != "bearer" or len(normalized) != 2:
logger.debug('Got invalid bearer token format: %s', auth_header) logger.debug("Got invalid bearer token format: %s", auth_header)
return ValidateResult(AuthKind.oauth, missing=True) return ValidateResult(AuthKind.oauth, missing=True)
(_, oauth_token) = normalized (_, oauth_token) = normalized
return validate_oauth_token(oauth_token) return validate_oauth_token(oauth_token)
def validate_oauth_token(token): def validate_oauth_token(token):
""" Validates the specified OAuth token, returning whether it points to a valid OAuth token. """ Validates the specified OAuth token, returning whether it points to a valid OAuth token.
""" """
validated = model.oauth.validate_access_token(token) validated = model.oauth.validate_access_token(token)
if not validated: if not validated:
logger.warning('OAuth access token could not be validated: %s', token) logger.warning("OAuth access token could not be validated: %s", token)
return ValidateResult(AuthKind.oauth, return ValidateResult(
error_message='OAuth access token could not be validated') AuthKind.oauth, error_message="OAuth access token could not be validated"
)
if validated.expires_at <= datetime.utcnow(): if validated.expires_at <= datetime.utcnow():
logger.warning('OAuth access with an expired token: %s', token) logger.warning("OAuth access with an expired token: %s", token)
return ValidateResult(AuthKind.oauth, error_message='OAuth access token has expired') return ValidateResult(
AuthKind.oauth, error_message="OAuth access token has expired"
)
# Don't allow disabled users to login. # Don't allow disabled users to login.
if not validated.authorized_user.enabled: if not validated.authorized_user.enabled:
return ValidateResult(AuthKind.oauth, return ValidateResult(
error_message='Granter of the oauth access token is disabled') AuthKind.oauth,
error_message="Granter of the oauth access token is disabled",
)
# We have a valid token # We have a valid token
scope_set = scopes_from_scope_string(validated.scope) scope_set = scopes_from_scope_string(validated.scope)
logger.debug('Successfully validated oauth access token with scope: %s', scope_set) logger.debug("Successfully validated oauth access token with scope: %s", scope_set)
return ValidateResult(AuthKind.oauth, oauthtoken=validated) return ValidateResult(AuthKind.oauth, oauthtoken=validated)

View file

@ -14,351 +14,399 @@ from data import model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ResourceNeed = namedtuple('resource', ['type', 'namespace', 'name', 'role']) _ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"])
_RepositoryNeed = partial(_ResourceNeed, 'repository') _RepositoryNeed = partial(_ResourceNeed, "repository")
_NamespaceWideNeed = namedtuple('namespacewide', ['type', 'namespace', 'role']) _NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"])
_OrganizationNeed = partial(_NamespaceWideNeed, 'organization') _OrganizationNeed = partial(_NamespaceWideNeed, "organization")
_OrganizationRepoNeed = partial(_NamespaceWideNeed, 'organizationrepo') _OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo")
_TeamTypeNeed = namedtuple('teamwideneed', ['type', 'orgname', 'teamname', 'role']) _TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"])
_TeamNeed = partial(_TeamTypeNeed, 'orgteam') _TeamNeed = partial(_TeamTypeNeed, "orgteam")
_UserTypeNeed = namedtuple('userspecificneed', ['type', 'username', 'role']) _UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"])
_UserNeed = partial(_UserTypeNeed, 'user') _UserNeed = partial(_UserTypeNeed, "user")
_SuperUserNeed = partial(namedtuple('superuserneed', ['type']), 'superuser') _SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser")
REPO_ROLES = [None, 'read', 'write', 'admin'] REPO_ROLES = [None, "read", "write", "admin"]
TEAM_ROLES = [None, 'member', 'creator', 'admin'] TEAM_ROLES = [None, "member", "creator", "admin"]
USER_ROLES = [None, 'read', 'admin'] USER_ROLES = [None, "read", "admin"]
TEAM_ORGWIDE_REPO_ROLES = { TEAM_ORGWIDE_REPO_ROLES = {"admin": "admin", "creator": None, "member": None}
'admin': 'admin',
'creator': None,
'member': None,
}
SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None) SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None)
SCOPE_MAX_REPO_ROLES.update({ SCOPE_MAX_REPO_ROLES.update(
scopes.READ_REPO: 'read', {
scopes.WRITE_REPO: 'write', scopes.READ_REPO: "read",
scopes.ADMIN_REPO: 'admin', scopes.WRITE_REPO: "write",
scopes.DIRECT_LOGIN: 'admin', scopes.ADMIN_REPO: "admin",
}) scopes.DIRECT_LOGIN: "admin",
}
)
SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None) SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None)
SCOPE_MAX_TEAM_ROLES.update({ SCOPE_MAX_TEAM_ROLES.update(
scopes.CREATE_REPO: 'creator', {
scopes.DIRECT_LOGIN: 'admin', scopes.CREATE_REPO: "creator",
scopes.ORG_ADMIN: 'admin', scopes.DIRECT_LOGIN: "admin",
}) scopes.ORG_ADMIN: "admin",
}
)
SCOPE_MAX_USER_ROLES = defaultdict(lambda: None) SCOPE_MAX_USER_ROLES = defaultdict(lambda: None)
SCOPE_MAX_USER_ROLES.update({ SCOPE_MAX_USER_ROLES.update(
scopes.READ_USER: 'read', {scopes.READ_USER: "read", scopes.DIRECT_LOGIN: "admin", scopes.ADMIN_USER: "admin"}
scopes.DIRECT_LOGIN: 'admin', )
scopes.ADMIN_USER: 'admin',
})
def repository_read_grant(namespace, repository): def repository_read_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'read') return _RepositoryNeed(namespace, repository, "read")
def repository_write_grant(namespace, repository): def repository_write_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'write') return _RepositoryNeed(namespace, repository, "write")
def repository_admin_grant(namespace, repository): def repository_admin_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'admin') return _RepositoryNeed(namespace, repository, "admin")
class QuayDeferredPermissionUser(Identity): class QuayDeferredPermissionUser(Identity):
def __init__(self, uuid, auth_type, auth_scopes, user=None): def __init__(self, uuid, auth_type, auth_scopes, user=None):
super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type) super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type)
self._namespace_wide_loaded = set() self._namespace_wide_loaded = set()
self._repositories_loaded = set() self._repositories_loaded = set()
self._personal_loaded = False self._personal_loaded = False
self._scope_set = auth_scopes self._scope_set = auth_scopes
self._user_object = user self._user_object = user
@staticmethod @staticmethod
def for_id(uuid, auth_scopes=None): def for_id(uuid, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN} auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(uuid, 'user_uuid', auth_scopes) return QuayDeferredPermissionUser(uuid, "user_uuid", auth_scopes)
@staticmethod @staticmethod
def for_user(user, auth_scopes=None): def for_user(user, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN} auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(user.uuid, 'user_uuid', auth_scopes, user=user) return QuayDeferredPermissionUser(
user.uuid, "user_uuid", auth_scopes, user=user
)
def _translate_role_for_scopes(self, cardinality, max_roles, role): def _translate_role_for_scopes(self, cardinality, max_roles, role):
if self._scope_set is None: if self._scope_set is None:
return role return role
max_for_scopes = max({cardinality.index(max_roles[scope]) for scope in self._scope_set}) max_for_scopes = max(
{cardinality.index(max_roles[scope]) for scope in self._scope_set}
)
if max_for_scopes < cardinality.index(role): if max_for_scopes < cardinality.index(role):
logger.debug('Translated permission %s -> %s', role, cardinality[max_for_scopes]) logger.debug(
return cardinality[max_for_scopes] "Translated permission %s -> %s", role, cardinality[max_for_scopes]
else: )
return role return cardinality[max_for_scopes]
else:
return role
def _team_role_for_scopes(self, role): def _team_role_for_scopes(self, role):
return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role) return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role)
def _repo_role_for_scopes(self, role): def _repo_role_for_scopes(self, role):
return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role) return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role)
def _user_role_for_scopes(self, role): def _user_role_for_scopes(self, role):
return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role) return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role)
def _populate_user_provides(self, user_object): def _populate_user_provides(self, user_object):
""" Populates the provides that naturally apply to a user, such as being the admin of """ Populates the provides that naturally apply to a user, such as being the admin of
their own namespace. their own namespace.
""" """
# Add the user specific permissions, only for non-oauth permission # Add the user specific permissions, only for non-oauth permission
user_grant = _UserNeed(user_object.username, self._user_role_for_scopes('admin')) user_grant = _UserNeed(
logger.debug('User permission: {0}'.format(user_grant)) user_object.username, self._user_role_for_scopes("admin")
self.provides.add(user_grant) )
logger.debug("User permission: {0}".format(user_grant))
self.provides.add(user_grant)
# Every user is the admin of their own 'org' # Every user is the admin of their own 'org'
user_namespace = _OrganizationNeed(user_object.username, self._team_role_for_scopes('admin')) user_namespace = _OrganizationNeed(
logger.debug('User namespace permission: {0}'.format(user_namespace)) user_object.username, self._team_role_for_scopes("admin")
self.provides.add(user_namespace) )
logger.debug("User namespace permission: {0}".format(user_namespace))
self.provides.add(user_namespace)
# Org repo roles can differ for scopes # Org repo roles can differ for scopes
user_repos = _OrganizationRepoNeed(user_object.username, self._repo_role_for_scopes('admin')) user_repos = _OrganizationRepoNeed(
logger.debug('User namespace repo permission: {0}'.format(user_repos)) user_object.username, self._repo_role_for_scopes("admin")
self.provides.add(user_repos) )
logger.debug("User namespace repo permission: {0}".format(user_repos))
self.provides.add(user_repos)
if ((scopes.SUPERUSER in self._scope_set or scopes.DIRECT_LOGIN in self._scope_set) and if (
superusers.is_superuser(user_object.username)): scopes.SUPERUSER in self._scope_set
logger.debug('Adding superuser to user: %s', user_object.username) or scopes.DIRECT_LOGIN in self._scope_set
self.provides.add(_SuperUserNeed()) ) and superusers.is_superuser(user_object.username):
logger.debug("Adding superuser to user: %s", user_object.username)
self.provides.add(_SuperUserNeed())
def _populate_namespace_wide_provides(self, user_object, namespace_filter): def _populate_namespace_wide_provides(self, user_object, namespace_filter):
""" Populates the namespace-wide provides for a particular user under a particular namespace. """ Populates the namespace-wide provides for a particular user under a particular namespace.
This method does *not* add any provides for specific repositories. This method does *not* add any provides for specific repositories.
""" """
for team in model.permission.get_org_wide_permissions(user_object, org_filter=namespace_filter): for team in model.permission.get_org_wide_permissions(
team_org_grant = _OrganizationNeed(team.organization.username, user_object, org_filter=namespace_filter
self._team_role_for_scopes(team.role.name)) ):
logger.debug('Organization team added permission: {0}'.format(team_org_grant)) team_org_grant = _OrganizationNeed(
self.provides.add(team_org_grant) team.organization.username, self._team_role_for_scopes(team.role.name)
)
logger.debug(
"Organization team added permission: {0}".format(team_org_grant)
)
self.provides.add(team_org_grant)
team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name] team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
org_repo_grant = _OrganizationRepoNeed(team.organization.username, org_repo_grant = _OrganizationRepoNeed(
self._repo_role_for_scopes(team_repo_role)) team.organization.username, self._repo_role_for_scopes(team_repo_role)
logger.debug('Organization team added repo permission: {0}'.format(org_repo_grant)) )
self.provides.add(org_repo_grant) logger.debug(
"Organization team added repo permission: {0}".format(org_repo_grant)
)
self.provides.add(org_repo_grant)
team_grant = _TeamNeed(team.organization.username, team.name, team_grant = _TeamNeed(
self._team_role_for_scopes(team.role.name)) team.organization.username,
logger.debug('Team added permission: {0}'.format(team_grant)) team.name,
self.provides.add(team_grant) self._team_role_for_scopes(team.role.name),
)
logger.debug("Team added permission: {0}".format(team_grant))
self.provides.add(team_grant)
def _populate_repository_provides(self, user_object, namespace_filter, repository_name): def _populate_repository_provides(
""" Populates the repository-specific provides for a particular user and repository. """ self, user_object, namespace_filter, repository_name
):
""" Populates the repository-specific provides for a particular user and repository. """
if namespace_filter and repository_name: if namespace_filter and repository_name:
permissions = model.permission.get_user_repository_permissions(user_object, namespace_filter, permissions = model.permission.get_user_repository_permissions(
repository_name) user_object, namespace_filter, repository_name
else: )
permissions = model.permission.get_all_user_repository_permissions(user_object) else:
permissions = model.permission.get_all_user_repository_permissions(
user_object
)
for perm in permissions: for perm in permissions:
repo_grant = _RepositoryNeed(perm.repository.namespace_user.username, perm.repository.name, repo_grant = _RepositoryNeed(
self._repo_role_for_scopes(perm.role.name)) perm.repository.namespace_user.username,
logger.debug('User added permission: {0}'.format(repo_grant)) perm.repository.name,
self.provides.add(repo_grant) self._repo_role_for_scopes(perm.role.name),
)
logger.debug("User added permission: {0}".format(repo_grant))
self.provides.add(repo_grant)
def can(self, permission): def can(self, permission):
logger.debug('Loading user permissions after deferring for: %s', self.id) logger.debug("Loading user permissions after deferring for: %s", self.id)
user_object = self._user_object or model.user.get_user_by_uuid(self.id) user_object = self._user_object or model.user.get_user_by_uuid(self.id)
if user_object is None: if user_object is None:
return super(QuayDeferredPermissionUser, self).can(permission) return super(QuayDeferredPermissionUser, self).can(permission)
# Add the user-specific provides. # Add the user-specific provides.
if not self._personal_loaded: if not self._personal_loaded:
self._populate_user_provides(user_object) self._populate_user_provides(user_object)
self._personal_loaded = True self._personal_loaded = True
# If we now have permission, no need to load any more permissions. # If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission): if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission) return super(QuayDeferredPermissionUser, self).can(permission)
# Check for namespace and/or repository permissions. # Check for namespace and/or repository permissions.
perm_namespace = permission.namespace perm_namespace = permission.namespace
perm_repo_name = permission.repo_name perm_repo_name = permission.repo_name
perm_repository = None perm_repository = None
if perm_namespace and perm_repo_name: if perm_namespace and perm_repo_name:
perm_repository = '%s/%s' % (perm_namespace, perm_repo_name) perm_repository = "%s/%s" % (perm_namespace, perm_repo_name)
if not perm_namespace and not perm_repo_name: if not perm_namespace and not perm_repo_name:
# Nothing more to load, so just check directly. # Nothing more to load, so just check directly.
return super(QuayDeferredPermissionUser, self).can(permission) return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the repository-specific permissions. # Lazy-load the repository-specific permissions.
if perm_repository and perm_repository not in self._repositories_loaded: if perm_repository and perm_repository not in self._repositories_loaded:
self._populate_repository_provides(user_object, perm_namespace, perm_repo_name) self._populate_repository_provides(
self._repositories_loaded.add(perm_repository) user_object, perm_namespace, perm_repo_name
)
self._repositories_loaded.add(perm_repository)
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the namespace-wide-only permissions.
if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
self._populate_namespace_wide_provides(user_object, perm_namespace)
self._namespace_wide_loaded.add(perm_namespace)
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission) return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the namespace-wide-only permissions.
if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
self._populate_namespace_wide_provides(user_object, perm_namespace)
self._namespace_wide_loaded.add(perm_namespace)
return super(QuayDeferredPermissionUser, self).can(permission)
class QuayPermission(Permission): class QuayPermission(Permission):
""" Base for all permissions in Quay. """ """ Base for all permissions in Quay. """
namespace = None
repo_name = None namespace = None
repo_name = None
class ModifyRepositoryPermission(QuayPermission): class ModifyRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, 'write') write_need = _RepositoryNeed(namespace, name, "write")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, 'write') org_write_need = _OrganizationRepoNeed(namespace, "write")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(ModifyRepositoryPermission, self).__init__(admin_need, write_need, org_admin_need, super(ModifyRepositoryPermission, self).__init__(
org_write_need) admin_need, write_need, org_admin_need, org_write_need
)
class ReadRepositoryPermission(QuayPermission): class ReadRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, 'write') write_need = _RepositoryNeed(namespace, name, "write")
read_need = _RepositoryNeed(namespace, name, 'read') read_need = _RepositoryNeed(namespace, name, "read")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, 'write') org_write_need = _OrganizationRepoNeed(namespace, "write")
org_read_need = _OrganizationRepoNeed(namespace, 'read') org_read_need = _OrganizationRepoNeed(namespace, "read")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(ReadRepositoryPermission, self).__init__(admin_need, write_need, read_need, super(ReadRepositoryPermission, self).__init__(
org_admin_need, org_read_need, org_write_need) admin_need,
write_need,
read_need,
org_admin_need,
org_read_need,
org_write_need,
)
class AdministerRepositoryPermission(QuayPermission): class AdministerRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(AdministerRepositoryPermission, self).__init__(admin_need, super(AdministerRepositoryPermission, self).__init__(admin_need, org_admin_need)
org_admin_need)
class CreateRepositoryPermission(QuayPermission): class CreateRepositoryPermission(QuayPermission):
def __init__(self, namespace): def __init__(self, namespace):
admin_org = _OrganizationNeed(namespace, 'admin') admin_org = _OrganizationNeed(namespace, "admin")
create_repo_org = _OrganizationNeed(namespace, 'creator') create_repo_org = _OrganizationNeed(namespace, "creator")
self.namespace = namespace self.namespace = namespace
super(CreateRepositoryPermission, self).__init__(admin_org, create_repo_org)
super(CreateRepositoryPermission, self).__init__(admin_org,
create_repo_org)
class SuperUserPermission(QuayPermission): class SuperUserPermission(QuayPermission):
def __init__(self): def __init__(self):
need = _SuperUserNeed() need = _SuperUserNeed()
super(SuperUserPermission, self).__init__(need) super(SuperUserPermission, self).__init__(need)
class UserAdminPermission(QuayPermission): class UserAdminPermission(QuayPermission):
def __init__(self, username): def __init__(self, username):
user_admin = _UserNeed(username, 'admin') user_admin = _UserNeed(username, "admin")
super(UserAdminPermission, self).__init__(user_admin) super(UserAdminPermission, self).__init__(user_admin)
class UserReadPermission(QuayPermission): class UserReadPermission(QuayPermission):
def __init__(self, username): def __init__(self, username):
user_admin = _UserNeed(username, 'admin') user_admin = _UserNeed(username, "admin")
user_read = _UserNeed(username, 'read') user_read = _UserNeed(username, "read")
super(UserReadPermission, self).__init__(user_read, user_admin) super(UserReadPermission, self).__init__(user_read, user_admin)
class AdministerOrganizationPermission(QuayPermission): class AdministerOrganizationPermission(QuayPermission):
def __init__(self, org_name): def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name self.namespace = org_name
super(AdministerOrganizationPermission, self).__init__(admin_org) super(AdministerOrganizationPermission, self).__init__(admin_org)
class OrganizationMemberPermission(QuayPermission): class OrganizationMemberPermission(QuayPermission):
def __init__(self, org_name): def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
repo_creator_org = _OrganizationNeed(org_name, 'creator') repo_creator_org = _OrganizationNeed(org_name, "creator")
org_member = _OrganizationNeed(org_name, 'member') org_member = _OrganizationNeed(org_name, "member")
self.namespace = org_name self.namespace = org_name
super(OrganizationMemberPermission, self).__init__(admin_org, org_member, super(OrganizationMemberPermission, self).__init__(
repo_creator_org) admin_org, org_member, repo_creator_org
)
class ViewTeamPermission(QuayPermission): class ViewTeamPermission(QuayPermission):
def __init__(self, org_name, team_name): def __init__(self, org_name, team_name):
team_admin = _TeamNeed(org_name, team_name, 'admin') team_admin = _TeamNeed(org_name, team_name, "admin")
team_creator = _TeamNeed(org_name, team_name, 'creator') team_creator = _TeamNeed(org_name, team_name, "creator")
team_member = _TeamNeed(org_name, team_name, 'member') team_member = _TeamNeed(org_name, team_name, "member")
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name self.namespace = org_name
super(ViewTeamPermission, self).__init__(team_admin, team_creator, super(ViewTeamPermission, self).__init__(
team_member, admin_org) team_admin, team_creator, team_member, admin_org
)
class AlwaysFailPermission(QuayPermission): class AlwaysFailPermission(QuayPermission):
def can(self): def can(self):
return False return False
@identity_loaded.connect_via(app) @identity_loaded.connect_via(app)
def on_identity_loaded(sender, identity): def on_identity_loaded(sender, identity):
logger.debug('Identity loaded: %s' % identity) logger.debug("Identity loaded: %s" % identity)
# We have verified an identity, load in all of the permissions # We have verified an identity, load in all of the permissions
if isinstance(identity, QuayDeferredPermissionUser): if isinstance(identity, QuayDeferredPermissionUser):
logger.debug('Deferring permissions for user with uuid: %s', identity.id) logger.debug("Deferring permissions for user with uuid: %s", identity.id)
elif identity.auth_type == 'user_uuid': elif identity.auth_type == "user_uuid":
logger.debug('Switching username permission to deferred object with uuid: %s', identity.id) logger.debug(
switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id) "Switching username permission to deferred object with uuid: %s",
identity_changed.send(app, identity=switch_to_deferred) identity.id,
)
switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
identity_changed.send(app, identity=switch_to_deferred)
elif identity.auth_type == 'token': elif identity.auth_type == "token":
logger.debug('Loading permissions for token: %s', identity.id) logger.debug("Loading permissions for token: %s", identity.id)
token_data = model.token.load_token_data(identity.id) token_data = model.token.load_token_data(identity.id)
repo_grant = _RepositoryNeed(token_data.repository.namespace_user.username, repo_grant = _RepositoryNeed(
token_data.repository.name, token_data.repository.namespace_user.username,
token_data.role.name) token_data.repository.name,
logger.debug('Delegate token added permission: %s', repo_grant) token_data.role.name,
identity.provides.add(repo_grant) )
logger.debug("Delegate token added permission: %s", repo_grant)
identity.provides.add(repo_grant)
elif identity.auth_type == 'signed_grant' or identity.auth_type == 'signed_jwt': elif identity.auth_type == "signed_grant" or identity.auth_type == "signed_jwt":
logger.debug('Loaded %s identity for: %s', identity.auth_type, identity.id) logger.debug("Loaded %s identity for: %s", identity.auth_type, identity.id)
else: else:
logger.error('Unknown identity auth type: %s', identity.auth_type) logger.error("Unknown identity auth type: %s", identity.auth_type)

View file

@ -9,156 +9,166 @@ from flask_principal import identity_changed, Identity
from app import app, get_app_url, instance_keys, metric_queue from app import app, get_app_url, instance_keys, metric_queue
from auth.auth_context import set_authenticated_context from auth.auth_context import set_authenticated_context
from auth.auth_context_type import SignedAuthContext from auth.auth_context_type import SignedAuthContext
from auth.permissions import repository_read_grant, repository_write_grant, repository_admin_grant from auth.permissions import (
repository_read_grant,
repository_write_grant,
repository_admin_grant,
)
from util.http import abort from util.http import abort
from util.names import parse_namespace_repository from util.names import parse_namespace_repository
from util.security.registry_jwt import (ANONYMOUS_SUB, decode_bearer_header, from util.security.registry_jwt import (
InvalidBearerTokenException) ANONYMOUS_SUB,
decode_bearer_header,
InvalidBearerTokenException,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ACCESS_SCHEMA = { ACCESS_SCHEMA = {
'type': 'array', "type": "array",
'description': 'List of access granted to the subject', "description": "List of access granted to the subject",
'items': { "items": {
'type': 'object', "type": "object",
'required': [ "required": ["type", "name", "actions"],
'type', "properties": {
'name', "type": {
'actions', "type": "string",
], "description": "We only allow repository permissions",
'properties': { "enum": ["repository"],
'type': { },
'type': 'string', "name": {
'description': 'We only allow repository permissions', "type": "string",
'enum': [ "description": "The name of the repository for which we are receiving access",
'repository', },
], "actions": {
}, "type": "array",
'name': { "description": "List of specific verbs which can be performed against repository",
'type': 'string', "items": {"type": "string", "enum": ["push", "pull", "*"]},
'description': 'The name of the repository for which we are receiving access' },
},
'actions': {
'type': 'array',
'description': 'List of specific verbs which can be performed against repository',
'items': {
'type': 'string',
'enum': [
'push',
'pull',
'*',
],
}, },
},
}, },
},
} }
class InvalidJWTException(Exception): class InvalidJWTException(Exception):
pass pass
def get_auth_headers(repository=None, scopes=None): def get_auth_headers(repository=None, scopes=None):
""" Returns a dictionary of headers for auth responses. """ """ Returns a dictionary of headers for auth responses. """
headers = {} headers = {}
realm_auth_path = url_for('v2.generate_registry_jwt') realm_auth_path = url_for("v2.generate_registry_jwt")
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(), authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(
realm_auth_path, get_app_url(), realm_auth_path, app.config["SERVER_HOSTNAME"]
app.config['SERVER_HOSTNAME']) )
if repository: if repository:
scopes_string = "repository:{0}".format(repository) scopes_string = "repository:{0}".format(repository)
if scopes: if scopes:
scopes_string += ':' + ','.join(scopes) scopes_string += ":" + ",".join(scopes)
authenticate += ',scope="{0}"'.format(scopes_string) authenticate += ',scope="{0}"'.format(scopes_string)
headers['WWW-Authenticate'] = authenticate headers["WWW-Authenticate"] = authenticate
headers['Docker-Distribution-API-Version'] = 'registry/2.0' headers["Docker-Distribution-API-Version"] = "registry/2.0"
return headers return headers
def identity_from_bearer_token(bearer_header): def identity_from_bearer_token(bearer_header):
""" Process a bearer header and return the loaded identity, or raise InvalidJWTException if an """ Process a bearer header and return the loaded identity, or raise InvalidJWTException if an
identity could not be loaded. Expects tokens and grants in the format of the Docker registry identity could not be loaded. Expects tokens and grants in the format of the Docker registry
v2 auth spec: https://docs.docker.com/registry/spec/auth/token/ v2 auth spec: https://docs.docker.com/registry/spec/auth/token/
""" """
logger.debug('Validating auth header: %s', bearer_header) logger.debug("Validating auth header: %s", bearer_header)
try:
payload = decode_bearer_header(bearer_header, instance_keys, app.config,
metric_queue=metric_queue)
except InvalidBearerTokenException as bte:
logger.exception('Invalid bearer token: %s', bte)
raise InvalidJWTException(bte)
loaded_identity = Identity(payload['sub'], 'signed_jwt')
# Process the grants from the payload
if 'access' in payload:
try: try:
validate(payload['access'], ACCESS_SCHEMA) payload = decode_bearer_header(
except ValidationError: bearer_header, instance_keys, app.config, metric_queue=metric_queue
logger.exception('We should not be minting invalid credentials') )
raise InvalidJWTException('Token contained invalid or malformed access grants') except InvalidBearerTokenException as bte:
logger.exception("Invalid bearer token: %s", bte)
raise InvalidJWTException(bte)
lib_namespace = app.config['LIBRARY_NAMESPACE'] loaded_identity = Identity(payload["sub"], "signed_jwt")
for grant in payload['access']:
namespace, repo_name = parse_namespace_repository(grant['name'], lib_namespace)
if '*' in grant['actions']: # Process the grants from the payload
loaded_identity.provides.add(repository_admin_grant(namespace, repo_name)) if "access" in payload:
elif 'push' in grant['actions']: try:
loaded_identity.provides.add(repository_write_grant(namespace, repo_name)) validate(payload["access"], ACCESS_SCHEMA)
elif 'pull' in grant['actions']: except ValidationError:
loaded_identity.provides.add(repository_read_grant(namespace, repo_name)) logger.exception("We should not be minting invalid credentials")
raise InvalidJWTException(
"Token contained invalid or malformed access grants"
)
default_context = { lib_namespace = app.config["LIBRARY_NAMESPACE"]
'kind': 'anonymous' for grant in payload["access"]:
} namespace, repo_name = parse_namespace_repository(
grant["name"], lib_namespace
)
if payload['sub'] != ANONYMOUS_SUB: if "*" in grant["actions"]:
default_context = { loaded_identity.provides.add(
'kind': 'user', repository_admin_grant(namespace, repo_name)
'user': payload['sub'], )
} elif "push" in grant["actions"]:
loaded_identity.provides.add(
repository_write_grant(namespace, repo_name)
)
elif "pull" in grant["actions"]:
loaded_identity.provides.add(
repository_read_grant(namespace, repo_name)
)
return loaded_identity, payload.get('context', default_context) default_context = {"kind": "anonymous"}
if payload["sub"] != ANONYMOUS_SUB:
default_context = {"kind": "user", "user": payload["sub"]}
return loaded_identity, payload.get("context", default_context)
def process_registry_jwt_auth(scopes=None): def process_registry_jwt_auth(scopes=None):
""" Processes the registry JWT auth token found in the authorization header. If none found, """ Processes the registry JWT auth token found in the authorization header. If none found,
no error is returned. If an invalid token is found, raises a 401. no error is returned. If an invalid token is found, raises a 401.
""" """
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger.debug('Called with params: %s, %s', args, kwargs)
auth = request.headers.get('authorization', '').strip()
if auth:
try:
extracted_identity, context_dict = identity_from_bearer_token(auth)
identity_changed.send(app, identity=extracted_identity)
logger.debug('Identity changed to %s', extracted_identity.id)
auth_context = SignedAuthContext.build_from_signed_dict(context_dict) def inner(func):
if auth_context is not None: @wraps(func)
logger.debug('Auth context set to %s', auth_context.signed_data) def wrapper(*args, **kwargs):
set_authenticated_context(auth_context) logger.debug("Called with params: %s, %s", args, kwargs)
auth = request.headers.get("authorization", "").strip()
if auth:
try:
extracted_identity, context_dict = identity_from_bearer_token(auth)
identity_changed.send(app, identity=extracted_identity)
logger.debug("Identity changed to %s", extracted_identity.id)
except InvalidJWTException as ije: auth_context = SignedAuthContext.build_from_signed_dict(
repository = None context_dict
if 'namespace_name' in kwargs and 'repo_name' in kwargs: )
repository = kwargs['namespace_name'] + '/' + kwargs['repo_name'] if auth_context is not None:
logger.debug("Auth context set to %s", auth_context.signed_data)
set_authenticated_context(auth_context)
abort(401, message=ije.message, headers=get_auth_headers(repository=repository, except InvalidJWTException as ije:
scopes=scopes)) repository = None
else: if "namespace_name" in kwargs and "repo_name" in kwargs:
logger.debug('No auth header.') repository = (
kwargs["namespace_name"] + "/" + kwargs["repo_name"]
)
return func(*args, **kwargs) abort(
return wrapper 401,
return inner message=ije.message,
headers=get_auth_headers(repository=repository, scopes=scopes),
)
else:
logger.debug("No auth header.")
return func(*args, **kwargs)
return wrapper
return inner

View file

@ -2,145 +2,194 @@ from collections import namedtuple
import features import features
import re import re
Scope = namedtuple('scope', ['scope', 'icon', 'dangerous', 'title', 'description']) Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"])
READ_REPO = Scope(scope='repo:read', READ_REPO = Scope(
icon='fa-hdd-o', scope="repo:read",
dangerous=False, icon="fa-hdd-o",
title='View all visible repositories', dangerous=False,
description=('This application will be able to view and pull all repositories ' title="View all visible repositories",
'visible to the granting user or robot account')) description=(
"This application will be able to view and pull all repositories "
"visible to the granting user or robot account"
),
)
WRITE_REPO = Scope(scope='repo:write', WRITE_REPO = Scope(
icon='fa-hdd-o', scope="repo:write",
dangerous=False, icon="fa-hdd-o",
title='Read/Write to any accessible repositories', dangerous=False,
description=('This application will be able to view, push and pull to all ' title="Read/Write to any accessible repositories",
'repositories to which the granting user or robot account has ' description=(
'write access')) "This application will be able to view, push and pull to all "
"repositories to which the granting user or robot account has "
"write access"
),
)
ADMIN_REPO = Scope(scope='repo:admin', ADMIN_REPO = Scope(
icon='fa-hdd-o', scope="repo:admin",
dangerous=False, icon="fa-hdd-o",
title='Administer Repositories', dangerous=False,
description=('This application will have administrator access to all ' title="Administer Repositories",
'repositories to which the granting user or robot account has ' description=(
'access')) "This application will have administrator access to all "
"repositories to which the granting user or robot account has "
"access"
),
)
CREATE_REPO = Scope(scope='repo:create', CREATE_REPO = Scope(
icon='fa-plus', scope="repo:create",
dangerous=False, icon="fa-plus",
title='Create Repositories', dangerous=False,
description=('This application will be able to create repositories in to any ' title="Create Repositories",
'namespaces that the granting user or robot account is allowed ' description=(
'to create repositories')) "This application will be able to create repositories in to any "
"namespaces that the granting user or robot account is allowed "
"to create repositories"
),
)
READ_USER = Scope(scope= 'user:read', READ_USER = Scope(
icon='fa-user', scope="user:read",
dangerous=False, icon="fa-user",
title='Read User Information', dangerous=False,
description=('This application will be able to read user information such as ' title="Read User Information",
'username and email address.')) description=(
"This application will be able to read user information such as "
"username and email address."
),
)
ADMIN_USER = Scope(scope= 'user:admin', ADMIN_USER = Scope(
icon='fa-gear', scope="user:admin",
dangerous=True, icon="fa-gear",
title='Administer User', dangerous=True,
description=('This application will be able to administer your account ' title="Administer User",
'including creating robots and granting them permissions ' description=(
'to your repositories. You should have absolute trust in the ' "This application will be able to administer your account "
'requesting application before granting this permission.')) "including creating robots and granting them permissions "
"to your repositories. You should have absolute trust in the "
"requesting application before granting this permission."
),
)
ORG_ADMIN = Scope(scope='org:admin', ORG_ADMIN = Scope(
icon='fa-gear', scope="org:admin",
dangerous=True, icon="fa-gear",
title='Administer Organization', dangerous=True,
description=('This application will be able to administer your organizations ' title="Administer Organization",
'including creating robots, creating teams, adjusting team ' description=(
'membership, and changing billing settings. You should have ' "This application will be able to administer your organizations "
'absolute trust in the requesting application before granting this ' "including creating robots, creating teams, adjusting team "
'permission.')) "membership, and changing billing settings. You should have "
"absolute trust in the requesting application before granting this "
"permission."
),
)
DIRECT_LOGIN = Scope(scope='direct_user_login', DIRECT_LOGIN = Scope(
icon='fa-exclamation-triangle', scope="direct_user_login",
dangerous=True, icon="fa-exclamation-triangle",
title='Full Access', dangerous=True,
description=('This scope should not be available to OAuth applications. ' title="Full Access",
'Never approve a request for this scope!')) description=(
"This scope should not be available to OAuth applications. "
"Never approve a request for this scope!"
),
)
SUPERUSER = Scope(scope='super:user', SUPERUSER = Scope(
icon='fa-street-view', scope="super:user",
dangerous=True, icon="fa-street-view",
title='Super User Access', dangerous=True,
description=('This application will be able to administer your installation ' title="Super User Access",
'including managing users, managing organizations and other ' description=(
'features found in the superuser panel. You should have ' "This application will be able to administer your installation "
'absolute trust in the requesting application before granting this ' "including managing users, managing organizations and other "
'permission.')) "features found in the superuser panel. You should have "
"absolute trust in the requesting application before granting this "
"permission."
),
)
ALL_SCOPES = {scope.scope: scope for scope in (READ_REPO, WRITE_REPO, ADMIN_REPO, CREATE_REPO, ALL_SCOPES = {
READ_USER, ORG_ADMIN, SUPERUSER, ADMIN_USER)} scope.scope: scope
for scope in (
READ_REPO,
WRITE_REPO,
ADMIN_REPO,
CREATE_REPO,
READ_USER,
ORG_ADMIN,
SUPERUSER,
ADMIN_USER,
)
}
IMPLIED_SCOPES = { IMPLIED_SCOPES = {
ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO}, ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
WRITE_REPO: {WRITE_REPO, READ_REPO}, WRITE_REPO: {WRITE_REPO, READ_REPO},
READ_REPO: {READ_REPO}, READ_REPO: {READ_REPO},
CREATE_REPO: {CREATE_REPO}, CREATE_REPO: {CREATE_REPO},
READ_USER: {READ_USER}, READ_USER: {READ_USER},
ORG_ADMIN: {ORG_ADMIN}, ORG_ADMIN: {ORG_ADMIN},
SUPERUSER: {SUPERUSER}, SUPERUSER: {SUPERUSER},
ADMIN_USER: {ADMIN_USER}, ADMIN_USER: {ADMIN_USER},
None: set(), None: set(),
} }
def app_scopes(app_config): def app_scopes(app_config):
scopes_from_config = dict(ALL_SCOPES) scopes_from_config = dict(ALL_SCOPES)
if not app_config.get('FEATURE_SUPER_USERS', False): if not app_config.get("FEATURE_SUPER_USERS", False):
del scopes_from_config[SUPERUSER.scope] del scopes_from_config[SUPERUSER.scope]
return scopes_from_config return scopes_from_config
def scopes_from_scope_string(scopes): def scopes_from_scope_string(scopes):
if not scopes: if not scopes:
scopes = '' scopes = ""
# Note: The scopes string should be space seperated according to the spec: # Note: The scopes string should be space seperated according to the spec:
# https://tools.ietf.org/html/rfc6749#section-3.3 # https://tools.ietf.org/html/rfc6749#section-3.3
# However, we also support commas for backwards compatibility with existing callers to our code. # However, we also support commas for backwards compatibility with existing callers to our code.
scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(' |,', scopes)} scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(" |,", scopes)}
return scope_set if not None in scope_set else set() return scope_set if not None in scope_set else set()
def validate_scope_string(scopes): def validate_scope_string(scopes):
decoded = scopes_from_scope_string(scopes) decoded = scopes_from_scope_string(scopes)
return len(decoded) > 0 return len(decoded) > 0
def is_subset_string(full_string, expected_string): def is_subset_string(full_string, expected_string):
""" Returns true if the scopes found in expected_string are also found """ Returns true if the scopes found in expected_string are also found
in full_string. in full_string.
""" """
full_scopes = scopes_from_scope_string(full_string) full_scopes = scopes_from_scope_string(full_string)
if not full_scopes: if not full_scopes:
return False return False
full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes]) full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes])
expected_scopes = scopes_from_scope_string(expected_string) expected_scopes = scopes_from_scope_string(expected_string)
return expected_scopes.issubset(full_implied_scopes) return expected_scopes.issubset(full_implied_scopes)
def get_scope_information(scopes_string): def get_scope_information(scopes_string):
scopes = scopes_from_scope_string(scopes_string) scopes = scopes_from_scope_string(scopes_string)
scope_info = [] scope_info = []
for scope in scopes: for scope in scopes:
scope_info.append({ scope_info.append(
'title': scope.title, {
'scope': scope.scope, "title": scope.title,
'description': scope.description, "scope": scope.scope,
'icon': scope.icon, "description": scope.description,
'dangerous': scope.dangerous, "icon": scope.icon,
}) "dangerous": scope.dangerous,
}
)
return scope_info return scope_info

View file

@ -8,48 +8,49 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The prefix for all signatures of signed granted. # The prefix for all signatures of signed granted.
SIGNATURE_PREFIX = 'sigv2=' SIGNATURE_PREFIX = "sigv2="
def generate_signed_token(grants, user_context): def generate_signed_token(grants, user_context):
""" Generates a signed session token with the given grants and user context. """ """ Generates a signed session token with the given grants and user context. """
ser = SecureCookieSessionInterface().get_signing_serializer(app) ser = SecureCookieSessionInterface().get_signing_serializer(app)
data_to_sign = { data_to_sign = {"grants": grants, "user_context": user_context}
'grants': grants,
'user_context': user_context,
}
encrypted = ser.dumps(data_to_sign) encrypted = ser.dumps(data_to_sign)
return '{0}{1}'.format(SIGNATURE_PREFIX, encrypted) return "{0}{1}".format(SIGNATURE_PREFIX, encrypted)
def validate_signed_grant(auth_header): def validate_signed_grant(auth_header):
""" Validates a signed grant as found inside an auth header and returns whether it points to """ Validates a signed grant as found inside an auth header and returns whether it points to
a valid grant. a valid grant.
""" """
if not auth_header: if not auth_header:
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Try to parse the token from the header. # Try to parse the token from the header.
normalized = [part.strip() for part in auth_header.split(' ') if part] normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != 'token' or len(normalized) != 2: if normalized[0].lower() != "token" or len(normalized) != 2:
logger.debug('Not a token: %s', auth_header) logger.debug("Not a token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Check that it starts with the expected prefix. # Check that it starts with the expected prefix.
if not normalized[1].startswith(SIGNATURE_PREFIX): if not normalized[1].startswith(SIGNATURE_PREFIX):
logger.debug('Not a signed grant token: %s', auth_header) logger.debug("Not a signed grant token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Decrypt the grant. # Decrypt the grant.
encrypted = normalized[1][len(SIGNATURE_PREFIX):] encrypted = normalized[1][len(SIGNATURE_PREFIX) :]
ser = SecureCookieSessionInterface().get_signing_serializer(app) ser = SecureCookieSessionInterface().get_signing_serializer(app)
try: try:
token_data = ser.loads(encrypted, max_age=app.config['SIGNED_GRANT_EXPIRATION_SEC']) token_data = ser.loads(
except BadSignature: encrypted, max_age=app.config["SIGNED_GRANT_EXPIRATION_SEC"]
logger.warning('Signed grant could not be validated: %s', encrypted) )
return ValidateResult(AuthKind.signed_grant, except BadSignature:
error_message='Signed grant could not be validated') logger.warning("Signed grant could not be validated: %s", encrypted)
return ValidateResult(
AuthKind.signed_grant, error_message="Signed grant could not be validated"
)
logger.debug('Successfully validated signed grant with data: %s', token_data) logger.debug("Successfully validated signed grant with data: %s", token_data)
return ValidateResult(AuthKind.signed_grant, signed_data=token_data) return ValidateResult(AuthKind.signed_grant, signed_data=token_data)

View file

@ -1,51 +1,65 @@
import pytest import pytest
from auth.auth_context_type import SignedAuthContext, ValidatedAuthContext, ContextEntityKind from auth.auth_context_type import (
SignedAuthContext,
ValidatedAuthContext,
ContextEntityKind,
)
from data import model, database from data import model, database
from test.fixtures import * from test.fixtures import *
def get_oauth_token(_): def get_oauth_token(_):
return database.OAuthAccessToken.get() return database.OAuthAccessToken.get()
@pytest.mark.parametrize('kind, entity_reference, loader', [ @pytest.mark.parametrize(
(ContextEntityKind.anonymous, None, None), "kind, entity_reference, loader",
(ContextEntityKind.appspecifictoken, '%s%s' % ('a' * 60, 'b' * 60), [
model.appspecifictoken.access_valid_token), (ContextEntityKind.anonymous, None, None),
(ContextEntityKind.oauthtoken, None, get_oauth_token), (
(ContextEntityKind.robot, 'devtable+dtrobot', model.user.lookup_robot), ContextEntityKind.appspecifictoken,
(ContextEntityKind.user, 'devtable', model.user.get_user), "%s%s" % ("a" * 60, "b" * 60),
]) model.appspecifictoken.access_valid_token,
@pytest.mark.parametrize('v1_dict_format', [ ),
(True), (ContextEntityKind.oauthtoken, None, get_oauth_token),
(False), (ContextEntityKind.robot, "devtable+dtrobot", model.user.lookup_robot),
]) (ContextEntityKind.user, "devtable", model.user.get_user),
def test_signed_auth_context(kind, entity_reference, loader, v1_dict_format, initialized_db): ],
if kind == ContextEntityKind.anonymous: )
validated = ValidatedAuthContext() @pytest.mark.parametrize("v1_dict_format", [(True), (False)])
assert validated.is_anonymous def test_signed_auth_context(
else: kind, entity_reference, loader, v1_dict_format, initialized_db
ref = loader(entity_reference) ):
validated = ValidatedAuthContext(**{kind.value: ref}) if kind == ContextEntityKind.anonymous:
assert not validated.is_anonymous validated = ValidatedAuthContext()
assert validated.is_anonymous
else:
ref = loader(entity_reference)
validated = ValidatedAuthContext(**{kind.value: ref})
assert not validated.is_anonymous
assert validated.entity_kind == kind assert validated.entity_kind == kind
assert validated.unique_key assert validated.unique_key
signed = SignedAuthContext.build_from_signed_dict(validated.to_signed_dict(), signed = SignedAuthContext.build_from_signed_dict(
v1_dict_format=v1_dict_format) validated.to_signed_dict(), v1_dict_format=v1_dict_format
)
if not v1_dict_format: if not v1_dict_format:
# Under legacy V1 format, we don't track the app specific token, merely its associated user. # Under legacy V1 format, we don't track the app specific token, merely its associated user.
assert signed.entity_kind == kind assert signed.entity_kind == kind
assert signed.description == validated.description assert signed.description == validated.description
assert signed.credential_username == validated.credential_username assert signed.credential_username == validated.credential_username
assert signed.analytics_id_and_public_metadata() == validated.analytics_id_and_public_metadata() assert (
assert signed.unique_key == validated.unique_key signed.analytics_id_and_public_metadata()
== validated.analytics_id_and_public_metadata()
)
assert signed.unique_key == validated.unique_key
assert signed.is_anonymous == validated.is_anonymous assert signed.is_anonymous == validated.is_anonymous
assert signed.authed_user == validated.authed_user assert signed.authed_user == validated.authed_user
assert signed.has_nonrobot_user == validated.has_nonrobot_user assert signed.has_nonrobot_user == validated.has_nonrobot_user
assert signed.to_signed_dict() == validated.to_signed_dict() assert signed.to_signed_dict() == validated.to_signed_dict()

View file

@ -5,8 +5,11 @@ import pytest
from base64 import b64encode from base64 import b64encode
from auth.basic import validate_basic_auth from auth.basic import validate_basic_auth
from auth.credentials import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credentials import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
from data import model from data import model
@ -14,85 +17,120 @@ from test.fixtures import *
def _token(username, password): def _token(username, password):
assert isinstance(username, basestring) assert isinstance(username, basestring)
assert isinstance(password, basestring) assert isinstance(password, basestring)
return 'basic ' + b64encode('%s:%s' % (username, password)) return "basic " + b64encode("%s:%s" % (username, password))
@pytest.mark.parametrize('token, expected_result', [ @pytest.mark.parametrize(
('', ValidateResult(AuthKind.basic, missing=True)), "token, expected_result",
('someinvalidtoken', ValidateResult(AuthKind.basic, missing=True)), [
('somefoobartoken', ValidateResult(AuthKind.basic, missing=True)), ("", ValidateResult(AuthKind.basic, missing=True)),
('basic ', ValidateResult(AuthKind.basic, missing=True)), ("someinvalidtoken", ValidateResult(AuthKind.basic, missing=True)),
('basic some token', ValidateResult(AuthKind.basic, missing=True)), ("somefoobartoken", ValidateResult(AuthKind.basic, missing=True)),
('basic sometoken', ValidateResult(AuthKind.basic, missing=True)), ("basic ", ValidateResult(AuthKind.basic, missing=True)),
(_token(APP_SPECIFIC_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, ("basic some token", ValidateResult(AuthKind.basic, missing=True)),
error_message='Invalid token')), ("basic sometoken", ValidateResult(AuthKind.basic, missing=True)),
(_token(ACCESS_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, (
error_message='Invalid access token')), _token(APP_SPECIFIC_TOKEN_USERNAME, "invalid"),
(_token(OAUTH_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, error_message="Invalid token"),
ValidateResult(AuthKind.basic, error_message='OAuth access token could not be validated')), ),
(_token('devtable', 'invalid'), ValidateResult(AuthKind.basic, (
error_message='Invalid Username or Password')), _token(ACCESS_TOKEN_USERNAME, "invalid"),
(_token('devtable+somebot', 'invalid'), ValidateResult( ValidateResult(AuthKind.basic, error_message="Invalid access token"),
AuthKind.basic, error_message='Could not find robot with username: devtable+somebot')), ),
(_token('disabled', 'password'), ValidateResult( (
AuthKind.basic, _token(OAUTH_TOKEN_USERNAME, "invalid"),
error_message='This user has been disabled. Please contact your administrator.')),]) ValidateResult(
AuthKind.basic,
error_message="OAuth access token could not be validated",
),
),
(
_token("devtable", "invalid"),
ValidateResult(
AuthKind.basic, error_message="Invalid Username or Password"
),
),
(
_token("devtable+somebot", "invalid"),
ValidateResult(
AuthKind.basic,
error_message="Could not find robot with username: devtable+somebot",
),
),
(
_token("disabled", "password"),
ValidateResult(
AuthKind.basic,
error_message="This user has been disabled. Please contact your administrator.",
),
),
],
)
def test_validate_basic_auth_token(token, expected_result, app): def test_validate_basic_auth_token(token, expected_result, app):
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == expected_result assert result == expected_result
def test_valid_user(app): def test_valid_user(app):
token = _token('devtable', 'password') token = _token("devtable", "password")
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, user=model.user.get_user('devtable')) assert result == ValidateResult(
AuthKind.basic, user=model.user.get_user("devtable")
)
def test_valid_robot(app): def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, password = model.user.create_robot(
token = _token(robot.username, password) "somerobot", model.user.get_user("devtable")
result = validate_basic_auth(token) )
assert result == ValidateResult(AuthKind.basic, robot=robot) token = _token(robot.username, password)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, robot=robot)
def test_valid_token(app): def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken') access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code()) token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, token=access_token) assert result == ValidateResult(AuthKind.basic, token=access_token)
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read') model.user.get_user_or_org("buynlarge")
token = _token(OAUTH_TOKEN_USERNAME, code) )[0]
result = validate_basic_auth(token) oauth_token, code = model.oauth.create_access_token_for_testing(
assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token) user, app.client_id, "repo:read"
)
token = _token(OAUTH_TOKEN_USERNAME, code)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
def test_valid_app_specific_token(app): def test_valid_app_specific_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token) token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token) assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token)
def test_invalid_unicode(app): def test_invalid_unicode(app):
token = '\xebOH' token = "\xebOH"
header = 'basic ' + b64encode(token) header = "basic " + b64encode(token)
result = validate_basic_auth(header) result = validate_basic_auth(header)
assert result == ValidateResult(AuthKind.basic, missing=True) assert result == ValidateResult(AuthKind.basic, missing=True)
def test_invalid_unicode_2(app): def test_invalid_unicode_2(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
header = 'basic ' + b64encode('devtable+somerobot:%s' % token) header = "basic " + b64encode("devtable+somerobot:%s" % token)
result = validate_basic_auth(header) result = validate_basic_auth(header)
assert result == ValidateResult( assert result == ValidateResult(
AuthKind.basic, AuthKind.basic,
error_message='Could not find robot with username: devtable+somerobot and supplied password.') error_message="Could not find robot with username: devtable+somerobot and supplied password.",
)

View file

@ -9,58 +9,58 @@ from test.fixtures import *
def test_anonymous_cookie(app): def test_anonymous_cookie(app):
assert validate_session_cookie().missing assert validate_session_cookie().missing
def test_invalidformatted_cookie(app): def test_invalidformatted_cookie(app):
# "Login" with a non-UUID reference. # "Login" with a non-UUID reference.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser('somenonuuid', someuser)) login_user(LoginWrappedDBUser("somenonuuid", someuser))
# Ensure we get an invalid session cookie format error. # Ensure we get an invalid session cookie format error.
result = validate_session_cookie() result = validate_session_cookie()
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'Invalid session cookie format' assert result.error_message == "Invalid session cookie format"
def test_disabled_user(app): def test_disabled_user(app):
# "Login" with a disabled user. # "Login" with a disabled user.
someuser = model.user.get_user('disabled') someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Ensure we get an invalid session cookie format error. # Ensure we get an invalid session cookie format error.
result = validate_session_cookie() result = validate_session_cookie()
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'User account is disabled' assert result.error_message == "User account is disabled"
def test_valid_user(app): def test_valid_user(app):
# Login with a valid user. # Login with a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
result = validate_session_cookie() result = validate_session_cookie()
assert result.authed_user == someuser assert result.authed_user == someuser
assert result.context.identity is not None assert result.context.identity is not None
assert result.has_nonrobot_user assert result.has_nonrobot_user
assert result.error_message is None assert result.error_message is None
def test_valid_organization(app): def test_valid_organization(app):
# "Login" with a valid organization. # "Login" with a valid organization.
someorg = model.user.get_namespace_user('buynlarge') someorg = model.user.get_namespace_user("buynlarge")
someorg.uuid = str(uuid.uuid4()) someorg.uuid = str(uuid.uuid4())
someorg.verified = True someorg.verified = True
someorg.save() someorg.save()
login_user(LoginWrappedDBUser(someorg.uuid, someorg)) login_user(LoginWrappedDBUser(someorg.uuid, someorg))
result = validate_session_cookie() result = validate_session_cookie()
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'Cannot login to organization' assert result.error_message == "Cannot login to organization"

View file

@ -1,147 +1,184 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from auth.credentials import validate_credentials, CredentialKind from auth.credentials import validate_credentials, CredentialKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
from data import model from data import model
from test.fixtures import * from test.fixtures import *
def test_valid_user(app): def test_valid_user(app):
result, kind = validate_credentials('devtable', 'password') result, kind = validate_credentials("devtable", "password")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials, user=model.user.get_user('devtable')) assert result == ValidateResult(
AuthKind.credentials, user=model.user.get_user("devtable")
)
def test_valid_robot(app): def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, password = model.user.create_robot(
result, kind = validate_credentials(robot.username, password) "somerobot", model.user.get_user("devtable")
assert kind == CredentialKind.robot )
assert result == ValidateResult(AuthKind.credentials, robot=robot) result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot
assert result == ValidateResult(AuthKind.credentials, robot=robot)
def test_valid_robot_for_disabled_user(app): def test_valid_robot_for_disabled_user(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
user.enabled = False user.enabled = False
user.save() user.save()
robot, password = model.user.create_robot('somerobot', user) robot, password = model.user.create_robot("somerobot", user)
result, kind = validate_credentials(robot.username, password) result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err)
err = 'This user has been disabled. Please contact your administrator.'
assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_valid_token(app): def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken') access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code()) result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
assert kind == CredentialKind.token assert kind == CredentialKind.token
assert result == ValidateResult(AuthKind.credentials, token=access_token) assert result == ValidateResult(AuthKind.credentials, token=access_token)
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read') model.user.get_user_or_org("buynlarge")
result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code) )[0]
assert kind == CredentialKind.oauth_token oauth_token, code = model.oauth.create_access_token_for_testing(
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token) user, app.client_id, "repo:read"
)
result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
assert kind == CredentialKind.oauth_token
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
def test_invalid_user(app): def test_invalid_user(app):
result, kind = validate_credentials('devtable', 'somepassword') result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials, assert result == ValidateResult(
error_message='Invalid Username or Password') AuthKind.credentials, error_message="Invalid Username or Password"
)
def test_valid_app_specific_token(app): def test_valid_app_specific_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, appspecifictoken=app_specific_token) assert result == ValidateResult(
AuthKind.credentials, appspecifictoken=app_specific_token
)
def test_valid_app_specific_token_for_disabled_user(app): def test_valid_app_specific_token_for_disabled_user(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
user.enabled = False user.enabled = False
user.save() user.save()
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err)
err = 'This user has been disabled. Please contact your administrator.'
assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_invalid_app_specific_token(app): def test_invalid_app_specific_token(app):
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, 'somecode') result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, "somecode")
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token') assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_invalid_app_specific_token_code(app): def test_invalid_app_specific_token_code(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = app_specific_token.token_name + 'something' full_token = app_specific_token.token_name + "something"
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token') assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_unicode(app): def test_unicode(app):
result, kind = validate_credentials('someusername', 'some₪code') result, kind = validate_credentials("someusername", "some₪code")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.auth_valid assert not result.auth_valid
assert result == ValidateResult(AuthKind.credentials, assert result == ValidateResult(
error_message='Invalid Username or Password') AuthKind.credentials, error_message="Invalid Username or Password"
)
def test_unicode_robot(app): def test_unicode_robot(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, 'some₪code') result, kind = validate_credentials(robot.username, "some₪code")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = (
"Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
msg = 'Could not find robot with username: devtable+somerobot and supplied password.'
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_user(app): def test_invalid_user(app):
result, kind = validate_credentials('someinvaliduser', 'password') result, kind = validate_credentials("someinvaliduser", "password")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_user_password(app): def test_invalid_user_password(app):
result, kind = validate_credentials('devtable', 'somepassword') result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_robot(app): def test_invalid_robot(app):
result, kind = validate_credentials('devtable+doesnotexist', 'password') result, kind = validate_credentials("devtable+doesnotexist", "password")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_robot_token(app): def test_invalid_robot_token(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, 'invalidpassword') result, kind = validate_credentials(robot.username, "invalidpassword")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_unicode_robot(app): def test_invalid_unicode_robot(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials('devtable+somerobot', token) result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot' msg = "Could not find robot with username: devtable+somerobot"
assert result == ValidateResult(AuthKind.credentials, error_message=msg) assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_unicode_robot_2(app): def test_invalid_unicode_robot_2(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
robot, password = model.user.create_robot('somerobot', user) robot, password = model.user.create_robot("somerobot", user)
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials('devtable+somerobot', token) result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot and supplied password.' msg = (
assert result == ValidateResult(AuthKind.credentials, error_message=msg) "Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg)

View file

@ -7,99 +7,102 @@ from werkzeug.exceptions import HTTPException
from app import LoginWrappedDBUser from app import LoginWrappedDBUser
from auth.auth_context import get_authenticated_user from auth.auth_context import get_authenticated_user
from auth.decorators import ( from auth.decorators import (
extract_namespace_repo_from_session, require_session_login, process_auth_or_cookie) extract_namespace_repo_from_session,
require_session_login,
process_auth_or_cookie,
)
from data import model from data import model
from test.fixtures import * from test.fixtures import *
def test_extract_namespace_repo_from_session_missing(app): def test_extract_namespace_repo_from_session_missing(app):
def emptyfunc(): def emptyfunc():
pass pass
session.clear() session.clear()
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
extract_namespace_repo_from_session(emptyfunc)() extract_namespace_repo_from_session(emptyfunc)()
def test_extract_namespace_repo_from_session_present(app): def test_extract_namespace_repo_from_session_present(app):
encountered = [] encountered = []
def somefunc(namespace, repository): def somefunc(namespace, repository):
encountered.append(namespace) encountered.append(namespace)
encountered.append(repository) encountered.append(repository)
# Add the namespace and repository to the session. # Add the namespace and repository to the session.
session.clear() session.clear()
session['namespace'] = 'foo' session["namespace"] = "foo"
session['repository'] = 'bar' session["repository"] = "bar"
# Call the decorated method. # Call the decorated method.
extract_namespace_repo_from_session(somefunc)() extract_namespace_repo_from_session(somefunc)()
assert encountered[0] == 'foo' assert encountered[0] == "foo"
assert encountered[1] == 'bar' assert encountered[1] == "bar"
def test_require_session_login_missing(app): def test_require_session_login_missing(app):
def emptyfunc(): def emptyfunc():
pass pass
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
require_session_login(emptyfunc)() require_session_login(emptyfunc)()
def test_require_session_login_valid_user(app): def test_require_session_login_valid_user(app):
def emptyfunc(): def emptyfunc():
pass pass
# Login as a valid user. # Login as a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.
require_session_login(emptyfunc)() require_session_login(emptyfunc)()
# Ensure the authenticated user was updated. # Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser assert get_authenticated_user() == someuser
def test_require_session_login_invalid_user(app): def test_require_session_login_invalid_user(app):
def emptyfunc(): def emptyfunc():
pass pass
# "Login" as a disabled user. # "Login" as a disabled user.
someuser = model.user.get_user('disabled') someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
require_session_login(emptyfunc)() require_session_login(emptyfunc)()
# Ensure the authenticated user was not updated. # Ensure the authenticated user was not updated.
assert get_authenticated_user() is None assert get_authenticated_user() is None
def test_process_auth_or_cookie_invalid_user(app): def test_process_auth_or_cookie_invalid_user(app):
def emptyfunc(): def emptyfunc():
pass pass
# Call the function. # Call the function.
process_auth_or_cookie(emptyfunc)() process_auth_or_cookie(emptyfunc)()
# Ensure the authenticated user was not updated. # Ensure the authenticated user was not updated.
assert get_authenticated_user() is None assert get_authenticated_user() is None
def test_process_auth_or_cookie_valid_user(app): def test_process_auth_or_cookie_valid_user(app):
def emptyfunc(): def emptyfunc():
pass pass
# Login as a valid user. # Login as a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.
process_auth_or_cookie(emptyfunc)() process_auth_or_cookie(emptyfunc)()
# Ensure the authenticated user was updated. # Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser assert get_authenticated_user() == someuser

View file

@ -6,50 +6,63 @@ from data import model
from test.fixtures import * from test.fixtures import *
@pytest.mark.parametrize('header, expected_result', [ @pytest.mark.parametrize(
('', ValidateResult(AuthKind.oauth, missing=True)), "header, expected_result",
('somerandomtoken', ValidateResult(AuthKind.oauth, missing=True)), [
('bearer some random token', ValidateResult(AuthKind.oauth, missing=True)), ("", ValidateResult(AuthKind.oauth, missing=True)),
('bearer invalidtoken', ("somerandomtoken", ValidateResult(AuthKind.oauth, missing=True)),
ValidateResult(AuthKind.oauth, error_message='OAuth access token could not be validated')),]) ("bearer some random token", ValidateResult(AuthKind.oauth, missing=True)),
(
"bearer invalidtoken",
ValidateResult(
AuthKind.oauth,
error_message="OAuth access token could not be validated",
),
),
],
)
def test_bearer(header, expected_result, app): def test_bearer(header, expected_result, app):
assert validate_bearer_auth(header) == expected_result assert validate_bearer_auth(header) == expected_result
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
token_string = '%s%s' % ('a' * 20, 'b' * 20) model.user.get_user_or_org("buynlarge")
oauth_token, _ = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read', )[0]
access_token=token_string) token_string = "%s%s" % ("a" * 20, "b" * 20)
result = validate_bearer_auth('bearer ' + token_string) oauth_token, _ = model.oauth.create_access_token_for_testing(
assert result.context.oauthtoken == oauth_token user, app.client_id, "repo:read", access_token=token_string
assert result.authed_user == user )
assert result.auth_valid result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken == oauth_token
assert result.authed_user == user
assert result.auth_valid
def test_disabled_user_oauth(app): def test_disabled_user_oauth(app):
user = model.user.get_user('disabled') user = model.user.get_user("disabled")
token_string = '%s%s' % ('a' * 20, 'b' * 20) token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin', oauth_token, _ = model.oauth.create_access_token_for_testing(
access_token=token_string) user, "deadbeef", "repo:admin", access_token=token_string
)
result = validate_bearer_auth('bearer ' + token_string) result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None assert result.context.oauthtoken is None
assert result.authed_user is None assert result.authed_user is None
assert not result.auth_valid assert not result.auth_valid
assert result.error_message == 'Granter of the oauth access token is disabled' assert result.error_message == "Granter of the oauth access token is disabled"
def test_expired_token(app): def test_expired_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
token_string = '%s%s' % ('a' * 20, 'b' * 20) token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin', oauth_token, _ = model.oauth.create_access_token_for_testing(
access_token=token_string, user, "deadbeef", "repo:admin", access_token=token_string, expires_in=-1000
expires_in=-1000) )
result = validate_bearer_auth('bearer ' + token_string) result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None assert result.context.oauthtoken is None
assert result.authed_user is None assert result.authed_user is None
assert not result.auth_valid assert not result.auth_valid
assert result.error_message == 'OAuth access token has expired' assert result.error_message == "OAuth access token has expired"

View file

@ -6,32 +6,33 @@ from data import model
from test.fixtures import * from test.fixtures import *
SUPER_USERNAME = 'devtable' SUPER_USERNAME = "devtable"
UNSUPER_USERNAME = 'freshuser' UNSUPER_USERNAME = "freshuser"
@pytest.fixture() @pytest.fixture()
def superuser(initialized_db): def superuser(initialized_db):
return model.user.get_user(SUPER_USERNAME) return model.user.get_user(SUPER_USERNAME)
@pytest.fixture() @pytest.fixture()
def normie(initialized_db): def normie(initialized_db):
return model.user.get_user(UNSUPER_USERNAME) return model.user.get_user(UNSUPER_USERNAME)
def test_superuser_matrix(superuser, normie): def test_superuser_matrix(superuser, normie):
test_cases = [ test_cases = [
(superuser, {scopes.SUPERUSER}, True), (superuser, {scopes.SUPERUSER}, True),
(superuser, {scopes.DIRECT_LOGIN}, True), (superuser, {scopes.DIRECT_LOGIN}, True),
(superuser, {scopes.READ_USER, scopes.SUPERUSER}, True), (superuser, {scopes.READ_USER, scopes.SUPERUSER}, True),
(superuser, {scopes.READ_USER}, False), (superuser, {scopes.READ_USER}, False),
(normie, {scopes.SUPERUSER}, False), (normie, {scopes.SUPERUSER}, False),
(normie, {scopes.DIRECT_LOGIN}, False), (normie, {scopes.DIRECT_LOGIN}, False),
(normie, {scopes.READ_USER, scopes.SUPERUSER}, False), (normie, {scopes.READ_USER, scopes.SUPERUSER}, False),
(normie, {scopes.READ_USER}, False), (normie, {scopes.READ_USER}, False),
] ]
for user_obj, scope_set, expected in test_cases: for user_obj, scope_set, expected in test_cases:
perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set) perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set)
has_su = perm_user.can(SuperUserPermission()) has_su = perm_user.can(SuperUserPermission())
assert has_su == expected assert has_su == expected

View file

@ -14,190 +14,226 @@ from initdb import setup_database_for_testing, finished_database_for_testing
from util.morecollections import AttrDict from util.morecollections import AttrDict
from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject
TEST_AUDIENCE = app.config['SERVER_HOSTNAME'] TEST_AUDIENCE = app.config["SERVER_HOSTNAME"]
TEST_USER = AttrDict({'username': 'joeuser', 'uuid': 'foobar', 'enabled': True}) TEST_USER = AttrDict({"username": "joeuser", "uuid": "foobar", "enabled": True})
MAX_SIGNED_S = 3660 MAX_SIGNED_S = 3660
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
ANONYMOUS_SUB = '(anonymous)' ANONYMOUS_SUB = "(anonymous)"
SERVICE_NAME = 'quay' SERVICE_NAME = "quay"
# This import has to come below any references to "app". # This import has to come below any references to "app".
from test.fixtures import * from test.fixtures import *
def _access(typ='repository', name='somens/somerepo', actions=None): def _access(typ="repository", name="somens/somerepo", actions=None):
actions = [] if actions is None else actions actions = [] if actions is None else actions
return [{ return [{"type": typ, "name": name, "actions": actions}]
'type': typ,
'name': name,
'actions': actions,
}]
def _delete_field(token_data, field_name): def _delete_field(token_data, field_name):
token_data.pop(field_name) token_data.pop(field_name)
return token_data return token_data
def _token_data(access=[], context=None, audience=TEST_AUDIENCE, user=TEST_USER, iat=None, def _token_data(
exp=None, nbf=None, iss=None, subject=None): access=[],
if subject is None: context=None,
_, subject = build_context_and_subject(ValidatedAuthContext(user=user)) audience=TEST_AUDIENCE,
return { user=TEST_USER,
'iss': iss or instance_keys.service_name, iat=None,
'aud': audience, exp=None,
'nbf': nbf if nbf is not None else int(time.time()), nbf=None,
'iat': iat if iat is not None else int(time.time()), iss=None,
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S), subject=None,
'sub': subject, ):
'access': access, if subject is None:
'context': context, _, subject = build_context_and_subject(ValidatedAuthContext(user=user))
} return {
"iss": iss or instance_keys.service_name,
"aud": audience,
"nbf": nbf if nbf is not None else int(time.time()),
"iat": iat if iat is not None else int(time.time()),
"exp": exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
"sub": subject,
"access": access,
"context": context,
}
def _token(token_data, key_id=None, private_key=None, skip_header=False, alg=None): def _token(token_data, key_id=None, private_key=None, skip_header=False, alg=None):
key_id = key_id or instance_keys.local_key_id key_id = key_id or instance_keys.local_key_id
private_key = private_key or instance_keys.local_private_key private_key = private_key or instance_keys.local_private_key
if alg == "none": if alg == "none":
private_key = None private_key = None
token_headers = {'kid': key_id} token_headers = {"kid": key_id}
if skip_header: if skip_header:
token_headers = {} token_headers = {}
token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers) token_data = jwt.encode(
return 'Bearer {0}'.format(token_data) token_data, private_key, alg or "RS256", headers=token_headers
)
return "Bearer {0}".format(token_data)
def _parse_token(token): def _parse_token(token):
return identity_from_bearer_token(token)[0] return identity_from_bearer_token(token)[0]
def test_accepted_token(initialized_db): def test_accepted_token(initialized_db):
token = _token(_token_data()) token = _token(_token_data())
identity = _parse_token(token) identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username, assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
identity.id) TEST_USER.username,
assert len(identity.provides) == 0 identity.id,
)
assert len(identity.provides) == 0
anon_token = _token(_token_data(user=None)) anon_token = _token(_token_data(user=None))
anon_identity = _parse_token(anon_token) anon_identity = _parse_token(anon_token)
assert anon_identity.id == ANONYMOUS_SUB, 'should be %s, but was %s' % (ANONYMOUS_SUB, assert anon_identity.id == ANONYMOUS_SUB, "should be %s, but was %s" % (
anon_identity.id) ANONYMOUS_SUB,
assert len(identity.provides) == 0 anon_identity.id,
)
assert len(identity.provides) == 0
@pytest.mark.parametrize('access', [ @pytest.mark.parametrize(
(_access(actions=['pull', 'push'])), "access",
(_access(actions=['pull', '*'])), [
(_access(actions=['*', 'push'])), (_access(actions=["pull", "push"])),
(_access(actions=['*'])), (_access(actions=["pull", "*"])),
(_access(actions=['pull', '*', 'push'])),]) (_access(actions=["*", "push"])),
(_access(actions=["*"])),
(_access(actions=["pull", "*", "push"])),
],
)
def test_token_with_access(access, initialized_db): def test_token_with_access(access, initialized_db):
token = _token(_token_data(access=access)) token = _token(_token_data(access=access))
identity = _parse_token(token) identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username, assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
identity.id) TEST_USER.username,
assert len(identity.provides) == 1 identity.id,
)
assert len(identity.provides) == 1
role = list(identity.provides)[0][3] role = list(identity.provides)[0][3]
if "*" in access[0]['actions']: if "*" in access[0]["actions"]:
assert role == 'admin' assert role == "admin"
elif "push" in access[0]['actions']: elif "push" in access[0]["actions"]:
assert role == 'write' assert role == "write"
elif "pull" in access[0]['actions']: elif "pull" in access[0]["actions"]:
assert role == 'read' assert role == "read"
@pytest.mark.parametrize('token', [ @pytest.mark.parametrize(
pytest.param(_token( "token",
_token_data(access=[{ [
'toipe': 'repository', pytest.param(
'namesies': 'somens/somerepo', _token(
'akshuns': ['pull', 'push', '*']}])), id='bad access'), _token_data(
pytest.param(_token(_token_data(audience='someotherapp')), id='bad aud'), access=[
pytest.param(_token(_delete_field(_token_data(), 'aud')), id='no aud'), {
pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id='future nbf'), "toipe": "repository",
pytest.param(_token(_delete_field(_token_data(), 'nbf')), id='no nbf'), "namesies": "somens/somerepo",
pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id='future iat'), "akshuns": ["pull", "push", "*"],
pytest.param(_token(_delete_field(_token_data(), 'iat')), id='no iat'), }
pytest.param(_token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)), id='exp too long'), ]
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id='expired'), )
pytest.param(_token(_delete_field(_token_data(), 'exp')), id='no exp'), ),
pytest.param(_token(_delete_field(_token_data(), 'sub')), id='no sub'), id="bad access",
pytest.param(_token(_token_data(iss='badissuer')), id='bad iss'), ),
pytest.param(_token(_delete_field(_token_data(), 'iss')), id='no iss'), pytest.param(_token(_token_data(audience="someotherapp")), id="bad aud"),
pytest.param(_token(_token_data(), skip_header=True), id='no header'), pytest.param(_token(_delete_field(_token_data(), "aud")), id="no aud"),
pytest.param(_token(_token_data(), key_id='someunknownkey'), id='bad key'), pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id="future nbf"),
pytest.param(_token(_token_data(), key_id='kid7'), id='bad key :: kid7'), pytest.param(_token(_delete_field(_token_data(), "nbf")), id="no nbf"),
pytest.param(_token(_token_data(), alg='none', private_key=None), id='none alg'), pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id="future iat"),
pytest.param('some random token', id='random token'), pytest.param(_token(_delete_field(_token_data(), "iat")), id="no iat"),
pytest.param('Bearer: sometokenhere', id='extra bearer'), pytest.param(
pytest.param('\nBearer: dGVzdA', id='leading newline'), _token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)),
]) id="exp too long",
),
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id="expired"),
pytest.param(_token(_delete_field(_token_data(), "exp")), id="no exp"),
pytest.param(_token(_delete_field(_token_data(), "sub")), id="no sub"),
pytest.param(_token(_token_data(iss="badissuer")), id="bad iss"),
pytest.param(_token(_delete_field(_token_data(), "iss")), id="no iss"),
pytest.param(_token(_token_data(), skip_header=True), id="no header"),
pytest.param(_token(_token_data(), key_id="someunknownkey"), id="bad key"),
pytest.param(_token(_token_data(), key_id="kid7"), id="bad key :: kid7"),
pytest.param(
_token(_token_data(), alg="none", private_key=None), id="none alg"
),
pytest.param("some random token", id="random token"),
pytest.param("Bearer: sometokenhere", id="extra bearer"),
pytest.param("\nBearer: dGVzdA", id="leading newline"),
],
)
def test_invalid_jwt(token, initialized_db): def test_invalid_jwt(token, initialized_db):
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(token) _parse_token(token)
def test_mixing_keys_e2e(initialized_db): def test_mixing_keys_e2e(initialized_db):
token_data = _token_data() token_data = _token_data()
# Create a new key for testing. # Create a new key for testing.
p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey', p, key = model.service_keys.generate_service_key(
name='newkey', metadata={}) instance_keys.service_name, None, kid="newkey", name="newkey", metadata={}
private_key = p.exportKey('PEM') )
private_key = p.exportKey("PEM")
# Test first with the new valid, but unapproved key. # Test first with the new valid, but unapproved key.
unapproved_key_token = _token(token_data, key_id='newkey', private_key=private_key) unapproved_key_token = _token(token_data, key_id="newkey", private_key=private_key)
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(unapproved_key_token) _parse_token(unapproved_key_token)
# Approve the key and try again. # Approve the key and try again.
admin_user = model.user.get_user('devtable') admin_user = model.user.get_user("devtable")
model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user) model.service_keys.approve_service_key(
key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user
)
valid_token = _token(token_data, key_id='newkey', private_key=private_key) valid_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(valid_token) identity = _parse_token(valid_token)
assert identity.id == TEST_USER.username assert identity.id == TEST_USER.username
assert len(identity.provides) == 0 assert len(identity.provides) == 0
# Try using a different private key with the existing key ID. # Try using a different private key with the existing key ID.
bad_private_token = _token(token_data, key_id='newkey', bad_private_token = _token(
private_key=instance_keys.local_private_key) token_data, key_id="newkey", private_key=instance_keys.local_private_key
with pytest.raises(InvalidJWTException): )
_parse_token(bad_private_token) with pytest.raises(InvalidJWTException):
_parse_token(bad_private_token)
# Try using a different key ID with the existing private key. # Try using a different key ID with the existing private key.
kid_mismatch_token = _token(token_data, key_id=instance_keys.local_key_id, kid_mismatch_token = _token(
private_key=private_key) token_data, key_id=instance_keys.local_key_id, private_key=private_key
with pytest.raises(InvalidJWTException): )
_parse_token(kid_mismatch_token) with pytest.raises(InvalidJWTException):
_parse_token(kid_mismatch_token)
# Delete the new key. # Delete the new key.
key.delete_instance(recursive=True) key.delete_instance(recursive=True)
# Ensure it still works (via the cache.) # Ensure it still works (via the cache.)
deleted_key_token = _token(token_data, key_id='newkey', private_key=private_key) deleted_key_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(deleted_key_token) identity = _parse_token(deleted_key_token)
assert identity.id == TEST_USER.username assert identity.id == TEST_USER.username
assert len(identity.provides) == 0 assert len(identity.provides) == 0
# Break the cache. # Break the cache.
instance_keys.clear_cache() instance_keys.clear_cache()
# Ensure the key no longer works. # Ensure the key no longer works.
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(deleted_key_token) _parse_token(deleted_key_token)
@pytest.mark.parametrize('token', [ @pytest.mark.parametrize("token", [u"someunicodetoken✡", u"\xc9\xad\xbd"])
u'someunicodetoken✡',
u'\xc9\xad\xbd',
])
def test_unicode_token(token): def test_unicode_token(token):
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(token) _parse_token(token)

View file

@ -1,50 +1,55 @@
import pytest import pytest
from auth.scopes import ( from auth.scopes import (
scopes_from_scope_string, validate_scope_string, ALL_SCOPES, is_subset_string) scopes_from_scope_string,
validate_scope_string,
ALL_SCOPES,
is_subset_string,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'scopes_string, expected', "scopes_string, expected",
[ [
# Valid single scopes. # Valid single scopes.
('repo:read', ['repo:read']), ("repo:read", ["repo:read"]),
('repo:admin', ['repo:admin']), ("repo:admin", ["repo:admin"]),
# Invalid scopes.
# Invalid scopes. ("not:valid", []),
('not:valid', []), ("repo:admins", []),
('repo:admins', []), # Valid scope strings.
("repo:read repo:admin", ["repo:read", "repo:admin"]),
# Valid scope strings. ("repo:read,repo:admin", ["repo:read", "repo:admin"]),
('repo:read repo:admin', ['repo:read', 'repo:admin']), ("repo:read,repo:admin repo:write", ["repo:read", "repo:admin", "repo:write"]),
('repo:read,repo:admin', ['repo:read', 'repo:admin']), # Partially invalid scopes.
('repo:read,repo:admin repo:write', ['repo:read', 'repo:admin', 'repo:write']), ("repo:read,not:valid", []),
("repo:read repo:admins", []),
# Partially invalid scopes. # Invalid scope strings.
('repo:read,not:valid', []), ("repo:read|repo:admin", []),
('repo:read repo:admins', []), # Mixture of delimiters.
("repo:read, repo:admin", []),
# Invalid scope strings. ],
('repo:read|repo:admin', []), )
# Mixture of delimiters.
('repo:read, repo:admin', []),])
def test_parsing(scopes_string, expected): def test_parsing(scopes_string, expected):
expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected} expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
parsed_scope_set = scopes_from_scope_string(scopes_string) parsed_scope_set = scopes_from_scope_string(scopes_string)
assert parsed_scope_set == expected_scope_set assert parsed_scope_set == expected_scope_set
assert validate_scope_string(scopes_string) == bool(expected) assert validate_scope_string(scopes_string) == bool(expected)
@pytest.mark.parametrize('superset, subset, result', [ @pytest.mark.parametrize(
('repo:read', 'repo:read', True), "superset, subset, result",
('repo:read repo:admin', 'repo:read', True), [
('repo:read,repo:admin', 'repo:read', True), ("repo:read", "repo:read", True),
('repo:read,repo:admin', 'repo:admin', True), ("repo:read repo:admin", "repo:read", True),
('repo:read,repo:admin', 'repo:admin repo:read', True), ("repo:read,repo:admin", "repo:read", True),
('', 'repo:read', False), ("repo:read,repo:admin", "repo:admin", True),
('unknown:tag', 'repo:read', False), ("repo:read,repo:admin", "repo:admin repo:read", True),
('repo:read unknown:tag', 'repo:read', False), ("", "repo:read", False),
('repo:read,unknown:tag', 'repo:read', False),]) ("unknown:tag", "repo:read", False),
("repo:read unknown:tag", "repo:read", False),
("repo:read,unknown:tag", "repo:read", False),
],
)
def test_subset_string(superset, subset, result): def test_subset_string(superset, subset, result):
assert is_subset_string(superset, subset) == result assert is_subset_string(superset, subset) == result

View file

@ -1,32 +1,47 @@
import pytest import pytest
from auth.signedgrant import validate_signed_grant, generate_signed_token, SIGNATURE_PREFIX from auth.signedgrant import (
validate_signed_grant,
generate_signed_token,
SIGNATURE_PREFIX,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
@pytest.mark.parametrize('header, expected_result', [ @pytest.mark.parametrize(
pytest.param('', ValidateResult(AuthKind.signed_grant, missing=True), id='Missing'), "header, expected_result",
pytest.param('somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True), [
id='Invalid header'), pytest.param(
pytest.param('token somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True), "", ValidateResult(AuthKind.signed_grant, missing=True), id="Missing"
id='Random Token'), ),
pytest.param('token ' + SIGNATURE_PREFIX + 'foo', pytest.param(
ValidateResult(AuthKind.signed_grant, "somerandomtoken",
error_message='Signed grant could not be validated'), ValidateResult(AuthKind.signed_grant, missing=True),
id='Invalid token'), id="Invalid header",
]) ),
pytest.param(
"token somerandomtoken",
ValidateResult(AuthKind.signed_grant, missing=True),
id="Random Token",
),
pytest.param(
"token " + SIGNATURE_PREFIX + "foo",
ValidateResult(
AuthKind.signed_grant,
error_message="Signed grant could not be validated",
),
id="Invalid token",
),
],
)
def test_token(header, expected_result): def test_token(header, expected_result):
assert validate_signed_grant(header) == expected_result assert validate_signed_grant(header) == expected_result
def test_valid_grant(): def test_valid_grant():
header = 'token ' + generate_signed_token({'a': 'b'}, {'c': 'd'}) header = "token " + generate_signed_token({"a": "b"}, {"c": "d"})
expected = ValidateResult(AuthKind.signed_grant, signed_data={ expected = ValidateResult(
'grants': { AuthKind.signed_grant,
'a': 'b', signed_data={"grants": {"a": "b"}, "user_context": {"c": "d"}},
}, )
'user_context': { assert validate_signed_grant(header) == expected
'c': 'd'
},
})
assert validate_signed_grant(header) == expected

View file

@ -6,58 +6,68 @@ from data import model
from data.database import AppSpecificAuthToken from data.database import AppSpecificAuthToken
from test.fixtures import * from test.fixtures import *
def get_user(): def get_user():
return model.user.get_user('devtable') return model.user.get_user("devtable")
def get_app_specific_token(): def get_app_specific_token():
return AppSpecificAuthToken.get() return AppSpecificAuthToken.get()
def get_robot(): def get_robot():
robot, _ = model.user.create_robot('somebot', get_user()) robot, _ = model.user.create_robot("somebot", get_user())
return robot return robot
def get_token(): def get_token():
return model.token.create_delegate_token('devtable', 'simple', 'sometoken') return model.token.create_delegate_token("devtable", "simple", "sometoken")
def get_oauthtoken(): def get_oauthtoken():
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
return list(model.oauth.list_access_tokens_for_user(user))[0] return list(model.oauth.list_access_tokens_for_user(user))[0]
def get_signeddata(): def get_signeddata():
return {'grants': {'a': 'b'}, 'user_context': {'c': 'd'}} return {"grants": {"a": "b"}, "user_context": {"c": "d"}}
@pytest.mark.parametrize('get_entity,entity_kind', [
(get_user, 'user'), @pytest.mark.parametrize(
(get_robot, 'robot'), "get_entity,entity_kind",
(get_token, 'token'), [
(get_oauthtoken, 'oauthtoken'), (get_user, "user"),
(get_signeddata, 'signed_data'), (get_robot, "robot"),
(get_app_specific_token, 'appspecifictoken'), (get_token, "token"),
]) (get_oauthtoken, "oauthtoken"),
(get_signeddata, "signed_data"),
(get_app_specific_token, "appspecifictoken"),
],
)
def test_apply_context(get_entity, entity_kind, app): def test_apply_context(get_entity, entity_kind, app):
assert get_authenticated_context() is None assert get_authenticated_context() is None
entity = get_entity() entity = get_entity()
args = {} args = {}
args[entity_kind] = entity args[entity_kind] = entity
result = ValidateResult(AuthKind.basic, **args) result = ValidateResult(AuthKind.basic, **args)
result.apply_to_context() result.apply_to_context()
expected_user = entity if entity_kind == 'user' or entity_kind == 'robot' else None expected_user = entity if entity_kind == "user" or entity_kind == "robot" else None
if entity_kind == 'oauthtoken': if entity_kind == "oauthtoken":
expected_user = entity.authorized_user expected_user = entity.authorized_user
if entity_kind == 'appspecifictoken': if entity_kind == "appspecifictoken":
expected_user = entity.user expected_user = entity.user
expected_token = entity if entity_kind == 'token' else None expected_token = entity if entity_kind == "token" else None
expected_oauth = entity if entity_kind == 'oauthtoken' else None expected_oauth = entity if entity_kind == "oauthtoken" else None
expected_appspecifictoken = entity if entity_kind == 'appspecifictoken' else None expected_appspecifictoken = entity if entity_kind == "appspecifictoken" else None
expected_grant = entity if entity_kind == 'signed_data' else None expected_grant = entity if entity_kind == "signed_data" else None
assert get_authenticated_context().authed_user == expected_user assert get_authenticated_context().authed_user == expected_user
assert get_authenticated_context().token == expected_token assert get_authenticated_context().token == expected_token
assert get_authenticated_context().oauthtoken == expected_oauth assert get_authenticated_context().oauthtoken == expected_oauth
assert get_authenticated_context().appspecifictoken == expected_appspecifictoken assert get_authenticated_context().appspecifictoken == expected_appspecifictoken
assert get_authenticated_context().signed_data == expected_grant assert get_authenticated_context().signed_data == expected_grant

View file

@ -3,54 +3,76 @@ from auth.auth_context_type import ValidatedAuthContext, ContextEntityKind
class AuthKind(Enum): class AuthKind(Enum):
cookie = 'cookie' cookie = "cookie"
basic = 'basic' basic = "basic"
oauth = 'oauth' oauth = "oauth"
signed_grant = 'signed_grant' signed_grant = "signed_grant"
credentials = 'credentials' credentials = "credentials"
class ValidateResult(object): class ValidateResult(object):
""" A result of validating auth in one form or another. """ """ A result of validating auth in one form or another. """
def __init__(self, kind, missing=False, user=None, token=None, oauthtoken=None,
robot=None, appspecifictoken=None, signed_data=None, error_message=None):
self.kind = kind
self.missing = missing
self.error_message = error_message
self.context = ValidatedAuthContext(user=user, token=token, oauthtoken=oauthtoken, robot=robot,
appspecifictoken=appspecifictoken, signed_data=signed_data)
def tuple(self): def __init__(
return (self.kind, self.missing, self.error_message, self.context.tuple()) self,
kind,
missing=False,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
error_message=None,
):
self.kind = kind
self.missing = missing
self.error_message = error_message
self.context = ValidatedAuthContext(
user=user,
token=token,
oauthtoken=oauthtoken,
robot=robot,
appspecifictoken=appspecifictoken,
signed_data=signed_data,
)
def __eq__(self, other): def tuple(self):
return self.tuple() == other.tuple() return (self.kind, self.missing, self.error_message, self.context.tuple())
def apply_to_context(self): def __eq__(self, other):
""" Applies this auth result to the auth context and Flask-Principal. """ return self.tuple() == other.tuple()
self.context.apply_to_request_context()
def with_kind(self, kind): def apply_to_context(self):
""" Returns a copy of this result, but with the kind replaced. """ """ Applies this auth result to the auth context and Flask-Principal. """
result = ValidateResult(kind, missing=self.missing, error_message=self.error_message) self.context.apply_to_request_context()
result.context = self.context
return result
def __repr__(self): def with_kind(self, kind):
return 'ValidateResult: %s (missing: %s, error: %s)' % (self.kind, self.missing, """ Returns a copy of this result, but with the kind replaced. """
self.error_message) result = ValidateResult(
kind, missing=self.missing, error_message=self.error_message
)
result.context = self.context
return result
@property def __repr__(self):
def authed_user(self): return "ValidateResult: %s (missing: %s, error: %s)" % (
""" Returns the authenticated user, whether directly, or via an OAuth token. """ self.kind,
return self.context.authed_user self.missing,
self.error_message,
)
@property @property
def has_nonrobot_user(self): def authed_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """ """ Returns the authenticated user, whether directly, or via an OAuth token. """
return self.context.has_nonrobot_user return self.context.authed_user
@property @property
def auth_valid(self): def has_nonrobot_user(self):
""" Returns whether authentication successfully occurred. """ """ Returns whether a user (not a robot) was authenticated successfully. """
return self.context.entity_kind != ContextEntityKind.anonymous return self.context.has_nonrobot_user
@property
def auth_valid(self):
""" Returns whether authentication successfully occurred. """
return self.context.entity_kind != ContextEntityKind.anonymous

View file

@ -6,110 +6,133 @@ from requests.exceptions import RequestException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Avatar(object): class Avatar(object):
def __init__(self, app=None): def __init__(self, app=None):
self.app = app self.app = app
self.state = self._init_app(app) self.state = self._init_app(app)
def _init_app(self, app): def _init_app(self, app):
return AVATAR_CLASSES[app.config.get('AVATAR_KIND', 'Gravatar')]( return AVATAR_CLASSES[app.config.get("AVATAR_KIND", "Gravatar")](
app.config['PREFERRED_URL_SCHEME'], app.config['AVATAR_COLORS'], app.config['HTTPCLIENT']) app.config["PREFERRED_URL_SCHEME"],
app.config["AVATAR_COLORS"],
app.config["HTTPCLIENT"],
)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.state, name, None) return getattr(self.state, name, None)
class BaseAvatar(object): class BaseAvatar(object):
""" Base class for all avatar implementations. """ """ Base class for all avatar implementations. """
def __init__(self, preferred_url_scheme, colors, http_client):
self.preferred_url_scheme = preferred_url_scheme
self.colors = colors
self.http_client = http_client
def get_mail_html(self, name, email_or_id, size=16, kind='user'): def __init__(self, preferred_url_scheme, colors, http_client):
""" Returns the full HTML and CSS for viewing the avatar of the given name and email address, self.preferred_url_scheme = preferred_url_scheme
self.colors = colors
self.http_client = http_client
def get_mail_html(self, name, email_or_id, size=16, kind="user"):
""" Returns the full HTML and CSS for viewing the avatar of the given name and email address,
with an optional size. with an optional size.
""" """
data = self.get_data(name, email_or_id, kind) data = self.get_data(name, email_or_id, kind)
url = self._get_url(data['hash'], size) if kind != 'team' else None url = self._get_url(data["hash"], size) if kind != "team" else None
font_size = size - 6 font_size = size - 6
if url is not None: if url is not None:
# Try to load the gravatar. If we get a non-404 response, then we use it in place of # Try to load the gravatar. If we get a non-404 response, then we use it in place of
# the CSS avatar. # the CSS avatar.
try: try:
response = self.http_client.get(url, timeout=5) response = self.http_client.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:
return """<img src="%s" width="%s" height="%s" alt="%s" return """<img src="%s" width="%s" height="%s" alt="%s"
style="vertical-align: middle;">""" % (url, size, size, kind) style="vertical-align: middle;">""" % (
except RequestException: url,
logger.exception('Could not retrieve avatar for user %s', name) size,
size,
kind,
)
except RequestException:
logger.exception("Could not retrieve avatar for user %s", name)
radius = '50%' if kind == 'team' else '0%' radius = "50%" if kind == "team" else "0%"
letter = '&Omega;' if kind == 'team' and data['name'] == 'owners' else data['name'].upper()[0] letter = (
"&Omega;"
if kind == "team" and data["name"] == "owners"
else data["name"].upper()[0]
)
return """ return """
<span style="width: %spx; height: %spx; background-color: %s; font-size: %spx; <span style="width: %spx; height: %spx; background-color: %s; font-size: %spx;
line-height: %spx; margin-left: 2px; margin-right: 2px; display: inline-block; line-height: %spx; margin-left: 2px; margin-right: 2px; display: inline-block;
vertical-align: middle; text-align: center; color: white; border-radius: %s"> vertical-align: middle; text-align: center; color: white; border-radius: %s">
%s %s
</span> </span>
""" % (size, size, data['color'], font_size, size, radius, letter) """ % (
size,
size,
data["color"],
font_size,
size,
radius,
letter,
)
def get_data_for_user(self, user): def get_data_for_user(self, user):
return self.get_data(user.username, user.email, 'robot' if user.robot else 'user') return self.get_data(
user.username, user.email, "robot" if user.robot else "user"
)
def get_data_for_team(self, team): def get_data_for_team(self, team):
return self.get_data(team.name, team.name, 'team') return self.get_data(team.name, team.name, "team")
def get_data_for_org(self, org): def get_data_for_org(self, org):
return self.get_data(org.username, org.email, 'org') return self.get_data(org.username, org.email, "org")
def get_data_for_external_user(self, external_user): def get_data_for_external_user(self, external_user):
return self.get_data(external_user.username, external_user.email, 'user') return self.get_data(external_user.username, external_user.email, "user")
def get_data(self, name, email_or_id, kind='user'): def get_data(self, name, email_or_id, kind="user"):
""" Computes and returns the full data block for the avatar: """ Computes and returns the full data block for the avatar:
{ {
'name': name, 'name': name,
'hash': The gravatar hash, if any. 'hash': The gravatar hash, if any.
'color': The color for the avatar 'color': The color for the avatar
} }
""" """
colors = self.colors colors = self.colors
# Note: email_or_id may be None if gotten from external auth when email is disabled, # Note: email_or_id may be None if gotten from external auth when email is disabled,
# so use the username in that case. # so use the username in that case.
username_email_or_id = email_or_id or name username_email_or_id = email_or_id or name
hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest() hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest()
byte_count = int(math.ceil(math.log(len(colors), 16))) byte_count = int(math.ceil(math.log(len(colors), 16)))
byte_data = hash_value[0:byte_count] byte_data = hash_value[0:byte_count]
hash_color = colors[int(byte_data, 16) % len(colors)] hash_color = colors[int(byte_data, 16) % len(colors)]
return { return {"name": name, "hash": hash_value, "color": hash_color, "kind": kind}
'name': name,
'hash': hash_value,
'color': hash_color,
'kind': kind
}
def _get_url(self, hash_value, size): def _get_url(self, hash_value, size):
""" Returns the URL for displaying the overlay avatar. """ """ Returns the URL for displaying the overlay avatar. """
return None return None
class GravatarAvatar(BaseAvatar): class GravatarAvatar(BaseAvatar):
""" Avatar system that uses gravatar for generating avatars. """ """ Avatar system that uses gravatar for generating avatars. """
def _get_url(self, hash_value, size=16):
return '%s://www.gravatar.com/avatar/%s?d=404&size=%s' % (self.preferred_url_scheme, def _get_url(self, hash_value, size=16):
hash_value, size) return "%s://www.gravatar.com/avatar/%s?d=404&size=%s" % (
self.preferred_url_scheme,
hash_value,
size,
)
class LocalAvatar(BaseAvatar): class LocalAvatar(BaseAvatar):
""" Avatar system that uses the local system for generating avatars. """ """ Avatar system that uses the local system for generating avatars. """
pass
AVATAR_CLASSES = { pass
'gravatar': GravatarAvatar,
'local': LocalAvatar
} AVATAR_CLASSES = {"gravatar": GravatarAvatar, "local": LocalAvatar}

165
boot.py
View file

@ -16,7 +16,7 @@ from data.model.release import set_region_release
from data.model.service_keys import get_service_key from data.model.service_keys import get_service_key
from util.config.database import sync_database_with_config from util.config.database import sync_database_with_config
from util.generatepresharedkey import generate_key from util.generatepresharedkey import generate_key
from _init import CONF_DIR from _init import CONF_DIR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,108 +24,117 @@ logger = logging.getLogger(__name__)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_audience(): def get_audience():
audience = app.config.get('JWTPROXY_AUDIENCE') audience = app.config.get("JWTPROXY_AUDIENCE")
if audience: if audience:
return audience return audience
scheme = app.config.get('PREFERRED_URL_SCHEME') scheme = app.config.get("PREFERRED_URL_SCHEME")
hostname = app.config.get('SERVER_HOSTNAME') hostname = app.config.get("SERVER_HOSTNAME")
# hostname includes port, use that # hostname includes port, use that
if ':' in hostname: if ":" in hostname:
return urlunparse((scheme, hostname, '', '', '', '')) return urlunparse((scheme, hostname, "", "", "", ""))
# no port, guess based on scheme # no port, guess based on scheme
if scheme == 'https': if scheme == "https":
port = '443' port = "443"
else: else:
port = '80' port = "80"
return urlunparse((scheme, hostname + ':' + port, '', '', '', '')) return urlunparse((scheme, hostname + ":" + port, "", "", "", ""))
def _verify_service_key(): def _verify_service_key():
try:
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION']) as f:
quay_key_id = f.read()
try: try:
get_service_key(quay_key_id, approved_only=False) with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"]) as f:
assert os.path.exists(app.config['INSTANCE_SERVICE_KEY_LOCATION']) quay_key_id = f.read()
return quay_key_id
except ServiceKeyDoesNotExist:
logger.exception('Could not find non-expired existing service key %s; creating a new one',
quay_key_id)
return None
# Found a valid service key, so exiting. try:
except IOError: get_service_key(quay_key_id, approved_only=False)
logger.exception('Could not load existing service key; creating a new one') assert os.path.exists(app.config["INSTANCE_SERVICE_KEY_LOCATION"])
return None return quay_key_id
except ServiceKeyDoesNotExist:
logger.exception(
"Could not find non-expired existing service key %s; creating a new one",
quay_key_id,
)
return None
# Found a valid service key, so exiting.
except IOError:
logger.exception("Could not load existing service key; creating a new one")
return None
def setup_jwt_proxy(): def setup_jwt_proxy():
""" """
Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration. Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration.
""" """
if os.path.exists(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml')): if os.path.exists(os.path.join(CONF_DIR, "jwtproxy_conf.yaml")):
# Proxy is already setup. Make sure the service key is still valid. # Proxy is already setup. Make sure the service key is still valid.
quay_key_id = _verify_service_key() quay_key_id = _verify_service_key()
if quay_key_id is not None: if quay_key_id is not None:
return return
# Ensure we have an existing key if in read-only mode. # Ensure we have an existing key if in read-only mode.
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly': if app.config.get("REGISTRY_STATE", "normal") == "readonly":
quay_key_id = _verify_service_key() quay_key_id = _verify_service_key()
if quay_key_id is None: if quay_key_id is None:
raise Exception('No valid service key found for read-only registry.') raise Exception("No valid service key found for read-only registry.")
else: else:
# Generate the key for this Quay instance to use. # Generate the key for this Quay instance to use.
minutes_until_expiration = app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120) minutes_until_expiration = app.config.get(
expiration = datetime.now() + timedelta(minutes=minutes_until_expiration) "INSTANCE_SERVICE_KEY_EXPIRATION", 120
quay_key, quay_key_id = generate_key(app.config['INSTANCE_SERVICE_KEY_SERVICE'], )
get_audience(), expiration_date=expiration) expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
quay_key, quay_key_id = generate_key(
app.config["INSTANCE_SERVICE_KEY_SERVICE"],
get_audience(),
expiration_date=expiration,
)
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'], mode='w') as f: with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"], mode="w") as f:
f.truncate(0) f.truncate(0)
f.write(quay_key_id) f.write(quay_key_id)
with open(app.config['INSTANCE_SERVICE_KEY_LOCATION'], mode='w') as f: with open(app.config["INSTANCE_SERVICE_KEY_LOCATION"], mode="w") as f:
f.truncate(0) f.truncate(0)
f.write(quay_key.exportKey()) f.write(quay_key.exportKey())
# Generate the JWT proxy configuration. # Generate the JWT proxy configuration.
audience = get_audience() audience = get_audience()
registry = audience + '/keys' registry = audience + "/keys"
security_issuer = app.config.get('SECURITY_SCANNER_ISSUER_NAME', 'security_scanner') security_issuer = app.config.get("SECURITY_SCANNER_ISSUER_NAME", "security_scanner")
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml.jnj')) as f: with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml.jnj")) as f:
template = Template(f.read()) template = Template(f.read())
rendered = template.render( rendered = template.render(
conf_dir=CONF_DIR, conf_dir=CONF_DIR,
audience=audience, audience=audience,
registry=registry, registry=registry,
key_id=quay_key_id, key_id=quay_key_id,
security_issuer=security_issuer, security_issuer=security_issuer,
service_key_location=app.config['INSTANCE_SERVICE_KEY_LOCATION'], service_key_location=app.config["INSTANCE_SERVICE_KEY_LOCATION"],
) )
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml'), 'w') as f: with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml"), "w") as f:
f.write(rendered) f.write(rendered)
def main(): def main():
if not app.config.get('SETUP_COMPLETE', False): if not app.config.get("SETUP_COMPLETE", False):
raise Exception('Your configuration bundle is either not mounted or setup has not been completed') raise Exception(
"Your configuration bundle is either not mounted or setup has not been completed"
)
sync_database_with_config(app.config) sync_database_with_config(app.config)
setup_jwt_proxy() setup_jwt_proxy()
# Record deploy # Record deploy
if release.REGION and release.GIT_HEAD: if release.REGION and release.GIT_HEAD:
set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD) set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -5,38 +5,39 @@ from trollius import get_event_loop, coroutine
def wrap_with_threadpool(obj, worker_threads=1): def wrap_with_threadpool(obj, worker_threads=1):
""" """
Wraps a class in an async executor so that it can be safely used in an event loop like trollius. Wraps a class in an async executor so that it can be safely used in an event loop like trollius.
""" """
async_executor = ThreadPoolExecutor(worker_threads) async_executor = ThreadPoolExecutor(worker_threads)
return AsyncWrapper(obj, executor=async_executor), async_executor return AsyncWrapper(obj, executor=async_executor), async_executor
class AsyncWrapper(object): class AsyncWrapper(object):
""" Wrapper class which will transform a syncronous library to one that can be used with """ Wrapper class which will transform a syncronous library to one that can be used with
trollius coroutines. trollius coroutines.
""" """
def __init__(self, delegate, loop=None, executor=None):
self._loop = loop if loop is not None else get_event_loop()
self._delegate = delegate
self._executor = executor
def __getattr__(self, attrib): def __init__(self, delegate, loop=None, executor=None):
delegate_attr = getattr(self._delegate, attrib) self._loop = loop if loop is not None else get_event_loop()
self._delegate = delegate
self._executor = executor
if not callable(delegate_attr): def __getattr__(self, attrib):
return delegate_attr delegate_attr = getattr(self._delegate, attrib)
def wrapper(*args, **kwargs): if not callable(delegate_attr):
""" Wraps the delegate_attr with primitives that will transform sync calls to ones shelled return delegate_attr
def wrapper(*args, **kwargs):
""" Wraps the delegate_attr with primitives that will transform sync calls to ones shelled
out to a thread pool. out to a thread pool.
""" """
callable_delegate_attr = partial(delegate_attr, *args, **kwargs) callable_delegate_attr = partial(delegate_attr, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr) return self._loop.run_in_executor(self._executor, callable_delegate_attr)
return wrapper return wrapper
@coroutine @coroutine
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
callable_delegate_attr = partial(self._delegate, *args, **kwargs) callable_delegate_attr = partial(self._delegate, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr) return self._loop.run_in_executor(self._executor, callable_delegate_attr)

View file

@ -18,80 +18,104 @@ from raven.conf import setup_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BUILD_MANAGERS = { BUILD_MANAGERS = {"enterprise": EnterpriseManager, "ephemeral": EphemeralBuilderManager}
'enterprise': EnterpriseManager,
'ephemeral': EphemeralBuilderManager,
}
EXTERNALLY_MANAGED = 'external' EXTERNALLY_MANAGED = "external"
DEFAULT_WEBSOCKET_PORT = 8787 DEFAULT_WEBSOCKET_PORT = 8787
DEFAULT_CONTROLLER_PORT = 8686 DEFAULT_CONTROLLER_PORT = 8686
LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s" LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
def run_build_manager(): def run_build_manager():
if not features.BUILD_SUPPORT: if not features.BUILD_SUPPORT:
logger.debug('Building is disabled. Please enable the feature flag') logger.debug("Building is disabled. Please enable the feature flag")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly': if app.config.get("REGISTRY_STATE", "normal") == "readonly":
logger.debug('Building is disabled while in read-only mode.') logger.debug("Building is disabled while in read-only mode.")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
build_manager_config = app.config.get('BUILD_MANAGER') build_manager_config = app.config.get("BUILD_MANAGER")
if build_manager_config is None: if build_manager_config is None:
return return
# If the build system is externally managed, then we just sleep this process. # If the build system is externally managed, then we just sleep this process.
if build_manager_config[0] == EXTERNALLY_MANAGED: if build_manager_config[0] == EXTERNALLY_MANAGED:
logger.debug('Builds are externally managed.') logger.debug("Builds are externally managed.")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
logger.debug('Asking to start build manager with lifecycle "%s"', build_manager_config[0]) logger.debug(
manager_klass = BUILD_MANAGERS.get(build_manager_config[0]) 'Asking to start build manager with lifecycle "%s"', build_manager_config[0]
if manager_klass is None: )
return manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
if manager_klass is None:
return
manager_hostname = os.environ.get('BUILDMAN_HOSTNAME', manager_hostname = os.environ.get(
app.config.get('BUILDMAN_HOSTNAME', "BUILDMAN_HOSTNAME",
app.config['SERVER_HOSTNAME'])) app.config.get("BUILDMAN_HOSTNAME", app.config["SERVER_HOSTNAME"]),
websocket_port = int(os.environ.get('BUILDMAN_WEBSOCKET_PORT', )
app.config.get('BUILDMAN_WEBSOCKET_PORT', websocket_port = int(
DEFAULT_WEBSOCKET_PORT))) os.environ.get(
controller_port = int(os.environ.get('BUILDMAN_CONTROLLER_PORT', "BUILDMAN_WEBSOCKET_PORT",
app.config.get('BUILDMAN_CONTROLLER_PORT', app.config.get("BUILDMAN_WEBSOCKET_PORT", DEFAULT_WEBSOCKET_PORT),
DEFAULT_CONTROLLER_PORT))) )
)
controller_port = int(
os.environ.get(
"BUILDMAN_CONTROLLER_PORT",
app.config.get("BUILDMAN_CONTROLLER_PORT", DEFAULT_CONTROLLER_PORT),
)
)
logger.debug('Will pass buildman hostname %s to builders for websocket connection', logger.debug(
manager_hostname) "Will pass buildman hostname %s to builders for websocket connection",
manager_hostname,
)
logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0]) logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
ssl_context = None ssl_context = None
if os.environ.get('SSL_CONFIG'): if os.environ.get("SSL_CONFIG"):
logger.debug('Loading SSL cert and key') logger.debug("Loading SSL cert and key")
ssl_context = SSLContext() ssl_context = SSLContext()
ssl_context.load_cert_chain(os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.cert'), ssl_context.load_cert_chain(
os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.key')) os.path.join(os.environ.get("SSL_CONFIG"), "ssl.cert"),
os.path.join(os.environ.get("SSL_CONFIG"), "ssl.key"),
)
server = BuilderServer(app.config['SERVER_HOSTNAME'], dockerfile_build_queue, build_logs, server = BuilderServer(
user_files, manager_klass, build_manager_config[1], manager_hostname) app.config["SERVER_HOSTNAME"],
server.run('0.0.0.0', websocket_port, controller_port, ssl=ssl_context) dockerfile_build_queue,
build_logs,
user_files,
manager_klass,
build_manager_config[1],
manager_hostname,
)
server.run("0.0.0.0", websocket_port, controller_port, ssl=ssl_context)
if __name__ == '__main__':
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
logging.getLogger('peewee').setLevel(logging.WARN)
logging.getLogger('boto').setLevel(logging.WARN)
if app.config.get('EXCEPTION_LOG_TYPE', 'FakeSentry') == 'Sentry': if __name__ == "__main__":
buildman_name = '%s:buildman' % socket.gethostname() logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
setup_logging(SentryHandler(app.config.get('SENTRY_DSN', ''), name=buildman_name, logging.getLogger("peewee").setLevel(logging.WARN)
level=logging.ERROR)) logging.getLogger("boto").setLevel(logging.WARN)
run_build_manager() if app.config.get("EXCEPTION_LOG_TYPE", "FakeSentry") == "Sentry":
buildman_name = "%s:buildman" % socket.gethostname()
setup_logging(
SentryHandler(
app.config.get("SENTRY_DSN", ""),
name=buildman_name,
level=logging.ERROR,
)
)
run_build_manager()

View file

@ -1,13 +1,15 @@
from autobahn.asyncio.wamp import ApplicationSession from autobahn.asyncio.wamp import ApplicationSession
class BaseComponent(ApplicationSession):
""" Base class for all registered component sessions in the server. """
def __init__(self, config, **kwargs):
ApplicationSession.__init__(self, config)
self.server = None
self.parent_manager = None
self.build_logs = None
self.user_files = None
def kind(self): class BaseComponent(ApplicationSession):
raise NotImplementedError """ Base class for all registered component sessions in the server. """
def __init__(self, config, **kwargs):
ApplicationSession.__init__(self, config)
self.server = None
self.parent_manager = None
self.build_logs = None
self.user_files = None
def kind(self):
raise NotImplementedError

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,16 @@
import re import re
def extract_current_step(current_status_string): def extract_current_step(current_status_string):
""" Attempts to extract the current step numeric identifier from the given status string. Returns the step """ Attempts to extract the current step numeric identifier from the given status string. Returns the step
number or None if none. number or None if none.
""" """
# Older format: `Step 12 :` # Older format: `Step 12 :`
# Newer format: `Step 4/13 :` # Newer format: `Step 4/13 :`
step_increment = re.search(r'Step ([0-9]+)/([0-9]+) :', current_status_string) step_increment = re.search(r"Step ([0-9]+)/([0-9]+) :", current_status_string)
if step_increment: if step_increment:
return int(step_increment.group(1)) return int(step_increment.group(1))
step_increment = re.search(r'Step ([0-9]+) :', current_status_string) step_increment = re.search(r"Step ([0-9]+) :", current_status_string)
if step_increment: if step_increment:
return int(step_increment.group(1)) return int(step_increment.group(1))

View file

@ -3,34 +3,62 @@ import pytest
from buildman.component.buildcomponent import BuildComponent from buildman.component.buildcomponent import BuildComponent
@pytest.mark.parametrize('input,expected_path,expected_file', [ @pytest.mark.parametrize(
("", "/", "Dockerfile"), "input,expected_path,expected_file",
("/", "/", "Dockerfile"), [
("/Dockerfile", "/", "Dockerfile"), ("", "/", "Dockerfile"),
("/server.Dockerfile", "/", "server.Dockerfile"), ("/", "/", "Dockerfile"),
("/somepath", "/somepath", "Dockerfile"), ("/Dockerfile", "/", "Dockerfile"),
("/somepath/", "/somepath", "Dockerfile"), ("/server.Dockerfile", "/", "server.Dockerfile"),
("/somepath/Dockerfile", "/somepath", "Dockerfile"), ("/somepath", "/somepath", "Dockerfile"),
("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"), ("/somepath/", "/somepath", "Dockerfile"),
("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"), ("/somepath/Dockerfile", "/somepath", "Dockerfile"),
("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"), ("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
("/somepath/some_other_path/Dockerfile", "/somepath/some_other_path", "Dockerfile"), ("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/server.Dockerfile", "/somepath/some_other_path", "server.Dockerfile"), ("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
]) (
"/somepath/some_other_path/Dockerfile",
"/somepath/some_other_path",
"Dockerfile",
),
(
"/somepath/some_other_path/server.Dockerfile",
"/somepath/some_other_path",
"server.Dockerfile",
),
],
)
def test_path_is_dockerfile(input, expected_path, expected_file): def test_path_is_dockerfile(input, expected_path, expected_file):
actual_path, actual_file = BuildComponent.name_and_path(input) actual_path, actual_file = BuildComponent.name_and_path(input)
assert actual_path == expected_path assert actual_path == expected_path
assert actual_file == expected_file assert actual_file == expected_file
@pytest.mark.parametrize('build_config,context,dockerfile_path', [
({}, '', ''), @pytest.mark.parametrize(
({'build_subdir': '/builddir/Dockerfile'}, '', '/builddir/Dockerfile'), "build_config,context,dockerfile_path",
({'context': '/builddir'}, '/builddir', ''), [
({'context': '/builddir', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'), ({}, "", ""),
({'context': '/some_other_dir/Dockerfile', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'), ({"build_subdir": "/builddir/Dockerfile"}, "", "/builddir/Dockerfile"),
({'context': '/', 'build_subdir':'Dockerfile'}, '/', 'Dockerfile') ({"context": "/builddir"}, "/builddir", ""),
]) (
{"context": "/builddir", "build_subdir": "/builddir/Dockerfile"},
"/builddir",
"Dockerfile",
),
(
{
"context": "/some_other_dir/Dockerfile",
"build_subdir": "/builddir/Dockerfile",
},
"/builddir",
"Dockerfile",
),
({"context": "/", "build_subdir": "Dockerfile"}, "/", "Dockerfile"),
],
)
def test_extract_dockerfile_args(build_config, context, dockerfile_path): def test_extract_dockerfile_args(build_config, context, dockerfile_path):
actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(build_config) actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(
assert context == actual_context build_config
assert dockerfile_path == actual_dockerfile_path )
assert context == actual_context
assert dockerfile_path == actual_dockerfile_path

View file

@ -3,14 +3,17 @@ import pytest
from buildman.component.buildparse import extract_current_step from buildman.component.buildparse import extract_current_step
@pytest.mark.parametrize('input,expected_step', [ @pytest.mark.parametrize(
("", None), "input,expected_step",
("Step a :", None), [
("Step 1 :", 1), ("", None),
("Step 1 : ", 1), ("Step a :", None),
("Step 1/2 : ", 1), ("Step 1 :", 1),
("Step 2/17 : ", 2), ("Step 1 : ", 1),
("Step 4/13 : ARG somearg=foo", 4), ("Step 1/2 : ", 1),
]) ("Step 2/17 : ", 2),
("Step 4/13 : ARG somearg=foo", 4),
],
)
def test_extract_current_step(input, expected_step): def test_extract_current_step(input, expected_step):
assert extract_current_step(input) == expected_step assert extract_current_step(input) == expected_step

View file

@ -1,21 +1,25 @@
from data.database import BUILD_PHASE from data.database import BUILD_PHASE
class BuildJobResult(object): class BuildJobResult(object):
""" Build job result enum """ """ Build job result enum """
INCOMPLETE = 'incomplete'
COMPLETE = 'complete' INCOMPLETE = "incomplete"
ERROR = 'error' COMPLETE = "complete"
ERROR = "error"
class BuildServerStatus(object): class BuildServerStatus(object):
""" Build server status enum """ """ Build server status enum """
STARTING = 'starting'
RUNNING = 'running' STARTING = "starting"
SHUTDOWN = 'shutting_down' RUNNING = "running"
EXCEPTION = 'exception' SHUTDOWN = "shutting_down"
EXCEPTION = "exception"
RESULT_PHASES = { RESULT_PHASES = {
BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR, BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,
BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE, BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE,
BuildJobResult.ERROR: BUILD_PHASE.ERROR, BuildJobResult.ERROR: BUILD_PHASE.ERROR,
} }

View file

@ -14,170 +14,196 @@ logger = logging.getLogger(__name__)
class BuildJobLoadException(Exception): class BuildJobLoadException(Exception):
""" Exception raised if a build job could not be instantiated for some reason. """ """ Exception raised if a build job could not be instantiated for some reason. """
pass
pass
class BuildJob(object): class BuildJob(object):
""" Represents a single in-progress build job. """ """ Represents a single in-progress build job. """
def __init__(self, job_item):
self.job_item = job_item
try: def __init__(self, job_item):
self.job_details = json.loads(job_item.body) self.job_item = job_item
self.build_notifier = BuildJobNotifier(self.build_uuid)
except ValueError:
raise BuildJobLoadException(
'Could not parse build queue item config with ID %s' % self.job_details['build_uuid']
)
@property try:
def retries_remaining(self): self.job_details = json.loads(job_item.body)
return self.job_item.retries_remaining self.build_notifier = BuildJobNotifier(self.build_uuid)
except ValueError:
raise BuildJobLoadException(
"Could not parse build queue item config with ID %s"
% self.job_details["build_uuid"]
)
def has_retries_remaining(self): @property
return self.job_item.retries_remaining > 0 def retries_remaining(self):
return self.job_item.retries_remaining
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None): def has_retries_remaining(self):
self.build_notifier.send_notification(kind, error_message, image_id, manifest_digests) return self.job_item.retries_remaining > 0
@lru_cache(maxsize=1) def send_notification(
def _load_repo_build(self): self, kind, error_message=None, image_id=None, manifest_digests=None
with UseThenDisconnect(app.config): ):
try: self.build_notifier.send_notification(
return model.build.get_repository_build(self.build_uuid) kind, error_message, image_id, manifest_digests
except model.InvalidRepositoryBuildException: )
raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid)
@property @lru_cache(maxsize=1)
def build_uuid(self): def _load_repo_build(self):
""" Returns the unique UUID for this build job. """ with UseThenDisconnect(app.config):
return self.job_details['build_uuid'] try:
return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException:
raise BuildJobLoadException(
"Could not load repository build with ID %s" % self.build_uuid
)
@property @property
def namespace(self): def build_uuid(self):
""" Returns the namespace under which this build is running. """ """ Returns the unique UUID for this build job. """
return self.repo_build.repository.namespace_user.username return self.job_details["build_uuid"]
@property @property
def repo_name(self): def namespace(self):
""" Returns the name of the repository under which this build is running. """ """ Returns the namespace under which this build is running. """
return self.repo_build.repository.name return self.repo_build.repository.namespace_user.username
@property @property
def repo_build(self): def repo_name(self):
return self._load_repo_build() """ Returns the name of the repository under which this build is running. """
return self.repo_build.repository.name
def get_build_package_url(self, user_files): @property
""" Returns the URL of the build package for this build, if any or empty string if none. """ def repo_build(self):
archive_url = self.build_config.get('archive_url', None) return self._load_repo_build()
if archive_url:
return archive_url
if not self.repo_build.resource_key: def get_build_package_url(self, user_files):
return '' """ Returns the URL of the build package for this build, if any or empty string if none. """
archive_url = self.build_config.get("archive_url", None)
if archive_url:
return archive_url
return user_files.get_file_url(self.repo_build.resource_key, '127.0.0.1', requires_cors=False) if not self.repo_build.resource_key:
return ""
@property return user_files.get_file_url(
def pull_credentials(self): self.repo_build.resource_key, "127.0.0.1", requires_cors=False
""" Returns the pull credentials for this job, or None if none. """ )
return self.job_details.get('pull_credentials')
@property @property
def build_config(self): def pull_credentials(self):
try: """ Returns the pull credentials for this job, or None if none. """
return json.loads(self.repo_build.job_config) return self.job_details.get("pull_credentials")
except ValueError:
raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.job_details['build_uuid']
)
def determine_cached_tag(self, base_image_id=None, cache_comments=None): @property
""" Returns the tag to pull to prime the cache or None if none. """ def build_config(self):
cached_tag = self._determine_cached_tag_by_tag() try:
logger.debug('Determined cached tag %s for %s: %s', cached_tag, base_image_id, cache_comments) return json.loads(self.repo_build.job_config)
return cached_tag except ValueError:
raise BuildJobLoadException(
"Could not parse repository build job config with ID %s"
% self.job_details["build_uuid"]
)
def _determine_cached_tag_by_tag(self): def determine_cached_tag(self, base_image_id=None, cache_comments=None):
""" Determines the cached tag by looking for one of the tags being built, and seeing if it """ Returns the tag to pull to prime the cache or None if none. """
cached_tag = self._determine_cached_tag_by_tag()
logger.debug(
"Determined cached tag %s for %s: %s",
cached_tag,
base_image_id,
cache_comments,
)
return cached_tag
def _determine_cached_tag_by_tag(self):
""" Determines the cached tag by looking for one of the tags being built, and seeing if it
exists in the repository. This is a fallback for when no comment information is available. exists in the repository. This is a fallback for when no comment information is available.
""" """
with UseThenDisconnect(app.config): with UseThenDisconnect(app.config):
tags = self.build_config.get('docker_tags', ['latest']) tags = self.build_config.get("docker_tags", ["latest"])
repository = RepositoryReference.for_repo_obj(self.repo_build.repository) repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
matching_tag = registry_model.find_matching_tag(repository, tags) matching_tag = registry_model.find_matching_tag(repository, tags)
if matching_tag is not None: if matching_tag is not None:
return matching_tag.name return matching_tag.name
most_recent_tag = registry_model.get_most_recent_tag(repository) most_recent_tag = registry_model.get_most_recent_tag(repository)
if most_recent_tag is not None: if most_recent_tag is not None:
return most_recent_tag.name return most_recent_tag.name
return None return None
class BuildJobNotifier(object): class BuildJobNotifier(object):
""" A class for sending notifications to a job that only relies on the build_uuid """ """ A class for sending notifications to a job that only relies on the build_uuid """
def __init__(self, build_uuid): def __init__(self, build_uuid):
self.build_uuid = build_uuid self.build_uuid = build_uuid
@property @property
def repo_build(self): def repo_build(self):
return self._load_repo_build() return self._load_repo_build()
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def _load_repo_build(self): def _load_repo_build(self):
try: try:
return model.build.get_repository_build(self.build_uuid) return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException: except model.InvalidRepositoryBuildException:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid) "Could not load repository build with ID %s" % self.build_uuid
)
@property @property
def build_config(self): def build_config(self):
try: try:
return json.loads(self.repo_build.job_config) return json.loads(self.repo_build.job_config)
except ValueError: except ValueError:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.repo_build.uuid "Could not parse repository build job config with ID %s"
) % self.repo_build.uuid
)
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None): def send_notification(
with UseThenDisconnect(app.config): self, kind, error_message=None, image_id=None, manifest_digests=None
tags = self.build_config.get('docker_tags', ['latest']) ):
trigger = self.repo_build.trigger with UseThenDisconnect(app.config):
if trigger is not None and trigger.id is not None: tags = self.build_config.get("docker_tags", ["latest"])
trigger_kind = trigger.service.name trigger = self.repo_build.trigger
else: if trigger is not None and trigger.id is not None:
trigger_kind = None trigger_kind = trigger.service.name
else:
trigger_kind = None
event_data = { event_data = {
'build_id': self.repo_build.uuid, "build_id": self.repo_build.uuid,
'build_name': self.repo_build.display_name, "build_name": self.repo_build.display_name,
'docker_tags': tags, "docker_tags": tags,
'trigger_id': trigger.uuid if trigger is not None else None, "trigger_id": trigger.uuid if trigger is not None else None,
'trigger_kind': trigger_kind, "trigger_kind": trigger_kind,
'trigger_metadata': self.build_config.get('trigger_metadata', {}) "trigger_metadata": self.build_config.get("trigger_metadata", {}),
} }
if image_id is not None: if image_id is not None:
event_data['image_id'] = image_id event_data["image_id"] = image_id
if manifest_digests: if manifest_digests:
event_data['manifest_digests'] = manifest_digests event_data["manifest_digests"] = manifest_digests
if error_message is not None: if error_message is not None:
event_data['error_message'] = error_message event_data["error_message"] = error_message
# TODO: remove when more endpoints have been converted to using # TODO: remove when more endpoints have been converted to using
# interfaces # interfaces
repo = AttrDict({ repo = AttrDict(
'namespace_name': self.repo_build.repository.namespace_user.username, {
'name': self.repo_build.repository.name, "namespace_name": self.repo_build.repository.namespace_user.username,
}) "name": self.repo_build.repository.name,
spawn_notification(repo, kind, event_data, }
subpage='build/%s' % self.repo_build.uuid, )
pathargs=['build', self.repo_build.uuid]) spawn_notification(
repo,
kind,
event_data,
subpage="build/%s" % self.repo_build.uuid,
pathargs=["build", self.repo_build.uuid],
)

View file

@ -13,76 +13,94 @@ logger = logging.getLogger(__name__)
class StatusHandler(object): class StatusHandler(object):
""" Context wrapper for writing status to build logs. """ """ Context wrapper for writing status to build logs. """
def __init__(self, build_logs, repository_build_uuid): def __init__(self, build_logs, repository_build_uuid):
self._current_phase = None self._current_phase = None
self._current_command = None self._current_command = None
self._uuid = repository_build_uuid self._uuid = repository_build_uuid
self._build_logs = AsyncWrapper(build_logs) self._build_logs = AsyncWrapper(build_logs)
self._sync_build_logs = build_logs self._sync_build_logs = build_logs
self._build_model = AsyncWrapper(model.build) self._build_model = AsyncWrapper(model.build)
self._status = { self._status = {
'total_commands': 0, "total_commands": 0,
'current_command': None, "current_command": None,
'push_completion': 0.0, "push_completion": 0.0,
'pull_completion': 0.0, "pull_completion": 0.0,
} }
# Write the initial status. # Write the initial status.
self.__exit__(None, None, None) self.__exit__(None, None, None)
@coroutine @coroutine
def _append_log_message(self, log_message, log_type=None, log_data=None): def _append_log_message(self, log_message, log_type=None, log_data=None):
log_data = log_data or {} log_data = log_data or {}
log_data['datetime'] = str(datetime.datetime.now()) log_data["datetime"] = str(datetime.datetime.now())
try: try:
yield From(self._build_logs.append_log_message(self._uuid, log_message, log_type, log_data)) yield From(
except RedisError: self._build_logs.append_log_message(
logger.exception('Could not save build log for build %s: %s', self._uuid, log_message) self._uuid, log_message, log_type, log_data
)
)
except RedisError:
logger.exception(
"Could not save build log for build %s: %s", self._uuid, log_message
)
@coroutine @coroutine
def append_log(self, log_message, extra_data=None): def append_log(self, log_message, extra_data=None):
if log_message is None: if log_message is None:
return return
yield From(self._append_log_message(log_message, log_data=extra_data)) yield From(self._append_log_message(log_message, log_data=extra_data))
@coroutine @coroutine
def set_command(self, command, extra_data=None): def set_command(self, command, extra_data=None):
if self._current_command == command: if self._current_command == command:
raise Return() raise Return()
self._current_command = command self._current_command = command
yield From(self._append_log_message(command, self._build_logs.COMMAND, extra_data)) yield From(
self._append_log_message(command, self._build_logs.COMMAND, extra_data)
)
@coroutine @coroutine
def set_error(self, error_message, extra_data=None, internal_error=False, requeued=False): def set_error(
error_phase = BUILD_PHASE.INTERNAL_ERROR if internal_error and requeued else BUILD_PHASE.ERROR self, error_message, extra_data=None, internal_error=False, requeued=False
yield From(self.set_phase(error_phase)) ):
error_phase = (
BUILD_PHASE.INTERNAL_ERROR
if internal_error and requeued
else BUILD_PHASE.ERROR
)
yield From(self.set_phase(error_phase))
extra_data = extra_data or {} extra_data = extra_data or {}
extra_data['internal_error'] = internal_error extra_data["internal_error"] = internal_error
yield From(self._append_log_message(error_message, self._build_logs.ERROR, extra_data)) yield From(
self._append_log_message(error_message, self._build_logs.ERROR, extra_data)
)
@coroutine @coroutine
def set_phase(self, phase, extra_data=None): def set_phase(self, phase, extra_data=None):
if phase == self._current_phase: if phase == self._current_phase:
raise Return(False) raise Return(False)
self._current_phase = phase self._current_phase = phase
yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data)) yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data))
# Update the repository build with the new phase # Update the repository build with the new phase
raise Return(self._build_model.update_phase_then_close(self._uuid, phase)) raise Return(self._build_model.update_phase_then_close(self._uuid, phase))
def __enter__(self): def __enter__(self):
return self._status return self._status
def __exit__(self, exc_type, value, traceback): def __exit__(self, exc_type, value, traceback):
try: try:
self._sync_build_logs.set_status(self._uuid, self._status) self._sync_build_logs.set_status(self._uuid, self._status)
except RedisError: except RedisError:
logger.exception('Could not set status of build %s to %s', self._uuid, self._status) logger.exception(
"Could not set status of build %s to %s", self._uuid, self._status
)

View file

@ -1,119 +1,99 @@
class WorkerError(object): class WorkerError(object):
""" Helper class which represents errors raised by a build worker. """ """ Helper class which represents errors raised by a build worker. """
def __init__(self, error_code, base_message=None):
self._error_code = error_code
self._base_message = base_message
self._error_handlers = { def __init__(self, error_code, base_message=None):
'io.quay.builder.buildpackissue': { self._error_code = error_code
'message': 'Could not load build package', self._base_message = base_message
'is_internal': True,
},
'io.quay.builder.gitfailure': { self._error_handlers = {
'message': 'Could not clone git repository', "io.quay.builder.buildpackissue": {
'show_base_error': True, "message": "Could not load build package",
}, "is_internal": True,
},
"io.quay.builder.gitfailure": {
"message": "Could not clone git repository",
"show_base_error": True,
},
"io.quay.builder.gitcheckout": {
"message": "Could not checkout git ref. If you force pushed recently, "
+ "the commit may be missing.",
"show_base_error": True,
},
"io.quay.builder.cannotextractbuildpack": {
"message": "Could not extract the contents of the build package"
},
"io.quay.builder.cannotpullforcache": {
"message": "Could not pull cached image",
"is_internal": True,
},
"io.quay.builder.dockerfileissue": {
"message": "Could not find or parse Dockerfile",
"show_base_error": True,
},
"io.quay.builder.cannotpullbaseimage": {
"message": "Could not pull base image",
"show_base_error": True,
},
"io.quay.builder.internalerror": {
"message": "An internal error occurred while building. Please submit a ticket.",
"is_internal": True,
},
"io.quay.builder.buildrunerror": {
"message": "Could not start the build process",
"is_internal": True,
},
"io.quay.builder.builderror": {
"message": "A build step failed",
"show_base_error": True,
},
"io.quay.builder.tagissue": {
"message": "Could not tag built image",
"is_internal": True,
},
"io.quay.builder.pushissue": {
"message": "Could not push built image",
"show_base_error": True,
"is_internal": True,
},
"io.quay.builder.dockerconnecterror": {
"message": "Could not connect to Docker daemon",
"is_internal": True,
},
"io.quay.builder.missingorinvalidargument": {
"message": "Missing required arguments for builder",
"is_internal": True,
},
"io.quay.builder.cachelookupissue": {
"message": "Error checking for a cached tag",
"is_internal": True,
},
"io.quay.builder.errorduringphasetransition": {
"message": "Error during phase transition. If this problem persists "
+ "please contact customer support.",
"is_internal": True,
},
"io.quay.builder.clientrejectedtransition": {
"message": "Build can not be finished due to user cancellation."
},
}
'io.quay.builder.gitcheckout': { def is_internal_error(self):
'message': 'Could not checkout git ref. If you force pushed recently, ' + handler = self._error_handlers.get(self._error_code)
'the commit may be missing.', return handler.get("is_internal", False) if handler else True
'show_base_error': True,
},
'io.quay.builder.cannotextractbuildpack': { def public_message(self):
'message': 'Could not extract the contents of the build package' handler = self._error_handlers.get(self._error_code)
}, if not handler:
return "An unknown error occurred"
'io.quay.builder.cannotpullforcache': { message = handler["message"]
'message': 'Could not pull cached image', if handler.get("show_base_error", False) and self._base_message:
'is_internal': True message = message + ": " + self._base_message
},
'io.quay.builder.dockerfileissue': { return message
'message': 'Could not find or parse Dockerfile',
'show_base_error': True
},
'io.quay.builder.cannotpullbaseimage': { def extra_data(self):
'message': 'Could not pull base image', if self._base_message:
'show_base_error': True return {"base_error": self._base_message, "error_code": self._error_code}
},
'io.quay.builder.internalerror': { return {"error_code": self._error_code}
'message': 'An internal error occurred while building. Please submit a ticket.',
'is_internal': True
},
'io.quay.builder.buildrunerror': {
'message': 'Could not start the build process',
'is_internal': True
},
'io.quay.builder.builderror': {
'message': 'A build step failed',
'show_base_error': True
},
'io.quay.builder.tagissue': {
'message': 'Could not tag built image',
'is_internal': True
},
'io.quay.builder.pushissue': {
'message': 'Could not push built image',
'show_base_error': True,
'is_internal': True
},
'io.quay.builder.dockerconnecterror': {
'message': 'Could not connect to Docker daemon',
'is_internal': True
},
'io.quay.builder.missingorinvalidargument': {
'message': 'Missing required arguments for builder',
'is_internal': True
},
'io.quay.builder.cachelookupissue': {
'message': 'Error checking for a cached tag',
'is_internal': True
},
'io.quay.builder.errorduringphasetransition': {
'message': 'Error during phase transition. If this problem persists ' +
'please contact customer support.',
'is_internal': True
},
'io.quay.builder.clientrejectedtransition': {
'message': 'Build can not be finished due to user cancellation.',
}
}
def is_internal_error(self):
handler = self._error_handlers.get(self._error_code)
return handler.get('is_internal', False) if handler else True
def public_message(self):
handler = self._error_handlers.get(self._error_code)
if not handler:
return 'An unknown error occurred'
message = handler['message']
if handler.get('show_base_error', False) and self._base_message:
message = message + ': ' + self._base_message
return message
def extra_data(self):
if self._base_message:
return {
'base_error': self._base_message,
'error_code': self._error_code
}
return {
'error_code': self._error_code
}

View file

@ -1,71 +1,80 @@
from trollius import coroutine from trollius import coroutine
class BaseManager(object): class BaseManager(object):
""" Base for all worker managers. """ """ Base for all worker managers. """
def __init__(self, register_component, unregister_component, job_heartbeat_callback,
job_complete_callback, manager_hostname, heartbeat_period_sec):
self.register_component = register_component
self.unregister_component = unregister_component
self.job_heartbeat_callback = job_heartbeat_callback
self.job_complete_callback = job_complete_callback
self.manager_hostname = manager_hostname
self.heartbeat_period_sec = heartbeat_period_sec
@coroutine def __init__(
def job_heartbeat(self, build_job): self,
""" Method invoked to tell the manager that a job is still running. This method will be called register_component,
unregister_component,
job_heartbeat_callback,
job_complete_callback,
manager_hostname,
heartbeat_period_sec,
):
self.register_component = register_component
self.unregister_component = unregister_component
self.job_heartbeat_callback = job_heartbeat_callback
self.job_complete_callback = job_complete_callback
self.manager_hostname = manager_hostname
self.heartbeat_period_sec = heartbeat_period_sec
@coroutine
def job_heartbeat(self, build_job):
""" Method invoked to tell the manager that a job is still running. This method will be called
every few minutes. """ every few minutes. """
self.job_heartbeat_callback(build_job) self.job_heartbeat_callback(build_job)
def overall_setup_time(self): def overall_setup_time(self):
""" Returns the number of seconds that the build system should wait before allowing the job """ Returns the number of seconds that the build system should wait before allowing the job
to be picked up again after called 'schedule'. to be picked up again after called 'schedule'.
""" """
raise NotImplementedError raise NotImplementedError
def shutdown(self): def shutdown(self):
""" Indicates that the build controller server is in a shutdown state and that no new jobs """ Indicates that the build controller server is in a shutdown state and that no new jobs
or workers should be performed. Existing workers should be cleaned up once their jobs or workers should be performed. Existing workers should be cleaned up once their jobs
have completed have completed
""" """
raise NotImplementedError raise NotImplementedError
@coroutine @coroutine
def schedule(self, build_job): def schedule(self, build_job):
""" Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was """ Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was
properly scheduled and (False, a retry timeout in seconds) if all workers are busy or an properly scheduled and (False, a retry timeout in seconds) if all workers are busy or an
error occurs. error occurs.
""" """
raise NotImplementedError raise NotImplementedError
def initialize(self, manager_config): def initialize(self, manager_config):
""" Runs any initialization code for the manager. Called once the server is in a ready state. """ Runs any initialization code for the manager. Called once the server is in a ready state.
""" """
raise NotImplementedError raise NotImplementedError
@coroutine @coroutine
def build_component_ready(self, build_component): def build_component_ready(self, build_component):
""" Method invoked whenever a build component announces itself as ready. """ Method invoked whenever a build component announces itself as ready.
""" """
raise NotImplementedError raise NotImplementedError
def build_component_disposed(self, build_component, timed_out): def build_component_disposed(self, build_component, timed_out):
""" Method invoked whenever a build component has been disposed. The timed_out boolean indicates """ Method invoked whenever a build component has been disposed. The timed_out boolean indicates
whether the component's heartbeat timed out. whether the component's heartbeat timed out.
""" """
raise NotImplementedError raise NotImplementedError
@coroutine @coroutine
def job_completed(self, build_job, job_status, build_component): def job_completed(self, build_job, job_status, build_component):
""" Method invoked once a job_item has completed, in some manner. The job_status will be """ Method invoked once a job_item has completed, in some manner. The job_status will be
one of: incomplete, error, complete. Implementations of this method should call coroutine one of: incomplete, error, complete. Implementations of this method should call coroutine
self.job_complete_callback with a status of Incomplete if they wish for the job to be self.job_complete_callback with a status of Incomplete if they wish for the job to be
automatically requeued. automatically requeued.
""" """
raise NotImplementedError raise NotImplementedError
def num_workers(self): def num_workers(self):
""" Returns the number of active build workers currently registered. This includes those """ Returns the number of active build workers currently registered. This includes those
that are currently busy and awaiting more work. that are currently busy and awaiting more work.
""" """
raise NotImplementedError raise NotImplementedError

View file

@ -5,23 +5,23 @@ from buildman.manager.noop_canceller import NoopCanceller
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CANCELLERS = {'ephemeral': OrchestratorCanceller} CANCELLERS = {"ephemeral": OrchestratorCanceller}
class BuildCanceller(object): class BuildCanceller(object):
""" A class to manage cancelling a build """ """ A class to manage cancelling a build """
def __init__(self, app=None): def __init__(self, app=None):
self.build_manager_config = app.config.get('BUILD_MANAGER') self.build_manager_config = app.config.get("BUILD_MANAGER")
if app is None or self.build_manager_config is None: if app is None or self.build_manager_config is None:
self.handler = NoopCanceller() self.handler = NoopCanceller()
else: else:
self.handler = None self.handler = None
def try_cancel_build(self, uuid): def try_cancel_build(self, uuid):
""" A method to kill a running build """ """ A method to kill a running build """
if self.handler is None: if self.handler is None:
canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller) canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller)
self.handler = canceller(self.build_manager_config[1]) self.handler = canceller(self.build_manager_config[1])
return self.handler.try_cancel_build(uuid) return self.handler.try_cancel_build(uuid)

View file

@ -7,86 +7,89 @@ from buildman.manager.basemanager import BaseManager
from trollius import From, Return, coroutine from trollius import From, Return, coroutine
REGISTRATION_REALM = 'registration' REGISTRATION_REALM = "registration"
RETRY_TIMEOUT = 5 RETRY_TIMEOUT = 5
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DynamicRegistrationComponent(BaseComponent): class DynamicRegistrationComponent(BaseComponent):
""" Component session that handles dynamic registration of the builder components. """ """ Component session that handles dynamic registration of the builder components. """
def onConnect(self): def onConnect(self):
self.join(REGISTRATION_REALM) self.join(REGISTRATION_REALM)
def onJoin(self, details): def onJoin(self, details):
logger.debug('Registering registration method') logger.debug("Registering registration method")
yield From(self.register(self._worker_register, u'io.quay.buildworker.register')) yield From(
self.register(self._worker_register, u"io.quay.buildworker.register")
)
def _worker_register(self): def _worker_register(self):
realm = self.parent_manager.add_build_component() realm = self.parent_manager.add_build_component()
logger.debug('Registering new build component+worker with realm %s', realm) logger.debug("Registering new build component+worker with realm %s", realm)
return realm return realm
def kind(self): def kind(self):
return 'registration' return "registration"
class EnterpriseManager(BaseManager): class EnterpriseManager(BaseManager):
""" Build manager implementation for the Enterprise Registry. """ """ Build manager implementation for the Enterprise Registry. """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.ready_components = set() self.ready_components = set()
self.all_components = set() self.all_components = set()
self.shutting_down = False self.shutting_down = False
super(EnterpriseManager, self).__init__(*args, **kwargs) super(EnterpriseManager, self).__init__(*args, **kwargs)
def initialize(self, manager_config): def initialize(self, manager_config):
# Add a component which is used by build workers for dynamic registration. Unlike # Add a component which is used by build workers for dynamic registration. Unlike
# production, build workers in enterprise are long-lived and register dynamically. # production, build workers in enterprise are long-lived and register dynamically.
self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent) self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent)
def overall_setup_time(self): def overall_setup_time(self):
# Builders are already registered, so the setup time should be essentially instant. We therefore # Builders are already registered, so the setup time should be essentially instant. We therefore
# only return a minute here. # only return a minute here.
return 60 return 60
def add_build_component(self): def add_build_component(self):
""" Adds a new build component for an Enterprise Registry. """ """ Adds a new build component for an Enterprise Registry. """
# Generate a new unique realm ID for the build worker. # Generate a new unique realm ID for the build worker.
realm = str(uuid.uuid4()) realm = str(uuid.uuid4())
new_component = self.register_component(realm, BuildComponent, token="") new_component = self.register_component(realm, BuildComponent, token="")
self.all_components.add(new_component) self.all_components.add(new_component)
return realm return realm
@coroutine @coroutine
def schedule(self, build_job): def schedule(self, build_job):
""" Schedules a build for an Enterprise Registry. """ """ Schedules a build for an Enterprise Registry. """
if self.shutting_down or not self.ready_components: if self.shutting_down or not self.ready_components:
raise Return(False, RETRY_TIMEOUT) raise Return(False, RETRY_TIMEOUT)
component = self.ready_components.pop() component = self.ready_components.pop()
yield From(component.start_build(build_job)) yield From(component.start_build(build_job))
raise Return(True, None) raise Return(True, None)
@coroutine @coroutine
def build_component_ready(self, build_component): def build_component_ready(self, build_component):
self.ready_components.add(build_component) self.ready_components.add(build_component)
def shutdown(self): def shutdown(self):
self.shutting_down = True self.shutting_down = True
@coroutine @coroutine
def job_completed(self, build_job, job_status, build_component): def job_completed(self, build_job, job_status, build_component):
yield From(self.job_complete_callback(build_job, job_status)) yield From(self.job_complete_callback(build_job, job_status))
def build_component_disposed(self, build_component, timed_out): def build_component_disposed(self, build_component, timed_out):
self.all_components.remove(build_component) self.all_components.remove(build_component)
if build_component in self.ready_components: if build_component in self.ready_components:
self.ready_components.remove(build_component) self.ready_components.remove(build_component)
self.unregister_component(build_component) self.unregister_component(build_component)
def num_workers(self): def num_workers(self):
return len(self.all_components) return len(self.all_components)

View file

@ -5,33 +5,36 @@ logger = logging.getLogger(__name__)
class EtcdCanceller(object): class EtcdCanceller(object):
""" A class that sends a message to etcd to cancel a build """ """ A class that sends a message to etcd to cancel a build """
def __init__(self, config): def __init__(self, config):
etcd_host = config.get('ETCD_HOST', '127.0.0.1') etcd_host = config.get("ETCD_HOST", "127.0.0.1")
etcd_port = config.get('ETCD_PORT', 2379) etcd_port = config.get("ETCD_PORT", 2379)
etcd_ca_cert = config.get('ETCD_CA_CERT', None) etcd_ca_cert = config.get("ETCD_CA_CERT", None)
etcd_auth = config.get('ETCD_CERT_AND_KEY', None) etcd_auth = config.get("ETCD_CERT_AND_KEY", None)
if etcd_auth is not None: if etcd_auth is not None:
etcd_auth = tuple(etcd_auth) etcd_auth = tuple(etcd_auth)
etcd_protocol = 'http' if etcd_auth is None else 'https' etcd_protocol = "http" if etcd_auth is None else "https"
logger.debug('Connecting to etcd on %s:%s', etcd_host, etcd_port) logger.debug("Connecting to etcd on %s:%s", etcd_host, etcd_port)
self._cancel_prefix = config.get('ETCD_CANCEL_PREFIX', 'cancel/') self._cancel_prefix = config.get("ETCD_CANCEL_PREFIX", "cancel/")
self._etcd_client = etcd.Client( self._etcd_client = etcd.Client(
host=etcd_host, host=etcd_host,
port=etcd_port, port=etcd_port,
cert=etcd_auth, cert=etcd_auth,
ca_cert=etcd_ca_cert, ca_cert=etcd_ca_cert,
protocol=etcd_protocol, protocol=etcd_protocol,
read_timeout=5) read_timeout=5,
)
def try_cancel_build(self, build_uuid): def try_cancel_build(self, build_uuid):
""" Writes etcd message to cancel build_uuid. """ """ Writes etcd message to cancel build_uuid. """
logger.info("Cancelling build %s".format(build_uuid)) logger.info("Cancelling build %s".format(build_uuid))
try: try:
self._etcd_client.write("{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60) self._etcd_client.write(
return True "{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60
except etcd.EtcdException: )
logger.exception("Failed to write to etcd client %s", build_uuid) return True
return False except etcd.EtcdException:
logger.exception("Failed to write to etcd client %s", build_uuid)
return False

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,9 @@
class NoopCanceller(object): class NoopCanceller(object):
""" A class that can not cancel a build """ """ A class that can not cancel a build """
def __init__(self, config=None):
pass
def try_cancel_build(self, uuid): def __init__(self, config=None):
""" Does nothing and fails to cancel build. """ pass
return False
def try_cancel_build(self, uuid):
""" Does nothing and fails to cancel build. """
return False

View file

@ -7,20 +7,23 @@ from util import slash_join
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CANCEL_PREFIX = 'cancel/' CANCEL_PREFIX = "cancel/"
class OrchestratorCanceller(object): class OrchestratorCanceller(object):
""" An asynchronous way to cancel a build with any Orchestrator. """ """ An asynchronous way to cancel a build with any Orchestrator. """
def __init__(self, config):
self._orchestrator = orchestrator_from_config(config, canceller_only=True)
def try_cancel_build(self, build_uuid): def __init__(self, config):
logger.info('Cancelling build %s', build_uuid) self._orchestrator = orchestrator_from_config(config, canceller_only=True)
cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
try: def try_cancel_build(self, build_uuid):
self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60) logger.info("Cancelling build %s", build_uuid)
return True cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
except OrchestratorError: try:
logger.exception('Failed to write cancel action to redis with uuid %s', build_uuid) self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
return False return True
except OrchestratorError:
logger.exception(
"Failed to write cancel action to redis with uuid %s", build_uuid
)
return False

File diff suppressed because it is too large Load diff

View file

@ -2,4 +2,3 @@ import buildtrigger.bitbuckethandler
import buildtrigger.customhandler import buildtrigger.customhandler
import buildtrigger.githubhandler import buildtrigger.githubhandler
import buildtrigger.gitlabhandler import buildtrigger.gitlabhandler

View file

@ -9,359 +9,360 @@ from data import model
from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException
NAMESPACES_SCHEMA = { NAMESPACES_SCHEMA = {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {
'personal': { "personal": {
'type': 'boolean', "type": "boolean",
'description': 'True if the namespace is the user\'s personal namespace', "description": "True if the namespace is the user's personal namespace",
}, },
'score': { "score": {
'type': 'number', "type": "number",
'description': 'Score of the relevance of the namespace', "description": "Score of the relevance of the namespace",
}, },
'avatar_url': { "avatar_url": {
'type': ['string', 'null'], "type": ["string", "null"],
'description': 'URL of the avatar for this namespace', "description": "URL of the avatar for this namespace",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'URL of the website to view the namespace', "description": "URL of the website to view the namespace",
}, },
'id': { "id": {
'type': 'string', "type": "string",
'description': 'Trigger-internal ID of the namespace', "description": "Trigger-internal ID of the namespace",
}, },
'title': { "title": {
'type': 'string', "type": "string",
'description': 'Human-readable title of the namespace', "description": "Human-readable title of the namespace",
}, },
},
"required": ["personal", "score", "avatar_url", "id", "title"],
}, },
'required': ['personal', 'score', 'avatar_url', 'id', 'title'],
},
} }
BUILD_SOURCES_SCHEMA = { BUILD_SOURCES_SCHEMA = {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {
'name': { "name": {
'type': 'string', "type": "string",
'description': 'The name of the repository, without its namespace', "description": "The name of the repository, without its namespace",
}, },
'full_name': { "full_name": {
'type': 'string', "type": "string",
'description': 'The name of the repository, with its namespace', "description": "The name of the repository, with its namespace",
}, },
'description': { "description": {
'type': 'string', "type": "string",
'description': 'The description of the repository. May be an empty string', "description": "The description of the repository. May be an empty string",
}, },
'last_updated': { "last_updated": {
'type': 'number', "type": "number",
'description': 'The date/time when the repository was last updated, since epoch in UTC', "description": "The date/time when the repository was last updated, since epoch in UTC",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'The URL at which to view the repository in the browser', "description": "The URL at which to view the repository in the browser",
}, },
'has_admin_permissions': { "has_admin_permissions": {
'type': 'boolean', "type": "boolean",
'description': 'True if the current user has admin permissions on the repository', "description": "True if the current user has admin permissions on the repository",
}, },
'private': { "private": {
'type': 'boolean', "type": "boolean",
'description': 'True if the repository is private', "description": "True if the repository is private",
}, },
},
"required": [
"name",
"full_name",
"description",
"last_updated",
"has_admin_permissions",
"private",
],
}, },
'required': ['name', 'full_name', 'description', 'last_updated',
'has_admin_permissions', 'private'],
},
} }
METADATA_SCHEMA = { METADATA_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'commit': { "commit": {
'type': 'string', "type": "string",
'description': 'first 7 characters of the SHA-1 identifier for a git commit', "description": "first 7 characters of the SHA-1 identifier for a git commit",
'pattern': '^([A-Fa-f0-9]{7,})$', "pattern": "^([A-Fa-f0-9]{7,})$",
},
'git_url': {
'type': 'string',
'description': 'The GIT url to use for the checkout',
},
'ref': {
'type': 'string',
'description': 'git reference for a git commit',
'pattern': r'^refs\/(heads|tags|remotes)\/(.+)$',
},
'default_branch': {
'type': 'string',
'description': 'default branch of the git repository',
},
'commit_info': {
'type': 'object',
'description': 'metadata about a git commit',
'properties': {
'url': {
'type': 'string',
'description': 'URL to view a git commit',
}, },
'message': { "git_url": {
'type': 'string', "type": "string",
'description': 'git commit message', "description": "The GIT url to use for the checkout",
}, },
'date': { "ref": {
'type': 'string', "type": "string",
'description': 'timestamp for a git commit' "description": "git reference for a git commit",
"pattern": r"^refs\/(heads|tags|remotes)\/(.+)$",
}, },
'author': { "default_branch": {
'type': 'object', "type": "string",
'description': 'metadata about the author of a git commit', "description": "default branch of the git repository",
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the author',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
},
},
'required': ['username'],
}, },
'committer': { "commit_info": {
'type': 'object', "type": "object",
'description': 'metadata about the committer of a git commit', "description": "metadata about a git commit",
'properties': { "properties": {
'username': { "url": {"type": "string", "description": "URL to view a git commit"},
'type': 'string', "message": {"type": "string", "description": "git commit message"},
'description': 'username of the committer', "date": {"type": "string", "description": "timestamp for a git commit"},
"author": {
"type": "object",
"description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
},
"url": {
"type": "string",
"description": "URL to view the profile of the author",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the author",
},
},
"required": ["username"],
},
"committer": {
"type": "object",
"description": "metadata about the committer of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the committer",
},
"url": {
"type": "string",
"description": "URL to view the profile of the committer",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the committer",
},
},
"required": ["username"],
},
}, },
'url': { "required": ["message"],
'type': 'string',
'description': 'URL to view the profile of the committer',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the committer',
},
},
'required': ['username'],
}, },
},
'required': ['message'],
}, },
}, "required": ["commit", "git_url"],
'required': ['commit', 'git_url'],
} }
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class BuildTriggerHandler(object): class BuildTriggerHandler(object):
def __init__(self, trigger, override_config=None): def __init__(self, trigger, override_config=None):
self.trigger = trigger self.trigger = trigger
self.config = override_config or get_trigger_config(trigger) self.config = override_config or get_trigger_config(trigger)
@property @property
def auth_token(self): def auth_token(self):
""" Returns the auth token for the trigger. """ """ Returns the auth token for the trigger. """
# NOTE: This check is for testing. # NOTE: This check is for testing.
if isinstance(self.trigger.auth_token, str): if isinstance(self.trigger.auth_token, str):
return self.trigger.auth_token return self.trigger.auth_token
# TODO(remove-unenc): Remove legacy field. # TODO(remove-unenc): Remove legacy field.
if self.trigger.secure_auth_token is not None: if self.trigger.secure_auth_token is not None:
return self.trigger.secure_auth_token.decrypt() return self.trigger.secure_auth_token.decrypt()
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS): if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
return self.trigger.auth_token return self.trigger.auth_token
return None return None
@abstractmethod @abstractmethod
def load_dockerfile_contents(self): def load_dockerfile_contents(self):
""" """
Loads the Dockerfile found for the trigger's config and returns them or None if none could Loads the Dockerfile found for the trigger's config and returns them or None if none could
be found/loaded. be found/loaded.
""" """
pass pass
@abstractmethod @abstractmethod
def list_build_source_namespaces(self): def list_build_source_namespaces(self):
""" """
Take the auth information for the specific trigger type and load the Take the auth information for the specific trigger type and load the
list of namespaces that can contain build sources. list of namespaces that can contain build sources.
""" """
pass pass
@abstractmethod @abstractmethod
def list_build_sources_for_namespace(self, namespace): def list_build_sources_for_namespace(self, namespace):
""" """
Take the auth information for the specific trigger type and load the Take the auth information for the specific trigger type and load the
list of repositories under the given namespace. list of repositories under the given namespace.
""" """
pass pass
@abstractmethod @abstractmethod
def list_build_subdirs(self): def list_build_subdirs(self):
""" """
Take the auth information and the specified config so far and list all of Take the auth information and the specified config so far and list all of
the possible subdirs containing dockerfiles. the possible subdirs containing dockerfiles.
""" """
pass pass
@abstractmethod @abstractmethod
def handle_trigger_request(self, request): def handle_trigger_request(self, request):
""" """
Transform the incoming request data into a set of actions. Returns a PreparedBuild. Transform the incoming request data into a set of actions. Returns a PreparedBuild.
""" """
pass pass
@abstractmethod @abstractmethod
def is_active(self): def is_active(self):
""" """
Returns True if the current build trigger is active. Inactive means further Returns True if the current build trigger is active. Inactive means further
setup is needed. setup is needed.
""" """
pass pass
@abstractmethod @abstractmethod
def activate(self, standard_webhook_url): def activate(self, standard_webhook_url):
""" """
Activates the trigger for the service, with the given new configuration. Activates the trigger for the service, with the given new configuration.
Returns new public and private config that should be stored if successful. Returns new public and private config that should be stored if successful.
""" """
pass pass
@abstractmethod @abstractmethod
def deactivate(self): def deactivate(self):
""" """
Deactivates the trigger for the service, removing any hooks installed in Deactivates the trigger for the service, removing any hooks installed in
the remote service. Returns the new config that should be stored if this the remote service. Returns the new config that should be stored if this
trigger is going to be re-activated. trigger is going to be re-activated.
""" """
pass pass
@abstractmethod @abstractmethod
def manual_start(self, run_parameters=None): def manual_start(self, run_parameters=None):
""" """
Manually creates a repository build for this trigger. Returns a PreparedBuild. Manually creates a repository build for this trigger. Returns a PreparedBuild.
""" """
pass pass
@abstractmethod @abstractmethod
def list_field_values(self, field_name, limit=None): def list_field_values(self, field_name, limit=None):
""" """
Lists all values for the given custom trigger field. For example, a trigger might have a Lists all values for the given custom trigger field. For example, a trigger might have a
field named "branches", and this method would return all branches. field named "branches", and this method would return all branches.
""" """
pass pass
@abstractmethod @abstractmethod
def get_repository_url(self): def get_repository_url(self):
""" Returns the URL of the current trigger's repository. Note that this operation """ Returns the URL of the current trigger's repository. Note that this operation
can be called in a loop, so it should be as fast as possible. """ can be called in a loop, so it should be as fast as possible. """
pass pass
@classmethod @classmethod
def filename_is_dockerfile(cls, file_name): def filename_is_dockerfile(cls, file_name):
""" Returns whether the file is named Dockerfile or follows the convention <name>.Dockerfile""" """ Returns whether the file is named Dockerfile or follows the convention <name>.Dockerfile"""
return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name
@classmethod @classmethod
def service_name(cls): def service_name(cls):
""" """
Particular service implemented by subclasses. Particular service implemented by subclasses.
""" """
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def get_handler(cls, trigger, override_config=None): def get_handler(cls, trigger, override_config=None):
for subc in cls.__subclasses__(): for subc in cls.__subclasses__():
if subc.service_name() == trigger.service.name: if subc.service_name() == trigger.service.name:
return subc(trigger, override_config) return subc(trigger, override_config)
raise InvalidServiceException('Unable to find service: %s' % trigger.service.name) raise InvalidServiceException(
"Unable to find service: %s" % trigger.service.name
)
def put_config_key(self, key, value): def put_config_key(self, key, value):
""" Updates a config key in the trigger, saving it to the DB. """ """ Updates a config key in the trigger, saving it to the DB. """
self.config[key] = value self.config[key] = value
model.build.update_build_trigger(self.trigger, self.config) model.build.update_build_trigger(self.trigger, self.config)
def set_auth_token(self, auth_token): def set_auth_token(self, auth_token):
""" Sets the auth token for the trigger, saving it to the DB. """ """ Sets the auth token for the trigger, saving it to the DB. """
model.build.update_build_trigger(self.trigger, self.config, auth_token=auth_token) model.build.update_build_trigger(
self.trigger, self.config, auth_token=auth_token
)
def get_dockerfile_path(self): def get_dockerfile_path(self):
""" Returns the normalized path to the Dockerfile found in the subdirectory """ Returns the normalized path to the Dockerfile found in the subdirectory
in the config. """ in the config. """
dockerfile_path = self.config.get('dockerfile_path') or 'Dockerfile' dockerfile_path = self.config.get("dockerfile_path") or "Dockerfile"
if dockerfile_path[0] == '/': if dockerfile_path[0] == "/":
dockerfile_path = dockerfile_path[1:] dockerfile_path = dockerfile_path[1:]
return dockerfile_path return dockerfile_path
def prepare_build(self, metadata, is_manual=False): def prepare_build(self, metadata, is_manual=False):
# Ensure that the metadata meets the scheme. # Ensure that the metadata meets the scheme.
validate(metadata, METADATA_SCHEMA) validate(metadata, METADATA_SCHEMA)
config = self.config config = self.config
ref = metadata.get('ref', None) ref = metadata.get("ref", None)
commit_sha = metadata['commit'] commit_sha = metadata["commit"]
default_branch = metadata.get('default_branch', None) default_branch = metadata.get("default_branch", None)
prepared = PreparedBuild(self.trigger) prepared = PreparedBuild(self.trigger)
prepared.name_from_sha(commit_sha) prepared.name_from_sha(commit_sha)
prepared.subdirectory = config.get('dockerfile_path', None) prepared.subdirectory = config.get("dockerfile_path", None)
prepared.context = config.get('context', None) prepared.context = config.get("context", None)
prepared.is_manual = is_manual prepared.is_manual = is_manual
prepared.metadata = metadata prepared.metadata = metadata
if ref is not None: if ref is not None:
prepared.tags_from_ref(ref, default_branch) prepared.tags_from_ref(ref, default_branch)
else: else:
prepared.tags = [commit_sha[:7]] prepared.tags = [commit_sha[:7]]
return prepared return prepared
@classmethod @classmethod
def build_sources_response(cls, sources): def build_sources_response(cls, sources):
validate(sources, BUILD_SOURCES_SCHEMA) validate(sources, BUILD_SOURCES_SCHEMA)
return sources return sources
@classmethod @classmethod
def build_namespaces_response(cls, namespaces_dict): def build_namespaces_response(cls, namespaces_dict):
namespaces = list(namespaces_dict.values()) namespaces = list(namespaces_dict.values())
validate(namespaces, NAMESPACES_SCHEMA) validate(namespaces, NAMESPACES_SCHEMA)
return namespaces return namespaces
@classmethod @classmethod
def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None): def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None):
""" Returns a map of dockerfile_paths to it's possible contexts. """ """ Returns a map of dockerfile_paths to it's possible contexts. """
if dockerfile_path == "": if dockerfile_path == "":
return {} return {}
if dockerfile_path[0] != os.path.sep: if dockerfile_path[0] != os.path.sep:
dockerfile_path = os.path.sep + dockerfile_path dockerfile_path = os.path.sep + dockerfile_path
dockerfile_path = os.path.normpath(dockerfile_path) dockerfile_path = os.path.normpath(dockerfile_path)
all_paths = set() all_paths = set()
path, _ = os.path.split(dockerfile_path) path, _ = os.path.split(dockerfile_path)
if path == "": if path == "":
path = os.path.sep path = os.path.sep
all_paths.add(path) all_paths.add(path)
for i in range(1, len(path.split(os.path.sep))): for i in range(1, len(path.split(os.path.sep))):
path, _ = os.path.split(path) path, _ = os.path.split(path)
all_paths.add(path) all_paths.add(path)
if current_paths: if current_paths:
return dict({dockerfile_path: list(all_paths)}, **current_paths) return dict({dockerfile_path: list(all_paths)}, **current_paths)
return {dockerfile_path: list(all_paths)} return {dockerfile_path: list(all_paths)}

View file

@ -9,537 +9,551 @@ from jsonschema import validate
from app import app, get_app_url from app import app, get_app_url
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerDeactivationException, TriggerStartException, RepositoryReadException,
InvalidPayloadException, TriggerProviderException, TriggerActivationException,
SkipRequestException, TriggerDeactivationException,
determine_build_ref, raise_if_skipped_build, TriggerStartException,
find_matching_branches) InvalidPayloadException,
TriggerProviderException,
SkipRequestException,
determine_build_ref,
raise_if_skipped_build,
find_matching_branches,
)
from util.dict_wrappers import JSONPathDict, SafeDictSetter from util.dict_wrappers import JSONPathDict, SafeDictSetter
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_BITBUCKET_COMMIT_URL = 'https://bitbucket.org/%s/commits/%s' _BITBUCKET_COMMIT_URL = "https://bitbucket.org/%s/commits/%s"
_RAW_AUTHOR_REGEX = re.compile(r'.*<(.+)>') _RAW_AUTHOR_REGEX = re.compile(r".*<(.+)>")
BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = { BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'repository': { "repository": {
'type': 'object', "type": "object",
'properties': { "properties": {"full_name": {"type": "string"}},
'full_name': { "required": ["full_name"],
'type': 'string', }, # /Repository
}, "push": {
}, "type": "object",
'required': ['full_name'], "properties": {
}, # /Repository "changes": {
'push': { "type": "array",
'type': 'object', "items": {
'properties': { "type": "object",
'changes': { "properties": {
'type': 'array', "new": {
'items': { "type": "object",
'type': 'object', "properties": {
'properties': { "target": {
'new': { "type": "object",
'type': 'object', "properties": {
'properties': { "hash": {"type": "string"},
'target': { "message": {"type": "string"},
'type': 'object', "date": {"type": "string"},
'properties': { "author": {
'hash': { "type": "object",
'type': 'string' "properties": {
}, "user": {
'message': { "type": "object",
'type': 'string' "properties": {
}, "display_name": {
'date': { "type": "string"
'type': 'string' },
}, "account_id": {
'author': { "type": "string"
'type': 'object', },
'properties': { "links": {
'user': { "type": "object",
'type': 'object', "properties": {
'properties': { "avatar": {
'display_name': { "type": "object",
'type': 'string', "properties": {
}, "href": {
'account_id': { "type": "string"
'type': 'string', }
}, },
'links': { "required": [
'type': 'object', "href"
'properties': { ],
'avatar': { }
'type': 'object', },
'properties': { "required": ["avatar"],
'href': { }, # /User
'type': 'string', },
}, } # /Author
}, },
'required': ['href'], },
}, },
"required": ["hash", "message", "date"],
} # /Target
}, },
'required': ['avatar'], "required": ["name", "target"],
}, # /User } # /New
},
}, # /Author
}, },
}, }, # /Changes item
}, } # /Changes
'required': ['hash', 'message', 'date'],
}, # /Target
},
'required': ['name', 'target'],
}, # /New
}, },
}, # /Changes item "required": ["changes"],
}, # /Changes }, # / Push
},
'required': ['changes'],
}, # / Push
},
'actor': {
'type': 'object',
'properties': {
'account_id': {
'type': 'string',
},
'display_name': {
'type': 'string',
},
'links': {
'type': 'object',
'properties': {
'avatar': {
'type': 'object',
'properties': {
'href': {
'type': 'string',
},
},
'required': ['href'],
},
},
'required': ['avatar'],
},
}, },
}, # /Actor "actor": {
'required': ['push', 'repository'], "type": "object",
} # /Root "properties": {
"account_id": {"type": "string"},
"display_name": {"type": "string"},
"links": {
"type": "object",
"properties": {
"avatar": {
"type": "object",
"properties": {"href": {"type": "string"}},
"required": ["href"],
}
},
"required": ["avatar"],
},
},
}, # /Actor
"required": ["push", "repository"],
} # /Root
BITBUCKET_COMMIT_INFO_SCHEMA = { BITBUCKET_COMMIT_INFO_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'node': { "node": {"type": "string"},
'type': 'string', "message": {"type": "string"},
"timestamp": {"type": "string"},
"raw_author": {"type": "string"},
}, },
'message': { "required": ["node", "message", "timestamp"],
'type': 'string',
},
'timestamp': {
'type': 'string',
},
'raw_author': {
'type': 'string',
},
},
'required': ['node', 'message', 'timestamp']
} }
def get_transformed_commit_info(bb_commit, ref, default_branch, repository_name, lookup_author):
""" Returns the BitBucket commit information transformed into our own def get_transformed_commit_info(
bb_commit, ref, default_branch, repository_name, lookup_author
):
""" Returns the BitBucket commit information transformed into our own
payload format. payload format.
""" """
try: try:
validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA) validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
except Exception as exc: except Exception as exc:
logger.exception('Exception when validating Bitbucket commit information: %s from %s', exc.message, bb_commit) logger.exception(
raise InvalidPayloadException(exc.message) "Exception when validating Bitbucket commit information: %s from %s",
exc.message,
bb_commit,
)
raise InvalidPayloadException(exc.message)
commit = JSONPathDict(bb_commit) commit = JSONPathDict(bb_commit)
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = commit['node'] config["commit"] = commit["node"]
config['ref'] = ref config["ref"] = ref
config['default_branch'] = default_branch config["default_branch"] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = _BITBUCKET_COMMIT_URL % (repository_name, commit['node']) config["commit_info.url"] = _BITBUCKET_COMMIT_URL % (
config['commit_info.message'] = commit['message'] repository_name,
config['commit_info.date'] = commit['timestamp'] commit["node"],
)
config["commit_info.message"] = commit["message"]
config["commit_info.date"] = commit["timestamp"]
match = _RAW_AUTHOR_REGEX.match(commit['raw_author']) match = _RAW_AUTHOR_REGEX.match(commit["raw_author"])
if match: if match:
author = lookup_author(match.group(1)) author = lookup_author(match.group(1))
author_info = JSONPathDict(author) if author is not None else None author_info = JSONPathDict(author) if author is not None else None
if author_info: if author_info:
config['commit_info.author.username'] = author_info['user.display_name'] config["commit_info.author.username"] = author_info["user.display_name"]
config['commit_info.author.avatar_url'] = author_info['user.avatar'] config["commit_info.author.avatar_url"] = author_info["user.avatar"]
return config.dict_value() return config.dict_value()
def get_transformed_webhook_payload(bb_payload, default_branch=None): def get_transformed_webhook_payload(bb_payload, default_branch=None):
""" Returns the BitBucket webhook JSON payload transformed into our own payload """ Returns the BitBucket webhook JSON payload transformed into our own payload
format. If the bb_payload is not valid, returns None. format. If the bb_payload is not valid, returns None.
""" """
try: try:
validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA) validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
except Exception as exc: except Exception as exc:
logger.exception('Exception when validating Bitbucket webhook payload: %s from %s', exc.message, logger.exception(
bb_payload) "Exception when validating Bitbucket webhook payload: %s from %s",
raise InvalidPayloadException(exc.message) exc.message,
bb_payload,
)
raise InvalidPayloadException(exc.message)
payload = JSONPathDict(bb_payload) payload = JSONPathDict(bb_payload)
change = payload['push.changes[-1].new'] change = payload["push.changes[-1].new"]
if not change: if not change:
raise SkipRequestException raise SkipRequestException
is_branch = change['type'] == 'branch' is_branch = change["type"] == "branch"
ref = 'refs/heads/' + change['name'] if is_branch else 'refs/tags/' + change['name'] ref = "refs/heads/" + change["name"] if is_branch else "refs/tags/" + change["name"]
repository_name = payload['repository.full_name'] repository_name = payload["repository.full_name"]
target = change['target'] target = change["target"]
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = target['hash'] config["commit"] = target["hash"]
config['ref'] = ref config["ref"] = ref
config['default_branch'] = default_branch config["default_branch"] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = target['links.html.href'] or '' config["commit_info.url"] = target["links.html.href"] or ""
config['commit_info.message'] = target['message'] config["commit_info.message"] = target["message"]
config['commit_info.date'] = target['date'] config["commit_info.date"] = target["date"]
config['commit_info.author.username'] = target['author.user.display_name'] config["commit_info.author.username"] = target["author.user.display_name"]
config['commit_info.author.avatar_url'] = target['author.user.links.avatar.href'] config["commit_info.author.avatar_url"] = target["author.user.links.avatar.href"]
config['commit_info.committer.username'] = payload['actor.display_name'] config["commit_info.committer.username"] = payload["actor.display_name"]
config['commit_info.committer.avatar_url'] = payload['actor.links.avatar.href'] config["commit_info.committer.avatar_url"] = payload["actor.links.avatar.href"]
return config.dict_value() return config.dict_value()
class BitbucketBuildTrigger(BuildTriggerHandler): class BitbucketBuildTrigger(BuildTriggerHandler):
""" """
BuildTrigger for Bitbucket. BuildTrigger for Bitbucket.
""" """
@classmethod
def service_name(cls):
return 'bitbucket'
def _get_client(self): @classmethod
""" Returns a BitBucket API client for this trigger's config. """ def service_name(cls):
key = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_KEY', '') return "bitbucket"
secret = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_SECRET', '')
trigger_uuid = self.trigger.uuid def _get_client(self):
callback_url = '%s/oauth1/bitbucket/callback/trigger/%s' % (get_app_url(), trigger_uuid) """ Returns a BitBucket API client for this trigger's config. """
key = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get("CONSUMER_KEY", "")
secret = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get(
"CONSUMER_SECRET", ""
)
return BitBucket(key, secret, callback_url, timeout=15) trigger_uuid = self.trigger.uuid
callback_url = "%s/oauth1/bitbucket/callback/trigger/%s" % (
get_app_url(),
trigger_uuid,
)
def _get_authorized_client(self): return BitBucket(key, secret, callback_url, timeout=15)
""" Returns an authorized API client. """
base_client = self._get_client()
auth_token = self.auth_token or 'invalid:invalid'
token_parts = auth_token.split(':')
if len(token_parts) != 2:
token_parts = ['invalid', 'invalid']
(access_token, access_token_secret) = token_parts def _get_authorized_client(self):
return base_client.get_authorized_client(access_token, access_token_secret) """ Returns an authorized API client. """
base_client = self._get_client()
auth_token = self.auth_token or "invalid:invalid"
token_parts = auth_token.split(":")
if len(token_parts) != 2:
token_parts = ["invalid", "invalid"]
def _get_repository_client(self): (access_token, access_token_secret) = token_parts
""" Returns an API client for working with this config's BB repository. """ return base_client.get_authorized_client(access_token, access_token_secret)
source = self.config['build_source']
(namespace, name) = source.split('/')
bitbucket_client = self._get_authorized_client()
return bitbucket_client.for_namespace(namespace).repositories().get(name)
def _get_default_branch(self, repository, default_value='master'): def _get_repository_client(self):
""" Returns the default branch for the repository or the value given. """ """ Returns an API client for working with this config's BB repository. """
(result, data, _) = repository.get_main_branch() source = self.config["build_source"]
if result: (namespace, name) = source.split("/")
return data['name'] bitbucket_client = self._get_authorized_client()
return bitbucket_client.for_namespace(namespace).repositories().get(name)
return default_value def _get_default_branch(self, repository, default_value="master"):
""" Returns the default branch for the repository or the value given. """
(result, data, _) = repository.get_main_branch()
if result:
return data["name"]
def get_oauth_url(self): return default_value
""" Returns the OAuth URL to authorize Bitbucket. """
bitbucket_client = self._get_client()
(result, data, err_msg) = bitbucket_client.get_authorization_url()
if not result:
raise TriggerProviderException(err_msg)
return data def get_oauth_url(self):
""" Returns the OAuth URL to authorize Bitbucket. """
bitbucket_client = self._get_client()
(result, data, err_msg) = bitbucket_client.get_authorization_url()
if not result:
raise TriggerProviderException(err_msg)
def exchange_verifier(self, verifier): return data
""" Exchanges the given verifier token to setup this trigger. """
bitbucket_client = self._get_client()
access_token = self.config.get('access_token', '')
access_token_secret = self.auth_token
# Exchange the verifier for a new access token. def exchange_verifier(self, verifier):
(result, data, _) = bitbucket_client.verify_token(access_token, access_token_secret, verifier) """ Exchanges the given verifier token to setup this trigger. """
if not result: bitbucket_client = self._get_client()
return False access_token = self.config.get("access_token", "")
access_token_secret = self.auth_token
# Save the updated access token and secret. # Exchange the verifier for a new access token.
self.set_auth_token(data[0] + ':' + data[1]) (result, data, _) = bitbucket_client.verify_token(
access_token, access_token_secret, verifier
)
if not result:
return False
# Retrieve the current authorized user's information and store the username in the config. # Save the updated access token and secret.
authorized_client = self._get_authorized_client() self.set_auth_token(data[0] + ":" + data[1])
(result, data, _) = authorized_client.get_current_user()
if not result:
return False
self.put_config_key('account_id', data['user']['account_id']) # Retrieve the current authorized user's information and store the username in the config.
self.put_config_key('nickname', data['user']['nickname']) authorized_client = self._get_authorized_client()
return True (result, data, _) = authorized_client.get_current_user()
if not result:
return False
def is_active(self): self.put_config_key("account_id", data["user"]["account_id"])
return 'webhook_id' in self.config self.put_config_key("nickname", data["user"]["nickname"])
return True
def activate(self, standard_webhook_url): def is_active(self):
config = self.config return "webhook_id" in self.config
# Add a deploy key to the repository. def activate(self, standard_webhook_url):
public_key, private_key = generate_ssh_keypair() config = self.config
config['credentials'] = [
{
'name': 'SSH Public Key',
'value': public_key,
},
]
repository = self._get_repository_client() # Add a deploy key to the repository.
(result, created_deploykey, err_msg) = repository.deploykeys().create( public_key, private_key = generate_ssh_keypair()
app.config['REGISTRY_TITLE'] + ' webhook key', public_key) config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
if not result: repository = self._get_repository_client()
msg = 'Unable to add deploy key to repository: %s' % err_msg (result, created_deploykey, err_msg) = repository.deploykeys().create(
raise TriggerActivationException(msg) app.config["REGISTRY_TITLE"] + " webhook key", public_key
)
config['deploy_key_id'] = created_deploykey['pk'] if not result:
msg = "Unable to add deploy key to repository: %s" % err_msg
raise TriggerActivationException(msg)
# Add a webhook callback. config["deploy_key_id"] = created_deploykey["pk"]
description = 'Webhook for invoking builds on %s' % app.config['REGISTRY_TITLE_SHORT']
webhook_events = ['repo:push']
(result, created_webhook, err_msg) = repository.webhooks().create(
description, standard_webhook_url, webhook_events)
if not result: # Add a webhook callback.
msg = 'Unable to add webhook to repository: %s' % err_msg description = (
raise TriggerActivationException(msg) "Webhook for invoking builds on %s" % app.config["REGISTRY_TITLE_SHORT"]
)
webhook_events = ["repo:push"]
(result, created_webhook, err_msg) = repository.webhooks().create(
description, standard_webhook_url, webhook_events
)
config['webhook_id'] = created_webhook['uuid'] if not result:
self.config = config msg = "Unable to add webhook to repository: %s" % err_msg
return config, {'private_key': private_key} raise TriggerActivationException(msg)
def deactivate(self): config["webhook_id"] = created_webhook["uuid"]
config = self.config self.config = config
return config, {"private_key": private_key}
webhook_id = config.pop('webhook_id', None) def deactivate(self):
deploy_key_id = config.pop('deploy_key_id', None) config = self.config
repository = self._get_repository_client()
# Remove the webhook. webhook_id = config.pop("webhook_id", None)
if webhook_id is not None: deploy_key_id = config.pop("deploy_key_id", None)
(result, _, err_msg) = repository.webhooks().delete(webhook_id) repository = self._get_repository_client()
if not result:
msg = 'Unable to remove webhook from repository: %s' % err_msg
raise TriggerDeactivationException(msg)
# Remove the public key. # Remove the webhook.
if deploy_key_id is not None: if webhook_id is not None:
(result, _, err_msg) = repository.deploykeys().delete(deploy_key_id) (result, _, err_msg) = repository.webhooks().delete(webhook_id)
if not result: if not result:
msg = 'Unable to remove deploy key from repository: %s' % err_msg msg = "Unable to remove webhook from repository: %s" % err_msg
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
return config # Remove the public key.
if deploy_key_id is not None:
(result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
if not result:
msg = "Unable to remove deploy key from repository: %s" % err_msg
raise TriggerDeactivationException(msg)
def list_build_source_namespaces(self): return config
bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException('Could not read repository list: ' + err_msg)
namespaces = {} def list_build_source_namespaces(self):
for repo in data: bitbucket_client = self._get_authorized_client()
owner = repo['owner'] (result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException("Could not read repository list: " + err_msg)
if owner in namespaces: namespaces = {}
namespaces[owner]['score'] = namespaces[owner]['score'] + 1 for repo in data:
else: owner = repo["owner"]
namespaces[owner] = {
'personal': owner == self.config.get('nickname', self.config.get('username')),
'id': owner,
'title': owner,
'avatar_url': repo['logo'],
'url': 'https://bitbucket.org/%s' % (owner),
'score': 1,
}
return BuildTriggerHandler.build_namespaces_response(namespaces) if owner in namespaces:
namespaces[owner]["score"] = namespaces[owner]["score"] + 1
else:
namespaces[owner] = {
"personal": owner
== self.config.get("nickname", self.config.get("username")),
"id": owner,
"title": owner,
"avatar_url": repo["logo"],
"url": "https://bitbucket.org/%s" % (owner),
"score": 1,
}
def list_build_sources_for_namespace(self, namespace): return BuildTriggerHandler.build_namespaces_response(namespaces)
def repo_view(repo):
last_modified = dateutil.parser.parse(repo['utc_last_updated'])
return { def list_build_sources_for_namespace(self, namespace):
'name': repo['slug'], def repo_view(repo):
'full_name': '%s/%s' % (repo['owner'], repo['slug']), last_modified = dateutil.parser.parse(repo["utc_last_updated"])
'description': repo['description'] or '',
'last_updated': timegm(last_modified.utctimetuple()),
'url': 'https://bitbucket.org/%s/%s' % (repo['owner'], repo['slug']),
'has_admin_permissions': repo['read_only'] is False,
'private': repo['is_private'],
}
bitbucket_client = self._get_authorized_client() return {
(result, data, err_msg) = bitbucket_client.get_visible_repositories() "name": repo["slug"],
if not result: "full_name": "%s/%s" % (repo["owner"], repo["slug"]),
raise RepositoryReadException('Could not read repository list: ' + err_msg) "description": repo["description"] or "",
"last_updated": timegm(last_modified.utctimetuple()),
"url": "https://bitbucket.org/%s/%s" % (repo["owner"], repo["slug"]),
"has_admin_permissions": repo["read_only"] is False,
"private": repo["is_private"],
}
repos = [repo_view(repo) for repo in data if repo['owner'] == namespace] bitbucket_client = self._get_authorized_client()
return BuildTriggerHandler.build_sources_response(repos) (result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException("Could not read repository list: " + err_msg)
def list_build_subdirs(self): repos = [repo_view(repo) for repo in data if repo["owner"] == namespace]
config = self.config return BuildTriggerHandler.build_sources_response(repos)
repository = self._get_repository_client()
# Find the first matching branch. def list_build_subdirs(self):
repo_branches = self.list_field_values('branch_name') or [] config = self.config
branches = find_matching_branches(config, repo_branches) repository = self._get_repository_client()
if not branches:
branches = [self._get_default_branch(repository)]
(result, data, err_msg) = repository.get_path_contents('', revision=branches[0]) # Find the first matching branch.
if not result: repo_branches = self.list_field_values("branch_name") or []
raise RepositoryReadException(err_msg) branches = find_matching_branches(config, repo_branches)
if not branches:
branches = [self._get_default_branch(repository)]
files = set([f['path'] for f in data['files']]) (result, data, err_msg) = repository.get_path_contents("", revision=branches[0])
return ["/" + file_path for file_path in files if self.filename_is_dockerfile(os.path.basename(file_path))] if not result:
raise RepositoryReadException(err_msg)
def load_dockerfile_contents(self): files = set([f["path"] for f in data["files"]])
repository = self._get_repository_client() return [
path = self.get_dockerfile_path() "/" + file_path
for file_path in files
if self.filename_is_dockerfile(os.path.basename(file_path))
]
(result, data, err_msg) = repository.get_raw_path_contents(path, revision='master') def load_dockerfile_contents(self):
if not result: repository = self._get_repository_client()
return None path = self.get_dockerfile_path()
return data (result, data, err_msg) = repository.get_raw_path_contents(
path, revision="master"
)
if not result:
return None
def list_field_values(self, field_name, limit=None): return data
if 'build_source' not in self.config:
return None
source = self.config['build_source'] def list_field_values(self, field_name, limit=None):
(namespace, name) = source.split('/') if "build_source" not in self.config:
return None
bitbucket_client = self._get_authorized_client() source = self.config["build_source"]
repository = bitbucket_client.for_namespace(namespace).repositories().get(name) (namespace, name) = source.split("/")
bitbucket_client = self._get_authorized_client()
repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
if field_name == "refs":
(result, data, _) = repository.get_branches_and_tags()
if not result:
return None
branches = [b["name"] for b in data["branches"]]
tags = [t["name"] for t in data["tags"]]
return [{"kind": "branch", "name": b} for b in branches] + [
{"kind": "tag", "name": tag} for tag in tags
]
if field_name == "tag_name":
(result, data, _) = repository.get_tags()
if not result:
return None
tags = list(data.keys())
if limit:
tags = tags[0:limit]
return tags
if field_name == "branch_name":
(result, data, _) = repository.get_branches()
if not result:
return None
branches = list(data.keys())
if limit:
branches = branches[0:limit]
return branches
if field_name == 'refs':
(result, data, _) = repository.get_branches_and_tags()
if not result:
return None return None
branches = [b['name'] for b in data['branches']] def get_repository_url(self):
tags = [t['name'] for t in data['tags']] source = self.config["build_source"]
(namespace, name) = source.split("/")
return "https://bitbucket.org/%s/%s" % (namespace, name)
return ([{'kind': 'branch', 'name': b} for b in branches] + def handle_trigger_request(self, request):
[{'kind': 'tag', 'name': tag} for tag in tags]) payload = request.get_json()
if payload is None:
raise InvalidPayloadException("Missing payload")
if field_name == 'tag_name': logger.debug("Got BitBucket request: %s", payload)
(result, data, _) = repository.get_tags()
if not result:
return None
tags = list(data.keys()) repository = self._get_repository_client()
if limit: default_branch = self._get_default_branch(repository)
tags = tags[0:limit]
return tags metadata = get_transformed_webhook_payload(
payload, default_branch=default_branch
)
prepared = self.prepare_build(metadata)
if field_name == 'branch_name': # Check if we should skip this build.
(result, data, _) = repository.get_branches() raise_if_skipped_build(prepared, self.config)
if not result: return prepared
return None
branches = list(data.keys()) def manual_start(self, run_parameters=None):
if limit: run_parameters = run_parameters or {}
branches = branches[0:limit] repository = self._get_repository_client()
bitbucket_client = self._get_authorized_client()
return branches def get_branch_sha(branch_name):
# Lookup the commit SHA for the branch.
(result, data, _) = repository.get_branch(branch_name)
if not result:
raise TriggerStartException("Could not find branch in repository")
return None return data["target"]["hash"]
def get_repository_url(self): def get_tag_sha(tag_name):
source = self.config['build_source'] # Lookup the commit SHA for the tag.
(namespace, name) = source.split('/') (result, data, _) = repository.get_tag(tag_name)
return 'https://bitbucket.org/%s/%s' % (namespace, name) if not result:
raise TriggerStartException("Could not find tag in repository")
def handle_trigger_request(self, request): return data["target"]["hash"]
payload = request.get_json()
if payload is None:
raise InvalidPayloadException('Missing payload')
logger.debug('Got BitBucket request: %s', payload) def lookup_author(email_address):
(result, data, _) = bitbucket_client.accounts().get_profile(email_address)
return data if result else None
repository = self._get_repository_client() # Find the branch or tag to build.
default_branch = self._get_default_branch(repository) default_branch = self._get_default_branch(repository)
(commit_sha, ref) = determine_build_ref(
run_parameters, get_branch_sha, get_tag_sha, default_branch
)
metadata = get_transformed_webhook_payload(payload, default_branch=default_branch) # Lookup the commit SHA in BitBucket.
prepared = self.prepare_build(metadata) (result, commit_info, _) = repository.changesets().get(commit_sha)
if not result:
raise TriggerStartException("Could not lookup commit SHA")
# Check if we should skip this build. # Return a prepared build for the commit.
raise_if_skipped_build(prepared, self.config) repository_name = "%s/%s" % (repository.namespace, repository.repository_name)
return prepared metadata = get_transformed_commit_info(
commit_info, ref, default_branch, repository_name, lookup_author
)
def manual_start(self, run_parameters=None): return self.prepare_build(metadata, is_manual=True)
run_parameters = run_parameters or {}
repository = self._get_repository_client()
bitbucket_client = self._get_authorized_client()
def get_branch_sha(branch_name):
# Lookup the commit SHA for the branch.
(result, data, _) = repository.get_branch(branch_name)
if not result:
raise TriggerStartException('Could not find branch in repository')
return data['target']['hash']
def get_tag_sha(tag_name):
# Lookup the commit SHA for the tag.
(result, data, _) = repository.get_tag(tag_name)
if not result:
raise TriggerStartException('Could not find tag in repository')
return data['target']['hash']
def lookup_author(email_address):
(result, data, _) = bitbucket_client.accounts().get_profile(email_address)
return data if result else None
# Find the branch or tag to build.
default_branch = self._get_default_branch(repository)
(commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha,
default_branch)
# Lookup the commit SHA in BitBucket.
(result, commit_info, _) = repository.changesets().get(commit_sha)
if not result:
raise TriggerStartException('Could not lookup commit SHA')
# Return a prepared build for the commit.
repository_name = '%s/%s' % (repository.namespace, repository.repository_name)
metadata = get_transformed_commit_info(commit_info, ref, default_branch,
repository_name, lookup_author)
return self.prepare_build(metadata, is_manual=True)

View file

@ -2,22 +2,33 @@ import logging
import json import json
from jsonschema import validate, ValidationError from jsonschema import validate, ValidationError
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerStartException, ValidationRequestException, RepositoryReadException,
InvalidPayloadException, TriggerActivationException,
SkipRequestException, raise_if_skipped_build, TriggerStartException,
find_matching_branches) ValidationRequestException,
InvalidPayloadException,
SkipRequestException,
raise_if_skipped_build,
find_matching_branches,
)
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.bitbuckethandler import (BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema, from buildtrigger.bitbuckethandler import (
get_transformed_webhook_payload as bb_payload) BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
get_transformed_webhook_payload as bb_payload,
)
from buildtrigger.githubhandler import (GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema, from buildtrigger.githubhandler import (
get_transformed_webhook_payload as gh_payload) GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
get_transformed_webhook_payload as gh_payload,
)
from buildtrigger.gitlabhandler import (GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema, from buildtrigger.gitlabhandler import (
get_transformed_webhook_payload as gl_payload) GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
get_transformed_webhook_payload as gl_payload,
)
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
@ -27,203 +38,191 @@ logger = logging.getLogger(__name__)
# Defines an ordered set of tuples of the schemas and associated transformation functions # Defines an ordered set of tuples of the schemas and associated transformation functions
# for incoming webhook payloads. # for incoming webhook payloads.
SCHEMA_AND_HANDLERS = [ SCHEMA_AND_HANDLERS = [
(gh_schema, gh_payload), (gh_schema, gh_payload),
(bb_schema, bb_payload), (bb_schema, bb_payload),
(gl_schema, gl_payload), (gl_schema, gl_payload),
] ]
def custom_trigger_payload(metadata, git_url): def custom_trigger_payload(metadata, git_url):
# First try the customhandler schema. If it matches, nothing more to do. # First try the customhandler schema. If it matches, nothing more to do.
custom_handler_validation_error = None custom_handler_validation_error = None
try:
validate(metadata, CustomBuildTrigger.payload_schema)
except ValidationError as vex:
custom_handler_validation_error = vex
# Otherwise, try the defined schemas, in order, until we find a match.
for schema, handler in SCHEMA_AND_HANDLERS:
try: try:
validate(metadata, schema) validate(metadata, CustomBuildTrigger.payload_schema)
except ValidationError: except ValidationError as vex:
continue custom_handler_validation_error = vex
result = handler(metadata) # Otherwise, try the defined schemas, in order, until we find a match.
result['git_url'] = git_url for schema, handler in SCHEMA_AND_HANDLERS:
return result try:
validate(metadata, schema)
except ValidationError:
continue
# If we have reached this point and no other schemas validated, then raise the error for the result = handler(metadata)
# custom schema. result["git_url"] = git_url
if custom_handler_validation_error is not None: return result
raise InvalidPayloadException(custom_handler_validation_error.message)
metadata['git_url'] = git_url # If we have reached this point and no other schemas validated, then raise the error for the
return metadata # custom schema.
if custom_handler_validation_error is not None:
raise InvalidPayloadException(custom_handler_validation_error.message)
metadata["git_url"] = git_url
return metadata
class CustomBuildTrigger(BuildTriggerHandler): class CustomBuildTrigger(BuildTriggerHandler):
payload_schema = { payload_schema = {
'type': 'object', "type": "object",
'properties': { "properties": {
'commit': { "commit": {
'type': 'string', "type": "string",
'description': 'first 7 characters of the SHA-1 identifier for a git commit', "description": "first 7 characters of the SHA-1 identifier for a git commit",
'pattern': '^([A-Fa-f0-9]{7,})$', "pattern": "^([A-Fa-f0-9]{7,})$",
},
'ref': {
'type': 'string',
'description': 'git reference for a git commit',
'pattern': '^refs\/(heads|tags|remotes)\/(.+)$',
},
'default_branch': {
'type': 'string',
'description': 'default branch of the git repository',
},
'commit_info': {
'type': 'object',
'description': 'metadata about a git commit',
'properties': {
'url': {
'type': 'string',
'description': 'URL to view a git commit',
},
'message': {
'type': 'string',
'description': 'git commit message',
},
'date': {
'type': 'string',
'description': 'timestamp for a git commit'
},
'author': {
'type': 'object',
'description': 'metadata about the author of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the author',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
},
}, },
'required': ['username', 'url', 'avatar_url'], "ref": {
}, "type": "string",
'committer': { "description": "git reference for a git commit",
'type': 'object', "pattern": "^refs\/(heads|tags|remotes)\/(.+)$",
'description': 'metadata about the committer of a git commit', },
'properties': { "default_branch": {
'username': { "type": "string",
'type': 'string', "description": "default branch of the git repository",
'description': 'username of the committer', },
}, "commit_info": {
'url': { "type": "object",
'type': 'string', "description": "metadata about a git commit",
'description': 'URL to view the profile of the committer', "properties": {
}, "url": {
'avatar_url': { "type": "string",
'type': 'string', "description": "URL to view a git commit",
'description': 'URL to view the avatar of the committer', },
}, "message": {"type": "string", "description": "git commit message"},
"date": {
"type": "string",
"description": "timestamp for a git commit",
},
"author": {
"type": "object",
"description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
},
"url": {
"type": "string",
"description": "URL to view the profile of the author",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the author",
},
},
"required": ["username", "url", "avatar_url"],
},
"committer": {
"type": "object",
"description": "metadata about the committer of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the committer",
},
"url": {
"type": "string",
"description": "URL to view the profile of the committer",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the committer",
},
},
"required": ["username", "url", "avatar_url"],
},
},
"required": ["url", "message", "date"],
}, },
'required': ['username', 'url', 'avatar_url'],
},
}, },
'required': ['url', 'message', 'date'], "required": ["commit", "ref", "default_branch"],
},
},
'required': ['commit', 'ref', 'default_branch'],
}
@classmethod
def service_name(cls):
return 'custom-git'
def is_active(self):
return self.config.has_key('credentials')
def _metadata_from_payload(self, payload, git_url):
# Parse the JSON payload.
try:
metadata = json.loads(payload)
except ValueError as vex:
raise InvalidPayloadException(vex.message)
return custom_trigger_payload(metadata, git_url)
def handle_trigger_request(self, request):
payload = request.data
if not payload:
raise InvalidPayloadException('Missing expected payload')
logger.debug('Payload %s', payload)
metadata = self._metadata_from_payload(payload, self.config['build_source'])
prepared = self.prepare_build(metadata)
# Check if we should skip this build.
raise_if_skipped_build(prepared, self.config)
return prepared
def manual_start(self, run_parameters=None):
# commit_sha is the only required parameter
commit_sha = run_parameters.get('commit_sha')
if commit_sha is None:
raise TriggerStartException('missing required parameter')
config = self.config
metadata = {
'commit': commit_sha,
'git_url': config['build_source'],
} }
try: @classmethod
return self.prepare_build(metadata, is_manual=True) def service_name(cls):
except ValidationError as ve: return "custom-git"
raise TriggerStartException(ve.message)
def activate(self, standard_webhook_url): def is_active(self):
config = self.config return self.config.has_key("credentials")
public_key, private_key = generate_ssh_keypair()
config['credentials'] = [
{
'name': 'SSH Public Key',
'value': public_key,
},
{
'name': 'Webhook Endpoint URL',
'value': standard_webhook_url,
},
]
self.config = config
return config, {'private_key': private_key}
def deactivate(self): def _metadata_from_payload(self, payload, git_url):
config = self.config # Parse the JSON payload.
config.pop('credentials', None) try:
self.config = config metadata = json.loads(payload)
return config except ValueError as vex:
raise InvalidPayloadException(vex.message)
def get_repository_url(self): return custom_trigger_payload(metadata, git_url)
return None
def list_build_source_namespaces(self): def handle_trigger_request(self, request):
raise NotImplementedError payload = request.data
if not payload:
raise InvalidPayloadException("Missing expected payload")
def list_build_sources_for_namespace(self, namespace): logger.debug("Payload %s", payload)
raise NotImplementedError
def list_build_subdirs(self): metadata = self._metadata_from_payload(payload, self.config["build_source"])
raise NotImplementedError prepared = self.prepare_build(metadata)
def list_field_values(self, field_name, limit=None): # Check if we should skip this build.
raise NotImplementedError raise_if_skipped_build(prepared, self.config)
def load_dockerfile_contents(self): return prepared
raise NotImplementedError
def manual_start(self, run_parameters=None):
# commit_sha is the only required parameter
commit_sha = run_parameters.get("commit_sha")
if commit_sha is None:
raise TriggerStartException("missing required parameter")
config = self.config
metadata = {"commit": commit_sha, "git_url": config["build_source"]}
try:
return self.prepare_build(metadata, is_manual=True)
except ValidationError as ve:
raise TriggerStartException(ve.message)
def activate(self, standard_webhook_url):
config = self.config
public_key, private_key = generate_ssh_keypair()
config["credentials"] = [
{"name": "SSH Public Key", "value": public_key},
{"name": "Webhook Endpoint URL", "value": standard_webhook_url},
]
self.config = config
return config, {"private_key": private_key}
def deactivate(self):
config = self.config
config.pop("credentials", None)
self.config = config
return config
def get_repository_url(self):
return None
def list_build_source_namespaces(self):
raise NotImplementedError
def list_build_sources_for_namespace(self, namespace):
raise NotImplementedError
def list_build_subdirs(self):
raise NotImplementedError
def list_field_values(self, field_name, limit=None):
raise NotImplementedError
def load_dockerfile_contents(self):
raise NotImplementedError

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,156 +4,168 @@ from mock import Mock
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
from util.morecollections import AttrDict from util.morecollections import AttrDict
def get_bitbucket_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
trigger = BitbucketBuildTrigger(trigger_obj, {
'build_source': 'foo/bar',
'dockerfile_path': dockerfile_path,
'nickname': 'knownuser',
'account_id': 'foo',
})
trigger._get_client = get_mock_bitbucket def get_bitbucket_trigger(dockerfile_path=""):
return trigger trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger = BitbucketBuildTrigger(
trigger_obj,
{
"build_source": "foo/bar",
"dockerfile_path": dockerfile_path,
"nickname": "knownuser",
"account_id": "foo",
},
)
trigger._get_client = get_mock_bitbucket
return trigger
def get_repo_path_contents(path, revision): def get_repo_path_contents(path, revision):
data = { data = {"files": [{"path": "Dockerfile"}]}
'files': [{'path': 'Dockerfile'}],
} return (True, data, None)
return (True, data, None)
def get_raw_path_contents(path, revision): def get_raw_path_contents(path, revision):
if path == 'Dockerfile': if path == "Dockerfile":
return (True, 'hello world', None) return (True, "hello world", None)
if path == 'somesubdir/Dockerfile': if path == "somesubdir/Dockerfile":
return (True, 'hi universe', None) return (True, "hi universe", None)
return (False, None, None)
return (False, None, None)
def get_branches_and_tags(): def get_branches_and_tags():
data = { data = {
'branches': [{'name': 'master'}, {'name': 'otherbranch'}], "branches": [{"name": "master"}, {"name": "otherbranch"}],
'tags': [{'name': 'sometag'}, {'name': 'someothertag'}], "tags": [{"name": "sometag"}, {"name": "someothertag"}],
} }
return (True, data, None) return (True, data, None)
def get_branches(): def get_branches():
return (True, {'master': {}, 'otherbranch': {}}, None) return (True, {"master": {}, "otherbranch": {}}, None)
def get_tags(): def get_tags():
return (True, {'sometag': {}, 'someothertag': {}}, None) return (True, {"sometag": {}, "someothertag": {}}, None)
def get_branch(branch_name): def get_branch(branch_name):
if branch_name != 'master': if branch_name != "master":
return (False, None, None) return (False, None, None)
data = { data = {"target": {"hash": "aaaaaaa"}}
'target': {
'hash': 'aaaaaaa', return (True, data, None)
},
}
return (True, data, None)
def get_tag(tag_name): def get_tag(tag_name):
if tag_name != 'sometag': if tag_name != "sometag":
return (False, None, None) return (False, None, None)
data = { data = {"target": {"hash": "aaaaaaa"}}
'target': {
'hash': 'aaaaaaa', return (True, data, None)
},
}
return (True, data, None)
def get_changeset_mock(commit_sha): def get_changeset_mock(commit_sha):
if commit_sha != 'aaaaaaa': if commit_sha != "aaaaaaa":
return (False, None, 'Not found') return (False, None, "Not found")
data = { data = {
'node': 'aaaaaaa', "node": "aaaaaaa",
'message': 'some message', "message": "some message",
'timestamp': 'now', "timestamp": "now",
'raw_author': 'foo@bar.com', "raw_author": "foo@bar.com",
} }
return (True, data, None)
return (True, data, None)
def get_changesets(): def get_changesets():
changesets_mock = Mock() changesets_mock = Mock()
changesets_mock.get = Mock(side_effect=get_changeset_mock) changesets_mock.get = Mock(side_effect=get_changeset_mock)
return changesets_mock return changesets_mock
def get_deploykeys(): def get_deploykeys():
deploykeys_mock = Mock() deploykeys_mock = Mock()
deploykeys_mock.create = Mock(return_value=(True, {'pk': 'someprivatekey'}, None)) deploykeys_mock.create = Mock(return_value=(True, {"pk": "someprivatekey"}, None))
deploykeys_mock.delete = Mock(return_value=(True, {}, None)) deploykeys_mock.delete = Mock(return_value=(True, {}, None))
return deploykeys_mock return deploykeys_mock
def get_webhooks(): def get_webhooks():
webhooks_mock = Mock() webhooks_mock = Mock()
webhooks_mock.create = Mock(return_value=(True, {'uuid': 'someuuid'}, None)) webhooks_mock.create = Mock(return_value=(True, {"uuid": "someuuid"}, None))
webhooks_mock.delete = Mock(return_value=(True, {}, None)) webhooks_mock.delete = Mock(return_value=(True, {}, None))
return webhooks_mock return webhooks_mock
def get_repo_mock(name): def get_repo_mock(name):
if name != 'bar': if name != "bar":
return None return None
repo_mock = Mock() repo_mock = Mock()
repo_mock.get_main_branch = Mock(return_value=(True, {'name': 'master'}, None)) repo_mock.get_main_branch = Mock(return_value=(True, {"name": "master"}, None))
repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents) repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents) repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags) repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
repo_mock.get_branches = Mock(side_effect=get_branches) repo_mock.get_branches = Mock(side_effect=get_branches)
repo_mock.get_tags = Mock(side_effect=get_tags) repo_mock.get_tags = Mock(side_effect=get_tags)
repo_mock.get_branch = Mock(side_effect=get_branch) repo_mock.get_branch = Mock(side_effect=get_branch)
repo_mock.get_tag = Mock(side_effect=get_tag) repo_mock.get_tag = Mock(side_effect=get_tag)
repo_mock.changesets = Mock(side_effect=get_changesets)
repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
repo_mock.webhooks = Mock(side_effect=get_webhooks)
return repo_mock
repo_mock.changesets = Mock(side_effect=get_changesets)
repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
repo_mock.webhooks = Mock(side_effect=get_webhooks)
return repo_mock
def get_repositories_mock(): def get_repositories_mock():
repos_mock = Mock() repos_mock = Mock()
repos_mock.get = Mock(side_effect=get_repo_mock) repos_mock.get = Mock(side_effect=get_repo_mock)
return repos_mock return repos_mock
def get_namespace_mock(namespace): def get_namespace_mock(namespace):
namespace_mock = Mock() namespace_mock = Mock()
namespace_mock.repositories = Mock(side_effect=get_repositories_mock) namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
return namespace_mock return namespace_mock
def get_repo(namespace, name): def get_repo(namespace, name):
return { return {
'owner': namespace, "owner": namespace,
'logo': 'avatarurl', "logo": "avatarurl",
'slug': name, "slug": name,
'description': 'some %s repo' % (name), "description": "some %s repo" % (name),
'utc_last_updated': str(datetime.utcfromtimestamp(0)), "utc_last_updated": str(datetime.utcfromtimestamp(0)),
'read_only': namespace != 'knownuser', "read_only": namespace != "knownuser",
'is_private': name == 'somerepo', "is_private": name == "somerepo",
} }
def get_visible_repos(): def get_visible_repos():
repos = [ repos = [
get_repo('knownuser', 'somerepo'), get_repo("knownuser", "somerepo"),
get_repo('someorg', 'somerepo'), get_repo("someorg", "somerepo"),
get_repo('someorg', 'anotherrepo'), get_repo("someorg", "anotherrepo"),
] ]
return (True, repos, None) return (True, repos, None)
def get_authed_mock(token, secret): def get_authed_mock(token, secret):
authed_mock = Mock() authed_mock = Mock()
authed_mock.for_namespace = Mock(side_effect=get_namespace_mock) authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos) authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
return authed_mock return authed_mock
def get_mock_bitbucket(): def get_mock_bitbucket():
bitbucket_mock = Mock() bitbucket_mock = Mock()
bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock) bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)
return bitbucket_mock return bitbucket_mock

View file

@ -6,173 +6,178 @@ from github import GithubException
from buildtrigger.githubhandler import GithubBuildTrigger from buildtrigger.githubhandler import GithubBuildTrigger
from util.morecollections import AttrDict from util.morecollections import AttrDict
def get_github_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger')) def get_github_trigger(dockerfile_path=""):
trigger = GithubBuildTrigger(trigger_obj, {'build_source': 'foo', 'dockerfile_path': dockerfile_path}) trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger._get_client = get_mock_github trigger = GithubBuildTrigger(
return trigger trigger_obj, {"build_source": "foo", "dockerfile_path": dockerfile_path}
)
trigger._get_client = get_mock_github
return trigger
def get_mock_github(): def get_mock_github():
def get_commit_mock(commit_sha): def get_commit_mock(commit_sha):
if commit_sha == 'aaaaaaa': if commit_sha == "aaaaaaa":
commit_mock = Mock() commit_mock = Mock()
commit_mock.sha = commit_sha commit_mock.sha = commit_sha
commit_mock.html_url = 'http://url/to/commit' commit_mock.html_url = "http://url/to/commit"
commit_mock.last_modified = 'now' commit_mock.last_modified = "now"
commit_mock.commit = Mock() commit_mock.commit = Mock()
commit_mock.commit.message = 'some cool message' commit_mock.commit.message = "some cool message"
commit_mock.committer = Mock() commit_mock.committer = Mock()
commit_mock.committer.login = 'someuser' commit_mock.committer.login = "someuser"
commit_mock.committer.avatar_url = 'avatarurl' commit_mock.committer.avatar_url = "avatarurl"
commit_mock.committer.html_url = 'htmlurl' commit_mock.committer.html_url = "htmlurl"
commit_mock.author = Mock() commit_mock.author = Mock()
commit_mock.author.login = 'someuser' commit_mock.author.login = "someuser"
commit_mock.author.avatar_url = 'avatarurl' commit_mock.author.avatar_url = "avatarurl"
commit_mock.author.html_url = 'htmlurl' commit_mock.author.html_url = "htmlurl"
return commit_mock return commit_mock
raise GithubException(None, None) raise GithubException(None, None)
def get_branch_mock(branch_name): def get_branch_mock(branch_name):
if branch_name == 'master': if branch_name == "master":
branch_mock = Mock() branch_mock = Mock()
branch_mock.commit = Mock() branch_mock.commit = Mock()
branch_mock.commit.sha = 'aaaaaaa' branch_mock.commit.sha = "aaaaaaa"
return branch_mock return branch_mock
raise GithubException(None, None) raise GithubException(None, None)
def get_repo_mock(namespace, name):
repo_mock = Mock()
repo_mock.owner = Mock()
repo_mock.owner.login = namespace
repo_mock.full_name = "%s/%s" % (namespace, name)
repo_mock.name = name
repo_mock.description = "some %s repo" % (name)
if name != "anotherrepo":
repo_mock.pushed_at = datetime.utcfromtimestamp(0)
else:
repo_mock.pushed_at = None
repo_mock.html_url = "https://bitbucket.org/%s/%s" % (namespace, name)
repo_mock.private = name == "somerepo"
repo_mock.permissions = Mock()
repo_mock.permissions.admin = namespace == "knownuser"
return repo_mock
def get_user_repos_mock(type="all", sort="created"):
return [get_repo_mock("knownuser", "somerepo")]
def get_org_repos_mock(type="all"):
return [
get_repo_mock("someorg", "somerepo"),
get_repo_mock("someorg", "anotherrepo"),
]
def get_orgs_mock():
return [get_org_mock("someorg")]
def get_user_mock(username="knownuser"):
if username == "knownuser":
user_mock = Mock()
user_mock.name = username
user_mock.plan = Mock()
user_mock.plan.private_repos = 1
user_mock.login = username
user_mock.html_url = "https://bitbucket.org/%s" % (username)
user_mock.avatar_url = "avatarurl"
user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
return user_mock
raise GithubException(None, None)
def get_org_mock(namespace):
if namespace == "someorg":
org_mock = Mock()
org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
org_mock.login = namespace
org_mock.html_url = "https://bitbucket.org/%s" % (namespace)
org_mock.avatar_url = "avatarurl"
org_mock.name = namespace
org_mock.plan = Mock()
org_mock.plan.private_repos = 2
return org_mock
raise GithubException(None, None)
def get_tags_mock():
sometag = Mock()
sometag.name = "sometag"
sometag.commit = get_commit_mock("aaaaaaa")
someothertag = Mock()
someothertag.name = "someothertag"
someothertag.commit = get_commit_mock("aaaaaaa")
return [sometag, someothertag]
def get_branches_mock():
master = Mock()
master.name = "master"
master.commit = get_commit_mock("aaaaaaa")
otherbranch = Mock()
otherbranch.name = "otherbranch"
otherbranch.commit = get_commit_mock("aaaaaaa")
return [master, otherbranch]
def get_contents_mock(filepath):
if filepath == "Dockerfile":
m = Mock()
m.content = "hello world"
return m
if filepath == "somesubdir/Dockerfile":
m = Mock()
m.content = "hi universe"
return m
raise GithubException(None, None)
def get_git_tree_mock(commit_sha, recursive=False):
first_file = Mock()
first_file.type = "blob"
first_file.path = "Dockerfile"
second_file = Mock()
second_file.type = "other"
second_file.path = "/some/Dockerfile"
third_file = Mock()
third_file.type = "blob"
third_file.path = "somesubdir/Dockerfile"
t = Mock()
if commit_sha == "aaaaaaa":
t.tree = [first_file, second_file, third_file]
else:
t.tree = []
return t
def get_repo_mock(namespace, name):
repo_mock = Mock() repo_mock = Mock()
repo_mock.owner = Mock() repo_mock.default_branch = "master"
repo_mock.owner.login = namespace repo_mock.ssh_url = "ssh_url"
repo_mock.full_name = '%s/%s' % (namespace, name) repo_mock.get_branch = Mock(side_effect=get_branch_mock)
repo_mock.name = name repo_mock.get_tags = Mock(side_effect=get_tags_mock)
repo_mock.description = 'some %s repo' % (name) repo_mock.get_branches = Mock(side_effect=get_branches_mock)
repo_mock.get_commit = Mock(side_effect=get_commit_mock)
repo_mock.get_contents = Mock(side_effect=get_contents_mock)
repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
if name != 'anotherrepo': gh_mock = Mock()
repo_mock.pushed_at = datetime.utcfromtimestamp(0) gh_mock.get_repo = Mock(return_value=repo_mock)
else: gh_mock.get_user = Mock(side_effect=get_user_mock)
repo_mock.pushed_at = None gh_mock.get_organization = Mock(side_effect=get_org_mock)
return gh_mock
repo_mock.html_url = 'https://bitbucket.org/%s/%s' % (namespace, name)
repo_mock.private = name == 'somerepo'
repo_mock.permissions = Mock()
repo_mock.permissions.admin = namespace == 'knownuser'
return repo_mock
def get_user_repos_mock(type='all', sort='created'):
return [get_repo_mock('knownuser', 'somerepo')]
def get_org_repos_mock(type='all'):
return [get_repo_mock('someorg', 'somerepo'), get_repo_mock('someorg', 'anotherrepo')]
def get_orgs_mock():
return [get_org_mock('someorg')]
def get_user_mock(username='knownuser'):
if username == 'knownuser':
user_mock = Mock()
user_mock.name = username
user_mock.plan = Mock()
user_mock.plan.private_repos = 1
user_mock.login = username
user_mock.html_url = 'https://bitbucket.org/%s' % (username)
user_mock.avatar_url = 'avatarurl'
user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
return user_mock
raise GithubException(None, None)
def get_org_mock(namespace):
if namespace == 'someorg':
org_mock = Mock()
org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
org_mock.login = namespace
org_mock.html_url = 'https://bitbucket.org/%s' % (namespace)
org_mock.avatar_url = 'avatarurl'
org_mock.name = namespace
org_mock.plan = Mock()
org_mock.plan.private_repos = 2
return org_mock
raise GithubException(None, None)
def get_tags_mock():
sometag = Mock()
sometag.name = 'sometag'
sometag.commit = get_commit_mock('aaaaaaa')
someothertag = Mock()
someothertag.name = 'someothertag'
someothertag.commit = get_commit_mock('aaaaaaa')
return [sometag, someothertag]
def get_branches_mock():
master = Mock()
master.name = 'master'
master.commit = get_commit_mock('aaaaaaa')
otherbranch = Mock()
otherbranch.name = 'otherbranch'
otherbranch.commit = get_commit_mock('aaaaaaa')
return [master, otherbranch]
def get_contents_mock(filepath):
if filepath == 'Dockerfile':
m = Mock()
m.content = 'hello world'
return m
if filepath == 'somesubdir/Dockerfile':
m = Mock()
m.content = 'hi universe'
return m
raise GithubException(None, None)
def get_git_tree_mock(commit_sha, recursive=False):
first_file = Mock()
first_file.type = 'blob'
first_file.path = 'Dockerfile'
second_file = Mock()
second_file.type = 'other'
second_file.path = '/some/Dockerfile'
third_file = Mock()
third_file.type = 'blob'
third_file.path = 'somesubdir/Dockerfile'
t = Mock()
if commit_sha == 'aaaaaaa':
t.tree = [
first_file, second_file, third_file,
]
else:
t.tree = []
return t
repo_mock = Mock()
repo_mock.default_branch = 'master'
repo_mock.ssh_url = 'ssh_url'
repo_mock.get_branch = Mock(side_effect=get_branch_mock)
repo_mock.get_tags = Mock(side_effect=get_tags_mock)
repo_mock.get_branches = Mock(side_effect=get_branches_mock)
repo_mock.get_commit = Mock(side_effect=get_commit_mock)
repo_mock.get_contents = Mock(side_effect=get_contents_mock)
repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
gh_mock = Mock()
gh_mock.get_repo = Mock(return_value=repo_mock)
gh_mock.get_user = Mock(side_effect=get_user_mock)
gh_mock.get_organization = Mock(side_effect=get_org_mock)
return gh_mock

File diff suppressed because it is too large Load diff

View file

@ -3,53 +3,74 @@ import pytest
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
@pytest.mark.parametrize('input,output', [ @pytest.mark.parametrize(
("Dockerfile", True), "input,output",
("server.Dockerfile", True), [
(u"Dockerfile", True), ("Dockerfile", True),
(u"server.Dockerfile", True), ("server.Dockerfile", True),
("bad file name", False), (u"Dockerfile", True),
(u"bad file name", False), (u"server.Dockerfile", True),
]) ("bad file name", False),
(u"bad file name", False),
],
)
def test_path_is_dockerfile(input, output): def test_path_is_dockerfile(input, output):
assert BuildTriggerHandler.filename_is_dockerfile(input) == output assert BuildTriggerHandler.filename_is_dockerfile(input) == output
@pytest.mark.parametrize('input,output', [ @pytest.mark.parametrize(
("", {}), "input,output",
("/a", {"/a": ["/"]}), [
("a", {"/a": ["/"]}), ("", {}),
("/b/a", {"/b/a": ["/b", "/"]}), ("/a", {"/a": ["/"]}),
("b/a", {"/b/a": ["/b", "/"]}), ("a", {"/a": ["/"]}),
("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}), ("/b/a", {"/b/a": ["/b", "/"]}),
("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}), ("b/a", {"/b/a": ["/b", "/"]}),
("/a", {"/a": ["/"]}), ("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
]) ("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
("/a", {"/a": ["/"]}),
],
)
def test_subdir_path_map_no_previous(input, output): def test_subdir_path_map_no_previous(input, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input) actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
for key in actual_mapping: for key in actual_mapping:
value = actual_mapping[key] value = actual_mapping[key]
actual_mapping[key] = value.sort() actual_mapping[key] = value.sort()
for key in output: for key in output:
value = output[key] value = output[key]
output[key] = value.sort() output[key] = value.sort()
assert actual_mapping == output assert actual_mapping == output
@pytest.mark.parametrize('new_path,original_dictionary,output', [ @pytest.mark.parametrize(
("/a", {}, {"/a": ["/"]}), "new_path,original_dictionary,output",
("b", {"/a": ["some_path", "another_path"]}, {"/a": ["some_path", "another_path"], "/b": ["/"]}), [
("/a/b/c/d", {"/e": ["some_path", "another_path"]}, ("/a", {}, {"/a": ["/"]}),
{"/e": ["some_path", "another_path"], "/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"]}), (
]) "b",
{"/a": ["some_path", "another_path"]},
{"/a": ["some_path", "another_path"], "/b": ["/"]},
),
(
"/a/b/c/d",
{"/e": ["some_path", "another_path"]},
{
"/e": ["some_path", "another_path"],
"/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"],
},
),
],
)
def test_subdir_path_map(new_path, original_dictionary, output): def test_subdir_path_map(new_path, original_dictionary, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(new_path, original_dictionary) actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(
for key in actual_mapping: new_path, original_dictionary
value = actual_mapping[key] )
actual_mapping[key] = value.sort() for key in actual_mapping:
for key in output: value = actual_mapping[key]
value = output[key] actual_mapping[key] = value.sort()
output[key] = value.sort() for key in output:
value = output[key]
output[key] = value.sort()
assert actual_mapping == output assert actual_mapping == output

View file

@ -2,35 +2,44 @@ import json
import pytest import pytest
from buildtrigger.test.bitbucketmock import get_bitbucket_trigger from buildtrigger.test.bitbucketmock import get_bitbucket_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture @pytest.fixture
def bitbucket_trigger(): def bitbucket_trigger():
return get_bitbucket_trigger() return get_bitbucket_trigger()
def test_list_build_subdirs(bitbucket_trigger): def test_list_build_subdirs(bitbucket_trigger):
assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"] assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_bitbucket_trigger(dockerfile_path) trigger = get_bitbucket_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
('{}', InvalidPayloadException, "'push' is a required property"), "payload, expected_error, expected_message",
[
# Valid payload: ("{}", InvalidPayloadException, "'push' is a required property"),
('''{ # Valid payload:
(
"""{
"push": { "push": {
"changes": [{ "changes": [{
"new": { "new": {
@ -51,10 +60,13 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": { "repository": {
"full_name": "foo/bar" "full_name": "foo/bar"
} }
}''', None, None), }""",
None,
# Skip message: None,
('''{ ),
# Skip message:
(
"""{
"push": { "push": {
"changes": [{ "changes": [{
"new": { "new": {
@ -75,17 +87,25 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": { "repository": {
"full_name": "foo/bar" "full_name": "foo/bar"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(bitbucket_trigger, payload, expected_error, expected_message): "",
def get_payload(): ),
return json.loads(payload) ],
)
def test_handle_trigger_request(
bitbucket_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload)) request = AttrDict(dict(get_json=get_payload))
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
bitbucket_trigger.handle_trigger_request(request) bitbucket_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(bitbucket_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(
bitbucket_trigger.handle_trigger_request(request), PreparedBuild
)

View file

@ -1,20 +1,32 @@
import pytest import pytest
from buildtrigger.customhandler import CustomBuildTrigger from buildtrigger.customhandler import CustomBuildTrigger
from buildtrigger.triggerutil import (InvalidPayloadException, SkipRequestException, from buildtrigger.triggerutil import (
TriggerStartException) InvalidPayloadException,
SkipRequestException,
TriggerStartException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('', InvalidPayloadException, 'Missing expected payload'),
('{}', InvalidPayloadException, "'commit' is a required property"),
('{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}', @pytest.mark.parametrize(
InvalidPayloadException, "u'foo' does not match '^([A-Fa-f0-9]{7,})$'"), "payload, expected_error, expected_message",
[
('{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}', None, None), ("", InvalidPayloadException, "Missing expected payload"),
('''{ ("{}", InvalidPayloadException, "'commit' is a required property"),
(
'{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
InvalidPayloadException,
"u'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
(
'{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}',
None,
None,
),
(
"""{
"commit": "11d6fbc", "commit": "11d6fbc",
"ref": "refs/heads/something", "ref": "refs/heads/something",
"default_branch": "baz", "default_branch": "baz",
@ -23,29 +35,41 @@ from util.morecollections import AttrDict
"url": "http://foo.bar", "url": "http://foo.bar",
"date": "NOW" "date": "NOW"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(payload, expected_error, expected_message): def test_handle_trigger_request(payload, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'}) trigger = CustomBuildTrigger(None, {"build_source": "foo"})
request = AttrDict(dict(data=payload)) request = AttrDict(dict(data=payload))
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
trigger.handle_trigger_request(request) trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
({}, TriggerStartException, 'missing required parameter'), @pytest.mark.parametrize(
({'commit_sha': 'foo'}, TriggerStartException, "'foo' does not match '^([A-Fa-f0-9]{7,})$'"), "run_parameters, expected_error, expected_message",
({'commit_sha': '11d6fbc'}, None, None), [
]) ({}, TriggerStartException, "missing required parameter"),
(
{"commit_sha": "foo"},
TriggerStartException,
"'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
({"commit_sha": "11d6fbc"}, None, None),
],
)
def test_manual_start(run_parameters, expected_error, expected_message): def test_manual_start(run_parameters, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'}) trigger = CustomBuildTrigger(None, {"build_source": "foo"})
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
trigger.manual_start(run_parameters) trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(trigger.manual_start(run_parameters), PreparedBuild) assert isinstance(trigger.manual_start(run_parameters), PreparedBuild)

View file

@ -9,113 +9,145 @@ from endpoints.building import PreparedBuild
# in this fixture. Each trigger's mock is expected to return the same data for all of these calls. # in this fixture. Each trigger's mock is expected to return the same data for all of these calls.
@pytest.fixture(params=[get_github_trigger(), get_bitbucket_trigger()]) @pytest.fixture(params=[get_github_trigger(), get_bitbucket_trigger()])
def githost_trigger(request): def githost_trigger(request):
return request.param return request.param
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
# No branch or tag specified: use the commit of the default branch.
({}, None, None),
# Invalid branch.
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException,
'Could not find branch in repository'),
# Invalid tag.
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException,
'Could not find tag in repository'),
# Valid branch.
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None),
# Valid tag.
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None),
])
def test_manual_start(run_parameters, expected_error, expected_message, githost_trigger):
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
githost_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
@pytest.mark.parametrize('name, expected', [ @pytest.mark.parametrize(
('refs', [ "run_parameters, expected_error, expected_message",
{'kind': 'branch', 'name': 'master'}, [
{'kind': 'branch', 'name': 'otherbranch'}, # No branch or tag specified: use the commit of the default branch.
{'kind': 'tag', 'name': 'sometag'}, ({}, None, None),
{'kind': 'tag', 'name': 'someothertag'}, # Invalid branch.
]), (
('tag_name', set(['sometag', 'someothertag'])), {"refs": {"kind": "branch", "name": "invalid"}},
('branch_name', set(['master', 'otherbranch'])), TriggerStartException,
('invalid', None) "Could not find branch in repository",
]) ),
# Invalid tag.
(
{"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
"Could not find tag in repository",
),
# Valid branch.
({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag.
({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
],
)
def test_manual_start(
run_parameters, expected_error, expected_message, githost_trigger
):
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
githost_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
@pytest.mark.parametrize(
"name, expected",
[
(
"refs",
[
{"kind": "branch", "name": "master"},
{"kind": "branch", "name": "otherbranch"},
{"kind": "tag", "name": "sometag"},
{"kind": "tag", "name": "someothertag"},
],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, githost_trigger): def test_list_field_values(name, expected, githost_trigger):
if expected is None: if expected is None:
assert githost_trigger.list_field_values(name) is None assert githost_trigger.list_field_values(name) is None
elif isinstance(expected, set): elif isinstance(expected, set):
assert set(githost_trigger.list_field_values(name)) == set(expected) assert set(githost_trigger.list_field_values(name)) == set(expected)
else: else:
assert githost_trigger.list_field_values(name) == expected assert githost_trigger.list_field_values(name) == expected
def test_list_build_source_namespaces(): def test_list_build_source_namespaces():
namespaces_expected = [ namespaces_expected = [
{ {
'personal': True, "personal": True,
'score': 1, "score": 1,
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'knownuser', "id": "knownuser",
'title': 'knownuser', "title": "knownuser",
'url': 'https://bitbucket.org/knownuser', "url": "https://bitbucket.org/knownuser",
}, },
{ {
'score': 2, "score": 2,
'title': 'someorg', "title": "someorg",
'personal': False, "personal": False,
'url': 'https://bitbucket.org/someorg', "url": "https://bitbucket.org/someorg",
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'someorg' "id": "someorg",
} },
] ]
found = get_bitbucket_trigger().list_build_source_namespaces() found = get_bitbucket_trigger().list_build_source_namespaces()
found.sort() found.sort()
namespaces_expected.sort() namespaces_expected.sort()
assert found == namespaces_expected assert found == namespaces_expected
@pytest.mark.parametrize('namespace, expected', [ @pytest.mark.parametrize(
('', []), "namespace, expected",
('unknown', []), [
("", []),
('knownuser', [ ("unknown", []),
{ (
'last_updated': 0, 'name': 'somerepo', "knownuser",
'url': 'https://bitbucket.org/knownuser/somerepo', 'private': True, [
'full_name': 'knownuser/somerepo', 'has_admin_permissions': True, {
'description': 'some somerepo repo' "last_updated": 0,
}]), "name": "somerepo",
"url": "https://bitbucket.org/knownuser/somerepo",
('someorg', [ "private": True,
{ "full_name": "knownuser/somerepo",
'last_updated': 0, 'name': 'somerepo', "has_admin_permissions": True,
'url': 'https://bitbucket.org/someorg/somerepo', 'private': True, "description": "some somerepo repo",
'full_name': 'someorg/somerepo', 'has_admin_permissions': False, }
'description': 'some somerepo repo' ],
}, ),
{ (
'last_updated': 0, 'name': 'anotherrepo', "someorg",
'url': 'https://bitbucket.org/someorg/anotherrepo', 'private': False, [
'full_name': 'someorg/anotherrepo', 'has_admin_permissions': False, {
'description': 'some anotherrepo repo' "last_updated": 0,
}]), "name": "somerepo",
]) "url": "https://bitbucket.org/someorg/somerepo",
"private": True,
"full_name": "someorg/somerepo",
"has_admin_permissions": False,
"description": "some somerepo repo",
},
{
"last_updated": 0,
"name": "anotherrepo",
"url": "https://bitbucket.org/someorg/anotherrepo",
"private": False,
"full_name": "someorg/anotherrepo",
"has_admin_permissions": False,
"description": "some anotherrepo repo",
},
],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, githost_trigger): def test_list_build_sources_for_namespace(namespace, expected, githost_trigger):
assert githost_trigger.list_build_sources_for_namespace(namespace) == expected assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
def test_activate_and_deactivate(githost_trigger): def test_activate_and_deactivate(githost_trigger):
_, private_key = githost_trigger.activate('http://some/url') _, private_key = githost_trigger.activate("http://some/url")
assert 'private_key' in private_key assert "private_key" in private_key
githost_trigger.deactivate() githost_trigger.deactivate()

View file

@ -2,24 +2,33 @@ import json
import pytest import pytest
from buildtrigger.test.githubmock import get_github_trigger from buildtrigger.test.githubmock import get_github_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture @pytest.fixture
def github_trigger(): def github_trigger():
return get_github_trigger() return get_github_trigger()
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
('{"zen": true}', SkipRequestException, ""), "payload, expected_error, expected_message",
[
('{}', InvalidPayloadException, "Missing 'repository' on request"), ('{"zen": true}', SkipRequestException, ""),
('{"repository": "foo"}', InvalidPayloadException, "Missing 'owner' on repository"), ("{}", InvalidPayloadException, "Missing 'repository' on request"),
(
# Valid payload: '{"repository": "foo"}',
('''{ InvalidPayloadException,
"Missing 'owner' on repository",
),
# Valid payload:
(
"""{
"repository": { "repository": {
"owner": { "owner": {
"name": "someguy" "name": "someguy"
@ -34,10 +43,13 @@ def github_trigger():
"message": "some message", "message": "some message",
"timestamp": "NOW" "timestamp": "NOW"
} }
}''', None, None), }""",
None,
# Skip message: None,
('''{ ),
# Skip message:
(
"""{
"repository": { "repository": {
"owner": { "owner": {
"name": "someguy" "name": "someguy"
@ -52,66 +64,84 @@ def github_trigger():
"message": "[skip build]", "message": "[skip build]",
"timestamp": "NOW" "timestamp": "NOW"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(github_trigger, payload, expected_error, expected_message): "",
def get_payload(): ),
return json.loads(payload) ],
)
def test_handle_trigger_request(
github_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload)) request = AttrDict(dict(get_json=get_payload))
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
github_trigger.handle_trigger_request(request) github_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_github_trigger(dockerfile_path) trigger = get_github_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('username, expected_response', [ @pytest.mark.parametrize(
('unknownuser', None), "username, expected_response",
('knownuser', {'html_url': 'https://bitbucket.org/knownuser', 'avatar_url': 'avatarurl'}), [
]) ("unknownuser", None),
(
"knownuser",
{"html_url": "https://bitbucket.org/knownuser", "avatar_url": "avatarurl"},
),
],
)
def test_lookup_user(username, expected_response, github_trigger): def test_lookup_user(username, expected_response, github_trigger):
assert github_trigger.lookup_user(username) == expected_response assert github_trigger.lookup_user(username) == expected_response
def test_list_build_subdirs(github_trigger): def test_list_build_subdirs(github_trigger):
assert github_trigger.list_build_subdirs() == ['Dockerfile', 'somesubdir/Dockerfile'] assert github_trigger.list_build_subdirs() == [
"Dockerfile",
"somesubdir/Dockerfile",
]
def test_list_build_source_namespaces(github_trigger): def test_list_build_source_namespaces(github_trigger):
namespaces_expected = [ namespaces_expected = [
{ {
'personal': True, "personal": True,
'score': 1, "score": 1,
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'knownuser', "id": "knownuser",
'title': 'knownuser', "title": "knownuser",
'url': 'https://bitbucket.org/knownuser', "url": "https://bitbucket.org/knownuser",
}, },
{ {
'score': 0, "score": 0,
'title': 'someorg', "title": "someorg",
'personal': False, "personal": False,
'url': '', "url": "",
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'someorg' "id": "someorg",
} },
] ]
found = github_trigger.list_build_source_namespaces() found = github_trigger.list_build_source_namespaces()
found.sort() found.sort()
namespaces_expected.sort() namespaces_expected.sort()
assert found == namespaces_expected assert found == namespaces_expected

View file

@ -4,91 +4,111 @@ import pytest
from mock import Mock from mock import Mock
from buildtrigger.test.gitlabmock import get_gitlab_trigger from buildtrigger.test.gitlabmock import get_gitlab_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException, TriggerStartException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
TriggerStartException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture() @pytest.fixture()
def gitlab_trigger(): def gitlab_trigger():
with get_gitlab_trigger() as t: with get_gitlab_trigger() as t:
yield t yield t
def test_list_build_subdirs(gitlab_trigger): def test_list_build_subdirs(gitlab_trigger):
assert gitlab_trigger.list_build_subdirs() == ['Dockerfile'] assert gitlab_trigger.list_build_subdirs() == ["Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger: with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('email, expected_response', [ @pytest.mark.parametrize(
('unknown@email.com', None), "email, expected_response",
('knownuser', {'username': 'knownuser', 'html_url': 'https://bitbucket.org/knownuser', [
'avatar_url': 'avatarurl'}), ("unknown@email.com", None),
]) (
"knownuser",
{
"username": "knownuser",
"html_url": "https://bitbucket.org/knownuser",
"avatar_url": "avatarurl",
},
),
],
)
def test_lookup_user(email, expected_response, gitlab_trigger): def test_lookup_user(email, expected_response, gitlab_trigger):
assert gitlab_trigger.lookup_user(email) == expected_response assert gitlab_trigger.lookup_user(email) == expected_response
def test_null_permissions(): def test_null_permissions():
with get_gitlab_trigger(add_permissions=False) as trigger: with get_gitlab_trigger(add_permissions=False) as trigger:
sources = trigger.list_build_sources_for_namespace('someorg') sources = trigger.list_build_sources_for_namespace("someorg")
source = sources[0] source = sources[0]
assert source['has_admin_permissions'] assert source["has_admin_permissions"]
def test_list_build_sources(): def test_list_build_sources():
with get_gitlab_trigger() as trigger: with get_gitlab_trigger() as trigger:
sources = trigger.list_build_sources_for_namespace('someorg') sources = trigger.list_build_sources_for_namespace("someorg")
assert sources == [ assert sources == [
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'someproject', "name": u"someproject",
'url': u'http://example.com/someorg/someproject', "url": u"http://example.com/someorg/someproject",
'private': True, "private": True,
'full_name': u'someorg/someproject', "full_name": u"someorg/someproject",
'has_admin_permissions': False, "has_admin_permissions": False,
'description': '' "description": "",
}, },
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'anotherproject', "name": u"anotherproject",
'url': u'http://example.com/someorg/anotherproject', "url": u"http://example.com/someorg/anotherproject",
'private': False, "private": False,
'full_name': u'someorg/anotherproject', "full_name": u"someorg/anotherproject",
'has_admin_permissions': True, "has_admin_permissions": True,
'description': '', "description": "",
}] },
]
def test_null_avatar(): def test_null_avatar():
with get_gitlab_trigger(missing_avatar_url=True) as trigger: with get_gitlab_trigger(missing_avatar_url=True) as trigger:
namespace_data = trigger.list_build_source_namespaces() namespace_data = trigger.list_build_source_namespaces()
expected = { expected = {
'avatar_url': None, "avatar_url": None,
'personal': False, "personal": False,
'title': u'someorg', "title": u"someorg",
'url': u'http://gitlab.com/groups/someorg', "url": u"http://gitlab.com/groups/someorg",
'score': 1, "score": 1,
'id': '2', "id": "2",
} }
assert namespace_data == [expected] assert namespace_data == [expected]
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
('{}', InvalidPayloadException, ''), "payload, expected_error, expected_message",
[
# Valid payload: ("{}", InvalidPayloadException, ""),
('''{ # Valid payload:
(
"""{
"object_kind": "push", "object_kind": "push",
"ref": "refs/heads/master", "ref": "refs/heads/master",
"checkout_sha": "aaaaaaa", "checkout_sha": "aaaaaaa",
@ -103,10 +123,13 @@ def test_null_avatar():
"timestamp": "now" "timestamp": "now"
} }
] ]
}''', None, None), }""",
None,
# Skip message: None,
('''{ ),
# Skip message:
(
"""{
"object_kind": "push", "object_kind": "push",
"ref": "refs/heads/master", "ref": "refs/heads/master",
"checkout_sha": "aaaaaaa", "checkout_sha": "aaaaaaa",
@ -121,111 +144,136 @@ def test_null_avatar():
"timestamp": "now" "timestamp": "now"
} }
] ]
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(gitlab_trigger, payload, expected_error, expected_message): "",
def get_payload(): ),
return json.loads(payload) ],
)
def test_handle_trigger_request(
gitlab_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload)) request = AttrDict(dict(get_json=get_payload))
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
gitlab_trigger.handle_trigger_request(request) gitlab_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [ @pytest.mark.parametrize(
# No branch or tag specified: use the commit of the default branch. "run_parameters, expected_error, expected_message",
({}, None, None), [
# No branch or tag specified: use the commit of the default branch.
# Invalid branch. ({}, None, None),
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException, # Invalid branch.
'Could not find branch in repository'), (
{"refs": {"kind": "branch", "name": "invalid"}},
# Invalid tag. TriggerStartException,
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException, "Could not find branch in repository",
'Could not find tag in repository'), ),
# Invalid tag.
# Valid branch. (
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None), {"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
# Valid tag. "Could not find tag in repository",
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None), ),
]) # Valid branch.
({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag.
({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
],
)
def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger): def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger):
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
gitlab_trigger.manual_start(run_parameters) gitlab_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild) assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild)
def test_activate_and_deactivate(gitlab_trigger): def test_activate_and_deactivate(gitlab_trigger):
_, private_key = gitlab_trigger.activate('http://some/url') _, private_key = gitlab_trigger.activate("http://some/url")
assert 'private_key' in private_key assert "private_key" in private_key
gitlab_trigger.deactivate() gitlab_trigger.deactivate()
@pytest.mark.parametrize('name, expected', [ @pytest.mark.parametrize(
('refs', [ "name, expected",
{'kind': 'branch', 'name': 'master'}, [
{'kind': 'branch', 'name': 'otherbranch'}, (
{'kind': 'tag', 'name': 'sometag'}, "refs",
{'kind': 'tag', 'name': 'someothertag'}, [
]), {"kind": "branch", "name": "master"},
('tag_name', set(['sometag', 'someothertag'])), {"kind": "branch", "name": "otherbranch"},
('branch_name', set(['master', 'otherbranch'])), {"kind": "tag", "name": "sometag"},
('invalid', None) {"kind": "tag", "name": "someothertag"},
]) ],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, gitlab_trigger): def test_list_field_values(name, expected, gitlab_trigger):
if expected is None: if expected is None:
assert gitlab_trigger.list_field_values(name) is None assert gitlab_trigger.list_field_values(name) is None
elif isinstance(expected, set): elif isinstance(expected, set):
assert set(gitlab_trigger.list_field_values(name)) == set(expected) assert set(gitlab_trigger.list_field_values(name)) == set(expected)
else: else:
assert gitlab_trigger.list_field_values(name) == expected assert gitlab_trigger.list_field_values(name) == expected
@pytest.mark.parametrize('namespace, expected', [ @pytest.mark.parametrize(
('', []), "namespace, expected",
('unknown', []), [
("", []),
('knownuser', [ ("unknown", []),
{ (
'last_updated': 1380548762, "knownuser",
'name': u'anotherproject', [
'url': u'http://example.com/knownuser/anotherproject', {
'private': False, "last_updated": 1380548762,
'full_name': u'knownuser/anotherproject', "name": u"anotherproject",
'has_admin_permissions': True, "url": u"http://example.com/knownuser/anotherproject",
'description': '' "private": False,
}, "full_name": u"knownuser/anotherproject",
]), "has_admin_permissions": True,
"description": "",
('someorg', [ }
{ ],
'last_updated': 1380548762, ),
'name': u'someproject', (
'url': u'http://example.com/someorg/someproject', "someorg",
'private': True, [
'full_name': u'someorg/someproject', {
'has_admin_permissions': False, "last_updated": 1380548762,
'description': '' "name": u"someproject",
}, "url": u"http://example.com/someorg/someproject",
{ "private": True,
'last_updated': 1380548762, "full_name": u"someorg/someproject",
'name': u'anotherproject', "has_admin_permissions": False,
'url': u'http://example.com/someorg/anotherproject', "description": "",
'private': False, },
'full_name': u'someorg/anotherproject', {
'has_admin_permissions': True, "last_updated": 1380548762,
'description': '', "name": u"anotherproject",
}]), "url": u"http://example.com/someorg/anotherproject",
]) "private": False,
"full_name": u"someorg/anotherproject",
"has_admin_permissions": True,
"description": "",
},
],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger): def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger):
assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected

File diff suppressed because it is too large Load diff

View file

@ -4,22 +4,43 @@ import pytest
from buildtrigger.triggerutil import matches_ref from buildtrigger.triggerutil import matches_ref
@pytest.mark.parametrize('ref, filt, matches', [
('ref/heads/master', '.+', True),
('ref/heads/master', 'heads/.+', True),
('ref/heads/master', 'heads/master', True),
('ref/heads/slash/branch', 'heads/slash/branch', True),
('ref/heads/slash/branch', 'heads/.+', True),
('ref/heads/foobar', 'heads/master', False), @pytest.mark.parametrize(
('ref/heads/master', 'tags/master', False), "ref, filt, matches",
[
('ref/heads/master', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", ".+", True),
('ref/heads/alpha', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", "heads/.+", True),
('ref/heads/beta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", "heads/master", True),
('ref/heads/gamma', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/slash/branch", "heads/slash/branch", True),
("ref/heads/slash/branch", "heads/.+", True),
('ref/heads/delta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', False), ("ref/heads/foobar", "heads/master", False),
]) ("ref/heads/master", "tags/master", False),
(
"ref/heads/master",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/alpha",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/beta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/gamma",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/delta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
False,
),
],
)
def test_matches_ref(ref, filt, matches): def test_matches_ref(ref, filt, matches):
assert matches_ref(ref, re.compile(filt)) == matches assert matches_ref(ref, re.compile(filt)) == matches

View file

@ -3,128 +3,146 @@ import io
import logging import logging
import re import re
class TriggerException(Exception): class TriggerException(Exception):
pass pass
class TriggerAuthException(TriggerException): class TriggerAuthException(TriggerException):
pass pass
class InvalidPayloadException(TriggerException): class InvalidPayloadException(TriggerException):
pass pass
class BuildArchiveException(TriggerException): class BuildArchiveException(TriggerException):
pass pass
class InvalidServiceException(TriggerException): class InvalidServiceException(TriggerException):
pass pass
class TriggerActivationException(TriggerException): class TriggerActivationException(TriggerException):
pass pass
class TriggerDeactivationException(TriggerException): class TriggerDeactivationException(TriggerException):
pass pass
class TriggerStartException(TriggerException): class TriggerStartException(TriggerException):
pass pass
class ValidationRequestException(TriggerException): class ValidationRequestException(TriggerException):
pass pass
class SkipRequestException(TriggerException): class SkipRequestException(TriggerException):
pass pass
class EmptyRepositoryException(TriggerException): class EmptyRepositoryException(TriggerException):
pass pass
class RepositoryReadException(TriggerException): class RepositoryReadException(TriggerException):
pass pass
class TriggerProviderException(TriggerException): class TriggerProviderException(TriggerException):
pass pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch): def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch):
run_parameters = run_parameters or {} run_parameters = run_parameters or {}
kind = '' kind = ""
value = '' value = ""
if 'refs' in run_parameters and run_parameters['refs']: if "refs" in run_parameters and run_parameters["refs"]:
kind = run_parameters['refs']['kind'] kind = run_parameters["refs"]["kind"]
value = run_parameters['refs']['name'] value = run_parameters["refs"]["name"]
elif 'branch_name' in run_parameters: elif "branch_name" in run_parameters:
kind = 'branch' kind = "branch"
value = run_parameters['branch_name'] value = run_parameters["branch_name"]
kind = kind or 'branch' kind = kind or "branch"
value = value or default_branch or 'master' value = value or default_branch or "master"
ref = 'refs/tags/' + value if kind == 'tag' else 'refs/heads/' + value ref = "refs/tags/" + value if kind == "tag" else "refs/heads/" + value
commit_sha = get_tag_sha(value) if kind == 'tag' else get_branch_sha(value) commit_sha = get_tag_sha(value) if kind == "tag" else get_branch_sha(value)
return (commit_sha, ref) return (commit_sha, ref)
def find_matching_branches(config, branches): def find_matching_branches(config, branches):
if 'branchtag_regex' in config: if "branchtag_regex" in config:
try: try:
regex = re.compile(config['branchtag_regex']) regex = re.compile(config["branchtag_regex"])
return [branch for branch in branches return [
if matches_ref('refs/heads/' + branch, regex)] branch
except: for branch in branches
pass if matches_ref("refs/heads/" + branch, regex)
]
except:
pass
return branches return branches
def should_skip_commit(metadata): def should_skip_commit(metadata):
if 'commit_info' in metadata: if "commit_info" in metadata:
message = metadata['commit_info']['message'] message = metadata["commit_info"]["message"]
return '[skip build]' in message or '[build skip]' in message return "[skip build]" in message or "[build skip]" in message
return False return False
def raise_if_skipped_build(prepared_build, config): def raise_if_skipped_build(prepared_build, config):
""" Raises a SkipRequestException if the given build should be skipped. """ """ Raises a SkipRequestException if the given build should be skipped. """
# Check to ensure we have metadata. # Check to ensure we have metadata.
if not prepared_build.metadata: if not prepared_build.metadata:
logger.debug('Skipping request due to missing metadata for prepared build') logger.debug("Skipping request due to missing metadata for prepared build")
raise SkipRequestException() raise SkipRequestException()
# Check the branchtag regex. # Check the branchtag regex.
if 'branchtag_regex' in config: if "branchtag_regex" in config:
try: try:
regex = re.compile(config['branchtag_regex']) regex = re.compile(config["branchtag_regex"])
except: except:
regex = re.compile('.*') regex = re.compile(".*")
if not matches_ref(prepared_build.metadata.get('ref'), regex): if not matches_ref(prepared_build.metadata.get("ref"), regex):
raise SkipRequestException() raise SkipRequestException()
# Check the commit message. # Check the commit message.
if should_skip_commit(prepared_build.metadata): if should_skip_commit(prepared_build.metadata):
logger.debug('Skipping request due to commit message request') logger.debug("Skipping request due to commit message request")
raise SkipRequestException() raise SkipRequestException()
def matches_ref(ref, regex): def matches_ref(ref, regex):
match_string = ref.split('/', 1)[1] match_string = ref.split("/", 1)[1]
if not regex: if not regex:
return False return False
m = regex.match(match_string) m = regex.match(match_string)
if not m: if not m:
return False return False
return len(m.group(0)) == len(match_string) return len(m.group(0)) == len(match_string)
def raise_unsupported(): def raise_unsupported():
raise io.UnsupportedOperation raise io.UnsupportedOperation
def get_trigger_config(trigger): def get_trigger_config(trigger):
try: try:
return json.loads(trigger.config) return json.loads(trigger.config)
except ValueError: except ValueError:
return {} return {}

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,18 +11,24 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000' bind = "0.0.0.0:5000"
workers = get_worker_count('local', 2, minimum=2, maximum=8) workers = get_worker_count("local", 2, minimum=2, maximum=8)
worker_class = 'gevent' worker_class = "gevent"
daemon = False daemon = False
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_registry.sock' bind = "unix:/tmp/gunicorn_registry.sock"
workers = get_worker_count('registry', 4, minimum=8, maximum=64) workers = get_worker_count("registry", 4, minimum=8, maximum=64)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting registry gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting registry gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_secscan.sock' bind = "unix:/tmp/gunicorn_secscan.sock"
workers = get_worker_count('secscan', 2, minimum=2, maximum=4) workers = get_worker_count("secscan", 2, minimum=2, maximum=4)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting secscan gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting secscan gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,18 +11,21 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_verbs.sock' bind = "unix:/tmp/gunicorn_verbs.sock"
workers = get_worker_count('verbs', 2, minimum=2, maximum=32) workers = get_worker_count("verbs", 2, minimum=2, maximum=32)
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
timeout = 2000 # Because sync workers timeout = 2000 # Because sync workers
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting verbs gunicorn with %s workers and sync worker class', workers) logger.debug(
"Starting verbs gunicorn with %s workers and sync worker class", workers
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -11,18 +12,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_web.sock' bind = "unix:/tmp/gunicorn_web.sock"
workers = get_worker_count('web', 2, minimum=2, maximum=32) workers = get_worker_count("web", 2, minimum=2, maximum=32)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting web gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting web gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -7,120 +7,130 @@ import jinja2
QUAYPATH = os.getenv("QUAYPATH", ".") QUAYPATH = os.getenv("QUAYPATH", ".")
QUAYDIR = os.getenv("QUAYDIR", "/") QUAYDIR = os.getenv("QUAYDIR", "/")
QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf")) QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf"))
STATIC_DIR = os.path.join(QUAYDIR, 'static') STATIC_DIR = os.path.join(QUAYDIR, "static")
SSL_PROTOCOL_DEFAULTS = ['TLSv1', 'TLSv1.1', 'TLSv1.2'] SSL_PROTOCOL_DEFAULTS = ["TLSv1", "TLSv1.1", "TLSv1.2"]
SSL_CIPHER_DEFAULTS = [ SSL_CIPHER_DEFAULTS = [
'ECDHE-RSA-AES128-GCM-SHA256', "ECDHE-RSA-AES128-GCM-SHA256",
'ECDHE-ECDSA-AES128-GCM-SHA256', "ECDHE-ECDSA-AES128-GCM-SHA256",
'ECDHE-RSA-AES256-GCM-SHA384', "ECDHE-RSA-AES256-GCM-SHA384",
'ECDHE-ECDSA-AES256-GCM-SHA384', "ECDHE-ECDSA-AES256-GCM-SHA384",
'DHE-RSA-AES128-GCM-SHA256', "DHE-RSA-AES128-GCM-SHA256",
'DHE-DSS-AES128-GCM-SHA256', "DHE-DSS-AES128-GCM-SHA256",
'kEDH+AESGCM', "kEDH+AESGCM",
'ECDHE-RSA-AES128-SHA256', "ECDHE-RSA-AES128-SHA256",
'ECDHE-ECDSA-AES128-SHA256', "ECDHE-ECDSA-AES128-SHA256",
'ECDHE-RSA-AES128-SHA', "ECDHE-RSA-AES128-SHA",
'ECDHE-ECDSA-AES128-SHA', "ECDHE-ECDSA-AES128-SHA",
'ECDHE-RSA-AES256-SHA384', "ECDHE-RSA-AES256-SHA384",
'ECDHE-ECDSA-AES256-SHA384', "ECDHE-ECDSA-AES256-SHA384",
'ECDHE-RSA-AES256-SHA', "ECDHE-RSA-AES256-SHA",
'ECDHE-ECDSA-AES256-SHA', "ECDHE-ECDSA-AES256-SHA",
'DHE-RSA-AES128-SHA256', "DHE-RSA-AES128-SHA256",
'DHE-RSA-AES128-SHA', "DHE-RSA-AES128-SHA",
'DHE-DSS-AES128-SHA256', "DHE-DSS-AES128-SHA256",
'DHE-RSA-AES256-SHA256', "DHE-RSA-AES256-SHA256",
'DHE-DSS-AES256-SHA', "DHE-DSS-AES256-SHA",
'DHE-RSA-AES256-SHA', "DHE-RSA-AES256-SHA",
'AES128-GCM-SHA256', "AES128-GCM-SHA256",
'AES256-GCM-SHA384', "AES256-GCM-SHA384",
'AES128-SHA256', "AES128-SHA256",
'AES256-SHA256', "AES256-SHA256",
'AES128-SHA', "AES128-SHA",
'AES256-SHA', "AES256-SHA",
'AES', "AES",
'CAMELLIA', "CAMELLIA",
'!3DES', "!3DES",
'!aNULL', "!aNULL",
'!eNULL', "!eNULL",
'!EXPORT', "!EXPORT",
'!DES', "!DES",
'!RC4', "!RC4",
'!MD5', "!MD5",
'!PSK', "!PSK",
'!aECDH', "!aECDH",
'!EDH-DSS-DES-CBC3-SHA', "!EDH-DSS-DES-CBC3-SHA",
'!EDH-RSA-DES-CBC3-SHA', "!EDH-RSA-DES-CBC3-SHA",
'!KRB5-DES-CBC3-SHA', "!KRB5-DES-CBC3-SHA",
] ]
def write_config(filename, **kwargs):
with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(kwargs)
with open(filename, 'w') as f: def write_config(filename, **kwargs):
f.write(rendered) with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(kwargs)
with open(filename, "w") as f:
f.write(rendered)
def generate_nginx_config(config): def generate_nginx_config(config):
""" """
Generates nginx config from the app config Generates nginx config from the app config
""" """
config = config or {} config = config or {}
use_https = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.key')) use_https = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.key"))
use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.old.key')) use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.old.key"))
v1_only_domain = config.get('V1_ONLY_DOMAIN', None) v1_only_domain = config.get("V1_ONLY_DOMAIN", None)
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
ssl_protocols = config.get('SSL_PROTOCOLS', SSL_PROTOCOL_DEFAULTS) ssl_protocols = config.get("SSL_PROTOCOLS", SSL_PROTOCOL_DEFAULTS)
ssl_ciphers = config.get('SSL_CIPHERS', SSL_CIPHER_DEFAULTS) ssl_ciphers = config.get("SSL_CIPHERS", SSL_CIPHER_DEFAULTS)
write_config(os.path.join(QUAYCONF_DIR, 'nginx/nginx.conf'), use_https=use_https, write_config(
use_old_certs=use_old_certs, os.path.join(QUAYCONF_DIR, "nginx/nginx.conf"),
enable_rate_limits=enable_rate_limits, use_https=use_https,
v1_only_domain=v1_only_domain, use_old_certs=use_old_certs,
ssl_protocols=ssl_protocols, enable_rate_limits=enable_rate_limits,
ssl_ciphers=':'.join(ssl_ciphers)) v1_only_domain=v1_only_domain,
ssl_protocols=ssl_protocols,
ssl_ciphers=":".join(ssl_ciphers),
)
def generate_server_config(config): def generate_server_config(config):
""" """
Generates server config from the app config Generates server config from the app config
""" """
config = config or {} config = config or {}
tuf_server = config.get('TUF_SERVER', None) tuf_server = config.get("TUF_SERVER", None)
tuf_host = config.get('TUF_HOST', None) tuf_host = config.get("TUF_HOST", None)
signing_enabled = config.get('FEATURE_SIGNING', False) signing_enabled = config.get("FEATURE_SIGNING", False)
maximum_layer_size = config.get('MAXIMUM_LAYER_SIZE', '20G') maximum_layer_size = config.get("MAXIMUM_LAYER_SIZE", "20G")
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config( write_config(
os.path.join(QUAYCONF_DIR, 'nginx/server-base.conf'), tuf_server=tuf_server, tuf_host=tuf_host, os.path.join(QUAYCONF_DIR, "nginx/server-base.conf"),
signing_enabled=signing_enabled, maximum_layer_size=maximum_layer_size, tuf_server=tuf_server,
enable_rate_limits=enable_rate_limits, tuf_host=tuf_host,
static_dir=STATIC_DIR) signing_enabled=signing_enabled,
maximum_layer_size=maximum_layer_size,
enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR,
)
def generate_rate_limiting_config(config): def generate_rate_limiting_config(config):
""" """
Generates rate limiting config from the app config Generates rate limiting config from the app config
""" """
config = config or {} config = config or {}
non_rate_limited_namespaces = config.get('NON_RATE_LIMITED_NAMESPACES') or set() non_rate_limited_namespaces = config.get("NON_RATE_LIMITED_NAMESPACES") or set()
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config( write_config(
os.path.join(QUAYCONF_DIR, 'nginx/rate-limiting.conf'), os.path.join(QUAYCONF_DIR, "nginx/rate-limiting.conf"),
non_rate_limited_namespaces=non_rate_limited_namespaces, non_rate_limited_namespaces=non_rate_limited_namespaces,
enable_rate_limits=enable_rate_limits, enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR) static_dir=STATIC_DIR,
)
if __name__ == "__main__": if __name__ == "__main__":
if os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/config.yaml')): if os.path.exists(os.path.join(QUAYCONF_DIR, "stack/config.yaml")):
with open(os.path.join(QUAYCONF_DIR, 'stack/config.yaml'), 'r') as f: with open(os.path.join(QUAYCONF_DIR, "stack/config.yaml"), "r") as f:
config = yaml.load(f) config = yaml.load(f)
else: else:
config = None config = None
generate_rate_limiting_config(config) generate_rate_limiting_config(config)
generate_server_config(config) generate_server_config(config)
generate_nginx_config(config) generate_nginx_config(config)

View file

@ -12,136 +12,74 @@ QUAY_OVERRIDE_SERVICES = os.getenv("QUAY_OVERRIDE_SERVICES", [])
def default_services(): def default_services():
return { return {
"blobuploadcleanupworker": { "blobuploadcleanupworker": {"autostart": "true"},
"autostart": "true" "buildlogsarchiver": {"autostart": "true"},
}, "builder": {"autostart": "true"},
"buildlogsarchiver": { "chunkcleanupworker": {"autostart": "true"},
"autostart": "true" "expiredappspecifictokenworker": {"autostart": "true"},
}, "exportactionlogsworker": {"autostart": "true"},
"builder": { "gcworker": {"autostart": "true"},
"autostart": "true" "globalpromstats": {"autostart": "true"},
}, "labelbackfillworker": {"autostart": "true"},
"chunkcleanupworker": { "logrotateworker": {"autostart": "true"},
"autostart": "true" "namespacegcworker": {"autostart": "true"},
}, "notificationworker": {"autostart": "true"},
"expiredappspecifictokenworker": { "queuecleanupworker": {"autostart": "true"},
"autostart": "true" "repositoryactioncounter": {"autostart": "true"},
}, "security_notification_worker": {"autostart": "true"},
"exportactionlogsworker": { "securityworker": {"autostart": "true"},
"autostart": "true" "storagereplication": {"autostart": "true"},
}, "tagbackfillworker": {"autostart": "true"},
"gcworker": { "teamsyncworker": {"autostart": "true"},
"autostart": "true" "dnsmasq": {"autostart": "true"},
}, "gunicorn-registry": {"autostart": "true"},
"globalpromstats": { "gunicorn-secscan": {"autostart": "true"},
"autostart": "true" "gunicorn-verbs": {"autostart": "true"},
}, "gunicorn-web": {"autostart": "true"},
"labelbackfillworker": { "ip-resolver-update-worker": {"autostart": "true"},
"autostart": "true" "jwtproxy": {"autostart": "true"},
}, "memcache": {"autostart": "true"},
"logrotateworker": { "nginx": {"autostart": "true"},
"autostart": "true" "prometheus-aggregator": {"autostart": "true"},
}, "servicekey": {"autostart": "true"},
"namespacegcworker": { "repomirrorworker": {"autostart": "false"},
"autostart": "true"
},
"notificationworker": {
"autostart": "true"
},
"queuecleanupworker": {
"autostart": "true"
},
"repositoryactioncounter": {
"autostart": "true"
},
"security_notification_worker": {
"autostart": "true"
},
"securityworker": {
"autostart": "true"
},
"storagereplication": {
"autostart": "true"
},
"tagbackfillworker": {
"autostart": "true"
},
"teamsyncworker": {
"autostart": "true"
},
"dnsmasq": {
"autostart": "true"
},
"gunicorn-registry": {
"autostart": "true"
},
"gunicorn-secscan": {
"autostart": "true"
},
"gunicorn-verbs": {
"autostart": "true"
},
"gunicorn-web": {
"autostart": "true"
},
"ip-resolver-update-worker": {
"autostart": "true"
},
"jwtproxy": {
"autostart": "true"
},
"memcache": {
"autostart": "true"
},
"nginx": {
"autostart": "true"
},
"prometheus-aggregator": {
"autostart": "true"
},
"servicekey": {
"autostart": "true"
},
"repomirrorworker": {
"autostart": "false"
} }
}
def generate_supervisord_config(filename, config): def generate_supervisord_config(filename, config):
with open(filename + ".jnj") as f: with open(filename + ".jnj") as f:
template = jinja2.Template(f.read()) template = jinja2.Template(f.read())
rendered = template.render(config=config) rendered = template.render(config=config)
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(rendered) f.write(rendered)
def limit_services(config, enabled_services): def limit_services(config, enabled_services):
if enabled_services == []: if enabled_services == []:
return return
for service in config.keys(): for service in config.keys():
if service in enabled_services: if service in enabled_services:
config[service]["autostart"] = "true" config[service]["autostart"] = "true"
else: else:
config[service]["autostart"] = "false" config[service]["autostart"] = "false"
def override_services(config, override_services): def override_services(config, override_services):
if override_services == []: if override_services == []:
return return
for service in config.keys(): for service in config.keys():
if service + "=true" in override_services: if service + "=true" in override_services:
config[service]["autostart"] = "true" config[service]["autostart"] = "true"
elif service + "=false" in override_services: elif service + "=false" in override_services:
config[service]["autostart"] = "false" config[service]["autostart"] = "false"
if __name__ == "__main__": if __name__ == "__main__":
config = default_services() config = default_services()
limit_services(config, QUAY_SERVICES) limit_services(config, QUAY_SERVICES)
override_services(config, QUAY_OVERRIDE_SERVICES) override_services(config, QUAY_OVERRIDE_SERVICES)
generate_supervisord_config(os.path.join(QUAYCONF_DIR, 'supervisord.conf'), config) generate_supervisord_config(os.path.join(QUAYCONF_DIR, "supervisord.conf"), config)

View file

@ -6,17 +6,23 @@ import jinja2
from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services
def render_supervisord_conf(config): def render_supervisord_conf(config):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj")) as f: with open(
template = jinja2.Template(f.read()) os.path.join(
return template.render(config=config) os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj"
)
) as f:
template = jinja2.Template(f.read())
return template.render(config=config)
def test_supervisord_conf_create_defaults(): def test_supervisord_conf_create_defaults():
config = default_services() config = default_services()
limit_services(config, []) limit_services(config, [])
rendered = render_supervisord_conf(config) rendered = render_supervisord_conf(config)
expected = """[supervisord] expected = """[supervisord]
nodaemon=true nodaemon=true
[unix_http_server] [unix_http_server]
@ -392,14 +398,15 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true stdout_events_enabled = true
stderr_events_enabled = true stderr_events_enabled = true
# EOF NO NEWLINE""" # EOF NO NEWLINE"""
assert rendered == expected assert rendered == expected
def test_supervisord_conf_create_all_overrides(): def test_supervisord_conf_create_all_overrides():
config = default_services() config = default_services()
limit_services(config, "servicekey,prometheus-aggregator") limit_services(config, "servicekey,prometheus-aggregator")
rendered = render_supervisord_conf(config) rendered = render_supervisord_conf(config)
expected = """[supervisord] expected = """[supervisord]
nodaemon=true nodaemon=true
[unix_http_server] [unix_http_server]
@ -775,4 +782,4 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true stdout_events_enabled = true
stderr_events_enabled = true stderr_events_enabled = true
# EOF NO NEWLINE""" # EOF NO NEWLINE"""
assert rendered == expected assert rendered == expected

943
config.py

File diff suppressed because it is too large Load diff

View file

@ -7,31 +7,31 @@ import subprocess
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/")) CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/') STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/') STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/') STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/') TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
def _get_version_number_changelog(): def _get_version_number_changelog():
try: try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f: with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0) return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError: except IOError:
return '' return ""
def _get_git_sha(): def _get_git_sha():
if os.path.exists("GIT_HEAD"): if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f: with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read() return f.read()
else: else:
try: try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8] return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError): except (OSError, subprocess.CalledProcessError):
pass pass
return "unknown" return "unknown"
__version__ = _get_version_number_changelog() __version__ = _get_version_number_changelog()

View file

@ -15,28 +15,29 @@ app = Flask(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, 'config_app/conf/stack') OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, "config_app/conf/stack")
INIT_SCRIPTS_LOCATION = '/conf/init/' INIT_SCRIPTS_LOCATION = "/conf/init/"
is_testing = 'TEST' in os.environ is_testing = "TEST" in os.environ
is_kubernetes = IS_KUBERNETES is_kubernetes = IS_KUBERNETES
logger.debug('Configuration is on a kubernetes deployment: %s' % IS_KUBERNETES) logger.debug("Configuration is on a kubernetes deployment: %s" % IS_KUBERNETES)
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py', config_provider = get_config_provider(
testing=is_testing) OVERRIDE_CONFIG_DIRECTORY, "config.yaml", "config.py", testing=is_testing
)
if is_testing: if is_testing:
from test.testconfig import TestConfig from test.testconfig import TestConfig
logger.debug('Loading test config.') logger.debug("Loading test config.")
app.config.from_object(TestConfig()) app.config.from_object(TestConfig())
else: else:
from config import DefaultConfig from config import DefaultConfig
logger.debug('Loading default config.') logger.debug("Loading default config.")
app.config.from_object(DefaultConfig()) app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter) app.teardown_request(database.close_db_filter)
# Load the override config via the provider. # Load the override config via the provider.
config_provider.update_app_config(app.config) config_provider.update_app_config(app.config)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -9,18 +10,24 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000' bind = "0.0.0.0:5000"
workers = 1 workers = 1
worker_class = 'gevent' worker_class = "gevent"
daemon = False daemon = False
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,17 +11,23 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = 'unix:/tmp/gunicorn_web.sock' bind = "unix:/tmp/gunicorn_web.sock"
workers = 1 workers = 1
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -3,6 +3,6 @@ from config_app.c_app import app as application
# Bind all of the blueprints # Bind all of the blueprints
import config_web import config_web
if __name__ == '__main__': if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False) logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0') application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -13,141 +13,153 @@ from config_app.c_app import app, IS_KUBERNETES
from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
api_bp = Blueprint('api', __name__) api_bp = Blueprint("api", __name__)
CROSS_DOMAIN_HEADERS = ['Authorization', 'Content-Type', 'X-Requested-With'] CROSS_DOMAIN_HEADERS = ["Authorization", "Content-Type", "X-Requested-With"]
class ApiExceptionHandlingApi(Api): class ApiExceptionHandlingApi(Api):
pass pass
@crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS) @crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS)
def handle_error(self, error): def handle_error(self, error):
return super(ApiExceptionHandlingApi, self).handle_error(error) return super(ApiExceptionHandlingApi, self).handle_error(error)
api = ApiExceptionHandlingApi() api = ApiExceptionHandlingApi()
api.init_app(api_bp) api.init_app(api_bp)
def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None): def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
if not metadata: if not metadata:
metadata = {} metadata = {}
if repo: if repo:
repo_name = repo.name repo_name = repo.name
model.log.log_action(
kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata
)
model.log.log_action(kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata)
def format_date(date): def format_date(date):
""" Output an RFC822 date format. """ """ Output an RFC822 date format. """
if date is None: if date is None:
return None return None
return formatdate(timegm(date.utctimetuple())) return formatdate(timegm(date.utctimetuple()))
def resource(*urls, **kwargs): def resource(*urls, **kwargs):
def wrapper(api_resource): def wrapper(api_resource):
if not api_resource: if not api_resource:
return None return None
api_resource.registered = True api_resource.registered = True
api.add_resource(api_resource, *urls, **kwargs) api.add_resource(api_resource, *urls, **kwargs)
return api_resource return api_resource
return wrapper return wrapper
class ApiResource(Resource): class ApiResource(Resource):
registered = False registered = False
method_decorators = [] method_decorators = []
def options(self): def options(self):
return None, 200 return None, 200
def add_method_metadata(name, value): def add_method_metadata(name, value):
def modifier(func): def modifier(func):
if func is None: if func is None:
return None return None
if '__api_metadata' not in dir(func): if "__api_metadata" not in dir(func):
func.__api_metadata = {} func.__api_metadata = {}
func.__api_metadata[name] = value func.__api_metadata[name] = value
return func return func
return modifier return modifier
def method_metadata(func, name): def method_metadata(func, name):
if func is None: if func is None:
return None return None
if '__api_metadata' in dir(func): if "__api_metadata" in dir(func):
return func.__api_metadata.get(name, None) return func.__api_metadata.get(name, None)
return None return None
def no_cache(f): def no_cache(f):
@wraps(f) @wraps(f)
def add_no_cache(*args, **kwargs): def add_no_cache(*args, **kwargs):
response = f(*args, **kwargs) response = f(*args, **kwargs)
if response is not None: if response is not None:
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
return response return response
return add_no_cache
return add_no_cache
def define_json_response(schema_name): def define_json_response(schema_name):
def wrapper(func): def wrapper(func):
@add_method_metadata('response_schema', schema_name) @add_method_metadata("response_schema", schema_name)
@wraps(func) @wraps(func)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name] schema = self.schemas[schema_name]
resp = func(self, *args, **kwargs) resp = func(self, *args, **kwargs)
if app.config['TESTING']: if app.config["TESTING"]:
try: try:
validate(resp, schema) validate(resp, schema)
except ValidationError as ex: except ValidationError as ex:
raise InvalidResponse(ex.message) raise InvalidResponse(ex.message)
return resp return resp
return wrapped
return wrapper return wrapped
return wrapper
def validate_json_request(schema_name, optional=False): def validate_json_request(schema_name, optional=False):
def wrapper(func): def wrapper(func):
@add_method_metadata('request_schema', schema_name) @add_method_metadata("request_schema", schema_name)
@wraps(func) @wraps(func)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name] schema = self.schemas[schema_name]
try: try:
json_data = request.get_json() json_data = request.get_json()
if json_data is None: if json_data is None:
if not optional: if not optional:
raise InvalidRequest('Missing JSON body') raise InvalidRequest("Missing JSON body")
else: else:
validate(json_data, schema) validate(json_data, schema)
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
except ValidationError as ex: except ValidationError as ex:
raise InvalidRequest(ex.message) raise InvalidRequest(ex.message)
return wrapped
return wrapper return wrapped
return wrapper
def kubernetes_only(f): def kubernetes_only(f):
""" Aborts the request with a 400 if the app is not running on kubernetes """ """ Aborts the request with a 400 if the app is not running on kubernetes """
@wraps(f)
def abort_if_not_kube(*args, **kwargs):
if not IS_KUBERNETES:
abort(400)
return f(*args, **kwargs) @wraps(f)
return abort_if_not_kube def abort_if_not_kube(*args, **kwargs):
if not IS_KUBERNETES:
abort(400)
nickname = partial(add_method_metadata, 'nickname') return f(*args, **kwargs)
return abort_if_not_kube
nickname = partial(add_method_metadata, "nickname")
import config_app.config_endpoints.api.discovery import config_app.config_endpoints.api.discovery

View file

@ -5,250 +5,253 @@ from collections import OrderedDict
from config_app.c_app import app from config_app.c_app import app
from config_app.config_endpoints.api import method_metadata from config_app.config_endpoints.api import method_metadata
from config_app.config_endpoints.common import fully_qualified_name, PARAM_REGEX, TYPE_CONVERTER from config_app.config_endpoints.common import (
fully_qualified_name,
PARAM_REGEX,
TYPE_CONVERTER,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def generate_route_data(): def generate_route_data():
include_internal = True include_internal = True
compact = True compact = True
def swagger_parameter(name, description, kind='path', param_type='string', required=True, def swagger_parameter(
enum=None, schema=None): name,
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject description,
parameter_info = { kind="path",
'name': name, param_type="string",
'in': kind, required=True,
'required': required enum=None,
} schema=None,
):
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
parameter_info = {"name": name, "in": kind, "required": required}
if schema: if schema:
parameter_info['schema'] = { parameter_info["schema"] = {"$ref": "#/definitions/%s" % schema}
'$ref': '#/definitions/%s' % schema
}
else:
parameter_info['type'] = param_type
if enum is not None and len(list(enum)) > 0:
parameter_info['enum'] = list(enum)
return parameter_info
paths = {}
models = {}
tags = []
tags_added = set()
operation_ids = set()
for rule in app.url_map.iter_rules():
endpoint_method = app.view_functions[rule.endpoint]
# Verify that we have a view class for this API method.
if not 'view_class' in dir(endpoint_method):
continue
view_class = endpoint_method.view_class
# Hide the class if it is internal.
internal = method_metadata(view_class, 'internal')
if not include_internal and internal:
continue
# Build the tag.
parts = fully_qualified_name(view_class).split('.')
tag_name = parts[-2]
if not tag_name in tags_added:
tags_added.add(tag_name)
tags.append({
'name': tag_name,
'description': (sys.modules[view_class.__module__].__doc__ or '').strip()
})
# Build the Swagger data for the path.
swagger_path = PARAM_REGEX.sub(r'{\2}', rule.rule)
full_name = fully_qualified_name(view_class)
path_swagger = {
'x-name': full_name,
'x-path': swagger_path,
'x-tag': tag_name
}
related_user_res = method_metadata(view_class, 'related_user_resource')
if related_user_res is not None:
path_swagger['x-user-related'] = fully_qualified_name(related_user_res)
paths[swagger_path] = path_swagger
# Add any global path parameters.
param_data_map = view_class.__api_path_params if '__api_path_params' in dir(
view_class) else {}
if param_data_map:
path_parameters_swagger = []
for path_parameter in param_data_map:
description = param_data_map[path_parameter].get('description')
path_parameters_swagger.append(swagger_parameter(path_parameter, description))
path_swagger['parameters'] = path_parameters_swagger
# Add the individual HTTP operations.
method_names = list(rule.methods.difference(['HEAD', 'OPTIONS']))
for method_name in method_names:
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
method = getattr(view_class, method_name.lower(), None)
if method is None:
logger.debug('Unable to find method for %s in class %s', method_name, view_class)
continue
operationId = method_metadata(method, 'nickname')
operation_swagger = {
'operationId': operationId,
'parameters': [],
}
if operationId is None:
continue
if operationId in operation_ids:
raise Exception('Duplicate operation Id: %s' % operationId)
operation_ids.add(operationId)
# Mark the method as internal.
internal = method_metadata(method, 'internal')
if internal is not None:
operation_swagger['x-internal'] = True
if include_internal:
requires_fresh_login = method_metadata(method, 'requires_fresh_login')
if requires_fresh_login is not None:
operation_swagger['x-requires-fresh-login'] = True
# Add the path parameters.
if rule.arguments:
for path_parameter in rule.arguments:
description = param_data_map.get(path_parameter, {}).get('description')
operation_swagger['parameters'].append(
swagger_parameter(path_parameter, description))
# Add the query parameters.
if '__api_query_params' in dir(method):
for query_parameter_info in method.__api_query_params:
name = query_parameter_info['name']
description = query_parameter_info['help']
param_type = TYPE_CONVERTER[query_parameter_info['type']]
required = query_parameter_info['required']
operation_swagger['parameters'].append(
swagger_parameter(name, description, kind='query',
param_type=param_type,
required=required,
enum=query_parameter_info['choices']))
# Add the OAuth security block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
scope = method_metadata(method, 'oauth2_scope')
if scope and not compact:
operation_swagger['security'] = [{'oauth2_implicit': [scope.scope]}]
# Add the responses block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
response_schema_name = method_metadata(method, 'response_schema')
if not compact:
if response_schema_name:
models[response_schema_name] = view_class.schemas[response_schema_name]
models['ApiError'] = {
'type': 'object',
'properties': {
'status': {
'type': 'integer',
'description': 'Status code of the response.'
},
'type': {
'type': 'string',
'description': 'Reference to the type of the error.'
},
'detail': {
'type': 'string',
'description': 'Details about the specific instance of the error.'
},
'title': {
'type': 'string',
'description': 'Unique error code to identify the type of error.'
},
'error_message': {
'type': 'string',
'description': 'Deprecated; alias for detail'
},
'error_type': {
'type': 'string',
'description': 'Deprecated; alias for detail'
}
},
'required': [
'status',
'type',
'title',
]
}
responses = {
'400': {
'description': 'Bad Request',
},
'401': {
'description': 'Session required',
},
'403': {
'description': 'Unauthorized access',
},
'404': {
'description': 'Not found',
},
}
for _, body in responses.items():
body['schema'] = {'$ref': '#/definitions/ApiError'}
if method_name == 'DELETE':
responses['204'] = {
'description': 'Deleted'
}
elif method_name == 'POST':
responses['201'] = {
'description': 'Successful creation'
}
else: else:
responses['200'] = { parameter_info["type"] = param_type
'description': 'Successful invocation'
}
if response_schema_name: if enum is not None and len(list(enum)) > 0:
responses['200']['schema'] = { parameter_info["enum"] = list(enum)
'$ref': '#/definitions/%s' % response_schema_name
}
operation_swagger['responses'] = responses return parameter_info
# Add the request block. paths = {}
request_schema_name = method_metadata(method, 'request_schema') models = {}
if request_schema_name and not compact: tags = []
models[request_schema_name] = view_class.schemas[request_schema_name] tags_added = set()
operation_ids = set()
operation_swagger['parameters'].append( for rule in app.url_map.iter_rules():
swagger_parameter('body', 'Request body contents.', kind='body', endpoint_method = app.view_functions[rule.endpoint]
schema=request_schema_name))
# Add the operation to the parent path. # Verify that we have a view class for this API method.
if not internal or (internal and include_internal): if not "view_class" in dir(endpoint_method):
path_swagger[method_name.lower()] = operation_swagger continue
tags.sort(key=lambda t: t['name']) view_class = endpoint_method.view_class
paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]['x-tag']))
if compact: # Hide the class if it is internal.
return {'paths': paths} internal = method_metadata(view_class, "internal")
if not include_internal and internal:
continue
# Build the tag.
parts = fully_qualified_name(view_class).split(".")
tag_name = parts[-2]
if not tag_name in tags_added:
tags_added.add(tag_name)
tags.append(
{
"name": tag_name,
"description": (
sys.modules[view_class.__module__].__doc__ or ""
).strip(),
}
)
# Build the Swagger data for the path.
swagger_path = PARAM_REGEX.sub(r"{\2}", rule.rule)
full_name = fully_qualified_name(view_class)
path_swagger = {"x-name": full_name, "x-path": swagger_path, "x-tag": tag_name}
related_user_res = method_metadata(view_class, "related_user_resource")
if related_user_res is not None:
path_swagger["x-user-related"] = fully_qualified_name(related_user_res)
paths[swagger_path] = path_swagger
# Add any global path parameters.
param_data_map = (
view_class.__api_path_params
if "__api_path_params" in dir(view_class)
else {}
)
if param_data_map:
path_parameters_swagger = []
for path_parameter in param_data_map:
description = param_data_map[path_parameter].get("description")
path_parameters_swagger.append(
swagger_parameter(path_parameter, description)
)
path_swagger["parameters"] = path_parameters_swagger
# Add the individual HTTP operations.
method_names = list(rule.methods.difference(["HEAD", "OPTIONS"]))
for method_name in method_names:
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
method = getattr(view_class, method_name.lower(), None)
if method is None:
logger.debug(
"Unable to find method for %s in class %s", method_name, view_class
)
continue
operationId = method_metadata(method, "nickname")
operation_swagger = {"operationId": operationId, "parameters": []}
if operationId is None:
continue
if operationId in operation_ids:
raise Exception("Duplicate operation Id: %s" % operationId)
operation_ids.add(operationId)
# Mark the method as internal.
internal = method_metadata(method, "internal")
if internal is not None:
operation_swagger["x-internal"] = True
if include_internal:
requires_fresh_login = method_metadata(method, "requires_fresh_login")
if requires_fresh_login is not None:
operation_swagger["x-requires-fresh-login"] = True
# Add the path parameters.
if rule.arguments:
for path_parameter in rule.arguments:
description = param_data_map.get(path_parameter, {}).get(
"description"
)
operation_swagger["parameters"].append(
swagger_parameter(path_parameter, description)
)
# Add the query parameters.
if "__api_query_params" in dir(method):
for query_parameter_info in method.__api_query_params:
name = query_parameter_info["name"]
description = query_parameter_info["help"]
param_type = TYPE_CONVERTER[query_parameter_info["type"]]
required = query_parameter_info["required"]
operation_swagger["parameters"].append(
swagger_parameter(
name,
description,
kind="query",
param_type=param_type,
required=required,
enum=query_parameter_info["choices"],
)
)
# Add the OAuth security block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
scope = method_metadata(method, "oauth2_scope")
if scope and not compact:
operation_swagger["security"] = [{"oauth2_implicit": [scope.scope]}]
# Add the responses block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
response_schema_name = method_metadata(method, "response_schema")
if not compact:
if response_schema_name:
models[response_schema_name] = view_class.schemas[
response_schema_name
]
models["ApiError"] = {
"type": "object",
"properties": {
"status": {
"type": "integer",
"description": "Status code of the response.",
},
"type": {
"type": "string",
"description": "Reference to the type of the error.",
},
"detail": {
"type": "string",
"description": "Details about the specific instance of the error.",
},
"title": {
"type": "string",
"description": "Unique error code to identify the type of error.",
},
"error_message": {
"type": "string",
"description": "Deprecated; alias for detail",
},
"error_type": {
"type": "string",
"description": "Deprecated; alias for detail",
},
},
"required": ["status", "type", "title"],
}
responses = {
"400": {"description": "Bad Request"},
"401": {"description": "Session required"},
"403": {"description": "Unauthorized access"},
"404": {"description": "Not found"},
}
for _, body in responses.items():
body["schema"] = {"$ref": "#/definitions/ApiError"}
if method_name == "DELETE":
responses["204"] = {"description": "Deleted"}
elif method_name == "POST":
responses["201"] = {"description": "Successful creation"}
else:
responses["200"] = {"description": "Successful invocation"}
if response_schema_name:
responses["200"]["schema"] = {
"$ref": "#/definitions/%s" % response_schema_name
}
operation_swagger["responses"] = responses
# Add the request block.
request_schema_name = method_metadata(method, "request_schema")
if request_schema_name and not compact:
models[request_schema_name] = view_class.schemas[request_schema_name]
operation_swagger["parameters"].append(
swagger_parameter(
"body",
"Request body contents.",
kind="body",
schema=request_schema_name,
)
)
# Add the operation to the parent path.
if not internal or (internal and include_internal):
path_swagger[method_name.lower()] = operation_swagger
tags.sort(key=lambda t: t["name"])
paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]["x-tag"]))
if compact:
return {"paths": paths}

View file

@ -6,138 +6,152 @@ from config_app.config_util.config import get_config_as_kube_secret
from data.database import configure from data.database import configure
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname, kubernetes_only, validate_json_request from config_app.config_endpoints.api import (
from config_app.config_util.k8saccessor import KubernetesAccessorSingleton, K8sApiException resource,
ApiResource,
nickname,
kubernetes_only,
validate_json_request,
)
from config_app.config_util.k8saccessor import (
KubernetesAccessorSingleton,
K8sApiException,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@resource('/v1/kubernetes/deployments/')
@resource("/v1/kubernetes/deployments/")
class SuperUserKubernetesDeployment(ApiResource): class SuperUserKubernetesDeployment(ApiResource):
""" Resource for the getting the status of Red Hat Quay deployments and cycling them """ """ Resource for the getting the status of Red Hat Quay deployments and cycling them """
schemas = {
'ValidateDeploymentNames': { schemas = {
'type': 'object', "ValidateDeploymentNames": {
'description': 'Validates deployment names for cycling', "type": "object",
'required': [ "description": "Validates deployment names for cycling",
'deploymentNames' "required": ["deploymentNames"],
], "properties": {
'properties': { "deploymentNames": {
'deploymentNames': { "type": "array",
'type': 'array', "description": "The names of the deployments to cycle",
'description': 'The names of the deployments to cycle' }
}, },
}, }
} }
}
@kubernetes_only @kubernetes_only
@nickname('scGetNumDeployments') @nickname("scGetNumDeployments")
def get(self): def get(self):
return KubernetesAccessorSingleton.get_instance().get_qe_deployments() return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
@kubernetes_only @kubernetes_only
@validate_json_request('ValidateDeploymentNames') @validate_json_request("ValidateDeploymentNames")
@nickname('scCycleQEDeployments') @nickname("scCycleQEDeployments")
def put(self): def put(self):
deployment_names = request.get_json()['deploymentNames'] deployment_names = request.get_json()["deploymentNames"]
return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(deployment_names) return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(
deployment_names
)
@resource('/v1/kubernetes/deployment/<deployment>/status') @resource("/v1/kubernetes/deployment/<deployment>/status")
class QEDeploymentRolloutStatus(ApiResource): class QEDeploymentRolloutStatus(ApiResource):
@kubernetes_only @kubernetes_only
@nickname('scGetDeploymentRolloutStatus') @nickname("scGetDeploymentRolloutStatus")
def get(self, deployment): def get(self, deployment):
deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(deployment) deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(
return { deployment
'status': deployment_rollout_status.status, )
'message': deployment_rollout_status.message, return {
} "status": deployment_rollout_status.status,
"message": deployment_rollout_status.message,
}
@resource('/v1/kubernetes/deployments/rollback') @resource("/v1/kubernetes/deployments/rollback")
class QEDeploymentRollback(ApiResource): class QEDeploymentRollback(ApiResource):
""" Resource for rolling back deployments """ """ Resource for rolling back deployments """
schemas = {
'ValidateDeploymentNames': {
'type': 'object',
'description': 'Validates deployment names for rolling back',
'required': [
'deploymentNames'
],
'properties': {
'deploymentNames': {
'type': 'array',
'description': 'The names of the deployments to rollback'
},
},
}
}
@kubernetes_only schemas = {
@nickname('scRollbackDeployments') "ValidateDeploymentNames": {
@validate_json_request('ValidateDeploymentNames') "type": "object",
def post(self): "description": "Validates deployment names for rolling back",
""" "required": ["deploymentNames"],
"properties": {
"deploymentNames": {
"type": "array",
"description": "The names of the deployments to rollback",
}
},
}
}
@kubernetes_only
@nickname("scRollbackDeployments")
@validate_json_request("ValidateDeploymentNames")
def post(self):
"""
Returns the config to its original state and rolls back deployments Returns the config to its original state and rolls back deployments
:return: :return:
""" """
deployment_names = request.get_json()['deploymentNames'] deployment_names = request.get_json()["deploymentNames"]
# To roll back a deployment, we must do 2 things: # To roll back a deployment, we must do 2 things:
# 1. Roll back the config secret to its old value (discarding changes we made in this session) # 1. Roll back the config secret to its old value (discarding changes we made in this session)
# 2. Trigger a rollback to the previous revision, so that the pods will be restarted with # 2. Trigger a rollback to the previous revision, so that the pods will be restarted with
# the old config # the old config
old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir()) old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir())
kube_accessor = KubernetesAccessorSingleton.get_instance() kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.replace_qe_secret(old_secret) kube_accessor.replace_qe_secret(old_secret)
try: try:
for name in deployment_names: for name in deployment_names:
kube_accessor.rollback_deployment(name) kube_accessor.rollback_deployment(name)
except K8sApiException as e: except K8sApiException as e:
logger.exception('Failed to rollback deployment.') logger.exception("Failed to rollback deployment.")
return make_response(e.message, 503) return make_response(e.message, 503)
return make_response('Ok', 204) return make_response("Ok", 204)
@resource('/v1/kubernetes/config') @resource("/v1/kubernetes/config")
class SuperUserKubernetesConfiguration(ApiResource): class SuperUserKubernetesConfiguration(ApiResource):
""" Resource for saving the config files to kubernetes secrets. """ """ Resource for saving the config files to kubernetes secrets. """
@kubernetes_only @kubernetes_only
@nickname('scDeployConfiguration') @nickname("scDeployConfiguration")
def post(self): def post(self):
try: try:
new_secret = get_config_as_kube_secret(config_provider.get_config_dir_path()) new_secret = get_config_as_kube_secret(
KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret) config_provider.get_config_dir_path()
except K8sApiException as e: )
logger.exception('Failed to deploy qe config secret to kubernetes.') KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
return make_response(e.message, 503) except K8sApiException as e:
logger.exception("Failed to deploy qe config secret to kubernetes.")
return make_response(e.message, 503)
return make_response('Ok', 201) return make_response("Ok", 201)
@resource('/v1/kubernetes/config/populate') @resource("/v1/kubernetes/config/populate")
class KubernetesConfigurationPopulator(ApiResource): class KubernetesConfigurationPopulator(ApiResource):
""" Resource for populating the local configuration from the cluster's kubernetes secrets. """ """ Resource for populating the local configuration from the cluster's kubernetes secrets. """
@kubernetes_only @kubernetes_only
@nickname('scKubePopulateConfig') @nickname("scKubePopulateConfig")
def post(self): def post(self):
# Get a clean transient directory to write the config into # Get a clean transient directory to write the config into
config_provider.new_config_dir() config_provider.new_config_dir()
kube_accessor = KubernetesAccessorSingleton.get_instance() kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path()) kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path())
config_provider.create_copy_of_config_dir() config_provider.create_copy_of_config_dir()
# We update the db configuration to connect to their specified one # We update the db configuration to connect to their specified one
# (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error, # (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error,
# and all of the options create a new clean dir, so we'll never pollute configs) # and all of the options create a new clean dir, so we'll never pollute configs)
combined = dict(**app.config) combined = dict(**app.config)
combined.update(config_provider.get_config()) combined.update(config_provider.get_config())
configure(combined) configure(combined)
return 200 return 200

View file

@ -2,301 +2,283 @@ import logging
from flask import abort, request from flask import abort, request
from config_app.config_endpoints.api.suconfig_models_pre_oci import pre_oci_model as model from config_app.config_endpoints.api.suconfig_models_pre_oci import (
from config_app.config_endpoints.api import resource, ApiResource, nickname, validate_json_request pre_oci_model as model,
from config_app.c_app import (app, config_provider, superusers, ip_resolver, )
instance_keys, INIT_SCRIPTS_LOCATION) from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
validate_json_request,
)
from config_app.c_app import (
app,
config_provider,
superusers,
ip_resolver,
instance_keys,
INIT_SCRIPTS_LOCATION,
)
from data.database import configure from data.database import configure
from data.runmigration import run_alembic_migration from data.runmigration import run_alembic_migration
from util.config.configutil import add_enterprise_config_defaults from util.config.configutil import add_enterprise_config_defaults
from util.config.validator import validate_service_for_config, ValidatorContext, \ from util.config.validator import (
is_valid_config_upload_filename validate_service_for_config,
ValidatorContext,
is_valid_config_upload_filename,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def database_is_valid(): def database_is_valid():
""" Returns whether the database, as configured, is valid. """ """ Returns whether the database, as configured, is valid. """
return model.is_valid() return model.is_valid()
def database_has_users(): def database_has_users():
""" Returns whether the database has any users defined. """ """ Returns whether the database has any users defined. """
return model.has_users() return model.has_users()
@resource('/v1/superuser/config') @resource("/v1/superuser/config")
class SuperUserConfig(ApiResource): class SuperUserConfig(ApiResource):
""" Resource for fetching and updating the current configuration, if any. """ """ Resource for fetching and updating the current configuration, if any. """
schemas = {
'UpdateConfig': {
'type': 'object',
'description': 'Updates the YAML config file',
'required': [
'config',
],
'properties': {
'config': {
'type': 'object'
},
'password': {
'type': 'string'
},
},
},
}
@nickname('scGetConfig') schemas = {
def get(self): "UpdateConfig": {
""" Returns the currently defined configuration, if any. """ "type": "object",
config_object = config_provider.get_config() "description": "Updates the YAML config file",
return { "required": ["config"],
'config': config_object "properties": {
"config": {"type": "object"},
"password": {"type": "string"},
},
}
} }
@nickname('scUpdateConfig') @nickname("scGetConfig")
@validate_json_request('UpdateConfig') def get(self):
def put(self): """ Returns the currently defined configuration, if any. """
""" Updates the config override file. """ config_object = config_provider.get_config()
# Note: This method is called to set the database configuration before super users exists, return {"config": config_object}
# so we also allow it to be called if there is no valid registry configuration setup.
config_object = request.get_json()['config']
# Add any enterprise defaults missing from the config. @nickname("scUpdateConfig")
add_enterprise_config_defaults(config_object, app.config['SECRET_KEY']) @validate_json_request("UpdateConfig")
def put(self):
""" Updates the config override file. """
# Note: This method is called to set the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup.
config_object = request.get_json()["config"]
# Write the configuration changes to the config override file. # Add any enterprise defaults missing from the config.
config_provider.save_config(config_object) add_enterprise_config_defaults(config_object, app.config["SECRET_KEY"])
# now try to connect to the db provided in their config to validate it works # Write the configuration changes to the config override file.
combined = dict(**app.config) config_provider.save_config(config_object)
combined.update(config_provider.get_config())
configure(combined, testing=app.config['TESTING'])
return { # now try to connect to the db provided in their config to validate it works
'exists': True, combined = dict(**app.config)
'config': config_object combined.update(config_provider.get_config())
} configure(combined, testing=app.config["TESTING"])
return {"exists": True, "config": config_object}
@resource('/v1/superuser/registrystatus') @resource("/v1/superuser/registrystatus")
class SuperUserRegistryStatus(ApiResource): class SuperUserRegistryStatus(ApiResource):
""" Resource for determining the status of the registry, such as if config exists, """ Resource for determining the status of the registry, such as if config exists,
if a database is configured, and if it has any defined users. if a database is configured, and if it has any defined users.
""" """
@nickname('scRegistryStatus') @nickname("scRegistryStatus")
def get(self): def get(self):
""" Returns the status of the registry. """ """ Returns the status of the registry. """
# If there is no config file, we need to setup the database. # If there is no config file, we need to setup the database.
if not config_provider.config_exists(): if not config_provider.config_exists():
return { return {"status": "config-db"}
'status': 'config-db'
}
# If the database isn't yet valid, then we need to set it up. # If the database isn't yet valid, then we need to set it up.
if not database_is_valid(): if not database_is_valid():
return { return {"status": "setup-db"}
'status': 'setup-db'
}
config = config_provider.get_config() config = config_provider.get_config()
if config and config.get('SETUP_COMPLETE'): if config and config.get("SETUP_COMPLETE"):
return { return {"status": "config"}
'status': 'config'
}
return { return {"status": "create-superuser" if not database_has_users() else "config"}
'status': 'create-superuser' if not database_has_users() else 'config'
}
class _AlembicLogHandler(logging.Handler): class _AlembicLogHandler(logging.Handler):
def __init__(self): def __init__(self):
super(_AlembicLogHandler, self).__init__() super(_AlembicLogHandler, self).__init__()
self.records = [] self.records = []
def emit(self, record): def emit(self, record):
self.records.append({ self.records.append({"level": record.levelname, "message": record.getMessage()})
'level': record.levelname,
'message': record.getMessage()
})
def _reload_config(): def _reload_config():
combined = dict(**app.config) combined = dict(**app.config)
combined.update(config_provider.get_config()) combined.update(config_provider.get_config())
configure(combined) configure(combined)
return combined return combined
@resource('/v1/superuser/setupdb') @resource("/v1/superuser/setupdb")
class SuperUserSetupDatabase(ApiResource): class SuperUserSetupDatabase(ApiResource):
""" Resource for invoking alembic to setup the database. """ """ Resource for invoking alembic to setup the database. """
@nickname('scSetupDatabase') @nickname("scSetupDatabase")
def get(self): def get(self):
""" Invokes the alembic upgrade process. """ """ Invokes the alembic upgrade process. """
# Note: This method is called after the database configured is saved, but before the # Note: This method is called after the database configured is saved, but before the
# database has any tables. Therefore, we only allow it to be run in that unique case. # database has any tables. Therefore, we only allow it to be run in that unique case.
if config_provider.config_exists() and not database_is_valid(): if config_provider.config_exists() and not database_is_valid():
combined = _reload_config() combined = _reload_config()
app.config['DB_URI'] = combined['DB_URI'] app.config["DB_URI"] = combined["DB_URI"]
db_uri = app.config['DB_URI'] db_uri = app.config["DB_URI"]
escaped_db_uri = db_uri.replace('%', '%%') escaped_db_uri = db_uri.replace("%", "%%")
log_handler = _AlembicLogHandler() log_handler = _AlembicLogHandler()
try: try:
run_alembic_migration(escaped_db_uri, log_handler, setup_app=False) run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
except Exception as ex: except Exception as ex:
return { return {"error": str(ex)}
'error': str(ex)
}
return { return {"logs": log_handler.records}
'logs': log_handler.records
}
abort(403) abort(403)
@resource('/v1/superuser/config/createsuperuser') @resource("/v1/superuser/config/createsuperuser")
class SuperUserCreateInitialSuperUser(ApiResource): class SuperUserCreateInitialSuperUser(ApiResource):
""" Resource for creating the initial super user. """ """ Resource for creating the initial super user. """
schemas = {
'CreateSuperUser': {
'type': 'object',
'description': 'Information for creating the initial super user',
'required': [
'username',
'password',
'email'
],
'properties': {
'username': {
'type': 'string',
'description': 'The username for the superuser'
},
'password': {
'type': 'string',
'description': 'The password for the superuser'
},
'email': {
'type': 'string',
'description': 'The e-mail address for the superuser'
},
},
},
}
@nickname('scCreateInitialSuperuser') schemas = {
@validate_json_request('CreateSuperUser') "CreateSuperUser": {
def post(self): "type": "object",
""" Creates the initial super user, updates the underlying configuration and "description": "Information for creating the initial super user",
"required": ["username", "password", "email"],
"properties": {
"username": {
"type": "string",
"description": "The username for the superuser",
},
"password": {
"type": "string",
"description": "The password for the superuser",
},
"email": {
"type": "string",
"description": "The e-mail address for the superuser",
},
},
}
}
@nickname("scCreateInitialSuperuser")
@validate_json_request("CreateSuperUser")
def post(self):
""" Creates the initial super user, updates the underlying configuration and
sets the current session to have that super user. """ sets the current session to have that super user. """
_reload_config() _reload_config()
# Special security check: This method is only accessible when: # Special security check: This method is only accessible when:
# - There is a valid config YAML file. # - There is a valid config YAML file.
# - There are currently no users in the database (clean install) # - There are currently no users in the database (clean install)
# #
# We do this special security check because at the point this method is called, the database # We do this special security check because at the point this method is called, the database
# is clean but does not (yet) have any super users for our permissions code to check against. # is clean but does not (yet) have any super users for our permissions code to check against.
if config_provider.config_exists() and not database_has_users(): if config_provider.config_exists() and not database_has_users():
data = request.get_json() data = request.get_json()
username = data['username'] username = data["username"]
password = data['password'] password = data["password"]
email = data['email'] email = data["email"]
# Create the user in the database. # Create the user in the database.
superuser_uuid = model.create_superuser(username, password, email) superuser_uuid = model.create_superuser(username, password, email)
# Add the user to the config. # Add the user to the config.
config_object = config_provider.get_config() config_object = config_provider.get_config()
config_object['SUPER_USERS'] = [username] config_object["SUPER_USERS"] = [username]
config_provider.save_config(config_object) config_provider.save_config(config_object)
# Update the in-memory config for the new superuser. # Update the in-memory config for the new superuser.
superusers.register_superuser(username) superusers.register_superuser(username)
return { return {"status": True}
'status': True
}
abort(403) abort(403)
@resource('/v1/superuser/config/validate/<service>') @resource("/v1/superuser/config/validate/<service>")
class SuperUserConfigValidate(ApiResource): class SuperUserConfigValidate(ApiResource):
""" Resource for validating a block of configuration against an external service. """ """ Resource for validating a block of configuration against an external service. """
schemas = {
'ValidateConfig': { schemas = {
'type': 'object', "ValidateConfig": {
'description': 'Validates configuration', "type": "object",
'required': [ "description": "Validates configuration",
'config' "required": ["config"],
], "properties": {
'properties': { "config": {"type": "object"},
'config': { "password": {
'type': 'object' "type": "string",
}, "description": "The users password, used for auth validation",
'password': { },
'type': 'string', },
'description': 'The users password, used for auth validation'
} }
}, }
},
}
@nickname('scValidateConfig') @nickname("scValidateConfig")
@validate_json_request('ValidateConfig') @validate_json_request("ValidateConfig")
def post(self, service): def post(self, service):
""" Validates the given config for the given service. """ """ Validates the given config for the given service. """
# Note: This method is called to validate the database configuration before super users exists, # Note: This method is called to validate the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup. Note that # so we also allow it to be called if there is no valid registry configuration setup. Note that
# this is also safe since this method does not access any information not given in the request. # this is also safe since this method does not access any information not given in the request.
config = request.get_json()['config'] config = request.get_json()["config"]
validator_context = ValidatorContext.from_app(app, config, validator_context = ValidatorContext.from_app(
request.get_json().get('password', ''), app,
instance_keys=instance_keys, config,
ip_resolver=ip_resolver, request.get_json().get("password", ""),
config_provider=config_provider, instance_keys=instance_keys,
init_scripts_location=INIT_SCRIPTS_LOCATION) ip_resolver=ip_resolver,
config_provider=config_provider,
init_scripts_location=INIT_SCRIPTS_LOCATION,
)
return validate_service_for_config(service, validator_context) return validate_service_for_config(service, validator_context)
@resource('/v1/superuser/config/file/<filename>') @resource("/v1/superuser/config/file/<filename>")
class SuperUserConfigFile(ApiResource): class SuperUserConfigFile(ApiResource):
""" Resource for fetching the status of config files and overriding them. """ """ Resource for fetching the status of config files and overriding them. """
@nickname('scConfigFileExists') @nickname("scConfigFileExists")
def get(self, filename): def get(self, filename):
""" Returns whether the configuration file with the given name exists. """ """ Returns whether the configuration file with the given name exists. """
if not is_valid_config_upload_filename(filename): if not is_valid_config_upload_filename(filename):
abort(404) abort(404)
return { return {"exists": config_provider.volume_file_exists(filename)}
'exists': config_provider.volume_file_exists(filename)
}
@nickname('scUpdateConfigFile') @nickname("scUpdateConfigFile")
def post(self, filename): def post(self, filename):
""" Updates the configuration file with the given name. """ """ Updates the configuration file with the given name. """
if not is_valid_config_upload_filename(filename): if not is_valid_config_upload_filename(filename):
abort(404) abort(404)
# Note: This method can be called before the configuration exists # Note: This method can be called before the configuration exists
# to upload the database SSL cert. # to upload the database SSL cert.
uploaded_file = request.files['file'] uploaded_file = request.files["file"]
if not uploaded_file: if not uploaded_file:
abort(400) abort(400)
config_provider.save_volume_file(filename, uploaded_file) config_provider.save_volume_file(filename, uploaded_file)
return { return {"status": True}
'status': True
}

View file

@ -4,36 +4,36 @@ from six import add_metaclass
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class SuperuserConfigDataInterface(object): class SuperuserConfigDataInterface(object):
""" """
Interface that represents all data store interactions required by the superuser config API. Interface that represents all data store interactions required by the superuser config API.
""" """
@abstractmethod @abstractmethod
def is_valid(self): def is_valid(self):
""" """
Returns true if the configured database is valid. Returns true if the configured database is valid.
""" """
@abstractmethod @abstractmethod
def has_users(self): def has_users(self):
""" """
Returns true if there are any users defined. Returns true if there are any users defined.
""" """
@abstractmethod @abstractmethod
def create_superuser(self, username, password, email): def create_superuser(self, username, password, email):
""" """
Creates a new superuser with the given username, password and email. Returns the user's UUID. Creates a new superuser with the given username, password and email. Returns the user's UUID.
""" """
@abstractmethod @abstractmethod
def has_federated_login(self, username, service_name): def has_federated_login(self, username, service_name):
""" """
Returns true if the matching user has a federated login under the matching service. Returns true if the matching user has a federated login under the matching service.
""" """
@abstractmethod @abstractmethod
def attach_federated_login(self, username, service_name, federated_username): def attach_federated_login(self, username, service_name, federated_username):
""" """
Attaches a federatated login to the matching user, under the given service. Attaches a federatated login to the matching user, under the given service.
""" """

View file

@ -1,37 +1,39 @@
from data import model from data import model
from data.database import User from data.database import User
from config_app.config_endpoints.api.suconfig_models_interface import SuperuserConfigDataInterface from config_app.config_endpoints.api.suconfig_models_interface import (
SuperuserConfigDataInterface,
)
class PreOCIModel(SuperuserConfigDataInterface): class PreOCIModel(SuperuserConfigDataInterface):
# Note: this method is different than has_users: the user select will throw if the user # Note: this method is different than has_users: the user select will throw if the user
# table does not exist, whereas has_users assumes the table is valid # table does not exist, whereas has_users assumes the table is valid
def is_valid(self): def is_valid(self):
try: try:
list(User.select().limit(1)) list(User.select().limit(1))
return True return True
except: except:
return False return False
def has_users(self): def has_users(self):
return bool(list(User.select().limit(1))) return bool(list(User.select().limit(1)))
def create_superuser(self, username, password, email): def create_superuser(self, username, password, email):
return model.user.create_user(username, password, email, auto_verify=True).uuid return model.user.create_user(username, password, email, auto_verify=True).uuid
def has_federated_login(self, username, service_name): def has_federated_login(self, username, service_name):
user = model.user.get_user(username) user = model.user.get_user(username)
if user is None: if user is None:
return False return False
return bool(model.user.lookup_federated_login(user, service_name)) return bool(model.user.lookup_federated_login(user, service_name))
def attach_federated_login(self, username, service_name, federated_username): def attach_federated_login(self, username, service_name, federated_username):
user = model.user.get_user(username) user = model.user.get_user(username)
if user is None: if user is None:
return False return False
model.user.attach_federated_login(user, service_name, federated_username) model.user.attach_federated_login(user, service_name, federated_username)
pre_oci_model = PreOCIModel() pre_oci_model = PreOCIModel()

View file

@ -12,7 +12,13 @@ from data.model import ServiceKeyDoesNotExist
from util.config.validator import EXTRA_CA_DIRECTORY from util.config.validator import EXTRA_CA_DIRECTORY
from config_app.config_endpoints.exception import InvalidRequest from config_app.config_endpoints.exception import InvalidRequest
from config_app.config_endpoints.api import resource, ApiResource, nickname, log_action, validate_json_request from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
log_action,
validate_json_request,
)
from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model
from config_app.config_util.ssl import load_certificate, CertInvalidException from config_app.config_util.ssl import load_certificate, CertInvalidException
from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
@ -21,228 +27,233 @@ from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@resource('/v1/superuser/customcerts/<certpath>') @resource("/v1/superuser/customcerts/<certpath>")
class SuperUserCustomCertificate(ApiResource): class SuperUserCustomCertificate(ApiResource):
""" Resource for managing a custom certificate. """ """ Resource for managing a custom certificate. """
@nickname('uploadCustomCertificate') @nickname("uploadCustomCertificate")
def post(self, certpath): def post(self, certpath):
uploaded_file = request.files['file'] uploaded_file = request.files["file"]
if not uploaded_file: if not uploaded_file:
raise InvalidRequest('Missing certificate file') raise InvalidRequest("Missing certificate file")
# Save the certificate. # Save the certificate.
certpath = pathvalidate.sanitize_filename(certpath) certpath = pathvalidate.sanitize_filename(certpath)
if not certpath.endswith('.crt'): if not certpath.endswith(".crt"):
raise InvalidRequest('Invalid certificate file: must have suffix `.crt`') raise InvalidRequest("Invalid certificate file: must have suffix `.crt`")
logger.debug('Saving custom certificate %s', certpath) logger.debug("Saving custom certificate %s", certpath)
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath) cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.save_volume_file(cert_full_path, uploaded_file) config_provider.save_volume_file(cert_full_path, uploaded_file)
logger.debug('Saved custom certificate %s', certpath) logger.debug("Saved custom certificate %s", certpath)
# Validate the certificate. # Validate the certificate.
try: try:
logger.debug('Loading custom certificate %s', certpath) logger.debug("Loading custom certificate %s", certpath)
with config_provider.get_volume_file(cert_full_path) as f: with config_provider.get_volume_file(cert_full_path) as f:
load_certificate(f.read()) load_certificate(f.read())
except CertInvalidException: except CertInvalidException:
logger.exception('Got certificate invalid error for cert %s', certpath) logger.exception("Got certificate invalid error for cert %s", certpath)
return '', 204 return "", 204
except IOError: except IOError:
logger.exception('Got IO error for cert %s', certpath) logger.exception("Got IO error for cert %s", certpath)
return '', 204 return "", 204
# Call the update script with config dir location to install the certificate immediately. # Call the update script with config dir location to install the certificate immediately.
if not app.config['TESTING']: if not app.config["TESTING"]:
cert_dir = os.path.join(config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY) cert_dir = os.path.join(
if subprocess.call([os.path.join(INIT_SCRIPTS_LOCATION, 'certs_install.sh')], env={ 'CERTDIR': cert_dir }) != 0: config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY
raise Exception('Could not install certificates') )
if (
subprocess.call(
[os.path.join(INIT_SCRIPTS_LOCATION, "certs_install.sh")],
env={"CERTDIR": cert_dir},
)
!= 0
):
raise Exception("Could not install certificates")
return '', 204 return "", 204
@nickname('deleteCustomCertificate') @nickname("deleteCustomCertificate")
def delete(self, certpath): def delete(self, certpath):
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath) cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.remove_volume_file(cert_full_path) config_provider.remove_volume_file(cert_full_path)
return '', 204 return "", 204
@resource('/v1/superuser/customcerts') @resource("/v1/superuser/customcerts")
class SuperUserCustomCertificates(ApiResource): class SuperUserCustomCertificates(ApiResource):
""" Resource for managing custom certificates. """ """ Resource for managing custom certificates. """
@nickname('getCustomCertificates') @nickname("getCustomCertificates")
def get(self): def get(self):
has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY) has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY) extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
if extra_certs_found is None: if extra_certs_found is None:
return { return {"status": "file" if has_extra_certs_path else "none"}
'status': 'file' if has_extra_certs_path else 'none',
}
cert_views = [] cert_views = []
for extra_cert_path in extra_certs_found: for extra_cert_path in extra_certs_found:
try: try:
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, extra_cert_path) cert_full_path = config_provider.get_volume_path(
with config_provider.get_volume_file(cert_full_path) as f: EXTRA_CA_DIRECTORY, extra_cert_path
certificate = load_certificate(f.read()) )
cert_views.append({ with config_provider.get_volume_file(cert_full_path) as f:
'path': extra_cert_path, certificate = load_certificate(f.read())
'names': list(certificate.names), cert_views.append(
'expired': certificate.expired, {
}) "path": extra_cert_path,
except CertInvalidException as cie: "names": list(certificate.names),
cert_views.append({ "expired": certificate.expired,
'path': extra_cert_path, }
'error': cie.message, )
}) except CertInvalidException as cie:
except IOError as ioe: cert_views.append({"path": extra_cert_path, "error": cie.message})
cert_views.append({ except IOError as ioe:
'path': extra_cert_path, cert_views.append({"path": extra_cert_path, "error": ioe.message})
'error': ioe.message,
})
return { return {"status": "directory", "certs": cert_views}
'status': 'directory',
'certs': cert_views,
}
@resource('/v1/superuser/keys') @resource("/v1/superuser/keys")
class SuperUserServiceKeyManagement(ApiResource): class SuperUserServiceKeyManagement(ApiResource):
""" Resource for managing service keys.""" """ Resource for managing service keys."""
schemas = {
'CreateServiceKey': {
'id': 'CreateServiceKey',
'type': 'object',
'description': 'Description of creation of a service key',
'required': ['service', 'expiration'],
'properties': {
'service': {
'type': 'string',
'description': 'The service authenticating with this key',
},
'name': {
'type': 'string',
'description': 'The friendly name of a service key',
},
'metadata': {
'type': 'object',
'description': 'The key/value pairs of this key\'s metadata',
},
'notes': {
'type': 'string',
'description': 'If specified, the extra notes for the key',
},
'expiration': {
'description': 'The expiration date as a unix timestamp',
'anyOf': [{'type': 'number'}, {'type': 'null'}],
},
},
},
}
@nickname('listServiceKeys') schemas = {
def get(self): "CreateServiceKey": {
keys = pre_oci_model.list_all_service_keys() "id": "CreateServiceKey",
"type": "object",
return jsonify({ "description": "Description of creation of a service key",
'keys': [key.to_dict() for key in keys], "required": ["service", "expiration"],
}) "properties": {
"service": {
@nickname('createServiceKey') "type": "string",
@validate_json_request('CreateServiceKey') "description": "The service authenticating with this key",
def post(self): },
body = request.get_json() "name": {
"type": "string",
# Ensure we have a valid expiration date if specified. "description": "The friendly name of a service key",
expiration_date = body.get('expiration', None) },
if expiration_date is not None: "metadata": {
try: "type": "object",
expiration_date = datetime.utcfromtimestamp(float(expiration_date)) "description": "The key/value pairs of this key's metadata",
except ValueError as ve: },
raise InvalidRequest('Invalid expiration date: %s' % ve) "notes": {
"type": "string",
if expiration_date <= datetime.now(): "description": "If specified, the extra notes for the key",
raise InvalidRequest('Expiration date cannot be in the past') },
"expiration": {
# Create the metadata for the key. "description": "The expiration date as a unix timestamp",
metadata = body.get('metadata', {}) "anyOf": [{"type": "number"}, {"type": "null"}],
metadata.update({ },
'created_by': 'Quay Superuser Panel', },
'ip': request.remote_addr, }
})
# Generate a key with a private key that we *never save*.
(private_key, key_id) = pre_oci_model.generate_service_key(body['service'], expiration_date,
metadata=metadata,
name=body.get('name', ''))
# Auto-approve the service key.
pre_oci_model.approve_service_key(key_id, ServiceKeyApprovalType.SUPERUSER,
notes=body.get('notes', ''))
# Log the creation and auto-approval of the service key.
key_log_metadata = {
'kid': key_id,
'preshared': True,
'service': body['service'],
'name': body.get('name', ''),
'expiration_date': expiration_date,
'auto_approved': True,
} }
log_action('service_key_create', None, key_log_metadata) @nickname("listServiceKeys")
log_action('service_key_approve', None, key_log_metadata) def get(self):
keys = pre_oci_model.list_all_service_keys()
return jsonify({ return jsonify({"keys": [key.to_dict() for key in keys]})
'kid': key_id,
'name': body.get('name', ''),
'service': body['service'],
'public_key': private_key.publickey().exportKey('PEM'),
'private_key': private_key.exportKey('PEM'),
})
@resource('/v1/superuser/approvedkeys/<kid>') @nickname("createServiceKey")
@validate_json_request("CreateServiceKey")
def post(self):
body = request.get_json()
# Ensure we have a valid expiration date if specified.
expiration_date = body.get("expiration", None)
if expiration_date is not None:
try:
expiration_date = datetime.utcfromtimestamp(float(expiration_date))
except ValueError as ve:
raise InvalidRequest("Invalid expiration date: %s" % ve)
if expiration_date <= datetime.now():
raise InvalidRequest("Expiration date cannot be in the past")
# Create the metadata for the key.
metadata = body.get("metadata", {})
metadata.update(
{"created_by": "Quay Superuser Panel", "ip": request.remote_addr}
)
# Generate a key with a private key that we *never save*.
(private_key, key_id) = pre_oci_model.generate_service_key(
body["service"],
expiration_date,
metadata=metadata,
name=body.get("name", ""),
)
# Auto-approve the service key.
pre_oci_model.approve_service_key(
key_id, ServiceKeyApprovalType.SUPERUSER, notes=body.get("notes", "")
)
# Log the creation and auto-approval of the service key.
key_log_metadata = {
"kid": key_id,
"preshared": True,
"service": body["service"],
"name": body.get("name", ""),
"expiration_date": expiration_date,
"auto_approved": True,
}
log_action("service_key_create", None, key_log_metadata)
log_action("service_key_approve", None, key_log_metadata)
return jsonify(
{
"kid": key_id,
"name": body.get("name", ""),
"service": body["service"],
"public_key": private_key.publickey().exportKey("PEM"),
"private_key": private_key.exportKey("PEM"),
}
)
@resource("/v1/superuser/approvedkeys/<kid>")
class SuperUserServiceKeyApproval(ApiResource): class SuperUserServiceKeyApproval(ApiResource):
""" Resource for approving service keys. """ """ Resource for approving service keys. """
schemas = { schemas = {
'ApproveServiceKey': { "ApproveServiceKey": {
'id': 'ApproveServiceKey', "id": "ApproveServiceKey",
'type': 'object', "type": "object",
'description': 'Information for approving service keys', "description": "Information for approving service keys",
'properties': { "properties": {
'notes': { "notes": {"type": "string", "description": "Optional approval notes"}
'type': 'string', },
'description': 'Optional approval notes', }
}, }
},
},
}
@nickname('approveServiceKey') @nickname("approveServiceKey")
@validate_json_request('ApproveServiceKey') @validate_json_request("ApproveServiceKey")
def post(self, kid): def post(self, kid):
notes = request.get_json().get('notes', '') notes = request.get_json().get("notes", "")
try: try:
key = pre_oci_model.approve_service_key(kid, ServiceKeyApprovalType.SUPERUSER, notes=notes) key = pre_oci_model.approve_service_key(
kid, ServiceKeyApprovalType.SUPERUSER, notes=notes
)
# Log the approval of the service key. # Log the approval of the service key.
key_log_metadata = { key_log_metadata = {
'kid': kid, "kid": kid,
'service': key.service, "service": key.service,
'name': key.name, "name": key.name,
'expiration_date': key.expiration_date, "expiration_date": key.expiration_date,
} }
# Note: this may not actually be the current person modifying the config, but if they're in the config tool, # Note: this may not actually be the current person modifying the config, but if they're in the config tool,
# they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine # they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
super_user = app.config.get('SUPER_USERS', [None])[0] super_user = app.config.get("SUPER_USERS", [None])[0]
log_action('service_key_approve', super_user, key_log_metadata) log_action("service_key_approve", super_user, key_log_metadata)
except ServiceKeyDoesNotExist: except ServiceKeyDoesNotExist:
raise NotFound() raise NotFound()
except ServiceKeyAlreadyApproved: except ServiceKeyAlreadyApproved:
pass pass
return make_response('', 201) return make_response("", 201)

View file

@ -6,21 +6,33 @@ from config_app.config_endpoints.api import format_date
def user_view(user): def user_view(user):
return { return {"name": user.username, "kind": "user", "is_robot": user.robot}
'name': user.username,
'kind': 'user',
'is_robot': user.robot,
}
class RepositoryBuild(namedtuple('RepositoryBuild', class RepositoryBuild(
['uuid', 'logs_archived', 'repository_namespace_user_username', namedtuple(
'repository_name', "RepositoryBuild",
'can_write', 'can_read', 'pull_robot', 'resource_key', 'trigger', [
'display_name', "uuid",
'started', 'job_config', 'phase', 'status', 'error', "logs_archived",
'archive_url'])): "repository_namespace_user_username",
""" "repository_name",
"can_write",
"can_read",
"pull_robot",
"resource_key",
"trigger",
"display_name",
"started",
"job_config",
"phase",
"status",
"error",
"archive_url",
],
)
):
"""
RepositoryBuild represents a build associated with a repostiory RepositoryBuild represents a build associated with a repostiory
:type uuid: string :type uuid: string
:type logs_archived: boolean :type logs_archived: boolean
@ -40,42 +52,46 @@ class RepositoryBuild(namedtuple('RepositoryBuild',
:type archive_url: string :type archive_url: string
""" """
def to_dict(self): def to_dict(self):
resp = { resp = {
'id': self.uuid, "id": self.uuid,
'phase': self.phase, "phase": self.phase,
'started': format_date(self.started), "started": format_date(self.started),
'display_name': self.display_name, "display_name": self.display_name,
'status': self.status or {}, "status": self.status or {},
'subdirectory': self.job_config.get('build_subdir', ''), "subdirectory": self.job_config.get("build_subdir", ""),
'dockerfile_path': self.job_config.get('build_subdir', ''), "dockerfile_path": self.job_config.get("build_subdir", ""),
'context': self.job_config.get('context', ''), "context": self.job_config.get("context", ""),
'tags': self.job_config.get('docker_tags', []), "tags": self.job_config.get("docker_tags", []),
'manual_user': self.job_config.get('manual_user', None), "manual_user": self.job_config.get("manual_user", None),
'is_writer': self.can_write, "is_writer": self.can_write,
'trigger': self.trigger.to_dict(), "trigger": self.trigger.to_dict(),
'trigger_metadata': self.job_config.get('trigger_metadata', None) if self.can_read else None, "trigger_metadata": self.job_config.get("trigger_metadata", None)
'resource_key': self.resource_key, if self.can_read
'pull_robot': user_view(self.pull_robot) if self.pull_robot else None, else None,
'repository': { "resource_key": self.resource_key,
'namespace': self.repository_namespace_user_username, "pull_robot": user_view(self.pull_robot) if self.pull_robot else None,
'name': self.repository_name "repository": {
}, "namespace": self.repository_namespace_user_username,
'error': self.error, "name": self.repository_name,
} },
"error": self.error,
}
if self.can_write: if self.can_write:
if self.resource_key is not None: if self.resource_key is not None:
resp['archive_url'] = self.archive_url resp["archive_url"] = self.archive_url
elif self.job_config.get('archive_url', None): elif self.job_config.get("archive_url", None):
resp['archive_url'] = self.job_config['archive_url'] resp["archive_url"] = self.job_config["archive_url"]
return resp return resp
class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_date', 'notes'])): class Approval(
""" namedtuple("Approval", ["approver", "approval_type", "approved_date", "notes"])
):
"""
Approval represents whether a key has been approved or not Approval represents whether a key has been approved or not
:type approver: User :type approver: User
:type approval_type: string :type approval_type: string
@ -83,19 +99,32 @@ class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_da
:type notes: string :type notes: string
""" """
def to_dict(self): def to_dict(self):
return { return {
'approver': self.approver.to_dict() if self.approver else None, "approver": self.approver.to_dict() if self.approver else None,
'approval_type': self.approval_type, "approval_type": self.approval_type,
'approved_date': self.approved_date, "approved_date": self.approved_date,
'notes': self.notes, "notes": self.notes,
} }
class ServiceKey( class ServiceKey(
namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'metadata', 'created_date', namedtuple(
'expiration_date', 'rotation_duration', 'approval'])): "ServiceKey",
""" [
"name",
"kid",
"service",
"jwk",
"metadata",
"created_date",
"expiration_date",
"rotation_duration",
"approval",
],
)
):
"""
ServiceKey is an apostille signing key ServiceKey is an apostille signing key
:type name: string :type name: string
:type kid: int :type kid: int
@ -109,22 +138,22 @@ class ServiceKey(
""" """
def to_dict(self): def to_dict(self):
return { return {
'name': self.name, "name": self.name,
'kid': self.kid, "kid": self.kid,
'service': self.service, "service": self.service,
'jwk': self.jwk, "jwk": self.jwk,
'metadata': self.metadata, "metadata": self.metadata,
'created_date': self.created_date, "created_date": self.created_date,
'expiration_date': self.expiration_date, "expiration_date": self.expiration_date,
'rotation_duration': self.rotation_duration, "rotation_duration": self.rotation_duration,
'approval': self.approval.to_dict() if self.approval is not None else None, "approval": self.approval.to_dict() if self.approval is not None else None,
} }
class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robot'])): class User(namedtuple("User", ["username", "email", "verified", "enabled", "robot"])):
""" """
User represents a single user. User represents a single user.
:type username: string :type username: string
:type email: string :type email: string
@ -133,41 +162,38 @@ class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robo
:type robot: User :type robot: User
""" """
def to_dict(self): def to_dict(self):
user_data = { user_data = {
'kind': 'user', "kind": "user",
'name': self.username, "name": self.username,
'username': self.username, "username": self.username,
'email': self.email, "email": self.email,
'verified': self.verified, "verified": self.verified,
'enabled': self.enabled, "enabled": self.enabled,
} }
return user_data return user_data
class Organization(namedtuple('Organization', ['username', 'email'])): class Organization(namedtuple("Organization", ["username", "email"])):
""" """
Organization represents a single org. Organization represents a single org.
:type username: string :type username: string
:type email: string :type email: string
""" """
def to_dict(self): def to_dict(self):
return { return {"name": self.username, "email": self.email}
'name': self.username,
'email': self.email,
}
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class SuperuserDataInterface(object): class SuperuserDataInterface(object):
""" """
Interface that represents all data store interactions required by a superuser api. Interface that represents all data store interactions required by a superuser api.
""" """
@abstractmethod @abstractmethod
def list_all_service_keys(self): def list_all_service_keys(self):
""" """
Returns a list of service keys Returns a list of service keys
""" """

View file

@ -1,60 +1,85 @@
from data import model from data import model
from config_app.config_endpoints.api.superuser_models_interface import (SuperuserDataInterface, User, ServiceKey, from config_app.config_endpoints.api.superuser_models_interface import (
Approval) SuperuserDataInterface,
User,
ServiceKey,
Approval,
)
def _create_user(user): def _create_user(user):
if user is None: if user is None:
return None return None
return User(user.username, user.email, user.verified, user.enabled, user.robot) return User(user.username, user.email, user.verified, user.enabled, user.robot)
def _create_key(key): def _create_key(key):
approval = None approval = None
if key.approval is not None: if key.approval is not None:
approval = Approval(_create_user(key.approval.approver), key.approval.approval_type, approval = Approval(
key.approval.approved_date, _create_user(key.approval.approver),
key.approval.notes) key.approval.approval_type,
key.approval.approved_date,
key.approval.notes,
)
return ServiceKey(key.name, key.kid, key.service, key.jwk, key.metadata, key.created_date, return ServiceKey(
key.expiration_date, key.name,
key.rotation_duration, approval) key.kid,
key.service,
key.jwk,
key.metadata,
key.created_date,
key.expiration_date,
key.rotation_duration,
approval,
)
class ServiceKeyDoesNotExist(Exception): class ServiceKeyDoesNotExist(Exception):
pass pass
class ServiceKeyAlreadyApproved(Exception): class ServiceKeyAlreadyApproved(Exception):
pass pass
class PreOCIModel(SuperuserDataInterface): class PreOCIModel(SuperuserDataInterface):
""" """
PreOCIModel implements the data model for the SuperUser using a database schema PreOCIModel implements the data model for the SuperUser using a database schema
before it was changed to support the OCI specification. before it was changed to support the OCI specification.
""" """
def list_all_service_keys(self): def list_all_service_keys(self):
keys = model.service_keys.list_all_keys() keys = model.service_keys.list_all_keys()
return [_create_key(key) for key in keys] return [_create_key(key) for key in keys]
def approve_service_key(self, kid, approval_type, notes=''): def approve_service_key(self, kid, approval_type, notes=""):
try: try:
key = model.service_keys.approve_service_key(kid, approval_type, notes=notes) key = model.service_keys.approve_service_key(
return _create_key(key) kid, approval_type, notes=notes
except model.ServiceKeyDoesNotExist: )
raise ServiceKeyDoesNotExist return _create_key(key)
except model.ServiceKeyAlreadyApproved: except model.ServiceKeyDoesNotExist:
raise ServiceKeyAlreadyApproved raise ServiceKeyDoesNotExist
except model.ServiceKeyAlreadyApproved:
raise ServiceKeyAlreadyApproved
def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None, def generate_service_key(
rotation_duration=None): self,
(private_key, key) = model.service_keys.generate_service_key(service, expiration_date, service,
metadata=metadata, name=name) expiration_date,
kid=None,
name="",
metadata=None,
rotation_duration=None,
):
(private_key, key) = model.service_keys.generate_service_key(
service, expiration_date, metadata=metadata, name=name
)
return private_key, key.kid return private_key, key.kid
pre_oci_model = PreOCIModel() pre_oci_model = PreOCIModel()

View file

@ -10,53 +10,59 @@ from data.database import configure
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_util.tar import tarinfo_filter_partial, strip_absolute_path_and_add_trailing_dir from config_app.config_util.tar import (
tarinfo_filter_partial,
strip_absolute_path_and_add_trailing_dir,
)
@resource('/v1/configapp/initialization') @resource("/v1/configapp/initialization")
class ConfigInitialization(ApiResource): class ConfigInitialization(ApiResource):
""" """
Resource for dealing with any initialization logic for the config app Resource for dealing with any initialization logic for the config app
""" """
@nickname('scStartNewConfig') @nickname("scStartNewConfig")
def post(self): def post(self):
config_provider.new_config_dir() config_provider.new_config_dir()
return make_response('OK') return make_response("OK")
@resource('/v1/configapp/tarconfig') @resource("/v1/configapp/tarconfig")
class TarConfigLoader(ApiResource): class TarConfigLoader(ApiResource):
""" """
Resource for dealing with configuration as a tarball, Resource for dealing with configuration as a tarball,
including loading and generating functions including loading and generating functions
""" """
@nickname('scGetConfigTarball') @nickname("scGetConfigTarball")
def get(self): def get(self):
config_path = config_provider.get_config_dir_path() config_path = config_provider.get_config_dir_path()
tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path) tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
temp = tempfile.NamedTemporaryFile() temp = tempfile.NamedTemporaryFile()
with closing(tarfile.open(temp.name, mode="w|gz")) as tar: with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
for name in os.listdir(config_path): for name in os.listdir(config_path):
tar.add(os.path.join(config_path, name), filter=tarinfo_filter_partial(tar_dir_prefix)) tar.add(
return send_file(temp.name, mimetype='application/gzip') os.path.join(config_path, name),
filter=tarinfo_filter_partial(tar_dir_prefix),
)
return send_file(temp.name, mimetype="application/gzip")
@nickname('scUploadTarballConfig') @nickname("scUploadTarballConfig")
def put(self): def put(self):
""" Loads tarball config into the config provider """ """ Loads tarball config into the config provider """
# Generate a new empty dir to load the config into # Generate a new empty dir to load the config into
config_provider.new_config_dir() config_provider.new_config_dir()
input_stream = request.stream input_stream = request.stream
with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream: with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream:
tar_stream.extractall(config_provider.get_config_dir_path()) tar_stream.extractall(config_provider.get_config_dir_path())
config_provider.create_copy_of_config_dir() config_provider.create_copy_of_config_dir()
# now try to connect to the db provided in their config to validate it works # now try to connect to the db provided in their config to validate it works
combined = dict(**app.config) combined = dict(**app.config)
combined.update(config_provider.get_config()) combined.update(config_provider.get_config())
configure(combined) configure(combined)
return make_response('OK') return make_response("OK")

View file

@ -3,12 +3,12 @@ from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_endpoints.api.superuser_models_interface import user_view from config_app.config_endpoints.api.superuser_models_interface import user_view
@resource('/v1/user/') @resource("/v1/user/")
class User(ApiResource): class User(ApiResource):
""" Operations related to users. """ """ Operations related to users. """
@nickname('getLoggedInUser') @nickname("getLoggedInUser")
def get(self): def get(self):
""" Get user information for the authenticated user. """ """ Get user information for the authenticated user. """
user = get_authenticated_user() user = get_authenticated_user()
return user_view(user) return user_view(user)

View file

@ -14,60 +14,72 @@ from config_app.config_util.k8sconfig import get_k8s_namespace
def truthy_bool(param): def truthy_bool(param):
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'} return param not in {False, "false", "False", "0", "FALSE", "", "null"}
DEFAULT_JS_BUNDLE_NAME = 'configapp' DEFAULT_JS_BUNDLE_NAME = "configapp"
PARAM_REGEX = re.compile(r'<([^:>]+:)*([\w]+)>') PARAM_REGEX = re.compile(r"<([^:>]+:)*([\w]+)>")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TYPE_CONVERTER = { TYPE_CONVERTER = {
truthy_bool: 'boolean', truthy_bool: "boolean",
str: 'string', str: "string",
basestring: 'string', basestring: "string",
reqparse.text_type: 'string', reqparse.text_type: "string",
int: 'integer', int: "integer",
} }
def _list_files(path, extension, contains=""): def _list_files(path, extension, contains=""):
""" Returns a list of all the files with the given extension found under the given path. """ """ Returns a list of all the files with the given extension found under the given path. """
def matches(f): def matches(f):
return os.path.splitext(f)[1] == '.' + extension and contains in os.path.splitext(f)[0] return (
os.path.splitext(f)[1] == "." + extension
and contains in os.path.splitext(f)[0]
)
def join_path(dp, f): def join_path(dp, f):
# Remove the static/ prefix. It is added in the template. # Remove the static/ prefix. It is added in the template.
return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len('config_app/static/'):] return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len("config_app/static/") :]
filepath = os.path.join(os.path.join(ROOT_DIR, 'config_app/static/'), path) filepath = os.path.join(os.path.join(ROOT_DIR, "config_app/static/"), path)
return [join_path(dp, f) for dp, _, files in os.walk(filepath) for f in files if matches(f)] return [
join_path(dp, f)
for dp, _, files in os.walk(filepath)
for f in files
if matches(f)
]
FONT_AWESOME_4 = 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css' FONT_AWESOME_4 = "netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css"
def render_page_template(name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs): def render_page_template(
""" Renders the page template with the given name as the response and returns its contents. """ name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs
main_scripts = _list_files('build', 'js', js_bundle_name) ):
""" Renders the page template with the given name as the response and returns its contents. """
main_scripts = _list_files("build", "js", js_bundle_name)
use_cdn = os.getenv('TESTING') == 'true' use_cdn = os.getenv("TESTING") == "true"
external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4) external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
external_scripts = get_external_javascript(local=not use_cdn) external_scripts = get_external_javascript(local=not use_cdn)
contents = render_template(name, contents = render_template(
route_data=route_data, name,
main_scripts=main_scripts, route_data=route_data,
external_styles=external_styles, main_scripts=main_scripts,
external_scripts=external_scripts, external_styles=external_styles,
config_set=frontend_visible_config(app.config), external_scripts=external_scripts,
kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(), config_set=frontend_visible_config(app.config),
**kwargs) kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
**kwargs
)
resp = make_response(contents) resp = make_response(contents)
resp.headers['X-FRAME-OPTIONS'] = 'DENY' resp.headers["X-FRAME-OPTIONS"] = "DENY"
return resp return resp
def fully_qualified_name(method_view_class): def fully_qualified_name(method_view_class):
return '%s.%s' % (method_view_class.__module__, method_view_class.__name__) return "%s.%s" % (method_view_class.__module__, method_view_class.__name__)

View file

@ -5,11 +5,11 @@ from werkzeug.exceptions import HTTPException
class ApiErrorType(Enum): class ApiErrorType(Enum):
invalid_request = 'invalid_request' invalid_request = "invalid_request"
class ApiException(HTTPException): class ApiException(HTTPException):
""" """
Represents an error in the application/problem+json format. Represents an error in the application/problem+json format.
See: https://tools.ietf.org/html/rfc7807 See: https://tools.ietf.org/html/rfc7807
@ -31,36 +31,42 @@ class ApiException(HTTPException):
information if dereferenced. information if dereferenced.
""" """
def __init__(self, error_type, status_code, error_description, payload=None): def __init__(self, error_type, status_code, error_description, payload=None):
Exception.__init__(self) Exception.__init__(self)
self.error_description = error_description self.error_description = error_description
self.code = status_code self.code = status_code
self.payload = payload self.payload = payload
self.error_type = error_type self.error_type = error_type
self.data = self.to_dict() self.data = self.to_dict()
super(ApiException, self).__init__(error_description, None) super(ApiException, self).__init__(error_description, None)
def to_dict(self): def to_dict(self):
rv = dict(self.payload or ()) rv = dict(self.payload or ())
if self.error_description is not None: if self.error_description is not None:
rv['detail'] = self.error_description rv["detail"] = self.error_description
rv['error_message'] = self.error_description # TODO: deprecate rv["error_message"] = self.error_description # TODO: deprecate
rv['error_type'] = self.error_type.value # TODO: deprecate rv["error_type"] = self.error_type.value # TODO: deprecate
rv['title'] = self.error_type.value rv["title"] = self.error_type.value
rv['type'] = url_for('api.error', error_type=self.error_type.value, _external=True) rv["type"] = url_for(
rv['status'] = self.code "api.error", error_type=self.error_type.value, _external=True
)
rv["status"] = self.code
return rv return rv
class InvalidRequest(ApiException): class InvalidRequest(ApiException):
def __init__(self, error_description, payload=None): def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_request, 400, error_description, payload) ApiException.__init__(
self, ApiErrorType.invalid_request, 400, error_description, payload
)
class InvalidResponse(ApiException): class InvalidResponse(ApiException):
def __init__(self, error_description, payload=None): def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_response, 400, error_description, payload) ApiException.__init__(
self, ApiErrorType.invalid_response, 400, error_description, payload
)

View file

@ -5,19 +5,21 @@ from config_app.config_endpoints.common import render_page_template
from config_app.config_endpoints.api.discovery import generate_route_data from config_app.config_endpoints.api.discovery import generate_route_data
from config_app.config_endpoints.api import no_cache from config_app.config_endpoints.api import no_cache
setup_web = Blueprint('setup_web', __name__, template_folder='templates') setup_web = Blueprint("setup_web", __name__, template_folder="templates")
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def _get_route_data(): def _get_route_data():
return generate_route_data() return generate_route_data()
def render_page_template_with_routedata(name, *args, **kwargs): def render_page_template_with_routedata(name, *args, **kwargs):
return render_page_template(name, _get_route_data(), *args, **kwargs) return render_page_template(name, _get_route_data(), *args, **kwargs)
@no_cache @no_cache
@setup_web.route('/', methods=['GET'], defaults={'path': ''}) @setup_web.route("/", methods=["GET"], defaults={"path": ""})
def index(path, **kwargs): def index(path, **kwargs):
return render_page_template_with_routedata('index.html', js_bundle_name='configapp', **kwargs) return render_page_template_with_routedata(
"index.html", js_bundle_name="configapp", **kwargs
)

View file

@ -5,204 +5,242 @@ from data import database, model
from util.security.test.test_ssl_util import generate_test_cert from util.security.test.test_ssl_util import generate_test_cert
from config_app.c_app import app from config_app.c_app import app
from config_app.config_test import ApiTestCase, all_queues, ADMIN_ACCESS_USER, ADMIN_ACCESS_EMAIL from config_app.config_test import (
ApiTestCase,
all_queues,
ADMIN_ACCESS_USER,
ADMIN_ACCESS_EMAIL,
)
from config_app.config_endpoints.api import api_bp from config_app.config_endpoints.api import api_bp
from config_app.config_endpoints.api.superuser import SuperUserCustomCertificate, SuperUserCustomCertificates from config_app.config_endpoints.api.superuser import (
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserCreateInitialSuperUser, \ SuperUserCustomCertificate,
SuperUserConfigFile, SuperUserRegistryStatus SuperUserCustomCertificates,
)
from config_app.config_endpoints.api.suconfig import (
SuperUserConfig,
SuperUserCreateInitialSuperUser,
SuperUserConfigFile,
SuperUserRegistryStatus,
)
try: try:
app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(api_bp, url_prefix="/api")
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
pass pass
class TestSuperUserCreateInitialSuperUser(ApiTestCase): class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_create_superuser(self): def test_create_superuser(self):
data = { data = {
'username': 'newsuper', "username": "newsuper",
'password': 'password', "password": "password",
'email': 'jschorr+fake@devtable.com', "email": "jschorr+fake@devtable.com",
} }
# Add some fake config. # Add some fake config.
fake_config = { fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, hostname='fakehost')) self.putJsonResponse(
SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
)
# Try to write with config. Should 403 since there are users in the DB. # Try to write with config. Should 403 since there are users in the DB.
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
# Delete all users in the DB. # Delete all users in the DB.
for user in list(database.User.select()): for user in list(database.User.select()):
model.user.delete_user(user, all_queues) model.user.delete_user(user, all_queues)
# Create the superuser. # Create the superuser.
self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data) self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
# Ensure the user exists in the DB. # Ensure the user exists in the DB.
self.assertIsNotNone(model.user.get_user('newsuper')) self.assertIsNotNone(model.user.get_user("newsuper"))
# Ensure that the current user is a superuser in the config. # Ensure that the current user is a superuser in the config.
json = self.getJsonResponse(SuperUserConfig) json = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['newsuper'], json['config']['SUPER_USERS']) self.assertEquals(["newsuper"], json["config"]["SUPER_USERS"])
# Ensure that the current user is a superuser in memory by trying to call an API # Ensure that the current user is a superuser in memory by trying to call an API
# that will fail otherwise. # that will fail otherwise.
self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) self.getResponse(SuperUserConfigFile, params=dict(filename="ssl.cert"))
class TestSuperUserConfig(ApiTestCase): class TestSuperUserConfig(ApiTestCase):
def test_get_status_update_config(self): def test_get_status_update_config(self):
# With no config the status should be 'config-db'. # With no config the status should be 'config-db'.
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status']) self.assertEquals("config-db", json["status"])
# Add some fake config. # Add some fake config.
fake_config = { fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
json = self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, json = self.putJsonResponse(
hostname='fakehost')) SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
self.assertEquals('fakekey', json['config']['SECRET_KEY']) )
self.assertEquals('fakehost', json['config']['SERVER_HOSTNAME']) self.assertEquals("fakekey", json["config"]["SECRET_KEY"])
self.assertEquals('Database', json['config']['AUTHENTICATION_TYPE']) self.assertEquals("fakehost", json["config"]["SERVER_HOSTNAME"])
self.assertEquals("Database", json["config"]["AUTHENTICATION_TYPE"])
# With config the status should be 'setup-db'. # With config the status should be 'setup-db'.
# TODO: fix this test # TODO: fix this test
# json = self.getJsonResponse(SuperUserRegistryStatus) # json = self.getJsonResponse(SuperUserRegistryStatus)
# self.assertEquals('setup-db', json['status']) # self.assertEquals('setup-db', json['status'])
def test_config_file(self): def test_config_file(self):
# Try for an invalid file. Should 404. # Try for an invalid file. Should 404.
self.getResponse(SuperUserConfigFile, params=dict(filename='foobar'), expected_code=404) self.getResponse(
SuperUserConfigFile, params=dict(filename="foobar"), expected_code=404
)
# Try for a valid filename. Should not exist. # Try for a valid filename. Should not exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) json = self.getJsonResponse(
self.assertFalse(json['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(json["exists"])
# Add the file. # Add the file.
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), self.postResponse(
file=(StringIO('my file contents'), 'ssl.cert')) SuperUserConfigFile,
params=dict(filename="ssl.cert"),
file=(StringIO("my file contents"), "ssl.cert"),
)
# Should now exist. # Should now exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) json = self.getJsonResponse(
self.assertTrue(json['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertTrue(json["exists"])
def test_update_with_external_auth(self): def test_update_with_external_auth(self):
# Run a mock LDAP. # Run a mock LDAP.
mockldap = MockLdap({ mockldap = MockLdap(
'dc=quay,dc=io': { {
'dc': ['quay', 'io'] "dc=quay,dc=io": {"dc": ["quay", "io"]},
}, "ou=employees,dc=quay,dc=io": {"dc": ["quay", "io"], "ou": "employees"},
'ou=employees,dc=quay,dc=io': { "uid="
'dc': ['quay', 'io'], + ADMIN_ACCESS_USER
'ou': 'employees' + ",ou=employees,dc=quay,dc=io": {
}, "dc": ["quay", "io"],
'uid=' + ADMIN_ACCESS_USER + ',ou=employees,dc=quay,dc=io': { "ou": "employees",
'dc': ['quay', 'io'], "uid": [ADMIN_ACCESS_USER],
'ou': 'employees', "userPassword": ["password"],
'uid': [ADMIN_ACCESS_USER], "mail": [ADMIN_ACCESS_EMAIL],
'userPassword': ['password'], },
'mail': [ADMIN_ACCESS_EMAIL], }
}, )
})
config = { config = {
'AUTHENTICATION_TYPE': 'LDAP', "AUTHENTICATION_TYPE": "LDAP",
'LDAP_BASE_DN': ['dc=quay', 'dc=io'], "LDAP_BASE_DN": ["dc=quay", "dc=io"],
'LDAP_ADMIN_DN': 'uid=devtable,ou=employees,dc=quay,dc=io', "LDAP_ADMIN_DN": "uid=devtable,ou=employees,dc=quay,dc=io",
'LDAP_ADMIN_PASSWD': 'password', "LDAP_ADMIN_PASSWD": "password",
'LDAP_USER_RDN': ['ou=employees'], "LDAP_USER_RDN": ["ou=employees"],
'LDAP_UID_ATTR': 'uid', "LDAP_UID_ATTR": "uid",
'LDAP_EMAIL_ATTR': 'mail', "LDAP_EMAIL_ATTR": "mail",
} }
mockldap.start() mockldap.start()
try: try:
# Write the config with the valid password. # Write the config with the valid password.
self.putResponse(SuperUserConfig, self.putResponse(
data={'config': config, SuperUserConfig,
'password': 'password', data={"config": config, "password": "password", "hostname": "foo"},
'hostname': 'foo'}, expected_code=200) expected_code=200,
)
# Ensure that the user row has been linked.
# TODO: fix this test
# self.assertEquals(ADMIN_ACCESS_USER,
# model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
finally:
mockldap.stop()
# Ensure that the user row has been linked.
# TODO: fix this test
# self.assertEquals(ADMIN_ACCESS_USER,
# model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
finally:
mockldap.stop()
class TestSuperUserCustomCertificates(ApiTestCase): class TestSuperUserCustomCertificates(ApiTestCase):
def test_custom_certificates(self): def test_custom_certificates(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', san_list=['DNS:bar', 'DNS:baz']) cert_contents, _ = generate_test_cert(
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), hostname="somecoolhost", san_list=["DNS:bar", "DNS:baz"]
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204) )
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost', 'bar', 'baz']), set(cert_info['names'])) self.assertEquals(set(["somecoolhost", "bar", "baz"]), set(cert_info["names"]))
self.assertFalse(cert_info['expired']) self.assertFalse(cert_info["expired"])
# Remove the certificate. # Remove the certificate.
self.deleteResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt')) self.deleteResponse(
SuperUserCustomCertificate, params=dict(certpath="testcert.crt")
)
# Make sure it is gone. # Make sure it is gone.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(0, len(json['certs'])) self.assertEquals(0, len(json["certs"]))
def test_expired_custom_certificate(self): def test_expired_custom_certificate(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10) cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), self.postResponse(
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost']), set(cert_info['names'])) self.assertEquals(set(["somecoolhost"]), set(cert_info["names"]))
self.assertTrue(cert_info['expired']) self.assertTrue(cert_info["expired"])
def test_invalid_custom_certificate(self): def test_invalid_custom_certificate(self):
# Upload an invalid certificate. # Upload an invalid certificate.
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), self.postResponse(
file=(StringIO('some contents'), 'testcert.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO("some contents"), "testcert.crt"),
expected_code=204,
)
# Make sure it is present but invalid. # Make sure it is present but invalid.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals('no start line', cert_info['error']) self.assertEquals("no start line", cert_info["error"])
def test_path_sanitization(self): def test_path_sanitization(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10) cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert/../foobar.crt'), self.postResponse(
file=(StringIO(cert_contents), 'testcert/../foobar.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert/../foobar.crt"),
file=(StringIO(cert_contents), "testcert/../foobar.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0]
self.assertEquals('foobar.crt', cert_info['path'])
cert_info = json["certs"][0]
self.assertEquals("foobar.crt", cert_info["path"])

View file

@ -4,176 +4,235 @@ import mock
from data.database import User from data.database import User
from data import model from data import model
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserConfigValidate, SuperUserConfigFile, \ from config_app.config_endpoints.api.suconfig import (
SuperUserRegistryStatus, SuperUserCreateInitialSuperUser SuperUserConfig,
SuperUserConfigValidate,
SuperUserConfigFile,
SuperUserRegistryStatus,
SuperUserCreateInitialSuperUser,
)
from config_app.config_endpoints.api import api_bp from config_app.config_endpoints.api import api_bp
from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
try: try:
app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(api_bp, url_prefix="/api")
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
pass pass
# OVERRIDES FROM PORTING FROM OLD APP: # OVERRIDES FROM PORTING FROM OLD APP:
all_queues = [] # the config app doesn't have any queues all_queues = [] # the config app doesn't have any queues
class FreshConfigProvider(object): class FreshConfigProvider(object):
def __enter__(self): def __enter__(self):
config_provider.reset_for_test() config_provider.reset_for_test()
return config_provider return config_provider
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
config_provider.reset_for_test() config_provider.reset_for_test()
class TestSuperUserRegistryStatus(ApiTestCase): class TestSuperUserRegistryStatus(ApiTestCase):
def test_registry_status_no_config(self): def test_registry_status_no_config(self):
with FreshConfigProvider(): with FreshConfigProvider():
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status']) self.assertEquals("config-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=False)) @mock.patch(
def test_registry_status_no_database(self): "config_app.config_endpoints.api.suconfig.database_is_valid",
with FreshConfigProvider(): mock.Mock(return_value=False),
config_provider.save_config({'key': 'value'}) )
json = self.getJsonResponse(SuperUserRegistryStatus) def test_registry_status_no_database(self):
self.assertEquals('setup-db', json['status']) with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("setup-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) @mock.patch(
def test_registry_status_db_has_superuser(self): "config_app.config_endpoints.api.suconfig.database_is_valid",
with FreshConfigProvider(): mock.Mock(return_value=True),
config_provider.save_config({'key': 'value'}) )
json = self.getJsonResponse(SuperUserRegistryStatus) def test_registry_status_db_has_superuser(self):
self.assertEquals('config', json['status']) with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) @mock.patch(
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=False)) "config_app.config_endpoints.api.suconfig.database_is_valid",
def test_registry_status_db_no_superuser(self): mock.Mock(return_value=True),
with FreshConfigProvider(): )
config_provider.save_config({'key': 'value'}) @mock.patch(
json = self.getJsonResponse(SuperUserRegistryStatus) "config_app.config_endpoints.api.suconfig.database_has_users",
self.assertEquals('create-superuser', json['status']) mock.Mock(return_value=False),
)
def test_registry_status_db_no_superuser(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("create-superuser", json["status"])
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_has_users",
mock.Mock(return_value=True),
)
def test_registry_status_setup_complete(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value", "SETUP_COMPLETE": True})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=True))
def test_registry_status_setup_complete(self):
with FreshConfigProvider():
config_provider.save_config({'key': 'value', 'SETUP_COMPLETE': True})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config', json['status'])
class TestSuperUserConfigFile(ApiTestCase): class TestSuperUserConfigFile(ApiTestCase):
def test_get_superuser_invalid_filename(self): def test_get_superuser_invalid_filename(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.getResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) self.getResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_get_superuser(self): def test_get_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
result = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) result = self.getJsonResponse(
self.assertFalse(result['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(result["exists"])
def test_post_no_file(self): def test_post_no_file(self):
with FreshConfigProvider(): with FreshConfigProvider():
# No file # No file
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
def test_post_superuser_invalid_filename(self): def test_post_superuser_invalid_filename(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) self.postResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_post_superuser(self): def test_post_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
class TestSuperUserCreateInitialSuperUser(ApiTestCase): class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_no_config_file(self): def test_no_config_file(self):
with FreshConfigProvider(): with FreshConfigProvider():
# If there is no config.yaml, then this method should security fail. # If there is no config.yaml, then this method should security fail.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_db_users(self): def test_config_file_with_db_users(self):
with FreshConfigProvider(): with FreshConfigProvider():
# Write some config. # Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# If there is a config.yaml, but existing DB users exist, then this method should security # If there is a config.yaml, but existing DB users exist, then this method should security
# fail. # fail.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_no_db_users(self): def test_config_file_with_no_db_users(self):
with FreshConfigProvider(): with FreshConfigProvider():
# Write some config. # Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# Delete all the users in the DB. # Delete all the users in the DB.
for user in list(User.select()): for user in list(User.select()):
model.user.delete_user(user, all_queues) model.user.delete_user(user, all_queues)
# This method should now succeed. # This method should now succeed.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data) username="cooluser", password="password", email="fake@example.com"
self.assertTrue(result['status']) )
result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
self.assertTrue(result["status"])
# Verify the superuser was created. # Verify the superuser was created.
User.get(User.username == 'cooluser') User.get(User.username == "cooluser")
# Verify the superuser was placed into the config. # Verify the superuser was placed into the config.
result = self.getJsonResponse(SuperUserConfig) result = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['cooluser'], result['config']['SUPER_USERS']) self.assertEquals(["cooluser"], result["config"]["SUPER_USERS"])
class TestSuperUserConfigValidate(ApiTestCase): class TestSuperUserConfigValidate(ApiTestCase):
def test_nonsuperuser_noconfig(self): def test_nonsuperuser_noconfig(self):
with FreshConfigProvider(): with FreshConfigProvider():
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'), result = self.postJsonResponse(
data=dict(config={})) SuperUserConfigValidate,
params=dict(service="someservice"),
data=dict(config={}),
)
self.assertFalse(result['status']) self.assertFalse(result["status"])
def test_nonsuperuser_config(self):
with FreshConfigProvider():
# The validate config call works if there is no config.yaml OR the user is a superuser.
# Add a config, and verify it breaks when unauthenticated.
json = self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
def test_nonsuperuser_config(self): result = self.postJsonResponse(
with FreshConfigProvider(): SuperUserConfigValidate,
# The validate config call works if there is no config.yaml OR the user is a superuser. params=dict(service="someservice"),
# Add a config, and verify it breaks when unauthenticated. data=dict(config={}),
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) )
self.assertTrue(json['exists'])
self.assertFalse(result["status"])
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'),
data=dict(config={}))
self.assertFalse(result['status'])
class TestSuperUserConfig(ApiTestCase): class TestSuperUserConfig(ApiTestCase):
def test_get_superuser(self): def test_get_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
json = self.getJsonResponse(SuperUserConfig) json = self.getJsonResponse(SuperUserConfig)
# Note: We expect the config to be none because a config.yaml should never be checked into # Note: We expect the config to be none because a config.yaml should never be checked into
# the directory. # the directory.
self.assertIsNone(json['config']) self.assertIsNone(json["config"])
def test_put(self): def test_put(self):
with FreshConfigProvider() as config: with FreshConfigProvider() as config:
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) json = self.putJsonResponse(
self.assertTrue(json['exists']) SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
# Verify the config file exists. # Verify the config file exists.
self.assertTrue(config.config_exists()) self.assertTrue(config.config_exists())
# This should succeed. # This should succeed.
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='barbaz')) json = self.putJsonResponse(
self.assertTrue(json['exists']) SuperUserConfig, data=dict(config={}, hostname="barbaz")
)
self.assertTrue(json["exists"])
json = self.getJsonResponse(SuperUserConfig) json = self.getJsonResponse(SuperUserConfig)
self.assertIsNotNone(json['config']) self.assertIsNotNone(json["config"])
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -5,58 +5,63 @@ from backports.tempfile import TemporaryDirectory
from config_app.config_util.config.fileprovider import FileConfigProvider from config_app.config_util.config.fileprovider import FileConfigProvider
OLD_CONFIG_SUBDIR = 'old/' OLD_CONFIG_SUBDIR = "old/"
class TransientDirectoryProvider(FileConfigProvider): class TransientDirectoryProvider(FileConfigProvider):
""" Implementation of the config provider that reads and writes the data """ Implementation of the config provider that reads and writes the data
from/to the file system, only using temporary directories, from/to the file system, only using temporary directories,
deleting old dirs and creating new ones as requested. deleting old dirs and creating new ones as requested.
""" """
def __init__(self, config_volume, yaml_filename, py_filename): def __init__(self, config_volume, yaml_filename, py_filename):
# Create a temp directory that will be cleaned up when we change the config path # Create a temp directory that will be cleaned up when we change the config path
# This should ensure we have no "pollution" of different configs: # This should ensure we have no "pollution" of different configs:
# no uploaded config should ever affect subsequent config modifications/creations # no uploaded config should ever affect subsequent config modifications/creations
temp_dir = TemporaryDirectory() temp_dir = TemporaryDirectory()
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.old_config_dir = None self.old_config_dir = None
super(TransientDirectoryProvider, self).__init__(temp_dir.name, yaml_filename, py_filename) super(TransientDirectoryProvider, self).__init__(
temp_dir.name, yaml_filename, py_filename
)
@property @property
def provider_id(self): def provider_id(self):
return 'transient' return "transient"
def new_config_dir(self): def new_config_dir(self):
""" """
Update the path with a new temporary directory, deleting the old one in the process Update the path with a new temporary directory, deleting the old one in the process
""" """
self.temp_dir.cleanup() self.temp_dir.cleanup()
temp_dir = TemporaryDirectory() temp_dir = TemporaryDirectory()
self.config_volume = temp_dir.name self.config_volume = temp_dir.name
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename) self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename)
def create_copy_of_config_dir(self): def create_copy_of_config_dir(self):
""" """
Create a directory to store loaded/populated configuration (for rollback if necessary) Create a directory to store loaded/populated configuration (for rollback if necessary)
""" """
if self.old_config_dir is not None: if self.old_config_dir is not None:
self.old_config_dir.cleanup() self.old_config_dir.cleanup()
temp_dir = TemporaryDirectory() temp_dir = TemporaryDirectory()
self.old_config_dir = temp_dir self.old_config_dir = temp_dir
# Python 2.7's shutil.copy() doesn't allow for copying to existing directories, # Python 2.7's shutil.copy() doesn't allow for copying to existing directories,
# so when copying/reading to the old saved config, we have to talk to a subdirectory, # so when copying/reading to the old saved config, we have to talk to a subdirectory,
# and use the shutil.copytree() function # and use the shutil.copytree() function
copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR)) copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR))
def get_config_dir_path(self): def get_config_dir_path(self):
return self.config_volume return self.config_volume
def get_old_config_dir(self): def get_old_config_dir(self):
if self.old_config_dir is None: if self.old_config_dir is None:
raise Exception('Cannot return a configuration that was no old configuration') raise Exception(
"Cannot return a configuration that was no old configuration"
)
return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR) return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)

View file

@ -3,37 +3,40 @@ import os
from config_app.config_util.config.fileprovider import FileConfigProvider from config_app.config_util.config.fileprovider import FileConfigProvider
from config_app.config_util.config.testprovider import TestConfigProvider from config_app.config_util.config.testprovider import TestConfigProvider
from config_app.config_util.config.TransientDirectoryProvider import TransientDirectoryProvider from config_app.config_util.config.TransientDirectoryProvider import (
TransientDirectoryProvider,
)
from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX
def get_config_provider(config_volume, yaml_filename, py_filename, testing=False): def get_config_provider(config_volume, yaml_filename, py_filename, testing=False):
""" Loads and returns the config provider for the current environment. """ """ Loads and returns the config provider for the current environment. """
if testing: if testing:
return TestConfigProvider() return TestConfigProvider()
return TransientDirectoryProvider(config_volume, yaml_filename, py_filename) return TransientDirectoryProvider(config_volume, yaml_filename, py_filename)
def get_config_as_kube_secret(config_path): def get_config_as_kube_secret(config_path):
data = {} data = {}
# Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir # Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir
# we have to put the extra certs in with a prefix, and then one of our init scripts # we have to put the extra certs in with a prefix, and then one of our init scripts
# (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory # (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory
# so that they'll be installed correctly on startup by the certs_install script # so that they'll be installed correctly on startup by the certs_install script
certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY) certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY)
if os.path.exists(certs_dir): if os.path.exists(certs_dir):
for extra_cert in os.listdir(certs_dir): for extra_cert in os.listdir(certs_dir):
with open(os.path.join(certs_dir, extra_cert)) as f: with open(os.path.join(certs_dir, extra_cert)) as f:
data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(f.read()) data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(
f.read()
)
for name in os.listdir(config_path):
file_path = os.path.join(config_path, name)
if not os.path.isdir(file_path):
with open(file_path) as f:
data[name] = base64.b64encode(f.read())
for name in os.listdir(config_path): return data
file_path = os.path.join(config_path, name)
if not os.path.isdir(file_path):
with open(file_path) as f:
data[name] = base64.b64encode(f.read())
return data

Some files were not shown because too many files have changed in this diff Show more